1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_SELECT_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_SELECT_HPP
35template <
typename T_condition,
typename T_then,
typename T_else>
37 common_scalar_t<T_then, T_else>,
38 T_condition, T_then, T_else> {
42 T_condition, T_then, T_else>;
51 select_(T_condition&& condition, T_then&& then, T_else&& els)
52 :
base(
std::forward<T_condition>(condition),
std::forward<T_then>(then),
53 std::forward<T_else>(els)) {
56 "rows of ",
"then", then.rows());
60 "columns of ",
"then", then.cols());
65 "rows of ",
"else", els.rows());
69 "columns of ",
"else", els.cols());
78 auto&& condition_copy = this->
template get_arg<0>().deep_copy();
79 auto&& then_copy = this->
template get_arg<1>().deep_copy();
80 auto&& else_copy = this->
template get_arg<2>().deep_copy();
81 return select_<std::remove_reference_t<
decltype(condition_copy)>,
82 std::remove_reference_t<
decltype(then_copy)>,
83 std::remove_reference_t<
decltype(else_copy)>>(
84 std::move(condition_copy), std::move(then_copy), std::move(else_copy));
98 const std::string& col_index_name,
99 const bool view_handled,
100 const std::string& var_name_condition,
101 const std::string& var_name_then,
102 const std::string& var_name_else)
const {
104 res.
body = type_str<Scalar>() +
" " +
var_name_ +
" = " + var_name_condition
105 +
" ? " + var_name_then +
" : " + var_name_else +
";\n";
116 std::pair<int, int> condition_diags
117 = this->
template get_arg<0>().extreme_diagonals();
118 std::pair<int, int> then_diags
119 = this->
template get_arg<1>().extreme_diagonals();
120 std::pair<int, int> else_diags
121 = this->
template get_arg<2>().extreme_diagonals();
124 return {
max(
min(then_diags.first, else_diags.first),
125 min(condition_diags.first, else_diags.first)),
126 min(
max(then_diags.second, else_diags.second),
127 max(condition_diags.second, else_diags.second))};
143 typename T_condition,
typename T_then,
typename T_else,
146inline select_<as_operation_cl_t<T_condition>, as_operation_cl_t<T_then>,
147 as_operation_cl_t<T_else>>
148select(T_condition&& condition, T_then&& then, T_else&& els) {
static constexpr int dynamic
Base for all kernel generator operations.
common_scalar_t< T_then, T_else > Scalar
auto deep_copy() const
Creates a deep copy of this expression.
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
select_(T_condition &&condition, T_then &&then, T_else &&els)
Constructor.
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_condition, const std::string &var_name_then, const std::string &var_name_else) const
Generates kernel code for this (select) operation.
Represents a selection operation in kernel generator expressions.
require_any_not_t< std::is_arithmetic< std::decay_t< Types > >... > require_any_not_arithmetic_t
Require at least one of the types do not satisfy std::is_arithmetic.
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
require_all_t< is_kernel_expression< Types >... > require_all_kernel_expressions_t
Enables a template if all given types are are a valid kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
auto min(T1 x, T2 y)
Returns the minimum coefficient of the two specified scalar arguments.
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
typename std::common_type_t< typename std::remove_reference_t< Types >::Scalar... > common_scalar_t
Wrapper for std::common_type_t
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Parts of an OpenCL kernel, generated by an expression.