Automatic Differentiation
 
Loading...
Searching...
No Matches
reverse_pass_collect_adjoints.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUNCTOR_REVERSE_PASS_COLLECT_ADJOINTS_HPP
2#define STAN_MATH_REV_FUNCTOR_REVERSE_PASS_COLLECT_ADJOINTS_HPP
3
10#include <cstddef>
11#include <utility>
12
13namespace stan::math::internal {
14
25template <typename Output, typename Input>
26inline void reverse_pass_collect_adjoints(var ret, Output&& output,
27 Input&& input) {
28 if constexpr (is_tuple_v<Output>) {
30 [ret](auto&& inner_arg, auto&& inner_input) mutable {
32 ret, std::forward<decltype(inner_arg)>(inner_arg),
33 std::forward<decltype(inner_input)>(inner_input));
34 },
35 std::forward<Output>(output), std::forward<Input>(input));
36 } else if constexpr (is_std_vector_containing_tuple_v<Output>) {
37 for (std::size_t i = 0; i < output.size(); ++i) {
38 reverse_pass_collect_adjoints(ret, output[i], input[i]);
39 }
40 } else {
42 [vi = ret.vi_, arg_arena = to_arena(std::forward<Output>(output)),
43 input_arena = to_arena(std::forward<Input>(input))]() mutable {
44 collect_adjoints(arg_arena, vi, input_arena);
45 });
46 }
47}
48
49} // namespace stan::math::internal
50
51#endif
void reverse_pass_collect_adjoints(var ret, Output &&output, Input &&input)
Collects adjoints from a tuple or std::vector of tuples.
void collect_adjoints(Output &output, Input &&input)
Collect the adjoints from the input and add them to the output.
A comparator that works for any container type that has the brackets operator.
void reverse_pass_callback(F &&functor)
Puts a callback on the autodiff stack to be called in reverse pass.
constexpr void for_each(F &&f, const std::tuple<> &)
Apply a function to each element of a tuple.
Definition for_each.hpp:80
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25