1#ifndef STAN_MATH_OPENCL_KERNELS_MATRIX_MULTIPLY_HPP
2#define STAN_MATH_OPENCL_KERNELS_MATRIX_MULTIPLY_HPP
12namespace opencl_kernels {
14static constexpr const char* matrix_multiply_kernel_code =
STRINGIFY(
29 const __global
double* B, __global
double* C,
30 const int M,
const int N,
const int K,
31 unsigned int view_A,
unsigned int view_B) {
33 const int row_in_block = get_local_id(0);
34 const int col_in_block = get_local_id(1);
36 const int group_id_row = get_group_id(0);
37 const int group_id_col = get_group_id(1);
39 const int i = THREAD_BLOCK_SIZE * group_id_row + row_in_block;
40 const int j = THREAD_BLOCK_SIZE * group_id_col + col_in_block;
42 const int split_id = get_global_id(2);
43 const int split_size = get_global_size(2);
45 __local
double A_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
46 __local
double B_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
48 double acc[WORK_PER_THREAD];
49 for (
int w = 0; w < WORK_PER_THREAD; w++) {
53 const int num_tiles = (K + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
62 int split_tiles = num_tiles / split_size;
63 const int split_remainder = num_tiles % split_size;
64 int split_offset_tiles = split_id * split_tiles;
65 if (split_id < split_remainder) {
66 split_offset_tiles = split_offset_tiles + split_id;
69 split_offset_tiles = split_offset_tiles + split_remainder;
88 : (i / THREAD_BLOCK_SIZE);
91 : (j / THREAD_BLOCK_SIZE);
92 const int start_tile_A
94 const int start_tile_B
101 int start_tile =
max(start_tile_A, start_tile_B);
102 start_tile =
max(start_tile, split_offset_tiles);
103 int end_tile =
min(end_tile_A, end_tile_B);
104 end_tile =
min(end_tile, split_offset_tiles + split_tiles - 1);
106 const int total_work_n =
min(
107 THREAD_BLOCK_SIZE, N - THREAD_BLOCK_SIZE * group_id_col);
108 const int total_work_m =
min(
109 THREAD_BLOCK_SIZE, M - THREAD_BLOCK_SIZE * group_id_row);
110 const int total_work_nm = total_work_n * total_work_m;
111 const int threads_in_block = THREAD_BLOCK_SIZE_COL * THREAD_BLOCK_SIZE;
112 const int linear_index
113 = get_local_id(0) + get_local_id(1) * THREAD_BLOCK_SIZE;
115 if (start_tile <= end_tile
117 || view_B == LOWER)) {
118 const int tiled_i = THREAD_BLOCK_SIZE * start_tile + row_in_block;
119 const int tiled_j = THREAD_BLOCK_SIZE * start_tile + col_in_block;
120 for (
int w = 0; w < WORK_PER_THREAD; w++) {
124 const int A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
125 const int B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
126 const int curr_k = col_in_block + w * THREAD_BLOCK_SIZE_COL;
130 if (A_curr_j >= K || i >= M || (view_A == LOWER && A_curr_j > i)
131 || (view_A == UPPER && A_curr_j < i)) {
132 A_local[curr_k][row_in_block] = 0.0;
134 A_local[curr_k][row_in_block] = A[A_curr_j * M + i];
136 if (B_curr_j >= N || tiled_i >= K
137 || (view_B == LOWER && B_curr_j > tiled_i)
138 || (view_B == UPPER && B_curr_j < tiled_i)) {
139 B_local[curr_k][row_in_block] = 0.0;
141 B_local[curr_k][row_in_block] = B[B_curr_j * K + tiled_i];
144 barrier(CLK_LOCAL_MEM_FENCE);
145 const int total_work_k =
min(
146 THREAD_BLOCK_SIZE, K - THREAD_BLOCK_SIZE * start_tile);
147 for (
int idx = linear_index, w = 0; idx < total_work_nm;
148 idx += threads_in_block, w++) {
149 const int row_B_local = idx / total_work_m;
150 const int col_A_local = idx % total_work_m;
151 for (
int idx_in_block = 0; idx_in_block < total_work_k;
153 acc[w] += A_local[idx_in_block][col_A_local]
154 * B_local[row_B_local][idx_in_block];
157 barrier(CLK_LOCAL_MEM_FENCE);
160 if (start_tile <= end_tile
161 && (view_A == LOWER || view_B == UPPER
162 || K % THREAD_BLOCK_SIZE
164 const int tiled_i = THREAD_BLOCK_SIZE * end_tile + row_in_block;
165 const int tiled_j = THREAD_BLOCK_SIZE * end_tile + col_in_block;
166 for (
int w = 0; w < WORK_PER_THREAD; w++) {
170 const int A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
171 const int B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
172 const int curr_k = col_in_block + w * THREAD_BLOCK_SIZE_COL;
176 if (A_curr_j >= K || i >= M
179 A_local[curr_k][row_in_block] = 0.0;
181 A_local[curr_k][row_in_block] = A[A_curr_j * M + i];
183 if (B_curr_j >= N || tiled_i >= K
186 B_local[curr_k][row_in_block] = 0.0;
188 B_local[curr_k][row_in_block] = B[B_curr_j * K + tiled_i];
191 barrier(CLK_LOCAL_MEM_FENCE);
192 const int total_work_k =
min(
193 THREAD_BLOCK_SIZE, K - THREAD_BLOCK_SIZE * end_tile);
194 for (
int idx = linear_index, w = 0; idx < total_work_nm;
195 idx += threads_in_block, w++) {
196 const int row_B_local = idx / total_work_m;
197 const int col_A_local = idx % total_work_m;
198 for (
int idx_in_block = 0; idx_in_block < total_work_k;
200 acc[w] += A_local[idx_in_block][col_A_local]
201 * B_local[row_B_local][idx_in_block];
204 barrier(CLK_LOCAL_MEM_FENCE);
207 if (total_work_n < THREAD_BLOCK_SIZE
209 < THREAD_BLOCK_SIZE) {
210 for (
int tile_idx = start_tile; tile_idx <= end_tile; tile_idx++) {
211 const int tiled_i = THREAD_BLOCK_SIZE * tile_idx + row_in_block;
212 const int tiled_j = THREAD_BLOCK_SIZE * tile_idx + col_in_block;
215 for (
int w = 0; w < WORK_PER_THREAD; w++) {
216 const int A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
217 const int B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
218 const int curr_k = col_in_block + w * THREAD_BLOCK_SIZE_COL;
221 A_local[curr_k][row_in_block] = A[A_curr_j * M + i];
224 B_local[curr_k][row_in_block] = B[B_curr_j * K + tiled_i];
227 barrier(CLK_LOCAL_MEM_FENCE);
228 int total_work_k =
min(THREAD_BLOCK_SIZE,
229 K - THREAD_BLOCK_SIZE * tile_idx);
230 for (
int idx = linear_index, w = 0; idx < total_work_nm;
231 idx += threads_in_block, w++) {
232 const int row_B_local = idx / total_work_m;
233 const int col_A_local = idx % total_work_m;
234 for (
int idx_in_block = 0; idx_in_block < total_work_k;
236 acc[w] += A_local[idx_in_block][col_A_local]
237 * B_local[row_B_local][idx_in_block];
240 barrier(CLK_LOCAL_MEM_FENCE);
242 for (
int idx = linear_index, w = 0; idx < total_work_nm;
243 idx += threads_in_block, w++) {
245 = THREAD_BLOCK_SIZE * get_group_id(0) + idx % total_work_m;
247 = THREAD_BLOCK_SIZE * get_group_id(1) + idx / total_work_m;
248 C[split_id * M * N + B_curr_j * M + curr_i] = acc[w];
251 for (
int tile_idx = start_tile; tile_idx <= end_tile; tile_idx++) {
252 const int tiled_i = THREAD_BLOCK_SIZE * tile_idx + row_in_block;
253 const int tiled_j = THREAD_BLOCK_SIZE * tile_idx + col_in_block;
256 for (
int w = 0; w < WORK_PER_THREAD; w++) {
257 const int A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
258 const int B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
259 const int curr_k = col_in_block + w * THREAD_BLOCK_SIZE_COL;
260 A_local[curr_k][row_in_block] = A[A_curr_j * M + i];
261 B_local[curr_k][row_in_block] = B[B_curr_j * K + tiled_i];
263 barrier(CLK_LOCAL_MEM_FENCE);
264 for (
int w = 0; w < WORK_PER_THREAD; w++) {
265 for (
int idx_in_block = 0; idx_in_block < THREAD_BLOCK_SIZE;
267 acc[w] += A_local[idx_in_block][row_in_block]
268 * B_local[w * THREAD_BLOCK_SIZE_COL + col_in_block]
272 barrier(CLK_LOCAL_MEM_FENCE);
275 for (
int w = 0; w < WORK_PER_THREAD; w++) {
276 const int curr_j = j + w * THREAD_BLOCK_SIZE_COL;
277 C[split_id * M * N + curr_j * M + i] = acc[w];
288const kernel_cl<in_buffer, in_buffer, out_buffer, int, int, int,
matrix_cl_view,
292 matrix_multiply_kernel_code},
293 {{
"THREAD_BLOCK_SIZE", 32}, {
"WORK_PER_THREAD", 8}});
296static constexpr const char* row_vector_matrix_multiply_kernel_code =
STRINGIFY(
311 const __global
double* A,
const __global
double* B, __global
double* R,
312 const int N,
const int K,
unsigned int view_A,
unsigned int view_B) {
313 const int lid = get_local_id(0);
314 const int gid = get_global_id(0);
315 const int wgid = get_group_id(0);
323 for (
int i = lid + start; i < stop; i += LOCAL_SIZE_) {
324 acc += A[i] * B[i + wgid * N];
327 __local
double res_loc[LOCAL_SIZE_];
329 barrier(CLK_LOCAL_MEM_FENCE);
330 for (
int step = LOCAL_SIZE_ / REDUCTION_STEP_SIZE;
step > 0;
331 step /= REDUCTION_STEP_SIZE) {
333 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
334 res_loc[lid] += res_loc[lid +
step * i];
337 barrier(CLK_LOCAL_MEM_FENCE);
340 R[wgid] = res_loc[0];
351const kernel_cl<in_buffer, in_buffer, out_buffer, int, int,
matrix_cl_view,
354 {view_kernel_helpers,
355 row_vector_matrix_multiply_kernel_code},
356 {{
"LOCAL_SIZE_", 64},
357 {
"REDUCTION_STEP_SIZE", 4}});
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, matrix_cl_view, matrix_cl_view > row_vector_matrix_multiply("row_vector_matrix_multiply", {view_kernel_helpers, row_vector_matrix_multiply_kernel_code}, {{"LOCAL_SIZE_", 64}, {"REDUCTION_STEP_SIZE", 4}})
See the docs for row_vector_matrix_multiply() .
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, int, matrix_cl_view, matrix_cl_view > matrix_multiply("matrix_multiply", {thread_block_helpers, view_kernel_helpers, matrix_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for matrix_multiply() .
bool contains_nonzero(const matrix_cl_view view, const matrix_cl_view part)
Check whether a view contains certain nonzero part.
static const std::string thread_block_helpers
Defines a helper macro for kernels with 2D local size.
auto min(T1 x, T2 y)
Returns the minimum coefficient of the two specified scalar arguments.
T step(const T &y)
The step, or Heaviside, function.
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...