1#ifndef STAN_MATH_GPU_KERNELS_TRIDIAGONALIZATION_HPP
2#define STAN_MATH_GPU_KERNELS_TRIDIAGONALIZATION_HPP
10namespace opencl_kernels {
13static constexpr const char* tridiagonalization_householder_kernel_code
28 __global
double* P, __global
double* V, __global
double* q_glob,
29 const int P_rows,
const int V_rows,
const int j,
const int k) {
30 const int lid = get_local_id(0);
31 const int gid = get_global_id(0);
32 const int gsize = get_global_size(0);
33 const int lsize = get_local_size(0);
34 const int ngroups = get_num_groups(0);
35 const int wgid = get_group_id(0);
39 const int P_start = P_rows * (k + j) + k + j;
40 const int P_span = P_rows * (k + j + 1) - P_start;
41 for (
int i = lid; i < P_span; i += lsize) {
45 for (
int l = 0; l < j; l++) {
46 acc += P[P_rows * (k + l) + k + j + i] * V[V_rows * l + j - 1]
47 + V[V_rows * l + j - 1 + i] * P[P_rows * (k + l) + k + j];
49 double tmp = P[P_start + i] - acc;
56 __local
double q_local[LOCAL_SIZE_];
58 barrier(CLK_LOCAL_MEM_FENCE);
59 for (
int step = lsize / REDUCTION_STEP_SIZE;
step > 0;
60 step /= REDUCTION_STEP_SIZE) {
62 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
63 q_local[lid] += q_local[lid +
step * i];
66 barrier(CLK_LOCAL_MEM_FENCE);
72 double p1 = P[P_start + 1];
74 alpha = -copysign(
sqrt(q), P[P_start]);
84 barrier(CLK_LOCAL_MEM_FENCE);
88 double multi = M_SQRT2 / q;
90 for (
int i = lid + 1; i < P_span; i += lsize) {
91 P[P_start + i] *= multi;
95 P[P_rows * (k + j + 1) + k + j]
96 = P[P_rows * (k + j) + k + j + 1] * q / M_SQRT2 + alpha;
102static constexpr const char* tridiagonalization_v_step_1_kernel_code
120 const __global
double* P,
const __global
double* V,
121 __global
double* Uu, __global
double* Vu,
const int P_rows,
122 const int V_rows,
const int k) {
123 const int lid = get_local_id(0);
124 const int gid = get_global_id(0);
125 const int gsize = get_global_size(0);
126 const int lsize = get_local_size(0);
127 const int ngroups = get_num_groups(0);
128 const int wgid = get_group_id(0);
130 __local
double res_loc1[LOCAL_SIZE_];
131 __local
double res_loc2[LOCAL_SIZE_];
135 const __global
double* vec
136 = P + P_rows * (k + ngroups) + k + ngroups + 1;
137 const __global
double* M1 = P + P_rows * (k + wgid) + k + ngroups + 1;
138 const __global
double* M2 = V + V_rows * wgid + ngroups;
139 for (
int i = lid; i < P_rows - k - ngroups - 1;
146 res_loc1[lid] = acc1;
147 res_loc2[lid] = acc2;
148 barrier(CLK_LOCAL_MEM_FENCE);
150 for (
int step = lsize / REDUCTION_STEP_SIZE;
step > 0;
151 step /= REDUCTION_STEP_SIZE) {
153 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
154 res_loc1[lid] += res_loc1[lid +
step * i];
155 res_loc2[lid] += res_loc2[lid +
step * i];
158 barrier(CLK_LOCAL_MEM_FENCE);
161 Uu[wgid] = res_loc1[0];
162 Vu[wgid] = res_loc2[0];
168static constexpr const char* tridiagonalization_v_step_2_kernel_code
188 const __global
double* P, __global
double* V,
189 const __global
double* Uu,
const __global
double* Vu,
190 const int P_rows,
const int V_rows,
const int k,
const int j) {
191 const int lid = get_local_id(0);
192 const int gid = get_global_id(0);
193 const int gsize = get_global_size(0);
194 const int lsize = get_local_size(0);
195 const int ngroups = get_num_groups(0);
196 const int wgid = get_group_id(0);
198 int work = P_rows - k - j - 1;
201 const __global
double* vec = P + P_rows * (k + j) + k + j + 1;
202 const __global
double* M1 = P + P_rows * (k + j + 1) + k + j + 1;
203 const __global
double* M2 = P + P_rows * k + k + j + 1;
204 const __global
double* M3 = V + j;
207 for (i = 0; i <= gid; i++) {
208 acc += M1[P_rows * i + gid] * vec[i];
210 for (
int i = 0; i < j; i++) {
211 acc -= M2[P_rows * i + gid] * Vu[i];
212 acc -= M3[V_rows * i + gid] * Uu[i];
214 V[V_rows * j + gid + j] = acc;
217 = (float)work / ngroups;
218 int start = work_per_group * wgid;
219 int end = work_per_group * (wgid + 1);
220 __local
double res_loc[LOCAL_SIZE_];
221 for (
int i = start; i < end; i += 1) {
223 for (
int l = i + 1 + lid; l < work; l += LOCAL_SIZE_) {
224 acc += M1[P_rows * i + l] * vec[l];
227 barrier(CLK_LOCAL_MEM_FENCE);
228 for (
int step = lsize / REDUCTION_STEP_SIZE;
step > 0;
229 step /= REDUCTION_STEP_SIZE) {
231 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
232 res_loc[lid] += res_loc[lid +
step * i];
235 barrier(CLK_LOCAL_MEM_FENCE);
238 V[V_rows * (j + 1) + i + j] = res_loc[lid];
240 barrier(CLK_LOCAL_MEM_FENCE);
246static constexpr const char* tridiagonalization_v_step_3_kernel_code
261 __global
double* P, __global
double* V, __global
double* q,
262 const int P_rows,
const int V_rows,
const int k,
const int j) {
263 const int lid = get_local_id(0);
264 const int gid = get_global_id(0);
265 const int gsize = get_global_size(0);
266 const int lsize = get_local_size(0);
267 const int ngroups = get_num_groups(0);
268 const int wgid = get_group_id(0);
270 __global
double* u = P + P_rows * (k + j) + k + j + 1;
271 __global
double* v = V + V_rows * j + j;
274 for (
int i = lid; i < P_rows - k - j - 1; i += LOCAL_SIZE_) {
275 double vi = v[i] + v[i + V_rows];
279 __local
double res_loc[LOCAL_SIZE_];
281 barrier(CLK_LOCAL_MEM_FENCE);
282 for (
int step = lsize / REDUCTION_STEP_SIZE;
step > 0;
283 step /= REDUCTION_STEP_SIZE) {
285 for (
int i = 1; i < REDUCTION_STEP_SIZE; i++) {
286 res_loc[lid] += res_loc[lid +
step * i];
289 barrier(CLK_LOCAL_MEM_FENCE);
291 acc = res_loc[0] * 0.5;
292 for (
int i = lid; i < P_rows - k - j - 1; i += LOCAL_SIZE_) {
296 P[P_rows * (k + j + 1) + k + j] -= *q / M_SQRT2 * u[0];
301const kernel_cl<in_out_buffer, in_out_buffer, out_buffer, int, int, int, int>
303 {tridiagonalization_householder_kernel_code},
304 {{
"REDUCTION_STEP_SIZE", 4},
305 {
"LOCAL_SIZE_", 1024}});
307const kernel_cl<in_buffer, in_buffer, out_buffer, out_buffer, int, int, int>
309 {tridiagonalization_v_step_1_kernel_code},
310 {{
"REDUCTION_STEP_SIZE", 4},
311 {
"LOCAL_SIZE_", 64}});
313const kernel_cl<in_buffer, out_buffer, in_buffer, in_buffer, int, int, int, int>
315 {tridiagonalization_v_step_2_kernel_code},
316 {{
"REDUCTION_STEP_SIZE", 4},
317 {
"LOCAL_SIZE_", 64}});
319const kernel_cl<in_out_buffer, in_out_buffer, out_buffer, int, int, int, int>
321 {tridiagonalization_v_step_3_kernel_code},
322 {{
"REDUCTION_STEP_SIZE", 4},
323 {
"LOCAL_SIZE_", 1024}});
const kernel_cl< in_buffer, in_buffer, out_buffer, out_buffer, int, int, int > tridiagonalization_v_step_1("tridiagonalization_v_step_1", {tridiagonalization_v_step_1_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
const kernel_cl< in_buffer, out_buffer, in_buffer, in_buffer, int, int, int, int > tridiagonalization_v_step_2("tridiagonalization_v_step_2", {tridiagonalization_v_step_2_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
const kernel_cl< in_out_buffer, in_out_buffer, out_buffer, int, int, int, int > tridiagonalization_v_step_3("tridiagonalization_v_step_3", {tridiagonalization_v_step_3_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 1024}})
const kernel_cl< in_out_buffer, in_out_buffer, out_buffer, int, int, int, int > tridiagonalization_householder("tridiagonalization_householder", {tridiagonalization_householder_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 1024}})
T step(const T &y)
The step, or Heaviside, function.
fvar< T > sqrt(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...