6.12 Chain Rule and Derivatives
Derivatives of the log probability function defined by a model are used in several ways by Stan. The Hamiltonian Monte Carlo samplers, including NUTS, use gradients to guide updates. The BFGS optimizers also use gradients to guide search for posterior modes.
Errors Due to Chain Rule
Unlike evaluations in pure mathematics, evaluation of derivatives in Stan is done by applying the chain rule on an expression-by-expression basis, evaluating using floating-point arithmetic. As a result, models such as the following are problematic for inference involving derivatives.
parameters {
real x;
}
model {
x ~ normal(sqrt(x - x), 1);
}
Algebraically, the sampling statement in the model could be reduced to
x ~ normal(0, 1);
and it would seem the model should produce unit normal draws for
x
. But rather than canceling, the expression sqrt(x - x)
causes a problem for derivatives. The cause is the mechanistic
evaluation of the chain rule,
ddx√x−x=12√x−x×ddx(x−x)=10×(1−1)=∞×0=NaN.
Rather than the x−x canceling out, it introduces a 0 into the numerator and denominator of the chain-rule evaluation.
The only way to avoid this kind problem is to be careful to do the
necessary algebraic reductions as part of the model and not introduce
expressions like sqrt(x - x)
for which the chain rule produces
not-a-number values.
Diagnosing Problems with Derivatives
The best way to diagnose whether something is going wrong with the derivatives is to use the test-gradient option to the sampler or optimizer inputs; this option is available in both Stan and RStan (though it may be slow, because it relies on finite differences to make a comparison to the built-in automatic differentiation).
For example, compiling the above model to an executable
sqrt-x-minus-x
in CmdStan, the test can be run as
> ./sqrt-x-minus-x diagnose test=gradient
which produces
...
TEST GRADIENT MODE
Log probability=-0.393734
param idx value model finite diff error
0 -0.887393 nan 0 nan
Even though finite differences calculates the right gradient of 0, automatic differentiation follows the chain rule and produces a not-a-number output.