Automatic Differentiation
 
Loading...
Searching...
No Matches
trace_gen_inv_quad_form_ldlt.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
2#define STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
3
9#include <type_traits>
10
11namespace stan {
12namespace math {
13
29template <typename Td, typename Ta, typename Tb,
30 require_not_col_vector_t<Td>* = nullptr,
31 require_all_matrix_t<Td, Ta, Tb>* = nullptr,
32 require_any_st_var<Td, Ta, Tb>* = nullptr>
34 const Tb& B) {
35 check_square("trace_gen_inv_quad_form_ldlt", "D", D);
36 check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
37 check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
38
39 if (D.size() == 0 || A.matrix().size() == 0) {
40 return 0;
41 }
42
45 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
48 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
49 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
50
51 var res = (arena_D.val() * BTAsolveB).trace();
52
54 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
55 double C_adj = res.adj();
56
57 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
58 * AsolveB.transpose();
59 arena_B.adj() += C_adj * AsolveB
60 * (arena_D.val_op() + arena_D.val_op().transpose());
61 arena_D.adj() += C_adj * BTAsolveB;
62 });
63
64 return res;
67 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
70 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
71
72 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).trace();
73
74 reverse_pass_callback([arena_A, AsolveB, arena_B, arena_D, res]() mutable {
75 double C_adj = res.adj();
76
77 arena_A.adj()
78 -= C_adj * AsolveB * arena_D.transpose() * AsolveB.transpose();
79 arena_B.adj() += C_adj * AsolveB * (arena_D + arena_D.transpose());
80 });
81
82 return res;
85 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
86 const auto& B_ref = to_ref(B);
88 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
89 auto BTAsolveB = to_arena(value_of(B_ref).transpose() * AsolveB);
90
91 var res = (arena_D.val() * BTAsolveB).trace();
92
94 [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable {
95 double C_adj = res.adj();
96
97 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
98 * AsolveB.transpose();
99 arena_D.adj() += C_adj * BTAsolveB;
100 });
101
102 return res;
105 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
106 const auto& B_ref = to_ref(B);
108 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
109
110 var res = (arena_D * value_of(B_ref).transpose() * AsolveB).trace();
111
112 reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable {
113 double C_adj = res.adj();
114
115 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
116 * AsolveB.transpose();
117 });
118
119 return res;
124 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
125 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
126
127 var res = (arena_D.val() * BTAsolveB).trace();
128
130 [BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
131 double C_adj = res.adj();
132
133 arena_B.adj() += C_adj * AsolveB
134 * (arena_D.val_op() + arena_D.val_op().transpose());
135 arena_D.adj() += C_adj * BTAsolveB;
136 });
137
138 return res;
143 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
144
145 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).trace();
146
147 reverse_pass_callback([AsolveB, arena_B, arena_D, res]() mutable {
148 arena_B.adj() += res.adj() * AsolveB * (arena_D + arena_D.transpose());
149 });
150
151 return res;
154 const auto& B_ref = to_ref(B);
156 auto BTAsolveB = to_arena(value_of(B_ref).transpose()
157 * A.ldlt().solve(value_of(B_ref)));
158
159 var res = (arena_D.val() * BTAsolveB).trace();
160
161 reverse_pass_callback([BTAsolveB, arena_D, res]() mutable {
162 arena_D.adj() += res.adj() * BTAsolveB;
163 });
164
165 return res;
166 }
167}
168
186template <typename Td, typename Ta, typename Tb,
187 require_col_vector_t<Td>* = nullptr,
191 const Tb& B) {
192 check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
193 check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
194
195 if (D.size() == 0 || A.matrix().size() == 0) {
196 return 0;
197 }
198
201 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
204 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
205 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
206
207 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
208
210 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
211 double C_adj = res.adj();
212
213 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
214 * AsolveB.transpose();
215 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
216 arena_D.adj() += C_adj * BTAsolveB.diagonal();
217 });
218
219 return res;
222 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
225 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
226
227 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
228 .trace();
229
230 reverse_pass_callback([arena_A, AsolveB, arena_B, arena_D, res]() mutable {
231 double C_adj = res.adj();
232
233 arena_A.adj()
234 -= C_adj * AsolveB * arena_D.asDiagonal() * AsolveB.transpose();
235 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.asDiagonal();
236 });
237
238 return res;
241 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
242 const auto& B_ref = to_ref(B);
244 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
245 auto BTAsolveB = to_arena(value_of(B_ref).transpose() * AsolveB);
246
247 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
248
250 [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable {
251 double C_adj = res.adj();
252
253 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
254 * AsolveB.transpose();
255 arena_D.adj() += C_adj * BTAsolveB.diagonal();
256 });
257
258 return res;
261 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
262 const auto& B_ref = to_ref(B);
264 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
265
266 var res = (arena_D.asDiagonal() * value_of(B_ref).transpose() * AsolveB)
267 .trace();
268
269 reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable {
270 double C_adj = res.adj();
271
272 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
273 * AsolveB.transpose();
274 });
275
276 return res;
281 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
282 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
283
284 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
285
287 [BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
288 double C_adj = res.adj();
289
290 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
291 arena_D.adj() += C_adj * BTAsolveB.diagonal();
292 });
293
294 return res;
299 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
300
301 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
302 .trace();
303
304 reverse_pass_callback([AsolveB, arena_B, arena_D, res]() mutable {
305 arena_B.adj() += res.adj() * AsolveB * 2 * arena_D.asDiagonal();
306 });
307
308 return res;
311 const auto& B_ref = to_ref(B);
313 auto BTAsolveB = to_arena(value_of(B_ref).transpose()
314 * A.ldlt().solve(value_of(B_ref)));
315
316 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
317
318 reverse_pass_callback([BTAsolveB, arena_D, res]() mutable {
319 arena_D.adj() += res.adj() * BTAsolveB.diagonal();
320 });
321
322 return res;
323 }
324}
325
326} // namespace math
327} // namespace stan
328#endif
LDLT_factor is a structure that holds a matrix of type T and the LDLT of its values.
require_t< is_col_vector< std::decay_t< T > > > require_col_vector_t
Require type satisfies is_col_vector.
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
auto transpose(Arg &&a)
Transposes a kernel generator expression.
require_any_t< is_var< scalar_type_t< std::decay_t< Types > > >... > require_any_st_var
Require any of the scalar types satisfy is_var.
Definition is_var.hpp:131
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
Definition trace.hpp:22
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
return_type_t< EigMat1, T2, EigMat3 > trace_gen_inv_quad_form_ldlt(const EigMat1 &D, LDLT_factor< T2 > &A, const EigMat3 &B)
Compute the trace of an inverse quadratic form.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...