1#ifndef STAN_MATH_OPENCL_TRI_INVERSE_HPP 
    2#define STAN_MATH_OPENCL_TRI_INVERSE_HPP 
   39          require_matrix_cl_st<std::is_floating_point, T>* = 
nullptr>
 
   57  int thread_block_2D_dim = 32;
 
   60  int thread_block_size_1D
 
   61      = (((A.rows() / 32) + thread_block_2D_dim - 1) / thread_block_2D_dim)
 
   62        * thread_block_2D_dim;
 
   63  if (max_1D_thread_block_size < thread_block_size_1D) {
 
   64    thread_block_size_1D = max_1D_thread_block_size;
 
   66  int max_2D_thread_block_dim = std::sqrt(max_1D_thread_block_size);
 
   67  if (max_2D_thread_block_dim < thread_block_2D_dim) {
 
   68    thread_block_2D_dim = max_2D_thread_block_dim;
 
   71  if (thread_block_size_1D < 64) {
 
   72    thread_block_size_1D = 32;
 
   74  if (A.rows() < thread_block_size_1D) {
 
   75    thread_block_size_1D = A.rows();
 
   80      = ((A.rows() + thread_block_size_1D - 1) / thread_block_size_1D)
 
   81        * thread_block_size_1D;
 
   87      = 
constant(0.0, A_rows_padded - A.rows(), A_rows_padded);
 
   95  int parts = inv_padded.rows() / thread_block_size_1D;
 
   96  block_zero_based(inv_padded, 0, 0, inv_mat.rows(), inv_mat.rows()) = inv_mat;
 
  100        cl::NDRange(parts, thread_block_size_1D, thread_block_size_1D), temp,
 
  101        thread_block_size_1D, temp.size());
 
  104                             cl::NDRange(thread_block_size_1D), inv_padded,
 
  105                             temp, inv_padded.rows());
 
  106  } 
catch (cl::Error& 
e) {
 
  113  inv_padded.template zeros_strict_tri<stan::math::matrix_cl_view::Upper>();
 
  123  parts = 
ceil(parts / 2.0);
 
  125  auto result_matrix_dim = thread_block_size_1D;
 
  126  auto thread_block_work2d_dim = thread_block_2D_dim / work_per_thread;
 
  128      = cl::NDRange(thread_block_2D_dim, thread_block_work2d_dim, 1);
 
  130    int result_matrix_dim_x = result_matrix_dim;
 
  133    if (parts == 1 && (inv_padded.rows() - result_matrix_dim * 2) < 0) {
 
  134      result_matrix_dim_x = inv_padded.rows() - result_matrix_dim;
 
  136    auto result_work_dim = result_matrix_dim / work_per_thread;
 
  138        = cl::NDRange(result_matrix_dim_x, result_work_dim, parts);
 
  140                                           inv_padded, temp, inv_padded.rows(),
 
  143        result_ndrange, ndrange_2d, inv_padded, temp, inv_padded.rows(),
 
  149      parts = 
ceil(parts / 2.0);
 
  151    result_matrix_dim *= 2;
 
  156    inv_padded.template zeros_strict_tri<stan::math::matrix_cl_view::Upper>();
 
  159  inv_mat = 
block_zero_based(inv_padded, 0, 0, inv_mat.rows(), inv_mat.rows());
 
  163  inv_mat.view(tri_view);
 
The API to access the methods and values in opencl_context_base.
 
void check_triangular(const char *function, const char *name, const T &A)
Check if the matrix_cl is either upper triangular or lower triangular.
 
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
 
int max_thread_block_size() noexcept
Returns the maximum thread block size defined by CL_DEVICE_MAX_WORK_GROUP_SIZE for the device in the ...
 
auto block_zero_based(T &&a, int start_row, int start_col, int rows, int cols)
Block of a kernel generator expression.
 
auto transpose(Arg &&a)
Transposes a kernel generator expression.
 
elt_divide_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_divide(T_a &&a, T_b &&b)
 
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
 
auto diagonal(T &&a)
Diagonal of a kernel generator expression.
 
const kernel_cl< in_out_buffer, in_out_buffer, int > diag_inv("diag_inv", {indexing_helpers, diag_inv_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}})
See the docs for add() .
 
const kernel_cl< out_buffer, int, int > batch_identity("batch_identity", {indexing_helpers, batch_identity_kernel_code})
See the docs for batch_identity() .
 
const kernel_cl< in_buffer, out_buffer, int, int > inv_lower_tri_multiply("inv_lower_tri_multiply", {thread_block_helpers, inv_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for add() .
 
const kernel_cl< in_out_buffer, in_buffer, int, int > neg_rect_lower_tri_multiply("neg_rect_lower_tri_multiply", {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for neg_rect_lower_tri_multiply() .
 
plain_type_t< T > tri_inverse(const T &A)
Computes the inverse of a triangular matrix.
 
void check_square(const char *function, const char *name, const T_y &y)
Check if the specified matrix is square.
 
static constexpr double e()
Return the base of the natural logarithm.
 
fvar< T > ceil(const fvar< T > &x)
 
typename plain_type< std::decay_t< T > >::type plain_type_t
 
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...