Automatic Differentiation
 
Loading...
Searching...
No Matches
operands_and_partials.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_FWD_META_OPERANDS_AND_PARTIALS_HPP
2#define STAN_MATH_FWD_META_OPERANDS_AND_PARTIALS_HPP
3
16#include <vector>
17
18namespace stan {
19namespace math {
20namespace internal {
21
25static constexpr auto sum_dx() { return static_cast<double>(0.0); }
26
32template <typename T1>
33inline auto sum_dx(T1& a) {
34 return a.dx();
35}
36
46template <typename T1, typename T2, typename... Types>
47inline auto sum_dx(T1& a, T2& b, Types&... args) {
48 return a.dx() + b.dx() + sum_dx(args...);
49}
50
51template <typename InnerType, typename T>
52class ops_partials_edge<InnerType, T, require_fvar_t<T>> {
53 public:
54 using Op = std::decay_t<T>;
55 using Dx = std::decay_t<InnerType>;
56 Dx partial_{0};
57 broadcast_array<Dx> partials_{partial_};
58
59 explicit ops_partials_edge(const T& op)
60 : partial_(0), partials_(partial_), operands_(op) {}
61
63 const ops_partials_edge<InnerType, T, require_fvar_t<T>>& other)
64 : partial_(other.partial_),
65 partials_(partial_),
66 operands_(other.operands_) {}
67
69 ops_partials_edge<InnerType, T, require_fvar_t<T>>&& other)
70 : partial_(other.partial_),
71 partials_(partial_),
72 operands_(other.operands_) {}
73
74 const Op& operands_;
75
76 inline Dx dx() { return this->partial_ * this->operands_.d_; }
77};
78
79// Vectorized Univariate
80template <typename InnerType, typename T>
82 public:
83 using Op = std::decay_t<T>;
84 using Dx = std::decay_t<InnerType>;
85 using partials_t = Eigen::Matrix<Dx, -1, 1>;
86 partials_t partials_; // For univariate use-cases
87 broadcast_array<partials_t> partials_vec_{partials_}; // For multivariate
88 explicit ops_partials_edge(const Op& ops)
89 : partials_(partials_t::Zero(ops.size()).eval()), operands_(ops) {}
90
93 other)
94 : partials_(other.partials_),
95 partials_vec_(partials_),
96 operands_(other.operands_) {}
97
100 other)
101 : partials_(std::move(other.partials_)),
102 partials_vec_(partials_),
103 operands_(other.operands_) {}
104
105 const Op& operands_;
106 inline Dx dx() {
107 return dot_product(as_column_vector_or_scalar(this->partials_),
108 as_column_vector_or_scalar(this->operands_).d());
109 }
110};
111
112template <typename Dx, typename ViewElt>
113class ops_partials_edge<Dx, ViewElt, require_eigen_vt<is_fvar, ViewElt>> {
114 public:
117 partials_t partials_; // For univariate use-cases
118 broadcast_array<partials_t> partials_vec_{partials_}; // For multivariate
119 template <typename OpT, require_eigen_vt<is_fvar, OpT>* = nullptr>
120 explicit ops_partials_edge(const OpT& ops)
121 : partials_(partials_t::Zero(ops.rows(), ops.cols())), operands_(ops) {}
122
125 other)
126 : partials_(other.partials_),
127 partials_vec_(partials_),
128 operands_(other.operands_) {}
129
132 other)
133 : partials_(std::move(other.partials_)),
134 partials_vec_(partials_),
135 operands_(other.operands_) {}
136
138
139 inline Dx dx() {
140 return sum(elt_multiply(this->partials_, this->operands_.d()));
141 }
142};
143
144// Multivariate; vectors of eigen types
145template <typename Dx, int R, int C>
146class ops_partials_edge<Dx, std::vector<Eigen::Matrix<fvar<Dx>, R, C>>> {
147 public:
148 using Op = std::vector<Eigen::Matrix<fvar<Dx>, R, C>>;
149 using partial_t = Eigen::Matrix<Dx, R, C>;
150 std::vector<partial_t> partials_vec_;
151 explicit ops_partials_edge(const Op& ops)
152 : partials_vec_(ops.size()), operands_(ops) {
153 for (size_t i = 0; i < ops.size(); ++i) {
154 partials_vec_[i] = partial_t::Zero(ops[i].rows(), ops[i].cols());
155 }
156 }
157
159
160 inline Dx dx() {
161 Dx derivative(0);
162 for (size_t i = 0; i < this->operands_.size(); ++i) {
164 += sum(elt_multiply(this->partials_vec_[i], this->operands_[i].d()));
165 }
166 return derivative;
167 }
168};
169
170template <typename Dx>
171class ops_partials_edge<Dx, std::vector<std::vector<fvar<Dx>>>> {
172 public:
173 using Op = std::vector<std::vector<fvar<Dx>>>;
174 using partial_t = std::vector<Dx>;
175 std::vector<partial_t> partials_vec_;
176 explicit ops_partials_edge(const Op& ops)
177 : partials_vec_(stan::math::size(ops)), operands_(ops) {
178 for (size_t i = 0; i < stan::math::size(ops); ++i) {
179 partials_vec_[i] = partial_t(stan::math::size(ops[i]), 0.0);
180 }
181 }
182
184 inline Dx dx() {
185 Dx derivative(0);
186 for (size_t i = 0; i < this->operands_.size(); ++i) {
187 for (size_t j = 0; j < this->operands_[i].size(); ++j) {
188 derivative += this->partials_vec_[i][j] * this->operands_[i][j].d_;
189 }
190 }
191 return derivative;
192 }
193};
194
195} // namespace internal
196
236template <typename Op1, typename Op2, typename Op3, typename Op4, typename Op5,
237 typename Op6, typename Op7, typename Op8, typename Dx>
238class operands_and_partials<Op1, Op2, Op3, Op4, Op5, Op6, Op7, Op8, fvar<Dx>> {
239 public:
249 explicit operands_and_partials(const Op1& o1) : edge1_(o1) {}
250 operands_and_partials(const Op1& o1, const Op2& o2)
251 : edge1_(o1), edge2_(o2) {}
252 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3)
253 : edge1_(o1), edge2_(o2), edge3_(o3) {}
254 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
255 const Op4& o4)
256 : edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4) {}
257 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
258 const Op4& o4, const Op5& o5)
259 : edge1_(o1), edge2_(o2), edge3_(o3), edge4_(o4), edge5_(o5) {}
260 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
261 const Op4& o4, const Op5& o5, const Op6& o6)
262 : edge1_(o1),
263 edge2_(o2),
264 edge3_(o3),
265 edge4_(o4),
266 edge5_(o5),
267 edge6_(o6) {}
268 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
269 const Op4& o4, const Op5& o5, const Op6& o6,
270 const Op7& o7)
271 : edge1_(o1),
272 edge2_(o2),
273 edge3_(o3),
274 edge4_(o4),
275 edge5_(o5),
276 edge6_(o6),
277 edge7_(o7) {}
278 operands_and_partials(const Op1& o1, const Op2& o2, const Op3& o3,
279 const Op4& o4, const Op5& o5, const Op6& o6,
280 const Op7& o7, const Op8& o8)
281 : edge1_(o1),
282 edge2_(o2),
283 edge3_(o3),
284 edge4_(o4),
285 edge5_(o5),
286 edge6_(o6),
287 edge7_(o7),
288 edge8_(o8) {}
289
304 Dx deriv = edge1_.dx() + edge2_.dx() + edge3_.dx() + edge4_.dx()
305 + edge5_.dx() + edge6_.dx() + edge7_.dx() + edge8_.dx();
306 return T_return_type(value, deriv);
307 }
308};
309
310} // namespace math
311} // namespace stan
312#endif
ops_partials_edge(ops_partials_edge< Dx, ViewElt, require_eigen_vt< is_fvar, ViewElt > > &&other)
ops_partials_edge(const ops_partials_edge< Dx, ViewElt, require_eigen_vt< is_fvar, ViewElt > > &other)
ops_partials_edge(ops_partials_edge< InnerType, T, require_fvar_t< T > > &&other)
ops_partials_edge(const ops_partials_edge< InnerType, T, require_fvar_t< T > > &other)
ops_partials_edge(ops_partials_edge< InnerType, T, require_std_vector_vt< is_fvar, T > > &&other)
ops_partials_edge(const ops_partials_edge< InnerType, T, require_std_vector_vt< is_fvar, T > > &other)
An edge holds both the operands and its associated partial derivatives.
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6, const Op7 &o7, const Op8 &o8)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5, const Op6 &o6, const Op7 &o7)
operands_and_partials(const Op1 &o1, const Op2 &o2, const Op3 &o3, const Op4 &o4, const Op5 &o5)
internal::ops_partials_edge< double, std::decay_t< Op2 > > edge2_
internal::ops_partials_edge< double, std::decay_t< Op7 > > edge7_
internal::ops_partials_edge< double, std::decay_t< Op4 > > edge4_
internal::ops_partials_edge< double, std::decay_t< Op6 > > edge6_
internal::ops_partials_edge< double, std::decay_t< Op1 > > edge1_
internal::ops_partials_edge< double, std::decay_t< Op8 > > edge8_
internal::ops_partials_edge< double, std::decay_t< Op5 > > edge5_
internal::ops_partials_edge< double, std::decay_t< Op3 > > edge3_
This template builds partial derivatives with respect to a set of operands.
require_t< container_type_check_base< is_eigen, value_type_t, TypeCheck, Check... > > require_eigen_vt
Require type satisfies is_eigen.
Definition is_eigen.hpp:97
require_t< is_fvar< std::decay_t< T > > > require_fvar_t
Require type satisfies is_fvar.
Definition is_fvar.hpp:25
elt_multiply_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_multiply(T_a &&a, T_b &&b)
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
int rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:21
int cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
Definition cols.hpp:20
require_t< container_type_check_base< is_std_vector, value_type_t, TypeCheck, Check... > > require_std_vector_vt
Require type satisfies is_std_vector.
T_return_type build(Dx value)
Build the node to be stored on the autodiff graph.
size_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:18
static constexpr auto sum_dx()
End of recursion for summing .dx() for fvar<T> ops and partials.
void derivative(const F &f, const T &x, T &fx, T &dfx_dx)
Return the derivative of the specified univariate function at the specified argument.
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
Definition eval.hpp:20
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
auto dot_product(const T_a &a, const T_b &b)
Returns the dot product of the specified vectors.
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Defines a static member function type which is defined to be false as the primitive scalar types cann...
Definition is_fvar.hpp:15
This template class represents scalars used in forward-mode automatic differentiation,...
Definition fvar.hpp:40