Automatic Differentiation
 
Loading...
Searching...
No Matches
indexing_rev.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_INDEXING_REV_HPP
2#define STAN_MATH_OPENCL_INDEXING_REV_HPP
3#ifdef STAN_OPENCL
4
9
10namespace stan {
11namespace math {
12
23 const matrix_cl<double>& res) {
24 int local_mem_size
25 = opencl_context.device()[0].getInfo<CL_DEVICE_LOCAL_MEM_SIZE>();
26 int preferred_work_groups
27 = opencl_context.device()[0].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
28 int local_size = 64;
29 int n_threads = preferred_work_groups * 16 * local_size;
30
31 try {
32 if (local_mem_size > sizeof(double) * adj.size() * local_size * 2) {
34 cl::NDRange(n_threads), cl::NDRange(local_size), adj, idx, res,
35 cl::Local(sizeof(double) * adj.size() * local_size), res.size(),
36 adj.size());
37 } else if (local_mem_size > sizeof(double) * adj.size()) {
39 cl::NDRange(n_threads), cl::NDRange(local_size), adj, idx, res,
40 cl::Local(sizeof(double) * adj.size()), res.size(), adj.size());
41 } else {
43 cl::NDRange(n_threads), cl::NDRange(local_size), adj, idx, res,
44 res.size());
45 }
46 } catch (cl::Error& e) {
47 check_opencl_error("indexing reverse pass", e);
48 }
49}
50
51} // namespace math
52} // namespace stan
53
54#endif
55#endif
Represents an arithmetic matrix on the OpenCL device.
Definition matrix_cl.hpp:47
The API to access the methods and values in opencl_context_base.
void check_opencl_error(const char *function, const cl::Error &e)
Throws the domain error with specifying the OpenCL error that occurred.
std::vector< cl::Device > & device() noexcept
Returns a vector containing the OpenCL device used to create the context.
const kernel_cl< in_out_buffer, in_buffer, in_buffer, cl::LocalSpaceArg, int, int > indexing_rev_local_independent("indexing_rev", {atomic_add_double_device_function, indexing_rev_local_independent_kernel_code})
See the docs for add_batch() .
const kernel_cl< in_out_buffer, in_buffer, in_buffer, cl::LocalSpaceArg, int, int > indexing_rev_local_atomic("indexing_rev", {atomic_add_double_device_function, indexing_rev_local_atomic_kernel_code})
See the docs for add_batch() .
const kernel_cl< in_out_buffer, in_buffer, in_buffer, int > indexing_rev_global_atomic("indexing_rev", {atomic_add_double_device_function, indexing_rev_global_atomic_kernel_code})
See the docs for add_batch() .
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
void indexing_rev(matrix_cl< double > &adj, const matrix_cl< int > &idx, const matrix_cl< double > &res)
Performs reverse pass for indexing operation on the OpenCL device.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9