Automatic Differentiation
 
Loading...
Searching...
No Matches
trace_gen_inv_quad_form_ldlt.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
2#define STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
3
10#include <type_traits>
11
12namespace stan {
13namespace math {
14
30template <typename Td, typename Ta, typename Tb,
31 require_not_col_vector_t<Td>* = nullptr,
32 require_all_matrix_t<Td, Ta, Tb>* = nullptr,
33 require_any_st_var<Td, Ta, Tb>* = nullptr>
35 const Tb& B) {
36 check_square("trace_gen_inv_quad_form_ldlt", "D", D);
37 check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
38 check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
39
40 if (D.size() == 0 || A.matrix().size() == 0) {
41 return 0;
42 }
43
46 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
49 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
50 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
51
52 var res = (arena_D.val() * BTAsolveB).trace();
53
55 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
56 double C_adj = res.adj();
57
58 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
59 * AsolveB.transpose();
60 arena_B.adj() += C_adj * AsolveB
61 * (arena_D.val_op() + arena_D.val_op().transpose());
62 arena_D.adj() += C_adj * BTAsolveB;
63 });
64
65 return res;
68 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
71 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
72
73 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).trace();
74
75 reverse_pass_callback([arena_A, AsolveB, arena_B, arena_D, res]() mutable {
76 double C_adj = res.adj();
77
78 arena_A.adj()
79 -= C_adj * AsolveB * arena_D.transpose() * AsolveB.transpose();
80 arena_B.adj() += C_adj * AsolveB * (arena_D + arena_D.transpose());
81 });
82
83 return res;
86 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
87 const auto& B_ref = to_ref(B);
89 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
90 auto BTAsolveB = to_arena(value_of(B_ref).transpose() * AsolveB);
91
92 var res = (arena_D.val() * BTAsolveB).trace();
93
95 [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable {
96 double C_adj = res.adj();
97
98 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
99 * AsolveB.transpose();
100 arena_D.adj() += C_adj * BTAsolveB;
101 });
102
103 return res;
106 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
107 const auto& B_ref = to_ref(B);
109 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
110
111 var res = (arena_D * value_of(B_ref).transpose() * AsolveB).trace();
112
113 reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable {
114 double C_adj = res.adj();
115
116 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
117 * AsolveB.transpose();
118 });
119
120 return res;
125 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
126 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
127
128 var res = (arena_D.val() * BTAsolveB).trace();
129
131 [BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
132 double C_adj = res.adj();
133
134 arena_B.adj() += C_adj * AsolveB
135 * (arena_D.val_op() + arena_D.val_op().transpose());
136 arena_D.adj() += C_adj * BTAsolveB;
137 });
138
139 return res;
144 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
145
146 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).trace();
147
148 reverse_pass_callback([AsolveB, arena_B, arena_D, res]() mutable {
149 arena_B.adj() += res.adj() * AsolveB * (arena_D + arena_D.transpose());
150 });
151
152 return res;
155 const auto& B_ref = to_ref(B);
157 auto BTAsolveB = to_arena(value_of(B_ref).transpose()
158 * A.ldlt().solve(value_of(B_ref)));
159
160 var res = (arena_D.val() * BTAsolveB).trace();
161
162 reverse_pass_callback([BTAsolveB, arena_D, res]() mutable {
163 arena_D.adj() += res.adj() * BTAsolveB;
164 });
165
166 return res;
167 }
168}
169
187template <typename Td, typename Ta, typename Tb,
188 require_col_vector_t<Td>* = nullptr,
192 const Tb& B) {
193 check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
194 check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
195
196 if (D.size() == 0 || A.matrix().size() == 0) {
197 return 0;
198 }
199
202 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
205 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
206 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
207
208 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
209
211 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
212 double C_adj = res.adj();
213
214 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
215 * AsolveB.transpose();
216 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
217 arena_D.adj() += C_adj * BTAsolveB.diagonal();
218 });
219
220 return res;
223 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
226 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
227
228 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
229 .trace();
230
231 reverse_pass_callback([arena_A, AsolveB, arena_B, arena_D, res]() mutable {
232 double C_adj = res.adj();
233
234 arena_A.adj()
235 -= C_adj * AsolveB * arena_D.asDiagonal() * AsolveB.transpose();
236 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.asDiagonal();
237 });
238
239 return res;
242 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
243 const auto& B_ref = to_ref(B);
245 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
246 auto BTAsolveB = to_arena(value_of(B_ref).transpose() * AsolveB);
247
248 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
249
251 [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable {
252 double C_adj = res.adj();
253
254 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
255 * AsolveB.transpose();
256 arena_D.adj() += C_adj * BTAsolveB.diagonal();
257 });
258
259 return res;
262 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
263 const auto& B_ref = to_ref(B);
265 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
266
267 var res = (arena_D.asDiagonal() * value_of(B_ref).transpose() * AsolveB)
268 .trace();
269
270 reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable {
271 double C_adj = res.adj();
272
273 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
274 * AsolveB.transpose();
275 });
276
277 return res;
282 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
283 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
284
285 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
286
288 [BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
289 double C_adj = res.adj();
290
291 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
292 arena_D.adj() += C_adj * BTAsolveB.diagonal();
293 });
294
295 return res;
300 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
301
302 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
303 .trace();
304
305 reverse_pass_callback([AsolveB, arena_B, arena_D, res]() mutable {
306 arena_B.adj() += res.adj() * AsolveB * 2 * arena_D.asDiagonal();
307 });
308
309 return res;
312 const auto& B_ref = to_ref(B);
314 auto BTAsolveB = to_arena(value_of(B_ref).transpose()
315 * A.ldlt().solve(value_of(B_ref)));
316
317 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
318
319 reverse_pass_callback([BTAsolveB, arena_D, res]() mutable {
320 arena_D.adj() += res.adj() * BTAsolveB.diagonal();
321 });
322
323 return res;
324 }
325}
326
327} // namespace math
328} // namespace stan
329#endif
LDLT_factor is a structure that holds a matrix of type T and the LDLT of its values.
require_t< is_col_vector< std::decay_t< T > > > require_col_vector_t
Require type satisfies is_col_vector.
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
auto transpose(Arg &&a)
Transposes a kernel generator expression.
require_any_t< is_var< scalar_type_t< std::decay_t< Types > > >... > require_any_st_var
Require any of the scalar types satisfy is_var.
Definition is_var.hpp:131
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
Definition trace.hpp:22
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
return_type_t< EigMat1, T2, EigMat3 > trace_gen_inv_quad_form_ldlt(const EigMat1 &D, LDLT_factor< T2 > &A, const EigMat3 &B)
Compute the trace of an inverse quadratic form.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...