1#ifndef STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
2#define STAN_MATH_REV_FUN_TRACE_GEN_INV_QUAD_FORM_LDLT_HPP
30template <
typename Td,
typename Ta,
typename Tb,
31 require_not_col_vector_t<Td>* =
nullptr,
32 require_all_matrix_t<Td, Ta, Tb>* =
nullptr,
33 require_any_st_var<Td, Ta, Tb>* =
nullptr>
40 if (D.size() == 0 || A.matrix().size() == 0) {
49 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
50 auto BTAsolveB =
to_arena(arena_B.val_op().transpose() * AsolveB);
52 var res = (arena_D.val() * BTAsolveB).
trace();
55 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]()
mutable {
56 double C_adj = res.adj();
58 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
59 * AsolveB.transpose();
60 arena_B.adj() += C_adj * AsolveB
61 * (arena_D.val_op() + arena_D.val_op().transpose());
62 arena_D.adj() += C_adj * BTAsolveB;
71 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
73 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).
trace();
76 double C_adj = res.adj();
79 -= C_adj * AsolveB * arena_D.transpose() * AsolveB.transpose();
80 arena_B.adj() += C_adj * AsolveB * (arena_D + arena_D.transpose());
87 const auto& B_ref =
to_ref(B);
92 var res = (arena_D.val() * BTAsolveB).
trace();
95 [arena_A, BTAsolveB, AsolveB, arena_D, res]()
mutable {
96 double C_adj = res.adj();
98 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
99 * AsolveB.transpose();
100 arena_D.adj() += C_adj * BTAsolveB;
107 const auto& B_ref =
to_ref(B);
114 double C_adj = res.adj();
116 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().transpose()
117 * AsolveB.transpose();
125 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
126 auto BTAsolveB =
to_arena(arena_B.val_op().transpose() * AsolveB);
128 var res = (arena_D.val() * BTAsolveB).
trace();
131 [BTAsolveB, AsolveB, arena_B, arena_D, res]()
mutable {
132 double C_adj = res.adj();
134 arena_B.adj() += C_adj * AsolveB
135 * (arena_D.val_op() + arena_D.val_op().transpose());
136 arena_D.adj() += C_adj * BTAsolveB;
144 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
146 var res = (arena_D * arena_B.val_op().transpose() * AsolveB).
trace();
149 arena_B.adj() += res.adj() * AsolveB * (arena_D + arena_D.transpose());
155 const auto& B_ref =
to_ref(B);
160 var res = (arena_D.val() * BTAsolveB).
trace();
163 arena_D.adj() += res.adj() * BTAsolveB;
187template <
typename Td,
typename Ta,
typename Tb,
196 if (D.size() == 0 || A.matrix().size() == 0) {
205 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
206 auto BTAsolveB =
to_arena(arena_B.val_op().transpose() * AsolveB);
208 var res = (arena_D.val().asDiagonal() * BTAsolveB).
trace();
211 [arena_A, BTAsolveB, AsolveB, arena_B, arena_D, res]()
mutable {
212 double C_adj = res.adj();
214 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
215 * AsolveB.transpose();
216 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
217 arena_D.adj() += C_adj * BTAsolveB.diagonal();
226 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
228 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
232 double C_adj = res.adj();
235 -= C_adj * AsolveB * arena_D.asDiagonal() * AsolveB.transpose();
236 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.asDiagonal();
243 const auto& B_ref =
to_ref(B);
248 var res = (arena_D.val().asDiagonal() * BTAsolveB).
trace();
251 [arena_A, BTAsolveB, AsolveB, arena_D, res]()
mutable {
252 double C_adj = res.adj();
254 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
255 * AsolveB.transpose();
256 arena_D.adj() += C_adj * BTAsolveB.diagonal();
263 const auto& B_ref =
to_ref(B);
267 var res = (arena_D.asDiagonal() *
value_of(B_ref).transpose() * AsolveB)
271 double C_adj = res.adj();
273 arena_A.adj() -= C_adj * AsolveB * arena_D.val_op().asDiagonal()
274 * AsolveB.transpose();
282 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
283 auto BTAsolveB =
to_arena(arena_B.val_op().transpose() * AsolveB);
285 var res = (arena_D.val().asDiagonal() * BTAsolveB).
trace();
288 [BTAsolveB, AsolveB, arena_B, arena_D, res]()
mutable {
289 double C_adj = res.adj();
291 arena_B.adj() += C_adj * AsolveB * 2 * arena_D.val_op().asDiagonal();
292 arena_D.adj() += C_adj * BTAsolveB.diagonal();
300 auto AsolveB =
to_arena(A.ldlt().solve(arena_B.val()));
302 var res = (arena_D.asDiagonal() * arena_B.val_op().transpose() * AsolveB)
306 arena_B.adj() += res.adj() * AsolveB * 2 * arena_D.asDiagonal();
312 const auto& B_ref =
to_ref(B);
317 var res = (arena_D.val().asDiagonal() * BTAsolveB).
trace();
320 arena_D.adj() += res.adj() * BTAsolveB.diagonal();
LDLT_factor is a structure that holds a matrix of type T and the LDLT of its values.
require_t< is_col_vector< std::decay_t< T > > > require_col_vector_t
Require type satisfies is_col_vector.
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
auto transpose(Arg &&a)
Transposes a kernel generator expression.
require_any_t< is_var< scalar_type_t< std::decay_t< Types > > >... > require_any_st_var
Require any of the scalar types satisfy is_var.
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
return_type_t< EigMat1, T2, EigMat3 > trace_gen_inv_quad_form_ldlt(const EigMat1 &D, LDLT_factor< T2 > &A, const EigMat3 &B)
Compute the trace of an inverse quadratic form.
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...