Automatic Differentiation
 
Loading...
Searching...
No Matches
cvodes_integrator.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUNCTOR_INTEGRATE_ODE_CVODES_HPP
2#define STAN_MATH_REV_FUNCTOR_INTEGRATE_ODE_CVODES_HPP
3
11#include <sundials/sundials_context.h>
12#include <cvodes/cvodes.h>
13#include <nvector/nvector_serial.h>
14#include <sunlinsol/sunlinsol_dense.h>
15#include <algorithm>
16#include <ostream>
17#include <vector>
18
19namespace stan {
20namespace math {
21
33template <int Lmm, typename F, typename T_y0, typename T_t0, typename T_ts,
34 typename... T_Args>
36 using T_Return = return_type_t<T_y0, T_t0, T_ts, T_Args...>;
38
39 const char* function_name_;
40 sundials::Context sundials_context_;
41 const F& f_;
42 const Eigen::Matrix<T_y0_t0, Eigen::Dynamic, 1> y0_;
43 const T_t0 t0_;
44 const std::vector<T_ts>& ts_;
45 std::tuple<const T_Args&...> args_tuple_;
46 std::tuple<plain_type_t<decltype(value_of(std::declval<const T_Args&>()))>...>
48 const size_t N_;
49 std::ostream* msgs_;
52 long int max_num_steps_; // NOLINT(runtime/int)
53
54 const size_t num_y0_vars_;
55 const size_t num_args_vars_;
56
58
59 std::vector<double> coupled_state_;
60 N_Vector nv_state_;
61 N_Vector* nv_state_sens_;
62 SUNMatrix A_;
63 SUNLinearSolver LS_;
64
69 static int cv_rhs(realtype t, N_Vector y, N_Vector ydot, void* user_data) {
70 cvodes_integrator* integrator = static_cast<cvodes_integrator*>(user_data);
71 integrator->rhs(t, NV_DATA_S(y), NV_DATA_S(ydot));
72 return 0;
73 }
74
79 static int cv_rhs_sens(int Ns, realtype t, N_Vector y, N_Vector ydot,
80 N_Vector* yS, N_Vector* ySdot, void* user_data,
81 N_Vector tmp1, N_Vector tmp2) {
82 cvodes_integrator* integrator = static_cast<cvodes_integrator*>(user_data);
83 integrator->rhs_sens(t, NV_DATA_S(y), yS, ySdot);
84 return 0;
85 }
86
93 static int cv_jacobian_states(realtype t, N_Vector y, N_Vector fy,
94 SUNMatrix J, void* user_data, N_Vector tmp1,
95 N_Vector tmp2, N_Vector tmp3) {
96 cvodes_integrator* integrator = static_cast<cvodes_integrator*>(user_data);
97 integrator->jacobian_states(t, NV_DATA_S(y), J);
98 return 0;
99 }
100
105 inline void rhs(double t, const double y[], double dy_dt[]) const {
106 const Eigen::VectorXd y_vec = Eigen::Map<const Eigen::VectorXd>(y, N_);
107
108 Eigen::VectorXd dy_dt_vec = math::apply(
109 [&](auto&&... args) { return f_(t, y_vec, msgs_, args...); },
111
112 check_size_match("cvodes_integrator", "dy_dt", dy_dt_vec.size(), "states",
113 N_);
114
115 std::copy(dy_dt_vec.data(), dy_dt_vec.data() + dy_dt_vec.size(), dy_dt);
116 }
117
122 inline void jacobian_states(double t, const double y[], SUNMatrix J) const {
123 Eigen::VectorXd fy;
124 Eigen::MatrixXd Jfy;
125
126 auto f_wrapped = [&](const Eigen::Matrix<var, Eigen::Dynamic, 1>& y) {
127 return math::apply(
128 [&](auto&&... args) { return f_(t, y, msgs_, args...); },
130 };
131
132 jacobian(f_wrapped, Eigen::Map<const Eigen::VectorXd>(y, N_), fy, Jfy);
133
134 for (size_t j = 0; j < Jfy.cols(); ++j) {
135 for (size_t i = 0; i < Jfy.rows(); ++i) {
136 SM_ELEMENT_D(J, i, j) = Jfy(i, j);
137 }
138 }
139 }
140
147 inline void rhs_sens(double t, const double y[], N_Vector* yS,
148 N_Vector* ySdot) {
149 std::vector<double> z(coupled_state_.size());
150 std::vector<double> dz_dt;
151 std::copy(y, y + N_, z.data());
152 for (std::size_t s = 0; s < num_y0_vars_ + num_args_vars_; s++) {
153 std::copy(NV_DATA_S(yS[s]), NV_DATA_S(yS[s]) + N_,
154 z.data() + (s + 1) * N_);
155 }
156 coupled_ode_(z, dz_dt, t);
157 for (std::size_t s = 0; s < num_y0_vars_ + num_args_vars_; s++) {
158 std::move(dz_dt.data() + (s + 1) * N_, dz_dt.data() + (s + 2) * N_,
159 NV_DATA_S(ySdot[s]));
160 }
161 }
162
163 public:
190 template <require_eigen_col_vector_t<T_y0>* = nullptr>
191 cvodes_integrator(const char* function_name, const F& f, const T_y0& y0,
192 const T_t0& t0, const std::vector<T_ts>& ts,
193 double relative_tolerance, double absolute_tolerance,
194 long int max_num_steps, // NOLINT(runtime/int)
195 std::ostream* msgs, const T_Args&... args)
196 : function_name_(function_name),
198 f_(f),
199 y0_(y0.template cast<T_y0_t0>()),
200 t0_(t0),
201 ts_(ts),
202 args_tuple_(args...),
204 N_(y0.size()),
205 msgs_(msgs),
206 relative_tolerance_(relative_tolerance),
207 absolute_tolerance_(absolute_tolerance),
208 max_num_steps_(max_num_steps),
210 num_args_vars_(count_vars(args...)),
211 coupled_ode_(f, y0_, msgs, args...),
212 coupled_state_(coupled_ode_.initial_state()) {
213 check_finite(function_name, "initial state", y0_);
214 check_finite(function_name, "initial time", t0_);
215 check_finite(function_name, "times", ts_);
216
217 // Code from: https://stackoverflow.com/a/17340003 . Should probably do
218 // something better
220 [&](auto&&... args) {
221 std::vector<int> unused_temp{
222 0, (check_finite(function_name, "ode parameters and data", args),
223 0)...};
224 },
226
227 check_nonzero_size(function_name, "times", ts_);
228 check_nonzero_size(function_name, "initial state", y0_);
229 check_sorted(function_name, "times", ts_);
230 check_less(function_name, "initial time", t0_, ts_[0]);
231 check_positive_finite(function_name, "relative_tolerance",
233 check_positive_finite(function_name, "absolute_tolerance",
235 check_positive(function_name, "max_num_steps", max_num_steps_);
236
237 nv_state_ = N_VMake_Serial(N_, &coupled_state_[0], sundials_context_);
238 nv_state_sens_ = nullptr;
239 A_ = SUNDenseMatrix(N_, N_, sundials_context_);
240 LS_ = SUNLinSol_Dense(nv_state_, A_, sundials_context_);
241
242 if (num_y0_vars_ + num_args_vars_ > 0) {
244 = N_VCloneEmptyVectorArray(num_y0_vars_ + num_args_vars_, nv_state_);
245 for (std::size_t i = 0; i < num_y0_vars_ + num_args_vars_; i++) {
246 NV_DATA_S(nv_state_sens_[i]) = &coupled_state_[N_] + i * N_;
247 }
248 }
249 }
250
252 SUNLinSolFree(LS_);
253 SUNMatDestroy(A_);
254 N_VDestroy_Serial(nv_state_);
255 if (num_y0_vars_ + num_args_vars_ > 0) {
256 N_VDestroyVectorArray(nv_state_sens_, num_y0_vars_ + num_args_vars_);
257 }
258 }
259
268 std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>> operator()() {
269 std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>> y;
270
271 void* cvodes_mem = CVodeCreate(Lmm, sundials_context_);
272 if (cvodes_mem == nullptr) {
273 throw std::runtime_error("CVodeCreate failed to allocate memory");
274 }
275
276 try {
277 CHECK_CVODES_CALL(CVodeInit(cvodes_mem, &cvodes_integrator::cv_rhs,
279
280 // Assign pointer to this as user data
282 CVodeSetUserData(cvodes_mem, reinterpret_cast<void*>(this)));
283
285
286 CHECK_CVODES_CALL(CVodeSStolerances(cvodes_mem, relative_tolerance_,
288
289 CHECK_CVODES_CALL(CVodeSetLinearSolver(cvodes_mem, LS_, A_));
291 CVodeSetJacFn(cvodes_mem, &cvodes_integrator::cv_jacobian_states));
292
293 // initialize forward sensitivity system of CVODES as needed
294 if (num_y0_vars_ + num_args_vars_ > 0) {
295 CHECK_CVODES_CALL(CVodeSensInit(
296 cvodes_mem, static_cast<int>(num_y0_vars_ + num_args_vars_),
298
299 CHECK_CVODES_CALL(CVodeSetSensErrCon(cvodes_mem, SUNTRUE));
300
301 CHECK_CVODES_CALL(CVodeSensEEtolerances(cvodes_mem));
302 }
303
304 double t_init = value_of(t0_);
305 for (size_t n = 0; n < ts_.size(); ++n) {
306 double t_final = value_of(ts_[n]);
307
308 if (t_final != t_init) {
310 CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL));
311
312 if (num_y0_vars_ + num_args_vars_ > 0) {
314 CVodeGetSens(cvodes_mem, &t_init, nv_state_sens_));
315 }
316 }
317
318 y.emplace_back(math::apply(
319 [&](auto&&... args) {
321 ts_[n], msgs_, args...);
322 },
323 args_tuple_));
324
325 t_init = t_final;
326 }
327 } catch (const std::exception& e) {
328 CVodeFree(&cvodes_mem);
329 throw;
330 }
331
332 CVodeFree(&cvodes_mem);
333
334 return y;
335 }
336}; // cvodes integrator
337
338} // namespace math
339} // namespace stan
340#endif
#define CHECK_CVODES_CALL(call)
std::tuple< plain_type_t< decltype(value_of(std::declval< const T_Args & >()))> value_of_args_tuple_
void rhs(double t, const double y[], double dy_dt[]) const
Calculates the ODE RHS, dy_dt, using the user-supplied functor at the given time t and state y.
static int cv_rhs(realtype t, N_Vector y, N_Vector ydot, void *user_data)
Implements the function of type CVRhsFn which is the user-defined ODE RHS passed to CVODES.
cvodes_integrator(const char *function_name, const F &f, const T_y0 &y0, const T_t0 &t0, const std::vector< T_ts > &ts, double relative_tolerance, double absolute_tolerance, long int max_num_steps, std::ostream *msgs, const T_Args &... args)
Construct cvodes_integrator object.
static int cv_rhs_sens(int Ns, realtype t, N_Vector y, N_Vector ydot, N_Vector *yS, N_Vector *ySdot, void *user_data, N_Vector tmp1, N_Vector tmp2)
Implements the function of type CVSensRhsFn which is the RHS of the sensitivity ODE system.
static int cv_jacobian_states(realtype t, N_Vector y, N_Vector fy, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Implements the function of type CVDlsJacFn which is the user-defined callback for CVODES to calculate...
std::vector< double > coupled_state_
const std::vector< T_ts > & ts_
void rhs_sens(double t, const double y[], N_Vector *yS, N_Vector *ySdot)
Calculates the RHS of the sensitivity ODE system which corresponds to the coupled ode system from whi...
return_type_t< T_y0, T_t0 > T_y0_t0
return_type_t< T_y0, T_t0, T_ts, T_Args... > T_Return
const Eigen::Matrix< T_y0_t0, Eigen::Dynamic, 1 > y0_
coupled_ode_system< F, T_y0_t0, T_Args... > coupled_ode_
void jacobian_states(double t, const double y[], SUNMatrix J) const
Calculates the jacobian of the ODE RHS wrt to its states y at the given time-point t and state y.
std::tuple< const T_Args &... > args_tuple_
std::vector< Eigen::Matrix< T_Return, Eigen::Dynamic, 1 > > operator()()
Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of times, { t1,...
Integrator interface for CVODES' ODE solvers (Adams & BDF methods).
auto cast(T &&a)
Typecast a kernel generator expression scalar.
Definition cast.hpp:80
size_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:18
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
void jacobian(const F &f, const Eigen::Matrix< T, Eigen::Dynamic, 1 > &x, Eigen::Matrix< T, Eigen::Dynamic, 1 > &fx, Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &J)
Definition jacobian.hpp:11
static constexpr double e()
Return the base of the natural logarithm.
Definition constants.hpp:20
void cvodes_set_options(void *cvodes_mem, long int max_num_steps)
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
size_t count_vars(Pargs &&... args)
Count the number of vars in the input argument list.
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
Eigen::VectorXd ode_store_sensitivities(const F &f, const std::vector< double > &coupled_state, const Eigen::Matrix< T_y0_t0, Eigen::Dynamic, 1 > &y0, T_t0 t0, T_t t, std::ostream *msgs, const Args &... args)
When all arguments are arithmetic, there are no sensitivities to store, so the function just returns ...
void check_nonzero_size(const char *function, const char *name, const T_y &y)
Check if the specified matrix/vector is of non-zero size.
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
void check_sorted(const char *function, const char *name, const EigVec &y)
Check if the specified vector is sorted into increasing order (repeated values are okay).
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
void check_less(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not strictly less than high.
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
Definition apply.hpp:52
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9