1#ifndef STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
2#define STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
31template <
typename T_x,
typename T_sigma, require_st_arithmetic<T_x>* =
nullptr,
32 require_stan_scalar_t<T_sigma>* =
nullptr>
35 const var length_scale) {
38 size_t x_size = x.size();
39 for (
size_t i = 0; i < x_size; ++i) {
43 Eigen::Matrix<
var, -1, -1> cov(x_size, x_size);
47 size_t l_tri_size = x_size * (x_size - 1) / 2;
53 double l_val =
value_of(length_scale);
55 double neg_half_inv_l_sq = -0.5 /
square(l_val);
57 size_t block_size = 10;
59 for (
size_t jb = 0; jb < x_size; jb += block_size) {
60 size_t j_end = std::min(x_size, jb + block_size);
61 size_t j_size = j_end - jb;
62 cov.diagonal().segment(jb, j_size)
63 = Eigen::VectorXd::Constant(j_size, sigma_sq);
65 cov_diag.segment(jb, j_size) = cov.diagonal().segment(jb, j_size);
67 for (
size_t ib = jb; ib < x_size; ib += block_size) {
68 size_t i_end = std::min(x_size, ib + block_size);
69 for (
size_t j = jb; j < j_end; ++j) {
70 for (
size_t i = std::max(ib, j + 1); i < i_end; ++i) {
72 cov_l_tri_lin.coeffRef(pos) = cov.coeffRef(j, i) = cov.coeffRef(i, j)
73 = sigma_sq *
exp(sq_dists_lin.coeff(pos) * neg_half_inv_l_sq);
81 [cov_l_tri_lin, cov_diag, sq_dists_lin, sigma, length_scale, x_size]() {
82 size_t l_tri_size = x_size * (x_size - 1) / 2;
85 for (Eigen::Index pos = 0; pos < l_tri_size; pos++) {
87 = cov_l_tri_lin.coeff(pos).val() * cov_l_tri_lin.coeff(pos).adj();
88 adjl += prod_add * sq_dists_lin.coeff(pos);
94 adjsigma += (cov_diag.val().array() * cov_diag.adj().array()).
sum();
97 double l_val =
value_of(length_scale);
98 length_scale.adj() += adjl / (l_val * l_val * l_val);
Equivalent to Eigen::Matrix, except that the data is stored on AD stack.
matrix_cl< return_type_t< T1, T2, T3 > > gp_exp_quad_cov(const matrix_cl< T1 > &x, const T2 sigma, const T3 length_scale)
Squared exponential kernel on the GPU.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
auto & adjoint_of(const T &x)
Returns a reference to a variable's adjoint.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
void check_not_nan(const char *function, const char *name, const T_y &y)
Check if y is not NaN.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
auto squared_distance(const T_a &a, const T_b &b)
Returns the squared distance.
fvar< T > square(const fvar< T > &x)
fvar< T > exp(const fvar< T > &x)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...