1#ifndef STAN_MATH_REV_CORE_COLLECT_ADJOINTS_HPP
2#define STAN_MATH_REV_CORE_COLLECT_ADJOINTS_HPP
23template <
bool ZeroInput =
false,
typename Output,
typename Input,
28 [](
auto&& output_i,
auto&& input_i) {
29 using output_i_t = std::decay_t<
decltype(output_i)>;
30 if constexpr (is_std_vector_v<output_i_t>) {
31 Eigen::Map<Eigen::Matrix<double, -1, 1>> output_map(output_i.data(),
33 Eigen::Map<Eigen::Matrix<
var, -1, 1>> input_map(input_i.data(),
35 output_map.array() += input_map.adj().array();
36 if constexpr (ZeroInput) {
37 input_map.adj().setZero();
39 }
else if constexpr (is_eigen_v<output_i_t>) {
40 output_i.array() += input_i.adj().array();
41 if constexpr (ZeroInput) {
42 input_i.adj().setZero();
44 }
else if constexpr (is_stan_scalar_v<output_i_t>) {
45 output_i += input_i.adj();
46 if constexpr (ZeroInput) {
51 sizeof(std::decay_t<output_i_t>*) == 0,
52 "INTERNAL ERROR: collect_adjoints was "
53 "not able to deduce the actions needed for the given type. "
54 "This is an internal error, please report it: "
55 "https://github.com/stan-dev/math/issues");
58 std::forward<Output>(output), std::forward<Input>(input));
68template <
typename Output,
typename Input,
73 [](
auto&& output_i,
auto&& input_i) {
74 using output_i_t = std::decay_t<
decltype(output_i)>;
75 if constexpr (is_std_vector_v<output_i_t>) {
76 Eigen::Map<Eigen::Matrix<double, -1, 1>> output_map(output_i.data(),
78 Eigen::Map<Eigen::Matrix<double, -1, 1>> input_map(input_i.data(),
80 output_map.array() += input_map.array();
81 }
else if constexpr (is_eigen_v<output_i_t>) {
82 output_i.array() += input_i.array();
83 }
else if constexpr (is_stan_scalar_v<output_i_t>) {
87 sizeof(std::decay_t<output_i_t>*) == 0,
88 "INTERNAL ERROR: collect_adjoints was "
89 "not able to deduce the actions needed for the given type. "
90 "This is an internal error, please report it: "
91 "https://github.com/stan-dev/math/issues");
94 std::forward<Output>(output), std::forward<Input>(input));
105template <
typename Output,
typename Input>
107 if constexpr (is_tuple_v<Output>) {
108 static_assert(
sizeof(std::decay_t<Output>*) == 0,
109 "INTERNAL ERROR: collect_adjoints was "
110 "not able to deduce the actions needed for the given type. "
111 "This is an internal error, please report it: "
112 "https://github.com/stan-dev/math/issues");
113 }
else if constexpr (is_std_vector_v<Output>) {
114 if constexpr (!is_var_v<value_type_t<Output>>) {
115 const auto output_size = output.size();
116 for (std::size_t i = 0; i < output_size; ++i) {
120 Eigen::Map<Eigen::Matrix<
var, -1, 1>> output_map(output.data(),
122 Eigen::Map<
const Eigen::Matrix<double, -1, 1>> input_map(input.data(),
124 output_map.array().adj() += ret->adj_ * input_map.array();
126 }
else if constexpr (is_eigen_v<Output>) {
127 output.adj().array() += ret->adj_ * input.array();
128 }
else if constexpr (is_var_v<Output>) {
129 output.adj() += ret->adj_ * input;
void collect_adjoints(Output &output, Input &&input)
Collect the adjoints from the input and add them to the output.
A comparator that works for any container type that has the brackets operator.
void iter_tuple_nested(F &&f, Types &&... args)
Iterate and nest into a tuple or std::vector to apply f to each matrix or scalar type.
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.