1#ifndef STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP 
    2#define STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP 
   11namespace opencl_kernels {
 
   13static constexpr const char* multiply_transpose_kernel_code = 
STRINGIFY(
 
   25                                     __global 
double* B, 
const int M,
 
   28      const int thread_block_row = get_local_id(0);
 
   29      const int thread_block_col = get_local_id(1);
 
   32      const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
 
   33      const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
 
   38      const int j_min = THREAD_BLOCK_SIZE * get_group_id(1);
 
   39      const int i_max = THREAD_BLOCK_SIZE * get_group_id(0) + get_local_size(0);
 
   42      __local 
double A_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
 
   43      __local 
double B_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
 
   45      double acc[WORK_PER_THREAD];
 
   46      for (
int w = 0; w < WORK_PER_THREAD; w++) {
 
   50        const int num_tiles = (N + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
 
   52        for (
int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
 
   54          const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
 
   55          const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
 
   59          for (
int w = 0; w < WORK_PER_THREAD; w++) {
 
   60            const int A_temp_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
 
   61            const int AT_temp_j = j + w * THREAD_BLOCK_SIZE_COL;
 
   62            if (A_temp_j >= N || i >= M) {
 
   63              A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
 
   67              A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
 
   69                  = A[A_temp_j * M + i];
 
   71            if (AT_temp_j >= M || tiled_i >= N) {
 
   72              B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
 
   76              B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
 
   78                  = A[AT_temp_j + tiled_i * M];
 
   82          barrier(CLK_LOCAL_MEM_FENCE);
 
   84          for (
int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
 
   86            for (
int w = 0; w < WORK_PER_THREAD; w++) {
 
   87              if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
 
   88                acc[w] += A_local[block_ind][thread_block_row]
 
   89                          * B_local[thread_block_col
 
   90                                    + w * THREAD_BLOCK_SIZE_COL][block_ind];
 
   94          barrier(CLK_LOCAL_MEM_FENCE);
 
   97        for (
int w = 0; w < WORK_PER_THREAD; w++) {
 
  103          if ((j + w * THREAD_BLOCK_SIZE_COL) < M && i < M) {
 
  104            if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
 
  105              B[i + (j + w * THREAD_BLOCK_SIZE_COL) * M] = acc[w];
 
  106              B[(j + w * THREAD_BLOCK_SIZE_COL) + i * M] = acc[w];
 
  120    "multiply_transpose",
 
  122    {{
"THREAD_BLOCK_SIZE", 32}, {
"WORK_PER_THREAD", 4}});
 
const kernel_cl< in_buffer, out_buffer, int, int > multiply_transpose("multiply_transpose", {thread_block_helpers, multiply_transpose_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 4}})
See the docs for add() .
 
static const std::string thread_block_helpers
Defines a helper macro for kernels with 2D local size.
 
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
 
Creates functor for kernels.