Automatic Differentiation
 
Loading...
Searching...
No Matches
mdivide_left_spd.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_MDIVIDE_LEFT_SPD_HPP
2#define STAN_MATH_REV_FUN_MDIVIDE_LEFT_SPD_HPP
3
11#include <vector>
12
13namespace stan {
14namespace math {
15namespace internal {
16
17template <int R1, int C1, int R2, int C2>
19 public:
21
22 Eigen::LLT<Eigen::Matrix<double, R1, C1>> llt_;
23 Eigen::Matrix<double, R2, C2> C_;
24};
25
26template <int R1, int C1, int R2, int C2>
28 public:
29 int M_; // A.rows() = A.cols() = B.rows()
30 int N_; // B.cols()
35
36 mdivide_left_spd_vv_vari(const Eigen::Matrix<var, R1, C1> &A,
37 const Eigen::Matrix<var, R2, C2> &B)
38 : vari(0.0),
39 M_(A.rows()),
40 N_(B.cols()),
41 variRefA_(reinterpret_cast<vari **>(
42 ChainableStack::instance_->memalloc_.alloc(sizeof(vari *) * A.rows()
43 * A.cols()))),
44 variRefB_(reinterpret_cast<vari **>(
45 ChainableStack::instance_->memalloc_.alloc(sizeof(vari *) * B.rows()
46 * B.cols()))),
47 variRefC_(reinterpret_cast<vari **>(
48 ChainableStack::instance_->memalloc_.alloc(sizeof(vari *) * B.rows()
49 * B.cols()))),
50 alloc_(new mdivide_left_spd_alloc<R1, C1, R2, C2>()) {
51 Eigen::Map<matrix_vi>(variRefA_, M_, M_) = A.vi();
52 Eigen::Map<matrix_vi>(variRefB_, M_, N_) = B.vi();
53 alloc_->C_ = B.val();
54 alloc_->llt_ = A.val().llt();
55 check_pos_definite("mdivide_left_spd", "A", alloc_->llt_);
56 alloc_->llt_.solveInPlace(alloc_->C_);
57
58 Eigen::Map<matrix_vi>(variRefC_, M_, N_)
59 = alloc_->C_.unaryExpr([](double x) { return new vari(x, false); });
60 }
61
62 virtual void chain() {
63 matrix_d adjB = Eigen::Map<matrix_vi>(variRefC_, M_, N_).adj();
64 alloc_->llt_.solveInPlace(adjB);
65 Eigen::Map<matrix_vi>(variRefA_, M_, M_).adj()
66 -= adjB * alloc_->C_.transpose();
67 Eigen::Map<matrix_vi>(variRefB_, M_, N_).adj() += adjB;
68 }
69};
70
71template <int R1, int C1, int R2, int C2>
73 public:
74 int M_; // A.rows() = A.cols() = B.rows()
75 int N_; // B.cols()
79
80 mdivide_left_spd_dv_vari(const Eigen::Matrix<double, R1, C1> &A,
81 const Eigen::Matrix<var, R2, C2> &B)
82 : vari(0.0),
83 M_(A.rows()),
84 N_(B.cols()),
85 variRefB_(reinterpret_cast<vari **>(
86 ChainableStack::instance_->memalloc_.alloc(sizeof(vari *) * B.rows()
87 * B.cols()))),
88 variRefC_(reinterpret_cast<vari **>(
89 ChainableStack::instance_->memalloc_.alloc(sizeof(vari *) * B.rows()
90 * B.cols()))),
91 alloc_(new mdivide_left_spd_alloc<R1, C1, R2, C2>()) {
92 alloc_->C_ = B.val();
93 Eigen::Map<matrix_vi>(variRefB_, M_, N_) = B.vi();
94 alloc_->llt_ = A.llt();
95 check_pos_definite("mdivide_left_spd", "A", alloc_->llt_);
96 alloc_->llt_.solveInPlace(alloc_->C_);
97
98 Eigen::Map<matrix_vi>(variRefC_, M_, N_)
99 = alloc_->C_.unaryExpr([](double x) { return new vari(x, false); });
100 }
101
102 virtual void chain() {
103 matrix_d adjB = Eigen::Map<matrix_vi>(variRefC_, M_, N_).adj();
104 alloc_->llt_.solveInPlace(adjB);
105 Eigen::Map<matrix_vi>(variRefB_, M_, N_).adj() += adjB;
106 }
107};
108
109template <int R1, int C1, int R2, int C2>
111 public:
112 int M_; // A.rows() = A.cols() = B.rows()
113 int N_; // B.cols()
117
118 mdivide_left_spd_vd_vari(const Eigen::Matrix<var, R1, C1> &A,
119 const Eigen::Matrix<double, R2, C2> &B)
120 : vari(0.0),
121 M_(A.rows()),
122 N_(B.cols()),
123 variRefA_(reinterpret_cast<vari **>(
124 ChainableStack::instance_->memalloc_.alloc(sizeof(vari *) * A.rows()
125 * A.cols()))),
126 variRefC_(reinterpret_cast<vari **>(
127 ChainableStack::instance_->memalloc_.alloc(sizeof(vari *) * B.rows()
128 * B.cols()))),
129 alloc_(new mdivide_left_spd_alloc<R1, C1, R2, C2>()) {
130 Eigen::Map<matrix_vi>(variRefA_, M_, M_) = A.vi();
131 alloc_->llt_ = A.val().llt();
132 check_pos_definite("mdivide_left_spd", "A", alloc_->llt_);
133 alloc_->C_ = alloc_->llt_.solve(B);
134
135 Eigen::Map<matrix_vi>(variRefC_, M_, N_)
136 = alloc_->C_.unaryExpr([](double x) { return new vari(x, false); });
137 }
138
139 virtual void chain() {
140 matrix_d adjC = Eigen::Map<matrix_vi>(variRefC_, M_, N_).adj();
141 Eigen::Map<matrix_vi>(variRefA_, M_, M_).adj()
142 -= alloc_->llt_.solve(adjC * alloc_->C_.transpose());
143 }
144};
145} // namespace internal
146
147template <
148 typename EigMat1, typename EigMat2,
150inline Eigen::Matrix<var, EigMat1::RowsAtCompileTime,
151 EigMat2::ColsAtCompileTime>
152mdivide_left_spd(const EigMat1 &A, const EigMat2 &b) {
153 constexpr int R1 = EigMat1::RowsAtCompileTime;
154 constexpr int C1 = EigMat1::ColsAtCompileTime;
155 constexpr int R2 = EigMat2::RowsAtCompileTime;
156 constexpr int C2 = EigMat2::ColsAtCompileTime;
157 static constexpr const char *function = "mdivide_left_spd";
158 check_multiplicable(function, "A", A, "b", b);
159 const auto &A_ref = to_ref(A);
160 check_symmetric(function, "A", A_ref);
161 check_not_nan(function, "A", A_ref);
162 if (A.size() == 0) {
163 return {0, b.cols()};
164 }
165
166 // NOTE: this is not a memory leak, this vari is used in the
167 // expression graph to evaluate the adjoint, but is not needed
168 // for the returned matrix. Memory will be cleaned up with the
169 // arena allocator.
172
173 Eigen::Matrix<var, R1, C2> res(b.rows(), b.cols());
174 res.vi() = Eigen::Map<matrix_vi>(&baseVari->variRefC_[0], b.rows(), b.cols());
175 return res;
176}
177
178template <typename EigMat1, typename EigMat2,
181inline Eigen::Matrix<var, EigMat1::RowsAtCompileTime,
182 EigMat2::ColsAtCompileTime>
183mdivide_left_spd(const EigMat1 &A, const EigMat2 &b) {
184 constexpr int R1 = EigMat1::RowsAtCompileTime;
185 constexpr int C1 = EigMat1::ColsAtCompileTime;
186 constexpr int R2 = EigMat2::RowsAtCompileTime;
187 constexpr int C2 = EigMat2::ColsAtCompileTime;
188 static constexpr const char *function = "mdivide_left_spd";
189 check_multiplicable(function, "A", A, "b", b);
190 const auto &A_ref = to_ref(A);
191 check_symmetric(function, "A", A_ref);
192 check_not_nan(function, "A", A_ref);
193 if (A.size() == 0) {
194 return {0, b.cols()};
195 }
196
197 // NOTE: this is not a memory leak, this vari is used in the
198 // expression graph to evaluate the adjoint, but is not needed
199 // for the returned matrix. Memory will be cleaned up with the
200 // arena allocator.
203
204 Eigen::Matrix<var, R1, C2> res(b.rows(), b.cols());
205 res.vi() = Eigen::Map<matrix_vi>(&baseVari->variRefC_[0], b.rows(), b.cols());
206 return res;
207}
208
209template <typename EigMat1, typename EigMat2,
212inline Eigen::Matrix<var, EigMat1::RowsAtCompileTime,
213 EigMat2::ColsAtCompileTime>
214mdivide_left_spd(const EigMat1 &A, const EigMat2 &b) {
215 constexpr int R1 = EigMat1::RowsAtCompileTime;
216 constexpr int C1 = EigMat1::ColsAtCompileTime;
217 constexpr int R2 = EigMat2::RowsAtCompileTime;
218 constexpr int C2 = EigMat2::ColsAtCompileTime;
219 static constexpr const char *function = "mdivide_left_spd";
220 check_multiplicable(function, "A", A, "b", b);
221 const auto &A_ref = to_ref(A);
222 check_symmetric(function, "A", A_ref);
223 check_not_nan(function, "A", A_ref);
224 if (A.size() == 0) {
225 return {0, b.cols()};
226 }
227
228 // NOTE: this is not a memory leak, this vari is used in the
229 // expression graph to evaluate the adjoint, but is not needed
230 // for the returned matrix. Memory will be cleaned up with the
231 // arena allocator.
232 internal::mdivide_left_spd_dv_vari<R1, C1, R2, C2> *baseVari
233 = new internal::mdivide_left_spd_dv_vari<R1, C1, R2, C2>(A_ref, b);
234
235 Eigen::Matrix<var, R1, C2> res(b.rows(), b.cols());
236 res.vi() = Eigen::Map<matrix_vi>(&baseVari->variRefC_[0], b.rows(), b.cols());
237
238 return res;
239}
240
259template <typename T1, typename T2, require_all_matrix_t<T1, T2> * = nullptr,
260 require_any_var_matrix_t<T1, T2> * = nullptr>
261inline auto mdivide_left_spd(const T1 &A, const T2 &B) {
262 using ret_val_type = plain_type_t<decltype(value_of(A) * value_of(B))>;
263 using ret_type = var_value<ret_val_type>;
264
265 if (A.size() == 0) {
266 return ret_type(ret_val_type(0, B.cols()));
267 }
268
269 check_multiplicable("mdivide_left_spd", "A", A, "B", B);
270
274
275 check_symmetric("mdivide_left_spd", "A", arena_A.val());
276 check_not_nan("mdivide_left_spd", "A", arena_A.val());
277
278 auto A_llt = arena_A.val().llt();
279
280 check_pos_definite("mdivide_left_spd", "A", A_llt);
281
282 arena_t<Eigen::MatrixXd> arena_A_llt = A_llt.matrixL();
283 arena_t<ret_type> res = A_llt.solve(arena_B.val());
284
285 reverse_pass_callback([arena_A, arena_B, arena_A_llt, res]() mutable {
286 promote_scalar_t<double, T2> adjB = res.adj();
287
288 arena_A_llt.template triangularView<Eigen::Lower>().solveInPlace(adjB);
289 arena_A_llt.template triangularView<Eigen::Lower>()
290 .transpose()
291 .solveInPlace(adjB);
292
293 arena_A.adj() -= adjB * res.val_op().transpose();
294 arena_B.adj() += adjB;
295 });
296
297 return ret_type(res);
298 } else if (!is_constant<T1>::value) {
300
301 check_symmetric("mdivide_left_spd", "A", arena_A.val());
302 check_not_nan("mdivide_left_spd", "A", arena_A.val());
303
304 auto A_llt = arena_A.val().llt();
305
306 check_pos_definite("mdivide_left_spd", "A", A_llt);
307
308 arena_t<Eigen::MatrixXd> arena_A_llt = A_llt.matrixL();
309 arena_t<ret_type> res = A_llt.solve(value_of(B));
310
311 reverse_pass_callback([arena_A, arena_A_llt, res]() mutable {
312 promote_scalar_t<double, T2> adjB = res.adj();
313
314 arena_A_llt.template triangularView<Eigen::Lower>().solveInPlace(adjB);
315 arena_A_llt.template triangularView<Eigen::Lower>()
316 .transpose()
317 .solveInPlace(adjB);
318
319 arena_A.adj() -= adjB * res.val().transpose().eval();
320 });
321
322 return ret_type(res);
323 } else {
324 const auto &A_ref = to_ref(value_of(A));
326
327 check_symmetric("mdivide_left_spd", "A", A_ref);
328 check_not_nan("mdivide_left_spd", "A", A_ref);
329
330 auto A_llt = A_ref.llt();
331
332 check_pos_definite("mdivide_left_spd", "A", A_llt);
333
334 arena_t<Eigen::MatrixXd> arena_A_llt = A_llt.matrixL();
335 arena_t<ret_type> res = A_llt.solve(arena_B.val());
336
337 reverse_pass_callback([arena_B, arena_A_llt, res]() mutable {
338 promote_scalar_t<double, T2> adjB = res.adj();
339
340 arena_A_llt.template triangularView<Eigen::Lower>().solveInPlace(adjB);
341 arena_A_llt.template triangularView<Eigen::Lower>()
342 .transpose()
343 .solveInPlace(adjB);
344
345 arena_B.adj() += adjB;
346 });
347
348 return ret_type(res);
349 }
350}
351
352} // namespace math
353} // namespace stan
354#endif
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
Eigen::LLT< Eigen::Matrix< double, R1, C1 > > llt_
mdivide_left_spd_dv_vari(const Eigen::Matrix< double, R1, C1 > &A, const Eigen::Matrix< var, R2, C2 > &B)
mdivide_left_spd_alloc< R1, C1, R2, C2 > * alloc_
mdivide_left_spd_alloc< R1, C1, R2, C2 > * alloc_
mdivide_left_spd_vd_vari(const Eigen::Matrix< var, R1, C1 > &A, const Eigen::Matrix< double, R2, C2 > &B)
mdivide_left_spd_alloc< R1, C1, R2, C2 > * alloc_
mdivide_left_spd_vv_vari(const Eigen::Matrix< var, R1, C1 > &A, const Eigen::Matrix< var, R2, C2 > &B)
require_t< container_type_check_base< is_eigen_matrix_base, value_type_t, TypeCheck, Check... > > require_eigen_matrix_base_vt
Require type satisfies is_eigen_matrix_base.
require_all_t< container_type_check_base< is_eigen_matrix_base, value_type_t, TypeCheck, Check >... > require_all_eigen_matrix_base_vt
Require all of the types satisfy is_eigen_matrix_base.
void check_symmetric(const char *function, const char *name, const matrix_cl< T > &y)
Check if the matrix_cl is symmetric.
int rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:21
int cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
Definition cols.hpp:20
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
Eigen::Matrix< return_type_t< EigMat1, EigMat2 >, EigMat1::RowsAtCompileTime, EigMat2::ColsAtCompileTime > mdivide_left_spd(const EigMat1 &A, const EigMat2 &b)
Returns the solution of the system Ax=b where A is symmetric positive definite.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_pos_definite(const char *function, const char *name, const EigMat &y)
Check if the specified square, symmetric matrix is positive definite.
vari_value< double > vari
Definition vari.hpp:197
ref_type_t< T && > to_ref(T &&a)
This evaluates expensive Eigen expressions.
Definition to_ref.hpp:17
void check_not_nan(const char *function, const char *name, const T_y &y)
Check if y is not NaN.
var_value< double > var
Definition var.hpp:1187
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > matrix_d
Type for matrix of double values.
Definition typedefs.hpp:19
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
This struct always provides access to the autodiff stack using the singleton pattern.