Automatic Differentiation
 
Loading...
Searching...
No Matches
gp_exp_quad_cov.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_GP_EXP_QUAD_COV_HPP
2#define STAN_MATH_OPENCL_PRIM_GP_EXP_QUAD_COV_HPP
3#ifdef STAN_OPENCL
4
10#include <CL/opencl.hpp>
11
12namespace stan {
13namespace math {
26template <typename T1, typename T2, typename T3,
27 typename = require_all_arithmetic_t<T1, T2, T3>>
29 const matrix_cl<T1>& x, const T2 sigma, const T3 length_scale) {
31 try {
32 opencl_kernels::gp_exp_quad_cov(cl::NDRange(x.cols(), x.cols()), x, res,
33 sigma * sigma, -0.5 / square(length_scale),
34 x.cols(), x.rows());
35 } catch (const cl::Error& e) {
36 check_opencl_error("gp_exp_quad_cov", e);
37 }
38 return res;
39}
40
58template <typename T1, typename T2, typename T3, typename T4,
61 const matrix_cl<T1>& x, const matrix_cl<T2>& y, const T3 sigma,
62 const T4 length_scale) {
63 check_size_match("gp_exp_quad_cov_cross", "x", x.rows(), "y", y.rows());
65 try {
67 cl::NDRange(x.cols(), y.cols()), x, y, res, sigma * sigma,
68 -0.5 / square(length_scale), x.cols(), y.cols(), x.rows());
69 } catch (const cl::Error& e) {
70 check_opencl_error("gp_exp_quad_cov_cross", e);
71 }
72 return res;
73}
74
75} // namespace math
76} // namespace stan
77
78#endif
79#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
require_all_t< std::is_arithmetic< std::decay_t< Types > >... > require_all_arithmetic_t
Require all of the types satisfy std::is_arithmetic.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
const kernel_cl< in_buffer, in_buffer, out_buffer, double, double, int, int, int > gp_exp_quad_cov_cross("gp_exp_quad_cov_cross", {gp_exp_quad_cov_cross_kernel_code})
See the docs for gp_exp_quad_cov_cross() .
const kernel_cl< in_buffer, out_buffer, double, double, int, int > gp_exp_quad_cov("gp_exp_quad_cov", {gp_exp_quad_cov_kernel_code})
See the docs for gp_exp_quad_cov() .
matrix_cl< return_type_t< T1, T2, T3 > > gp_exp_quad_cov(const matrix_cl< T1 > &x, const T2 sigma, const T3 length_scale)
Squared exponential kernel on the GPU.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9