Automatic Differentiation
 
Loading...
Searching...
No Matches
tridiagonalization.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_GPU_KERNELS_TRIDIAGONALIZATION_HPP
2#define STAN_MATH_GPU_KERNELS_TRIDIAGONALIZATION_HPP
3
4#ifdef STAN_OPENCL
5
7
8namespace stan {
9namespace math {
10namespace opencl_kernels {
11
12// \cond
13static constexpr const char* tridiagonalization_householder_kernel_code
14 = STRINGIFY(
15 // \endcond
28 __global double* P, __global double* V, __global double* q_glob,
29 const int P_rows, const int V_rows, const int j, const int k) {
30 const int lid = get_local_id(0);
31 const int gid = get_global_id(0);
32 const int gsize = get_global_size(0);
33 const int lsize = get_local_size(0);
34 const int ngroups = get_num_groups(0);
35 const int wgid = get_group_id(0);
36
37 double q = 0;
38
39 const int P_start = P_rows * (k + j) + k + j;
40 const int P_span = P_rows * (k + j + 1) - P_start;
41 for (int i = lid; i < P_span; i += lsize) {
42 double acc = 0;
43 // apply previous householder reflections from current block to the
44 // column we are making the Householder vector from
45 for (int l = 0; l < j; l++) {
46 acc += P[P_rows * (k + l) + k + j + i] * V[V_rows * l + j - 1]
47 + V[V_rows * l + j - 1 + i] * P[P_rows * (k + l) + k + j];
48 }
49 double tmp = P[P_start + i] - acc;
50 P[P_start + i] = tmp;
51 if (i != 0) {
52 q += tmp * tmp;
53 }
54 }
55 // calculate column norm between threads
56 __local double q_local[LOCAL_SIZE_];
57 q_local[lid] = q;
58 barrier(CLK_LOCAL_MEM_FENCE);
59 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
60 step /= REDUCTION_STEP_SIZE) {
61 if (lid < step) {
62 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
63 q_local[lid] += q_local[lid + step * i];
64 }
65 }
66 barrier(CLK_LOCAL_MEM_FENCE);
67 }
68
69 double alpha;
70 if (lid == 0) {
71 q = q_local[0];
72 double p1 = P[P_start + 1];
73 // make Householder vector
74 alpha = -copysign(sqrt(q), P[P_start]);
75 q -= p1 * p1;
76 p1 -= alpha;
77 P[P_start + 1] = p1;
78 q += p1 * p1;
79 q = sqrt(q);
80 q_local[0] = q;
81 q_local[1] = alpha;
82 *q_glob = q;
83 }
84 barrier(CLK_LOCAL_MEM_FENCE);
85 q = q_local[0];
86 alpha = q_local[1];
87 if (q != 0) {
88 double multi = M_SQRT2 / q;
89 // normalize the Householder vector
90 for (int i = lid + 1; i < P_span; i += lsize) {
91 P[P_start + i] *= multi;
92 }
93 }
94 if (gid == 0) {
95 P[P_rows * (k + j + 1) + k + j]
96 = P[P_rows * (k + j) + k + j + 1] * q / M_SQRT2 + alpha;
97 }
98 }); // \cond
99// \endcond
100
101// \cond
102static constexpr const char* tridiagonalization_v_step_1_kernel_code
103 = STRINGIFY(
104 // \endcond
120 const __global double* P, const __global double* V,
121 __global double* Uu, __global double* Vu, const int P_rows,
122 const int V_rows, const int k) {
123 const int lid = get_local_id(0);
124 const int gid = get_global_id(0);
125 const int gsize = get_global_size(0);
126 const int lsize = get_local_size(0);
127 const int ngroups = get_num_groups(0);
128 const int wgid = get_group_id(0);
129
130 __local double res_loc1[LOCAL_SIZE_];
131 __local double res_loc2[LOCAL_SIZE_];
132 double acc1 = 0;
133 double acc2 = 0;
134
135 const __global double* vec
136 = P + P_rows * (k + ngroups) + k + ngroups + 1;
137 const __global double* M1 = P + P_rows * (k + wgid) + k + ngroups + 1;
138 const __global double* M2 = V + V_rows * wgid + ngroups;
139 for (int i = lid; i < P_rows - k - ngroups - 1;
140 i
141 += LOCAL_SIZE_) { // go over column of the matrix in steps of 64
142 double v = vec[i];
143 acc1 += M1[i] * v;
144 acc2 += M2[i] * v;
145 }
146 res_loc1[lid] = acc1;
147 res_loc2[lid] = acc2;
148 barrier(CLK_LOCAL_MEM_FENCE);
149
150 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
151 step /= REDUCTION_STEP_SIZE) {
152 if (lid < step) {
153 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
154 res_loc1[lid] += res_loc1[lid + step * i];
155 res_loc2[lid] += res_loc2[lid + step * i];
156 }
157 }
158 barrier(CLK_LOCAL_MEM_FENCE);
159 }
160 if (lid == 0) {
161 Uu[wgid] = res_loc1[0];
162 Vu[wgid] = res_loc2[0];
163 }
164 }); // \cond
165// \endcond
166
167// \cond
168static constexpr const char* tridiagonalization_v_step_2_kernel_code
169 = STRINGIFY(
170 // \endcond
188 const __global double* P, __global double* V,
189 const __global double* Uu, const __global double* Vu,
190 const int P_rows, const int V_rows, const int k, const int j) {
191 const int lid = get_local_id(0);
192 const int gid = get_global_id(0);
193 const int gsize = get_global_size(0);
194 const int lsize = get_local_size(0);
195 const int ngroups = get_num_groups(0);
196 const int wgid = get_group_id(0);
197
198 int work = P_rows - k - j - 1;
199 double acc = 0;
200
201 const __global double* vec = P + P_rows * (k + j) + k + j + 1;
202 const __global double* M1 = P + P_rows * (k + j + 1) + k + j + 1;
203 const __global double* M2 = P + P_rows * k + k + j + 1;
204 const __global double* M3 = V + j;
205 int i;
206 if (gid < work) {
207 for (i = 0; i <= gid; i++) {
208 acc += M1[P_rows * i + gid] * vec[i];
209 }
210 for (int i = 0; i < j; i++) {
211 acc -= M2[P_rows * i + gid] * Vu[i];
212 acc -= M3[V_rows * i + gid] * Uu[i];
213 }
214 V[V_rows * j + gid + j] = acc;
215 }
216 float work_per_group
217 = (float)work / ngroups; // NOLINT(readability/casting)
218 int start = work_per_group * wgid;
219 int end = work_per_group * (wgid + 1);
220 __local double res_loc[LOCAL_SIZE_];
221 for (int i = start; i < end; i += 1) {
222 acc = 0;
223 for (int l = i + 1 + lid; l < work; l += LOCAL_SIZE_) {
224 acc += M1[P_rows * i + l] * vec[l];
225 }
226 res_loc[lid] = acc;
227 barrier(CLK_LOCAL_MEM_FENCE);
228 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
229 step /= REDUCTION_STEP_SIZE) {
230 if (lid < step) {
231 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
232 res_loc[lid] += res_loc[lid + step * i];
233 }
234 }
235 barrier(CLK_LOCAL_MEM_FENCE);
236 }
237 if (lid == 0) {
238 V[V_rows * (j + 1) + i + j] = res_loc[lid];
239 }
240 barrier(CLK_LOCAL_MEM_FENCE);
241 }
242 }); // \cond
243// \endcond
244
245// \cond
246static constexpr const char* tridiagonalization_v_step_3_kernel_code
247 = STRINGIFY(
248 // \endcond
261 __global double* P, __global double* V, __global double* q,
262 const int P_rows, const int V_rows, const int k, const int j) {
263 const int lid = get_local_id(0);
264 const int gid = get_global_id(0);
265 const int gsize = get_global_size(0);
266 const int lsize = get_local_size(0);
267 const int ngroups = get_num_groups(0);
268 const int wgid = get_group_id(0);
269
270 __global double* u = P + P_rows * (k + j) + k + j + 1;
271 __global double* v = V + V_rows * j + j;
272 double acc = 0;
273
274 for (int i = lid; i < P_rows - k - j - 1; i += LOCAL_SIZE_) {
275 double vi = v[i] + v[i + V_rows];
276 v[i] = vi;
277 acc += u[i] * vi;
278 }
279 __local double res_loc[LOCAL_SIZE_];
280 res_loc[lid] = acc;
281 barrier(CLK_LOCAL_MEM_FENCE);
282 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
283 step /= REDUCTION_STEP_SIZE) {
284 if (lid < step) {
285 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
286 res_loc[lid] += res_loc[lid + step * i];
287 }
288 }
289 barrier(CLK_LOCAL_MEM_FENCE);
290 }
291 acc = res_loc[0] * 0.5;
292 for (int i = lid; i < P_rows - k - j - 1; i += LOCAL_SIZE_) {
293 v[i] -= acc * u[i];
294 }
295 if (gid == 0) {
296 P[P_rows * (k + j + 1) + k + j] -= *q / M_SQRT2 * u[0];
297 }
298 }); // \cond
299// \endcond
300
301const kernel_cl<in_out_buffer, in_out_buffer, out_buffer, int, int, int, int>
302 tridiagonalization_householder("tridiagonalization_householder",
303 {tridiagonalization_householder_kernel_code},
304 {{"REDUCTION_STEP_SIZE", 4},
305 {"LOCAL_SIZE_", 1024}});
306
307const kernel_cl<in_buffer, in_buffer, out_buffer, out_buffer, int, int, int>
308 tridiagonalization_v_step_1("tridiagonalization_v_step_1",
309 {tridiagonalization_v_step_1_kernel_code},
310 {{"REDUCTION_STEP_SIZE", 4},
311 {"LOCAL_SIZE_", 64}});
312
313const kernel_cl<in_buffer, out_buffer, in_buffer, in_buffer, int, int, int, int>
314 tridiagonalization_v_step_2("tridiagonalization_v_step_2",
315 {tridiagonalization_v_step_2_kernel_code},
316 {{"REDUCTION_STEP_SIZE", 4},
317 {"LOCAL_SIZE_", 64}});
318
319const kernel_cl<in_out_buffer, in_out_buffer, out_buffer, int, int, int, int>
320 tridiagonalization_v_step_3("tridiagonalization_v_step_3",
321 {tridiagonalization_v_step_3_kernel_code},
322 {{"REDUCTION_STEP_SIZE", 4},
323 {"LOCAL_SIZE_", 1024}});
324
325} // namespace opencl_kernels
326} // namespace math
327} // namespace stan
328#endif
329#endif
const kernel_cl< in_buffer, in_buffer, out_buffer, out_buffer, int, int, int > tridiagonalization_v_step_1("tridiagonalization_v_step_1", {tridiagonalization_v_step_1_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
const kernel_cl< in_buffer, out_buffer, in_buffer, in_buffer, int, int, int, int > tridiagonalization_v_step_2("tridiagonalization_v_step_2", {tridiagonalization_v_step_2_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}})
const kernel_cl< in_out_buffer, in_out_buffer, out_buffer, int, int, int, int > tridiagonalization_v_step_3("tridiagonalization_v_step_3", {tridiagonalization_v_step_3_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 1024}})
const kernel_cl< in_out_buffer, in_out_buffer, out_buffer, int, int, int, int > tridiagonalization_householder("tridiagonalization_householder", {tridiagonalization_householder_kernel_code}, {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 1024}})
T step(const T &y)
The step, or Heaviside, function.
Definition step.hpp:31
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:18
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9