Automatic Differentiation
 
Loading...
Searching...
No Matches
apply_scalar_ternary.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUNCTOR_APPLY_SCALAR_TERNARY_HPP
2#define STAN_MATH_PRIM_FUNCTOR_APPLY_SCALAR_TERNARY_HPP
3
12#include <vector>
13
14namespace stan {
15namespace math {
16
37template <typename F, typename T1, typename T2, typename T3,
38 require_all_stan_scalar_t<T1, T2, T3>* = nullptr>
39inline auto apply_scalar_ternary(const F& f, const T1& x, const T2& y,
40 const T3& z) {
41 return f(x, y, z);
42}
43
59template <typename F, typename T1, typename T2, typename T3,
61inline auto apply_scalar_ternary(F&& f, T1&& x, T2&& y, T3&& z) {
62 check_matching_dims("Ternary function", "x", x, "y", y);
63 check_matching_dims("Ternary function", "y", y, "z", z);
64 return make_holder(
65 [](auto&& f_inner, auto&& x_inner, auto&& y_inner, auto&& z_inner) {
66 return Eigen::CwiseTernaryOp<
67 std::decay_t<decltype(f_inner)>, std::decay_t<decltype(x_inner)>,
68 std::decay_t<decltype(y_inner)>, std::decay_t<decltype(z_inner)>>(
69 x_inner, y_inner, z_inner, f_inner);
70 },
71 std::forward<F>(f), std::forward<T1>(x), std::forward<T2>(y),
72 std::forward<T3>(z));
73}
74
94template <typename F, typename T1, typename T2, typename T3,
96inline auto apply_scalar_ternary(const F& f, const T1& x, const T2& y,
97 const T3& z) {
98 check_matching_sizes("Ternary function", "x", x, "y", y);
99 check_matching_sizes("Ternary function", "y", y, "z", z);
100 decltype(auto) x_vec = as_column_vector_or_scalar(x);
101 decltype(auto) y_vec = as_column_vector_or_scalar(y);
102 decltype(auto) z_vec = as_column_vector_or_scalar(z);
103 using T_return = std::decay_t<decltype(f(x[0], y[0], z[0]))>;
104 std::vector<T_return> result(x.size());
105 Eigen::Map<Eigen::Matrix<T_return, -1, 1>>(result.data(), result.size())
106 = apply_scalar_ternary(f, x_vec, y_vec, z_vec);
107 return result;
108}
109
126template <typename F, typename T1, typename T2, typename T3,
128 T3>* = nullptr>
129inline auto apply_scalar_ternary(const F& f, const T1& x, const T2& y,
130 const T3& z) {
131 check_matching_sizes("Ternary function", "x", x, "y", y);
132 check_matching_sizes("Ternary function", "y", y, "z", z);
133 using T_return
134 = plain_type_t<decltype(apply_scalar_ternary(f, x[0], y[0], z[0]))>;
135 size_t y_size = y.size();
136 std::vector<T_return> result(y_size);
137 for (size_t i = 0; i < y_size; ++i) {
138 result[i] = apply_scalar_ternary(f, x[i], y[i], z[i]);
139 }
140 return result;
141}
142
159template <typename F, typename T1, typename T2, typename T3,
160 require_any_container_t<T1, T2>* = nullptr,
161 require_stan_scalar_t<T3>* = nullptr>
162inline auto apply_scalar_ternary(const F& f, const T1& x, const T2& y,
163 const T3& z) {
164 return apply_scalar_binary(
165 x, y, [f, z](const auto& a, const auto& b) { return f(a, b, z); });
166}
167
184template <typename F, typename T1, typename T2, typename T3,
186 require_stan_scalar_t<T2>* = nullptr>
187inline auto apply_scalar_ternary(const F& f, const T1& x, const T2& y,
188 const T3& z) {
189 return apply_scalar_binary(
190 x, z, [f, y](const auto& a, const auto& c) { return f(a, y, c); });
191}
192
209template <typename F, typename T1, typename T2, typename T3,
210 require_container_t<T3>* = nullptr,
211 require_stan_scalar_t<T1>* = nullptr>
212inline auto apply_scalar_ternary(const F& f, const T1& x, const T2& y,
213 const T3& z) {
214 return apply_scalar_binary(
215 y, z, [f, x](const auto& b, const auto& c) { return f(x, b, c); });
216}
217
218} // namespace math
219} // namespace stan
220#endif
require_all_t< is_container< std::decay_t< Types > >... > require_all_container_t
Require all of the types satisfy is_container.
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
Definition is_eigen.hpp:120
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
require_all_t< container_type_check_base< is_std_vector, value_type_t, TypeCheck, Check >... > require_all_std_vector_vt
Require all of the types satisfy is_std_vector.
auto make_holder(const F &func, Args &&... args)
Constructs an expression from given arguments using given functor.
Definition holder.hpp:352
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
void check_matching_sizes(const char *function, const char *name1, const T_y1 &y1, const char *name2, const T_y2 &y2)
Check if two structures at the same size.
auto apply_scalar_ternary(const F &f, const T1 &x, const T2 &y, const T3 &z)
Base template function for vectorization of ternary scalar functions defined by applying a functor to...
auto apply_scalar_binary(const T1 &x, const T2 &y, const F &f)
Base template function for vectorization of binary scalar functions defined by applying a functor to ...
typename plain_type< T >::type plain_type_t
bool_constant< math::disjunction< is_container< Container >, is_var_matrix< Container > >::value > is_container_or_var_matrix
Deduces whether type is eigen matrix, standard vector, or var<Matrix>.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...