Automatic Differentiation
 
Loading...
Searching...
No Matches
trace_quad_form.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_TRACE_QUAD_FORM_HPP
2#define STAN_MATH_REV_FUN_TRACE_QUAD_FORM_HPP
3
14#include <type_traits>
15
16namespace stan {
17namespace math {
18namespace internal {
19template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
21 public:
22 trace_quad_form_vari_alloc(const Eigen::Matrix<Ta, Ra, Ca>& A,
23 const Eigen::Matrix<Tb, Rb, Cb>& B)
24 : A_(A), B_(B) {}
25
26 double compute() { return trace_quad_form(value_of(A_), value_of(B_)); }
27
28 Eigen::Matrix<Ta, Ra, Ca> A_;
29 Eigen::Matrix<Tb, Rb, Cb> B_;
30};
31
32template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
33class trace_quad_form_vari : public vari {
34 protected:
35 static inline void chainA(Eigen::Matrix<double, Ra, Ca>& A,
36 const Eigen::Matrix<double, Rb, Cb>& Bd,
37 double adjC) {}
38 static inline void chainB(Eigen::Matrix<double, Rb, Cb>& B,
39 const Eigen::Matrix<double, Ra, Ca>& Ad,
40 const Eigen::Matrix<double, Rb, Cb>& Bd,
41 double adjC) {}
42
43 static inline void chainA(Eigen::Matrix<var, Ra, Ca>& A,
44 const Eigen::Matrix<double, Rb, Cb>& Bd,
45 double adjC) {
46 A.adj() += adjC * Bd * Bd.transpose();
47 }
48 static inline void chainB(Eigen::Matrix<var, Rb, Cb>& B,
49 const Eigen::Matrix<double, Ra, Ca>& Ad,
50 const Eigen::Matrix<double, Rb, Cb>& Bd,
51 double adjC) {
52 B.adj() += adjC * (Ad + Ad.transpose()) * Bd;
53 }
54
55 inline void chainAB(Eigen::Matrix<Ta, Ra, Ca>& A,
56 Eigen::Matrix<Tb, Rb, Cb>& B,
57 const Eigen::Matrix<double, Ra, Ca>& Ad,
58 const Eigen::Matrix<double, Rb, Cb>& Bd, double adjC) {
59 chainA(A, Bd, adjC);
60 chainB(B, Ad, Bd, adjC);
61 }
62
63 public:
66 : vari(impl->compute()), impl_(impl) {}
67
68 virtual void chain() {
69 chainAB(impl_->A_, impl_->B_, value_of(impl_->A_), value_of(impl_->B_),
70 adj_);
71 }
72
74};
75} // namespace internal
76
77template <typename EigMat1, typename EigMat2,
81 const EigMat2& B) {
82 using Ta = value_type_t<EigMat1>;
83 using Tb = value_type_t<EigMat2>;
84 constexpr int Ra = EigMat1::RowsAtCompileTime;
85 constexpr int Ca = EigMat1::ColsAtCompileTime;
86 constexpr int Rb = EigMat2::RowsAtCompileTime;
87 constexpr int Cb = EigMat2::ColsAtCompileTime;
88 check_square("trace_quad_form", "A", A);
89 check_multiplicable("trace_quad_form", "A", A, "B", B);
90
91 auto* baseVari
93
94 return var(
96}
97
115template <typename Mat1, typename Mat2,
118inline var trace_quad_form(const Mat1& A, const Mat2& B) {
119 check_square("trace_quad_form", "A", A);
120 check_multiplicable("trace_quad_form", "A", A, "B", B);
121
122 var res;
123
127
128 res = (value_of(arena_B).transpose() * value_of(arena_A)
129 * value_of(arena_B))
130 .trace();
131
132 reverse_pass_callback([arena_A, arena_B, res]() mutable {
134 arena_A.adj().noalias()
135 += res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
136 } else {
137 arena_A.adj()
138 += res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
139 }
140
142 arena_B.adj().noalias()
143 += res.adj() * (value_of(arena_A) + value_of(arena_A).transpose())
144 * value_of(arena_B);
145 } else {
146 arena_B.adj() += res.adj()
147 * (value_of(arena_A) + value_of(arena_A).transpose())
148 * value_of(arena_B);
149 }
150 });
151 } else if (!is_constant<Mat2>::value) {
154
155 res = (value_of(arena_B).transpose() * value_of(arena_A)
156 * value_of(arena_B))
157 .trace();
158
159 reverse_pass_callback([arena_A, arena_B, res]() mutable {
161 arena_B.adj().noalias()
162 += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
163 } else {
164 arena_B.adj()
165 += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
166 }
167 });
168 } else {
171
172 res = (arena_B.transpose() * value_of(arena_A) * arena_B).trace();
173
174 reverse_pass_callback([arena_A, arena_B, res]() mutable {
176 arena_A.adj().noalias() += res.adj() * arena_B * arena_B.transpose();
177 } else {
178 arena_A.adj() += res.adj() * arena_B * arena_B.transpose();
179 }
180 });
181 }
182
183 return res;
184}
185
186} // namespace math
187} // namespace stan
188#endif
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
trace_quad_form_vari_alloc(const Eigen::Matrix< Ta, Ra, Ca > &A, const Eigen::Matrix< Tb, Rb, Cb > &B)
static void chainB(Eigen::Matrix< double, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
trace_quad_form_vari(trace_quad_form_vari_alloc< Ta, Ra, Ca, Tb, Rb, Cb > *impl)
void chainAB(Eigen::Matrix< Ta, Ra, Ca > &A, Eigen::Matrix< Tb, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
static void chainA(Eigen::Matrix< double, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
trace_quad_form_vari_alloc< Ta, Ra, Ca, Tb, Rb, Cb > * impl_
static void chainA(Eigen::Matrix< var, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
static void chainB(Eigen::Matrix< var, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
Definition is_eigen.hpp:65
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
auto transpose(Arg &&a)
Transposes a kernel generator expression.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
require_any_t< is_var< scalar_type_t< std::decay_t< Types > > >... > require_any_st_var
Require any of the scalar types satisfy is_var.
Definition is_var.hpp:131
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
Definition trace.hpp:22
var_value< double > var
Definition var.hpp:1187
return_type_t< EigMat1, EigMat2 > trace_quad_form(const EigMat1 &A, const EigMat2 &B)
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is a var_value whose value_type is derived from Eigen::EigenBase