1#ifndef STAN_MATH_OPENCL_REV_CHOLESKY_DECOMPOSE_HPP
2#define STAN_MATH_OPENCL_REV_CHOLESKY_DECOMPOSE_HPP
25 require_all_kernel_expressions_and_none_scalar_t<T>* =
nullptr>
27 check_cl(
"cholesky_decompose (OpenCL)",
"A", A.val(),
"not NaN")
35 = M_ / opencl_context.tuning_opts().cholesky_rev_block_partition;
36 block_size = std::max(block_size, 8);
37 block_size = std::min(
39 opencl_context.tuning_opts().cholesky_rev_min_block_size);
40 matrix_cl<double> A_adj = L_A.adj();
41 for (int k = M_; k > 0; k -= block_size) {
42 const int j = std::max(0, k - block_size);
43 const int k_j_ind = k - j;
44 const int m_k_ind = M_ - k;
46 auto&& R_val = block_zero_based(L_A.val(), j, 0, k_j_ind, j);
47 auto&& R_adj = block_zero_based(A_adj, j, 0, k_j_ind, j);
48 matrix_cl<double> D_val
49 = block_zero_based(L_A.val(), j, j, k_j_ind, k_j_ind);
50 matrix_cl<double> D_adj
51 = block_zero_based(A_adj, j, j, k_j_ind, k_j_ind);
52 auto&& B_val = block_zero_based(L_A.val(), k, 0, m_k_ind, j);
53 auto&& B_adj = block_zero_based(A_adj, k, 0, m_k_ind, j);
54 auto&& C_val = block_zero_based(L_A.val(), k, j, m_k_ind, k_j_ind);
55 auto&& C_adj = block_zero_based(A_adj, k, j, m_k_ind, k_j_ind);
57 C_adj = C_adj * tri_inverse(D_val);
58 B_adj = B_adj - C_adj * R_val;
59 D_adj = D_adj - transpose(C_adj) * C_val;
61 D_adj = symmetrize_from_lower_tri(transpose(D_val) * D_adj);
62 D_val = transpose(tri_inverse(D_val));
63 D_adj = symmetrize_from_lower_tri(D_val * transpose(D_val * D_adj));
65 R_adj = R_adj - transpose(C_adj) * B_val - D_adj * R_val;
66 diagonal(D_adj) = diagonal(D_adj) * 0.5;
68 block_zero_based(A_adj, j, j, k_j_ind, k_j_ind) = D_adj;
Represents an arithmetic matrix on the OpenCL device.
auto check_cl(const char *function, const char *var_name, T &&y, const char *must_be)
Constructs a check on opencl matrix or expression.
matrix_cl< double > cholesky_decompose(const matrix_cl< double > &A)
Returns the lower-triangular Cholesky factor (i.e., matrix square root) of the specified square,...
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...