Automatic Differentiation
 
Loading...
Searching...
No Matches
columns_dot_product.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_COLUMNS_DOT_PRODUCT_HPP
2#define STAN_MATH_OPENCL_PRIM_COLUMNS_DOT_PRODUCT_HPP
3#ifdef STAN_OPENCL
9
10namespace stan {
11namespace math {
12
25template <typename T_a, typename T_b,
26 require_all_kernel_expressions_and_none_scalar_t<T_a, T_b>* = nullptr>
27inline auto columns_dot_product(const T_a& a, const T_b& b) {
28 using res_scal = std::common_type_t<value_type_t<T_a>, value_type_t<T_b>>;
29 check_matching_sizes("columns_dot_product", "a", a, "b", b);
31
32 if (size_zero(a, b)) {
33 res = constant(res_scal(0), 1, a.cols());
34 return res;
35 }
36
37 res = colwise_sum(elt_multiply(a, b));
38 while (res.rows() > 1) {
39 res = colwise_sum(res).eval();
40 }
41 return res;
42}
43
44} // namespace math
45} // namespace stan
46
47#endif
48#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
elt_multiply_< as_operation_cl_t< T_a >, as_operation_cl_t< T_b > > elt_multiply(T_a &&a, T_b &&b)
auto constant(const T a, int rows, int cols)
Matrix of repeated values in kernel generator expressions.
Definition constant.hpp:130
auto colwise_sum(T &&a)
Column wise sum - reduction of a kernel generator expression.
typename value_type< T >::type value_type_t
Helper function for accessing underlying type.
bool size_zero(const T &x)
Returns 1 if input is of length 0, returns 0 otherwise.
Definition size_zero.hpp:19
void check_matching_sizes(const char *function, const char *name1, const T_y1 &y1, const char *name2, const T_y2 &y2)
Check if two structures at the same size.
auto columns_dot_product(const T_a &a, const T_b &b)
Returns the dot product of columns of the specified matrices.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...