Automatic Differentiation
 
Loading...
Searching...
No Matches
reduction_2d.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_REDUCTION_2D_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_REDUCTION_2D_HPP
3#ifdef STAN_OPENCL
4
15#include <map>
16#include <string>
17#include <type_traits>
18#include <utility>
19
20namespace stan {
21namespace math {
26namespace internal {
28} // namespace internal
29
44template <typename Derived, typename T, typename Operation>
47 public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
48 T> {
49 public:
50 using Scalar = typename std::remove_reference_t<T>::Scalar;
52 using base::var_name_;
53 static constexpr bool require_specific_local_size = true;
54
55 protected:
56 std::string init_;
57 using base::derived;
58
59 public:
60 using base::cols;
66 explicit reduction_2d(T&& a, const std::string& init)
67 : base(std::forward<T>(a)), init_(init) {}
68
81 template <typename T_result>
83 std::unordered_map<const void*, const char*>& generated,
84 std::unordered_map<const void*, const char*>& generated_all,
85 name_generator& ng, const std::string& row_index_name,
86 const std::string& col_index_name, const T_result& result) const {
87 kernel_parts parts = derived().get_kernel_parts(
88 generated, generated_all, ng, row_index_name, col_index_name, false);
89 kernel_parts out_parts = result.get_kernel_parts_lhs(
90 generated, generated_all, ng, row_index_name, col_index_name);
91
92 parts.args += out_parts.args;
93 parts.reduction_2d += "if (lid_i == 0) {\n"
94 + result.var_name_
95 + "_global[wg_id_j * n_groups_i + wg_id_i] = "
96 + derived().var_name_ + "_local[0];\n"
97 "}\n";
98 return parts;
99 }
100
110 inline kernel_parts generate(const std::string& row_index_name,
111 const std::string& col_index_name,
112 const bool view_handled,
113 const std::string& var_name_arg) const {
114 kernel_parts res;
115 res.declarations = "__local " + type_str<Scalar>() + " " + var_name_
116 + "_local[LOCAL_SIZE_];\n" + type_str<Scalar>() + " "
117 + var_name_ + " = " + init_ + ";\n";
118 res.body = var_name_ + " = " + Operation::generate(var_name_, var_name_arg)
119 + ";\n";
120 res.reduction_2d =
121 var_name_ + "_local[lid_i] = " + var_name_ + ";\n"
122 "barrier(CLK_LOCAL_MEM_FENCE);\n"
123 "for (int step = lsize_i / REDUCTION_STEP_SIZE; "
124 "step > 0; step /= REDUCTION_STEP_SIZE) {\n"
125 " if (lid_i < step) {\n"
126 " for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {\n"
127 " " + var_name_ + "_local[lid_i] = " +
128 Operation::generate(var_name_ + "_local[lid_i]",
129 var_name_ + "_local[lid_i + step * i]") + ";\n"
130 " }\n"
131 " }\n"
132 " barrier(CLK_LOCAL_MEM_FENCE);\n"
133 "}\n";
134 return res;
135 }
136
142 inline int rows() const {
143 int arg_rows = this->template get_arg<0>().rows();
144 int arg_cols = this->template get_arg<0>().cols();
145 if (arg_cols == 0) {
146 return 1;
147 }
148 if (arg_cols == base::dynamic) {
149 return base::dynamic;
150 }
151 return internal::colwise_reduction_wgs_rows(arg_rows, arg_cols);
152 }
153
159 inline int cols() const {
160 int arg_rows = this->template get_arg<0>().rows();
161 int arg_cols = this->template get_arg<0>().cols();
162 if (arg_cols == 0) {
163 return 0;
164 }
165 if (arg_cols == base::dynamic) {
166 return base::dynamic;
167 }
168 int wgs_rows = internal::colwise_reduction_wgs_rows(arg_rows, arg_cols);
169 if (wgs_rows == 0) {
170 return 0;
171 }
172 return (arg_cols + wgs_rows - 1) / wgs_rows;
173 }
174
179 inline int thread_rows() const { return this->template get_arg<0>().rows(); }
180
185 inline int thread_cols() const { return this->template get_arg<0>().cols(); }
186
191 inline std::pair<int, int> extreme_diagonals() const {
192 return {-rows() + 1, cols() - 1};
193 }
194}; // namespace math
195
200template <typename T>
203 using base::arguments_;
204
205 public:
206 explicit sum_2d_(T&& a)
207 : reduction_2d<sum_2d_<T>, T, sum_op>(std::forward<T>(a), "0") {}
212 inline auto deep_copy() const {
213 auto&& arg_copy = this->template get_arg<0>().deep_copy();
214 return sum_2d_<std::remove_reference_t<decltype(arg_copy)>>(
215 std::move(arg_copy));
216 }
217};
218
232template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
233inline auto sum_2d(T&& a) {
234 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
235 return sum_2d_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a)));
236}
237
243template <typename T>
244class prod_2d_ : public reduction_2d<prod_2d_<T>, T, prod_op> {
246 using base::arguments_;
247
248 public:
249 explicit prod_2d_(T&& a)
250 : reduction_2d<prod_2d_<T>, T, prod_op>(std::forward<T>(a), "1") {}
255 inline auto deep_copy() const {
256 auto&& arg_copy = this->template get_arg<0>().deep_copy();
257 return prod_2d_<std::remove_reference_t<decltype(arg_copy)>>(
258 std::move(arg_copy));
259 }
260};
261
275template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
276inline auto prod_2d(T&& a) {
277 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
278 return prod_2d_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a)));
279}
280
285template <typename T>
287 : public reduction_2d<max_2d_<T>, T,
288 max_op<typename std::remove_reference_t<T>::Scalar>> {
289 using base
292 using base::arguments_;
293
294 public:
296 explicit max_2d_(T&& a)
297 : reduction_2d<max_2d_<T>, T, op>(std::forward<T>(a), op::init()) {}
302 inline auto deep_copy() const {
303 auto&& arg_copy = this->template get_arg<0>().deep_copy();
304 return max_2d_<std::remove_reference_t<decltype(arg_copy)>>(
305 std::move(arg_copy));
306 }
307};
308
322template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
323inline auto max_2d(T&& a) {
324 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
325 return max_2d_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a)));
326}
327
332template <typename T>
334 : public reduction_2d<min_2d_<T>, T,
335 min_op<typename std::remove_reference_t<T>::Scalar>> {
336 using base
339 using base::arguments_;
340
341 public:
343 explicit min_2d_(T&& a)
344 : reduction_2d<min_2d_<T>, T, op>(std::forward<T>(a), op::init()) {}
349 inline auto deep_copy() const {
350 auto&& arg_copy = this->template get_arg<0>().deep_copy();
351 return min_2d_<std::remove_reference_t<decltype(arg_copy)>>(
352 std::move(arg_copy));
353 }
354};
355
369template <typename T, require_all_kernel_expressions_t<T>* = nullptr>
370inline auto min_2d(T&& a) {
371 return min_2d_<as_operation_cl_t<T>>(as_operation_cl(std::forward<T>(a)));
372}
373
374namespace internal {
375template <typename T>
377 : public std::is_base_of<internal::reduction_2d_base, std::decay_t<T>> {};
378template <typename T>
380 : public std::is_base_of<internal::reduction_2d_base, std::decay_t<T>> {};
381} // namespace internal
382
386template <typename T>
388
390} // namespace math
391} // namespace stan
392#endif
393#endif
Represents a calc_if in kernel generator expressions.
Definition calc_if.hpp:31
auto deep_copy() const
Creates a deep copy of this expression.
Represents two dimensional max - reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents two dimensional min - reduction in kernel generator expressions.
Unique name generator for variables used in generated kernels.
static constexpr int dynamic
Derived & derived()
Casts the instance into its derived type.
std::tuple< Args... > arguments_
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Base for all kernel generator operations.
auto deep_copy() const
Creates a deep copy of this expression.
Represents two dimensional product - reduction in kernel generator expressions.
kernel_parts get_whole_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &ng, const std::string &row_index_name, const std::string &col_index_name, const T_result &result) const
Generates kernel code for assigning this expression into result expression.
int thread_rows() const
Number of rows threads need to be launched for.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Derived & derived()
Casts the instance into its derived type.
int thread_cols() const
Number of rows threads need to be launched for.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
typename std::remove_reference_t< T >::Scalar Scalar
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
reduction_2d(T &&a, const std::string &init)
Constructor.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions.
static constexpr bool require_specific_local_size
Represents a two dimensional reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents two dimensional sum - reduction in kernel generator expressions.
auto sum_2d(T &&a)
Two dimensional sum - reduction of a kernel generator expression.
auto min_2d(T &&a)
Two dimensional min - reduction of a kernel generator expression.
auto prod_2d(T &&a)
Two dimensional product - reduction of a kernel generator expression.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
auto max_2d(T &&a)
Two dimensional max - reduction of a kernel generator expression.
int colwise_reduction_wgs_rows(int n_rows, int n_cols)
Determine number of work groups in rows direction that will be run fro colwise reduction of given siz...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.
Operation for max reduction.
Operation for min reduction.
Operation for product reduction.
Operation for sum reduction.