1#ifndef STAN_MATH_REV_MAT_FUN_CHOLESKY_DECOMPOSE_HPP
2#define STAN_MATH_REV_MAT_FUN_CHOLESKY_DECOMPOSE_HPP
21template <
typename LMat,
typename LAMat>
23 for (Eigen::Index j = 0; j < L_A.rows(); ++j) {
24 for (Eigen::Index i = 0; i < L_A.rows(); ++i) {
26 L.coeffRef(i, j) = dummy;
28 L.coeffRef(i, j) =
new vari(L_A.coeffRef(i, j),
false);
46template <
typename T1,
typename T2,
typename T3>
48 return [L_A, L, A]()
mutable {
49 const size_t N = A.rows();
51 Eigen::Matrix<double, -1, -1, Eigen::RowMajor> adjL(L.rows(), L.cols());
52 Eigen::MatrixXd adjA = Eigen::MatrixXd::Zero(L.rows(), L.cols());
53 adjL.template triangularView<Eigen::Lower>() = L.adj();
54 for (
int i = N - 1; i >= 0; --i) {
55 for (
int j = i; j >= 0; --j) {
57 adjA.coeffRef(i, j) = 0.5 * adjL.coeff(i, j) / L_A.coeff(i, j);
59 adjA.coeffRef(i, j) = adjL.coeff(i, j) / L_A.coeff(j, j);
61 -= adjL.coeff(i, j) * L_A.coeff(i, j) / L_A.coeff(j, j);
63 for (
int k = j - 1; k >= 0; --k) {
64 adjL.coeffRef(i, k) -= adjA.coeff(i, j) * L_A.coeff(j, k);
65 adjL.coeffRef(j, k) -= adjA.coeff(i, j) * L_A.coeff(i, k);
79template <
typename T1,
typename T2,
typename T3>
81 return [L_A, L, A]()
mutable {
83 using Eigen::StrictlyUpper;
85 Eigen::MatrixXd L_adj = Eigen::MatrixXd::Zero(L.rows(), L.cols());
86 L_adj.template triangularView<Eigen::Lower>() = L.adj();
87 const int M_ = L_A.rows();
88 int block_size_ = std::max(M_ / 8, 8);
89 block_size_ = std::min(block_size_, 128);
90 for (
int k = M_; k > 0; k -= block_size_) {
91 int j = std::max(0, k - block_size_);
92 auto R = L_A.block(j, 0, k - j, j);
93 auto D = L_A.block(j, j, k - j, k - j).eval();
94 auto B = L_A.block(k, 0, M_ - k, j);
95 auto C = L_A.block(k, j, M_ - k, k - j);
96 auto R_adj = L_adj.block(j, 0, k - j, j);
97 auto D_adj = L_adj.block(j, j, k - j, k - j);
98 auto B_adj = L_adj.block(k, 0, M_ - k, j);
99 auto C_adj = L_adj.block(k, j, M_ - k, k - j);
100 D.transposeInPlace();
101 if (C_adj.size() > 0) {
102 C_adj = D.template triangularView<Upper>()
103 .solve(C_adj.transpose())
105 B_adj.noalias() -= C_adj * R;
106 D_adj.noalias() -= C_adj.transpose() * C;
108 D_adj = (D * D_adj.template triangularView<Lower>()).
eval();
109 D_adj.template triangularView<StrictlyUpper>()
110 = D_adj.adjoint().template triangularView<StrictlyUpper>();
111 D.template triangularView<Upper>().solveInPlace(D_adj);
112 D.template triangularView<Upper>().solveInPlace(D_adj.transpose());
113 R_adj.noalias() -= C_adj.transpose() * B;
114 R_adj.noalias() -= D_adj.template selfadjointView<Lower>() * R;
115 D_adj.diagonal() *= 0.5;
117 A.adj().template triangularView<Eigen::Lower>() += L_adj;
133template <
typename EigMat, require_eigen_vt<is_var, EigMat>* =
nullptr>
137 arena_t<Eigen::Matrix<double, -1, -1>> L_A(arena_A.val());
140 Eigen::LLT<Eigen::Ref<Eigen::MatrixXd>, Eigen::Lower> L_factor(L_A);
143 L_A.template triangularView<Eigen::StrictlyUpper>().setZero();
148 if (L_A.rows() <= 35) {
169template <
typename T, require_var_matrix_t<T>* =
nullptr>
173 if (A.rows() <= 35) {
void check_symmetric(const char *function, const char *name, const matrix_cl< T > &y)
Check if the matrix_cl is symmetric.
void initialize_return(LMat &L, const LAMat &L_A, vari *&dummy)
auto unblocked_cholesky_lambda(T1 &L_A, T2 &L, T3 &A)
Reverse mode differentiation algorithm reference:
auto cholesky_lambda(T1 &L_A, T2 &L, T3 &A)
Reverse mode differentiation algorithm reference:
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
matrix_cl< double > cholesky_decompose(const matrix_cl< double > &A)
Returns the lower-triangular Cholesky factor (i.e., matrix square root) of the specified square,...
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
T eval(T &&arg)
Inputs which have a plain_type equal to the own time are forwarded unmodified (for Eigen expressions ...
void check_pos_definite(const char *function, const char *name, const EigMat &y)
Check if the specified square, symmetric matrix is positive definite.
vari_value< double > vari
typename plain_type< T >::type plain_type_t
typename internal::arena_type_impl< std::decay_t< T > >::type arena_t
Determines a type that can be used in place of T that does any dynamic allocations on the AD stack.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...