1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_REDUCTION_2D_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_REDUCTION_2D_HPP
44template <
typename Derived,
typename T,
typename Operation>
47 public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
50 using Scalar =
typename std::remove_reference_t<T>::Scalar;
81 template <
typename T_result>
83 std::unordered_map<const void*, const char*>& generated,
84 std::unordered_map<const void*, const char*>& generated_all,
86 const std::string& col_index_name,
const T_result& result)
const {
88 generated, generated_all, ng, row_index_name, col_index_name,
false);
90 generated, generated_all, ng, row_index_name, col_index_name);
95 +
"_global[wg_id_j * n_groups_i + wg_id_i] = "
96 +
derived().var_name_ +
"_local[0];\n"
111 const std::string& col_index_name,
112 const bool view_handled,
113 const std::string& var_name_arg)
const {
116 +
"_local[LOCAL_SIZE_];\n" + type_str<Scalar>() +
" "
122 "barrier(CLK_LOCAL_MEM_FENCE);\n"
123 "for (int step = lsize_i / REDUCTION_STEP_SIZE; "
124 "step > 0; step /= REDUCTION_STEP_SIZE) {\n"
125 " if (lid_i < step) {\n"
126 " for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {\n"
128 Operation::generate(
var_name_ +
"_local[lid_i]",
129 var_name_ +
"_local[lid_i + step * i]") +
";\n"
132 " barrier(CLK_LOCAL_MEM_FENCE);\n"
143 int arg_rows = this->
template get_arg<0>().rows();
144 int arg_cols = this->
template get_arg<0>().cols();
160 int arg_rows = this->
template get_arg<0>().rows();
161 int arg_cols = this->
template get_arg<0>().cols();
172 return (arg_cols + wgs_rows - 1) / wgs_rows;
179 inline int thread_rows()
const {
return this->
template get_arg<0>().rows(); }
185 inline int thread_cols()
const {
return this->
template get_arg<0>().cols(); }
203 using base::arguments_;
213 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
214 return sum_2d_<std::remove_reference_t<
decltype(arg_copy)>>(
215 std::move(arg_copy));
232template <
typename T, require_all_kernel_expressions_t<T>* =
nullptr>
256 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
257 return prod_2d_<std::remove_reference_t<
decltype(arg_copy)>>(
258 std::move(arg_copy));
275template <
typename T, require_all_kernel_expressions_t<T>* =
nullptr>
288 max_op<typename std::remove_reference_t<T>::Scalar>> {
303 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
304 return max_2d_<std::remove_reference_t<
decltype(arg_copy)>>(
305 std::move(arg_copy));
322template <
typename T, require_all_kernel_expressions_t<T>* =
nullptr>
335 min_op<typename std::remove_reference_t<T>::Scalar>> {
350 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
351 return min_2d_<std::remove_reference_t<
decltype(arg_copy)>>(
352 std::move(arg_copy));
369template <
typename T, require_all_kernel_expressions_t<T>* =
nullptr>
377 :
public std::is_base_of<internal::reduction_2d_base, std::decay_t<T>> {};
380 :
public std::is_base_of<internal::reduction_2d_base, std::decay_t<T>> {};
Represents a calc_if in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents two dimensional max - reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents two dimensional min - reduction in kernel generator expressions.
Unique name generator for variables used in generated kernels.
static constexpr int dynamic
Derived & derived()
Casts the instance into its derived type.
std::tuple< Args... > arguments_
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Base for all kernel generator operations.
auto deep_copy() const
Creates a deep copy of this expression.
Represents two dimensional product - reduction in kernel generator expressions.
kernel_parts get_whole_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &ng, const std::string &row_index_name, const std::string &col_index_name, const T_result &result) const
Generates kernel code for assigning this expression into result expression.
int thread_rows() const
Number of rows threads need to be launched for.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Derived & derived()
Casts the instance into its derived type.
int thread_cols() const
Number of rows threads need to be launched for.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
typename std::remove_reference_t< T >::Scalar Scalar
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
reduction_2d(T &&a, const std::string &init)
Constructor.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions.
static constexpr bool require_specific_local_size
Represents a two dimensional reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents two dimensional sum - reduction in kernel generator expressions.
auto sum_2d(T &&a)
Two dimensional sum - reduction of a kernel generator expression.
auto min_2d(T &&a)
Two dimensional min - reduction of a kernel generator expression.
auto prod_2d(T &&a)
Two dimensional product - reduction of a kernel generator expression.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
auto max_2d(T &&a)
Two dimensional max - reduction of a kernel generator expression.
int colwise_reduction_wgs_rows(int n_rows, int n_cols)
Determine number of work groups in rows direction that will be run fro colwise reduction of given siz...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.
Operation for max reduction.
Operation for min reduction.
Operation for product reduction.
Operation for sum reduction.