Automatic Differentiation
 
Loading...
Searching...
No Matches
mrrr.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_GPU_KERNELS_MRRR_HPP
2#define STAN_MATH_GPU_KERNELS_MRRR_HPP
3
4#ifdef STAN_OPENCL
5
8
9namespace stan {
10namespace math {
11namespace opencl_kernels {
12
13// \cond
14static constexpr const char* eigenvals_bisect_kernel_code = STRINGIFY(
15 // \endcond
16
26 int get_sturm_count_tri(const __global double* diagonal,
27 const __global double* subdiagonal_squared,
28 const double shift, const int n) {
29 double d = diagonal[0] - shift;
30 int count = d >= 0.0;
31 for (int j = 1; j < n; j++) {
32 d = diagonal[j] - shift - subdiagonal_squared[j - 1] / d;
33 count += d >= 0.0;
34 }
35 return count;
36 }
37
51 void eigenvals_bisect(const __global double* diagonal,
52 const __global double* subdiagonal_squared,
53 double* low_res, double* high_res,
54 const double min_eigval, const double max_eigval,
55 const int n, const int i) {
56 const double eps = 2 * DBL_EPSILON;
57
58 double low = min_eigval;
59 double high = max_eigval;
60
61 while ((high - low) > eps * fabs(high + low)
62 && fabs(high - low) > DBL_MIN) {
63 double mid = (high + low) * 0.5;
64 int count = get_sturm_count_tri(diagonal, subdiagonal_squared, mid, n);
65 if (count > i) {
66 low = mid;
67 } else {
68 high = mid;
69 }
70 }
71 *low_res = low;
72 *high_res = high;
73 }
74
86 int get_sturm_count_ldl(const __global double_d* l,
87 const __global double_d* d, const double_d shift,
88 const int n) {
89 double_d s = neg(shift);
90 double_d l_plus;
91 double_d d_plus;
92 int count = 0;
93 for (int i = 0; i < n; i++) {
94 d_plus = add_dd_dd(s, d[i]);
95 count += ge_dd_d(d_plus, 0.0);
96 if (isinf_dd(s)) { // this happens if d_plus==0 -> in next iteration
97 // d_plus==inf and s==inf
98 s = sub_dd_dd(mul_dd_dd(mul_dd_dd(l[i], l[i]), d[i]), shift);
99 } else {
100 s = sub_dd_dd(mul_dd_dd(mul_dd_dd(mul_dd_dd(l[i], l[i]), s),
101 div_dd_dd(d[i], d_plus)),
102 shift);
103 }
104 }
105 d_plus = add_dd_dd(s, d[n]);
106 count += ge_dd_d(d_plus, 0);
107 return count;
108 }
109
120 void eigenvals_bisect_refine(const __global double_d* l,
121 const __global double_d* d, double_d* low_res,
122 double_d* high_res, const int n, const int i) {
123 double_d eps = (double_d){3e-20, 0};
124 double_d min_norm = (double_d){DBL_MIN, 0};
125
126 double_d low = *low_res;
127 double_d high = *high_res;
128
129 while (gt_dd_dd(
130 abs_dd(div_dd_dd(sub_dd_dd(high, low), add_dd_dd(high, low))),
131 eps)
132 && gt_dd_dd(abs_dd(sub_dd_dd(high, low)), min_norm)) {
133 double_d mid = mul_dd_d(add_dd_dd(high, low), 0.5);
134 int count = get_sturm_count_ldl(l, d, mid, n - 1);
135 if (count > i) {
136 low = mid;
137 } else {
138 high = mid;
139 }
140 }
141 *low_res = low;
142 *high_res = high;
143 }
144
161 __kernel void eigenvals(
162 const __global double* diagonal,
163 const __global double* subdiagonal_squared, const __global double_d* l,
164 const __global double_d* d, __global double* eigval_global,
165 __global double_d* shifted_low_global,
166 __global double_d* shifted_high_global, const double min_eigval,
167 const double max_eigval, const double shift, const char do_refine) {
168 const int i = get_global_id(0);
169 const int n = get_global_size(0);
170
171 double low_eig, high_eig;
172 eigenvals_bisect(diagonal, subdiagonal_squared, &low_eig, &high_eig,
173 min_eigval, max_eigval, n, i);
174 eigval_global[i] = (low_eig + high_eig) * 0.5;
175 if (do_refine) {
176 double_d low_shifted = (double_d){low_eig - shift, 0};
177 double_d high_shifted = (double_d){high_eig - shift, 0};
178 low_shifted = mul_dd_dd(
179 low_shifted,
180 sub_dd_d((double_d){1, 0}, copysign_d_dd(1e-18 * n, low_shifted)));
181 high_shifted = mul_dd_dd(
182 high_shifted,
183 add_dd_d((double_d){1, 0}, copysign_d_dd(1e-18 * n, high_shifted)));
184 eigenvals_bisect_refine(l, d, &low_shifted, &high_shifted, n, i);
185 shifted_low_global[i] = low_shifted;
186 shifted_high_global[i] = high_shifted;
187 }
188 }
189 // \cond
190);
191// \endcond
192
193const kernel_cl<in_buffer, in_buffer, in_buffer, in_buffer, out_buffer,
194 out_buffer, out_buffer, double, double, double, char>
195 eigenvals("eigenvals", {stan::math::internal::double_d_src,
196 eigenvals_bisect_kernel_code});
197
198// \cond
199static constexpr const char* get_eigenvectors_kernel_code = STRINGIFY(
200 // \endcond
201
221 const __global double_d* l, const __global double_d* d, double_d shift,
222 __global double_d* l_plus, __global double_d* u_minus,
223 __global double_d* s) {
224 int n = get_global_size(0);
225 int gid = get_global_id(0);
226 int m = n - 1;
227 // calculate shifted ldl
228 s[gid] = neg(shift);
229 for (int i = 0; i < m; i++) {
230 double_d d_plus = add_dd_dd(s[i * n + gid], d[i * n + gid]);
231 l_plus[i * n + gid]
232 = mul_dd_dd(l[i * n + gid], div_dd_dd(d[i * n + gid], d_plus));
233 if (isnan_dd(l_plus[i * n + gid])) { // d_plus==0
234 // one (or both) of d[i], l[i] is very close to 0
235 if (lt_dd_dd(abs_dd(l[i * n + gid]), abs_dd(d[i * n + gid]))) {
236 l_plus[i * n + gid]
237 = mul_dd_d(d[i * n + gid], copysign_d_dd(1., l[i * n + gid])
238 * copysign_d_dd(1., d_plus));
239 } else {
240 l_plus[i * n + gid]
241 = mul_dd_d(l[i * n + gid], copysign_d_dd(1., d[i * n + gid])
242 * copysign_d_dd(1., d_plus));
243 }
244 }
245 s[(i + 1) * n + gid] = sub_dd_dd(
246 mul_dd_dd(mul_dd_dd(l_plus[i * n + gid], l[i * n + gid]),
247 s[i * n + gid]),
248 shift);
249 if (isnan_dd(s[(i + 1) * n + gid])) {
250 if (gt_dd_dd(abs_dd(l_plus[i * n + gid]),
251 abs_dd(s[i * n + gid]))) { // l_plus[i * n + gid] == inf
252 if (gt_dd_dd(abs_dd(s[i * n + gid]),
253 abs_dd(l[i * n + gid]))) { // l[i*n+gid]==0
254 s[(i + 1) * n + gid] = sub_dd_dd(
255 mul_dd_d(s[i * n + gid],
256 copysign_d_dd(1., l[i * n + gid])
257 * copysign_d_dd(1., l_plus[i * n + gid])),
258 shift);
259 } else { // s[i*n+gid]==0
260 s[(i + 1) * n + gid] = sub_dd_dd(
261 mul_dd_d(l[i * n + gid],
262 copysign_d_dd(1., s[i * n + gid])
263 * copysign_d_dd(1., l_plus[i * n + gid])),
264 shift);
265 }
266 } else { // s[i*n+gid]==inf
267 if (gt_dd_dd(abs_dd(l_plus[i * n + gid]),
268 abs_dd(l[i * n + gid]))) { // l[i]==0
269 s[(i + 1) * n + gid]
270 = sub_dd_dd(mul_dd_d(l_plus[i * n + gid],
271 copysign_d_dd(1., l[i * n + gid])
272 * copysign_d_dd(1., s[i * n + gid])),
273 shift);
274 } else { // l_plus[i]==0
275 s[(i + 1) * n + gid] = sub_dd_dd(
276 mul_dd_d(l[i * n + gid],
277 copysign_d_dd(1., s[i * n + gid])
278 * copysign_d_dd(1., l_plus[i * n + gid])),
279 shift);
280 }
281 }
282 }
283 }
284 // calculate shifted udu and twist index
285 double_d p = sub_dd_dd(d[m * n + gid], shift);
286 double_d min_gamma = abs_dd(add_dd_dd(s[m * n + gid], d[m * n + gid]));
287 int twist_index = m;
288
289 for (int i = m - 1; i >= 0; i--) {
290 double_d d_minus
291 = add_dd_dd(mul_dd_dd(mul_dd_dd(d[i * n + gid], l[i * n + gid]),
292 l[i * n + gid]),
293 p);
294 double_d t = div_dd_dd(d[i * n + gid], d_minus);
295 u_minus[i * n + gid] = mul_dd_dd(l[i * n + gid], t);
296 if (isnan_dd(u_minus[i * n + gid])) {
297 if (isnan_dd(t)) {
298 double t_high = copysign_d_dd(1., d[i * n + gid])
299 * copysign_d_dd(1., d_minus);
300 t.high = t_high;
301 t.low = 0;
302 u_minus[i * n + gid] = mul_dd_d(l[i * n + gid], t_high);
303 } else { // t==inf, l[i*n+gid]==0
304 u_minus[i * n + gid]
305 = mul_dd_d(d[i * n + gid], copysign_d_dd(1., l[i * n + gid])
306 * copysign_d_dd(1., t));
307 }
308 }
309 double_d gamma = abs_dd(add_dd_dd(s[i * n + gid], mul_dd_dd(t, p)));
310 if (isnan_dd(gamma)) { // t==inf, p==0 OR t==0, p==inf
311 double_d d_sign
312 = mul_dd_d(d[i * n + gid],
313 copysign_d_dd(1., d_minus) * copysign_d_dd(1., t));
314 gamma = abs_dd(add_dd_dd(s[i * n + gid], d_sign));
315 p = sub_dd_dd(d_sign, shift);
316 } else { // usual case
317 p = sub_dd_dd(mul_dd_dd(p, t), shift);
318 }
319 if (lt_dd_dd(gamma, min_gamma)) {
320 min_gamma = gamma;
321 twist_index = i;
322 }
323 }
324 return twist_index;
325 }
326
336 void calculate_eigenvector(const __global double_d* l_plus,
337 const __global double_d* u_minus,
338 const __global double* subdiag, int twist_idx,
339 __global double* eigenvectors) {
340 int n = get_global_size(0);
341 int gid = get_global_id(0);
342 int i = gid;
343 eigenvectors[twist_idx * n + gid] = 1;
344 double_d last = (double_d){1, 0};
345 double_d last2 = (double_d){1, 0};
346 double norm = 1;
347 // part of the eigenvector after the twist index
348 for (int j = twist_idx + 1; j < n; j++) {
349 if (last.high != 0 || last.low != 0) {
350 last2 = last;
351 last = neg(mul_dd_dd(u_minus[(j - 1) * n + gid], last));
352 eigenvectors[j * n + gid] = last.high;
353 } else {
354 double_d tmp = last;
355 last = mul_dd_d(last2, -subdiag[j - 2] / subdiag[j - 1]);
356 last2 = tmp;
357 if (isnan(last.high) || isinf(last.high)) { // subdiag[j - 1]==0
358 last = (double_d){0, 0};
359 }
360 eigenvectors[j * n + gid] = last.high;
361 }
362 norm += eigenvectors[j * n + gid] * eigenvectors[j * n + gid];
363 }
364 last = (double_d){eigenvectors[twist_idx * n + gid], 0};
365 last2 = (double_d){1, 0};
366 // part of the eigenvector before the twist index
367 for (int j = twist_idx - 1; j >= 0; j--) {
368 if (last.high != 0 || last.low != 0) {
369 last2 = last;
370 last = neg(mul_dd_dd(l_plus[j * n + gid], last));
371 eigenvectors[j * n + gid] = last.high;
372 } else {
373 double_d tmp = last;
374 last = mul_dd_d(last2, -subdiag[j + 1] / subdiag[j]);
375 if (isnan(last.high) || isinf(last.high)) { // subdiag[j]==0
376 last = (double_d){0, 0};
377 }
378 eigenvectors[j * n + gid] = last.high;
379 }
380 norm += eigenvectors[j * n + gid] * eigenvectors[j * n + gid];
381 }
382 norm = 1 / sqrt(norm);
383 // normalize the eigenvector
384 for (int j = 0; j < n; j++) {
385 eigenvectors[j * n + gid] *= norm;
386 }
387 }
388
402 __kernel void get_eigenvectors(
403 const __global double_d* l, const __global double_d* d,
404 const __global double* subdiag,
405 const __global double_d* shifted_eigvals, __global double_d* l_plus,
406 __global double_d* u_minus, __global double_d* temp,
407 __global double* eigenvectors) {
408 int twist_idx = get_twisted_factorization(
409 l, d, shifted_eigvals[get_global_id(0)], l_plus, u_minus, temp);
410 calculate_eigenvector(l_plus, u_minus, subdiag, twist_idx, eigenvectors);
411 }
412 // \cond
413);
414// \endcond
415
416const kernel_cl<in_buffer, in_buffer, in_buffer, in_buffer, in_out_buffer,
417 in_out_buffer, in_out_buffer, out_buffer>
418 get_eigenvectors("get_eigenvectors", {stan::math::internal::double_d_src,
419 get_eigenvectors_kernel_code});
420
421} // namespace opencl_kernels
422} // namespace math
423} // namespace stan
424#endif
425#endif
auto diagonal(T &&a)
Diagonal of a kernel generator expression.
Definition diagonal.hpp:136
void eigenvals_bisect(const __global double *diagonal, const __global double *subdiagonal_squared, double *low_res, double *high_res, const double min_eigval, const double max_eigval, const int n, const int i)
Calculates i-th largest eigenvalue of tridiagonal matrix represented by a LDL decomposition using bis...
Definition mrrr.hpp:51
int get_sturm_count_ldl(const __global double_d *l, const __global double_d *d, const double_d shift, const int n)
Calculates Sturm count of a LDL decomposition of a tridiagonal matrix - number of eigenvalues larger ...
Definition mrrr.hpp:86
int get_twisted_factorization(const __global double_d *l, const __global double_d *d, double_d shift, __global double_d *l_plus, __global double_d *u_minus, __global double_d *s)
Calculates shifted LDL and UDU factorizations.
Definition mrrr.hpp:220
void eigenvals_bisect_refine(const __global double_d *l, const __global double_d *d, double_d *low_res, double_d *high_res, const int n, const int i)
Refines bounds on the i-th largest eigenvalue of a LDL decomposition using bisection.
Definition mrrr.hpp:120
void calculate_eigenvector(const __global double_d *l_plus, const __global double_d *u_minus, const __global double *subdiag, int twist_idx, __global double *eigenvectors)
Calculates an eigenvector from twisted factorization T - shift * I = L+.
Definition mrrr.hpp:336
const kernel_cl< in_buffer, in_buffer, in_buffer, in_buffer, in_out_buffer, in_out_buffer, in_out_buffer, out_buffer > get_eigenvectors("get_eigenvectors", {stan::math::internal::double_d_src, get_eigenvectors_kernel_code})
const kernel_cl< in_buffer, in_buffer, in_buffer, in_buffer, out_buffer, out_buffer, out_buffer, double, double, double, char > eigenvals("eigenvals", {stan::math::internal::double_d_src, eigenvals_bisect_kernel_code})
int get_sturm_count_tri(const __global double *diagonal, const __global double *subdiagonal_squared, const double shift, const int n)
Calculates lower Sturm count of a tridiagonal matrix T - number of eigenvalues lower than shift.
Definition mrrr.hpp:26
fvar< T > norm(const std::complex< fvar< T > > &z)
Return the squared magnitude of the complex argument.
Definition norm.hpp:19
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
fvar< T > sqrt(const fvar< T > &x)
Definition sqrt.hpp:17
Eigen::Matrix< complex_return_t< value_type_t< EigMat > >, -1, -1 > eigenvectors(const EigMat &m)
Return the eigenvectors of a (real-valued) matrix.
fvar< T > fabs(const fvar< T > &x)
Definition fabs.hpp:15
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
#define STRINGIFY(...)
Definition stringify.hpp:9
double low
Definition double_d.hpp:27
double high
Definition double_d.hpp:26
Double double - a 128 bit floating point number defined as an exact sum of 2 doubles.
Definition double_d.hpp:25