1#ifndef STAN_MATH_OPENCL_KERNELS_INVERSE_LOWER_TRI_MULTIPLY_HPP 
    2#define STAN_MATH_OPENCL_KERNELS_INVERSE_LOWER_TRI_MULTIPLY_HPP 
   11namespace opencl_kernels {
 
   13static constexpr const char* inv_lower_tri_multiply_kernel_code = 
STRINGIFY(
 
   47                                         __global 
double* temp,
 
   48                                         const int A_rows, 
const int rows) {
 
   49      int result_matrix_id = get_global_id(2);
 
   50      int offset = result_matrix_id * 
rows * 2;
 
   51      const int thread_block_row = get_local_id(0);
 
   52      const int thread_block_col = get_local_id(1);
 
   53      const int global_thread_row
 
   54          = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
 
   55      const int global_thread_col
 
   56          = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
 
   58      __local 
double C2_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
 
   59      __local 
double A3_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
 
   61      double acc[WORK_PER_THREAD] = {0};
 
   63      const int num_tiles = (
rows + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
 
   64      for (
int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
 
   67        for (
int w = 0; w < WORK_PER_THREAD; w++) {
 
   68          const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
 
   69          const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
 
   72          const int C2_global_col
 
   73              = offset + 
rows + tiled_j + w * THREAD_BLOCK_SIZE_COL;
 
   74          const int C2_global_row = offset + global_thread_row + 
rows;
 
   75          const int A3_global_col
 
   76              = offset + global_thread_col + w * THREAD_BLOCK_SIZE_COL;
 
   77          const int A3_global_row = tiled_i + 
rows + offset;
 
   80          const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
 
   81          const int local_row = thread_block_row;
 
   83          if (C2_global_col <= C2_global_row && C2_global_col < A_rows
 
   84              && C2_global_row < A_rows) {
 
   85            C2_local[local_col][local_row]
 
   86                = A[C2_global_col * A_rows + C2_global_row];
 
   88            C2_local[local_col][local_row] = 0;
 
   90          if (A3_global_col < A_rows && A3_global_row < A_rows) {
 
   91            A3_local[local_col][local_row]
 
   92                = A[A3_global_col * A_rows + A3_global_row];
 
   94            A3_local[local_col][local_row] = 0.0;
 
   98        barrier(CLK_LOCAL_MEM_FENCE);
 
   99        for (
int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
 
  100          for (
int w = 0; w < WORK_PER_THREAD; w++) {
 
  101            const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
 
  102            const int local_row = thread_block_row;
 
  103            acc[w] += C2_local[block_ind][local_row]
 
  104                      * A3_local[local_col][block_ind];
 
  107        barrier(CLK_LOCAL_MEM_FENCE);
 
  110      const int batch_offset = result_matrix_id * 
rows * 
rows;
 
  113      const int temp_global_row = global_thread_row;
 
  115      for (
int w = 0; w < WORK_PER_THREAD; w++) {
 
  117        const int temp_global_col
 
  118            = global_thread_col + w * THREAD_BLOCK_SIZE_COL;
 
  119        temp[batch_offset + temp_global_col * 
rows + temp_global_row] = acc[w];
 
  130    "inv_lower_tri_multiply",
 
  132    {{
"THREAD_BLOCK_SIZE", 32}, {
"WORK_PER_THREAD", 8}});
 
const kernel_cl< in_buffer, out_buffer, int, int > inv_lower_tri_multiply("inv_lower_tri_multiply", {thread_block_helpers, inv_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for add() .
 
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
 
static const std::string thread_block_helpers
Defines a helper macro for kernels with 2D local size.
 
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
 
Creates functor for kernels.