1#ifndef STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
2#define STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
29template <
typename Td,
typename Ta,
typename Tb,
30 require_not_col_vector_t<Td>* =
nullptr,
31 require_all_matrix_t<Td, Ta, Tb>* =
nullptr,
32 require_any_st_var<Td, Ta, Tb>* =
nullptr>
39 if (D.size() == 0 || A.matrix().size() == 0) {
48 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
49 auto BTAsolveB =
to_arena(arena_B.val_op().transpose() * AsolveB);
51 var res = (arena_D.val() * BTAsolveB).
trace();
54 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]()
mutable {
55 double C_adj = res.adj();
57 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
58 * AsolveB.transpose();
59 arena_B.adj() += C_adj * AsolveB
60 * (arena_D.val_op() + arena_D.val_op().transpose());
61 arena_D.adj() += C_adj * BTAsolveB;
70 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
72 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).
trace();
75 double C_adj = res.adj();
78 -= C_adj * AsolveB * arena_D.transpose() * AsolveB.transpose();
79 arena_B.adj() += C_adj * AsolveB * (arena_D + arena_D.transpose());
86 const auto& B_ref =
to_ref(B);
91 var res = (arena_D.val() * BTAsolveB).
trace();
94 [arena_A, BTAsolveB, AsolveB, arena_D, res]()
mutable {
95 double C_adj = res.adj();
97 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
98 * AsolveB.transpose();
99 arena_D.adj() += C_adj * BTAsolveB;
106 const auto& B_ref =
to_ref(B);
113 double C_adj = res.adj();
115 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
116 * AsolveB.transpose();
124 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
125 auto BTAsolveB =
to_arena(arena_B.val_op().transpose() * AsolveB);
127 var res = (arena_D.val() * BTAsolveB).
trace();
130 [BTAsolveB, AsolveB, arena_B, arena_D, res]()
mutable {
131 double C_adj = res.adj();
133 arena_B.adj() += C_adj * AsolveB
134 * (arena_D.val_op() + arena_D.val_op().transpose());
135 arena_D.adj() += C_adj * BTAsolveB;
143 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
145 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).
trace();
148 arena_B.adj() += res.adj() * AsolveB * (arena_D + arena_D.transpose());
154 const auto& B_ref =
to_ref(B);
159 var res = (arena_D.val() * BTAsolveB).
trace();
162 arena_D.adj() += res.adj() * BTAsolveB;
186template <
typename Td,
typename Ta,
typename Tb,
195 if (D.size() == 0 || A.matrix().size() == 0) {
204 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
205 auto BTAsolveB =
to_arena(arena_B.val_op().transpose() * AsolveB);
207 var res = (arena_D.val().asDiagonal() * BTAsolveB).
trace();
210 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]()
mutable {
211 double C_adj = res.adj();
213 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
214 * AsolveB.transpose();
215 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
216 arena_D.adj() += C_adj * BTAsolveB.diagonal();
225 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
227 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
231 double C_adj = res.adj();
234 -= C_adj * AsolveB * arena_D.asDiagonal() * AsolveB.transpose();
235 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.asDiagonal();
242 const auto& B_ref =
to_ref(B);
247 var res = (arena_D.val().asDiagonal() * BTAsolveB).
trace();
250 [arena_A, BTAsolveB, AsolveB, arena_D, res]()
mutable {
251 double C_adj = res.adj();
253 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
254 * AsolveB.transpose();
255 arena_D.adj() += C_adj * BTAsolveB.diagonal();
262 const auto& B_ref =
to_ref(B);
266 var res = (arena_D.asDiagonal() *
value_of(B_ref).transpose() * AsolveB)
270 double C_adj = res.adj();
272 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
273 * AsolveB.transpose();
281 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
282 auto BTAsolveB =
to_arena(arena_B.val_op().transpose() * AsolveB);
284 var res = (arena_D.val().asDiagonal() * BTAsolveB).
trace();
287 [BTAsolveB, AsolveB, arena_B, arena_D, res]()
mutable {
288 double C_adj = res.adj();
290 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
291 arena_D.adj() += C_adj * BTAsolveB.diagonal();
299 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
301 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
305 arena_B.adj() += res.adj() * AsolveB * 2 * arena_D.asDiagonal();
311 const auto& B_ref =
to_ref(B);
316 var res = (arena_D.val().asDiagonal() * BTAsolveB).
trace();
319 arena_D.adj() += res.adj() * BTAsolveB.diagonal();
LDLT_factor is a structure that holds a matrix of type T and the LDLT of its values.
require_t< is_col_vector< std::decay_t< T > > > require_col_vector_t
Require type satisfies is_col_vector.
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
auto transpose(Arg &&a)
Transposes a kernel generator expression.
require_any_t< is_var< scalar_type_t< std::decay_t< Types > > >... > require_any_st_var
Require any of the scalar types satisfy is_var.
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
return_type_t< EigMat1, T2, EigMat3 > trace_gen_inv_quad_form_ldlt(const EigMat1 &D, LDLT_factor< T2 > &A, const EigMat3 &B)
Compute the trace of an inverse quadratic form.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...