28.3 Cross-validation
Cross-validation involves choosing multiple subsets of a data set as the test set and using the other data as training. This can be done by partitioning the data and using each subset in turn as the test set with the remaining subsets as training data. A partition into ten subsets is common to reduce computational overhead. In the limit, when the test set is just a single item, the result is known as leave-one-out (LOO) cross-validation (Vehtari, Gelman, and Gabry 2017).
Partitioning the data and reusing the partitions is very fiddly in the indexes and may not lead to even divisions of the data. It’s far easier to use random partitions, which support arbitrarily sized test/training splits and can be easily implemented in Stan. The drawback is that the variance of the resulting estimate is higher than with a balanced block partition.
28.3.1 Stan implementation with random folds
For the simple linear regression model, randomized cross-validation can be implemented in a single model. To randomly permute a vector in Stan, the simplest approach is the following.
functions {
int[] permutation_rng(int N) {
int N = rows(x);
int y[N];
for (n in 1:N)
y[n] = n;
vector[N] theta = rep_vector(1.0 / N, N);
for (n in 1:rows(y))
int i = categorical_rng(theta);
int temp = y[n];
y[n] = y[i];
y[i] = temp;
}
return y;
}
}
The name of the function must end in _rng
because it uses other
random functions internally. This will restrict its usage to the
transformed data and generated quantities block. The code walks
through the vector exchanging each item with another randomly chosen
item, resulting in a uniformly drawn permutation of the integers
1:N
.46
The transformed data block uses the permutation RNG to generate training data and test data by taking prefixes and suffixes of the permuted data.
data {
int<lower = 0> N;
vector[N] x;
vector[N] y;
int<lower = 0, upper = N> N_test;
}
transformed data {
int N_train = N - N_test;
int permutation[N] = permutation_rng(N);
vector[N_train] x_train = x[permutation[1 : N_train]];
vector[N_train] y_train = y[permutation[1 : N_train]];
vector[N_test] x_test = x[permutation[N_train + 1 : N]];
vector[N_test] y_test = y[permutation[N_train + 1 : N]];
}
Recall that in Stan, permutation[1:N_train]
is an array of integers,
so that x[permutation[1 : N_train]]
is a vector defined for i in 1:N_train
by
x[permutation[1 : N_train]][i] = x[permutation[1:N_train][i]]
= x[permutation[i]]
Given the test/train split, the rest of the model is straightforward.
parameters {
real alpha;
real beta;
real<lower = 0> sigma;
}
model {
y_train ~ normal(alpha + beta * x_train, sigma);
{ alpha, beta, sigma } ~ normal(0, 1);
}
generated quantities {
vector[N] y_test_hat = normal_rng(alpha + beta * x_test, sigma);
vector[N] err = y_test_sim - y_hat;
}
The prediction y_test_hat
is defined in the generated quantities
block using the general form involving all uncertainty. The posterior
of this quantity corresponds to using a posterior mean estimator,
\[\begin{eqnarray*}
\hat{y}^{\textrm{test}}
& = & \mathbb{E}\left[ y^{\textrm{test}} \mid x^{\textrm{test}}, x^{\textrm{train}} y^{\textrm{train}} \right]
\\[4pt]
& \approx & \frac{1}{M} \sum_{m = 1}^M \hat{y}^{\textrm{test}(m)}.
\end{eqnarray*}\]
Because the test set is constant and the expectation operator is
linear, the posterior mean of err
as defined in the Stan program
will be the error of the posterior mean estimate,
\[\begin{eqnarray*}
\hat{y}^{\textrm{test}} - y^{\textrm{test}}
& = &
\mathbb{E}\left[
\hat{y}^{\textrm{test}}
\mid x^{\textrm{test}}, x^{\textrm{train}}, y^{\textrm{train}}
\right]
- y^{\textrm{test}}
\\[4pt]
& = &
\mathbb{E}\left[
\hat{y}^{\textrm{test}} - y^{\textrm{test}}
\mid x^{\textrm{test}}, x^{\textrm{train}}, y^{\textrm{train}}
\right]
\\[4pt]
& \approx &
\frac{1}{M} \sum_{m = 1}^M \hat{y}^{\textrm{test}(m)} - y^{\textrm{test}},
\end{eqnarray*}\]
where
\[
\hat{y}^{\textrm{test}(m)}
\sim p(y \mid x^{\textrm{test}}, x^{\textrm{train}},
y^{\textrm{train}}).
\]
This just calculates error; taking absolute value or squaring will
compute absolute error and mean square error. Note that the absolute
value and square operation should not be done within the Stan
program because neither is a linear function and the result of
averaging squares is not the same as squaring an average in general.
Because the test set size is chosen for convenience in cross-validation, results should be presented on a per-item scale, such as average absolute error or root mean square error, not on the scale of error in the fold being evaluated.
28.3.2 User-defined permutations
It is straightforward to declare the variable permutation
in the
data block instead of the transformed data block and read it in as
data. This allows an external program to control the blocking,
allowing non-random partitions to be evaluated.
28.3.3 Cross-validation with structured data
Cross-validation must be done with care if the data is inherently structured. For example, in a simple natural language application, data might be structured by document. For cross-validation, one needs to cross-validate at the document level, not at the individual word level. This is related to mixed replication in posterior predictive checking, where there is a choice to simulate new elements of existing groups or generate entirely new groups.
Education testing applications are typically grouped by school district, by school, by classroom, and by demographic features of the individual students or the school as a whole. Depending on the variables of interest, different structured subsets should be evaluated. For example, the focus of interest may be on the performance of entire classrooms, so it would make sense to cross-validate at the class or school level on classroom performance.
28.3.4 Cross-validation with spatio-temporal data
Often data measurements have spatial or temporal properties. For example, home energy consumption varies by time of day, day of week, on holidays, by season, and by ambient temperature (e.g., a hot spell or a cold snap). Cross-validation must be tailored to the predictive goal. For example, in predicting energy consumption, the quantity of interest may be the prediction for next week’s energy consumption given historical data and current weather covariates. This suggests an alternative to cross-validation, wherein individual weeks are each tested given previous data. This often allows comparing how well prediction performs with more or less historical data.
28.3.5 Approximate cross-validation
Vehtari, Gelman, and Gabry (2017) introduce a method that approximates the evaluation of leave-one-out cross validation inexpensively using only the data point log likelihoods from a single model fit. This method is documented and implemented in the R package loo (Gabry et al. 2019).
References
The traditional approach is to walk through a vector and replace each item with a random element from the remaining elements, which is guaranteed to only move each item once. This was not done here as it’d require new categorical
theta
because Stan does not have a uniform discrete RNG built in.↩︎