Automatic Differentiation
 
Loading...
Searching...
No Matches
apply_scalar_ternary.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_PRIM_FUNCTOR_APPLY_SCALAR_TERNARY_HPP
2#define STAN_MATH_PRIM_FUNCTOR_APPLY_SCALAR_TERNARY_HPP
3
12#include <vector>
13
14namespace stan {
15namespace math {
16
37template <typename F, typename T1, typename T2, typename T3,
38 require_all_stan_scalar_t<T1, T2, T3>* = nullptr>
39inline auto apply_scalar_ternary(F&& f, T1&& x, T2&& y, T3&& z) {
40 return std::forward<F>(f)(std::forward<T1>(x), std::forward<T2>(y),
41 std::forward<T3>(z));
42}
43
59template <typename F, typename T1, typename T2, typename T3,
61inline auto apply_scalar_ternary(F&& f, T1&& x, T2&& y, T3&& z) {
62 check_matching_dims("Ternary function", "x", x, "y", y);
63 check_matching_dims("Ternary function", "y", y, "z", z);
64 return make_holder(
65 [](auto&& f_inner, auto&& x_inner, auto&& y_inner, auto&& z_inner) {
66 return Eigen::CwiseTernaryOp<
67 std::decay_t<decltype(f_inner)>, std::decay_t<decltype(x_inner)>,
68 std::decay_t<decltype(y_inner)>, std::decay_t<decltype(z_inner)>>(
69 x_inner, y_inner, z_inner, f_inner);
70 },
71 std::forward<F>(f), std::forward<T1>(x), std::forward<T2>(y),
72 std::forward<T3>(z));
73}
74
94template <typename F, typename T1, typename T2, typename T3,
95 require_all_std_vector_vt<is_stan_scalar, T1, T2, T3>* = nullptr>
96inline auto apply_scalar_ternary(F&& f, T1&& x, T2&& y, T3&& z) {
97 check_matching_sizes("Ternary function", "x", x, "y", y);
98 check_matching_sizes("Ternary function", "y", y, "z", z);
99 using T_return = std::decay_t<decltype(f(x[0], y[0], z[0]))>;
100 decltype(auto) x_vec = as_column_vector_or_scalar(std::forward<T1>(x));
101 decltype(auto) y_vec = as_column_vector_or_scalar(std::forward<T2>(y));
102 decltype(auto) z_vec = as_column_vector_or_scalar(std::forward<T3>(z));
103 std::vector<T_return> result(x_vec.size());
104 Eigen::Map<Eigen::Matrix<T_return, -1, 1>>(result.data(), result.size())
105 = apply_scalar_ternary(std::forward<F>(f),
106 std::forward<decltype(x_vec)>(x_vec),
107 std::forward<decltype(y_vec)>(y_vec),
108 std::forward<decltype(z_vec)>(z_vec));
109 return result;
110}
111
128template <typename F, typename T1, typename T2, typename T3,
130 T3>* = nullptr>
131inline auto apply_scalar_ternary(F&& f, T1&& x, T2&& y, T3&& z) {
132 check_matching_sizes("Ternary function", "x", x, "y", y);
133 check_matching_sizes("Ternary function", "y", y, "z", z);
134 using T_return
135 = plain_type_t<decltype(apply_scalar_ternary(f, x[0], y[0], z[0]))>;
136 size_t y_size = y.size();
137 std::vector<T_return> result(y_size);
138 for (size_t i = 0; i < y_size; ++i) {
139 result[i] = apply_scalar_ternary(f, x[i], y[i], z[i]);
140 }
141 return result;
142}
143
160template <typename F, typename T1, typename T2, typename T3,
161 require_any_container_t<T1, T2>* = nullptr,
162 require_stan_scalar_t<T3>* = nullptr>
163inline auto apply_scalar_ternary(F&& f, T1&& x, T2&& y, T3&& z) {
164 return apply_scalar_binary(
165 [f_ = std::forward<F>(f), z](auto&& a, auto&& b) {
166 return f_(std::forward<decltype(a)>(a), std::forward<decltype(b)>(b),
167 z);
168 },
169 std::forward<T1>(x), std::forward<T2>(y));
170}
171
188template <typename F, typename T1, typename T2, typename T3,
190 require_stan_scalar_t<T2>* = nullptr>
191inline auto apply_scalar_ternary(F&& f, T1&& x, T2&& y, T3&& z) {
192 return apply_scalar_binary(
193 [f_ = std::forward<F>(f), y](auto&& a, auto&& c) {
194 return f_(std::forward<decltype(a)>(a), y,
195 std::forward<decltype(c)>(c));
196 },
197 std::forward<T1>(x), std::forward<T3>(z));
198}
199
216template <typename F, typename T1, typename T2, typename T3,
217 require_container_t<T3>* = nullptr,
218 require_stan_scalar_t<T1>* = nullptr>
219inline auto apply_scalar_ternary(F&& f, T1&& x, T2&& y, T3&& z) {
220 return apply_scalar_binary(
221 [f_ = std::forward<F>(f), x](auto&& b, auto&& c) {
222 return f_(x, std::forward<decltype(b)>(b),
223 std::forward<decltype(c)>(c));
224 },
225 std::forward<T2>(y), std::forward<T3>(z));
226}
227
228} // namespace math
229} // namespace stan
230#endif
require_all_t< is_container< std::decay_t< Types > >... > require_all_container_t
Require all of the types satisfy is_container.
require_all_t< is_eigen< std::decay_t< Types > >... > require_all_eigen_t
Require all of the types satisfy is_eigen.
Definition is_eigen.hpp:123
auto as_column_vector_or_scalar(T &&a)
as_column_vector_or_scalar of a kernel generator expression.
require_t< is_stan_scalar< std::decay_t< T > > > require_stan_scalar_t
Require type satisfies is_stan_scalar.
require_all_t< container_type_check_base< is_std_vector, value_type_t, TypeCheck, Check >... > require_all_std_vector_vt
Require all of the types satisfy is_std_vector.
auto apply_scalar_ternary(F &&f, T1 &&x, T2 &&y, T3 &&z)
Base template function for vectorization of ternary scalar functions defined by applying a functor to...
auto apply_scalar_binary(F &&f, T1 &&x, T2 &&y)
Base template function for vectorization of binary scalar functions defined by applying a functor to ...
void check_matching_dims(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Check if the two containers have the same dimensions.
auto make_holder(F &&func, Args &&... args)
Calls given function with given arguments.
Definition holder.hpp:437
void check_matching_sizes(const char *function, const char *name1, const T_y1 &y1, const char *name2, const T_y2 &y2)
Check if two structures at the same size.
bool_constant< math::disjunction< is_container< Container >, is_var_matrix< Container > >::value > is_container_or_var_matrix
Deduces whether type is eigen matrix, standard vector, or var<Matrix>.
typename plain_type< std::decay_t< T > >::type plain_type_t
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...