1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ROWWISE_REDUCTION_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ROWWISE_REDUCTION_HPP
26template <
typename Arg>
34 const Arg& a, std::unordered_map<const void*, const char*>& generated,
35 std::unordered_map<const void*, const char*>& generated_all,
37 const std::string& col_index_name) {
42template <
typename Mat,
typename VecT>
56 return a.template get_arg<1>().template get_arg<0>().
view();
74 const Arg& mul, std::unordered_map<const void*, const char*>& generated,
75 std::unordered_map<const void*, const char*>& generated_all,
77 const std::string& col_index_name) {
79 if (generated.count(&mul) == 0) {
83 const auto& matrix = mul.template get_arg<0>();
84 const auto&
broadcast = mul.template get_arg<1>();
85 res = matrix.get_kernel_parts(generated, generated_all, name_gen,
86 row_index_name, col_index_name,
true);
91 const auto& vec =
broadcast.template get_arg<0>();
92 std::string row_index_name_bc = row_index_name;
93 std::string col_index_name_bc = col_index_name;
94 broadcast.modify_argument_indices(row_index_name_bc, col_index_name_bc);
95 res += vec.get_kernel_parts(generated, generated_all, name_gen,
96 row_index_name_bc, col_index_name_bc,
true);
97 res +=
broadcast.generate(row_index_name, col_index_name,
true,
100 res += mul.
generate(row_index_name, col_index_name,
true,
121template <
typename Derived,
typename T,
typename operation,
bool PassZero>
123 :
public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
127 using Scalar =
typename T_no_ref::Scalar;
157 std::unordered_map<const void*, const char*>& generated,
158 std::unordered_map<const void*, const char*>& generated_all,
160 const std::string& col_index_name,
bool view_handled)
const {
162 if (generated.count(
this) == 0) {
164 generated[
this] =
"";
166 std::unordered_map<const void*, const char*> generated2;
169 this->
template get_arg<0>(), generated2, generated_all, name_gen,
172 res = this->
template get_arg<0>().get_kernel_parts(
173 generated2, generated_all, name_gen, row_index_name,
174 var_name_ +
"_j", view_handled || PassZero);
177 =
generate(row_index_name, col_index_name, view_handled,
180 res.
body = res.body_prefix + res.body;
181 res.body_prefix =
"";
196 const std::string& col_index_name,
197 const bool view_handled,
198 const std::string& var_name_arg)
const {
204 +
var_name_ +
"_view, LOWER) ? 0 : " + row_index_name
210 + row_index_name +
" + 1);\n";
219 + row_index_name +
" + 1);\n";
249 std::unordered_map<const void*, const char*>& generated,
250 std::unordered_map<const void*, const char*>& generated_all,
251 cl::Kernel& kernel,
int& arg_num)
const {
252 if (generated.count(
this) == 0) {
253 generated[
this] =
"";
254 std::unordered_map<const void*, const char*> generated2;
255 this->
template get_arg<0>().set_args(generated2, generated_all, kernel,
257 kernel.setArg(arg_num++, this->
template get_arg<0>().
view());
258 kernel.setArg(arg_num++, this->
template get_arg<0>().
cols());
261 this->
template get_arg<0>()));
271 inline int cols()
const {
return 1; }
292 inline static std::string
generate(
const std::string& a,
293 const std::string& b) {
294 return a +
" + " + b;
306 using base::arguments_;
316 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
317 return rowwise_sum_<std::remove_reference_t<
decltype(arg_copy)>>(
318 std::move(arg_copy));
332 return rowwise_sum_<std::remove_reference_t<
decltype(arg_copy)>>(
333 std::move(arg_copy));
346 inline static std::string
generate(
const std::string& a,
347 const std::string& b) {
348 return a +
" * " + b;
370 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
371 return rowwise_prod_<std::remove_reference_t<
decltype(arg_copy)>>(
372 std::move(arg_copy));
386 return rowwise_prod_<std::remove_reference_t<
decltype(arg_copy)>>(
387 std::move(arg_copy));
402 inline static std::string
generate(
const std::string& a,
403 const std::string& b) {
404 if (std::is_floating_point<T>()) {
405 return "fmax(" + a +
", " + b +
")";
407 return "max(" + a +
", " + b +
")";
410 inline static std::string
init() {
411 if (std::is_floating_point<T>()) {
426 max_op<typename std::remove_reference_t<T>::Scalar>, false> {
438 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
439 return rowwise_max_<std::remove_reference_t<
decltype(arg_copy)>>(
440 std::move(arg_copy));
454 return rowwise_max_<std::remove_reference_t<
decltype(arg_copy)>>(
455 std::move(arg_copy));
469 inline static std::string
generate(
const std::string& a,
470 const std::string& b) {
471 if (std::is_floating_point<T>()) {
472 return "fmin(" + a +
", " + b +
")";
474 return "min(" + a +
", " + b +
")";
477 inline static std::string
init() {
478 if (std::is_floating_point<T>()) {
493 min_op<typename std::remove_reference_t<T>::Scalar>, false> {
505 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
506 return rowwise_min_<std::remove_reference_t<
decltype(arg_copy)>>(
507 std::move(arg_copy));
521 return rowwise_min_<std::remove_reference_t<
decltype(arg_copy)>>(
522 std::move(arg_copy));
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_a, const std::string &var_name_b) const
Generates kernel code for this expression.
Represents a broadcasting operation in kernel generator expressions.
std::string generate()
Generates a unique variable name.
Unique name generator for variables used in generated kernels.
matrix_cl_view view() const
View of a matrix that would be the result of evaluating this expression.
std::tuple< Args... > arguments_
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Base for all kernel generator operations.
auto deep_copy() const
Creates a deep copy of this expression.
max_op< typename std::remove_reference_t< T >::Scalar > op
Represents rowwise max reduction in kernel generator expressions.
min_op< typename std::remove_reference_t< T >::Scalar > op
auto deep_copy() const
Creates a deep copy of this expression.
Represents rowwise min reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents rowwise product reduction in kernel generator expressions.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
kernel_parts get_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name, bool view_handled) const
Generates kernel code for this and nested expressions.
typename T_no_ref::Scalar Scalar
std::remove_reference_t< T > T_no_ref
rowwise_reduction(T &&a, const std::string &init)
Constructor.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
Represents a rowwise reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents rowwise sum reduction in kernel generator expressions.
auto rowwise_max(T &&a)
Rowwise max reduction of a kernel generator expression.
auto rowwise_min(T &&a)
Min reduction of a kernel generator expression.
auto rowwise_sum(T &&a)
Rowwise sum reduction of a kernel generator expression.
auto rowwise_prod(T &&a)
Rowwise product reduction of a kernel generator expression.
auto broadcast(T &&a)
Broadcast an expression in specified dimension(s).
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
static matrix_cl_view view(const Arg &a)
Return view of the vector.
static kernel_parts get_kernel_parts(const Arg &mul, std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name)
Generates kernel code for the argument of rowwise reduction, applying the optimization - ignoring the...
static matrix_cl_view view(const Arg &)
static kernel_parts get_kernel_parts(const Arg &a, std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &name_gen, const std::string &row_index_name, const std::string &col_index_name)
Implementation of an optimization for usage of rowwise reduction in matrix-vector multiplication.
Parts of an OpenCL kernel, generated by an expression.
static std::string generate(const std::string &a, const std::string &b)
Generates max reduction kernel code.
static std::string init()
Operation for max reduction.
static std::string generate(const std::string &a, const std::string &b)
Generates min reduction kernel code.
static std::string init()
Operation for min reduction.
static std::string generate(const std::string &a, const std::string &b)
Generates prod reduction kernel code.
Operation for product reduction.
static std::string generate(const std::string &a, const std::string &b)
Generates sum reduction kernel code.
Operation for sum reduction.