Automatic Differentiation
 
Loading...
Searching...
No Matches
trace_gen_quad_form.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_TRACE_GEN_QUAD_FORM_HPP
2#define STAN_MATH_REV_FUN_TRACE_GEN_QUAD_FORM_HPP
3
13#include <type_traits>
14
15namespace stan {
16namespace math {
17namespace internal {
18
19template <typename Td, int Rd, int Cd, typename Ta, int Ra, int Ca, typename Tb,
20 int Rb, int Cb>
22 public:
23 trace_gen_quad_form_vari_alloc(const Eigen::Matrix<Td, Rd, Cd>& D,
24 const Eigen::Matrix<Ta, Ra, Ca>& A,
25 const Eigen::Matrix<Tb, Rb, Cb>& B)
26 : D_(D), A_(A), B_(B) {}
27
28 double compute() {
30 }
31
32 Eigen::Matrix<Td, Rd, Cd> D_;
33 Eigen::Matrix<Ta, Ra, Ca> A_;
34 Eigen::Matrix<Tb, Rb, Cb> B_;
35};
36
37template <typename Td, int Rd, int Cd, typename Ta, int Ra, int Ca, typename Tb,
38 int Rb, int Cb>
40 protected:
41 static inline void computeAdjoints(double adj,
42 const Eigen::Matrix<double, Rd, Cd>& D,
43 const Eigen::Matrix<double, Ra, Ca>& A,
44 const Eigen::Matrix<double, Rb, Cb>& B,
45 Eigen::Matrix<var, Rd, Cd>* varD,
46 Eigen::Matrix<var, Ra, Ca>* varA,
47 Eigen::Matrix<var, Rb, Cb>* varB) {
48 Eigen::Matrix<double, Ca, Cb> AtB;
49 Eigen::Matrix<double, Ra, Cb> BD;
50 if (varB || varA) {
51 BD.noalias() = B * D;
52 }
53 if (varB || varD) {
54 AtB.noalias() = A.transpose() * B;
55 }
56
57 if (varB) {
58 (*varB).adj() += adj * (A * BD + AtB * D.transpose());
59 }
60 if (varA) {
61 (*varA).adj() += adj * (B * BD.transpose());
62 }
63 if (varD) {
64 (*varD).adj() += adj * (B.transpose() * AtB);
65 }
66 }
67
68 public:
71 : vari(impl->compute()), impl_(impl) {}
72
73 virtual void chain() {
75 value_of(impl_->B_),
76 reinterpret_cast<Eigen::Matrix<var, Rd, Cd>*>(
77 std::is_same<Td, var>::value ? (&impl_->D_) : NULL),
78 reinterpret_cast<Eigen::Matrix<var, Ra, Ca>*>(
79 std::is_same<Ta, var>::value ? (&impl_->A_) : NULL),
80 reinterpret_cast<Eigen::Matrix<var, Rb, Cb>*>(
81 std::is_same<Tb, var>::value ? (&impl_->B_) : NULL));
82 }
83
85};
86} // namespace internal
87
88template <typename Td, typename Ta, typename Tb,
92inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) {
93 using Td_scal = value_type_t<Td>;
94 using Ta_scal = value_type_t<Ta>;
95 using Tb_scal = value_type_t<Tb>;
96 constexpr int Rd = Td::RowsAtCompileTime;
97 constexpr int Cd = Td::ColsAtCompileTime;
98 constexpr int Ra = Ta::RowsAtCompileTime;
99 constexpr int Ca = Ta::ColsAtCompileTime;
100 constexpr int Rb = Tb::RowsAtCompileTime;
101 constexpr int Cb = Tb::ColsAtCompileTime;
102 check_square("trace_gen_quad_form", "A", A);
103 check_square("trace_gen_quad_form", "D", D);
104 check_multiplicable("trace_gen_quad_form", "A", A, "B", B);
105 check_multiplicable("trace_gen_quad_form", "B", B, "D", D);
106
107 auto* baseVari
108 = new internal::trace_gen_quad_form_vari_alloc<Td_scal, Rd, Cd, Ta_scal,
109 Ra, Ca, Tb_scal, Rb, Cb>(
110 D, A, B);
111
112 return var(
113 new internal::trace_gen_quad_form_vari<Td_scal, Rd, Cd, Ta_scal, Ra, Ca,
114 Tb_scal, Rb, Cb>(baseVari));
115}
116
136template <typename Td, typename Ta, typename Tb,
139inline var trace_gen_quad_form(const Td& D, const Ta& A, const Tb& B) {
140 check_square("trace_gen_quad_form", "A", A);
141 check_square("trace_gen_quad_form", "D", D);
142 check_multiplicable("trace_gen_quad_form", "A", A, "B", B);
143 check_multiplicable("trace_gen_quad_form", "B", B, "D", D);
144
150
151 auto arena_BDT = to_arena(arena_B.val_op() * arena_D.val_op().transpose());
152 auto arena_AB = to_arena(arena_A.val_op() * arena_B.val_op());
153
154 var res = (arena_BDT.transpose() * arena_AB).trace();
155
157 [arena_A, arena_B, arena_D, arena_BDT, arena_AB, res]() mutable {
158 double C_adj = res.adj();
159
160 arena_A.adj() += C_adj * arena_BDT * arena_B.val_op().transpose();
161
162 arena_B.adj() += C_adj
163 * (arena_AB * arena_D.val_op()
164 + arena_A.val_op().transpose() * arena_BDT);
165
166 arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val_op());
167 });
168
169 return res;
170 } else if (!is_constant<Ta>::value && !is_constant<Tb>::value
171 && is_constant<Td>::value) {
172 arena_t<promote_scalar_t<double, Td>> arena_D = value_of(D);
173 arena_t<promote_scalar_t<var, Ta>> arena_A = A;
174 arena_t<promote_scalar_t<var, Tb>> arena_B = B;
175
176 auto arena_BDT = to_arena(arena_B.val_op() * arena_D.transpose());
177 auto arena_AB = to_arena(arena_A.val_op() * arena_B.val_op());
178
179 var res = (arena_BDT.transpose() * arena_AB).trace();
180
181 reverse_pass_callback([arena_A, arena_B, arena_D, arena_BDT, arena_AB,
182 res]() mutable {
183 double C_adj = res.adj();
184
185 arena_A.adj() += C_adj * arena_BDT * arena_B.val_op().transpose();
186 arena_B.adj()
187 += C_adj
188 * (arena_AB * arena_D + arena_A.val_op().transpose() * arena_BDT);
189 });
190
191 return res;
192 } else if (!is_constant<Ta>::value && is_constant<Tb>::value
193 && !is_constant<Td>::value) {
194 arena_t<promote_scalar_t<var, Td>> arena_D = D;
195 arena_t<promote_scalar_t<var, Ta>> arena_A = A;
196 arena_t<promote_scalar_t<double, Tb>> arena_B = value_of(B);
197
198 auto arena_BDT = to_arena(arena_B.val_op() * arena_D.val_op().transpose());
199 auto arena_AB = to_arena(arena_A.val_op() * arena_B.val_op());
200
201 var res = (arena_BDT.transpose() * arena_A.val_op() * arena_B).trace();
202
204 [arena_A, arena_B, arena_D, arena_BDT, arena_AB, res]() mutable {
205 double C_adj = res.adj();
206
207 arena_A.adj() += C_adj * arena_BDT * arena_B.transpose();
208 arena_D.adj() += C_adj * arena_AB.transpose() * arena_B;
209 });
210
211 return res;
212 } else if (!is_constant<Ta>::value && is_constant<Tb>::value
213 && is_constant<Td>::value) {
214 arena_t<promote_scalar_t<double, Td>> arena_D = value_of(D);
215 arena_t<promote_scalar_t<var, Ta>> arena_A = A;
216 arena_t<promote_scalar_t<double, Tb>> arena_B = value_of(B);
217
218 auto arena_BDT = to_arena(arena_B * arena_D);
219
220 var res = (arena_BDT.transpose() * arena_A.val_op() * arena_B).trace();
221
222 reverse_pass_callback([arena_A, arena_B, arena_BDT, res]() mutable {
223 arena_A.adj() += res.adj() * arena_BDT * arena_B.val_op().transpose();
224 });
225
226 return res;
227 } else if (is_constant<Ta>::value && !is_constant<Tb>::value
228 && !is_constant<Td>::value) {
229 arena_t<promote_scalar_t<var, Td>> arena_D = D;
230 arena_t<promote_scalar_t<double, Ta>> arena_A = value_of(A);
231 arena_t<promote_scalar_t<var, Tb>> arena_B = B;
232
233 auto arena_AB = to_arena(arena_A * arena_B.val_op());
234 auto arena_BDT = to_arena(arena_B.val_op() * arena_D.val_op());
235
236 var res = (arena_BDT.transpose() * arena_AB).trace();
237
238 reverse_pass_callback([arena_A, arena_B, arena_D, arena_AB, arena_BDT,
239 res]() mutable {
240 double C_adj = res.adj();
241
242 arena_B.adj()
243 += C_adj
244 * (arena_AB * arena_D.val_op() + arena_A.transpose() * arena_BDT);
245
246 arena_D.adj() += C_adj * (arena_AB.transpose() * arena_B.val_op());
247 });
248
249 return res;
250 } else if (is_constant<Ta>::value && !is_constant<Tb>::value
251 && is_constant<Td>::value) {
252 arena_t<promote_scalar_t<double, Td>> arena_D = value_of(D);
253 arena_t<promote_scalar_t<double, Ta>> arena_A = value_of(A);
254 arena_t<promote_scalar_t<var, Tb>> arena_B = B;
255
256 auto arena_AB = to_arena(arena_A * arena_B.val_op());
257 auto arena_BDT = to_arena(arena_B.val_op() * arena_D.val_op());
258
259 var res = (arena_BDT.transpose() * arena_AB).trace();
260
262 [arena_A, arena_B, arena_D, arena_AB, arena_BDT, res]() mutable {
263 arena_B.adj() += res.adj()
264 * (arena_AB * arena_D.val_op()
265 + arena_A.val_op().transpose() * arena_BDT);
266 });
267
268 return res;
269 } else if (is_constant<Ta>::value && is_constant<Tb>::value
270 && !is_constant<Td>::value) {
271 arena_t<promote_scalar_t<var, Td>> arena_D = D;
272 arena_t<promote_scalar_t<double, Ta>> arena_A = value_of(A);
273 arena_t<promote_scalar_t<double, Tb>> arena_B = value_of(B);
274
275 auto arena_AB = to_arena(arena_A * arena_B);
276
277 var res = (arena_D.val_op() * arena_B.transpose() * arena_AB).trace();
278
279 reverse_pass_callback([arena_AB, arena_B, arena_D, res]() mutable {
280 arena_D.adj() += res.adj() * (arena_AB.transpose() * arena_B);
281 });
282
283 return res;
284 }
285}
286
287} // namespace math
288} // namespace stan
289#endif
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
trace_gen_quad_form_vari_alloc(const Eigen::Matrix< Td, Rd, Cd > &D, const Eigen::Matrix< Ta, Ra, Ca > &A, const Eigen::Matrix< Tb, Rb, Cb > &B)
trace_gen_quad_form_vari(trace_gen_quad_form_vari_alloc< Td, Rd, Cd, Ta, Ra, Ca, Tb, Rb, Cb > *impl)
static void computeAdjoints(double adj, const Eigen::Matrix< double, Rd, Cd > &D, const Eigen::Matrix< double, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &B, Eigen::Matrix< var, Rd, Cd > *varD, Eigen::Matrix< var, Ra, Ca > *varA, Eigen::Matrix< var, Rb, Cb > *varB)
trace_gen_quad_form_vari_alloc< Td, Rd, Cd, Ta, Ra, Ca, Tb, Rb, Cb > * impl_
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
Definition is_eigen.hpp:65
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
require_any_t< is_var< std::decay_t< Types > >... > require_any_var_t
Require any of the types satisfy is_var.
Definition is_var.hpp:39
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
Definition trace.hpp:22
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
auto trace_gen_quad_form(const TD &D, const TA &A, const TB &B)
Return the trace of D times the quadratic form of B and A.
var_value< double > var
Definition var.hpp:1187
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...