Automatic Differentiation
 
Loading...
Searching...
No Matches
gp_exp_quad_cov.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
2#define STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
3
13#include <cmath>
14#include <type_traits>
15#include <vector>
16
17namespace stan {
18namespace math {
19
32template <typename T_x, typename T_sigma, require_st_arithmetic<T_x>* = nullptr,
33 require_stan_scalar_t<T_sigma>* = nullptr>
34inline Eigen::Matrix<var, -1, -1> gp_exp_quad_cov(const std::vector<T_x>& x,
35 const T_sigma sigma,
36 const var length_scale) {
37 check_positive("gp_exp_quad_cov", "sigma", sigma);
38 check_positive("gp_exp_quad_cov", "length_scale", length_scale);
39 size_t x_size = x.size();
40 for (size_t i = 0; i < x_size; ++i) {
41 check_not_nan("gp_exp_quad_cov", "x", x[i]);
42 }
43
44 Eigen::Matrix<var, -1, -1> cov(x_size, x_size);
45 if (x_size == 0) {
46 return cov;
47 }
48 size_t l_tri_size = x_size * (x_size - 1) / 2;
49 arena_matrix<Eigen::VectorXd> sq_dists_lin(l_tri_size);
50 arena_matrix<Eigen::Matrix<var, -1, 1>> cov_l_tri_lin(l_tri_size);
51 arena_matrix<Eigen::Matrix<var, -1, 1>> cov_diag(
52 is_constant<T_sigma>::value ? 0 : x_size);
53
54 double l_val = value_of(length_scale);
55 double sigma_sq = square(value_of(sigma));
56 double neg_half_inv_l_sq = -0.5 / square(l_val);
57
58 size_t block_size = 10;
59 size_t pos = 0;
60 for (size_t jb = 0; jb < x_size; jb += block_size) {
61 size_t j_end = std::min(x_size, jb + block_size);
62 size_t j_size = j_end - jb;
63 cov.diagonal().segment(jb, j_size)
64 = Eigen::VectorXd::Constant(j_size, sigma_sq);
66 cov_diag.segment(jb, j_size) = cov.diagonal().segment(jb, j_size);
67 }
68 for (size_t ib = jb; ib < x_size; ib += block_size) {
69 size_t i_end = std::min(x_size, ib + block_size);
70 for (size_t j = jb; j < j_end; ++j) {
71 for (size_t i = std::max(ib, j + 1); i < i_end; ++i) {
72 sq_dists_lin.coeffRef(pos) = squared_distance(x[i], x[j]);
73 cov_l_tri_lin.coeffRef(pos) = cov.coeffRef(j, i) = cov.coeffRef(i, j)
74 = sigma_sq * exp(sq_dists_lin.coeff(pos) * neg_half_inv_l_sq);
75 pos++;
76 }
77 }
78 }
79 }
80
82 [cov_l_tri_lin, cov_diag, sq_dists_lin, sigma, length_scale, x_size]() {
83 size_t l_tri_size = x_size * (x_size - 1) / 2;
84 double adjl = 0;
85 double adjsigma = 0;
86 for (Eigen::Index pos = 0; pos < l_tri_size; pos++) {
87 double prod_add
88 = cov_l_tri_lin.coeff(pos).val() * cov_l_tri_lin.coeff(pos).adj();
89 adjl += prod_add * sq_dists_lin.coeff(pos);
91 adjsigma += prod_add;
92 }
93 }
95 adjsigma += (cov_diag.val().array() * cov_diag.adj().array()).sum();
96 adjoint_of(sigma) += adjsigma * 2 / value_of(sigma);
97 }
98 double l_val = value_of(length_scale);
99 length_scale.adj() += adjl / (l_val * l_val * l_val);
100 });
101
102 return cov;
103}
104
105} // namespace math
106} // namespace stan
107#endif
Equivalent to Eigen::Matrix, except that the data is stored on AD stack.
matrix_cl< return_type_t< T1, T2, T3 > > gp_exp_quad_cov(const matrix_cl< T1 > &x, const T2 sigma, const T3 length_scale)
Squared exponential kernel on the GPU.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
auto & adjoint_of(const T &x)
Returns a reference to a variable's adjoint.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_not_nan(const char *function, const char *name, const T_y &y)
Check if y is not NaN.
auto sum(const std::vector< T > &m)
Return the sum of the entries of the specified standard vector.
Definition sum.hpp:23
var_value< double > var
Definition var.hpp:1187
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
auto squared_distance(const T_a &a, const T_b &b)
Returns the squared distance.
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
fvar< T > exp(const fvar< T > &x)
Definition exp.hpp:15
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...