Automatic Differentiation
 
Loading...
Searching...
No Matches
matrix_multiply.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_MATRIX_MULTIPLY_HPP
2#define STAN_MATH_OPENCL_KERNELS_MATRIX_MULTIPLY_HPP
3#ifdef STAN_OPENCL
4
8#include <string>
9
10namespace stan {
11namespace math {
12namespace opencl_kernels {
13// \cond
14static constexpr const char* matrix_multiply_kernel_code = STRINGIFY(
15 // \endcond
28 __kernel void matrix_multiply(const __global double* A,
29 const __global double* B, __global double* C,
30 const int M, const int N, const int K,
31 unsigned int view_A, unsigned int view_B) {
32 // thread index inside the thread_block
33 const int row_in_block = get_local_id(0);
34 const int col_in_block = get_local_id(1);
35
36 const int group_id_row = get_group_id(0);
37 const int group_id_col = get_group_id(1);
38 // global thread index
39 const int i = THREAD_BLOCK_SIZE * group_id_row + row_in_block;
40 const int j = THREAD_BLOCK_SIZE * group_id_col + col_in_block;
41 // identify if the matrix multiply is split
42 const int split_id = get_global_id(2);
43 const int split_size = get_global_size(2);
44 // local memory
45 __local double A_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
46 __local double B_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
47
48 double acc[WORK_PER_THREAD];
49 for (int w = 0; w < WORK_PER_THREAD; w++) {
50 acc[w] = 0.0;
51 }
52 // the number of tiles for each scalar product in the matrix multiply
53 const int num_tiles = (K + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
54 // in case of splitting the matrix multiply we need
55 // use split_offset_tiles the threads assigned part
56 // of the scalar products, while the split_tiles
57 // determines the number of tiles a thread multiplies
58 // if split_size = 1, each thread calculates the
59 // the entire scalar product for all assigned
60 // elements of the resulting matrix, meaning that
61 // split_offset_tiles is 0 and split_tiles = num_tiles
62 int split_tiles = num_tiles / split_size;
63 const int split_remainder = num_tiles % split_size;
64 int split_offset_tiles = split_id * split_tiles;
65 if (split_id < split_remainder) {
66 split_offset_tiles = split_offset_tiles + split_id;
67 split_tiles++;
68 } else {
69 split_offset_tiles = split_offset_tiles + split_remainder;
70 }
71 // This kernel is based on the well known
72 // general matrix multiplication kernels that
73 // use tiling for shared memory
74 // In cases where a matrix is lower triangular
75 // its not necessary to multiply the elements
76 // over the diagonal, therefore those tiles
77 // in the matrix multiply can be skipped.
78 // With upper triangular matrices we dont need
79 // to multiply the elements under the diagonal,
80 // so those tiles can be skipped.
81 // The following code determines the start and
82 // end tile based on triangularity of the input matrices
83 // If no matrices are triangular the starting tile
84 // is 0 and the end tile is num_tiles-1 which
85 // is then a general matrix multiply
86 const int end_tile_A = contains_nonzero(view_A, UPPER)
87 ? (num_tiles - 1)
88 : (i / THREAD_BLOCK_SIZE);
89 const int end_tile_B = contains_nonzero(view_B, LOWER)
90 ? (num_tiles - 1)
91 : (j / THREAD_BLOCK_SIZE);
92 const int start_tile_A
93 = contains_nonzero(view_A, LOWER) ? 0 : (i / THREAD_BLOCK_SIZE);
94 const int start_tile_B
95 = contains_nonzero(view_B, UPPER) ? 0 : (j / THREAD_BLOCK_SIZE);
96 // the starting and end tiles for a thread are determined by
97 // split_offset_tiles and split_tiles. If the input matrix is
98 // triangular some tiles can be skipped in which case we
99 // either start the scalar product at larger cols/rows
100 // or end them at smaller cols/rows.
101 int start_tile = max(start_tile_A, start_tile_B);
102 start_tile = max(start_tile, split_offset_tiles);
103 int end_tile = min(end_tile_A, end_tile_B); // NOLINT
104 end_tile = min(end_tile, split_offset_tiles + split_tiles - 1); // NOLINT
105
106 const int total_work_n = min( // NOLINT
107 THREAD_BLOCK_SIZE, N - THREAD_BLOCK_SIZE * group_id_col);
108 const int total_work_m = min( // NOLINT
109 THREAD_BLOCK_SIZE, M - THREAD_BLOCK_SIZE * group_id_row);
110 const int total_work_nm = total_work_n * total_work_m;
111 const int threads_in_block = THREAD_BLOCK_SIZE_COL * THREAD_BLOCK_SIZE;
112 const int linear_index
113 = get_local_id(0) + get_local_id(1) * THREAD_BLOCK_SIZE;
114
115 if (start_tile <= end_tile
116 && (view_A == UPPER
117 || view_B == LOWER)) { // special handling of first block
118 const int tiled_i = THREAD_BLOCK_SIZE * start_tile + row_in_block;
119 const int tiled_j = THREAD_BLOCK_SIZE * start_tile + col_in_block;
120 for (int w = 0; w < WORK_PER_THREAD; w++) {
121 // For the tiles on the diagonal we can ignore the values over
122 // the diagonal if the matrix is lower triangular or under
123 // the diagonal if the matrix is upper triangular
124 const int A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
125 const int B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
126 const int curr_k = col_in_block + w * THREAD_BLOCK_SIZE_COL;
127 // check if the indexes are outside the matrix
128 // or under/above the diagonal with upper/lower
129 // triangular matrices
130 if (A_curr_j >= K || i >= M || (view_A == LOWER && A_curr_j > i)
131 || (view_A == UPPER && A_curr_j < i)) {
132 A_local[curr_k][row_in_block] = 0.0;
133 } else {
134 A_local[curr_k][row_in_block] = A[A_curr_j * M + i];
135 }
136 if (B_curr_j >= N || tiled_i >= K
137 || (view_B == LOWER && B_curr_j > tiled_i)
138 || (view_B == UPPER && B_curr_j < tiled_i)) {
139 B_local[curr_k][row_in_block] = 0.0;
140 } else {
141 B_local[curr_k][row_in_block] = B[B_curr_j * K + tiled_i];
142 }
143 }
144 barrier(CLK_LOCAL_MEM_FENCE);
145 const int total_work_k = min(
146 THREAD_BLOCK_SIZE, K - THREAD_BLOCK_SIZE * start_tile); // NOLINT
147 for (int idx = linear_index, w = 0; idx < total_work_nm;
148 idx += threads_in_block, w++) {
149 const int row_B_local = idx / total_work_m;
150 const int col_A_local = idx % total_work_m;
151 for (int idx_in_block = 0; idx_in_block < total_work_k;
152 idx_in_block++) {
153 acc[w] += A_local[idx_in_block][col_A_local]
154 * B_local[row_B_local][idx_in_block];
155 }
156 }
157 barrier(CLK_LOCAL_MEM_FENCE);
158 start_tile++;
159 }
160 if (start_tile <= end_tile
161 && (view_A == LOWER || view_B == UPPER
162 || K % THREAD_BLOCK_SIZE
163 != 0)) { // special handling of last block
164 const int tiled_i = THREAD_BLOCK_SIZE * end_tile + row_in_block;
165 const int tiled_j = THREAD_BLOCK_SIZE * end_tile + col_in_block;
166 for (int w = 0; w < WORK_PER_THREAD; w++) {
167 // For the tiles on the diagonal we can ignore the values over
168 // the diagonal if the matrix is lower triangular or under
169 // the diagonal if the matrix is upper triangular
170 const int A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
171 const int B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
172 const int curr_k = col_in_block + w * THREAD_BLOCK_SIZE_COL;
173 // check if the indexes are outside the matrix
174 // or under/above the diagonal with upper/lower
175 // triangular matrices
176 if (A_curr_j >= K || i >= M
177 || (!contains_nonzero(view_A, UPPER) && A_curr_j > i)
178 || (!contains_nonzero(view_A, LOWER) && A_curr_j < i)) {
179 A_local[curr_k][row_in_block] = 0.0;
180 } else {
181 A_local[curr_k][row_in_block] = A[A_curr_j * M + i];
182 }
183 if (B_curr_j >= N || tiled_i >= K
184 || (!contains_nonzero(view_B, UPPER) && B_curr_j > tiled_i)
185 || (!contains_nonzero(view_B, LOWER) && B_curr_j < tiled_i)) {
186 B_local[curr_k][row_in_block] = 0.0;
187 } else {
188 B_local[curr_k][row_in_block] = B[B_curr_j * K + tiled_i];
189 }
190 }
191 barrier(CLK_LOCAL_MEM_FENCE);
192 const int total_work_k = min(
193 THREAD_BLOCK_SIZE, K - THREAD_BLOCK_SIZE * end_tile); // NOLINT
194 for (int idx = linear_index, w = 0; idx < total_work_nm;
195 idx += threads_in_block, w++) {
196 const int row_B_local = idx / total_work_m;
197 const int col_A_local = idx % total_work_m;
198 for (int idx_in_block = 0; idx_in_block < total_work_k;
199 idx_in_block++) {
200 acc[w] += A_local[idx_in_block][col_A_local]
201 * B_local[row_B_local][idx_in_block];
202 }
203 }
204 barrier(CLK_LOCAL_MEM_FENCE);
205 end_tile--;
206 }
207 if (total_work_n < THREAD_BLOCK_SIZE
208 || total_work_m
209 < THREAD_BLOCK_SIZE) { // special handling of edge blocks
210 for (int tile_idx = start_tile; tile_idx <= end_tile; tile_idx++) {
211 const int tiled_i = THREAD_BLOCK_SIZE * tile_idx + row_in_block;
212 const int tiled_j = THREAD_BLOCK_SIZE * tile_idx + col_in_block;
213 // each thread copies WORK_PER_THREAD values to the local
214 // memory
215 for (int w = 0; w < WORK_PER_THREAD; w++) {
216 const int A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
217 const int B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
218 const int curr_k = col_in_block + w * THREAD_BLOCK_SIZE_COL;
219 // check if the indexes are outside the matrix
220 if (i < M) {
221 A_local[curr_k][row_in_block] = A[A_curr_j * M + i];
222 }
223 if (B_curr_j < N) {
224 B_local[curr_k][row_in_block] = B[B_curr_j * K + tiled_i];
225 }
226 }
227 barrier(CLK_LOCAL_MEM_FENCE);
228 int total_work_k = min(THREAD_BLOCK_SIZE, // NOLINT
229 K - THREAD_BLOCK_SIZE * tile_idx);
230 for (int idx = linear_index, w = 0; idx < total_work_nm;
231 idx += threads_in_block, w++) {
232 const int row_B_local = idx / total_work_m;
233 const int col_A_local = idx % total_work_m;
234 for (int idx_in_block = 0; idx_in_block < total_work_k;
235 idx_in_block++) {
236 acc[w] += A_local[idx_in_block][col_A_local]
237 * B_local[row_B_local][idx_in_block];
238 }
239 }
240 barrier(CLK_LOCAL_MEM_FENCE);
241 }
242 for (int idx = linear_index, w = 0; idx < total_work_nm;
243 idx += threads_in_block, w++) {
244 const int curr_i
245 = THREAD_BLOCK_SIZE * get_group_id(0) + idx % total_work_m;
246 const int B_curr_j
247 = THREAD_BLOCK_SIZE * get_group_id(1) + idx / total_work_m;
248 C[split_id * M * N + B_curr_j * M + curr_i] = acc[w];
249 }
250 } else { // general case that is not on the edge - all threads have work
251 for (int tile_idx = start_tile; tile_idx <= end_tile; tile_idx++) {
252 const int tiled_i = THREAD_BLOCK_SIZE * tile_idx + row_in_block;
253 const int tiled_j = THREAD_BLOCK_SIZE * tile_idx + col_in_block;
254 // each thread copies WORK_PER_THREAD values to the local
255 // memory
256 for (int w = 0; w < WORK_PER_THREAD; w++) {
257 const int A_curr_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
258 const int B_curr_j = j + w * THREAD_BLOCK_SIZE_COL;
259 const int curr_k = col_in_block + w * THREAD_BLOCK_SIZE_COL;
260 A_local[curr_k][row_in_block] = A[A_curr_j * M + i];
261 B_local[curr_k][row_in_block] = B[B_curr_j * K + tiled_i];
262 }
263 barrier(CLK_LOCAL_MEM_FENCE);
264 for (int w = 0; w < WORK_PER_THREAD; w++) {
265 for (int idx_in_block = 0; idx_in_block < THREAD_BLOCK_SIZE;
266 idx_in_block++) {
267 acc[w] += A_local[idx_in_block][row_in_block]
268 * B_local[w * THREAD_BLOCK_SIZE_COL + col_in_block]
269 [idx_in_block];
270 }
271 }
272 barrier(CLK_LOCAL_MEM_FENCE);
273 }
274 // each thread saves WORK_PER_THREAD values
275 for (int w = 0; w < WORK_PER_THREAD; w++) {
276 const int curr_j = j + w * THREAD_BLOCK_SIZE_COL;
277 C[split_id * M * N + curr_j * M + i] = acc[w];
278 }
279 }
280 }
281 // \cond
282);
283// \endcond
284
288const kernel_cl<in_buffer, in_buffer, out_buffer, int, int, int, matrix_cl_view,
290 matrix_multiply("matrix_multiply",
291 {thread_block_helpers, view_kernel_helpers,
292 matrix_multiply_kernel_code},
293 {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}});
294
295// \cond
296static constexpr const char* row_vector_matrix_multiply_kernel_code = STRINGIFY(
297 // \endcond
311 const __global double* A, const __global double* B, __global double* R,
312 const int N, const int K, unsigned int view_A, unsigned int view_B) {
313 const int lid = get_local_id(0);
314 const int gid = get_global_id(0);
315 const int wgid = get_group_id(0);
316
317 const int start = contains_nonzero(view_B, UPPER) ? 0 : wgid;
318 const int stop = contains_nonzero(view_A, UPPER)
319 ? contains_nonzero(view_B, LOWER) ? N : wgid + 1
320 : 1;
321
322 double acc = 0;
323 for (int i = lid + start; i < stop; i += LOCAL_SIZE_) {
324 acc += A[i] * B[i + wgid * N];
325 }
326
327 __local double res_loc[LOCAL_SIZE_];
328 res_loc[lid] = acc;
329 barrier(CLK_LOCAL_MEM_FENCE);
330 for (int step = LOCAL_SIZE_ / REDUCTION_STEP_SIZE; step > 0;
331 step /= REDUCTION_STEP_SIZE) {
332 if (lid < step) {
333 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
334 res_loc[lid] += res_loc[lid + step * i];
335 }
336 }
337 barrier(CLK_LOCAL_MEM_FENCE);
338 }
339 if (lid == 0) {
340 R[wgid] = res_loc[0];
341 }
342 }
343 // \cond
344);
345// \endcond
346
351const kernel_cl<in_buffer, in_buffer, out_buffer, int, int, matrix_cl_view,
353 row_vector_matrix_multiply("row_vector_matrix_multiply",
354 {view_kernel_helpers,
355 row_vector_matrix_multiply_kernel_code},
356 {{"LOCAL_SIZE_", 64},
357 {"REDUCTION_STEP_SIZE", 4}});
358
359} // namespace opencl_kernels
360} // namespace math
361} // namespace stan
362#endif
363#endif
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, matrix_cl_view, matrix_cl_view > row_vector_matrix_multiply("row_vector_matrix_multiply", {view_kernel_helpers, row_vector_matrix_multiply_kernel_code}, {{"LOCAL_SIZE_", 64}, {"REDUCTION_STEP_SIZE", 4}})
See the docs for row_vector_matrix_multiply() .
const kernel_cl< in_buffer, in_buffer, out_buffer, int, int, int, matrix_cl_view, matrix_cl_view > matrix_multiply("matrix_multiply", {thread_block_helpers, view_kernel_helpers, matrix_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for matrix_multiply() .
bool contains_nonzero(const matrix_cl_view view, const matrix_cl_view part)
Check whether a view contains certain nonzero part.
static const std::string thread_block_helpers
Defines a helper macro for kernels with 2D local size.
Definition helpers.hpp:24
auto min(T1 x, T2 y)
Returns the minimum coefficient of the two specified scalar arguments.
Definition min.hpp:24
T step(const T &y)
The step, or Heaviside, function.
Definition step.hpp:31
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
Definition max.hpp:25
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9