Automatic Differentiation
 
Loading...
Searching...
No Matches
inv_lower_tri_multiply.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNELS_INVERSE_LOWER_TRI_MULTIPLY_HPP
2#define STAN_MATH_OPENCL_KERNELS_INVERSE_LOWER_TRI_MULTIPLY_HPP
3#ifdef STAN_OPENCL
4
7#include <string>
8
9namespace stan {
10namespace math {
11namespace opencl_kernels {
12// \cond
13static constexpr const char* inv_lower_tri_multiply_kernel_code = STRINGIFY(
14 // \endcond
46 __kernel void inv_lower_tri_multiply(__global double* A,
47 __global double* temp,
48 const int A_rows, const int rows) {
49 int result_matrix_id = get_global_id(2);
50 int offset = result_matrix_id * rows * 2;
51 const int thread_block_row = get_local_id(0);
52 const int thread_block_col = get_local_id(1);
53 const int global_thread_row
54 = THREAD_BLOCK_SIZE * get_group_id(0) + thread_block_row;
55 const int global_thread_col
56 = THREAD_BLOCK_SIZE * get_group_id(1) + thread_block_col;
57
58 __local double C2_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
59 __local double A3_local[THREAD_BLOCK_SIZE][THREAD_BLOCK_SIZE];
60
61 double acc[WORK_PER_THREAD] = {0};
62
63 const int num_tiles = (rows + THREAD_BLOCK_SIZE - 1) / THREAD_BLOCK_SIZE;
64 for (int tile_ind = 0; tile_ind < num_tiles; tile_ind++) {
65 // Each thread copies WORK_PER_THREAD values to the local
66 // memory
67 for (int w = 0; w < WORK_PER_THREAD; w++) {
68 const int tiled_i = THREAD_BLOCK_SIZE * tile_ind + thread_block_row;
69 const int tiled_j = THREAD_BLOCK_SIZE * tile_ind + thread_block_col;
70 // {C2}{A2}_global_{col}{row} specifies which global element for each
71 // matrix the thread is in charge of moving to local memory.
72 const int C2_global_col
73 = offset + rows + tiled_j + w * THREAD_BLOCK_SIZE_COL;
74 const int C2_global_row = offset + global_thread_row + rows;
75 const int A3_global_col
76 = offset + global_thread_col + w * THREAD_BLOCK_SIZE_COL;
77 const int A3_global_row = tiled_i + rows + offset;
78 // Which {col}{row} location in the local memory the thread is in
79 // charge of.
80 const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
81 const int local_row = thread_block_row;
82 // Element above the diagonal will not be transferred.
83 if (C2_global_col <= C2_global_row && C2_global_col < A_rows
84 && C2_global_row < A_rows) {
85 C2_local[local_col][local_row]
86 = A[C2_global_col * A_rows + C2_global_row];
87 } else {
88 C2_local[local_col][local_row] = 0;
89 }
90 if (A3_global_col < A_rows && A3_global_row < A_rows) {
91 A3_local[local_col][local_row]
92 = A[A3_global_col * A_rows + A3_global_row];
93 } else {
94 A3_local[local_col][local_row] = 0.0;
95 }
96 }
97 // Wait until all tile values are loaded to the local memory
98 barrier(CLK_LOCAL_MEM_FENCE);
99 for (int block_ind = 0; block_ind < THREAD_BLOCK_SIZE; block_ind++) {
100 for (int w = 0; w < WORK_PER_THREAD; w++) {
101 const int local_col = thread_block_col + w * THREAD_BLOCK_SIZE_COL;
102 const int local_row = thread_block_row;
103 acc[w] += C2_local[block_ind][local_row]
104 * A3_local[local_col][block_ind];
105 }
106 }
107 barrier(CLK_LOCAL_MEM_FENCE);
108 }
109 // Global offset for each resulting submatrix
110 const int batch_offset = result_matrix_id * rows * rows;
111 // temp_global_{row}{col} tells the thread which local memory it needs
112 // to move to the final output
113 const int temp_global_row = global_thread_row;
114 // save the values
115 for (int w = 0; w < WORK_PER_THREAD; w++) {
116 // each thread saves WORK_PER_THREAD values
117 const int temp_global_col
118 = global_thread_col + w * THREAD_BLOCK_SIZE_COL;
119 temp[batch_offset + temp_global_col * rows + temp_global_row] = acc[w];
120 }
121 }
122 // \cond
123);
124// \endcond
125
130 "inv_lower_tri_multiply",
131 {thread_block_helpers, inv_lower_tri_multiply_kernel_code},
132 {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}});
133
134} // namespace opencl_kernels
135} // namespace math
136} // namespace stan
137#endif
138#endif
const kernel_cl< in_buffer, out_buffer, int, int > inv_lower_tri_multiply("inv_lower_tri_multiply", {thread_block_helpers, inv_lower_tri_multiply_kernel_code}, {{"THREAD_BLOCK_SIZE", 32}, {"WORK_PER_THREAD", 8}})
See the docs for add() .
int64_t rows(const T_x &x)
Returns the number of rows in the specified kernel generator expression.
Definition rows.hpp:22
static const std::string thread_block_helpers
Defines a helper macro for kernels with 2D local size.
Definition helpers.hpp:24
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9
Creates functor for kernels.