1#ifndef STAN_MATH_REV_FUNCTOR_CVODES_INTEGRATOR_ADJOINT_HPP
2#define STAN_MATH_REV_FUNCTOR_CVODES_INTEGRATOR_ADJOINT_HPP
12#include <sundials/sundials_context.h>
13#include <cvodes/cvodes.h>
14#include <nvector/nvector_serial.h>
15#include <sunmatrix/sunmatrix_dense.h>
16#include <sunlinsol/sunlinsol_dense.h>
35template <
typename F,
typename T_y0,
typename T_t0,
typename T_ts,
77 const std::decay_t<F>
f_;
79 std::vector<Eigen::VectorXd>
y_;
81 std::vector<T_ts>
ts_;
82 Eigen::Matrix<T_y0_t0, Eigen::Dynamic, 1>
y0_;
105 template <
typename FF>
107 const T_t0& t0,
const std::vector<T_ts>& ts,
108 const Eigen::VectorXd& absolute_tolerance_forward,
109 const Eigen::VectorXd& absolute_tolerance_backward,
110 size_t num_args_vars,
int solver_forward,
111 const T_Args&... args)
115 f_(
std::forward<FF>(f)),
118 ts_(ts.begin(), ts.end()),
148 throw std::runtime_error(
"CVodeCreate failed to allocate memory");
210 template <
typename FF, require_eigen_col_vector_t<T_y0>* =
nullptr>
212 const char* function_name, FF&& f,
const T_y0& y0,
const T_t0& t0,
213 const std::vector<T_ts>& ts,
double relative_tolerance_forward,
214 const Eigen::VectorXd& absolute_tolerance_forward,
215 double relative_tolerance_backward,
216 const Eigen::VectorXd& absolute_tolerance_backward,
217 double relative_tolerance_quadrature,
218 double absolute_tolerance_quadrature,
219 long int max_num_steps,
220 long int num_steps_between_checkpoints,
221 int interpolation_polynomial,
int solver_forward,
int solver_backward,
222 std::ostream* msgs,
const T_Args&... args)
255 check_less(function_name,
"initial time", t0, ts[0]);
259 absolute_tolerance_forward);
261 absolute_tolerance_forward.size(),
"states",
N_);
265 absolute_tolerance_backward);
267 absolute_tolerance_backward.size(),
"states",
N_);
279 ", must be 1 for Hermite or 2 for polynomial "
280 "interpolation of ODE solution");
284 ", must be 1 for Adams or 2 for BDF forward solver");
287 ", must be 1 for Adams or 2 for BDF backward solver");
289 solver_ =
new cvodes_solver(function_name, std::forward<FF>(f),
N_, y0, t0,
290 ts, absolute_tolerance_forward,
295 [func_name = function_name](
auto&&
arg) {
335 for (
size_t n = 0; n < ts_dbl.size(); ++n) {
336 double t_final = ts_dbl[n];
337 if (t_final != t_init) {
343 CV_NORMAL, &ncheck));
352 for (std::size_t i = 0; i <
N_; ++i)
369 Eigen::Matrix<var, Eigen::Dynamic, 1>& state_return) {
370 state_return.resize(
N_);
371 for (
size_t i = 0; i <
N_; i++) {
377 Eigen::Matrix<double, Eigen::Dynamic, 1>& state_return) {
378 state_return = state;
388 std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>>
solution() noexcept {
389 std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>> y_return(
391 for (std::size_t n = 0; n <
solver_->
ts_.size(); ++n)
409 Eigen::VectorXd step_sens = Eigen::VectorXd::Zero(
N_);
410 for (
int i = 0; i <
solver_->
ts_.size(); ++i) {
411 for (
int j = 0; j <
N_; ++j) {
432 for (
int i =
solver_->
ts_.size() - 1; i >= 0; --i) {
435 for (
int j = 0; j <
N_; ++j) {
441 if (t_final != t_init) {
448 reinterpret_cast<void*
>(
this)));
529 forward_as<Eigen::Matrix<var, Eigen::Dynamic, 1>>(
solver_->
y0_).adj()
545 template <
typename yT,
typename... ArgsT>
546 constexpr auto rhs(
double t,
const yT& y,
547 const std::tuple<ArgsT...>& args_tuple)
const {
549 [&](
auto&&... args) {
return solver_->
f_(t, y,
msgs_, args...); },
565 inline int rhs(
double t,
const double* y,
double*& dy_dt)
const {
566 const Eigen::VectorXd y_vec = Eigen::Map<const Eigen::VectorXd>(y,
N_);
567 const Eigen::VectorXd dy_dt_vec
570 dy_dt_vec.size(),
"states",
N_);
571 Eigen::Map<Eigen::VectorXd>(dy_dt,
N_) = dy_dt_vec;
579 constexpr static int cv_rhs(realtype t, N_Vector y, N_Vector ydot,
595 inline int rhs_adj(
double t, N_Vector y, N_Vector yB, N_Vector yBdot)
const {
598 Eigen::Matrix<var, Eigen::Dynamic, 1> y_vars(
599 Eigen::Map<const Eigen::VectorXd>(NV_DATA_S(y),
N_));
600 Eigen::Matrix<var, Eigen::Dynamic, 1> f_y_t_vars
603 f_y_t_vars.size(),
"states",
N_);
604 f_y_t_vars.adj() = -Eigen::Map<Eigen::VectorXd>(NV_DATA_S(yB),
N_);
606 Eigen::Map<Eigen::VectorXd>(NV_DATA_S(yBdot),
N_) = y_vars.adj();
614 constexpr static int cv_rhs_adj(realtype t, N_Vector y, N_Vector yB,
615 N_Vector yBdot,
void* user_data) {
630 inline int quad_rhs_adj(
double t, N_Vector y, N_Vector yB, N_Vector qBdot) {
631 Eigen::Map<const Eigen::VectorXd> y_vec(NV_DATA_S(y),
N_);
639 Eigen::Matrix<var, Eigen::Dynamic, 1> f_y_t_vars
642 f_y_t_vars.size(),
"states",
N_);
643 f_y_t_vars.adj() = -Eigen::Map<Eigen::VectorXd>(NV_DATA_S(yB),
N_);
646 [&qBdot](
auto&&... args) {
658 N_Vector qBdot,
void* user_data) {
667 Eigen::Map<Eigen::MatrixXd> Jfy(SM_DATA_D(J),
N_,
N_);
671 Eigen::Matrix<var, Eigen::Dynamic, 1> y_var(
672 Eigen::Map<const Eigen::VectorXd>(NV_DATA_S(y),
N_));
673 Eigen::Matrix<var, Eigen::Dynamic, 1> fy_var
677 fy_var.size(),
"states",
N_);
679 grad(fy_var.coeffRef(0).vi_);
680 Jfy.col(0) = y_var.adj();
681 for (
int i = 1; i < fy_var.size(); ++i) {
683 grad(fy_var.coeffRef(i).vi_);
684 Jfy.col(i) = y_var.adj();
686 Jfy.transposeInPlace();
697 N_Vector fy, SUNMatrix J,
698 void* user_data, N_Vector tmp1,
699 N_Vector tmp2, N_Vector tmp3) {
715 Eigen::Map<Eigen::MatrixXd> J_adj_y(SM_DATA_D(J),
N_,
N_);
716 J_adj_y.transposeInPlace();
717 J_adj_y.array() *= -1.0;
726 N_Vector yB, N_Vector fyB,
727 SUNMatrix J,
void* user_data,
728 N_Vector tmp1, N_Vector tmp2,
#define CHECK_CVODES_CALL(call)
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
const long int num_steps_between_checkpoints_
void store_state(std::size_t n, const Eigen::VectorXd &state, Eigen::Matrix< var, Eigen::Dynamic, 1 > &state_return)
Overloads which setup the states returned from the forward solve.
void set_zero_adjoint() final
No-op for setting adjoints since this class does not own any adjoints.
static constexpr int cv_jacobian_rhs_states(realtype t, N_Vector y, N_Vector fy, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Implements the function of type CVDlsJacFn which is the user-defined callback for CVODES to calculate...
const long int max_num_steps_
void store_state(std::size_t n, const Eigen::VectorXd &state, Eigen::Matrix< double, Eigen::Dynamic, 1 > &state_return)
int rhs(double t, const double *y, double *&dy_dt) const
Calculates the ODE RHS, dy_dt, using the user-supplied functor at the given time t and state y.
const int interpolation_polynomial_
const int solver_backward_
return_type_t< T_y0, T_t0, T_ts, T_Args... > T_Return
const double relative_tolerance_quadrature_
return_type_t< T_y0, T_t0 > T_y0_t0
int rhs_adj(double t, N_Vector y, N_Vector yB, N_Vector yBdot) const
cvodes_integrator_adjoint_vari(const char *function_name, FF &&f, const T_y0 &y0, const T_t0 &t0, const std::vector< T_ts > &ts, double relative_tolerance_forward, const Eigen::VectorXd &absolute_tolerance_forward, double relative_tolerance_backward, const Eigen::VectorXd &absolute_tolerance_backward, double relative_tolerance_quadrature, double absolute_tolerance_quadrature, long int max_num_steps, long int num_steps_between_checkpoints, int interpolation_polynomial, int solver_forward, int solver_backward, std::ostream *msgs, const T_Args &... args)
Construct cvodes_integrator object.
int jacobian_rhs_adj_states(double t, N_Vector y, SUNMatrix J) const
bool backward_is_initialized_
std::vector< Eigen::Matrix< T_Return, Eigen::Dynamic, 1 > > solution() noexcept
Obtain solution of ODE.
static constexpr int cv_quad_rhs_adj(realtype t, N_Vector y, N_Vector yB, N_Vector qBdot, void *user_data)
Implements the function of type CVQuadRhsFnB which is the RHS of the backward ODE system's quadrature...
const double absolute_tolerance_quadrature_
constexpr auto rhs(double t, const yT &y, const std::tuple< ArgsT... > &args_tuple) const
Call the ODE RHS with given tuple.
void chain() final
Apply the chain rule to this variable based on the variables on which it depends.
static constexpr int cv_rhs(realtype t, N_Vector y, N_Vector ydot, void *user_data)
Implements the function of type CVRhsFn which is the user-defined ODE RHS passed to CVODES.
const size_t num_args_vars_
static constexpr int cv_jacobian_rhs_adj_states(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Implements the CVLsJacFnB function for evaluating the jacobian of the adjoint problem wrt to the back...
static constexpr bool is_var_ts_
static constexpr cvodes_integrator_adjoint_vari * cast_to_self(void *mem)
Utility to cast user memory pointer passed in from CVODES to actual typed object pointer.
const double relative_tolerance_forward_
static constexpr bool is_var_y0_
static constexpr bool is_any_var_args_
const double relative_tolerance_backward_
const int solver_forward_
static constexpr bool is_var_only_ts_
int jacobian_rhs_states(double t, N_Vector y, SUNMatrix J) const
Calculates the jacobian of the ODE RHS wrt to its states y at the given time-point t and state y.
static constexpr int cv_rhs_adj(realtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_data)
Implements the function of type CVRhsFnB which is the RHS of the backward ODE system.
static constexpr bool is_var_t0_
int quad_rhs_adj(double t, N_Vector y, N_Vector yB, N_Vector qBdot)
static constexpr bool is_var_y0_t0_
static constexpr bool is_var_return_
Integrator interface for CVODES' adjoint ODE solvers (Adams & BDF methods).
void set_zero_all_adjoints()
Reset all adjoint values in this nested stack to zero.
A class following the RAII idiom to start and recover nested autodiff scopes.
T * alloc_array(size_t n)
Allocate an array on the arena of the specified size to hold values of the specified template paramet...
Abstract base class that all vari_value and it's derived classes inherit.
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
int64_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
(Expert) Numerical traits for algorithmic differentiation variables.
constexpr auto for_each(F &&f, T &&t)
Apply a function to each element of a tuple.
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
void cvodes_set_options(void *cvodes_mem, long int max_num_steps)
auto & adjoint_of(const T &x)
Returns a reference to a variable's adjoint.
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
T value_of(const fvar< T > &v)
Return the value of the specified variable.
vari_value< double > vari
Arith deep_copy_vars(Arith &&arg)
Forward arguments that do not contain vars.
size_t count_vars(Pargs &&... args)
Count the number of vars in the input argument list.
vari ** save_varis(vari **dest, const var &x, Pargs &&... args)
Save the vari pointer in x into the memory pointed to by dest, increment the dest storage pointer,...
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
void zero_adjoints() noexcept
End of recursion for set_zero_adjoints.
void check_nonzero_size(const char *function, const char *name, const T_y &y)
Check if the specified matrix/vector is of non-zero size.
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
void check_sorted(const char *function, const char *name, const EigVec &y)
Check if the specified vector is sorted into increasing order (repeated values are okay).
double * accumulate_adjoints(double *dest, const var &x, Pargs &&... args)
Accumulate adjoints from x into storage pointed to by dest, increment the adjoint storage pointer,...
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
void check_less(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not strictly less than high.
static void grad()
Compute the gradient for all variables starting from the end of the AD tape.
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Defines a static member named value which is defined to be false as the primitive scalar types cannot...
std::vector< ChainableT * > var_stack_
static thread_local AutodiffStackStorage * instance_
This struct always provides access to the autodiff stack using the singleton pattern.
Eigen::VectorXd absolute_tolerance_backward_
Eigen::VectorXd state_forward_
const std::decay_t< F > f_
const std::string function_name_str_
SUNLinearSolver LS_backward_
Eigen::VectorXd state_backward_
cvodes_solver(const char *function_name, FF &&f, size_t N, const T_y0 &y0, const T_t0 &t0, const std::vector< T_ts > &ts, const Eigen::VectorXd &absolute_tolerance_forward, const Eigen::VectorXd &absolute_tolerance_backward, size_t num_args_vars, int solver_forward, const T_Args &... args)
const std::tuple< promote_scalar_t< partials_type_t< scalar_type_t< T_Args > >, T_Args >... > value_of_args_tuple_
N_Vector nv_absolute_tolerance_backward_
std::tuple< T_Args... > local_args_tuple_
N_Vector nv_absolute_tolerance_forward_
N_Vector nv_state_forward_
Eigen::Matrix< T_y0_t0, Eigen::Dynamic, 1 > y0_
N_Vector nv_state_backward_
sundials::Context sundials_context_
std::vector< Eigen::VectorXd > y_
Eigen::VectorXd absolute_tolerance_forward_
SUNLinearSolver LS_forward_
Since the CVODES solver manages memory with malloc calls, these resources must be freed using a destr...
Extends std::false_type when instantiated with zero or more template parameters, all of which extend ...