1#ifndef STAN_MATH_REV_FUN_MDIVIDE_LEFT_TRI_HPP
2#define STAN_MATH_REV_FUN_MDIVIDE_LEFT_TRI_HPP
16template <Eigen::UpLoType TriView,
int R1,
int C1,
int R2,
int C2>
28 const Eigen::Matrix<var, R2, C2> &B)
32 A_(reinterpret_cast<double *>(
35 C_(reinterpret_cast<double *>(
40 * (A.
rows() + 1) / 2))),
50 if (TriView == Eigen::Lower) {
56 }
else if (TriView == Eigen::Upper) {
64 Map<matrix_d> c_map(
C_,
M_,
N_);
65 Map<matrix_d> a_map(
A_,
M_,
M_);
70 c_map = a_map.template triangularView<TriView>().solve(c_map);
73 = c_map.unaryExpr([](
double x) {
return new vari(x,
false); });
81 adjB = Map<matrix_d>(
A_,
M_,
M_)
82 .template triangularView<TriView>()
85 adjA = -adjB * Map<matrix_d>(
C_,
M_,
N_).transpose();
88 if (TriView == Eigen::Lower) {
89 for (
size_type j = 0; j < adjA.cols(); j++) {
90 for (
size_type i = j; i < adjA.rows(); i++) {
94 }
else if (TriView == Eigen::Upper) {
95 for (
size_type j = 0; j < adjA.cols(); j++) {
105template <Eigen::UpLoType TriView,
int R1,
int C1,
int R2,
int C2>
116 const Eigen::Matrix<var, R2, C2> &B)
120 A_(reinterpret_cast<double *>(
123 C_(reinterpret_cast<double *>(
134 Map<matrix_d>(
A_,
M_,
M_) = A;
136 Map<matrix_d> c_map(
C_,
M_,
N_);
139 c_map = Map<matrix_d>(
A_,
M_,
M_)
140 .template triangularView<TriView>()
144 = c_map.unaryExpr([](
double x) {
return new vari(x,
false); });
151 += Map<matrix_d>(
A_,
M_,
M_)
152 .template triangularView<TriView>()
158template <Eigen::UpLoType TriView,
int R1,
int C1,
int R2,
int C2>
169 const Eigen::Matrix<double, R2, C2> &B)
173 A_(reinterpret_cast<double *>(
176 C_(reinterpret_cast<double *>(
181 * (A.
rows() + 1) / 2))),
189 if (TriView == Eigen::Lower) {
195 }
else if (TriView == Eigen::Upper) {
202 Map<matrix_d> Ad(
A_,
M_,
M_);
203 Map<matrix_d> Cd(
C_,
M_,
N_);
206 Cd = Ad.template triangularView<TriView>().solve(B);
209 = Cd.unaryExpr([](
double x) {
return new vari(x,
false); });
215 Matrix<double, R1, C1> adjA(
M_,
M_);
216 Matrix<double, R1, C2> adjC(
M_,
N_);
221 = -Map<Matrix<double, R1, C1>>(
A_,
M_,
M_)
222 .
template triangularView<TriView>()
225 * Map<Matrix<double, R1, C2>>(
C_,
M_,
N_).transpose());
228 if (TriView == Eigen::Lower) {
229 for (
size_type j = 0; j < adjA.cols(); j++) {
230 for (
size_type i = j; i < adjA.rows(); i++) {
234 }
else if (TriView == Eigen::Upper) {
235 for (
size_type j = 0; j < adjA.cols(); j++) {
245template <Eigen::UpLoType TriView,
typename T1,
typename T2,
247inline Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime>
252 return {0, b.cols()};
260 TriView, T1::RowsAtCompileTime, T1::ColsAtCompileTime,
261 T2::RowsAtCompileTime, T2::ColsAtCompileTime>(A, b);
263 Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime> res(
266 = Eigen::Map<matrix_vi>(&(baseVari->variRefC_[0]), b.rows(), b.cols());
270template <Eigen::UpLoType TriView,
typename T1,
typename T2,
273inline Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime>
278 return {0, b.cols()};
286 TriView, T1::RowsAtCompileTime, T1::ColsAtCompileTime,
287 T2::RowsAtCompileTime, T2::ColsAtCompileTime>(A, b);
289 Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime> res(
292 = Eigen::Map<matrix_vi>(&(baseVari->variRefC_[0]), b.rows(), b.cols());
296template <Eigen::UpLoType TriView,
typename T1,
typename T2,
299inline Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime>
304 return {0, b.cols()};
311 auto *baseVari =
new internal::mdivide_left_tri_vd_vari<
312 TriView, T1::RowsAtCompileTime, T1::ColsAtCompileTime,
313 T2::RowsAtCompileTime, T2::ColsAtCompileTime>(A, b);
315 Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime> res(
318 = Eigen::Map<matrix_vi>(&(baseVari->variRefC_[0]), b.rows(), b.cols());
342template <Eigen::UpLoType TriView,
typename T1,
typename T2,
343 require_all_matrix_t<T1, T2> * =
nullptr,
344 require_any_var_matrix_t<T1, T2> * =
nullptr>
347 using ret_type = var_value<ret_val_type>;
350 return ret_type(ret_val_type(0, B.cols()));
356 if (!is_constant<T1>::value && !is_constant<T2>::value) {
357 arena_t<promote_scalar_t<var, T1>> arena_A = A;
358 arena_t<promote_scalar_t<var, T2>> arena_B = B;
359 auto arena_A_val =
to_arena(arena_A.val());
361 arena_t<ret_type> res
362 = arena_A_val.template triangularView<TriView>().solve(arena_B.val());
365 promote_scalar_t<double, T2> adjB
366 = arena_A_val.template triangularView<TriView>().transpose().solve(
369 arena_B.adj() += adjB;
370 arena_A.adj() -= (adjB * res.val().transpose().eval())
371 .
template triangularView<TriView>();
374 return ret_type(res);
375 }
else if (!is_constant<T1>::value) {
376 arena_t<promote_scalar_t<var, T1>> arena_A = A;
377 auto arena_A_val =
to_arena(arena_A.val());
379 arena_t<ret_type> res
380 = arena_A_val.template triangularView<TriView>().solve(
value_of(B));
383 promote_scalar_t<double, T2> adjB
384 = arena_A_val.template triangularView<TriView>().transpose().solve(
387 arena_A.adj() -= (adjB * res.val().transpose().eval())
388 .
template triangularView<TriView>();
391 return ret_type(res);
393 arena_t<promote_scalar_t<double, T1>> arena_A =
value_of(A);
394 arena_t<promote_scalar_t<var, T2>> arena_B = B;
396 arena_t<ret_type> res
397 = arena_A.template triangularView<TriView>().solve(arena_B.val());
400 promote_scalar_t<double, T2> adjB
401 = arena_A.template triangularView<TriView>().transpose().solve(
404 arena_B.adj() += adjB;
407 return ret_type(res);
mdivide_left_tri_dv_vari(const Eigen::Matrix< double, R1, C1 > &A, const Eigen::Matrix< var, R2, C2 > &B)
mdivide_left_tri_vd_vari(const Eigen::Matrix< var, R1, C1 > &A, const Eigen::Matrix< double, R2, C2 > &B)
mdivide_left_tri_vv_vari(const Eigen::Matrix< var, R1, C1 > &A, const Eigen::Matrix< var, R2, C2 > &B)
require_all_t< container_type_check_base< is_eigen, value_type_t, TypeCheck, Check >... > require_all_eigen_vt
Require all of the types satisfy is_eigen.
require_t< container_type_check_base< is_eigen, value_type_t, TypeCheck, Check... > > require_eigen_vt
Require type satisfies is_eigen.
int64_t cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
vari_value< double > vari
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic >::Index size_type
Type for sizes and indexes in an Eigen matrix with double elements.
auto mdivide_left_tri(const T1 &A, const T2 &b)
Returns the solution of the system Ax=b when A is triangular.
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > matrix_d
Type for matrix of double values.
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
This struct always provides access to the autodiff stack using the singleton pattern.