Automatic Differentiation
 
Loading...
Searching...
No Matches
neg_rect_lower_tri_multiply.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_NEGATIVE_RECT_LOWER_TRI_MULTIPLY_HPP
2#define STAN_MATH_OPENCL_KERNELS_NEGATIVE_RECT_LOWER_TRI_MULTIPLY_HPP
3#ifdef STAN_OPENCL
4
7#include <string>
8
9namespace stan {
10namespace math {
11namespace opencl_kernels {
12// \cond
13static constexpr const char* neg_rect_lower_tri_multiply_kernel_code
14 = STRINGIFY(
15 // \endcond
44 __global double* A, const __global double* temp, const int A_rows,
45 const int rows) {
46 int result_matrix_id = get_global_id(2);
47 int offset = result_matrix_id * rows * 2;
48 const int thread_block_row = get_local_id(0);
49 const int thread_block_col = get_local_id(1);
50 const int i = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
51 const int j = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
52
53 __local double temp_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
54 __local double C1_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
55
56 double acc[WORK_PER_THREAD] = {0};
57
58 const int num_tiles
59 = (rows + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
60 for (int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
61 // each thread copies WORK_PER_THREAD values to the local
62 // memory
63 for (int w = 0; w < WORK_PER_THREAD; w++) {
64 const int tiled_i
65 = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
66 const int tiled_j
67 = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
68 const int temp_global_col = tiled_j + w * THREAD_BLOCK_SIZE_COL;
69 // {C2}{A2}_global_{col}{row} specifies which global element for
70 // each matrix the thread is in charge of moving to local memory.
71 const int C1_global_col = offset + j + w * THREAD_BLOCK_SIZE_COL;
72 const int C1_global_row = tiled_i + offset;
73 // Which {col}{row} location in the local memory the thread is in
74 // charge of.
75 const int local_col
76 = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
77 const int local_row = thread_block_row;
78 if ((temp_global_col) < rows && i < rows) {
79 temp_local[local_col][local_row]
80 = temp[result_matrix_id * rows * rows
81 + temp_global_col * rows + i];
82 } else {
83 temp_local[local_col][local_row] = 0.0;
84 }
85 // Element above the diagonal will not be transferred.
86 if (C1_global_col <= C1_global_row && C1_global_col < A_rows
87 && C1_global_row < A_rows) {
88 C1_local[local_col][local_row]
89 = A[C1_global_col * A_rows + C1_global_row];
90 } else {
91 C1_local[local_col][local_row] = 0;
92 }
93 }
94 // wait until all tile values are loaded to the local memory
95 barrier(CLK_LOCAL_MEM_FENCE);
96 for (int block_ind = 0; block_ind < THREAD_BLOCK_SIZE;
97 block_ind++) {
98 for (int w = 0; w < WORK_PER_THREAD; w++) {
99 // Which {col}{row} location in the local memory the thread is
100 // in
101 // charge of.
102 const int local_col
103 = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
104 const int local_row = thread_block_row;
105 acc[w] += temp_local[block_ind][local_row]
106 * C1_local[local_col][block_ind];
107 }
108 }
109 barrier(CLK_LOCAL_MEM_FENCE);
110 }
111 // A_global_{row}{col} tells the thread which local memory it needs
112 // to move to the final output
113 const int A_global_row = i + rows + offset;
114 const int A_global_col_offset = offset + j;
115 // each thread saves WORK_PER_THREAD values
116 for (int w = 0; w < WORK_PER_THREAD; w++) {
117 const int A_global_col
118 = A_global_col_offset + w * THREAD_BLOCK_SIZE_COL;
119 if (A_global_col < A_rows && (i + rows + offset) < A_rows) {
120 A[A_global_col * A_rows + i + rows + offset] = -acc[w];
121 }
122 }
123 }
124 // \cond
125 ); // NOLINT(whitespace/parens)
126// \endcond
127
134 "neg_rect_lower_tri_multiply",
135 {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code},
136 {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}});
137} // namespace opencl_kernels
138} // namespace math
139} // namespace stan
140#endif
141#endif
const kernel_cl< in_out_buffer, in_buffer, int, int > neg_rect_lower_tri_multiply("neg_rect_lower_tri_multiply", {thread_block_helpers, neg_rect_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for neg_rect_lower_tri_multiply() .
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:22
static const std::string thread_block_helpers
Defines a helper macro for kernels with 2D local size.
Definition helpers.hpp:24
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9
Creates functor for kernels.