1#ifndef STAN_MATH_GPU_KERNELS_MRRR_HPP
2#define STAN_MATH_GPU_KERNELS_MRRR_HPP
11namespace opencl_kernels {
14static constexpr const char* eigenvals_bisect_kernel_code =
STRINGIFY(
27 const __global
double* subdiagonal_squared,
28 const double shift,
const int n) {
31 for (
int j = 1; j < n; j++) {
32 d =
diagonal[j] - shift - subdiagonal_squared[j - 1] / d;
52 const __global
double* subdiagonal_squared,
53 double* low_res,
double* high_res,
54 const double min_eigval,
const double max_eigval,
55 const int n,
const int i) {
56 const double eps = 2 * DBL_EPSILON;
58 double low = min_eigval;
59 double high = max_eigval;
61 while ((high - low) > eps *
fabs(high + low)
62 &&
fabs(high - low) > DBL_MIN) {
63 double mid = (high + low) * 0.5;
93 for (
int i = 0; i < n; i++) {
94 d_plus = add_dd_dd(s, d[i]);
95 count += ge_dd_d(d_plus, 0.0);
98 s = sub_dd_dd(mul_dd_dd(mul_dd_dd(l[i], l[i]), d[i]), shift);
100 s = sub_dd_dd(mul_dd_dd(mul_dd_dd(mul_dd_dd(l[i], l[i]), s),
101 div_dd_dd(d[i], d_plus)),
105 d_plus = add_dd_dd(s, d[n]);
106 count += ge_dd_d(d_plus, 0);
122 double_d* high_res,
const int n,
const int i) {
130 abs_dd(div_dd_dd(sub_dd_dd(high, low), add_dd_dd(high, low))),
132 && gt_dd_dd(abs_dd(sub_dd_dd(high, low)), min_norm)) {
133 double_d mid = mul_dd_d(add_dd_dd(high, low), 0.5);
163 const __global
double* subdiagonal_squared,
const __global
double_d* l,
164 const __global
double_d* d, __global
double* eigval_global,
165 __global
double_d* shifted_low_global,
166 __global
double_d* shifted_high_global,
const double min_eigval,
167 const double max_eigval,
const double shift,
const char do_refine) {
168 const int i = get_global_id(0);
169 const int n = get_global_size(0);
171 double low_eig, high_eig;
173 min_eigval, max_eigval, n, i);
174 eigval_global[i] = (low_eig + high_eig) * 0.5;
178 low_shifted = mul_dd_dd(
180 sub_dd_d((
double_d){1, 0}, copysign_d_dd(1
e-18 * n, low_shifted)));
181 high_shifted = mul_dd_dd(
183 add_dd_d((
double_d){1, 0}, copysign_d_dd(1
e-18 * n, high_shifted)));
185 shifted_low_global[i] = low_shifted;
186 shifted_high_global[i] = high_shifted;
193const kernel_cl<in_buffer, in_buffer, in_buffer, in_buffer, out_buffer,
194 out_buffer, out_buffer, double, double, double,
char>
195 eigenvals(
"eigenvals", {stan::math::internal::double_d_src,
196 eigenvals_bisect_kernel_code});
199static constexpr const char* get_eigenvectors_kernel_code =
STRINGIFY(
224 int n = get_global_size(0);
225 int gid = get_global_id(0);
229 for (
int i = 0; i < m; i++) {
230 double_d d_plus = add_dd_dd(s[i * n + gid], d[i * n + gid]);
232 = mul_dd_dd(l[i * n + gid], div_dd_dd(d[i * n + gid], d_plus));
233 if (isnan_dd(l_plus[i * n + gid])) {
235 if (lt_dd_dd(abs_dd(l[i * n + gid]), abs_dd(d[i * n + gid]))) {
237 = mul_dd_d(d[i * n + gid], copysign_d_dd(1., l[i * n + gid])
238 * copysign_d_dd(1., d_plus));
241 = mul_dd_d(l[i * n + gid], copysign_d_dd(1., d[i * n + gid])
242 * copysign_d_dd(1., d_plus));
245 s[(i + 1) * n + gid] = sub_dd_dd(
246 mul_dd_dd(mul_dd_dd(l_plus[i * n + gid], l[i * n + gid]),
249 if (isnan_dd(s[(i + 1) * n + gid])) {
250 if (gt_dd_dd(abs_dd(l_plus[i * n + gid]),
251 abs_dd(s[i * n + gid]))) {
252 if (gt_dd_dd(abs_dd(s[i * n + gid]),
253 abs_dd(l[i * n + gid]))) {
254 s[(i + 1) * n + gid] = sub_dd_dd(
255 mul_dd_d(s[i * n + gid],
256 copysign_d_dd(1., l[i * n + gid])
257 * copysign_d_dd(1., l_plus[i * n + gid])),
260 s[(i + 1) * n + gid] = sub_dd_dd(
261 mul_dd_d(l[i * n + gid],
262 copysign_d_dd(1., s[i * n + gid])
263 * copysign_d_dd(1., l_plus[i * n + gid])),
267 if (gt_dd_dd(abs_dd(l_plus[i * n + gid]),
268 abs_dd(l[i * n + gid]))) {
270 = sub_dd_dd(mul_dd_d(l_plus[i * n + gid],
271 copysign_d_dd(1., l[i * n + gid])
272 * copysign_d_dd(1., s[i * n + gid])),
275 s[(i + 1) * n + gid] = sub_dd_dd(
276 mul_dd_d(l[i * n + gid],
277 copysign_d_dd(1., s[i * n + gid])
278 * copysign_d_dd(1., l_plus[i * n + gid])),
285 double_d p = sub_dd_dd(d[m * n + gid], shift);
286 double_d min_gamma = abs_dd(add_dd_dd(s[m * n + gid], d[m * n + gid]));
289 for (
int i = m - 1; i >= 0; i--) {
291 = add_dd_dd(mul_dd_dd(mul_dd_dd(d[i * n + gid], l[i * n + gid]),
294 double_d t = div_dd_dd(d[i * n + gid], d_minus);
295 u_minus[i * n + gid] = mul_dd_dd(l[i * n + gid], t);
296 if (isnan_dd(u_minus[i * n + gid])) {
298 double t_high = copysign_d_dd(1., d[i * n + gid])
299 * copysign_d_dd(1., d_minus);
302 u_minus[i * n + gid] = mul_dd_d(l[i * n + gid], t_high);
305 = mul_dd_d(d[i * n + gid], copysign_d_dd(1., l[i * n + gid])
306 * copysign_d_dd(1., t));
309 double_d gamma = abs_dd(add_dd_dd(s[i * n + gid], mul_dd_dd(t, p)));
310 if (isnan_dd(gamma)) {
312 = mul_dd_d(d[i * n + gid],
313 copysign_d_dd(1., d_minus) * copysign_d_dd(1., t));
314 gamma = abs_dd(add_dd_dd(s[i * n + gid], d_sign));
315 p = sub_dd_dd(d_sign, shift);
317 p = sub_dd_dd(mul_dd_dd(p, t), shift);
319 if (lt_dd_dd(gamma, min_gamma)) {
338 const __global
double* subdiag,
int twist_idx,
340 int n = get_global_size(0);
341 int gid = get_global_id(0);
348 for (
int j = twist_idx + 1; j < n; j++) {
349 if (last.
high != 0 || last.
low != 0) {
351 last = neg(mul_dd_dd(u_minus[(j - 1) * n + gid], last));
355 last = mul_dd_d(last2, -subdiag[j - 2] / subdiag[j - 1]);
357 if (isnan(last.
high) || isinf(last.
high)) {
367 for (
int j = twist_idx - 1; j >= 0; j--) {
368 if (last.
high != 0 || last.
low != 0) {
370 last = neg(mul_dd_dd(l_plus[j * n + gid], last));
374 last = mul_dd_d(last2, -subdiag[j + 1] / subdiag[j]);
375 if (isnan(last.
high) || isinf(last.
high)) {
384 for (
int j = 0; j < n; j++) {
404 const __global
double* subdiag,
409 l, d, shifted_eigvals[get_global_id(0)], l_plus, u_minus, temp);
416const kernel_cl<in_buffer, in_buffer, in_buffer, in_buffer, in_out_buffer,
417 in_out_buffer, in_out_buffer, out_buffer>
419 get_eigenvectors_kernel_code});
auto diagonal(T &&a)
Diagonal of a kernel generator expression.
void eigenvals_bisect(const __global double *diagonal, const __global double *subdiagonal_squared, double *low_res, double *high_res, const double min_eigval, const double max_eigval, const int n, const int i)
Calculates i-th largest eigenvalue of tridiagonal matrix represented by a LDL decomposition using bis...
int get_sturm_count_ldl(const __global double_d *l, const __global double_d *d, const double_d shift, const int n)
Calculates Sturm count of a LDL decomposition of a tridiagonal matrix - number of eigenvalues larger ...
int get_twisted_factorization(const __global double_d *l, const __global double_d *d, double_d shift, __global double_d *l_plus, __global double_d *u_minus, __global double_d *s)
Calculates shifted LDL and UDU factorizations.
void eigenvals_bisect_refine(const __global double_d *l, const __global double_d *d, double_d *low_res, double_d *high_res, const int n, const int i)
Refines bounds on the i-th largest eigenvalue of a LDL decomposition using bisection.
void calculate_eigenvector(const __global double_d *l_plus, const __global double_d *u_minus, const __global double *subdiag, int twist_idx, __global double *eigenvectors)
Calculates an eigenvector from twisted factorization T - shift * I = L+.
const kernel_cl< in_buffer, in_buffer, in_buffer, in_buffer, in_out_buffer, in_out_buffer, in_out_buffer, out_buffer > get_eigenvectors("get_eigenvectors", {stan::math::internal::double_d_src, get_eigenvectors_kernel_code})
const kernel_cl< in_buffer, in_buffer, in_buffer, in_buffer, out_buffer, out_buffer, out_buffer, double, double, double, char > eigenvals("eigenvals", {stan::math::internal::double_d_src, eigenvals_bisect_kernel_code})
int get_sturm_count_tri(const __global double *diagonal, const __global double *subdiagonal_squared, const double shift, const int n)
Calculates lower Sturm count of a tridiagonal matrix T - number of eigenvalues lower than shift.
fvar< T > norm(const std::complex< fvar< T > > &z)
Return the squared magnitude of the complex argument.
static constexpr double e()
Return the base of the natural logarithm.
fvar< T > sqrt(const fvar< T > &x)
Eigen::Matrix< complex_return_t< value_type_t< EigMat > >, -1, -1 > eigenvectors(const EigMat &m)
Return the eigenvectors of a (real-valued) matrix.
fvar< T > fabs(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Double double - a 128 bit floating point number defined as an exact sum of 2 doubles.