Automatic Differentiation
 
Loading...
Searching...
No Matches
collect_adjoints.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CORE_COLLECT_ADJOINTS_HPP
2#define STAN_MATH_REV_CORE_COLLECT_ADJOINTS_HPP
3
10#include <type_traits>
11
12namespace stan::math::internal {
13
14inline constexpr bool ZeroOut = true;
23template <bool ZeroInput = false, typename Output, typename Input,
26inline void collect_adjoints(Output& output, Input&& input) {
27 return iter_tuple_nested(
28 [](auto&& output_i, auto&& input_i) {
29 using output_i_t = std::decay_t<decltype(output_i)>;
30 if constexpr (is_std_vector_v<output_i_t>) {
31 Eigen::Map<Eigen::Matrix<double, -1, 1>> output_map(output_i.data(),
32 output_i.size());
33 Eigen::Map<Eigen::Matrix<var, -1, 1>> input_map(input_i.data(),
34 input_i.size());
35 output_map.array() += input_map.adj().array();
36 if constexpr (ZeroInput) {
37 input_map.adj().setZero();
38 }
39 } else if constexpr (is_eigen_v<output_i_t>) {
40 output_i.array() += input_i.adj().array();
41 if constexpr (ZeroInput) {
42 input_i.adj().setZero();
43 }
44 } else if constexpr (is_stan_scalar_v<output_i_t>) {
45 output_i += input_i.adj();
46 if constexpr (ZeroInput) {
47 input_i.adj() = 0;
48 }
49 } else {
50 static_assert(
51 sizeof(std::decay_t<output_i_t>*) == 0,
52 "INTERNAL ERROR: collect_adjoints was "
53 "not able to deduce the actions needed for the given type. "
54 "This is an internal error, please report it: "
55 "https://github.com/stan-dev/math/issues");
56 }
57 },
58 std::forward<Output>(output), std::forward<Input>(input));
59}
60
68template <typename Output, typename Input,
71inline void collect_adjoints(Output&& output, Input&& input) {
72 return iter_tuple_nested(
73 [](auto&& output_i, auto&& input_i) {
74 using output_i_t = std::decay_t<decltype(output_i)>;
75 if constexpr (is_std_vector_v<output_i_t>) {
76 Eigen::Map<Eigen::Matrix<double, -1, 1>> output_map(output_i.data(),
77 output_i.size());
78 Eigen::Map<Eigen::Matrix<double, -1, 1>> input_map(input_i.data(),
79 input_i.size());
80 output_map.array() += input_map.array();
81 } else if constexpr (is_eigen_v<output_i_t>) {
82 output_i.array() += input_i.array();
83 } else if constexpr (is_stan_scalar_v<output_i_t>) {
84 output_i += input_i;
85 } else {
86 static_assert(
87 sizeof(std::decay_t<output_i_t>*) == 0,
88 "INTERNAL ERROR: collect_adjoints was "
89 "not able to deduce the actions needed for the given type. "
90 "This is an internal error, please report it: "
91 "https://github.com/stan-dev/math/issues");
92 }
93 },
94 std::forward<Output>(output), std::forward<Input>(input));
95}
96
105template <typename Output, typename Input>
106inline void collect_adjoints(Output&& output, const vari* ret, Input&& input) {
107 if constexpr (is_tuple_v<Output>) {
108 static_assert(sizeof(std::decay_t<Output>*) == 0,
109 "INTERNAL ERROR: collect_adjoints was "
110 "not able to deduce the actions needed for the given type. "
111 "This is an internal error, please report it: "
112 "https://github.com/stan-dev/math/issues");
113 } else if constexpr (is_std_vector_v<Output>) {
114 if constexpr (!is_var_v<value_type_t<Output>>) {
115 const auto output_size = output.size();
116 for (std::size_t i = 0; i < output_size; ++i) {
117 collect_adjoints(output[i], ret, input[i]);
118 }
119 } else {
120 Eigen::Map<Eigen::Matrix<var, -1, 1>> output_map(output.data(),
121 output.size());
122 Eigen::Map<const Eigen::Matrix<double, -1, 1>> input_map(input.data(),
123 input.size());
124 output_map.array().adj() += ret->adj_ * input_map.array();
125 }
126 } else if constexpr (is_eigen_v<Output>) {
127 output.adj().array() += ret->adj_ * input.array();
128 } else if constexpr (is_var_v<Output>) {
129 output.adj() += ret->adj_ * input;
130 }
131}
132
133} // namespace stan::math::internal
134
135#endif
void collect_adjoints(Output &output, Input &&input)
Collect the adjoints from the input and add them to the output.
A comparator that works for any container type that has the brackets operator.
void iter_tuple_nested(F &&f, Types &&... args)
Iterate and nest into a tuple or std::vector to apply f to each matrix or scalar type.
var_value< double > var
Definition var.hpp:1187
std::enable_if_t< Check::value > require_t
If condition is true, template is enabled.