1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_OPTIONALBROADCAST_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_OPTIONALBROADCAST_HPP
27template <
typename T,
bool Colwise,
bool Rowwise>
29 :
public operation_cl<optional_broadcast_<T, Colwise, Rowwise>,
30 typename std::remove_reference_t<T>::Scalar, T> {
32 using Scalar =
typename std::remove_reference_t<T>::Scalar;
48 auto&& arg_copy = this->
template get_arg<0>().deep_copy();
50 Colwise, Rowwise>{std::move(arg_copy)};
63 const std::string& col_idx_name,
64 const bool view_handled,
65 const std::string& var_name_arg)
const {
68 += type_str<Scalar>() +
" " +
var_name_ +
" = " + var_name_arg +
";\n";
84 std::string& col_idx_name)
const {
86 row_idx_name =
"(" + row_idx_name +
" * " +
var_name_ +
"is_multirow)";
89 col_idx_name =
"(" + col_idx_name +
" * " +
var_name_ +
"is_multicol)";
104 std::unordered_map<const void*, const char*>& generated,
105 std::unordered_map<const void*, const char*>& generated_all,
106 cl::Kernel& kernel,
int& arg_num)
const {
107 if (generated.count(
this) == 0) {
108 generated[
this] =
"";
109 std::unordered_map<const void*, const char*> generated2;
110 this->
template get_arg<0>().set_args(generated2, generated_all, kernel,
113 kernel.setArg(arg_num++,
static_cast<int>(
114 this->
template get_arg<0>().
rows() != 1));
117 kernel.setArg(arg_num++,
static_cast<int>(
118 this->
template get_arg<0>().
cols() != 1));
129 return Colwise && this->
template get_arg<0>().rows() == 1
131 : this->
template get_arg<0>().rows();
140 return Rowwise && this->
template get_arg<0>().cols() == 1
142 : this->
template get_arg<0>().cols();
151 if (Colwise && this->
template get_arg<0>().
rows() == 1) {
154 if (Rowwise && this->
template get_arg<0>().
cols() == 1) {
165 if (Colwise && this->
template get_arg<0>().
rows() == 1) {
166 return std::numeric_limits<int>::min();
168 return this->
template get_arg<0>().bottom_diagonal();
177 if (Rowwise && this->
template get_arg<0>().
cols() == 1) {
178 return std::numeric_limits<int>::max();
180 return this->
template get_arg<0>().top_diagonal();
200template <
bool Colwise,
bool Rowwise,
typename T,
202inline optional_broadcast_<as_operation_cl_t<T>, Colwise, Rowwise>
206 std::move(a_operation));
224 return optional_broadcast<false, true>(std::forward<T>(a));
242 return optional_broadcast<true, false>(std::forward<T>(a));
static constexpr int dynamic
Base for all kernel generator operations.
typename std::remove_reference_t< T >::Scalar Scalar
int rows() const
Number of rows of a matrix that would be the result of evaluating this expression.
int bottom_diagonal() const
Determine index of bottom diagonal written.
matrix_cl_view view() const
View of a matrix that would be the result of evaluating this expression.
int cols() const
Number of columns of a matrix that would be the result of evaluating this expression.
kernel_parts generate(const std::string &row_idx_name, const std::string &col_idx_name, const bool view_handled, const std::string &var_name_arg) const
Generates kernel code for this and nested expressions.
void modify_argument_indices(std::string &row_idx_name, std::string &col_idx_name) const
Sets index/indices along broadcasted dimmension(s) to 0.
int top_diagonal() const
Determine index of top diagonal written.
optional_broadcast_(T &&a)
Constructor.
auto deep_copy() const
Creates a deep copy of this expression.
void set_args(std::unordered_map< const void *, const char * > &generated, std::unordered_map< const void *, const char * > &generated_all, cl::Kernel &kernel, int &arg_num) const
Sets kernel arguments for this and nested expressions.
Represents an optional broadcasting operation in kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
require_all_t< is_kernel_expression_and_not_scalar< Types >... > require_all_kernel_expressions_and_none_scalar_t
Enables a template if all given types are non-scalar types that are a valid kernel generator expressi...
const matrix_cl_view either(const matrix_cl_view left_view, const matrix_cl_view right_view)
Determines which parts are nonzero in any of the input views.
optional_broadcast_< as_operation_cl_t< T >, Colwise, Rowwise > optional_broadcast(T &&a)
Broadcast an expression in specified dimension(s) if the size along that dimension equals 1.
auto colwise_optional_broadcast(T &&a)
Broadcast an expression in colwise dimmension if the number of rows equals to 1.
auto rowwise_optional_broadcast(T &&a)
Broadcast an expression in rowwise dimmension if the number of columns equals to 1.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.