1#ifndef STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP
2#define STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP
11namespace opencl_kernels {
13static constexpr const char* multiply_transpose_kernel_code =
STRINGIFY(
25 __global
double* B,
const int M,
28 const int thread_block_row = get_local_id(0);
29 const int thread_block_col = get_local_id(1);
32 const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
33 const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
38 const int j_min = THREAD_BLOCK_SIZE * get_group_id(1);
39 const int i_max = THREAD_BLOCK_SIZE * get_group_id(0) + get_local_size(0);
42 __local
double A_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
43 __local
double B_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
45 double acc[WORK_PER_THREAD];
46 for (
int w = 0; w < WORK_PER_THREAD; w++) {
50 const int num_tiles = (N + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
52 for (
int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
54 const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
55 const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
59 for (
int w = 0; w < WORK_PER_THREAD; w++) {
60 const int A_temp_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
61 const int AT_temp_j = j + w * THREAD_BLOCK_SIZE_COL;
62 if (A_temp_j >= N || i >= M) {
63 A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
67 A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
69 = A[A_temp_j * M + i];
71 if (AT_temp_j >= M || tiled_i >= N) {
72 B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
76 B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
78 = A[AT_temp_j + tiled_i * M];
82 barrier(CLK_LOCAL_MEM_FENCE);
84 for (
int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
86 for (
int w = 0; w < WORK_PER_THREAD; w++) {
87 if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
88 acc[w] += A_local[block_ind][thread_block_row]
89 * B_local[thread_block_col
90 + w * THREAD_BLOCK_SIZE_COL][block_ind];
94 barrier(CLK_LOCAL_MEM_FENCE);
97 for (
int w = 0; w < WORK_PER_THREAD; w++) {
103 if ((j + w * THREAD_BLOCK_SIZE_COL) < M && i < M) {
104 if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
105 B[i + (j + w * THREAD_BLOCK_SIZE_COL) * M] = acc[w];
106 B[(j + w * THREAD_BLOCK_SIZE_COL) + i * M] = acc[w];
120 "multiply_transpose",
122 {{
"THREAD_BLOCK_SIZE", 32}, {
"WORK_PER_THREAD", 4}});
const kernel_cl< in_buffer, out_buffer, int, int > multiply_transpose("multiply_transpose", {thread_block_helpers, multiply_transpose_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 4}})
See the docs for add() .
static const std::string thread_block_helpers
Defines a helper macro for kernels with 2D local size.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Creates functor for kernels.