Automatic Differentiation
 
Loading...
Searching...
No Matches
operands_and_partials.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_META_OPERANDS_AND_PARTIALS_HPP
2#define STAN_MATH_REV_META_OPERANDS_AND_PARTIALS_HPP
3
19#include <vector>
20#include <tuple>
21
22namespace stan {
23namespace math {
24
25namespace internal {
26
27// OpenCL
28template <typename T1, typename T2,
29 require_all_kernel_expressions_and_none_scalar_t<T1, T2>* = nullptr>
30inline void update_adjoints(var_value<T1>& x, const T2& y, const vari& z) {
31 x.adj() += z.adj() * y;
32}
33
34template <typename T1, typename T2,
36inline void update_adjoints(var_value<T1>& x, const T2& y, const var& z) {
37 x.adj() += z.adj() * y;
38}
39
40// Scalars
41template <typename Scalar1, typename Scalar2, require_var_t<Scalar1>* = nullptr,
42 require_not_var_matrix_t<Scalar1>* = nullptr,
43 require_arithmetic_t<Scalar2>* = nullptr>
44inline void update_adjoints(Scalar1 x, Scalar2 y, const vari& z) noexcept {
45 x.adj() += z.adj() * y;
46}
47
48template <typename Scalar1, typename Scalar2, require_var_t<Scalar1>* = nullptr,
49 require_not_var_matrix_t<Scalar1>* = nullptr,
50 require_arithmetic_t<Scalar2>* = nullptr>
51inline void update_adjoints(Scalar1 x, Scalar2 y, const var& z) noexcept {
52 x.adj() += z.adj() * y;
53}
54
55// Matrix
56template <typename Matrix1, typename Matrix2,
59inline void update_adjoints(Matrix1& x, const Matrix2& y, const vari& z) {
60 x.adj().array() += z.adj() * y.array();
61}
62
63template <typename Matrix1, typename Matrix2,
66inline void update_adjoints(Matrix1& x, const Matrix2& y, const var& z) {
67 x.adj().array() += z.adj() * y.array();
68}
69
70template <typename Arith, typename Alt, require_st_arithmetic<Arith>* = nullptr>
71inline constexpr void update_adjoints(Arith&& /* x */, Alt&& /* y */,
72 const vari& /* z */) noexcept {}
73
74template <typename Arith, typename Alt, require_st_arithmetic<Arith>* = nullptr>
75inline constexpr void update_adjoints(Arith&& /* x */, Alt&& /* y */,
76 const var& /* z */) noexcept {}
77
78// Vectors
79template <typename StdVec1, typename Vec2,
82inline void update_adjoints(StdVec1& x, const Vec2& y, const vari& z) {
83 for (size_t i = 0; i < x.size(); ++i) {
84 update_adjoints(x[i], y[i], z);
85 }
86}
87
88template <typename StdVec1, typename Vec2,
91inline void update_adjoints(StdVec1& x, const Vec2& y, const var& z) {
92 for (size_t i = 0; i < x.size(); ++i) {
93 update_adjoints(x[i], y[i], z);
94 }
95}
96
100template <>
101class ops_partials_edge<double, var> {
102 public:
103 double partial_{0};
104 broadcast_array<double> partials_{partial_};
105 explicit ops_partials_edge(const var& op) noexcept : operands_(op) {}
107 : partial_(other.partial_),
108 partials_(partial_),
109 operands_(other.operands_) {}
110 inline auto& partial() { return partial_; }
111 inline auto& operand() noexcept { return operands_; }
112
114 static constexpr int size() { return 1; }
115};
116// Vectorized Univariate
117// Vectorized Univariate
118template <>
119class ops_partials_edge<double, std::vector<var>> {
120 public:
121 using Op = std::vector<var, arena_allocator<var>>;
123 partials_t partials_; // For univariate use-cases
125 explicit ops_partials_edge(const std::vector<var>& op)
126 : partials_(partials_t::Zero(op.size())),
127 partials_vec_(partials_),
128 operands_(op.begin(), op.end()) {}
129
131
132 inline int size() const noexcept { return this->operands_.size(); }
133 inline auto&& operand() noexcept { return std::move(this->operands_); }
134 inline auto& partial() noexcept { return this->partials_; }
135};
136
137template <typename Op>
139 public:
141 partials_t partials_; // For univariate use-cases
143 explicit ops_partials_edge(const Op& ops)
144 : partials_(partials_t::Zero(ops.rows(), ops.cols())),
145 partials_vec_(partials_),
146 operands_(to_arena(ops)) {}
147
149 const ops_partials_edge<double, Op, require_eigen_st<is_var, Op>>& other)
150 : partials_(other.partials_),
151 partials_vec_(partials_),
152 operands_(other.operands_) {}
153
154 inline auto& partial() noexcept { return partials_; }
155 inline auto& operand() noexcept { return operands_; }
157 inline auto size() const noexcept { return this->operands_.size(); }
158};
159
160template <typename Op>
162 public:
164 partials_t partials_; // For univariate use-cases
166 explicit ops_partials_edge(const var_value<Op>& ops)
167 : partials_(
168 plain_type_t<partials_t>::Zero(ops.vi_->rows(), ops.vi_->cols())),
169 partials_vec_(partials_),
170 operands_(ops) {}
171
174 other)
175 : partials_(other.partials_),
176 partials_vec_(partials_),
177 operands_(other.operands_) {}
178
179 inline auto& partial() noexcept { return partials_; }
180 inline auto& operand() noexcept { return operands_; }
181
183 static constexpr int size() { return 0; }
184};
185
186// SPECIALIZATIONS FOR MULTIVARIATE VECTORIZATIONS
187// (i.e. nested containers)
188template <int R, int C>
189class ops_partials_edge<double, std::vector<Eigen::Matrix<var, R, C>>> {
190 public:
192 using Op = std::vector<inner_op, arena_allocator<inner_op>>;
194 std::vector<partial_t, arena_allocator<partial_t>> partials_vec_;
195 explicit ops_partials_edge(const std::vector<Eigen::Matrix<var, R, C>>& ops)
196 : partials_vec_(ops.size()), operands_(ops.begin(), ops.end()) {
197 for (size_t i = 0; i < ops.size(); ++i) {
198 partials_vec_[i] = partial_t::Zero(ops[i].rows(), ops[i].cols());
199 }
200 }
201
203
204 inline int size() const noexcept {
205 if (unlikely(this->operands_.size() == 0)) {
206 return 0;
207 }
208 return this->operands_.size() * this->operands_[0].size();
209 }
210 inline auto&& operand() noexcept { return std::move(this->operands_); }
211 inline auto& partial() noexcept { return this->partials_vec_; }
212};
213
214template <>
215class ops_partials_edge<double, std::vector<std::vector<var>>> {
216 public:
217 using inner_vec = std::vector<var, arena_allocator<var>>;
218 using Op = std::vector<inner_vec, arena_allocator<inner_vec>>;
219 using partial_t = std::vector<double, arena_allocator<double>>;
220 std::vector<partial_t, arena_allocator<partial_t>> partials_vec_;
221 explicit ops_partials_edge(const std::vector<std::vector<var>>& ops)
222 : partials_vec_(stan::math::size(ops)), operands_(ops.size()) {
223 for (size_t i = 0; i < stan::math::size(ops); ++i) {
224 operands_[i] = inner_vec(ops[i].begin(), ops[i].end());
225 partials_vec_[i] = partial_t(stan::math::size(ops[i]), 0.0);
226 }
227 }
228
230 inline int size() const noexcept {
231 return this->operands_.size() * this->operands_[0].size();
232 }
233 inline auto&& operand() noexcept { return std::move(this->operands_); }
234 inline auto&& partial() noexcept { return std::move(this->partials_vec_); }
235};
236
237template <typename Op>
238class ops_partials_edge<double, std::vector<var_value<Op>>,
239 require_eigen_t<Op>> {
240 public:
241 using partials_t = std::vector<arena_t<Op>, arena_allocator<arena_t<Op>>>;
243 explicit ops_partials_edge(const std::vector<var_value<Op>>& ops)
244 : partials_vec_(ops.size()), operands_(ops.begin(), ops.end()) {
245 for (size_t i = 0; i < ops.size(); ++i) {
246 partials_vec_[i]
247 = plain_type_t<Op>::Zero(ops[i].vi_->rows(), ops[i].vi_->cols());
248 }
249 }
250
251 std::vector<var_value<Op>, arena_allocator<var_value<Op>>> operands_;
252
253 static constexpr int size() noexcept { return 0; }
254 inline auto&& operand() noexcept { return std::move(this->operands_); }
255 inline auto&& partial() noexcept { return std::move(this->partials_vec_); }
256};
257} // namespace internal
258
296template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
297 typename Op6, typename Op7, typename Op8>
298class operands_and_partials<Op1, Op2, Op3, Op4, Op5, Op6, Op7, Op8, var> {
299 public:
308
309 explicit operands_and_partials(const Op1& o1) : edge1_(o1) {}
310 operands_and_partials(const Op1& o1, const Op2& o2)
311 : edge1_(o1), edge2_(o2) {}
312 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3)
313 : edge1_(o1), edge2_(o2), edge3_(o3) {}
314 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
315 const Op4& o4)
316 : edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4) {}
317 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
318 const Op4& o4, const Op5& o5)
319 : edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4), edge5_(o5) {}
320 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
321 const Op4& o4, const Op5& o5, const Op6& o6)
322 : edge1_(o1),
323 edge2_(o2),
324 edge3_(o3),
325 edge4_(o4),
326 edge5_(o5),
327 edge6_(o6) {}
328 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
329 const Op4& o4, const Op5& o5, const Op6& o6,
330 const Op7& o7)
331 : edge1_(o1),
332 edge2_(o2),
333 edge3_(o3),
334 edge4_(o4),
335 edge5_(o5),
336 edge6_(o6),
337 edge7_(o7) {}
338 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
339 const Op4& o4, const Op5& o5, const Op6& o6,
340 const Op7& o7, const Op8& o8)
341 : edge1_(o1),
342 edge2_(o2),
343 edge3_(o3),
344 edge4_(o4),
345 edge5_(o5),
346 edge6_(o6),
347 edge7_(o7),
348 edge8_(o8) {}
349
363 var build(double value) {
364 return make_callback_var(
365 value, [operand1 = edge1_.operand(), partial1 = edge1_.partial(),
366 operand2 = edge2_.operand(), partial2 = edge2_.partial(),
367 operand3 = edge3_.operand(), partial3 = edge3_.partial(),
368 operand4 = edge4_.operand(), partial4 = edge4_.partial(),
369 operand5 = edge5_.operand(), partial5 = edge5_.partial(),
370 operand6 = edge6_.operand(), partial6 = edge6_.partial(),
371 operand7 = edge7_.operand(), partial7 = edge7_.partial(),
372 operand8 = edge8_.operand(),
373 partial8 = edge8_.partial()](const auto& vi) mutable {
374 if (!is_constant<Op1>::value) {
375 internal::update_adjoints(operand1, partial1, vi);
376 }
378 internal::update_adjoints(operand2, partial2, vi);
379 }
381 internal::update_adjoints(operand3, partial3, vi);
382 }
384 internal::update_adjoints(operand4, partial4, vi);
385 }
387 internal::update_adjoints(operand5, partial5, vi);
388 }
390 internal::update_adjoints(operand6, partial6, vi);
391 }
393 internal::update_adjoints(operand7, partial7, vi);
394 }
396 internal::update_adjoints(operand8, partial8, vi);
397 }
398 });
399 }
400};
401
402} // namespace math
403} // namespace stan
404#endif
ops_partials_edge(const ops_partials_edge< double, Op, require_eigen_st< is_var, Op > > &other)
ops_partials_edge(const ops_partials_edge< double, var > &other)
ops_partials_edge(const ops_partials_edge< double, var_value< Op >, require_eigen_t< Op > > &other)
An edge holds both the operands and its associated partial derivatives.
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6, const Op7 &o7)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6, const Op7 &o7, const Op8 &o8)
internal::ops_partials_edge< double, std::decay_t< Op2 > > edge2_
internal::ops_partials_edge< double, std::decay_t< Op7 > > edge7_
internal::ops_partials_edge< double, std::decay_t< Op4 > > edge4_
internal::ops_partials_edge< double, std::decay_t< Op6 > > edge6_
internal::ops_partials_edge< double, std::decay_t< Op1 > > edge1_
internal::ops_partials_edge< double, std::decay_t< Op8 > > edge8_
internal::ops_partials_edge< double, std::decay_t< Op5 > > edge5_
internal::ops_partials_edge< double, std::decay_t< Op3 > > edge3_
This template builds partial derivatives with respect to a set of operands.
#define unlikely(x)
require_t< std::is_arithmetic< scalar_type_t< std::decay_t< T > > > > require_st_arithmetic
Require scalar type satisfies std::is_arithmetic.
require_t< container_type_check_base< is_eigen, scalar_type_t, TypeCheck, Check... > > require_eigen_st
Require type satisfies is_eigen.
Definition is_eigen.hpp:151
require_t< is_eigen< std::decay_t< T > > > require_eigen_t
Require type satisfies is_eigen.
Definition is_eigen.hpp:55
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
int rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:21
int cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
Definition cols.hpp:20
require_t< is_rev_matrix< std::decay_t< T > > > require_rev_matrix_t
Require type satisfies is_rev_matrix.
require_t< is_std_vector< std::decay_t< T > > > require_std_vector_t
Require type satisfies is_std_vector.
size_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:18
var build(double value)
Build the node to be stored on the autodiff graph.
void update_adjoints(var_value< T1 > &x, const T2 &y, const vari &z)
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Defines a static member named value which is defined to be false as the primitive scalar types cannot...
Definition is_var.hpp:14
std library compatible allocator that uses AD stack.