Automatic Differentiation
 
Loading...
Searching...
No Matches
accumulate_adjoints.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP
2#define STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP
3
7
8#include <utility>
9#include <vector>
10
11namespace stan {
12namespace math {
13
14template <typename... Pargs>
15inline double* accumulate_adjoints(double* dest, const var& x, Pargs&&... args);
16
17template <typename VarVec, require_std_vector_vt<is_var, VarVec>* = nullptr,
18 typename... Pargs>
19inline double* accumulate_adjoints(double* dest, VarVec&& x, Pargs&&... args);
20
21template <typename VecContainer,
22 require_std_vector_st<is_var, VecContainer>* = nullptr,
23 require_std_vector_vt<is_container, VecContainer>* = nullptr,
24 typename... Pargs>
25inline double* accumulate_adjoints(double* dest, VecContainer&& x,
26 Pargs&&... args);
27
28template <typename EigT, require_eigen_vt<is_var, EigT>* = nullptr,
29 typename... Pargs>
30inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args);
31
32template <typename Arith, require_st_arithmetic<Arith>* = nullptr,
33 typename... Pargs>
34inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args);
35
36inline double* accumulate_adjoints(double* dest);
37
50template <typename... Pargs>
51inline double* accumulate_adjoints(double* dest, const var& x,
52 Pargs&&... args) {
53 *dest += x.adj();
54 return accumulate_adjoints(dest + 1, std::forward<Pargs>(args)...);
55}
56
69template <typename VarVec, require_std_vector_vt<is_var, VarVec>*,
70 typename... Pargs>
71inline double* accumulate_adjoints(double* dest, VarVec&& x, Pargs&&... args) {
72 for (auto&& x_iter : x) {
73 *dest += x_iter.adj();
74 ++dest;
75 }
76 return accumulate_adjoints(dest, std::forward<Pargs>(args)...);
77}
78
94template <typename VecContainer, require_std_vector_st<is_var, VecContainer>*,
95 require_std_vector_vt<is_container, VecContainer>*, typename... Pargs>
96inline double* accumulate_adjoints(double* dest, VecContainer&& x,
97 Pargs&&... args) {
98 for (auto&& x_iter : x) {
99 dest = accumulate_adjoints(dest, x_iter);
100 }
101 return accumulate_adjoints(dest, std::forward<Pargs>(args)...);
102}
103
118template <typename EigT, require_eigen_vt<is_var, EigT>*, typename... Pargs>
119inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args) {
120 Eigen::Map<Eigen::MatrixXd>(dest, x.rows(), x.cols()) += x.adj();
121 return accumulate_adjoints(dest + x.size(), std::forward<Pargs>(args)...);
122}
123
138template <typename Arith, require_st_arithmetic<Arith>*, typename... Pargs>
139inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args) {
140 return accumulate_adjoints(dest, std::forward<Pargs>(args)...);
141}
142
148inline double* accumulate_adjoints(double* dest) { return dest; }
149
150} // namespace math
151} // namespace stan
152
153#endif
var_value< double > var
Definition var.hpp:1187
double * accumulate_adjoints(double *dest, const var &x, Pargs &&... args)
Accumulate adjoints from x into storage pointed to by dest, increment the adjoint storage pointer,...
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...