Automatic Differentiation
 
Loading...
Searching...
No Matches
gp_exp_quad_cov.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
2#define STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
3
12#include <cmath>
13#include <type_traits>
14#include <vector>
15
16namespace stan {
17namespace math {
18
31template <typename T_x, typename T_sigma, require_st_arithmetic<T_x>* = nullptr,
32 require_stan_scalar_t<T_sigma>* = nullptr>
33inline Eigen::Matrix<var, -1, -1> gp_exp_quad_cov(const std::vector<T_x>& x,
34 const T_sigma sigma,
35 const var length_scale) {
36 check_positive("gp_exp_quad_cov", "sigma", sigma);
37 check_positive("gp_exp_quad_cov", "length_scale", length_scale);
38 size_t x_size = x.size();
39 for (size_t i = 0; i < x_size; ++i) {
40 check_not_nan("gp_exp_quad_cov", "x", x[i]);
41 }
42
43 Eigen::Matrix<var, -1, -1> cov(x_size, x_size);
44 if (x_size == 0) {
45 return cov;
46 }
47 size_t l_tri_size = x_size * (x_size - 1) / 2;
48 arena_matrix<Eigen::VectorXd> sq_dists_lin(l_tri_size);
49 arena_matrix<Eigen::Matrix<var, -1, 1>> cov_l_tri_lin(l_tri_size);
50 arena_matrix<Eigen::Matrix<var, -1, 1>> cov_diag(
51 is_constant<T_sigma>::value ? 0 : x_size);
52
53 double l_val = value_of(length_scale);
54 double sigma_sq = square(value_of(sigma));
55 double neg_half_inv_l_sq = -0.5 / square(l_val);
56
57 size_t block_size = 10;
58 size_t pos = 0;
59 for (size_t jb = 0; jb < x_size; jb += block_size) {
60 size_t j_end = std::min(x_size, jb + block_size);
61 size_t j_size = j_end - jb;
62 cov.diagonal().segment(jb, j_size)
63 = Eigen::VectorXd::Constant(j_size, sigma_sq);
65 cov_diag.segment(jb, j_size) = cov.diagonal().segment(jb, j_size);
66 }
67 for (size_t ib = jb; ib < x_size; ib += block_size) {
68 size_t i_end = std::min(x_size, ib + block_size);
69 for (size_t j = jb; j < j_end; ++j) {
70 for (size_t i = std::max(ib, j + 1); i < i_end; ++i) {
71 sq_dists_lin.coeffRef(pos) = squared_distance(x[i], x[j]);
72 cov_l_tri_lin.coeffRef(pos) = cov.coeffRef(j, i) = cov.coeffRef(i, j)
73 = sigma_sq * exp(sq_dists_lin.coeff(pos) * neg_half_inv_l_sq);
74 pos++;
75 }
76 }
77 }
78 }
79
81 [cov_l_tri_lin, cov_diag, sq_dists_lin, sigma, length_scale, x_size]() {
82 size_t l_tri_size = x_size * (x_size - 1) / 2;
83 double adjl = 0;
84 double adjsigma = 0;
85 for (Eigen::Index pos = 0; pos < l_tri_size; pos++) {
86 double prod_add
87 = cov_l_tri_lin.coeff(pos).val() * cov_l_tri_lin.coeff(pos).adj();
88 adjl += prod_add * sq_dists_lin.coeff(pos);
90 adjsigma += prod_add;
91 }
92 }
94 adjsigma += (cov_diag.val().array() * cov_diag.adj().array()).sum();
95 adjoint_of(sigma) += adjsigma * 2 / value_of(sigma);
96 }
97 double l_val = value_of(length_scale);
98 length_scale.adj() += adjl / (l_val * l_val * l_val);
99 });
100
101 return cov;
102}
103
104} // namespace math
105} // namespace stan
106#endif
Equivalent to Eigen::Matrix, except that the data is stored on AD stack.
matrix_cl< return_type_t< T1, T2, T3 > > gp_exp_quad_cov(const matrix_cl< T1 > &x, const T2 sigma, const T3 length_scale)
Squared exponential kernel on the GPU.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
auto & adjoint_of(const T &x)
Returns a reference to a variable's adjoint.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:22
void check_not_nan(const char *function, const char *name, const T_y &y)
Check if y is not NaN.
var_value< double > var
Definition var.hpp:1187
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
auto squared_distance(const T_a &a, const T_b &b)
Returns the squared distance.
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:13
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...