Automatic Differentiation
 
Loading...
Searching...
No Matches
multiply_transpose.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP
2#define STAN_MATH_OPENCL_KERNELS_MULTIPLY_TRANSPOSE_HPP
3#ifdef STAN_OPENCL
4
7#include <string>
8
9namespace stan {
10namespace math {
11namespace opencl_kernels {
12// \cond
13static constexpr const char* multiply_transpose_kernel_code = STRINGIFY(
14 // \endcond
24 __kernel void multiply_transpose(const __global double* A,
25 __global double* B, const int M,
26 const int N) {
27 // thread index inside the thread block
28 const int thread_block_row = get_local_id(0);
29 const int thread_block_col = get_local_id(1);
30
31 // global thread index
32 const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
33 const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
34
35 // indexes that determine the last indexes that need to compute
36 // in order to remove the unnecessary multiplications in the special
37 // multiplication of A*A^T
38 const int j_min = THREAD_BLOCK_SIZE * get_group_id(1);
39 const int i_max = THREAD_BLOCK_SIZE * get_group_id(0) + get_local_size(0);
40
41 // local memory
42 __local double A_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
43 __local double B_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
44
45 double acc[WORK_PER_THREAD];
46 for (int w = 0; w < WORK_PER_THREAD; w++) {
47 acc[w] = 0.0;
48 }
49 if (j_min <= i_max) {
50 const int num_tiles = (N + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
51 // iterate over all tiles
52 for (int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
53 // in each tile
54 const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
55 const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
56 // if the data needs to be loaded to local memory
57 // each thread copies WORK_PER_THREAD values to the
58 // local memory
59 for (int w = 0; w < WORK_PER_THREAD; w++) {
60 const int A_temp_j = tiled_j + w * THREAD_BLOCK_SIZE_COL;
61 const int AT_temp_j = j + w * THREAD_BLOCK_SIZE_COL;
62 if (A_temp_j >= N || i >= M) {
63 A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
64 [thread_block_row]
65 = 0.0;
66 } else {
67 A_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
68 [thread_block_row]
69 = A[A_temp_j * M + i];
70 }
71 if (AT_temp_j >= M || tiled_i >= N) {
72 B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
73 [thread_block_row]
74 = 0.0;
75 } else {
76 B_local[thread_block_col + w * THREAD_BLOCK_SIZE_COL]
77 [thread_block_row]
78 = A[AT_temp_j + tiled_i * M];
79 }
80 }
81 // wait till all tile values are loaded to the local memory
82 barrier(CLK_LOCAL_MEM_FENCE);
83 // multiply the tile products
84 for (int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
85 // each thread multiplies WORK_PER_THREAD values
86 for (int w = 0; w < WORK_PER_THREAD; w++) {
87 if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
88 acc[w] += A_local[block_ind][thread_block_row]
89 * B_local[thread_block_col
90 + w * THREAD_BLOCK_SIZE_COL][block_ind];
91 }
92 }
93 }
94 barrier(CLK_LOCAL_MEM_FENCE);
95 }
96 // each thread saves WORK_PER_THREAD values to C
97 for (int w = 0; w < WORK_PER_THREAD; w++) {
98 // This prevents threads from accessing elements
99 // outside the allocated memory for C. The check
100 // is in the loop because some threads
101 // can be assigned elements in and out of
102 // the allocated memory.
103 if ((j + w * THREAD_BLOCK_SIZE_COL) < M && i < M) {
104 if ((j + w * THREAD_BLOCK_SIZE_COL) <= i) {
105 B[i + (j + w * THREAD_BLOCK_SIZE_COL) * M] = acc[w];
106 B[(j + w * THREAD_BLOCK_SIZE_COL) + i * M] = acc[w];
107 }
108 }
109 }
110 }
111 }
112 // \cond
113);
114// \endcond
115
120 "multiply_transpose",
121 {thread_block_helpers, multiply_transpose_kernel_code},
122 {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 4}});
123
124} // namespace opencl_kernels
125} // namespace math
126} // namespace stan
127#endif
128#endif
const kernel_cl< in_buffer, out_buffer, int, int > multiply_transpose("multiply_transpose", {thread_block_helpers, multiply_transpose_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 4}})
See the docs for add() .
static const std::string thread_block_helpers
Defines a helper macro for kernels with 2D local size.
Definition helpers.hpp:24
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9
Creates functor for kernels.