1#ifndef STAN_MATH_OPENCL_TRI_INVERSE_HPP
2#define STAN_MATH_OPENCL_TRI_INVERSE_HPP
39 require_matrix_cl_st<std::is_floating_point, T>* =
nullptr>
57 int thread_block_2D_dim = 32;
60 int thread_block_size_1D
61 = (((A.rows() / 32) + thread_block_2D_dim - 1) / thread_block_2D_dim)
62 * thread_block_2D_dim;
63 if (max_1D_thread_block_size < thread_block_size_1D) {
64 thread_block_size_1D = max_1D_thread_block_size;
66 int max_2D_thread_block_dim = std::sqrt(max_1D_thread_block_size);
67 if (max_2D_thread_block_dim < thread_block_2D_dim) {
68 thread_block_2D_dim = max_2D_thread_block_dim;
71 if (thread_block_size_1D < 64) {
72 thread_block_size_1D = 32;
74 if (A.rows() < thread_block_size_1D) {
75 thread_block_size_1D = A.rows();
80 = ((A.rows() + thread_block_size_1D - 1) / thread_block_size_1D)
81 * thread_block_size_1D;
87 =
constant(0.0, A_rows_padded - A.rows(), A_rows_padded);
95 int parts = inv_padded.rows() / thread_block_size_1D;
96 block_zero_based(inv_padded, 0, 0, inv_mat.rows(), inv_mat.rows()) = inv_mat;
100 cl::NDRange(parts, thread_block_size_1D, thread_block_size_1D), temp,
101 thread_block_size_1D, temp.size());
104 cl::NDRange(thread_block_size_1D), inv_padded,
105 temp, inv_padded.rows());
106 }
catch (cl::Error&
e) {
113 inv_padded.template zeros_strict_tri<stan::math::matrix_cl_view::Upper>();
123 parts =
ceil(parts / 2.0);
125 auto result_matrix_dim = thread_block_size_1D;
126 auto thread_block_work2d_dim = thread_block_2D_dim / work_per_thread;
128 = cl::NDRange(thread_block_2D_dim, thread_block_work2d_dim, 1);
130 int result_matrix_dim_x = result_matrix_dim;
133 if (parts == 1 && (inv_padded.rows() - result_matrix_dim * 2) < 0) {
134 result_matrix_dim_x = inv_padded.rows() - result_matrix_dim;
136 auto result_work_dim = result_matrix_dim / work_per_thread;
138 = cl::NDRange(result_matrix_dim_x, result_work_dim, parts);
140 inv_padded, temp, inv_padded.rows(),
143 result_ndrange, ndrange_2d, inv_padded, temp, inv_padded.rows(),
149 parts =
ceil(parts / 2.0);
151 result_matrix_dim *= 2;
156 inv_padded.template zeros_strict_tri<stan::math::matrix_cl_view::Upper>();
159 inv_mat =
block_zero_based(inv_padded, 0, 0, inv_mat.rows(), inv_mat.rows());
163 inv_mat.view(tri_view);
The API to access the methods and values in opencl_context_base.
void check_triangular(const char *function, const char *name, const T &A)
Check if the matrix_cl is either upper triangular or lower triangular.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
int max_thread_block_size() noexcept
Returns the maximum thread block size defined by CL_DEVICE_MAX_WORK_GROUP_SIZE for the device in the ...
auto block_zero_based(T &&a, int start_row, int start_col, int rows, int cols)
Block of a kernel generator expression.
auto transpose(Arg &&a)
Transposes a kernel generator expression.
elt_divide_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_divide(T_a &&a, T_b &&b)
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
auto diagonal(T &&a)
Diagonal of a kernel generator expression.
const kernel_cl< in_out_buffer, in_out_buffer, int > diag_inv("diag_inv", {indexing_helpers, diag_inv_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}})
See the docs for add() .
const kernel_cl< out_buffer, int, int > batch_identity("batch_identity", {indexing_helpers, batch_identity_kernel_code})
See the docs for batch_identity() .
const kernel_cl< in_buffer, out_buffer, int, int > inv_lower_tri_multiply("inv_lower_tri_multiply", {thread_block_helpers, inv_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for add() .
const kernel_cl< in_out_buffer, in_buffer, int, int > neg_rect_lower_tri_multiply("neg_rect_lower_tri_multiply", {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for neg_rect_lower_tri_multiply() .
plain_type_t< T > tri_inverse(const T &A)
Computes the inverse of a triangular matrix.
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
static constexpr double e()
Return the base of the natural logarithm.
fvar< T > ceil(const fvar< T > &x)
typename plain_type< T >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...