Automatic Differentiation
 
Loading...
Searching...
No Matches
gp_matern52_cov.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_GP_MATERN52_COV_HPP
2#define STAN_MATH_OPENCL_PRIM_GP_MATERN52_COV_HPP
3#ifdef STAN_OPENCL
4
10#include <CL/opencl.hpp>
11
12namespace stan {
13namespace math {
26template <typename T1, typename T2, typename T3,
27 require_all_kernel_expressions_and_none_scalar_t<T1>* = nullptr,
28 require_all_arithmetic_t<T2, T3>* = nullptr>
30 const T1& x, const T2 sigma, const T3 length_scale) {
31 const auto& x_eval = x.eval();
32 matrix_cl<return_type_t<T1, T2, T3>> res(x.cols(), x.cols());
33 int block_size = 16;
34 int n_blocks = (x.cols() + block_size - 1) / block_size;
35 int blocked_size = block_size * n_blocks;
36 try {
38 cl::NDRange(blocked_size, blocked_size),
39 cl::NDRange(block_size, block_size), x_eval, res, sigma * sigma,
40 std::sqrt(5.0) / length_scale, 5.0 / (3.0 * square(length_scale)),
41 x.cols(), x.rows());
42 } catch (const cl::Error& e) {
43 check_opencl_error("gp_matern52_cov", e);
44 }
45 return res;
46}
47
65template <typename T1, typename T2, typename T3, typename T4,
69 const T1& x, const T2& y, const T3 sigma, const T4 length_scale) {
70 check_size_match("gp_matern52_cov_cross", "x", x.rows(), "y", y.rows());
71 matrix_cl<return_type_t<T1, T2, T3, T4>> res(x.cols(), y.cols());
72 const auto& x_eval = x.eval();
73 const auto& y_eval = y.eval();
74 int block_size = 16;
75 int x_blocks = (x.cols() + block_size - 1) / block_size;
76 int x_blocked_size = block_size * x_blocks;
77 int y_blocks = (y.cols() + block_size - 1) / block_size;
78 int y_blocked_size = block_size * y_blocks;
79 try {
81 cl::NDRange(x_blocked_size, y_blocked_size),
82 cl::NDRange(block_size, block_size), x_eval, y_eval, res, sigma * sigma,
83 std::sqrt(5.0) / length_scale, 5.0 / (3.0 * square(length_scale)),
84 x.cols(), y.cols(), x.rows());
85 } catch (const cl::Error& e) {
86 check_opencl_error("gp_matern52_cov_cross", e);
87 }
88 return res;
89}
90
103template <typename T1, typename T2, typename T3,
106inline matrix_cl<return_type_t<T1, T2, T3>> gp_matern52_cov(
107 const T1& x, const T2 sigma, const T3 length_scale) {
108 const auto& x_eval = elt_divide(x, rowwise_broadcast(length_scale)).eval();
109 matrix_cl<return_type_t<T1, T2, T3>> res(x.cols(), x.cols());
110 int block_size = 16;
111 int n_blocks = (x.cols() + block_size - 1) / block_size;
112 int blocked_size = block_size * n_blocks;
113 try {
114 opencl_kernels::gp_matern52_cov(cl::NDRange(blocked_size, blocked_size),
115 cl::NDRange(block_size, block_size), x_eval,
116 res, sigma * sigma, std::sqrt(5.0),
117 5.0 / 3.0, x.cols(), x.rows());
118 } catch (const cl::Error& e) {
119 check_opencl_error("gp_matern52_cov", e);
120 }
121 return res;
122}
123
141template <
142 typename T1, typename T2, typename T3, typename T4,
143 require_all_kernel_expressions_and_none_scalar_t<T1, T2, T4>* = nullptr,
144 require_all_arithmetic_t<T3>* = nullptr>
145inline matrix_cl<return_type_t<T1, T2, T3, T4>> gp_matern52_cov(
146 const T1& x, const T2& y, const T3 sigma, const T4 length_scale) {
147 check_size_match("gp_matern52_cov_cross", "x", x.rows(), "y", y.rows());
148 matrix_cl<return_type_t<T1, T2, T3, T4>> res(x.cols(), y.cols());
149 const auto& x_eval = elt_divide(x, rowwise_broadcast(length_scale)).eval();
150 const auto& y_eval = elt_divide(y, rowwise_broadcast(length_scale)).eval();
151 int block_size = 16;
152 int x_blocks = (x.cols() + block_size - 1) / block_size;
153 int x_blocked_size = block_size * x_blocks;
154 int y_blocks = (y.cols() + block_size - 1) / block_size;
155 int y_blocked_size = block_size * y_blocks;
156 try {
158 cl::NDRange(x_blocked_size, y_blocked_size),
159 cl::NDRange(block_size, block_size), x_eval, y_eval, res, sigma * sigma,
160 std::sqrt(5.0), 5.0 / 3.0, x.cols(), y.cols(), x.rows());
161 } catch (const cl::Error& e) {
162 check_opencl_error("gp_matern52_cov_cross", e);
163 }
164 return res;
165}
166
167} // namespace math
168} // namespace stan
169
170#endif
171#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
require_all_t< std::is_arithmetic< std::decay_t< Types > >... > require_all_arithmetic_t
Require all of the types satisfy std::is_arithmetic.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
elt_divide_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_divide(T_a &&a, T_b &&b)
auto rowwise_broadcast(T &&a)
Broadcast an expression in rowwise dimmension.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
const kernel_cl< in_buffer, out_buffer, double, double, double, int, int > gp_matern52_cov("gp_matern52_cov", {gp_matern52_cov_kernel_code})
See the docs for gp_matern52_cov() .
const kernel_cl< in_buffer, in_buffer, out_buffer, double, double, double, int, int, int > gp_matern52_cov_cross("gp_matern52_cov_cross", {gp_matern52_cov_cross_kernel_code})
See the docs for gp_matern52_cov_cross() .
matrix_cl< return_type_t< T1, T2, T3 > > gp_matern52_cov(const T1 &x, const T2 sigma, const T3 length_scale)
Matern 5/2 kernel on the GPU.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9