Automatic Differentiation
 
Loading...
Searching...
No Matches
trace_gen_inv_quad_form_ldlt.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
2#define STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
3
10#include <type_traits>
11
12namespace stan {
13namespace math {
14
30template <typename Td, typename Ta, typename Tb,
31 require_not_col_vector_t<Td>* = nullptr,
32 require_all_matrix_t<Td, Ta, Tb>* = nullptr,
33 require_any_st_var<Td, Ta, Tb>* = nullptr>
35 const Tb& B) {
36 check_square("trace_gen_inv_quad_form_ldlt", "D", D);
37 check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
38 check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
39
40 if (D.size() == 0 || A.matrix().size() == 0) {
41 return 0;
42 }
43
44 if constexpr (is_all_autodiff_v<Ta, Tb, Td>) {
45 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
48 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
49 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
50
51 var res = (arena_D.val() * BTAsolveB).trace();
52
54 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
55 double C_adj = res.adj();
56
57 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
58 * AsolveB.transpose();
59 arena_B.adj() += C_adj * AsolveB
60 * (arena_D.val_op() + arena_D.val_op().transpose());
61 arena_D.adj() += C_adj * BTAsolveB;
62 });
63
64 return res;
65 } else if constexpr (is_all_autodiff_v<Ta, Tb> && is_constant_v<Td>) {
66 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
69 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
70
71 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).trace();
72
73 reverse_pass_callback([arena_A, AsolveB, arena_B, arena_D, res]() mutable {
74 double C_adj = res.adj();
75
76 arena_A.adj()
77 -= C_adj * AsolveB * arena_D.transpose() * AsolveB.transpose();
78 arena_B.adj() += C_adj * AsolveB * (arena_D + arena_D.transpose());
79 });
80
81 return res;
82 } else if constexpr (is_all_autodiff_v<Ta, Td> && is_constant_v<Tb>) {
83 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
84 const auto& B_ref = to_ref(B);
86 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
87 auto BTAsolveB = to_arena(value_of(B_ref).transpose() * AsolveB);
88
89 var res = (arena_D.val() * BTAsolveB).trace();
90
92 [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable {
93 double C_adj = res.adj();
94
95 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
96 * AsolveB.transpose();
97 arena_D.adj() += C_adj * BTAsolveB;
98 });
99
100 return res;
101 } else if constexpr (is_autodiff_v<Ta> && is_constant_all_v<Tb, Td>) {
102 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
103 const auto& B_ref = to_ref(B);
105 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
106
107 var res = (arena_D * value_of(B_ref).transpose() * AsolveB).trace();
108
109 reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable {
110 double C_adj = res.adj();
111
112 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
113 * AsolveB.transpose();
114 });
115
116 return res;
117 } else if constexpr (is_constant_v<Ta> && is_all_autodiff_v<Tb, Td>) {
120 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
121 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
122
123 var res = (arena_D.val() * BTAsolveB).trace();
124
126 [BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
127 double C_adj = res.adj();
128
129 arena_B.adj() += C_adj * AsolveB
130 * (arena_D.val_op() + arena_D.val_op().transpose());
131 arena_D.adj() += C_adj * BTAsolveB;
132 });
133
134 return res;
135 } else if constexpr (is_constant_all_v<Ta, Td> && is_autodiff_v<Tb>) {
138 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
139
140 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).trace();
141
142 reverse_pass_callback([AsolveB, arena_B, arena_D, res]() mutable {
143 arena_B.adj() += res.adj() * AsolveB * (arena_D + arena_D.transpose());
144 });
145
146 return res;
147 } else if constexpr (is_constant_all_v<Ta, Tb> && is_autodiff_v<Td>) {
148 const auto& B_ref = to_ref(B);
150 auto BTAsolveB = to_arena(value_of(B_ref).transpose()
151 * A.ldlt().solve(value_of(B_ref)));
152
153 var res = (arena_D.val() * BTAsolveB).trace();
154
155 reverse_pass_callback([BTAsolveB, arena_D, res]() mutable {
156 arena_D.adj() += res.adj() * BTAsolveB;
157 });
158
159 return res;
160 }
161}
162
180template <typename Td, typename Ta, typename Tb,
181 require_col_vector_t<Td>* = nullptr,
185 const Tb& B) {
186 check_multiplicable("trace_gen_inv_quad_form_ldlt", "A", A.matrix(), "B", B);
187 check_multiplicable("trace_gen_inv_quad_form_ldlt", "B", B, "D", D);
188
189 if (D.size() == 0 || A.matrix().size() == 0) {
190 return 0;
191 }
192
193 if constexpr (is_all_autodiff_v<Ta, Tb, Td>) {
194 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
197 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
198 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
199
200 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
201
203 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
204 double C_adj = res.adj();
205
206 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
207 * AsolveB.transpose();
208 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
209 arena_D.adj() += C_adj * BTAsolveB.diagonal();
210 });
211
212 return res;
213 } else if constexpr (is_all_autodiff_v<Ta, Tb> && is_constant_v<Td>) {
214 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
217 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
218
219 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
220 .trace();
221
222 reverse_pass_callback([arena_A, AsolveB, arena_B, arena_D, res]() mutable {
223 double C_adj = res.adj();
224
225 arena_A.adj()
226 -= C_adj * AsolveB * arena_D.asDiagonal() * AsolveB.transpose();
227 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.asDiagonal();
228 });
229
230 return res;
231 } else if constexpr (is_all_autodiff_v<Ta, Td> && is_constant_v<Tb>) {
232 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
233 const auto& B_ref = to_ref(B);
235 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
236 auto BTAsolveB = to_arena(value_of(B_ref).transpose() * AsolveB);
237
238 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
239
241 [arena_A, BTAsolveB, AsolveB, arena_D, res]() mutable {
242 double C_adj = res.adj();
243
244 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
245 * AsolveB.transpose();
246 arena_D.adj() += C_adj * BTAsolveB.diagonal();
247 });
248
249 return res;
250 } else if constexpr (is_autodiff_v<Ta> && is_constant_all_v<Tb, Td>) {
251 arena_t<promote_scalar_t<var, Ta>> arena_A = A.matrix();
252 const auto& B_ref = to_ref(B);
254 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
255
256 var res = (arena_D.asDiagonal() * value_of(B_ref).transpose() * AsolveB)
257 .trace();
258
259 reverse_pass_callback([arena_A, AsolveB, arena_D, res]() mutable {
260 double C_adj = res.adj();
261
262 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
263 * AsolveB.transpose();
264 });
265
266 return res;
267 } else if constexpr (is_constant_v<Ta> && is_all_autodiff_v<Tb, Td>) {
270 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
271 auto BTAsolveB = to_arena(arena_B.val_op().transpose() * AsolveB);
272
273 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
274
276 [BTAsolveB, AsolveB, arena_B, arena_D, res]() mutable {
277 double C_adj = res.adj();
278
279 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
280 arena_D.adj() += C_adj * BTAsolveB.diagonal();
281 });
282
283 return res;
284 } else if constexpr (is_constant_all_v<Ta, Td> && is_autodiff_v<Tb>) {
287 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
288
289 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
290 .trace();
291
292 reverse_pass_callback([AsolveB, arena_B, arena_D, res]() mutable {
293 arena_B.adj() += res.adj() * AsolveB * 2 * arena_D.asDiagonal();
294 });
295
296 return res;
297 } else if constexpr (is_constant_all_v<Ta, Tb> && is_autodiff_v<Td>) {
298 const auto& B_ref = to_ref(B);
300 auto BTAsolveB = to_arena(value_of(B_ref).transpose()
301 * A.ldlt().solve(value_of(B_ref)));
302
303 var res = (arena_D.val().asDiagonal() * BTAsolveB).trace();
304
305 reverse_pass_callback([BTAsolveB, arena_D, res]() mutable {
306 arena_D.adj() += res.adj() * BTAsolveB.diagonal();
307 });
308
309 return res;
310 }
311}
312
313} // namespace math
314} // namespace stan
315#endif
LDLT_factor is a structure that holds a matrix of type T and the LDLT of its values.
require_t< is_col_vector< std::decay_t< T > > > require_col_vector_t
Require type satisfies is_col_vector.
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
auto transpose(Arg &&a)
Transposes a kernel generator expression.
require_any_t< is_var< scalar_type_t< std::decay_t< Types > > >... > require_any_st_var
Require any of the scalar types satisfy is_var.
Definition is_var.hpp:196
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
Definition trace.hpp:22
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
return_type_t< EigMat1, T2, EigMat3 > trace_gen_inv_quad_form_ldlt(const EigMat1 &D, LDLT_factor< T2 > &A, const EigMat3 &B)
Compute the trace of an inverse quadratic form.
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:18
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...