Automatic Differentiation
 
Loading...
Searching...
No Matches
select.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_SELECT_HPP
2#define STAN_MATH_OPENCL_KERNEL_GENERATOR_SELECT_HPP
3#ifdef STAN_OPENCL
4
13#include <algorithm>
14#include <set>
15#include <string>
16#include <type_traits>
17#include <utility>
18
19namespace stan {
20namespace math {
21
35template <typename T_condition, typename T_then, typename T_else>
36class select_ : public operation_cl<select_<T_condition, T_then, T_else>,
37 common_scalar_t<T_then, T_else>,
38 T_condition, T_then, T_else> {
39 public:
42 T_condition, T_then, T_else>;
43 using base::var_name_;
44
51 select_(T_condition&& condition, T_then&& then, T_else&& els) // NOLINT
52 : base(std::forward<T_condition>(condition), std::forward<T_then>(then),
53 std::forward<T_else>(els)) {
54 if (condition.rows() != base::dynamic && then.rows() != base::dynamic) {
55 check_size_match("select", "Rows of ", "condition", condition.rows(),
56 "rows of ", "then", then.rows());
57 }
58 if (condition.cols() != base::dynamic && then.cols() != base::dynamic) {
59 check_size_match("select", "Columns of ", "condition", condition.cols(),
60 "columns of ", "then", then.cols());
61 }
62
63 if (condition.rows() != base::dynamic && els.rows() != base::dynamic) {
64 check_size_match("select", "Rows of ", "condition", condition.rows(),
65 "rows of ", "else", els.rows());
66 }
67 if (condition.cols() != base::dynamic && els.cols() != base::dynamic) {
68 check_size_match("select", "Columns of ", "condition", condition.cols(),
69 "columns of ", "else", els.cols());
70 }
71 }
72
77 inline auto deep_copy() const {
78 auto&& condition_copy = this->template get_arg<0>().deep_copy();
79 auto&& then_copy = this->template get_arg<1>().deep_copy();
80 auto&& else_copy = this->template get_arg<2>().deep_copy();
81 return select_<std::remove_reference_t<decltype(condition_copy)>,
82 std::remove_reference_t<decltype(then_copy)>,
83 std::remove_reference_t<decltype(else_copy)>>(
84 std::move(condition_copy), std::move(then_copy), std::move(else_copy));
85 }
86
97 inline kernel_parts generate(const std::string& row_index_name,
98 const std::string& col_index_name,
99 const bool view_handled,
100 const std::string& var_name_condition,
101 const std::string& var_name_then,
102 const std::string& var_name_else) const {
103 kernel_parts res{};
104 res.body = type_str<Scalar>() + " " + var_name_ + " = " + var_name_condition
105 + " ? " + var_name_then + " : " + var_name_else + ";\n";
106 return res;
107 }
108
113 inline std::pair<int, int> extreme_diagonals() const {
114 using std::max;
115 using std::min;
116 std::pair<int, int> condition_diags
117 = this->template get_arg<0>().extreme_diagonals();
118 std::pair<int, int> then_diags
119 = this->template get_arg<1>().extreme_diagonals();
120 std::pair<int, int> else_diags
121 = this->template get_arg<2>().extreme_diagonals();
122 // Where the condition is 0 we get else's values. Otherwise we get the more
123 // extreme of then's and else's.
124 return {max(min(then_diags.first, else_diags.first),
125 min(condition_diags.first, else_diags.first)),
126 min(max(then_diags.second, else_diags.second),
127 max(condition_diags.second, else_diags.second))};
128 }
129};
130
142template <
143 typename T_condition, typename T_then, typename T_else,
146inline select_<as_operation_cl_t<T_condition>, as_operation_cl_t<T_then>,
147 as_operation_cl_t<T_else>>
148select(T_condition&& condition, T_then&& then, T_else&& els) { // NOLINT
149 return {as_operation_cl(std::forward<T_condition>(condition)),
150 as_operation_cl(std::forward<T_then>(then)),
151 as_operation_cl(std::forward<T_else>(els))};
152}
153
155} // namespace math
156} // namespace stan
157#endif
158#endif
static constexpr int dynamic
Base for all kernel generator operations.
common_scalar_t< T_then, T_else > Scalar
Definition select.hpp:40
auto deep_copy() const
Creates a deep copy of this expression.
Definition select.hpp:77
std::pair< int, int > extreme_diagonals() const
Determine indices of extreme sub- and superdiagonals written.
Definition select.hpp:113
select_(T_condition &&condition, T_then &&then, T_else &&els)
Constructor.
Definition select.hpp:51
kernel_parts generate(const std::string &row_index_name, const std::string &col_index_name, const bool view_handled, const std::string &var_name_condition, const std::string &var_name_then, const std::string &var_name_else) const
Generates kernel code for this (select) operation.
Definition select.hpp:97
Represents a selection operation in kernel generator expressions.
Definition select.hpp:38
require_any_not_t< std::is_arithmetic< std::decay_t< Types > >... > require_any_not_arithmetic_t
Require at least one of the types do not satisfy std::is_arithmetic.
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
require_all_t< is_kernel_expression< Types >... > require_all_kernel_expressions_t
Enables a template if all given types are are a valid kernel generator expressions.
T_operation && as_operation_cl(T_operation &&a)
Converts any valid kernel generator expression into an operation.
auto min(T1 x, T2 y)
Returns the minimum coefficient of the two specified scalar arguments.
Definition min.hpp:24
auto max(T1 x, T2 y)
Returns the maximum value of the two specified scalar arguments.
Definition max.hpp:25
typename std::common_type_t< typename std::remove_reference_t< Types >::Scalar... > common_scalar_t
Wrapper for std::common_type_t
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Parts of an OpenCL kernel, generated by an expression.