1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_COLWISE_REDUCTION_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_COLWISE_REDUCTION_HPP
37 int preferred_work_groups
40 return (std::min(preferred_work_groups, (n_rows + local - 1) / local) + n_cols
60template <
typename Derived,
typename T,
typename Operation>
63 public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
66 using Scalar =
typename std::remove_reference_t<T>::Scalar;
97 template <
typename T_result>
99 std::unordered_map<const void*, const char*>& generated,
100 std::unordered_map<const void*, const char*>& generated_all,
102 const std::string& col_index_name,
const T_result& result)
const {
104 generated, generated_all, ng, row_index_name, col_index_name,
false);
106 generated, generated_all, ng, row_index_name, col_index_name);
111 +
"_global[j * n_groups_i + wg_id_i] = "
112 +
derived().var_name_ +
"_local[0];\n"
127 const std::string& col_index_name,
128 const bool view_handled,
129 const std::string& var_name_arg)
const {
132 +
"_local[LOCAL_SIZE_];\n" + type_str<Scalar>() +
" "
139 "barrier(CLK_LOCAL_MEM_FENCE);\n"
140 "for (int step = lsize_i / REDUCTION_STEP_SIZE; "
141 "step > 0; step /= REDUCTION_STEP_SIZE) {\n"
142 " if (lid_i < step) {\n"
143 " for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {\n"
145 Operation::generate(
var_name_ +
"_local[lid_i]",
146 var_name_ +
"_local[lid_i + step * i]") +
";\n"
149 " barrier(CLK_LOCAL_MEM_FENCE);\n"
160 int arg_rows = this->
template get_arg<0>().rows();
161 int arg_cols = this->
template get_arg<0>().cols();
165 if (arg_cols == -1) {
175 inline int thread_rows()
const {
return this->
template get_arg<0>().rows(); }
193 using base::arguments_;
204 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
205 return colwise_sum_<std::remove_reference_t<
decltype(arg_copy)>>(
206 std::move(arg_copy));
223template <
typename T, require_all_kernel_expressions_t<T>* =
nullptr>
248 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
249 return colwise_prod_<std::remove_reference_t<
decltype(arg_copy)>>(
250 std::move(arg_copy));
267template <
typename T, require_all_kernel_expressions_t<T>* =
nullptr>
281 max_op<typename std::remove_reference_t<T>::Scalar>> {
297 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
298 return colwise_max_<std::remove_reference_t<
decltype(arg_copy)>>(
299 std::move(arg_copy));
316template <
typename T, require_all_kernel_expressions_t<T>* =
nullptr>
330 min_op<typename std::remove_reference_t<T>::Scalar>> {
346 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
347 return colwise_min_<std::remove_reference_t<
decltype(arg_copy)>>(
348 std::move(arg_copy));
365template <
typename T, require_all_kernel_expressions_t<T>* =
nullptr>
374 :
public std::is_base_of<internal::colwise_reduction_base,
378 :
public std::is_base_of<internal::colwise_reduction_base,
Represents a calc_if in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents column wise max - reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents column wise min - reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents column wise product - reduction in kernel generator expressions.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions.
static constexpr bool require_specific_local_size
kernel_parts get_whole_kernel_parts(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, name_generator &ng, const std::string &row_index_name, const std::string &col_index_name, const T_result &result) const
Generates kernel code for assigning this expression into result expression.
int thread_rows() const
Number of rows threads need to be launched for.
Derived & derived()
Casts the instance into its derived type.
typename std::remove_reference_t< T >::Scalar Scalar
colwise_reduction(T &&a, const std::string &init)
Constructor.
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Represents a column wise reduction in kernel generator expressions.
auto deep_copy() const
Creates a deep copy of this expression.
Represents column wise sum - reduction in kernel generator expressions.
Unique name generator for variables used in generated kernels.
The API to access the methods and values in opencl_context_base.
Derived & derived()
Casts the instance into its derived type.
std::tuple< Args... > arguments_
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
Base for all kernel generator operations.
opencl_context_base::map_base_opts & base_opts() noexcept
Returns a copy of the map of kernel defines.
std::vector< cl::Device > & device() noexcept
Returns a vector containing the OpenCL device used to create the context.
auto colwise_min(T &&a)
Column wise min - reduction of a kernel generator expression.
auto colwise_prod(T &&a)
Column wise product - reduction of a kernel generator expression.
auto colwise_max(T &&a)
Column wise max - reduction of a kernel generator expression.
auto colwise_sum(T &&a)
Column wise sum - reduction of a kernel generator expression.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
int colwise_reduction_wgs_rows(int n_rows, int n_cols)
Determine number of work groups in rows direction that will be run fro colwise reduction of given siz...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
std::string initialization
Parts of an OpenCL kernel, generated by an expression.
Operation for max reduction.
Operation for min reduction.
Operation for product reduction.
Operation for sum reduction.