Automatic Differentiation
 
Loading...
Searching...
No Matches
coupled_ode_system.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUNCTOR_COUPLED_ODE_SYSTEM_HPP
2#define STAN_MATH_REV_FUNCTOR_COUPLED_ODE_SYSTEM_HPP
3
12#include <stdexcept>
13#include <ostream>
14#include <vector>
15
16namespace stan {
17namespace math {
18
68template <typename F, typename T_y0, typename... Args>
69struct coupled_ode_system_impl<false, F, T_y0, Args...> {
70 const F& f_;
71 const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>& y0_;
72 std::tuple<decltype(deep_copy_vars(std::declval<const Args&>()))...>
74 const size_t num_y0_vars_;
75 const size_t num_args_vars;
76 const size_t N_;
77 Eigen::VectorXd args_adjoints_;
78 Eigen::VectorXd y_adjoints_;
79 std::ostream* msgs_;
80
92 const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>& y0,
93 std::ostream* msgs, const Args&... args)
94 : f_(f),
95 y0_(y0),
96 local_args_tuple_(deep_copy_vars(args)...),
97 num_y0_vars_(count_vars(y0_)),
98 num_args_vars(count_vars(args...)),
99 N_(y0.size()),
100 args_adjoints_(num_args_vars),
101 y_adjoints_(N_),
102 msgs_(msgs) {}
103
116 void operator()(const std::vector<double>& z, std::vector<double>& dz_dt,
117 double t) {
118 using std::vector;
119
120 dz_dt.resize(size());
121
122 // Run nested autodiff in this scope
123 nested_rev_autodiff nested;
124
125 Eigen::Matrix<var, Eigen::Dynamic, 1> y_vars(N_);
126 for (size_t n = 0; n < N_; ++n)
127 y_vars.coeffRef(n) = z[n];
128
129 Eigen::Matrix<var, Eigen::Dynamic, 1> f_y_t_vars = math::apply(
130 [&](auto&&... args) { return f_(t, y_vars, msgs_, args...); },
131 local_args_tuple_);
132
133 check_size_match("coupled_ode_system", "dy_dt", f_y_t_vars.size(), "states",
134 N_);
135
136 for (size_t i = 0; i < N_; ++i) {
137 dz_dt[i] = f_y_t_vars.coeffRef(i).val();
138 f_y_t_vars.coeffRef(i).grad();
139
140 y_adjoints_ = y_vars.adj();
141
142 if (args_adjoints_.size() > 0) {
143 memset(args_adjoints_.data(), 0,
144 sizeof(double) * args_adjoints_.size());
145 }
146
148 [&](auto&&... args) {
149 accumulate_adjoints(args_adjoints_.data(), args...);
150 },
151 local_args_tuple_);
152
153 // The vars here do not live on the nested stack so must be zero'd
154 // separately
155 stan::math::for_each([](auto&& arg) { zero_adjoints(arg); },
156 local_args_tuple_);
157
158 // No need to zero adjoints after last sweep
159 if (i + 1 < N_) {
160 nested.set_zero_all_adjoints();
161 }
162
163 // Compute the right hand side for the sensitivities with respect to the
164 // initial conditions
165 for (size_t j = 0; j < num_y0_vars_; ++j) {
166 double temp_deriv = 0.0;
167 for (size_t k = 0; k < N_; ++k) {
168 temp_deriv += z[N_ + N_ * j + k] * y_adjoints_.coeffRef(k);
169 }
170
171 dz_dt[N_ + N_ * j + i] = temp_deriv;
172 }
173
174 // Compute the right hand size for the sensitivities with respect to the
175 // parameters
176 for (size_t j = 0; j < num_args_vars; ++j) {
177 double temp_deriv = args_adjoints_.coeffRef(j);
178 for (size_t k = 0; k < N_; ++k) {
179 temp_deriv += z[N_ + N_ * num_y0_vars_ + N_ * j + k]
180 * y_adjoints_.coeffRef(k);
181 }
182
183 dz_dt[N_ + N_ * num_y0_vars_ + N_ * j + i] = temp_deriv;
184 }
185 }
186 }
187
193 size_t size() const { return N_ + N_ * num_y0_vars_ + N_ * num_args_vars; }
194
213 std::vector<double> initial_state() const {
214 std::vector<double> initial(size(), 0.0);
215 for (size_t i = 0; i < N_; i++) {
216 initial[i] = value_of(y0_(i));
217 }
218 for (size_t i = 0; i < num_y0_vars_; i++) {
219 initial[N_ + i * N_ + i] = 1.0;
220 }
221 return initial;
222 }
223};
224
225} // namespace math
226} // namespace stan
227#endif
void set_zero_all_adjoints()
Reset all adjoint values in this nested stack to zero.
A class following the RAII idiom to start and recover nested autodiff scopes.
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:19
constexpr auto for_each(F &&f, T &&t)
Apply a function to each element of a tuple.
Definition for_each.hpp:66
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
Definition arg.hpp:19
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
Arith deep_copy_vars(Arith &&arg)
Forward arguments that do not contain vars.
size_t count_vars(Pargs &&... args)
Count the number of vars in the input argument list.
void zero_adjoints() noexcept
End of recursion for set_zero_adjoints.
double * accumulate_adjoints(double *dest, const var &x, Pargs &&... args)
Accumulate adjoints from x into storage pointed to by dest, increment the adjoint storage pointer,...
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
Definition apply.hpp:52
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
size_t size() const
Returns the size of the coupled system.
std::vector< double > initial_state() const
Returns the initial state of the coupled system.
coupled_ode_system_impl(const F &f, const Eigen::Matrix< T_y0, Eigen::Dynamic, 1 > &y0, std::ostream *msgs, const Args &... args)
Construct a coupled ode system from the base system function, initial state of the base system,...
const Eigen::Matrix< T_y0, Eigen::Dynamic, 1 > & y0_
std::tuple< decltype(deep_copy_vars(std::declval< const Args & >()))... > local_args_tuple_
void operator()(const std::vector< double > &z, std::vector< double > &dz_dt, double t)
Calculates the right hand side of the coupled ode system (the regular ode system with forward sensiti...