Automatic Differentiation
 
Loading...
Searching...
No Matches
gp_dot_prod_cov.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_GP_DOT_PROD_COV_HPP
2#define STAN_MATH_OPENCL_PRIM_GP_DOT_PROD_COV_HPP
3#ifdef STAN_OPENCL
4
9
10namespace stan {
11namespace math {
12
24template <typename T_x, typename T_sigma,
25 require_all_prim_or_rev_kernel_expression_t<T_x>* = nullptr,
26 require_stan_scalar_t<T_sigma>* = nullptr>
27inline auto gp_dot_prod_cov(const T_x& x, const T_sigma sigma) {
28 const char* fun = "gp_dot_prod_cov(OpenCL)";
29 check_nonnegative(fun, "sigma", sigma);
30 check_finite(fun, "sigma", sigma);
31 const auto& x_val = value_of(x);
32 check_cl(fun, "x", x_val, "not NaN") = !isnan(x_val);
33 return add(square(sigma), transpose(x) * x);
34}
35
48template <typename T_x, typename T_y, typename T_sigma,
51inline auto gp_dot_prod_cov(const T_x& x, const T_y& y, const T_sigma sigma) {
52 const char* fun = "gp_dot_prod_cov(OpenCL)";
53 check_nonnegative(fun, "sigma", sigma);
54 check_finite(fun, "sigma", sigma);
55 const auto& x_val = value_of(x);
56 const auto& y_val = value_of(y);
57 check_cl(fun, "x", x_val, "not NaN") = !isnan(x_val);
58 check_cl(fun, "y", y_val, "not NaN") = !isnan(y_val);
59 return add(square(sigma), transpose(x) * y);
60}
61} // namespace math
62} // namespace stan
63#endif
64#endif
auto check_cl(const char *function, const char *var_name, T &&y, const char *must_be)
Constructs a check on opencl matrix or expression.
Definition check_cl.hpp:219
addition_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > add(T_a &&a, T_b &&b)
auto transpose(Arg &&a)
Transposes a kernel generator expression.
auto gp_dot_prod_cov(const T_x &x, const T_sigma sigma)
Dot product kernel on the GPU.
require_all_t< is_prim_or_rev_kernel_expression< std::decay_t< Types > >... > require_all_prim_or_rev_kernel_expression_t
Require type satisfies is_prim_or_rev_kernel_expression.
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
void check_nonnegative(const char *function, const char *name, const T_y &y)
Check if y is non-negative.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
fvar< T > square(const fvar< T > &x)
Definition square.hpp:12
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9