Automatic Differentiation
 
Loading...
Searching...
No Matches
cvodes_integrator_adjoint.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUNCTOR_CVODES_INTEGRATOR_ADJOINT_HPP
2#define STAN_MATH_REV_FUNCTOR_CVODES_INTEGRATOR_ADJOINT_HPP
3
12#include <sundials/sundials_context.h>
13#include <cvodes/cvodes.h>
14#include <nvector/nvector_serial.h>
15#include <sunmatrix/sunmatrix_dense.h>
16#include <sunlinsol/sunlinsol_dense.h>
17#include <algorithm>
18#include <ostream>
19#include <utility>
20#include <vector>
21
22namespace stan {
23namespace math {
24
35template <typename F, typename T_y0, typename T_t0, typename T_ts,
36 typename... T_Args>
38 using T_Return = return_type_t<T_y0, T_t0, T_ts, T_Args...>;
40
41 static constexpr bool is_var_ts_{is_var<T_ts>::value};
42 static constexpr bool is_var_t0_{is_var<T_t0>::value};
43 static constexpr bool is_var_y0_{is_var<T_y0>::value};
44 static constexpr bool is_var_y0_t0_{is_var<T_y0_t0>::value};
45 static constexpr bool is_any_var_args_{
48 static constexpr bool is_var_only_ts_{
50
51 const size_t num_args_vars_;
52
57 const long int max_num_steps_; // NOLINT(runtime/int)
58 const long int num_steps_between_checkpoints_; // NOLINT(runtime/int)
59 const size_t N_;
60 std::ostream* msgs_;
64 const int solver_forward_;
68
75 sundials::Context sundials_context_;
76 const std::string function_name_str_;
77 const std::decay_t<F> f_;
78 const size_t N_;
79 std::vector<Eigen::VectorXd> y_;
80
81 std::vector<T_ts> ts_;
82 Eigen::Matrix<T_y0_t0, Eigen::Dynamic, 1> y0_;
85 Eigen::VectorXd state_forward_;
86 Eigen::VectorXd state_backward_;
87 Eigen::VectorXd quad_;
88 T_t0 t0_;
89
92 N_Vector nv_quad_;
95 SUNMatrix A_forward_;
96 SUNMatrix A_backward_;
97 SUNLinearSolver LS_forward_;
98 SUNLinearSolver LS_backward_;
100 std::tuple<T_Args...> local_args_tuple_;
101 const std::tuple<
104
105 template <typename FF>
106 cvodes_solver(const char* function_name, FF&& f, size_t N, const T_y0& y0,
107 const T_t0& t0, const std::vector<T_ts>& ts,
108 const Eigen::VectorXd& absolute_tolerance_forward,
109 const Eigen::VectorXd& absolute_tolerance_backward,
110 size_t num_args_vars, int solver_forward,
111 const T_Args&... args)
112 : chainable_alloc(),
114 function_name_str_(function_name),
115 f_(std::forward<FF>(f)),
116 N_(N),
117 y_(ts.size()),
118 ts_(ts.begin(), ts.end()),
119 y0_(y0),
120 absolute_tolerance_forward_(absolute_tolerance_forward),
121 absolute_tolerance_backward_(absolute_tolerance_backward),
123 state_backward_(Eigen::VectorXd::Zero(N)),
124 quad_(Eigen::VectorXd::Zero(num_args_vars)),
125 t0_(t0),
127 N_VMake_Serial(N, state_forward_.data(), sundials_context_)),
129 N_VMake_Serial(N, state_backward_.data(), sundials_context_)),
130 nv_quad_(
131 N_VMake_Serial(num_args_vars, quad_.data(), sundials_context_)),
132 nv_absolute_tolerance_forward_(N_VMake_Serial(
134 nv_absolute_tolerance_backward_(N_VMake_Serial(
136 A_forward_(SUNDenseMatrix(N, N, sundials_context_)),
137 A_backward_(SUNDenseMatrix(N, N, sundials_context_)),
138 LS_forward_(N == 0 ? nullptr
139 : SUNLinSol_Dense(nv_state_forward_, A_forward_,
141 LS_backward_(N == 0 ? nullptr
142 : SUNLinSol_Dense(nv_state_backward_, A_backward_,
144 cvodes_mem_(CVodeCreate(solver_forward, sundials_context_)),
147 if (cvodes_mem_ == nullptr) {
148 throw std::runtime_error("CVodeCreate failed to allocate memory");
149 }
150 }
151
152 virtual ~cvodes_solver() {
153 SUNMatDestroy(A_forward_);
154 SUNMatDestroy(A_backward_);
155 if (N_ > 0) {
156 SUNLinSolFree(LS_forward_);
157 SUNLinSolFree(LS_backward_);
158 }
159 N_VDestroy_Serial(nv_state_forward_);
160 N_VDestroy_Serial(nv_state_backward_);
161 N_VDestroy_Serial(nv_quad_);
162 N_VDestroy_Serial(nv_absolute_tolerance_forward_);
163 N_VDestroy_Serial(nv_absolute_tolerance_backward_);
164
165 CVodeFree(&cvodes_mem_);
166 }
167 };
169
170 public:
210 template <typename FF, require_eigen_col_vector_t<T_y0>* = nullptr>
212 const char* function_name, FF&& f, const T_y0& y0, const T_t0& t0,
213 const std::vector<T_ts>& ts, double relative_tolerance_forward,
214 const Eigen::VectorXd& absolute_tolerance_forward,
215 double relative_tolerance_backward,
216 const Eigen::VectorXd& absolute_tolerance_backward,
217 double relative_tolerance_quadrature,
218 double absolute_tolerance_quadrature,
219 long int max_num_steps, // NOLINT(runtime/int)
220 long int num_steps_between_checkpoints, // NOLINT(runtime/int)
221 int interpolation_polynomial, int solver_forward, int solver_backward,
222 std::ostream* msgs, const T_Args&... args)
223 : vari_base(),
224 num_args_vars_(count_vars(args...)),
225 relative_tolerance_forward_(relative_tolerance_forward),
226 relative_tolerance_backward_(relative_tolerance_backward),
227 relative_tolerance_quadrature_(relative_tolerance_quadrature),
228 absolute_tolerance_quadrature_(absolute_tolerance_quadrature),
229 max_num_steps_(max_num_steps),
230 num_steps_between_checkpoints_(num_steps_between_checkpoints),
231 N_(y0.size()),
232 msgs_(msgs),
233 y_return_varis_(is_var_return_ ? ChainableStack::instance_->memalloc_
234 .alloc_array<vari*>(N_ * ts.size())
235 : nullptr),
236 args_varis_([&args..., num_vars = this->num_args_vars_]() {
237 vari** vari_mem
239 num_vars);
240 save_varis(vari_mem, args...);
241 return vari_mem;
242 }()),
243 interpolation_polynomial_(interpolation_polynomial),
244 solver_forward_(solver_forward),
245 solver_backward_(solver_backward),
247 solver_(nullptr) {
248 check_finite(function_name, "initial state", y0);
249 check_finite(function_name, "initial time", t0);
250 check_finite(function_name, "times", ts);
251
252 check_nonzero_size(function_name, "times", ts);
253 check_nonzero_size(function_name, "initial state", y0);
254 check_sorted(function_name, "times", ts);
255 check_less(function_name, "initial time", t0, ts[0]);
256 check_positive_finite(function_name, "relative_tolerance_forward",
258 check_positive_finite(function_name, "absolute_tolerance_forward",
259 absolute_tolerance_forward);
260 check_size_match(function_name, "absolute_tolerance_forward",
261 absolute_tolerance_forward.size(), "states", N_);
262 check_positive_finite(function_name, "relative_tolerance_backward",
264 check_positive_finite(function_name, "absolute_tolerance_backward",
265 absolute_tolerance_backward);
266 check_size_match(function_name, "absolute_tolerance_backward",
267 absolute_tolerance_backward.size(), "states", N_);
268 check_positive_finite(function_name, "relative_tolerance_quadrature",
270 check_positive_finite(function_name, "absolute_tolerance_quadrature",
272 check_positive(function_name, "max_num_steps", max_num_steps_);
273 check_positive(function_name, "num_steps_between_checkpoints",
275 // for polynomial: 1=CV_HERMITE / 2=CV_POLYNOMIAL
277 invalid_argument(function_name, "interpolation_polynomial",
279 ", must be 1 for Hermite or 2 for polynomial "
280 "interpolation of ODE solution");
281 // 1=Adams=CV_ADAMS, 2=BDF=CV_BDF
282 if (solver_forward_ != 1 && solver_forward_ != 2)
283 invalid_argument(function_name, "solver_forward", solver_forward_, "",
284 ", must be 1 for Adams or 2 for BDF forward solver");
285 if (solver_backward_ != 1 && solver_backward_ != 2)
286 invalid_argument(function_name, "solver_backward", solver_backward_, "",
287 ", must be 1 for Adams or 2 for BDF backward solver");
288
289 solver_ = new cvodes_solver(function_name, std::forward<FF>(f), N_, y0, t0,
290 ts, absolute_tolerance_forward,
291 absolute_tolerance_backward, num_args_vars_,
292 solver_forward_, args...);
293
295 [func_name = function_name](auto&& arg) {
296 check_finite(func_name, "ode parameters and data", arg);
297 },
299
303
304 // Assign pointer to this as user data
306 CVodeSetUserData(solver_->cvodes_mem_, reinterpret_cast<void*>(this)));
307
309
313
314 CHECK_CVODES_CALL(CVodeSetLinearSolver(
316
318 CVodeSetJacFn(solver_->cvodes_mem_,
320
321 // initialize backward sensitivity system of CVODES as needed
326 }
327
332 const auto ts_dbl = value_of(solver_->ts_);
333
334 double t_init = value_of(solver_->t0_);
335 for (size_t n = 0; n < ts_dbl.size(); ++n) {
336 double t_final = ts_dbl[n];
337 if (t_final != t_init) {
339 int ncheck;
340
341 CHECK_CVODES_CALL(CVodeF(solver_->cvodes_mem_, t_final,
342 solver_->nv_state_forward_, &t_init,
343 CV_NORMAL, &ncheck));
344 } else {
345 CHECK_CVODES_CALL(CVode(solver_->cvodes_mem_, t_final,
346 solver_->nv_state_forward_, &t_init,
347 CV_NORMAL));
348 }
349 }
351 if (is_var_return_) {
352 for (std::size_t i = 0; i < N_; ++i)
353 y_return_varis_[N_ * n + i]
354 = new vari(solver_->state_forward_.coeff(i), false);
355 }
356
357 t_init = t_final;
358 }
359 ChainableStack::instance_->var_stack_.push_back(this);
360 }
361
362 private:
368 void store_state(std::size_t n, const Eigen::VectorXd& state,
369 Eigen::Matrix<var, Eigen::Dynamic, 1>& state_return) {
370 state_return.resize(N_);
371 for (size_t i = 0; i < N_; i++) {
372 state_return.coeffRef(i) = var(y_return_varis_[N_ * n + i]);
373 }
374 }
375
376 void store_state(std::size_t n, const Eigen::VectorXd& state,
377 Eigen::Matrix<double, Eigen::Dynamic, 1>& state_return) {
378 state_return = state;
379 }
380
381 public:
388 std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>> solution() noexcept {
389 std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>> y_return(
390 solver_->ts_.size());
391 for (std::size_t n = 0; n < solver_->ts_.size(); ++n)
392 store_state(n, solver_->y_[n], y_return[n]);
393 return y_return;
394 }
395
399 void set_zero_adjoint() final{};
400
401 void chain() final {
402 if (!is_var_return_) {
403 return;
404 }
405
406 // for sensitivities wrt to ts we do not need to run the backward
407 // integration
408 if (is_var_ts_) {
409 Eigen::VectorXd step_sens = Eigen::VectorXd::Zero(N_);
410 for (int i = 0; i < solver_->ts_.size(); ++i) {
411 for (int j = 0; j < N_; ++j) {
412 step_sens.coeffRef(j) += y_return_varis_[N_ * i + j]->adj_;
413 }
414
416 += step_sens.dot(rhs(value_of(solver_->ts_[i]), solver_->y_[i],
418 step_sens.setZero();
419 }
420
421 if (is_var_only_ts_) {
422 return;
423 }
424 }
425
426 solver_->state_backward_.setZero();
427 solver_->quad_.setZero();
428
429 // At every time step, collect the adjoints from the output
430 // variables and re-initialize the solver
431 double t_init = value_of(solver_->ts_.back());
432 for (int i = solver_->ts_.size() - 1; i >= 0; --i) {
433 // Take in the adjoints from all the output variables at this point
434 // in time
435 for (int j = 0; j < N_; ++j) {
436 solver_->state_backward_.coeffRef(j)
437 += y_return_varis_[N_ * i + j]->adj_;
438 }
439
440 double t_final = value_of((i > 0) ? solver_->ts_[i - 1] : solver_->t0_);
441 if (t_final != t_init) {
445
446 CHECK_CVODES_CALL(CVodeSetUserDataB(solver_->cvodes_mem_,
448 reinterpret_cast<void*>(this)));
449
450 // initialize CVODES backward machinery.
451 // the states of the backward problem *are* the adjoints
452 // of the ode states
457
459 CVodeSVtolerancesB(solver_->cvodes_mem_, index_backward_,
462
463 CHECK_CVODES_CALL(CVodeSetMaxNumStepsB(
465
466 CHECK_CVODES_CALL(CVodeSetLinearSolverB(
469
470 CHECK_CVODES_CALL(CVodeSetJacFnB(
473
474 // Allocate space for backwards quadrature needed when
475 // parameters vary.
476 if (is_any_var_args_) {
478 CVodeQuadInitB(solver_->cvodes_mem_, index_backward_,
480 solver_->nv_quad_));
481
483 CVodeQuadSStolerancesB(solver_->cvodes_mem_, index_backward_,
486
487 CHECK_CVODES_CALL(CVodeSetQuadErrConB(solver_->cvodes_mem_,
488 index_backward_, SUNTRUE));
489 }
490
492 } else {
493 // just re-initialize the solver
494
496 t_init, solver_->nv_state_backward_));
497
498 if (is_any_var_args_) {
499 CHECK_CVODES_CALL(CVodeQuadReInitB(
501 }
502 }
503
504 CHECK_CVODES_CALL(CVodeB(solver_->cvodes_mem_, t_final, CV_NORMAL));
505
506 // obtain adjoint states and update t_init to time point
507 // reached of t_final
509 &t_init, solver_->nv_state_backward_));
510
511 if (is_any_var_args_) {
513 &t_init, solver_->nv_quad_));
514 }
515 }
516 }
517
518 // After integrating all the way back to t0, we finally have the
519 // the adjoints we wanted
520
521 // This is the dlog_density / d(initial_time_point) adjoint
522 if (is_var_t0_) {
525 }
526
527 // These are the dlog_density / d(initial_conditions[s]) adjoints
528 if (is_var_y0_t0_) {
529 forward_as<Eigen::Matrix<var, Eigen::Dynamic, 1>>(solver_->y0_).adj()
531 }
532
533 // These are the dlog_density / d(parameters[s]) adjoints
534 if (is_any_var_args_) {
535 for (size_t s = 0; s < num_args_vars_; ++s) {
536 args_varis_[s]->adj_ += solver_->quad_.coeff(s);
537 }
538 }
539 }
540
541 private:
545 template <typename yT, typename... ArgsT>
546 constexpr auto rhs(double t, const yT& y,
547 const std::tuple<ArgsT...>& args_tuple) const {
548 return math::apply(
549 [&](auto&&... args) { return solver_->f_(t, y, msgs_, args...); },
550 args_tuple);
551 }
552
557 constexpr static cvodes_integrator_adjoint_vari* cast_to_self(void* mem) {
558 return static_cast<cvodes_integrator_adjoint_vari*>(mem);
559 }
560
565 inline int rhs(double t, const double* y, double*& dy_dt) const {
566 const Eigen::VectorXd y_vec = Eigen::Map<const Eigen::VectorXd>(y, N_);
567 const Eigen::VectorXd dy_dt_vec
568 = rhs(t, y_vec, solver_->value_of_args_tuple_);
570 dy_dt_vec.size(), "states", N_);
571 Eigen::Map<Eigen::VectorXd>(dy_dt, N_) = dy_dt_vec;
572 return 0;
573 }
574
579 constexpr static int cv_rhs(realtype t, N_Vector y, N_Vector ydot,
580 void* user_data) {
581 return cast_to_self(user_data)->rhs(t, NV_DATA_S(y), NV_DATA_S(ydot));
582 }
583
584 /*
585 * Calculate the adjoint sensitivity RHS for varying initial conditions
586 * and parameters
587 *
588 * Equation 2.23 in the cvs_guide.
589 *
590 * @param[in] t time
591 * @param[in] y state of the base ODE system
592 * @param[in] yB state of the adjoint ODE system
593 * @param[out] yBdot evaluation of adjoint ODE RHS
594 */
595 inline int rhs_adj(double t, N_Vector y, N_Vector yB, N_Vector yBdot) const {
596 const nested_rev_autodiff nested;
597
598 Eigen::Matrix<var, Eigen::Dynamic, 1> y_vars(
599 Eigen::Map<const Eigen::VectorXd>(NV_DATA_S(y), N_));
600 Eigen::Matrix<var, Eigen::Dynamic, 1> f_y_t_vars
601 = rhs(t, y_vars, solver_->value_of_args_tuple_);
603 f_y_t_vars.size(), "states", N_);
604 f_y_t_vars.adj() = -Eigen::Map<Eigen::VectorXd>(NV_DATA_S(yB), N_);
605 grad();
606 Eigen::Map<Eigen::VectorXd>(NV_DATA_S(yBdot), N_) = y_vars.adj();
607 return 0;
608 }
609
614 constexpr static int cv_rhs_adj(realtype t, N_Vector y, N_Vector yB,
615 N_Vector yBdot, void* user_data) {
616 return cast_to_self(user_data)->rhs_adj(t, y, yB, yBdot);
617 }
618
619 /*
620 * Calculate the RHS for the quadrature part of the adjoint ODE
621 * problem.
622 *
623 * This is the integrand of equation 2.22 in the cvs_guide.
624 *
625 * @param[in] t time
626 * @param[in] y state of the base ODE system
627 * @param[in] yB state of the adjoint ODE system
628 * @param[out] qBdot evaluation of adjoint ODE quadrature RHS
629 */
630 inline int quad_rhs_adj(double t, N_Vector y, N_Vector yB, N_Vector qBdot) {
631 Eigen::Map<const Eigen::VectorXd> y_vec(NV_DATA_S(y), N_);
632 const nested_rev_autodiff nested;
633
634 // The vars here do not live on the nested stack so must be zero'd
635 // separately
636 stan::math::for_each([](auto&& arg) { zero_adjoints(arg); },
638
639 Eigen::Matrix<var, Eigen::Dynamic, 1> f_y_t_vars
640 = rhs(t, y_vec, solver_->local_args_tuple_);
642 f_y_t_vars.size(), "states", N_);
643 f_y_t_vars.adj() = -Eigen::Map<Eigen::VectorXd>(NV_DATA_S(yB), N_);
644 grad();
646 [&qBdot](auto&&... args) {
647 accumulate_adjoints(NV_DATA_S(qBdot), args...);
648 },
650 return 0;
651 }
652
657 constexpr static int cv_quad_rhs_adj(realtype t, N_Vector y, N_Vector yB,
658 N_Vector qBdot, void* user_data) {
659 return cast_to_self(user_data)->quad_rhs_adj(t, y, yB, qBdot);
660 }
661
666 inline int jacobian_rhs_states(double t, N_Vector y, SUNMatrix J) const {
667 Eigen::Map<Eigen::MatrixXd> Jfy(SM_DATA_D(J), N_, N_);
668
669 nested_rev_autodiff nested;
670
671 Eigen::Matrix<var, Eigen::Dynamic, 1> y_var(
672 Eigen::Map<const Eigen::VectorXd>(NV_DATA_S(y), N_));
673 Eigen::Matrix<var, Eigen::Dynamic, 1> fy_var
674 = rhs(t, y_var, solver_->value_of_args_tuple_);
675
677 fy_var.size(), "states", N_);
678
679 grad(fy_var.coeffRef(0).vi_);
680 Jfy.col(0) = y_var.adj();
681 for (int i = 1; i < fy_var.size(); ++i) {
682 nested.set_zero_all_adjoints();
683 grad(fy_var.coeffRef(i).vi_);
684 Jfy.col(i) = y_var.adj();
685 }
686 Jfy.transposeInPlace();
687 return 0;
688 }
689
696 constexpr static int cv_jacobian_rhs_states(realtype t, N_Vector y,
697 N_Vector fy, SUNMatrix J,
698 void* user_data, N_Vector tmp1,
699 N_Vector tmp2, N_Vector tmp3) {
700 return cast_to_self(user_data)->jacobian_rhs_states(t, y, J);
701 }
702
703 /*
704 * Calculate the Jacobian of the RHS of the adjoint ODE (see rhs_adj
705 * below for citation for how this is done)
706 *
707 * @param[in] t Time
708 * @param[in] y State of system
709 * @param[out] J CVode structure where output is to be stored
710 */
711 inline int jacobian_rhs_adj_states(double t, N_Vector y, SUNMatrix J) const {
712 // J_adj_y = -1 * transpose(J_y)
713 int error_code = jacobian_rhs_states(t, y, J);
714
715 Eigen::Map<Eigen::MatrixXd> J_adj_y(SM_DATA_D(J), N_, N_);
716 J_adj_y.transposeInPlace();
717 J_adj_y.array() *= -1.0;
718 return error_code;
719 }
720
725 constexpr static int cv_jacobian_rhs_adj_states(realtype t, N_Vector y,
726 N_Vector yB, N_Vector fyB,
727 SUNMatrix J, void* user_data,
728 N_Vector tmp1, N_Vector tmp2,
729 N_Vector tmp3) {
730 return cast_to_self(user_data)->jacobian_rhs_adj_states(t, y, J);
731 }
732}; // cvodes integrator adjoint vari
733
734} // namespace math
735} // namespace stan
736#endif
#define CHECK_CVODES_CALL(call)
A chainable_alloc is an object which is constructed and destructed normally but the memory lifespan i...
void store_state(std::size_t n, const Eigen::VectorXd &state, Eigen::Matrix< var, Eigen::Dynamic, 1 > &state_return)
Overloads which setup the states returned from the forward solve.
void set_zero_adjoint() final
No-op for setting adjoints since this class does not own any adjoints.
static constexpr int cv_jacobian_rhs_states(realtype t, N_Vector y, N_Vector fy, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Implements the function of type CVDlsJacFn which is the user-defined callback for CVODES to calculate...
void store_state(std::size_t n, const Eigen::VectorXd &state, Eigen::Matrix< double, Eigen::Dynamic, 1 > &state_return)
int rhs(double t, const double *y, double *&dy_dt) const
Calculates the ODE RHS, dy_dt, using the user-supplied functor at the given time t and state y.
return_type_t< T_y0, T_t0, T_ts, T_Args... > T_Return
int rhs_adj(double t, N_Vector y, N_Vector yB, N_Vector yBdot) const
cvodes_integrator_adjoint_vari(const char *function_name, FF &&f, const T_y0 &y0, const T_t0 &t0, const std::vector< T_ts > &ts, double relative_tolerance_forward, const Eigen::VectorXd &absolute_tolerance_forward, double relative_tolerance_backward, const Eigen::VectorXd &absolute_tolerance_backward, double relative_tolerance_quadrature, double absolute_tolerance_quadrature, long int max_num_steps, long int num_steps_between_checkpoints, int interpolation_polynomial, int solver_forward, int solver_backward, std::ostream *msgs, const T_Args &... args)
Construct cvodes_integrator object.
int jacobian_rhs_adj_states(double t, N_Vector y, SUNMatrix J) const
std::vector< Eigen::Matrix< T_Return, Eigen::Dynamic, 1 > > solution() noexcept
Obtain solution of ODE.
static constexpr int cv_quad_rhs_adj(realtype t, N_Vector y, N_Vector yB, N_Vector qBdot, void *user_data)
Implements the function of type CVQuadRhsFnB which is the RHS of the backward ODE system's quadrature...
constexpr auto rhs(double t, const yT &y, const std::tuple< ArgsT... > &args_tuple) const
Call the ODE RHS with given tuple.
void chain() final
Apply the chain rule to this variable based on the variables on which it depends.
static constexpr int cv_rhs(realtype t, N_Vector y, N_Vector ydot, void *user_data)
Implements the function of type CVRhsFn which is the user-defined ODE RHS passed to CVODES.
static constexpr int cv_jacobian_rhs_adj_states(realtype t, N_Vector y, N_Vector yB, N_Vector fyB, SUNMatrix J, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
Implements the CVLsJacFnB function for evaluating the jacobian of the adjoint problem wrt to the back...
static constexpr cvodes_integrator_adjoint_vari * cast_to_self(void *mem)
Utility to cast user memory pointer passed in from CVODES to actual typed object pointer.
int jacobian_rhs_states(double t, N_Vector y, SUNMatrix J) const
Calculates the jacobian of the ODE RHS wrt to its states y at the given time-point t and state y.
static constexpr int cv_rhs_adj(realtype t, N_Vector y, N_Vector yB, N_Vector yBdot, void *user_data)
Implements the function of type CVRhsFnB which is the RHS of the backward ODE system.
int quad_rhs_adj(double t, N_Vector y, N_Vector yB, N_Vector qBdot)
Integrator interface for CVODES' adjoint ODE solvers (Adams & BDF methods).
void set_zero_all_adjoints()
Reset all adjoint values in this nested stack to zero.
A class following the RAII idiom to start and recover nested autodiff scopes.
T * alloc_array(size_t n)
Allocate an array on the arena of the specified size to hold values of the specified template paramet...
Abstract base class that all vari_value and it's derived classes inherit.
Definition vari.hpp:28
#define unlikely(x)
size_t size(const T &m)
Returns the size (number of the elements) of a matrix_cl or var_value<matrix_cl<T>>.
Definition size.hpp:18
typename return_type< Ts... >::type return_type_t
Convenience type for the return type of the specified template parameters.
(Expert) Numerical traits for algorithmic differentiation variables.
constexpr auto for_each(F &&f, T &&t)
Apply a function to each element of a tuple.
Definition for_each.hpp:66
typename promote_scalar_type< std::decay_t< T >, std::decay_t< S > >::type promote_scalar_t
void cvodes_set_options(void *cvodes_mem, long int max_num_steps)
auto & adjoint_of(const T &x)
Returns a reference to a variable's adjoint.
fvar< T > arg(const std::complex< fvar< T > > &z)
Return the phase angle of the complex argument.
Definition arg.hpp:19
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition value_of.hpp:18
vari_value< double > vari
Definition vari.hpp:197
Arith deep_copy_vars(Arith &&arg)
Forward arguments that do not contain vars.
size_t count_vars(Pargs &&... args)
Count the number of vars in the input argument list.
vari ** save_varis(vari **dest, const var &x, Pargs &&... args)
Save the vari pointer in x into the memory pointed to by dest, increment the dest storage pointer,...
void invalid_argument(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw an invalid_argument exception with a consistently formatted message.
void check_finite(const char *function, const char *name, const T_y &y)
Return true if all values in y are finite.
void zero_adjoints() noexcept
End of recursion for set_zero_adjoints.
var_value< double > var
Definition var.hpp:1187
void check_nonzero_size(const char *function, const char *name, const T_y &y)
Check if the specified matrix/vector is of non-zero size.
void check_positive(const char *function, const char *name, const T_y &y)
Check if y is positive.
void check_sorted(const char *function, const char *name, const EigVec &y)
Check if the specified vector is sorted into increasing order (repeated values are okay).
double * accumulate_adjoints(double *dest, const var &x, Pargs &&... args)
Accumulate adjoints from x into storage pointed to by dest, increment the adjoint storage pointer,...
void check_size_match(const char *function, const char *name_i, T_size1 i, const char *name_j, T_size2 j)
Check if the provided sizes match.
void check_less(const char *function, const char *name, const T_y &y, const T_high &high, Idxs... idxs)
Throw an exception if y is not strictly less than high.
static void grad()
Compute the gradient for all variables starting from the end of the AD tape.
Definition grad.hpp:26
constexpr decltype(auto) apply(F &&f, Tuple &&t, PreArgs &&... pre_args)
Definition apply.hpp:52
void check_positive_finite(const char *function, const char *name, const T_y &y)
Check if y is positive and finite.
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...
Definition fvar.hpp:9
STL namespace.
Defines a static member named value which is defined to be false as the primitive scalar types cannot...
Definition is_var.hpp:14
static thread_local AutodiffStackStorage * instance_
This struct always provides access to the autodiff stack using the singleton pattern.
cvodes_solver(const char *function_name, FF &&f, size_t N, const T_y0 &y0, const T_t0 &t0, const std::vector< T_ts > &ts, const Eigen::VectorXd &absolute_tolerance_forward, const Eigen::VectorXd &absolute_tolerance_backward, size_t num_args_vars, int solver_forward, const T_Args &... args)
const std::tuple< promote_scalar_t< partials_type_t< scalar_type_t< T_Args > >, T_Args >... > value_of_args_tuple_
Since the CVODES solver manages memory with malloc calls, these resources must be freed using a destr...
Extends std::false_type when instantiated with zero or more template parameters, all of which extend ...