Automatic Differentiation
 
Loading...
Searching...
No Matches
trace_inv_quad_form_ldlt.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
2#define STAN_MATH_REV_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
3
12#include <type_traits>
13
14namespace stan {
15namespace math {
16
30template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr,
31 require_any_st_var<T1, T2>* = nullptr>
33 check_multiplicable("trace_quad_form", "A", A.matrix(), "B", B);
34
35 if (A.matrix().size() == 0)
36 return 0.0;
37
39 arena_t<promote_scalar_t<var, T1>> arena_A = A.matrix();
41 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
42
43 var res = (arena_B.val_op().transpose() * AsolveB).trace();
44
45 reverse_pass_callback([arena_A, AsolveB, arena_B, res]() mutable {
46 arena_A.adj() += -res.adj() * AsolveB * AsolveB.transpose();
47 arena_B.adj() += 2 * res.adj() * AsolveB;
48 });
49
50 return res;
51 } else if (!is_constant<T1>::value) {
52 arena_t<promote_scalar_t<var, T1>> arena_A = A.matrix();
53 const auto& B_ref = to_ref(B);
54
55 auto AsolveB = to_arena(A.ldlt().solve(value_of(B_ref)));
56
57 var res = (value_of(B_ref).transpose() * AsolveB).trace();
58
59 reverse_pass_callback([arena_A, AsolveB, res]() mutable {
60 arena_A.adj() += -res.adj() * AsolveB * AsolveB.transpose();
61 });
62
63 return res;
64 } else {
66 auto AsolveB = to_arena(A.ldlt().solve(arena_B.val()));
67
68 var res = (arena_B.val_op().transpose() * AsolveB).trace();
69
70 reverse_pass_callback([AsolveB, arena_B, res]() mutable {
71 arena_B.adj() += 2 * res.adj() * AsolveB;
72 });
73
74 return res;
75 }
76}
77
78} // namespace math
79} // namespace stan
80#endif
LDLT_factor is a structure that holds a matrix of type T and the LDLT of its values.
return_type_t< T, EigMat2 > trace_inv_quad_form_ldlt(LDLT_factor< T > &A, const EigMat2 &B)
Compute the trace of an inverse quadratic form.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
Definition trace.hpp:22
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...