Automatic Differentiation
 
Loading...
Searching...
No Matches
cholesky_decompose.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_REV_CHOLESKY_DECOMPOSE_HPP
2#define STAN_MATH_OPENCL_REV_CHOLESKY_DECOMPOSE_HPP
3#ifdef STAN_OPENCL
4
10
11namespace stan {
12namespace math {
13
24template <typename T,
25 require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
27 check_cl("cholesky_decompose (OpenCL)", "A", A.val(), "not NaN")
28 = !isnan(A.val());
29
30 return make_callback_var(
31 cholesky_decompose(A.val()),
32 [A](vari_value<matrix_cl<double>>& L_A) mutable {
33 int M_ = A.rows();
34 int block_size
35 = M_ / opencl_context.tuning_opts().cholesky_rev_block_partition;
36 block_size = std::max(block_size, 8);
37 block_size = std::min(
38 block_size,
39 opencl_context.tuning_opts().cholesky_rev_min_block_size);
40 matrix_cl<double> A_adj = L_A.adj();
41 for (int k = M_; k > 0; k -= block_size) {
42 const int j = std::max(0, k - block_size);
43 const int k_j_ind = k - j;
44 const int m_k_ind = M_ - k;
45
46 auto&& R_val = block_zero_based(L_A.val(), j, 0, k_j_ind, j);
47 auto&& R_adj = block_zero_based(A_adj, j, 0, k_j_ind, j);
48 matrix_cl<double> D_val
49 = block_zero_based(L_A.val(), j, j, k_j_ind, k_j_ind);
50 matrix_cl<double> D_adj
51 = block_zero_based(A_adj, j, j, k_j_ind, k_j_ind);
52 auto&& B_val = block_zero_based(L_A.val(), k, 0, m_k_ind, j);
53 auto&& B_adj = block_zero_based(A_adj, k, 0, m_k_ind, j);
54 auto&& C_val = block_zero_based(L_A.val(), k, j, m_k_ind, k_j_ind);
55 auto&& C_adj = block_zero_based(A_adj, k, j, m_k_ind, k_j_ind);
56
57 C_adj = C_adj * tri_inverse(D_val);
58 B_adj = B_adj - C_adj * R_val;
59 D_adj = D_adj - transpose(C_adj) * C_val;
60
61 D_adj = symmetrize_from_lower_tri(transpose(D_val) * D_adj);
62 D_val = transpose(tri_inverse(D_val));
63 D_adj = symmetrize_from_lower_tri(D_val * transpose(D_val * D_adj));
64
65 R_adj = R_adj - transpose(C_adj) * B_val - D_adj * R_val;
66 diagonal(D_adj) = diagonal(D_adj) * 0.5;
67
68 block_zero_based(A_adj, j, j, k_j_ind, k_j_ind) = D_adj;
69 }
70 A_adj.view(matrix_cl_view::Lower);
71 A.adj() += A_adj;
72 });
73}
74
75} // namespace math
76} // namespace stan
77
78#endif
79#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
auto check_cl(const char *function, const char *var_name, T &&y, const char *must_be)
Constructs a check on opencl matrix or expression.
Definition check_cl.hpp:219
matrix_cl< double > cholesky_decompose(const matrix_cl< double > &A)
Returns the lower-triangular Cholesky factor (i.e., matrix square root) of the specified square,...
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...