This is an old version, view current version.

30.3 Cross-validation

Cross-validation involves choosing multiple subsets of a data set as the test set and using the other data as training. This can be done by partitioning the data and using each subset in turn as the test set with the remaining subsets as training data. A partition into ten subsets is common to reduce computational overhead. In the limit, when the test set is just a single item, the result is known as leave-one-out (LOO) cross-validation (Vehtari, Gelman, and Gabry 2017).

Partitioning the data and reusing the partitions is very fiddly in the indexes and may not lead to even divisions of the data. It’s far easier to use random partitions, which support arbitrarily sized test/training splits and can be easily implemented in Stan. The drawback is that the variance of the resulting estimate is higher than with a balanced block partition.

30.3.1 Stan implementation with random folds

For the simple linear regression model, randomized cross-validation can be implemented in a single model. To randomly permute a vector in Stan, the simplest approach is the following.

functions {
  array[] int permutation_rng(int N) {
    array[N] int y;
    for (n in 1 : N) {
      y[n] = n;
    }
    vector[N] theta = rep_vector(1.0 / N, N);
    for (n in 1 : size(y)) {
      int i = categorical_rng(theta);
      int temp = y[n];
      y[n] = y[i];
      y[i] = temp;
    }
    return y;
  }
}

The name of the function must end in _rng because it uses other random functions internally. This will restrict its usage to the transformed data and generated quantities block. The code walks through an array of integers exchanging each item with another randomly chosen item, resulting in a uniformly drawn permutation of the integers 1:N.46

The transformed data block uses the permutation RNG to generate training data and test data by taking prefixes and suffixes of the permuted data.

data {
  int<lower=0> N;
  vector[N] x;
  vector[N] y;
  int<lower=0, upper=N> N_test;
}
transformed data {
  int N_train = N - N_test;
  array[N] int permutation = permutation_rng(N);
  vector[N_train] x_train = x[permutation[1 : N_train]];
  vector[N_train] y_train = y[permutation[1 : N_train]];
  vector[N_test] x_test = x[permutation[N_train + 1 : N]];
  vector[N_test] y_test = y[permutation[N_train + 1 : N]];
}

Recall that in Stan, permutation[1:N_train] is an array of integers, so that x[permutation[1 : N_train]] is a vector defined for i in 1:N_train by

x[permutation[1 : N_train]][i] = x[permutation[1:N_train][i]]
                               = x[permutation[i]]

Given the test/train split, the rest of the model is straightforward.

parameters {
  real alpha;
  real beta;
  real<lower=0> sigma;
}
model {
  y_train ~ normal(alpha + beta * x_train, sigma);
  { alpha, beta, sigma } ~ normal(0, 1);
}
generated quantities {
  vector[N] y_test_hat = normal_rng(alpha + beta * x_test, sigma);
  vector[N] err = y_test_sim - y_hat;
}

The prediction y_test_hat is defined in the generated quantities block using the general form involving all uncertainty. The posterior of this quantity corresponds to using a posterior mean estimator, \[\begin{eqnarray*} \hat{y}^{\textrm{test}} & = & \mathbb{E}\left[ y^{\textrm{test}} \mid x^{\textrm{test}}, x^{\textrm{train}} y^{\textrm{train}} \right] \\[4pt] & \approx & \frac{1}{M} \sum_{m = 1}^M \hat{y}^{\textrm{test}(m)}. \end{eqnarray*}\]

Because the test set is constant and the expectation operator is linear, the posterior mean of err as defined in the Stan program will be the error of the posterior mean estimate, \[\begin{eqnarray*} \hat{y}^{\textrm{test}} - y^{\textrm{test}} & = & \mathbb{E}\left[ \hat{y}^{\textrm{test}} \mid x^{\textrm{test}}, x^{\textrm{train}}, y^{\textrm{train}} \right] - y^{\textrm{test}} \\[4pt] & = & \mathbb{E}\left[ \hat{y}^{\textrm{test}} - y^{\textrm{test}} \mid x^{\textrm{test}}, x^{\textrm{train}}, y^{\textrm{train}} \right] \\[4pt] & \approx & \frac{1}{M} \sum_{m = 1}^M \hat{y}^{\textrm{test}(m)} - y^{\textrm{test}}, \end{eqnarray*}\] where \[ \hat{y}^{\textrm{test}(m)} \sim p(y \mid x^{\textrm{test}}, x^{\textrm{train}}, y^{\textrm{train}}). \] This just calculates error; taking absolute value or squaring will compute absolute error and mean square error. Note that the absolute value and square operation should not be done within the Stan program because neither is a linear function and the result of averaging squares is not the same as squaring an average in general.

Because the test set size is chosen for convenience in cross-validation, results should be presented on a per-item scale, such as average absolute error or root mean square error, not on the scale of error in the fold being evaluated.

30.3.2 User-defined permutations

It is straightforward to declare the variable permutation in the data block instead of the transformed data block and read it in as data. This allows an external program to control the blocking, allowing non-random partitions to be evaluated.

30.3.3 Cross-validation with structured data

Cross-validation must be done with care if the data is inherently structured. For example, in a simple natural language application, data might be structured by document. For cross-validation, one needs to cross-validate at the document level, not at the individual word level. This is related to mixed replication in posterior predictive checking, where there is a choice to simulate new elements of existing groups or generate entirely new groups.

Education testing applications are typically grouped by school district, by school, by classroom, and by demographic features of the individual students or the school as a whole. Depending on the variables of interest, different structured subsets should be evaluated. For example, the focus of interest may be on the performance of entire classrooms, so it would make sense to cross-validate at the class or school level on classroom performance.

30.3.4 Cross-validation with spatio-temporal data

Often data measurements have spatial or temporal properties. For example, home energy consumption varies by time of day, day of week, on holidays, by season, and by ambient temperature (e.g., a hot spell or a cold snap). Cross-validation must be tailored to the predictive goal. For example, in predicting energy consumption, the quantity of interest may be the prediction for next week’s energy consumption given historical data and current weather covariates. This suggests an alternative to cross-validation, wherein individual weeks are each tested given previous data. This often allows comparing how well prediction performs with more or less historical data.

30.3.5 Approximate cross-validation

Vehtari, Gelman, and Gabry (2017) introduce a method that approximates the evaluation of leave-one-out cross validation inexpensively using only the data point log likelihoods from a single model fit. This method is documented and implemented in the R package loo (Gabry et al. 2019).

References

Gabry, Jonah, Daniel Simpson, Aki Vehtari, Michael Betancourt, and Andrew Gelman. 2019. “Visualization in Bayesian Workflow.” Journal of the Royal Statistical Society: Series A (Statistics in Society) 182 (2): 389–402.
Vehtari, Aki, Andrew Gelman, and Jonah Gabry. 2017. “Practical Bayesian Model Evaluation Using Leave-One-Out Cross-Validation and WAIC.” Statistics and Computing 27 (5): 1413–32.

  1. The traditional approach is to walk through a vector and replace each item with a random element from the remaining elements, which is guaranteed to only move each item once. This was not done here as it’d require new categorical theta because Stan does not have a uniform discrete RNG built in.↩︎