1#ifndef STAN_MATH_REV_FUN_MDIVIDE_LEFT_TRI_HPP
2#define STAN_MATH_REV_FUN_MDIVIDE_LEFT_TRI_HPP
17template <Eigen::UpLoType TriView,
int R1,
int C1,
int R2,
int C2>
29 const Eigen::Matrix<var, R2, C2> &B)
33 A_(reinterpret_cast<double *>(
36 C_(reinterpret_cast<double *>(
41 * (A.
rows() + 1) / 2))),
51 if (TriView == Eigen::Lower) {
57 }
else if (TriView == Eigen::Upper) {
65 Map<matrix_d> c_map(
C_,
M_,
N_);
66 Map<matrix_d> a_map(
A_,
M_,
M_);
71 c_map = a_map.template triangularView<TriView>().solve(c_map);
74 = c_map.unaryExpr([](
double x) {
return new vari(x,
false); });
82 adjB = Map<matrix_d>(
A_,
M_,
M_)
83 .template triangularView<TriView>()
86 adjA = -adjB * Map<matrix_d>(
C_,
M_,
N_).transpose();
89 if (TriView == Eigen::Lower) {
90 for (
size_type j = 0; j < adjA.cols(); j++) {
91 for (
size_type i = j; i < adjA.rows(); i++) {
95 }
else if (TriView == Eigen::Upper) {
96 for (
size_type j = 0; j < adjA.cols(); j++) {
106template <Eigen::UpLoType TriView,
int R1,
int C1,
int R2,
int C2>
117 const Eigen::Matrix<var, R2, C2> &B)
121 A_(reinterpret_cast<double *>(
124 C_(reinterpret_cast<double *>(
135 Map<matrix_d>(
A_,
M_,
M_) = A;
137 Map<matrix_d> c_map(
C_,
M_,
N_);
140 c_map = Map<matrix_d>(
A_,
M_,
M_)
141 .template triangularView<TriView>()
145 = c_map.unaryExpr([](
double x) {
return new vari(x,
false); });
152 += Map<matrix_d>(
A_,
M_,
M_)
153 .template triangularView<TriView>()
159template <Eigen::UpLoType TriView,
int R1,
int C1,
int R2,
int C2>
170 const Eigen::Matrix<double, R2, C2> &B)
174 A_(reinterpret_cast<double *>(
177 C_(reinterpret_cast<double *>(
182 * (A.
rows() + 1) / 2))),
190 if (TriView == Eigen::Lower) {
196 }
else if (TriView == Eigen::Upper) {
203 Map<matrix_d> Ad(
A_,
M_,
M_);
204 Map<matrix_d> Cd(
C_,
M_,
N_);
207 Cd = Ad.template triangularView<TriView>().solve(B);
210 = Cd.unaryExpr([](
double x) {
return new vari(x,
false); });
216 Matrix<double, R1, C1> adjA(
M_,
M_);
217 Matrix<double, R1, C2> adjC(
M_,
N_);
222 = -Map<Matrix<double, R1, C1>>(
A_,
M_,
M_)
223 .
template triangularView<TriView>()
226 * Map<Matrix<double, R1, C2>>(
C_,
M_,
N_).transpose());
229 if (TriView == Eigen::Lower) {
230 for (
size_type j = 0; j < adjA.cols(); j++) {
231 for (
size_type i = j; i < adjA.rows(); i++) {
235 }
else if (TriView == Eigen::Upper) {
236 for (
size_type j = 0; j < adjA.cols(); j++) {
246template <Eigen::UpLoType TriView,
typename T1,
typename T2,
248inline Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime>
253 return {0, b.cols()};
261 TriView, T1::RowsAtCompileTime, T1::ColsAtCompileTime,
262 T2::RowsAtCompileTime, T2::ColsAtCompileTime>(A, b);
264 Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime> res(
267 = Eigen::Map<matrix_vi>(&(baseVari->variRefC_[0]), b.rows(), b.cols());
271template <Eigen::UpLoType TriView,
typename T1,
typename T2,
274inline Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime>
279 return {0, b.cols()};
287 TriView, T1::RowsAtCompileTime, T1::ColsAtCompileTime,
288 T2::RowsAtCompileTime, T2::ColsAtCompileTime>(A, b);
290 Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime> res(
293 = Eigen::Map<matrix_vi>(&(baseVari->variRefC_[0]), b.rows(), b.cols());
297template <Eigen::UpLoType TriView,
typename T1,
typename T2,
300inline Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime>
305 return {0, b.cols()};
312 auto *baseVari =
new internal::mdivide_left_tri_vd_vari<
313 TriView, T1::RowsAtCompileTime, T1::ColsAtCompileTime,
314 T2::RowsAtCompileTime, T2::ColsAtCompileTime>(A, b);
316 Eigen::Matrix<var, T1::RowsAtCompileTime, T2::ColsAtCompileTime> res(
319 = Eigen::Map<matrix_vi>(&(baseVari->variRefC_[0]), b.rows(), b.cols());
343template <Eigen::UpLoType TriView,
typename T1,
typename T2,
344 require_all_matrix_t<T1, T2> * =
nullptr,
345 require_any_var_matrix_t<T1, T2> * =
nullptr>
348 using ret_type = var_value<ret_val_type>;
351 return ret_type(ret_val_type(0, B.cols()));
357 if (!is_constant<T1>::value && !is_constant<T2>::value) {
358 arena_t<promote_scalar_t<var, T1>> arena_A = A;
359 arena_t<promote_scalar_t<var, T2>> arena_B = B;
360 auto arena_A_val =
to_arena(arena_A.val());
362 arena_t<ret_type> res
363 = arena_A_val.template triangularView<TriView>().solve(arena_B.val());
366 promote_scalar_t<double, T2> adjB
367 = arena_A_val.template triangularView<TriView>().transpose().solve(
370 arena_B.adj() += adjB;
371 arena_A.adj() -= (adjB * res.val().transpose().eval())
372 .
template triangularView<TriView>();
375 return ret_type(res);
376 }
else if (!is_constant<T1>::value) {
377 arena_t<promote_scalar_t<var, T1>> arena_A = A;
378 auto arena_A_val =
to_arena(arena_A.val());
380 arena_t<ret_type> res
381 = arena_A_val.template triangularView<TriView>().solve(
value_of(B));
384 promote_scalar_t<double, T2> adjB
385 = arena_A_val.template triangularView<TriView>().transpose().solve(
388 arena_A.adj() -= (adjB * res.val().transpose().eval())
389 .
template triangularView<TriView>();
392 return ret_type(res);
394 arena_t<promote_scalar_t<double, T1>> arena_A =
value_of(A);
395 arena_t<promote_scalar_t<var, T2>> arena_B = B;
397 arena_t<ret_type> res
398 = arena_A.template triangularView<TriView>().solve(arena_B.val());
401 promote_scalar_t<double, T2> adjB
402 = arena_A.template triangularView<TriView>().transpose().solve(
405 arena_B.adj() += adjB;
408 return ret_type(res);
mdivide_left_tri_dv_vari(const Eigen::Matrix< double, R1, C1 > &A, const Eigen::Matrix< var, R2, C2 > &B)
mdivide_left_tri_vd_vari(const Eigen::Matrix< var, R1, C1 > &A, const Eigen::Matrix< double, R2, C2 > &B)
mdivide_left_tri_vv_vari(const Eigen::Matrix< var, R1, C1 > &A, const Eigen::Matrix< var, R2, C2 > &B)
require_all_t< container_type_check_base< is_eigen, value_type_t, TypeCheck, Check >... > require_all_eigen_vt
Require all of the types satisfy is_eigen.
require_t< container_type_check_base< is_eigen, value_type_t, TypeCheck, Check... > > require_eigen_vt
Require type satisfies is_eigen.
int64_t cols(const T_x &x)
Returns the number of columns in the specified kernel generator expression.
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
void check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the matrices can be multiplied.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
vari_value< double > vari
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic >::Index size_type
Type for sizes and indexes in an Eigen matrix with double elements.
auto mdivide_left_tri(const T1 &A, const T2 &b)
Returns the solution of the system Ax=b when A is triangular.
Eigen::Matrix< double, Eigen::Dynamic, Eigen::Dynamic > matrix_d
Type for matrix of double values.
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
This struct always provides access to the autodiff stack using the singleton pattern.