Automatic Differentiation
 
Loading...
Searching...
No Matches
cumulative_sum.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_PRIM_CUMULATIVE_SUM_HPP
2#define STAN_MATH_OPENCL_PRIM_CUMULATIVE_SUM_HPP
3#ifdef STAN_OPENCL
8
9namespace stan {
10namespace math {
11
23template <typename T_vec,
24 require_all_kernel_expressions_and_none_scalar_t<T_vec>* = nullptr>
25inline auto cumulative_sum(T_vec&& v) {
26 using T_scal = scalar_type_t<T_vec>;
27 check_vector("cumulative_sum(OpenCL)", "v", v);
28
29 matrix_cl<T_scal> res(v.rows(), v.cols());
30 if (v.size() == 0) {
31 return res;
32 }
33
35 res = v;
36 }
37 const int local_size
39 "LOCAL_SIZE_");
40 const int work_groups = std::min(
41 (v.size() + local_size - 1) / local_size,
42 static_cast<int>(
43 opencl_context.device()[0].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>())
44 * 16);
45 const int local_size2
47 "LOCAL_SIZE_");
48 const matrix_cl<T_scal>& in
49 = static_select<is_matrix_cl<T_vec>::value>(v, res);
50
51 matrix_cl<T_scal> tmp_threads(local_size * work_groups, 1);
52 matrix_cl<T_scal> tmp_wgs(work_groups, 1);
53 try {
55 cl::NDRange(local_size * work_groups), cl::NDRange(local_size), tmp_wgs,
56 tmp_threads, in, v.size());
58 cl::NDRange(local_size2),
59 tmp_wgs, work_groups);
61 cl::NDRange(local_size * work_groups), cl::NDRange(local_size), res, in,
62 tmp_threads, tmp_wgs, v.size());
63 } catch (const cl::Error& e) {
64 check_opencl_error("cumulative_sum", e);
65 }
66 return res;
67}
68
69} // namespace math
70} // namespace stan
71
72#endif
73#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
The API to access the methods and values in opencl_context_base.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
std::vector< cl::Device > & device() noexcept
Returns a vector containing the OpenCL device used to create the context.
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
void check_vector(const char *function, const char *name, const Mat &x)
Check the input is either a row vector or column vector or a matrix with a single row or column.
auto cumulative_sum(T_vec &&v)
Return the cumulative sum of the specified vector.
typename scalar_type< T >::type scalar_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
Checks if the decayed type of T is a matrix_cl.
struct containing cumulative_sum kernels, grouped by scalar type.