Automatic Differentiation
 
Loading...
Searching...
No Matches
colwise_reduction.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_COLWISE_REDUCTION_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_COLWISE_REDUCTION_HPP
3#ifdef STAN_OPENCL
4
14#include <map>
15#include <string>
16#include <type_traits>
17#include <utility>
18
19namespace stan {
20namespace math {
25namespace internal {
27
35inline int colwise_reduction_wgs_rows(int n_rows, int n_cols) {
36 int local = opencl_context.base_opts().at("LOCAL_SIZE_");
37 int preferred_work_groups
38 = opencl_context.device()[0].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>() * 16;
39 // round up n_rows/local/n_cols
40 return (std::min(preferred_work_groups, (n_rows + local - 1) / local) + n_cols
41 - 1)
42 / n_cols;
43}
44} // namespace internal
45
60template <typename Derived, typename T, typename Operation>
63 public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
64 T> {
65 public:
66 using Scalar = typename std::remove_reference_t<T>::Scalar;
68 using base::var_name_;
69 static constexpr bool require_specific_local_size = true;
70
71 protected:
72 std::string init_;
73 using base::derived;
74
75 public:
76 using base::cols;
82 explicit colwise_reduction(T&& a, const std::string& init)
83 : base(std::forward<T>(a)), init_(init) {}
84
97 template <typename T_result>
99 std::unordered_map<const void*, const char*>& generated,
100 std::unordered_map<const void*, const char*>& generated_all,
101 name_generator& ng, const std::string& row_index_name,
102 const std::string& col_index_name, const T_result& result) const {
103 kernel_parts parts = derived().get_kernel_parts(
104 generated, generated_all, ng, row_index_name, col_index_name, false);
105 kernel_parts out_parts = result.get_kernel_parts_lhs(
106 generated, generated_all, ng, row_index_name, col_index_name);
107
108 parts.args += out_parts.args;
109 parts.reduction_1d += "if (lid_i == 0) {\n"
110 + result.var_name_
111 + "_global[j * n_groups_i + wg_id_i] = "
112 + derived().var_name_ + "_local[0];\n"
113 "}\n";
114 return parts;
115 }
116
126 inline kernel_parts generate(const std::string& row_index_name,
127 const std::string& col_index_name,
128 const bool view_handled,
129 const std::string& var_name_arg) const {
130 kernel_parts res;
131 res.declarations = "__local " + type_str<Scalar>() + " " + var_name_
132 + "_local[LOCAL_SIZE_];\n" + type_str<Scalar>() + " "
133 + var_name_ + ";\n";
134 res.initialization = var_name_ + " = " + init_ + ";\n";
135 res.body = var_name_ + " = " + Operation::generate(var_name_, var_name_arg)
136 + ";\n";
137 res.reduction_1d =
138 var_name_ + "_local[lid_i] = " + var_name_ + ";\n"
139 "barrier(CLK_LOCAL_MEM_FENCE);\n"
140 "for (int step = lsize_i / REDUCTION_STEP_SIZE; "
141 "step > 0; step /= REDUCTION_STEP_SIZE) {\n"
142 " if (lid_i < step) {\n"
143 " for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {\n"
144 " " + var_name_ + "_local[lid_i] = " +
145 Operation::generate(var_name_ + "_local[lid_i]",
146 var_name_ + "_local[lid_i + step * i]") + ";\n"
147 " }\n"
148 " }\n"
149 " barrier(CLK_LOCAL_MEM_FENCE);\n"
150 "}\n";
151 return res;
152 }
153
159 inline int rows() const {
160 int arg_rows = this->template get_arg<0>().rows();
161 int arg_cols = this->template get_arg<0>().cols();
162 if (arg_cols == 0) {
163 return 1;
164 }
165 if (arg_cols == -1) {
166 return -1;
167 }
168 return internal::colwise_reduction_wgs_rows(arg_rows, arg_cols);
169 }
170
175 inline int thread_rows() const { return this->template get_arg<0>().rows(); }
176
181 inline std::pair<int, int> extreme_diagonals() const {
182 return {-rows() + 1, cols() - 1};
183 }
184}; // namespace math
185
190template <typename T>
193 using base::arguments_;
194
195 public:
196 explicit colwise_sum_(T&& a)
197 : colwise_reduction<colwise_sum_<T>, T, sum_op>(std::forward<T>(a), "0") {
198 }
203 inline auto deep_copy() const {
204 auto&& arg_copy = this->template get_arg<0>().deep_copy();
205 return colwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
206 std::move(arg_copy));
207 }
208};
209
223template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
224inline auto colwise_sum(T&& a) {
225 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
227 as_operation_cl(std::forward<T>(a)));
228}
229
234template <typename T>
235class colwise_prod_ : public colwise_reduction<colwise_prod_<T>, T, prod_op> {
237 using base::arguments_;
238
239 public:
240 explicit colwise_prod_(T&& a)
241 : colwise_reduction<colwise_prod_<T>, T, prod_op>(std::forward<T>(a),
242 "1") {}
247 inline auto deep_copy() const {
248 auto&& arg_copy = this->template get_arg<0>().deep_copy();
249 return colwise_prod_<std::remove_reference_t<decltype(arg_copy)>>(
250 std::move(arg_copy));
251 }
252};
253
267template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
268inline auto colwise_prod(T&& a) {
269 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
271 as_operation_cl(std::forward<T>(a)));
272}
273
278template <typename T>
280 colwise_max_<T>, T,
281 max_op<typename std::remove_reference_t<T>::Scalar>> {
282 using base
285 using base::arguments_;
286
287 public:
289 explicit colwise_max_(T&& a)
290 : colwise_reduction<colwise_max_<T>, T, op>(std::forward<T>(a),
291 op::init()) {}
296 inline auto deep_copy() const {
297 auto&& arg_copy = this->template get_arg<0>().deep_copy();
298 return colwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
299 std::move(arg_copy));
300 }
301};
302
316template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
317inline auto colwise_max(T&& a) {
318 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
320 as_operation_cl(std::forward<T>(a)));
321}
322
327template <typename T>
329 colwise_min_<T>, T,
330 min_op<typename std::remove_reference_t<T>::Scalar>> {
331 using base
334 using base::arguments_;
335
336 public:
338 explicit colwise_min_(T&& a)
339 : colwise_reduction<colwise_min_<T>, T, op>(std::forward<T>(a),
340 op::init()) {}
345 inline auto deep_copy() const {
346 auto&& arg_copy = this->template get_arg<0>().deep_copy();
347 return colwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
348 std::move(arg_copy));
349 }
350};
351
365template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
366inline auto colwise_min(T&& a) {
368 as_operation_cl(std::forward<T>(a)));
369}
370
371namespace internal {
372template <typename T>
374 : public std::is_base_of<internal::colwise_reduction_base,
375 std::decay_t<T>> {};
376template <typename T>
378 : public std::is_base_of<internal::colwise_reduction_base,
379 std::decay_t<T>> {};
380} // namespace internal
381
385template <typename T>
388
390} // namespace math
391} // namespace stan
392#endif
393#endif
Represents a calc_if in kernel generator expressions.
Definition calc_if.hpp:31
auto deep_copy() const
Creates a deep copy of this expression.
Represents column wise max - reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents column wise min - reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents column wise product - reduction in kernel generator expressions.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions.
static constexpr bool require_specific_local_size
kernel_parts get_whole_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &ng, const std::string &row_index_name, const std::string &col_index_name, const T_result &result) const
Generates kernel code for assigning this expression into result expression.
int thread_rows() const
Number of rows threads need to be launched for.
Derived & derived()
Casts the instance into its derived type.
typename std::remove_reference_t< T >::Scalar Scalar
colwise_reduction(T &&a, const std::string &init)
Constructor.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Represents a column wise reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents column wise sum - reduction in kernel generator expressions.
Unique name generator for variables used in generated kernels.
The API to access the methods and values in opencl_context_base.
Derived & derived()
Casts the instance into its derived type.
std::tuple< Args... > arguments_
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Base for all kernel generator operations.
opencl_context_base::map_base_opts & base_opts() noexcept
Returns a copy of the map of kernel defines.
std::vector< cl::Device > & device() noexcept
Returns a vector containing the OpenCL device used to create the context.
auto colwise_min(T &&a)
Column wise min - reduction of a kernel generator expression.
auto colwise_prod(T &&a)
Column wise product - reduction of a kernel generator expression.
auto colwise_max(T &&a)
Column wise max - reduction of a kernel generator expression.
auto colwise_sum(T &&a)
Column wise sum - reduction of a kernel generator expression.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
int colwise_reduction_wgs_rows(int n_rows, int n_cols)
Determine number of work groups in rows direction that will be run fro colwise reduction of given siz...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.
Operation for max reduction.
Operation for min reduction.
Operation for product reduction.
Operation for sum reduction.