13.7 Adjoint ODE solver
The adjoint ODE solver method differs mathematically from the forward ODE solvers in the way gradients of the ODE solution are obtained. The forward ODE approach augments the original ODE system with \(N\) additional states for each parameter for which gradients are needed. If there are \(M\) parameters for which sensitivities are required, then the augmented ODE system has a total of \(N \cdot (M + 1)\) states. This can result in very large ODE systems through the multiplicative scaling of the computational effort needed.
In contrast, the adjoint ODE solver integrates forward in time a
system of \(N\) equations to compute the ODE solution and then integrates
backwards in time another system of \(N\) equations to get the sensitivities.
Additionally, for \(M\) parameters there are \(M\) additional equations
to integrate during the backwards solve. Because of this the adjoint
sensitivity problem scales better in parameters than the forward
sensitivity problem. The adjoint solver in Stan uses CVODES (the same
as the bdf
and adams
forward sensitivity interfaces).
The solution computed in the forward integration is required during the backward integration. CVODES uses a checkpointing scheme that saves the forward solver state regularly. The number of steps between saving checkpoints is configurable in the interface. These checkpoints are then interpolated during the backward solve using one of two interpolation schemes.
The solver type (either bdf
or adams
) can be individually set for
both the forward and backward solves.
The tolerances for each phase of the solve must be specified in the interface. Note that the absolute tolerance for the forward and backward ODE integration phase need to be set for each ODE state separately. The harmonic oscillator example call from above becomes:
vector[2] y_sim[T]
= ode_adjoint_tol_ctl(sho, y0, t0, ts,
relative_tolerance/9.0, // forward tolerance
rep_vector(absolute_tolerance/9.0, 2), // forward tolerance
relative_tolerance/3.0, // backward tolerance
rep_vector(absolute_tolerance/3.0, 2), // backward tolerance
relative_tolerance, // quadrature tolerance
absolute_tolerance, // quadrature tolerance
max_num_steps,
150, // number of steps between checkpoints
1, // interpolation polynomial: 1=Hermite, 2=polynomial
2, // solver for forward phase: 1=Adams, 2=BDF
2, // solver for backward phase: 1=Adams, 2=BDF
theta);
For a detailed information on each argument please see the Stan function reference manual.