1#ifndef STAN_MATH_REV_META_OPERANDS_AND_PARTIALS_HPP
2#define STAN_MATH_REV_META_OPERANDS_AND_PARTIALS_HPP
29template <
typename T1,
typename T2,
30 require_all_kernel_expressions_and_none_scalar_t<T1, T2>* =
nullptr>
32 x.adj() += z.adj() * y;
35template <
typename T1,
typename T2,
38 x.adj() += z.adj() * y;
42template <
typename Scalar1,
typename Scalar2, require_var_t<Scalar1>* =
nullptr,
43 require_not_var_matrix_t<Scalar1>* =
nullptr,
44 require_arithmetic_t<Scalar2>* =
nullptr>
46 x.adj() += z.adj() * y;
49template <
typename Scalar1,
typename Scalar2, require_var_t<Scalar1>* =
nullptr,
50 require_not_var_matrix_t<Scalar1>* =
nullptr,
51 require_arithmetic_t<Scalar2>* =
nullptr>
53 x.adj() += z.adj() * y;
57template <
typename Matrix1,
typename Matrix2,
61 x.adj().array() += z.adj() * y.array();
64template <
typename Matrix1,
typename Matrix2,
68 x.adj().array() += z.adj() * y.array();
71template <
typename Arith,
typename Alt, require_st_arithmetic<Arith>* =
nullptr>
73 const vari& )
noexcept {}
75template <
typename Arith,
typename Alt, require_st_arithmetic<Arith>* =
nullptr>
77 const var& )
noexcept {}
80template <
typename StdVec1,
typename Vec2,
84 for (
size_t i = 0; i < x.size(); ++i) {
89template <
typename StdVec1,
typename Vec2,
93 for (
size_t i = 0; i < x.size(); ++i) {
108 : partial_(other.partial_),
110 operands_(other.operands_) {}
112 inline auto&
operand() noexcept {
return operands_; }
115 static constexpr int size() {
return 1; }
122 using Op = std::vector<var, arena_allocator<var>>;
128 partials_vec_(partials_),
129 operands_(op.begin(), op.end()) {}
133 inline int size() const noexcept {
return this->operands_.size(); }
134 inline auto&&
operand() noexcept {
return std::move(this->operands_); }
135 inline auto&
partial() noexcept {
return this->partials_; }
138template <
typename Op>
146 partials_vec_(partials_),
151 : partials_(other.partials_),
152 partials_vec_(partials_),
153 operands_(other.operands_) {}
155 inline auto&
partial() noexcept {
return partials_; }
156 inline auto&
operand() noexcept {
return operands_; }
158 inline auto size() const noexcept {
return this->operands_.size(); }
161template <
typename Op>
170 partials_vec_(partials_),
176 : partials_(other.partials_),
177 partials_vec_(partials_),
178 operands_(other.operands_) {}
180 inline auto&
partial() noexcept {
return partials_; }
181 inline auto&
operand() noexcept {
return operands_; }
184 static constexpr int size() {
return 0; }
189template <
int R,
int C>
193 using Op = std::vector<inner_op, arena_allocator<inner_op>>;
197 : partials_vec_(ops.
size()), operands_(ops.begin(), ops.end()) {
198 for (
size_t i = 0; i < ops.size(); ++i) {
199 partials_vec_[i] = partial_t::Zero(ops[i].
rows(), ops[i].
cols());
205 inline int size() const noexcept {
206 if (
unlikely(this->operands_.size() == 0)) {
209 return this->operands_.size() * this->operands_[0].size();
211 inline auto&&
operand() noexcept {
return std::move(this->operands_); }
212 inline auto&
partial() noexcept {
return this->partials_vec_; }
218 using inner_vec = std::vector<var, arena_allocator<var>>;
219 using Op = std::vector<inner_vec, arena_allocator<inner_vec>>;
220 using partial_t = std::vector<double, arena_allocator<double>>;
223 : partials_vec_(
stan::math::
size(ops)), operands_(ops.
size()) {
225 operands_[i] =
inner_vec(ops[i].begin(), ops[i].end());
231 inline int size() const noexcept {
232 return this->operands_.size() * this->operands_[0].size();
234 inline auto&&
operand() noexcept {
return std::move(this->operands_); }
235 inline auto&&
partial() noexcept {
return std::move(this->partials_vec_); }
238template <
typename Op>
245 : partials_vec_(ops.
size()), operands_(ops.begin(), ops.end()) {
246 for (
size_t i = 0; i < ops.size(); ++i) {
254 static constexpr int size() noexcept {
return 0; }
255 inline auto&&
operand() noexcept {
return std::move(this->operands_); }
256 inline auto&&
partial() noexcept {
return std::move(this->partials_vec_); }
297template <
typename Op1,
typename Op2,
typename Op3,
typename Op4,
typename Op5,
298 typename Op6,
typename Op7,
typename Op8>
312 : edge1_(o1), edge2_(o2) {}
314 : edge1_(o1), edge2_(o2), edge3_(o3) {}
317 : edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4) {}
319 const Op4& o4,
const Op5& o5)
320 : edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4), edge5_(o5) {}
322 const Op4& o4,
const Op5& o5,
const Op6& o6)
330 const Op4& o4,
const Op5& o5,
const Op6& o6,
340 const Op4& o4,
const Op5& o5,
const Op6& o6,
341 const Op7& o7,
const Op8& o8)
366 value, [operand1 = edge1_.operand(), partial1 = edge1_.partial(),
367 operand2 = edge2_.operand(), partial2 = edge2_.partial(),
368 operand3 = edge3_.operand(), partial3 = edge3_.partial(),
369 operand4 = edge4_.operand(), partial4 = edge4_.partial(),
370 operand5 = edge5_.operand(), partial5 = edge5_.partial(),
371 operand6 = edge6_.operand(), partial6 = edge6_.partial(),
372 operand7 = edge7_.operand(), partial7 = edge7_.partial(),
373 operand8 = edge8_.operand(),
374 partial8 = edge8_.partial()](
const auto& vi)
mutable {
375 if constexpr (is_autodiff_v<Op1>) {
376 internal::update_adjoints(operand1, partial1, vi);
378 if constexpr (is_autodiff_v<Op2>) {
381 if constexpr (is_autodiff_v<Op3>) {
384 if constexpr (is_autodiff_v<Op4>) {
387 if constexpr (is_autodiff_v<Op5>) {
390 if constexpr (is_autodiff_v<Op6>) {
393 if constexpr (is_autodiff_v<Op7>) {
396 if constexpr (is_autodiff_v<Op8>) {
ops_partials_edge(const ops_partials_edge< double, Op, require_eigen_st< is_var, Op > > &other)
auto & partial() noexcept
broadcast_array< partials_t > partials_vec_
auto & operand() noexcept
auto size() const noexcept
arena_t< promote_scalar_t< double, Op > > partials_t
ops_partials_edge(const Op &ops)
auto && operand() noexcept
int size() const noexcept
arena_t< Eigen::Matrix< var, R, C > > inner_op
ops_partials_edge(const std::vector< Eigen::Matrix< var, R, C > > &ops)
std::vector< inner_op, arena_allocator< inner_op > > Op
auto & partial() noexcept
std::vector< partial_t, arena_allocator< partial_t > > partials_vec_
arena_t< Eigen::Matrix< double, R, C > > partial_t
int size() const noexcept
auto && operand() noexcept
std::vector< inner_vec, arena_allocator< inner_vec > > Op
std::vector< partial_t, arena_allocator< partial_t > > partials_vec_
auto && partial() noexcept
std::vector< var, arena_allocator< var > > inner_vec
std::vector< double, arena_allocator< double > > partial_t
ops_partials_edge(const std::vector< std::vector< var > > &ops)
auto && operand() noexcept
std::vector< var, arena_allocator< var > > Op
auto & partial() noexcept
int size() const noexcept
broadcast_array< partials_t > partials_vec_
ops_partials_edge(const std::vector< var > &op)
arena_t< Eigen::VectorXd > partials_t
std::vector< var_value< Op >, arena_allocator< var_value< Op > > > operands_
static constexpr int size() noexcept
auto && partial() noexcept
ops_partials_edge(const std::vector< var_value< Op > > &ops)
auto && operand() noexcept
std::vector< arena_t< Op >, arena_allocator< arena_t< Op > > > partials_t
ops_partials_edge(const ops_partials_edge< double, var > &other)
static constexpr int size()
auto & operand() noexcept
ops_partials_edge(const var &op) noexcept
ops_partials_edge(const var_value< Op > &ops)
ops_partials_edge(const ops_partials_edge< double, var_value< Op >, require_eigen_t< Op > > &other)
auto & partial() noexcept
auto & operand() noexcept
broadcast_array< partials_t > partials_vec_
var_value< Op > operands_
static constexpr int size()
An edge holds both the operands and its associated partial derivatives.
internal::ops_partials_edge< double, std::decay_t< Op5 > > edge5_
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5)
operands_and_partials(const Op1 &o1, const Op2 &o2)
internal::ops_partials_edge< double, std::decay_t< Op6 > > edge6_
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6, const Op7 &o7)
internal::ops_partials_edge< double, std::decay_t< Op4 > > edge4_
internal::ops_partials_edge< double, std::decay_t< Op8 > > edge8_
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6)
internal::ops_partials_edge< double, std::decay_t< Op2 > > edge2_
internal::ops_partials_edge< double, std::decay_t< Op7 > > edge7_
operands_and_partials(const Op1 &o1)
internal::ops_partials_edge< double, std::decay_t< Op1 > > edge1_
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6, const Op7 &o7, const Op8 &o8)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3)
internal::ops_partials_edge< double, std::decay_t< Op3 > > edge3_
This template builds partial derivatives with respect to a set of operands.
require_t< std::is_arithmetic< scalar_type_t< std::decay_t< T > > > > require_st_arithmetic
Require scalar type satisfies std::is_arithmetic.
require_t< container_type_check_base< is_eigen, scalar_type_t, TypeCheck, Check... > > require_eigen_st
Require type satisfies is_eigen.
require_t< is_eigen< std::decay_t< T > > > require_eigen_t
Require type satisfies is_eigen.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
int64_t cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
require_t< is_rev_matrix< std::decay_t< T > > > require_rev_matrix_t
Require type satisfies is_rev_matrix.
require_t< is_std_vector< std::decay_t< T > > > require_std_vector_t
Require type satisfies is_std_vector.
var build(double value)
Build the node to be stored on the autodiff graph.
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
void update_adjoints(var_value< T1 > &x, const T2 &y, const vari &z)
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
typename plain_type< std::decay_t< T > >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Defines a static member named value which is defined to be false as the primitive scalar types cannot...
std library compatible allocator that uses AD stack.