Automatic Differentiation
 
Loading...
Searching...
No Matches
select.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUN_SELECT_HPP
2#define STAN_MATH_PRIM_FUN_SELECT_HPP
3
7
8namespace stan {
9namespace math {
10
23template <typename T_true, typename T_false,
24 typename ReturnT = return_type_t<T_true, T_false>,
25 require_all_stan_scalar_t<T_true, T_false>* = nullptr>
26inline ReturnT select(const bool c, const T_true y_true,
27 const T_false y_false) {
28 return c ? ReturnT(y_true) : ReturnT(y_false);
29}
30
47template <
48 typename T_true, typename T_false,
49 typename T_return = return_type_t<T_true, T_false>,
50 typename T_true_plain = promote_scalar_t<T_return, plain_type_t<T_true>>,
51 typename T_false_plain = promote_scalar_t<T_return, plain_type_t<T_false>>,
54inline T_true_plain select(const bool c, T_true&& y_true, T_false&& y_false) {
55 check_matching_dims("select", "left hand side", y_true, "right hand side",
56 y_false);
57 return c ? T_true_plain(std::forward<T_true>(y_true))
58 : T_true_plain(std::forward<T_false>(y_false));
59}
60
78template <typename T_true, typename T_false,
79 typename ReturnT = promote_scalar_t<return_type_t<T_true, T_false>,
83inline ReturnT select(const bool c, T_true&& y_true, const T_false& y_false) {
84 if (c) {
85 return y_true;
86 } else {
88 [](auto&& y_true_inner, auto&& y_false_inner) { return y_false_inner; },
89 std::forward<T_true>(y_true), y_false);
90 }
91}
92
110template <typename T_true, typename T_false,
111 typename ReturnT = promote_scalar_t<return_type_t<T_true, T_false>,
115inline ReturnT select(const bool c, const T_true& y_true, T_false&& y_false) {
116 if (c) {
117 return apply_scalar_binary(
118 [](auto&& y_true_inner, auto&& y_false_inner) { return y_true_inner; },
119 y_true, std::forward<T_false>(y_false));
120 } else {
121 return y_false;
122 }
123}
124
140template <typename T_bool, typename T_true, typename T_false,
143inline auto select(T_bool&& c, const T_true& y_true, const T_false& y_false) {
144 using ret_t = return_type_t<T_true, T_false>;
145 return make_holder(
146 [y_true, y_false](auto&& c_) {
147 return std::forward<decltype(c_)>(c_).unaryExpr(
148 [y_true, y_false](bool cond) {
149 return cond ? ret_t(y_true) : ret_t(y_false);
150 });
151 },
152 std::forward<T_bool>(c));
153}
154
167template <typename T_bool, typename T_true, typename T_false,
170inline auto select(T_bool&& c, T_true&& y_true, T_false&& y_false) {
171 check_consistent_sizes("select", "boolean", c, "left hand side", y_true,
172 "right hand side", y_false);
173 using ret_t = return_type_t<T_true, T_false>;
174 if constexpr (!std::is_same_v<std::decay_t<T_true>, std::decay_t<T_false>>) {
175 return make_holder(
176 [](auto&& c_, auto&& y_true_, auto&& y_false_) {
177 return std::forward<decltype(c_)>(c_).select(
178 std::forward<decltype(y_true_)>(y_true_),
179 std::forward<decltype(y_false_)>(y_false_));
180 },
181 std::forward<T_bool>(c), std::forward<T_true>(y_true),
182 std::forward<T_false>(y_false));
183 } else {
184 return make_holder(
185 [](auto&& c_, auto&& y_true_, auto&& y_false_) {
186 return std::forward<decltype(c_)>(c_)
187 .select(std::forward<decltype(y_true_)>(y_true_),
188 std::forward<decltype(y_false_)>(y_false_))
189 .template cast<ret_t>();
190 },
191 std::forward<T_bool>(c), std::forward<T_true>(y_true),
192 std::forward<T_false>(y_false));
193 }
194}
195} // namespace math
196} // namespace stan
197
198#endif
require_all_t< is_container< std::decay_t< Types > >... > require_all_container_t
Require all of the types satisfy is_container.
require_t< is_container< std::decay_t< T > > > require_container_t
Require type satisfies is_container.
require_any_t< is_eigen_array< std::decay_t< Types > >... > require_any_eigen_array_t
Require any of the types satisfy is_eigen_array.
Definition is_eigen.hpp:286
require_t< container_type_check_base< is_eigen_array, value_type_t, TypeCheck, Check... > > require_eigen_array_vt
Require type satisfies is_eigen_array.
Definition is_eigen.hpp:311
require_t< is_eigen_array< std::decay_t< T > > > require_eigen_array_t
Require type satisfies is_eigen_array.
Definition is_eigen.hpp:274
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
require_all_t< std::is_same< std::decay_t< T >, std::decay_t< Types > >... > require_all_same_t
Require T and all of the Types satisfy std::is_same.
require_all_t< is_stan_scalar< std::decay_t< Types > >... > require_all_stan_scalar_t
Require all of the types satisfy is_stan_scalar.
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
auto apply_scalar_binary(F &&f, T1 &&x, T2 &&y)
Base template function for vectorization of binary scalar functions defined by applying a functor to ...
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
auto make_holder(F &&func, Args &&... args)
Calls given function with given arguments.
Definition holder.hpp:481
void check_consistent_sizes(const char *)
Trivial no input case, this function is a no-op.
typename plain_type< std::decay_t< T > >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...