1#ifndef STAN_MATH_REV_FUNCTOR_COUPLED_ODE_SYSTEM_HPP
2#define STAN_MATH_REV_FUNCTOR_COUPLED_ODE_SYSTEM_HPP
68template <
typename F,
typename T_y0,
typename... Args>
71 const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>&
y0_;
72 std::tuple<decltype(deep_copy_vars(std::declval<const Args&>()))...>
92 const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>& y0,
93 std::ostream* msgs,
const Args&... args)
100 args_adjoints_(num_args_vars),
116 void operator()(
const std::vector<double>& z, std::vector<double>& dz_dt,
120 dz_dt.resize(
size());
125 Eigen::Matrix<var, Eigen::Dynamic, 1> y_vars(N_);
126 for (
size_t n = 0; n < N_; ++n)
127 y_vars.coeffRef(n) = z[n];
129 Eigen::Matrix<var, Eigen::Dynamic, 1> f_y_t_vars =
math::apply(
130 [&](
auto&&... args) {
return f_(t, y_vars, msgs_, args...); },
133 check_size_match(
"coupled_ode_system",
"dy_dt", f_y_t_vars.size(),
"states",
136 for (
size_t i = 0; i < N_; ++i) {
137 dz_dt[i] = f_y_t_vars.coeffRef(i).val();
138 f_y_t_vars.coeffRef(i).grad();
140 y_adjoints_ = y_vars.adj();
142 if (args_adjoints_.size() > 0) {
143 memset(args_adjoints_.data(), 0,
144 sizeof(
double) * args_adjoints_.size());
148 [&](
auto&&... args) {
165 for (
size_t j = 0; j < num_y0_vars_; ++j) {
166 double temp_deriv = 0.0;
167 for (
size_t k = 0; k < N_; ++k) {
168 temp_deriv += z[N_ + N_ * j + k] * y_adjoints_.coeffRef(k);
171 dz_dt[N_ + N_ * j + i] = temp_deriv;
176 for (
size_t j = 0; j < num_args_vars; ++j) {
177 double temp_deriv = args_adjoints_.coeffRef(j);
178 for (
size_t k = 0; k < N_; ++k) {
179 temp_deriv += z[N_ + N_ * num_y0_vars_ + N_ * j + k]
180 * y_adjoints_.coeffRef(k);
183 dz_dt[N_ + N_ * num_y0_vars_ + N_ * j + i] = temp_deriv;
193 size_t size()
const {
return N_ + N_ * num_y0_vars_ + N_ * num_args_vars; }
214 std::vector<double> initial(
size(), 0.0);
215 for (
size_t i = 0; i < N_; i++) {
218 for (
size_t i = 0; i < num_y0_vars_; i++) {
219 initial[N_ + i * N_ + i] = 1.0;
void set_zero_all_adjoints()
Reset all adjoint values in this nested stack to zero.
A class following the RAII idiom to start and recover nested autodiff scopes.
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
constexpr auto for_each(F &&f, T &&t)
Apply a function to each element of a tuple.
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Arith deep_copy_vars(Arith &&arg)
Forward arguments that do not contain vars.
size_t count_vars(Pargs &&... args)
Count the number of vars in the input argument list.
void zero_adjoints() noexcept
End of recursion for set_zero_adjoints.
double * accumulate_adjoints(double *dest, const var &x, Pargs &&... args)
Accumulate adjoints from x into storage pointed to by dest, increment the adjoint storage pointer,...
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Eigen::VectorXd args_adjoints_
size_t size() const
Returns the size of the coupled system.
std::vector< double > initial_state() const
Returns the initial state of the coupled system.
Eigen::VectorXd y_adjoints_
coupled_ode_system_impl(const F &f, const Eigen::Matrix< T_y0, Eigen::Dynamic, 1 > &y0, std::ostream *msgs, const Args &... args)
Construct a coupled ode system from the base system function, initial state of the base system,...
const Eigen::Matrix< T_y0, Eigen::Dynamic, 1 > & y0_
const size_t num_y0_vars_
std::tuple< decltype(deep_copy_vars(std::declval< const Args & >()))... > local_args_tuple_
const size_t num_args_vars
void operator()(const std::vector< double > &z, std::vector< double > &dz_dt, double t)
Calculates the right hand side of the coupled ode system (the regular ode system with forward sensiti...