Automatic Differentiation
 
Loading...
Searching...
No Matches
quad_form.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_QUAD_FORM_HPP
2#define STAN_MATH_REV_FUN_QUAD_FORM_HPP
3
12#include <type_traits>
13
14namespace stan {
15namespace math {
16
17namespace internal {
18template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
20 private:
21 inline void compute(const Eigen::Matrix<double, Ra, Ca>& A,
22 const Eigen::Matrix<double, Rb, Cb>& B) {
23 matrix_d Cd = B.transpose() * A * B;
24 if (sym_) {
25 matrix_d M = 0.5 * (Cd + Cd.transpose());
26 Cd = M;
27 }
28 for (int j = 0; j < C_.cols(); j++) {
29 for (int i = 0; i < C_.rows(); i++) {
30 C_(i, j) = var(new vari(Cd(i, j), false));
31 }
32 }
33 }
34
35 public:
36 quad_form_vari_alloc(const Eigen::Matrix<Ta, Ra, Ca>& A,
37 const Eigen::Matrix<Tb, Rb, Cb>& B,
38 bool symmetric = false)
39 : A_(A), B_(B), C_(B_.cols(), B_.cols()), sym_(symmetric) {
41 }
42
43 Eigen::Matrix<Ta, Ra, Ca> A_;
44 Eigen::Matrix<Tb, Rb, Cb> B_;
45 Eigen::Matrix<var, Cb, Cb> C_;
46 bool sym_;
47};
48
49template <typename Ta, int Ra, int Ca, typename Tb, int Rb, int Cb>
50class quad_form_vari : public vari {
51 protected:
52 inline void chainA(Eigen::Matrix<double, Ra, Ca>& A,
53 const Eigen::Matrix<double, Rb, Cb>& Bd,
54 const Eigen::Matrix<double, Cb, Cb>& adjC) {}
55 inline void chainB(Eigen::Matrix<double, Rb, Cb>& B,
56 const Eigen::Matrix<double, Ra, Ca>& Ad,
57 const Eigen::Matrix<double, Rb, Cb>& Bd,
58 const Eigen::Matrix<double, Cb, Cb>& adjC) {}
59
60 inline void chainA(Eigen::Matrix<var, Ra, Ca>& A,
61 const Eigen::Matrix<double, Rb, Cb>& Bd,
62 const Eigen::Matrix<double, Cb, Cb>& adjC) {
63 A.adj() += Bd * adjC * Bd.transpose();
64 }
65 inline void chainB(Eigen::Matrix<var, Rb, Cb>& B,
66 const Eigen::Matrix<double, Ra, Ca>& Ad,
67 const Eigen::Matrix<double, Rb, Cb>& Bd,
68 const Eigen::Matrix<double, Cb, Cb>& adjC) {
69 B.adj() += Ad * Bd * adjC.transpose() + Ad.transpose() * Bd * adjC;
70 }
71
72 inline void chainAB(Eigen::Matrix<Ta, Ra, Ca>& A,
73 Eigen::Matrix<Tb, Rb, Cb>& B,
74 const Eigen::Matrix<double, Ra, Ca>& Ad,
75 const Eigen::Matrix<double, Rb, Cb>& Bd,
76 const Eigen::Matrix<double, Cb, Cb>& adjC) {
77 chainA(A, Bd, adjC);
78 chainB(B, Ad, Bd, adjC);
79 }
80
81 public:
82 quad_form_vari(const Eigen::Matrix<Ta, Ra, Ca>& A,
83 const Eigen::Matrix<Tb, Rb, Cb>& B, bool symmetric = false)
84 : vari(0.0) {
86 }
87
88 virtual void chain() {
89 matrix_d adjC = impl_->C_.adj();
90
91 chainAB(impl_->A_, impl_->B_, value_of(impl_->A_), value_of(impl_->B_),
92 adjC);
93 }
94
96};
97
114template <typename Mat1, typename Mat2,
117inline auto quad_form_impl(const Mat1& A, const Mat2& B, bool symmetric) {
118 check_square("quad_form", "A", A);
119 check_multiplicable("quad_form", "A", A, "B", B);
120
121 using return_t
122 = return_var_matrix_t<decltype(value_of(B).transpose().eval()
123 * value_of(A) * value_of(B).eval()),
124 Mat1, Mat2>;
125
129
130 check_not_nan("multiply", "A", value_of(arena_A));
131 check_not_nan("multiply", "B", value_of(arena_B));
132
133 auto arena_res = to_arena(value_of(arena_B).transpose() * value_of(arena_A)
134 * value_of(arena_B));
135
136 if (symmetric) {
137 arena_res += arena_res.transpose().eval();
138 }
139
140 return_t res = arena_res;
141
142 reverse_pass_callback([arena_A, arena_B, res]() mutable {
143 auto C_adj_B_t = (res.adj() * value_of(arena_B).transpose()).eval();
144
146 arena_A.adj().noalias() += value_of(arena_B) * C_adj_B_t;
147 } else {
148 arena_A.adj() += value_of(arena_B) * C_adj_B_t;
149 }
150
152 arena_B.adj().noalias()
153 += value_of(arena_A) * C_adj_B_t.transpose()
154 + value_of(arena_A).transpose() * value_of(arena_B) * res.adj();
155 } else {
156 arena_B.adj()
157 += value_of(arena_A) * C_adj_B_t.transpose()
158 + value_of(arena_A).transpose() * value_of(arena_B) * res.adj();
159 }
160 });
161
162 return res;
163 } else if (!is_constant<Mat2>::value) {
166
167 check_not_nan("multiply", "A", arena_A);
168 check_not_nan("multiply", "B", arena_B.val());
169
170 auto arena_res
171 = to_arena(value_of(arena_B).transpose() * arena_A * value_of(arena_B));
172
173 if (symmetric) {
174 arena_res += arena_res.transpose().eval();
175 }
176
177 return_t res = arena_res;
178
179 reverse_pass_callback([arena_A, arena_B, res]() mutable {
180 auto C_adj_B_t = (res.adj() * value_of(arena_B).transpose());
181
183 arena_B.adj().noalias()
184 += arena_A * C_adj_B_t.transpose()
185 + arena_A.transpose() * value_of(arena_B) * res.adj();
186 } else {
187 arena_B.adj() += arena_A * C_adj_B_t.transpose()
188 + arena_A.transpose() * value_of(arena_B) * res.adj();
189 }
190 });
191
192 return res;
193 } else {
196
197 check_not_nan("multiply", "A", value_of(arena_A));
198 check_not_nan("multiply", "B", arena_B);
199
200 auto arena_res
201 = to_arena(arena_B.transpose() * value_of(arena_A) * arena_B);
202
203 if (symmetric) {
204 arena_res += arena_res.transpose().eval();
205 }
206
207 return_t res = arena_res;
208
209 reverse_pass_callback([arena_A, arena_B, res]() mutable {
210 auto C_adj_B_t = (res.adj() * arena_B.transpose());
211
213 arena_A.adj().noalias() += arena_B * C_adj_B_t;
214 } else {
215 arena_A.adj() += arena_B * C_adj_B_t;
216 }
217 });
218
219 return res;
220 }
221}
222} // namespace internal
223
240template <typename EigMat1, typename EigMat2,
245 const EigMat2& B,
246 bool symmetric = false) {
247 check_square("quad_form", "A", A);
248 check_multiplicable("quad_form", "A", A, "B", B);
249
250 auto* baseVari = new internal::quad_form_vari<
251 value_type_t<EigMat1>, EigMat1::RowsAtCompileTime,
252 EigMat1::ColsAtCompileTime, value_type_t<EigMat2>,
253 EigMat2::RowsAtCompileTime, EigMat2::ColsAtCompileTime>(A, B, symmetric);
254
255 return baseVari->impl_->C_;
256}
257
271template <typename EigMat, typename ColVec, require_eigen_t<EigMat>* = nullptr,
272 require_eigen_col_vector_t<ColVec>* = nullptr,
273 require_any_vt_var<EigMat, ColVec>* = nullptr>
274inline var quad_form(const EigMat& A, const ColVec& B, bool symmetric = false) {
275 check_square("quad_form", "A", A);
276 check_multiplicable("quad_form", "A", A, "B", B);
277
278 auto* baseVari = new internal::quad_form_vari<
279 value_type_t<EigMat>, EigMat::RowsAtCompileTime,
280 EigMat::ColsAtCompileTime, value_type_t<ColVec>,
281 ColVec::RowsAtCompileTime, 1>(A, B, symmetric);
282
283 return baseVari->impl_->C_(0, 0);
284}
285
307template <typename Mat1, typename Mat2,
311inline auto quad_form(const Mat1& A, const Mat2& B, bool symmetric = false) {
312 return internal::quad_form_impl(A, B, symmetric);
313}
314
333template <typename Mat, typename Vec, require_matrix_t<Mat>* = nullptr,
334 require_col_vector_t<Vec>* = nullptr,
335 require_any_var_matrix_t<Mat, Vec>* = nullptr>
336inline var quad_form(const Mat& A, const Vec& B, bool symmetric = false) {
337 return internal::quad_form_impl(A, B, symmetric)(0, 0);
338}
339
340} // namespace math
341} // namespace stan
342#endif
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
Eigen::Matrix< var, Cb, Cb > C_
Definition quad_form.hpp:45
quad_form_vari_alloc(const Eigen::Matrix< Ta, Ra, Ca > &A, const Eigen::Matrix< Tb, Rb, Cb > &B, bool symmetric=false)
Definition quad_form.hpp:36
void compute(const Eigen::Matrix< double, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &B)
Definition quad_form.hpp:21
quad_form_vari(const Eigen::Matrix< Ta, Ra, Ca > &A, const Eigen::Matrix< Tb, Rb, Cb > &B, bool symmetric=false)
Definition quad_form.hpp:82
void chainA(Eigen::Matrix< double, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition quad_form.hpp:52
void chainB(Eigen::Matrix< var, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition quad_form.hpp:65
void chainB(Eigen::Matrix< double, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition quad_form.hpp:55
void chainAB(Eigen::Matrix< Ta, Ra, Ca > &A, Eigen::Matrix< Tb, Rb, Cb > &B, const Eigen::Matrix< double, Ra, Ca > &Ad, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition quad_form.hpp:72
quad_form_vari_alloc< Ta, Ra, Ca, Tb, Rb, Cb > * impl_
Definition quad_form.hpp:95
void chainA(Eigen::Matrix< var, Ra, Ca > &A, const Eigen::Matrix< double, Rb, Cb > &Bd, const Eigen::Matrix< double, Cb, Cb > &adjC)
Definition quad_form.hpp:60
require_not_t< is_col_vector< std::decay_t< T > > > require_not_col_vector_t
Require type does not satisfy is_col_vector.
require_not_t< is_eigen_col_vector< std::decay_t< T > > > require_not_eigen_col_vector_t
Require type does not satisfy is_eigen_col_vector.
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
Definition is_eigen.hpp:65
require_all_t< is_matrix< std::decay_t< Types > >... > require_all_matrix_t
Require all of the types satisfy is_matrix.
Definition is_matrix.hpp:38
auto transpose(Arg &&a)
Transposes a kernel generator expression.
int cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
Definition cols.hpp:20
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
require_any_t< is_var_matrix< std::decay_t< Types > >... > require_any_var_matrix_t
Require any of the types satisfy is_var_matrix.
require_any_t< is_var< value_type_t< std::decay_t< Types > > >... > require_any_vt_var
Require any of the value_types satisfy is_var.
Definition is_var.hpp:99
auto quad_form_impl(const Mat1 &A, const Mat2 &B, bool symmetric)
Return the quadratic form .
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
vari_value< double > vari
Definition vari.hpp:197
promote_scalar_t< return_type_t< EigMat1, EigMat2 >, EigMat2 > quad_form(const EigMat1 &A, const EigMat2 &B)
Return the quadratic form .
Definition quad_form.hpp:31
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
void check_not_nan(const char *function, const char *name, const T_y &y)
Check if y is not NaN.
var_value< double > var
Definition var.hpp:1187
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > matrix_d
Type for matrix of double values.
Definition typedefs.hpp:19
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
std::conditional_t< is_any_var_matrix< ReturnType, Types... >::value, stan::math::var_value< stan::math::promote_scalar_t< double, plain_type_t< ReturnType > > >, stan::math::promote_scalar_t< stan::math::var_value< double >, plain_type_t< ReturnType > > > return_var_matrix_t
Given an Eigen type and several inputs, determine if a matrix should be var<Matrix> or Matrix<var>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Check if a type is a var_value whose value_type is derived from Eigen::EigenBase