6.12 Chain Rule and Derivatives

Derivatives of the log probability function defined by a model are used in several ways by Stan. The Hamiltonian Monte Carlo samplers, including NUTS, use gradients to guide updates. The BFGS optimizers also use gradients to guide search for posterior modes.

Errors Due to Chain Rule

Unlike evaluations in pure mathematics, evaluation of derivatives in Stan is done by applying the chain rule on an expression-by-expression basis, evaluating using floating-point arithmetic. As a result, models such as the following are problematic for inference involving derivatives.

parameters {
  real x;
}
model {
  x ~ normal(sqrt(x - x), 1);
}

Algebraically, the sampling statement in the model could be reduced to

  x ~ normal(0, 1);

and it would seem the model should produce unit normal draws for x. But rather than canceling, the expression sqrt(x - x) causes a problem for derivatives. The cause is the mechanistic evaluation of the chain rule,

\[ \begin{array}{rcl} \frac{d}{dx} \sqrt{x - x} & = & \frac{1}{2 \sqrt{x - x}} \times \frac{d}{dx} (x - x) \\[4pt] & = & \frac{1}{0} \times (1 - 1) \\[4pt] & = & \infty \times 0 \\[4pt] & = & \mathrm{NaN}. \end{array} \]

Rather than the \(x - x\) canceling out, it introduces a 0 into the numerator and denominator of the chain-rule evaluation.

The only way to avoid this kind problem is to be careful to do the necessary algebraic reductions as part of the model and not introduce expressions like sqrt(x - x) for which the chain rule produces not-a-number values.

Diagnosing Problems with Derivatives

The best way to diagnose whether something is going wrong with the derivatives is to use the test-gradient option to the sampler or optimizer inputs; this option is available in both Stan and RStan (though it may be slow, because it relies on finite differences to make a comparison to the built-in automatic differentiation).

For example, compiling the above model to an executable sqrt-x-minus-x in CmdStan, the test can be run as

> ./sqrt-x-minus-x diagnose test=gradient

which produces

...
TEST GRADIENT MODE

 Log probability=-0.393734

 param idx           value           model     finite diff           error
         0       -0.887393             nan               0             nan

Even though finite differences calculates the right gradient of 0, automatic differentiation follows the chain rule and produces a not-a-number output.