13.7 Adjoint ODE solver

The adjoint ODE solver method differs mathematically from the forward ODE solvers in the way gradients of the ODE solution are obtained. The forward ODE approach augments the original ODE system with \(N\) additional states for each parameter for which gradients are needed. If there are \(M\) parameters for which sensitivities are required, then the augmented ODE system has a total of \(N \cdot (M + 1)\) states. This can result in very large ODE systems through the multiplicative scaling of the computational effort needed.

In contrast, the adjoint ODE solver integrates forward in time a system of \(N\) equations to compute the ODE solution and then integrates backwards in time another system of \(N\) equations to get the sensitivities. Additionally, for \(M\) parameters there are \(M\) additional equations to integrate during the backwards solve. Because of this the adjoint sensitivity problem scales better in parameters than the forward sensitivity problem. The adjoint solver in Stan uses CVODES (the same as the bdf and adams forward sensivitity interfaces).

The solution computed in the forward integration is required during the backward integration. CVODES uses a checkpointing scheme that saves the forward solver state regularly. The number of steps between saving checkpoints is configurable in the interface. These checkpoints are then interpolated during the backward solve using one of two interpolation schemes.

The solver type (either bdf or adams) can be individually set for both the forward and backward solves.

The tolerances for each phase of the solve must be specified in the interface. Note that the absolute tolerance for the forward and backward ODE integration phase need to be set for each ODE state separately. The harmonic oscillator example call from above becomes:

vector[2] y_sim[T]
    = ode_adjoint_tol_ctl(sho, y0, t0, ts,
                          relative_tolerance/9.0,                # forward tolerance
                          rep_vector(absolute_tolerance/9.0, 2), # forward tolerance
                          relative_tolerance/3.0,                # backward tolerance
                          rep_vector(absolute_tolerance/3.0, 2), # backward tolerance
                          relative_tolerance,                    # quadrature tolerance
                          absolute_tolerance,                    # quadrature tolerance
                          150,                                   # number of steps between checkpoints
                          1,                                     # interpolation polynomial: 1=Hermite, 2=polynomial
                          2,                                     # solver for forward phase: 1=Adams, 2=BDF 
                          2,                                     # solver for backward phase: 1=Adams, 2=BDF 

For a detailed information on each argument please see the Stan function reference manual.