Automatic Differentiation
 
Loading...
Searching...
No Matches
rowwise_reduction.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ROWWISE_REDUCTION_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ROWWISE_REDUCTION_HPP
3#ifdef STAN_OPENCL
4
13#include <map>
14#include <string>
15#include <type_traits>
16#include <utility>
17
18namespace stan {
19namespace math {
20namespace internal {
21
26template <typename Arg>
28 // in general the optimization is not possible
29 enum { is_possible = 0 };
30
31 static matrix_cl_view view(const Arg&) { return matrix_cl_view::Entire; }
32
34 const Arg& a, std::unordered_map<const void*, const char*>& generated,
35 std::unordered_map<const void*, const char*>& generated_all,
36 name_generator& name_gen, const std::string& row_index_name,
37 const std::string& col_index_name) {
38 return {};
39 }
40};
41
42template <typename Mat, typename VecT>
43struct matvec_mul_opt<elt_multiply_<Mat, broadcast_<VecT, true, false>>> {
44 // if the argument of rowwise reduction is multiplication with a broadcast
45 // vector we can do the optimization
46 enum { is_possible = 1 };
48
55 static matrix_cl_view view(const Arg& a) {
56 return a.template get_arg<1>().template get_arg<0>().view();
57 }
58
74 const Arg& mul, std::unordered_map<const void*, const char*>& generated,
75 std::unordered_map<const void*, const char*>& generated_all,
76 name_generator& name_gen, const std::string& row_index_name,
77 const std::string& col_index_name) {
78 kernel_parts res{};
79 if (generated.count(&mul) == 0) {
80 mul.var_name_ = name_gen.generate();
81 generated[&mul] = "";
82
83 const auto& matrix = mul.template get_arg<0>();
84 const auto& broadcast = mul.template get_arg<1>();
85 res = matrix.get_kernel_parts(generated, generated_all, name_gen,
86 row_index_name, col_index_name, true);
87 if (generated.count(&broadcast) == 0) {
88 broadcast.var_name_ = name_gen.generate();
89 generated[&broadcast] = "";
90
91 const auto& vec = broadcast.template get_arg<0>();
92 std::string row_index_name_bc = row_index_name;
93 std::string col_index_name_bc = col_index_name;
94 broadcast.modify_argument_indices(row_index_name_bc, col_index_name_bc);
95 res += vec.get_kernel_parts(generated, generated_all, name_gen,
96 row_index_name_bc, col_index_name_bc, true);
97 res += broadcast.generate(row_index_name, col_index_name, true,
98 vec.var_name_);
99 }
100 res += mul.generate(row_index_name, col_index_name, true,
101 matrix.var_name_, broadcast.var_name_);
102 }
103 return res;
104 }
105};
106
107} // namespace internal
108
121template <typename Derived, typename T, typename operation, bool PassZero>
123 : public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
124 T> {
125 public:
126 using T_no_ref = std::remove_reference_t<T>;
127 using Scalar = typename T_no_ref::Scalar;
129 using base::var_name_;
130
131 protected:
132 std::string init_;
133
134 public:
135 using base::rows;
141 explicit rowwise_reduction(T&& a, const std::string& init)
142 : base(std::forward<T>(a)), init_(init) {}
143
157 std::unordered_map<const void*, const char*>& generated,
158 std::unordered_map<const void*, const char*>& generated_all,
159 name_generator& name_gen, const std::string& row_index_name,
160 const std::string& col_index_name, bool view_handled) const {
161 kernel_parts res{};
162 if (generated.count(this) == 0) {
163 this->var_name_ = name_gen.generate();
164 generated[this] = "";
165
166 std::unordered_map<const void*, const char*> generated2;
169 this->template get_arg<0>(), generated2, generated_all, name_gen,
170 row_index_name, var_name_ + "_j");
171 } else {
172 res = this->template get_arg<0>().get_kernel_parts(
173 generated2, generated_all, name_gen, row_index_name,
174 var_name_ + "_j", view_handled || PassZero);
175 }
176 kernel_parts my_part
177 = generate(row_index_name, col_index_name, view_handled,
178 this->template get_arg<0>().var_name_);
179 res += my_part;
180 res.body = res.body_prefix + res.body;
181 res.body_prefix = "";
182 }
183 return res;
184 }
185
195 inline kernel_parts generate(const std::string& row_index_name,
196 const std::string& col_index_name,
197 const bool view_handled,
198 const std::string& var_name_arg) const {
199 kernel_parts res;
200 res.body_prefix
201 = type_str<Scalar>() + " " + var_name_ + " = " + init_ + ";\n";
202 if (PassZero) {
203 res.body_prefix += "int " + var_name_ + "_start = contains_nonzero("
204 + var_name_ + "_view, LOWER) ? 0 : " + row_index_name
205 + ";\n";
207 res.body_prefix += "int " + var_name_ + "_end_temp = contains_nonzero("
208 + var_name_ + "_view, UPPER) ? " + var_name_
209 + "_cols : min(" + var_name_ + "_cols, "
210 + row_index_name + " + 1);\n";
211 res.body_prefix += "int " + var_name_ + "_end = contains_nonzero("
212 + var_name_ + "_vec_view, UPPER) ? " + var_name_
213 + "_end_temp : min(1, " + var_name_
214 + "_end_temp);\n";
215 } else {
216 res.body_prefix += "int " + var_name_ + "_end = contains_nonzero("
217 + var_name_ + "_view, UPPER) ? " + var_name_
218 + "_cols : min(" + var_name_ + "_cols, "
219 + row_index_name + " + 1);\n";
220 }
221 res.body_prefix += "for(int " + var_name_ + "_j = " + var_name_
222 + "_start; " + var_name_ + "_j < " + var_name_
223 + "_end; " + var_name_ + "_j++){\n";
224 } else {
225 res.body_prefix += "for(int " + var_name_ + "_j = 0; " + var_name_
226 + "_j < " + var_name_ + "_cols; " + var_name_
227 + "_j++){\n";
228 }
229 res.body += var_name_ + " = " + operation::generate(var_name_, var_name_arg)
230 + ";\n}\n";
231 res.args = "int " + var_name_ + "_view, int " + var_name_ + "_cols, ";
233 res.args += "int " + var_name_ + "_vec_view, ";
234 }
235 return res;
236 }
237
248 inline void set_args(
249 std::unordered_map<const void*, const char*>& generated,
250 std::unordered_map<const void*, const char*>& generated_all,
251 cl::Kernel& kernel, int& arg_num) const {
252 if (generated.count(this) == 0) {
253 generated[this] = "";
254 std::unordered_map<const void*, const char*> generated2;
255 this->template get_arg<0>().set_args(generated2, generated_all, kernel,
256 arg_num);
257 kernel.setArg(arg_num++, this->template get_arg<0>().view());
258 kernel.setArg(arg_num++, this->template get_arg<0>().cols());
260 kernel.setArg(arg_num++, internal::matvec_mul_opt<T_no_ref>::view(
261 this->template get_arg<0>()));
262 }
263 }
264 }
265
271 inline int cols() const { return 1; }
272
277 inline std::pair<int, int> extreme_diagonals() const {
278 return {-rows() + 1, cols() - 1};
279 }
280};
281
285struct sum_op {
292 inline static std::string generate(const std::string& a,
293 const std::string& b) {
294 return a + " + " + b;
295 }
296};
297
302template <typename T>
304 : public rowwise_reduction<rowwise_sum_<T>, T, sum_op, true> {
306 using base::arguments_;
307
308 public:
309 explicit rowwise_sum_(T&& a) : base(std::forward<T>(a), "0") {}
310
315 inline auto deep_copy() const {
316 auto&& arg_copy = this->template get_arg<0>().deep_copy();
317 return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
318 std::move(arg_copy));
319 }
320};
321
328template <typename T,
330inline auto rowwise_sum(T&& a) {
331 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
332 return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
333 std::move(arg_copy));
334}
335
339struct prod_op {
346 inline static std::string generate(const std::string& a,
347 const std::string& b) {
348 return a + " * " + b;
349 }
350};
351
356template <typename T>
358 : public rowwise_reduction<rowwise_prod_<T>, T, prod_op, false> {
360 using base::arguments_;
361
362 public:
363 explicit rowwise_prod_(T&& a) : base(std::forward<T>(a), "1") {}
364
369 inline auto deep_copy() const {
370 auto&& arg_copy = this->template get_arg<0>().deep_copy();
371 return rowwise_prod_<std::remove_reference_t<decltype(arg_copy)>>(
372 std::move(arg_copy));
373 }
374};
375
382template <typename T,
384inline auto rowwise_prod(T&& a) {
385 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
386 return rowwise_prod_<std::remove_reference_t<decltype(arg_copy)>>(
387 std::move(arg_copy));
388}
389
394template <typename T>
395struct max_op {
402 inline static std::string generate(const std::string& a,
403 const std::string& b) {
404 if (std::is_floating_point<T>()) {
405 return "fmax(" + a + ", " + b + ")";
406 }
407 return "max(" + a + ", " + b + ")";
408 }
409
410 inline static std::string init() {
411 if (std::is_floating_point<T>()) {
412 return "-INFINITY";
413 }
414 return "INT_MIN";
415 }
416};
417
422template <typename T>
424 : public rowwise_reduction<
425 rowwise_max_<T>, T,
426 max_op<typename std::remove_reference_t<T>::Scalar>, false> {
429 using base::arguments_;
430
431 public:
432 explicit rowwise_max_(T&& a) : base(std::forward<T>(a), op::init()) {}
437 inline auto deep_copy() const {
438 auto&& arg_copy = this->template get_arg<0>().deep_copy();
439 return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
440 std::move(arg_copy));
441 }
442};
443
450template <typename T,
452inline auto rowwise_max(T&& a) {
453 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
454 return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
455 std::move(arg_copy));
456}
461template <typename T>
462struct min_op {
469 inline static std::string generate(const std::string& a,
470 const std::string& b) {
471 if (std::is_floating_point<T>()) {
472 return "fmin(" + a + ", " + b + ")";
473 }
474 return "min(" + a + ", " + b + ")";
475 }
476
477 inline static std::string init() {
478 if (std::is_floating_point<T>()) {
479 return "INFINITY";
480 }
481 return "INT_MAX";
482 }
483};
484
489template <typename T>
491 : public rowwise_reduction<
492 rowwise_min_<T>, T,
493 min_op<typename std::remove_reference_t<T>::Scalar>, false> {
496 using base::arguments_;
497
498 public:
499 explicit rowwise_min_(T&& a) : base(std::forward<T>(a), op::init()) {}
504 inline auto deep_copy() const {
505 auto&& arg_copy = this->template get_arg<0>().deep_copy();
506 return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
507 std::move(arg_copy));
508 }
509};
510
517template <typename T,
519inline auto rowwise_min(T&& a) {
520 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
521 return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
522 std::move(arg_copy));
523}
525} // namespace math
526} // namespace stan
527
528#endif
529#endif
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_a, const std::string &var_name_b) const
Generates kernel code for this expression.
Represents a broadcasting operation in kernel generator expressions.
Definition broadcast.hpp:33
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
matrix_cl_view view() const
View of a matrix that would be the result of evaluating this expression.
std::tuple< Args... > arguments_
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Base for all kernel generator operations.
auto deep_copy() const
Creates a deep copy of this expression.
max_op< typename std::remove_reference_t< T >::Scalar > op
Represents rowwise max reduction in kernel generator expressions.
min_op< typename std::remove_reference_t< T >::Scalar > op
auto deep_copy() const
Creates a deep copy of this expression.
Represents rowwise min reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents rowwise product reduction in kernel generator expressions.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
typename T_no_ref::Scalar Scalar
std::remove_reference_t< T > T_no_ref
rowwise_reduction(T &&a, const std::string &init)
Constructor.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Represents a rowwise reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents rowwise sum reduction in kernel generator expressions.
auto rowwise_max(T &&a)
Rowwise max reduction of a kernel generator expression.
auto rowwise_min(T &&a)
Min reduction of a kernel generator expression.
auto rowwise_sum(T &&a)
Rowwise sum reduction of a kernel generator expression.
auto rowwise_prod(T &&a)
Rowwise product reduction of a kernel generator expression.
auto broadcast(T &&a)
Broadcast an expression in specified dimension(s).
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
STL namespace.
static kernel_parts get_kernel_parts(const Arg &mul, std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name)
Generates kernel code for the argument of rowwise reduction, applying the optimization - ignoring the...
static matrix_cl_view view(const Arg &)
static kernel_parts get_kernel_parts(const Arg &a, std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name)
Implementation of an optimization for usage of rowwise reduction in matrix-vector multiplication.
Parts of an OpenCL kernel, generated by an expression.
static std::string generate(const std::string &a, const std::string &b)
Generates max reduction kernel code.
static std::string init()
Operation for max reduction.
static std::string generate(const std::string &a, const std::string &b)
Generates min reduction kernel code.
static std::string init()
Operation for min reduction.
static std::string generate(const std::string &a, const std::string &b)
Generates prod reduction kernel code.
Operation for product reduction.
static std::string generate(const std::string &a, const std::string &b)
Generates sum reduction kernel code.
Operation for sum reduction.