Automatic Differentiation
 
Loading...
Searching...
No Matches
trace_quad_form.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_TRACE_QUAD_FORM_HPP
2#define STAN_MATH_REV_FUN_TRACE_QUAD_FORM_HPP
3
13#include <type_traits>
14
15namespace stan {
16namespace math {
17namespace internal {
18template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
20 public:
21 trace_quad_form_vari_alloc(const Eigen::Matrix<Ta, Ra, Ca>& A,
22 const Eigen::Matrix<Tb, Rb, Cb>& B)
23 : A_(A), B_(B) {}
24
25 double compute() { return trace_quad_form(value_of(A_), value_of(B_)); }
26
27 Eigen::Matrix<Ta, Ra, Ca> A_;
28 Eigen::Matrix<Tb, Rb, Cb> B_;
29};
30
31template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
32class trace_quad_form_vari : public vari {
33 protected:
34 static inline void chainA(Eigen::Matrix<double, Ra, Ca>& A,
35 const Eigen::Matrix<double, Rb, Cb>& Bd,
36 double adjC) {}
37 static inline void chainB(Eigen::Matrix<double, Rb, Cb>& B,
38 const Eigen::Matrix<double, Ra, Ca>& Ad,
39 const Eigen::Matrix<double, Rb, Cb>& Bd,
40 double adjC) {}
41
42 static inline void chainA(Eigen::Matrix<var, Ra, Ca>& A,
43 const Eigen::Matrix<double, Rb, Cb>& Bd,
44 double adjC) {
45 A.adj() += adjC * Bd * Bd.transpose();
46 }
47 static inline void chainB(Eigen::Matrix<var, Rb, Cb>& B,
48 const Eigen::Matrix<double, Ra, Ca>& Ad,
49 const Eigen::Matrix<double, Rb, Cb>& Bd,
50 double adjC) {
51 B.adj() += adjC * (Ad + Ad.transpose()) * Bd;
52 }
53
54 inline void chainAB(Eigen::Matrix<Ta, Ra, Ca>& A,
55 Eigen::Matrix<Tb, Rb, Cb>& B,
56 const Eigen::Matrix<double, Ra, Ca>& Ad,
57 const Eigen::Matrix<double, Rb, Cb>& Bd, double adjC) {
58 chainA(A, Bd, adjC);
59 chainB(B, Ad, Bd, adjC);
60 }
61
62 public:
65 : vari(impl->compute()), impl_(impl) {}
66
67 virtual void chain() {
68 chainAB(impl_->A_, impl_->B_, value_of(impl_->A_), value_of(impl_->B_),
69 adj_);
70 }
71
73};
74} // namespace internal
75
76template <typename EigMat1, typename EigMat2,
80 const EigMat2& B) {
81 using Ta = value_type_t<EigMat1>;
82 using Tb = value_type_t<EigMat2>;
83 constexpr int Ra = EigMat1::RowsAtCompileTime;
84 constexpr int Ca = EigMat1::ColsAtCompileTime;
85 constexpr int Rb = EigMat2::RowsAtCompileTime;
86 constexpr int Cb = EigMat2::ColsAtCompileTime;
87 check_square("trace_quad_form", "A", A);
88 check_multiplicable("trace_quad_form", "A", A, "B", B);
89
90 auto* baseVari
92
93 return var(
95}
96
114template <typename Mat1, typename Mat2,
117inline var trace_quad_form(const Mat1& A, const Mat2& B) {
118 check_square("trace_quad_form", "A", A);
119 check_multiplicable("trace_quad_form", "A", A, "B", B);
120
121 var res;
122
126
127 res = (value_of(arena_B).transpose() * value_of(arena_A)
128 * value_of(arena_B))
129 .trace();
130
131 reverse_pass_callback([arena_A, arena_B, res]() mutable {
133 arena_A.adj().noalias()
134 += res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
135 } else {
136 arena_A.adj()
137 += res.adj() * value_of(arena_B) * value_of(arena_B).transpose();
138 }
139
141 arena_B.adj().noalias()
142 += res.adj() * (value_of(arena_A) + value_of(arena_A).transpose())
143 * value_of(arena_B);
144 } else {
145 arena_B.adj() += res.adj()
146 * (value_of(arena_A) + value_of(arena_A).transpose())
147 * value_of(arena_B);
148 }
149 });
150 } else if (!is_constant<Mat2>::value) {
153
154 res = (value_of(arena_B).transpose() * value_of(arena_A)
155 * value_of(arena_B))
156 .trace();
157
158 reverse_pass_callback([arena_A, arena_B, res]() mutable {
160 arena_B.adj().noalias()
161 += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
162 } else {
163 arena_B.adj()
164 += res.adj() * (arena_A + arena_A.transpose()) * value_of(arena_B);
165 }
166 });
167 } else {
170
171 res = (arena_B.transpose() * value_of(arena_A) * arena_B).trace();
172
173 reverse_pass_callback([arena_A, arena_B, res]() mutable {
175 arena_A.adj().noalias() += res.adj() * arena_B * arena_B.transpose();
176 } else {
177 arena_A.adj() += res.adj() * arena_B * arena_B.transpose();
178 }
179 });
180 }
181
182 return res;
183}
184
185} // namespace math
186} // namespace stan
187#endif
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
trace_quad_form_vari_alloc(const Eigen::Matrix< Ta, Ra, Ca > &A, const Eigen::Matrix< Tb, Rb, Cb > &B)
static void chainB(Eigen::Matrix< double, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
trace_quad_form_vari(trace_quad_form_vari_alloc< Ta, Ra, Ca, Tb, Rb, Cb > *impl)
void chainAB(Eigen::Matrix< Ta, Ra, Ca > &A, Eigen::Matrix< Tb, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
static void chainA(Eigen::Matrix< double, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
trace_quad_form_vari_alloc< Ta, Ra, Ca, Tb, Rb, Cb > * impl_
static void chainA(Eigen::Matrix< var, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
static void chainB(Eigen::Matrix< var, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, double adjC)
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
Definition is_eigen.hpp:120
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
auto transpose(Arg &&a)
Transposes a kernel generator expression.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
require_any_t< is_var< scalar_type_t< std::decay_t< Types > > >... > require_any_st_var
Require any of the scalar types satisfy is_var.
Definition is_var.hpp:131
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
value_type_t< T > trace(const T &m)
Calculates trace (sum of diagonal) of given kernel generator expression.
Definition trace.hpp:22
var_value< double > var
Definition var.hpp:1187
return_type_t< EigMat1, EigMat2 > trace_quad_form(const EigMat1 &A, const EigMat2 &B)
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is a var_value whose value_type is derived from Eigen::EigenBase