Higher-Order Functions
Stan provides a few higher-order functions that act on other functions. In all cases, the function arguments to the higher-order functions are defined as functions within the Stan language and passed by name to the higher-order functions.
Algebraic equation solvers
Stan provides two built-in algebraic equation solvers, respectively based on the Newton method and the Powell “dog leg” hybrid method. Empirically the Newton method is found to be faster and its use is recommended for most problems.
An algebraic solver is a higher-order function, i.e. it takes another function as one of its arguments. Other functions in Stan which share this feature are the differential equation solvers (see section Ordinary Differential Equation (ODE) Solvers and Differential Algebraic Equation (DAE) solver). Ordinary Stan functions do not allow functions as arguments.
Specifying an algebraic equation as a function
An algebraic system is specified as an ordinary function in Stan within the function block. The function must return a vector
and takes in, as its first argument, the unknowns \(y\) we wish to solve for, also passed as a vector
. This argument is followed by additional arguments as specified by the user; we call such arguments variadic arguments and denote them ...
. The signature of the algebraic system is then:
vector algebra_system (vector y, ...)
There is no type restriction for the variadic arguments and each argument can be passed as data or parameter. However users should use parameter arguments only when necessary and mark data arguments with the keyword data
. In the below example, the last variadic argument, \(x\), is restricted to being data:
vector algebra_system (vector y, vector theta, data vector x)
Distinguishing data and parameter is important for computational reasons. Augmenting the total number of parameters increases the cost of propagating derivatives through the solution to the algebraic equation, and ultimately the computational cost of evaluating the gradients.
Call to the algebraic solver
vector
solve_newton
(function algebra_system, vector y_guess, ...)
Solves the algebraic system, given an initial guess, using Newton’s method.
vector
solve_newton_tol
(function algebra_system, vector y_guess, data real scaling_step, data real f_tol, int max_steps, ...)
Solves the algebraic system, given an initial guess, using Newton’s method with additional control parameters for the solver.
vector
solve_powell
(function algebra_system, vector y_guess, ...)
Solves the algebraic system, given an initial guess, using Powell’s hybrid method.
vector
solve_powell_tol
(function algebra_system, vector y_guess, data real rel_tol, data real f_tol, int max_steps, ...)
Solves the algebraic system, given an initial guess, using Powell’s hybrid method with additional control parameters for the solver.
Arguments to the algebraic solver
The arguments to the algebraic solvers are as follows:
algebra_system
: function literal referring to a function specifying the system of algebraic equations with signature(vector, ...):vector
. The arguments represent (1) unknowns, (2) additional parameter and/or data arguments, and the return value contains the value of the algebraic function, which goes to 0 when we plug in the solution to the algebraic system,y_guess
: initial guess for the solution, typevector
,...
: variadic arguments.
The algebraic solvers admit control parameters. While Stan provides default values, the user should be prepared to adjust the control parameters. The following controls are available:
scaling_step
: for the Newton solver only, the scaled-step stopping tolerance, typereal
, data only. If a Newton step is smaller than the scaling step tolerance, the code breaks, assuming the solver is no longer making significant progress. If set to 0, this constraint is ignored. Default value is \(10^{-3}\).rel_tol
: for the Powell solver only, the relative tolerance, typereal
, data only. The relative tolerance is the estimated relative error of the solver and serves to test if a satisfactory solution has been found. Default value is \(10^{-10}\).function_tol
: function tolerance for the algebraic solver, typereal
, data only. After convergence of the solver, the proposed solution is plugged into the algebraic system and its norm is compared to the function tolerance. If the norm is below the function tolerance, the solution is deemed acceptable. Default value is \(10^{-6}\).max_num_steps
: maximum number of steps to take in the algebraic solver, typeint
, data only. If the solver reaches this number of steps, it breaks and returns an error message. Default value is \(200\).
The difference in which control parameters are available has to do with the underlying implementations for the solvers and the control parameters these implementations support. The Newton solver is based on KINSOL from the SUNDIAL suites, while the Powell solver uses a module from the Eigen library.
Return value
The return value for the algebraic solver is an object of type vector
, with values which, when plugged in as y
make the algebraic function go to 0 (approximately, within the specified function tolerance).
Sizes and parallel arrays
Certain sizes have to be consistent. The initial guess, return value of the solver, and return value of the algebraic function must all be the same size.
Algorithmic details
Stan offers two methods to solve algebraic equations. solve_newton
and solve_newton_tol
use the Newton method, a first-order derivative based numerical solver. The Stan code builds on the implementation in KINSOL from the SUNDIALS suite (Hindmarsh et al. 2005). For many problems, we find that the Newton method is faster than the Powell method. If however Newton’s method performs poorly, either failing to or requiring an excessively long time to converge, the user should be prepared to switch to the Powell method.
solve_powell
and solve_powell_tol
are based on the Powell hybrid method (Powell 1970), which also uses first-order derivatives. The Stan code builds on the implementation of the hybrid solver in the unsupported module for nonlinear optimization problems of the Eigen library (Guennebaud, Jacob, et al. 2010). This solver is in turn based on the algorithm developed for the package MINPACK-1 (Jorge J. More 1980).
For both solvers, derivatives are propagated through the solution to the algebraic solution using the implicit function theorem and an adjoint method of automatic differentiation; for a discussion on this topic, see (Gaebler 2021) and (Margossian and Betancourt 2022).
Ordinary differential equation (ODE) solvers
Stan provides several higher order functions for solving initial value problems specified as Ordinary Differential Equations (ODEs).
Solving an initial value ODE means given a set of differential equations \(y'(t, \theta) = f(t, y, \theta)\) and initial conditions \(y(t_0, \theta)\), solving for \(y\) at a sequence of times \(t_0 < t_1 < t_2, \cdots < t_n\). \(f(t, y, \theta)\) is referred to here as the ODE system function.
\(f(t, y, \theta)\) will be defined as a function with a certain signature and provided along with the initial conditions and output times to one of the ODE solver functions.
To make it easier to write ODEs, the solve functions take extra arguments that are passed along unmodified to the user-supplied system function. Because there can be any number of these arguments and they can be of different types, they are denoted below as ...
. The types of the arguments represented by ...
in the ODE solve function call must match the types of the arguments represented by ...
in the user-supplied system function.
Non-stiff solver
array[] vector
ode_rk45
(function ode, vector initial_state, real initial_time, array[] real times, ...)
Solves the ODE system for the times provided using the Dormand-Prince algorithm, a 4th/5th order Runge-Kutta method.
array[] vector
ode_rk45_tol
(function ode, vector initial_state, real initial_time, array[] real times, data real rel_tol, data real abs_tol, int max_num_steps, ...)
Solves the ODE system for the times provided using the Dormand-Prince algorithm, a 4th/5th order Runge-Kutta method with additional control parameters for the solver.
array[] vector
ode_ckrk
(function ode, vector initial_state, real initial_time, array[] real times, ...)
Solves the ODE system for the times provided using the Cash-Karp algorithm, a 4th/5th order explicit Runge-Kutta method.
array[] vector
ode_ckrk_tol
(function ode, vector initial_state, real initial_time, array[] real times, data real rel_tol, data real abs_tol, int max_num_steps, ...)
Solves the ODE system for the times provided using the Cash-Karp algorithm, a 4th/5th order explicit Runge-Kutta method with additional control parameters for the solver.
array[] vector
ode_adams
(function ode, vector initial_state, real initial_time, array[] real times, ...)
Solves the ODE system for the times provided using the Adams-Moulton method.
array[] vector
ode_adams_tol
(function ode, vector initial_state, real initial_time, array[] real times, data real rel_tol, data real abs_tol, int max_num_steps, ...)
Solves the ODE system for the times provided using the Adams-Moulton method with additional control parameters for the solver.
Stiff solver
array[] vector
ode_bdf
(function ode, vector initial_state, real initial_time, array[] real times, ...)
Solves the ODE system for the times provided using the backward differentiation formula (BDF) method.
array[] vector
ode_bdf_tol
(function ode, vector initial_state, real initial_time, array[] real times, data real rel_tol, data real abs_tol, int max_num_steps, ...)
Solves the ODE system for the times provided using the backward differentiation formula (BDF) method with additional control parameters for the solver.
Adjoint solver
array[] vector
ode_adjoint_tol_ctl
(function ode, vector initial_state, real initial_time, array[] real times, data real rel_tol_forward, data vector abs_tol_forward, data real rel_tol_backward, data vector abs_tol_backward, int max_num_steps, int num_steps_between_checkpoints, int interpolation_polynomial, int solver_forward, int solver_backward, ...)
Solves the ODE system for the times provided using the adjoint ODE solver method from CVODES. The adjoint ODE solver requires a checkpointed forward in time ODE integration, a backwards in time integration that makes uses of an interpolated version of the forward solution, and the solution of a quadrature problem (the number of which depends on the number of parameters passed to the solve). The tolerances and numeric methods used for the forward solve, backward solve, quadratures, and interpolation can all be configured.
Available since 2.27ODE system function
The first argument to one of the ODE solvers is always the ODE system function. The ODE system function must have a vector
return type, and the first two arguments must be a real
and vector
in that order. These two arguments are followed by the variadic arguments that are passed through from the ODE solve function call:
vector ode(real time, vector state, ...)
The ODE system function should return the derivative of the state with respect to time at the time and state provided. The length of the returned vector must match the length of the state input into the function.
The arguments to this function are:
time
, the time to evaluate the ODE systemstate
, the state of the ODE system at the time specified...
, sequence of arguments passed unmodified from the ODE solve function call. The types here must match the types in the...
arguments of the ODE solve function call.
Arguments to the ODE solvers
The arguments to the ODE solvers in both the stiff and non-stiff solvers are the same. The arguments to the adjoint ODE solver are different; see Arguments to the adjoint ODE solver.
ode
: ODE system function,initial_state
: initial state, typevector
,initial_time
: initial time, typereal
,times
: solution times, typearray[] real
,...
: sequence of arguments that will be passed through unmodified to the ODE system function. The types here must match the types in the...
arguments of the ODE system function.
For the versions of the ode solver functions ending in _tol
, these three parameters must be provided after times
and before the ...
arguments:
data
rel_tol
: relative tolerance for the ODE solver, typereal
, data only,data
abs_tol
: absolute tolerance for the ODE solver, typereal
, data only, andmax_num_steps
: maximum number of steps to take between output times in the ODE solver, typeint
, data only.
Because the tolerances are data
arguments, they must be defined in either the data or transformed data blocks. They cannot be parameters, transformed parameters or functions of parameters or transformed parameters.
Arguments to the adjoint ODE solver
The arguments to the adjoint ODE solver are different from those for the other functions (for those see Arguments to the ODE solvers).
ode
: ODE system function,initial_state
: initial state, typevector
,initial_time
: initial time, typereal
,times
: solution times, typearray[] real
,data
rel_tol_forward
: Relative tolerance for forward solve, typereal
, data only,data
abs_tol_forward
: Absolute tolerance vector for each state for forward solve, typevector
, data only,data
rel_tol_backward
: Relative tolerance for backward solve, typereal
, data only,data
abs_tol_backward
: Absolute tolerance vector for each state for backward solve, typevector
, data only,data
rel_tol_quadrature
: Relative tolerance for backward quadrature, typereal
, data only,data
abs_tol_quadrature
: Absolute tolerance for backward quadrature, typereal
, data only,data
max_num_steps
: Maximum number of time-steps to take in integrating the ODE solution between output time points for forward and backward solve, typeint
, data only,num_steps_between_checkpoints
: number of steps between checkpointing forward solution, typeint
, data only,interpolation_polynomial
: can be 1 for hermite or 2 for polynomial interpolation method of CVODES, typeint
, data only,solver_forward
: solver used for forward ODE problem: 1=Adams (non-stiff), 2=BDF (stiff), typeint
, data only,solver_backward
: solver used for backward ODE problem: 1=Adams (non-stiff), 2=BDF (stiff), typeint
, data only....
: sequence of arguments that will be passed through unmodified to the ODE system function. The types here must match the types in the...
arguments of the ODE system function.
Because the tolerances are data
arguments, they must be defined in either the data or transformed data blocks. They cannot be parameters, transformed parameters or functions of parameters or transformed parameters.
Return values
The return value for the ODE solvers is an array of vectors (type array[] vector
), one vector representing the state of the system at every time in specified in the times
argument.
Array and vector sizes
The sizes must match, and in particular, the following groups are of the same size:
state variables passed into the system function, derivatives returned by the system function, initial state passed into the solver, and length of each vector in the output,
number of solution times and number of vectors in the output.
Differential-Algebraic equation (DAE) solver
Stan provides two higher order functions for solving initial value problems specified as Differential-Algebraic Equations (DAEs) with index-1 (Serban et al. 2021).
Solving an initial value DAE means given a set of residual functions \(r(y'(t, \theta), y(t, \theta), t)\) and initial conditions \((y(t_0, \theta), y'(t_0, \theta))\), solving for \(y\) at a sequence of times \(t_0 < t_1 \leq t_2, \cdots \leq t_n\). The residual function \(r(y', y, t, \theta)\) will be defined as a function with a certain signature and provided along with the initial conditions and output times to one of the DAE solver functions.
Similar to ODE solvers, the DAE solver function takes extra arguments that are passed along unmodified to the user-supplied system function. Because there can be any number of these arguments and they can be of different types, they are denoted below as ...
, and the types of these arguments, also represented by ...
in the DAE solver call, must match the types of the arguments represented by ...
in the user-supplied system function.
The DAE solver
array[] vector
dae
(function residual, vector initial_state, vector initial_state_derivative, data real initial_time, data array[] real times, ...)
Solves the DAE system using the backward differentiation formula (BDF) method (Serban et al. 2021).
array[] vector
dae_tol
(function residual, vector initial_state, vector initial_state_derivative, data real initial_time, data array[] real times, data real rel_tol, data real abs_tol, int max_num_steps, ...)
Solves the DAE system for the times provided using the backward differentiation formula (BDF) method with additional control parameters for the solver.
DAE system function
The first argument to the DAE solver is the DAE residual function. The DAE residual function must have a vector
return type, and the first three arguments must be a real
, vector
, and vector
, in that order. These three arguments are followed by the variadic arguments that are passed through from the DAE solver function call:
vector residual(real time, vector state, vector state_derivative, ...)
The DAE residual function should return the residuals at the time and state provided. The length of the returned vector must match the length of the state input into the function.
The arguments to this function are:
time
, the time to evaluate the DAE systemstate
, the state of the DAE system at the time specifiedstate_derivative
, the time derivatives of the state of the DAE system at the time specified...
, sequence of arguments passed unmodified from the DAE solve function call. The types here must match the types in the...
arguments of the DAE solve function call.
Arguments to the DAE solver
The arguments to the DAE solver are
residual
: DAE residual function,initial_state
: initial state, typevector
,initial_state_derivative
: time derivative of the initial state, typevector
,initial_time
: initial time, typedata real
,times
: solution times, typedata array[] real
,...
: sequence of arguments that will be passed through unmodified to the DAE residual function. The types here must match the types in the...
arguments of the DAE residual function.
For dae_tol
, the following three parameters must be provided after times
and before the ...
arguments:
data
rel_tol
: relative tolerance for the DAE solver, typereal
, data only,data
abs_tol
: absolute tolerance for the DAE solver, typereal
, data only, andmax_num_steps
: maximum number of steps to take between output times in the DAE solver, typeint
, data only.
Because the tolerances are data
arguments, they must be supplied as primitive numerics or defined in either the data or transformed data blocks. They cannot be parameters, transformed parameters or functions of parameters or transformed parameters.
Consistency of the initial conditions
The user is responsible to ensure the residual function becomes zero at the initial time, t0
, when the arguments initial_state
and initial_state_derivative
are introduced as state
and state_derivative
, respectively.
Return values
The return value for the DAE solvers is an array of vectors (type array[] vector
), one vector representing the state of the system at every time specified in the times
argument.
Array and vector sizes
The sizes must match, and in particular, the following groups are of the same size:
state variables and state derivatives passed into the residual function, the residual returned by the residual function, initial state and initial state derivatives passed into the solver, and length of each vector in the output,
number of solution times and number of vectors in the output.
1D integrator
Stan provides a built-in mechanism to perform 1D integration of a function via quadrature methods.
It operates similarly to the algebraic solver and the ordinary differential equations solver in that it allows as an argument a function.
Like both of those utilities, some of the arguments are limited to data only expressions. These expressions must not contain variables other than those declared in the data or transformed data blocks.
Specifying an integrand as a function
Performing a 1D integration requires the integrand to be specified somehow. This is done by defining a function in the Stan functions block with the special signature:
real integrand(real x, real xc, array[] real theta,
array[] real x_r, array[] int x_i)
The function should return the value of the integrand evaluated at the point x.
The argument of this function are:
x
, the independent variable being integrated overxc
, a high precision version of the distance from x to the nearest endpoint in a definite integral (for more into see section Precision Loss).theta
, parameter values used to evaluate the integralx_r
, data values used to evaluate the integralx_i
, integer data used to evaluate the integral
Like algebraic solver and the differential equations solver, the 1D integrator separates parameter values, theta
, from data values, x_r
.
Call to the 1D integrator
real
integrate_1d
(function integrand, real a, real b, array[] real theta, array[] real x_r, array[] int x_i)
Integrates the integrand from a to b.
real
integrate_1d
(function integrand, real a, real b, array[] real theta, array[] real x_r, array[] int x_i, real relative_tolerance)
Integrates the integrand from a to b with the given relative tolerance.
Arguments to the 1D integrator
The arguments to the 1D integrator are as follows:
integrand
: function literal referring to a function specifying the integrand with signature(real, real, array[] real, array[] real, array[] int):real
The arguments represent- where integrand is evaluated,
- distance from evaluation point to integration limit for definite integrals,
- parameters,
- real data
- integer data, and the return value is the integrand evaluated at the given point,
a
: left limit of integration, may be negative infinity, typereal
,b
: right limit of integration, may be positive infinity, typereal
,theta
: parameters only, typearray[] real
,x_r
: real data only, typearray[] real
,x_i
: integer data only, typearray[] int
.
A relative_tolerance
argument can optionally be provided for more control over the algorithm:
relative_tolerance
: relative tolerance for the 1d integrator, typereal
, data only.
Return value
The return value for the 1D integrator is a real
, the value of the integral.
Zero-crossing integrals
For numeric stability, integrals on the (possibly infinite) interval \((a, b)\) that cross zero are split into two integrals, one from \((a, 0)\) and one from \((0, b)\). Each integral is separately integrated to the given relative_tolerance
.
Precision loss near limits of integration in definite integrals
When integrating certain definite integrals, there can be significant precision loss in evaluating the integrand near the endpoints. This has to do with the breakdown in precision of double precision floating point values when adding or subtracting a small number from a number much larger than it in magnitude (for instance, 1.0 - x
). xc
(as passed to the integrand) is a high-precision version of the distance between x
and the definite integral endpoints and can be used to address this issue. More information (and an example where this is useful) is given in the User’s Guide. For zero crossing integrals, xc
will be a high precision version of the distance to the endpoints of the two smaller integrals. For any integral with an endpoint at negative infinity or positive infinity, xc
is set to NaN
.
Algorithmic details
Internally the 1D integrator uses the double-exponential methods in the Boost 1D quadrature library. Boost in turn makes use of quadrature methods developed in (Takahasi and Mori 1974), (Mori 1978), (Bailey, Jeyabalan, and Li 2005), and (Tanaka et al. 2009).
The gradients of the integral are computed in accordance with the Leibniz integral rule. Gradients of the integrand are computed internally with Stan’s automatic differentiation.
Reduce-sum function
Stan provides a higher-order reduce function for summation. A function which returns a scalar g: U -> real
is mapped to every element of a list of type array[] U
, { x1, x2, ... }
and all the results are accumulated,
g(x1) + g(x2) + ...
For efficiency reasons the reduce function doesn’t work with the element-wise evaluated function g
itself, but instead works through evaluating partial sums, f: array[] U -> real
, where:
f({ x1 }) = g(x1)
f({ x1, x2 }) = g(x1) + g(x2)
f({ x1, x2, ... }) = g(x1) + g(x2) + ...
Mathematically the summation reduction is associative and forming arbitrary partial sums in an arbitrary order will not change the result. However, floating point numerics on computers only have a limited precision such that associativity does not hold exactly. This implies that the order of summation determines the exact numerical result. For this reason, the higher-order reduce function is available in two variants:
reduce_sum
: Automatically choose partial sums partitioning based on a dynamic scheduling algorithm.reduce_sum_static
: Compute the same sum asreduce_sum
, but partition the input in the same way for given data set (inreduce_sum
this partitioning might change depending on computer load). This should result in stable numerical evaluations.
Specifying the reduce-sum function
The higher-order reduce function takes a partial sum function f
, an array argument x
(with one array element for each term in the sum), a recommended grainsize
, and a set of shared arguments. This representation allows parallelization of the resultant sum.
real
reduce_sum
(F f, array[] T x, int grainsize, T1 s1, T2 s2, ...)
real
reduce_sum_static
(F f, array[] T x, int grainsize, T1 s1, T2 s2, ...)
Returns the equivalent of f(x, 1, size(x), s1, s2, ...)
, but computes the result in parallel by breaking the array x
into independent partial sums. s1, s2, ...
are shared between all terms in the sum.
f
: function literal referring to a function specifying the partial sum operation. Refer to the partial sum function.x
: array ofT
, one for each term of the reduction,T
can be any type,grainsize
: Forreduce_sum
,grainsize
is the recommended size of the partial sum (grainsize = 1
means pick totally automatically). Forreduce_sum_static
,grainsize
determines the maximum size of the partial sums, typeint
,s1
: first (optional) shared argument, typeT1
, whereT1
can be any types2
: second (optional) shared argument, typeT2
, whereT2
can be any type,...
: remainder of shared arguments, each of which can be any type.
The partial sum function
The partial sum function must have the following signature where the type T
, and the types of all the shared arguments (T1
, T2
, …) match those of the original reduce_sum
(reduce_sum_static
) call.
(array[] T x_subset, int start, int end, T1 s1, T2 s2, ...):real
The partial sum function returns the sum of the start
to end
terms (inclusive) of the overall calculations. The arguments to the partial sum function are:
x_subset
, the subset ofx
a given partial sum is responsible for computing, typearray[] T
, whereT
matches the type ofx
inreduce_sum
(reduce_sum_static
)start
, the index of the first term of the partial sum, typeint
end
, the index of the last term of the partial sum (inclusive), typeint
s1
, first shared argument, typeT1
, matching type ofs1
inreduce_sum
(reduce_sum_static
)s2
, second shared argument, typeT2
, matching type ofs2
inreduce_sum
(reduce_sum_static
)...
, remainder of shared arguments, with types matching those inreduce_sum
(reduce_sum_static
)
Map-rect function
Stan provides a higher-order map function. This allows map-reduce functionality to be coded in Stan as described in the user’s guide.
Specifying the mapped function
The function being mapped must have a signature identical to that of the function f
in the following declaration.
vector f(vector phi, vector theta,
data array[] real x_r, data array[] int x_i);
The map function returns the sequence of results for the particular shard being evaluated. The arguments to the mapped function are:
phi
, the sequence of parameters shared across shardstheta
, the sequence of parameters specific to this shardx_r
, sequence of real-valued datax_i
, sequence of integer data
All input for the mapped function must be packed into these sequences and all output from the mapped function must be packed into a single vector. The vector of output from each mapped function is concatenated into the final result.
Rectangular map
The rectangular map function operates on rectangular (not ragged) data structures, with parallel data structures for job-specific parameters, job-specific real data, and job-specific integer data.
vector
map_rect
(F f, vector phi, array[] vector theta, data array[,] real x_r, data array[,] int x_i)
Return the concatenation of the results of applying the function f, of type (vector, vector, array[] real, array[] int):vector
elementwise, i.e., f(phi, theta[n], x_r[n], x_i[n])
for each n
in 1:N
, where N
is the size of the parallel arrays of job-specific/local parameters theta
, real data x_r
, and integer data x_r
. The shared/global parameters phi
are passed to each invocation of f
.