Automatic Differentiation
 
Loading...
Searching...
No Matches
operands_and_partials.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_META_OPERANDS_AND_PARTIALS_HPP
2#define STAN_MATH_REV_META_OPERANDS_AND_PARTIALS_HPP
3
20#include <vector>
21#include <tuple>
22
23namespace stan {
24namespace math {
25
26namespace internal {
27
28// OpenCL
29template <typename T1, typename T2,
30 require_all_kernel_expressions_and_none_scalar_t<T1, T2>* = nullptr>
31inline void update_adjoints(var_value<T1>& x, const T2& y, const vari& z) {
32 x.adj() += z.adj() * y;
33}
34
35template <typename T1, typename T2,
37inline void update_adjoints(var_value<T1>& x, const T2& y, const var& z) {
38 x.adj() += z.adj() * y;
39}
40
41// Scalars
42template <typename Scalar1, typename Scalar2, require_var_t<Scalar1>* = nullptr,
43 require_not_var_matrix_t<Scalar1>* = nullptr,
44 require_arithmetic_t<Scalar2>* = nullptr>
45inline void update_adjoints(Scalar1 x, Scalar2 y, const vari& z) noexcept {
46 x.adj() += z.adj() * y;
47}
48
49template <typename Scalar1, typename Scalar2, require_var_t<Scalar1>* = nullptr,
50 require_not_var_matrix_t<Scalar1>* = nullptr,
51 require_arithmetic_t<Scalar2>* = nullptr>
52inline void update_adjoints(Scalar1 x, Scalar2 y, const var& z) noexcept {
53 x.adj() += z.adj() * y;
54}
55
56// Matrix
57template <typename Matrix1, typename Matrix2,
60inline void update_adjoints(Matrix1& x, const Matrix2& y, const vari& z) {
61 x.adj().array() += z.adj() * y.array();
62}
63
64template <typename Matrix1, typename Matrix2,
67inline void update_adjoints(Matrix1& x, const Matrix2& y, const var& z) {
68 x.adj().array() += z.adj() * y.array();
69}
70
71template <typename Arith, typename Alt, require_st_arithmetic<Arith>* = nullptr>
72inline constexpr void update_adjoints(Arith&& /* x */, Alt&& /* y */,
73 const vari& /* z */) noexcept {}
74
75template <typename Arith, typename Alt, require_st_arithmetic<Arith>* = nullptr>
76inline constexpr void update_adjoints(Arith&& /* x */, Alt&& /* y */,
77 const var& /* z */) noexcept {}
78
79// Vectors
80template <typename StdVec1, typename Vec2,
83inline void update_adjoints(StdVec1& x, const Vec2& y, const vari& z) {
84 for (size_t i = 0; i < x.size(); ++i) {
85 update_adjoints(x[i], y[i], z);
86 }
87}
88
89template <typename StdVec1, typename Vec2,
92inline void update_adjoints(StdVec1& x, const Vec2& y, const var& z) {
93 for (size_t i = 0; i < x.size(); ++i) {
94 update_adjoints(x[i], y[i], z);
95 }
96}
97
101template <>
102class ops_partials_edge<double, var> {
103 public:
104 double partial_{0};
105 broadcast_array<double> partials_{partial_};
106 explicit ops_partials_edge(const var& op) noexcept : operands_(op) {}
108 : partial_(other.partial_),
109 partials_(partial_),
110 operands_(other.operands_) {}
111 inline auto& partial() { return partial_; }
112 inline auto& operand() noexcept { return operands_; }
113
115 static constexpr int size() { return 1; }
116};
117// Vectorized Univariate
118// Vectorized Univariate
119template <>
120class ops_partials_edge<double, std::vector<var>> {
121 public:
122 using Op = std::vector<var, arena_allocator<var>>;
124 partials_t partials_; // For univariate use-cases
126 explicit ops_partials_edge(const std::vector<var>& op)
127 : partials_(partials_t::Zero(op.size())),
128 partials_vec_(partials_),
129 operands_(op.begin(), op.end()) {}
130
132
133 inline int size() const noexcept { return this->operands_.size(); }
134 inline auto&& operand() noexcept { return std::move(this->operands_); }
135 inline auto& partial() noexcept { return this->partials_; }
136};
137
138template <typename Op>
140 public:
142 partials_t partials_; // For univariate use-cases
144 explicit ops_partials_edge(const Op& ops)
145 : partials_(partials_t::Zero(ops.rows(), ops.cols())),
146 partials_vec_(partials_),
147 operands_(to_arena(ops)) {}
148
150 const ops_partials_edge<double, Op, require_eigen_st<is_var, Op>>& other)
151 : partials_(other.partials_),
152 partials_vec_(partials_),
153 operands_(other.operands_) {}
154
155 inline auto& partial() noexcept { return partials_; }
156 inline auto& operand() noexcept { return operands_; }
158 inline auto size() const noexcept { return this->operands_.size(); }
159};
160
161template <typename Op>
163 public:
165 partials_t partials_; // For univariate use-cases
167 explicit ops_partials_edge(const var_value<Op>& ops)
168 : partials_(
169 plain_type_t<partials_t>::Zero(ops.vi_->rows(), ops.vi_->cols())),
170 partials_vec_(partials_),
171 operands_(ops) {}
172
175 other)
176 : partials_(other.partials_),
177 partials_vec_(partials_),
178 operands_(other.operands_) {}
179
180 inline auto& partial() noexcept { return partials_; }
181 inline auto& operand() noexcept { return operands_; }
182
184 static constexpr int size() { return 0; }
185};
186
187// SPECIALIZATIONS FOR MULTIVARIATE VECTORIZATIONS
188// (i.e. nested containers)
189template <int R, int C>
190class ops_partials_edge<double, std::vector<Eigen::Matrix<var, R, C>>> {
191 public:
193 using Op = std::vector<inner_op, arena_allocator<inner_op>>;
195 std::vector<partial_t, arena_allocator<partial_t>> partials_vec_;
196 explicit ops_partials_edge(const std::vector<Eigen::Matrix<var, R, C>>& ops)
197 : partials_vec_(ops.size()), operands_(ops.begin(), ops.end()) {
198 for (size_t i = 0; i < ops.size(); ++i) {
199 partials_vec_[i] = partial_t::Zero(ops[i].rows(), ops[i].cols());
200 }
201 }
202
204
205 inline int size() const noexcept {
206 if (unlikely(this->operands_.size() == 0)) {
207 return 0;
208 }
209 return this->operands_.size() * this->operands_[0].size();
210 }
211 inline auto&& operand() noexcept { return std::move(this->operands_); }
212 inline auto& partial() noexcept { return this->partials_vec_; }
213};
214
215template <>
216class ops_partials_edge<double, std::vector<std::vector<var>>> {
217 public:
218 using inner_vec = std::vector<var, arena_allocator<var>>;
219 using Op = std::vector<inner_vec, arena_allocator<inner_vec>>;
220 using partial_t = std::vector<double, arena_allocator<double>>;
221 std::vector<partial_t, arena_allocator<partial_t>> partials_vec_;
222 explicit ops_partials_edge(const std::vector<std::vector<var>>& ops)
223 : partials_vec_(stan::math::size(ops)), operands_(ops.size()) {
224 for (size_t i = 0; i < stan::math::size(ops); ++i) {
225 operands_[i] = inner_vec(ops[i].begin(), ops[i].end());
226 partials_vec_[i] = partial_t(stan::math::size(ops[i]), 0.0);
227 }
228 }
229
231 inline int size() const noexcept {
232 return this->operands_.size() * this->operands_[0].size();
233 }
234 inline auto&& operand() noexcept { return std::move(this->operands_); }
235 inline auto&& partial() noexcept { return std::move(this->partials_vec_); }
236};
237
238template <typename Op>
239class ops_partials_edge<double, std::vector<var_value<Op>>,
240 require_eigen_t<Op>> {
241 public:
242 using partials_t = std::vector<arena_t<Op>, arena_allocator<arena_t<Op>>>;
244 explicit ops_partials_edge(const std::vector<var_value<Op>>& ops)
245 : partials_vec_(ops.size()), operands_(ops.begin(), ops.end()) {
246 for (size_t i = 0; i < ops.size(); ++i) {
247 partials_vec_[i]
248 = plain_type_t<Op>::Zero(ops[i].vi_->rows(), ops[i].vi_->cols());
249 }
250 }
251
252 std::vector<var_value<Op>, arena_allocator<var_value<Op>>> operands_;
253
254 static constexpr int size() noexcept { return 0; }
255 inline auto&& operand() noexcept { return std::move(this->operands_); }
256 inline auto&& partial() noexcept { return std::move(this->partials_vec_); }
257};
258} // namespace internal
259
297template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
298 typename Op6, typename Op7, typename Op8>
299class operands_and_partials<Op1, Op2, Op3, Op4, Op5, Op6, Op7, Op8, var> {
300 public:
309
310 explicit operands_and_partials(const Op1& o1) : edge1_(o1) {}
311 operands_and_partials(const Op1& o1, const Op2& o2)
312 : edge1_(o1), edge2_(o2) {}
313 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3)
314 : edge1_(o1), edge2_(o2), edge3_(o3) {}
315 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
316 const Op4& o4)
317 : edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4) {}
318 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
319 const Op4& o4, const Op5& o5)
320 : edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4), edge5_(o5) {}
321 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
322 const Op4& o4, const Op5& o5, const Op6& o6)
323 : edge1_(o1),
324 edge2_(o2),
325 edge3_(o3),
326 edge4_(o4),
327 edge5_(o5),
328 edge6_(o6) {}
329 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
330 const Op4& o4, const Op5& o5, const Op6& o6,
331 const Op7& o7)
332 : edge1_(o1),
333 edge2_(o2),
334 edge3_(o3),
335 edge4_(o4),
336 edge5_(o5),
337 edge6_(o6),
338 edge7_(o7) {}
339 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
340 const Op4& o4, const Op5& o5, const Op6& o6,
341 const Op7& o7, const Op8& o8)
342 : edge1_(o1),
343 edge2_(o2),
344 edge3_(o3),
345 edge4_(o4),
346 edge5_(o5),
347 edge6_(o6),
348 edge7_(o7),
349 edge8_(o8) {}
350
364 var build(double value) {
365 return make_callback_var(
366 value, [operand1 = edge1_.operand(), partial1 = edge1_.partial(),
367 operand2 = edge2_.operand(), partial2 = edge2_.partial(),
368 operand3 = edge3_.operand(), partial3 = edge3_.partial(),
369 operand4 = edge4_.operand(), partial4 = edge4_.partial(),
370 operand5 = edge5_.operand(), partial5 = edge5_.partial(),
371 operand6 = edge6_.operand(), partial6 = edge6_.partial(),
372 operand7 = edge7_.operand(), partial7 = edge7_.partial(),
373 operand8 = edge8_.operand(),
374 partial8 = edge8_.partial()](const auto& vi) mutable {
375 if constexpr (is_autodiff_v<Op1>) {
376 internal::update_adjoints(operand1, partial1, vi);
377 }
378 if constexpr (is_autodiff_v<Op2>) {
379 internal::update_adjoints(operand2, partial2, vi);
380 }
381 if constexpr (is_autodiff_v<Op3>) {
382 internal::update_adjoints(operand3, partial3, vi);
383 }
384 if constexpr (is_autodiff_v<Op4>) {
385 internal::update_adjoints(operand4, partial4, vi);
386 }
387 if constexpr (is_autodiff_v<Op5>) {
388 internal::update_adjoints(operand5, partial5, vi);
389 }
390 if constexpr (is_autodiff_v<Op6>) {
391 internal::update_adjoints(operand6, partial6, vi);
392 }
393 if constexpr (is_autodiff_v<Op7>) {
394 internal::update_adjoints(operand7, partial7, vi);
395 }
396 if constexpr (is_autodiff_v<Op8>) {
397 internal::update_adjoints(operand8, partial8, vi);
398 }
399 });
400 }
401};
402
403} // namespace math
404} // namespace stan
405#endif
ops_partials_edge(const ops_partials_edge< double, Op, require_eigen_st< is_var, Op > > &other)
ops_partials_edge(const ops_partials_edge< double, var > &other)
ops_partials_edge(const ops_partials_edge< double, var_value< Op >, require_eigen_t< Op > > &other)
An edge holds both the operands and its associated partial derivatives.
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6, const Op7 &o7)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6, const Op7 &o7, const Op8 &o8)
This template builds partial derivatives with respect to a set of operands.
#define unlikely(x)
require_t< std::is_arithmetic< scalar_type_t< std::decay_t< T > > > > require_st_arithmetic
Require scalar type satisfies std::is_arithmetic.
require_t< container_type_check_base< is_eigen, scalar_type_t, TypeCheck, Check... > > require_eigen_st
Require type satisfies is_eigen.
Definition is_eigen.hpp:209
require_t< is_eigen< std::decay_t< T > > > require_eigen_t
Require type satisfies is_eigen.
Definition is_eigen.hpp:113
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
int64_t cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
Definition cols.hpp:21
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:22
require_t< is_rev_matrix< std::decay_t< T > > > require_rev_matrix_t
Require type satisfies is_rev_matrix.
require_t< is_std_vector< std::decay_t< T > > > require_std_vector_t
Require type satisfies is_std_vector.
var build(double value)
Build the node to be stored on the autodiff graph.
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:19
void update_adjoints(var_value< T1 > &x, const T2 &y, const vari &z)
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
typename plain_type< std::decay_t< T > >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
Defines a static member named value which is defined to be false as the primitive scalar types cannot...
Definition is_var.hpp:14
std library compatible allocator that uses AD stack.