23.1 Reduce-Sum
It is often necessary in probabilistic modeling to compute the sum of a number of independent function evaluations. This occurs, for instance, when evaluating a number of conditionally independent terms in a log-likelihood. If g: U -> real
is the function and { x1, x2, ... }
is an array of inputs, then that sum looks like:
g(x1) + g(x2) + ...
reduce_sum
and reduce_sum_static
are tools for parallelizing these calculations.
For efficiency reasons the reduce function doesn’t work with the element-wise evaluated function g
, but instead the partial sum function f: U[] -> real
, where f
computes the partial sum corresponding to a slice of the sequence x
passed in. Due to the associativity of the sum reduction it holds that:
g(x1) + g(x2) + g(x3) = f({ x1, x2, x3 })
= f({ x1, x2 }) + f({ x3 })
= f({ x1 }) + f({ x2, x3 })
= f({ x1 }) + f({ x2 }) + f({ x3 })
With the partial sum function f: U[] -> real
reduction of a large number of terms can be evaluated in parallel automatically, since the overall sum can be partitioned into arbitrary smaller partial sums. The exact partitioning into the partial sums is not under the control of the user. However, since the exact numerical result will depend on the order of summation, Stan provides two versions of the reduce summation facility:
reduce_sum
: Automatically choose partial sums partitioning based on a dynamic scheduling algorithm.reduce_sum_static
: Compute the same sum asreduce_sum
, but partition the input in the same way for given data set (inreduce_sum
this partitioning might change depending on computer load).
grainsize
is the one tuning parameter. For reduce_sum
, grainsize
is a suggested partial sum size. A grainsize
of 1 leaves the partitioning entirely up to the scheduler. This should be the default way of using reduce_sum
unless time is spent carefully picking grainsize
. For picking a grainsize
, see details below.
For reduce_sum_static
, grainsize
specifies the maximal partial sum size. With reduce_sum_static
it is more important to choose grainsize
carefully since it entirely determines the partitioning of work. See details below.
For efficiency and convenience additional shared arguments can be passed to every term in the sum. So for the array { x1, x2, ... }
and the shared arguments s1, s2, ...
the effective sum (with individual terms) looks like:
g(x1, s1, s2, ...) + g(x2, s1, s2, ...) + g(x3, s1, s2, ...) + ...
which can be written equivalently with partial sums to look like:
f({ x1, x2 }, s1, s2, ...) + f({ x3 }, s1, s2, ...)
where the particular slicing of the x
array can change.
Given this, the signatures are:
real reduce_sum(F f, T[] x, int grainsize, T1 s1, T2 s2, ...)
real reduce_sum_static(F f, T[] x, int grainsize, T1 s1, T2 s2, ...)
f
- User defined function that computes partial sumsx
- Array to slice, each element corresponds to a term in the summationgrainsize
- Target for size of slicess1, s2, ...
- Arguments shared in every term
The user-defined partial sum functions have the signature:
real f(T[] x_slice, int start, int end, T1 s1, T2 s2, ...)
and take the arguments:
x_slice
- The subset ofx
(fromreduce_sum
/reduce_sum_static
) for which this partial sum is responsible (x_slice = x[start:end]
)start
- An integer specifying the first term in the partial sumend
- An integer specifying the last term in the partial sum (inclusive)s1, s2, ...
- Arguments shared in every term (passed on without modification from thereduce_sum
/reduce_sum_static
call)
The user-provided function f
is expected to compute the partial sum with the terms start
through end
of the overall sum. The user function is passed the subset x[start:end]
as x_slice
. start
and end
are passed so that f
can index any of the tailing sM
arguments as necessary. The trailing sM
arguments are passed without modification to every call of f
.
A reduce_sum
(or reduce_sum_static
) call:
real sum = reduce_sum(f, x, grainsize, s1, s2, ...);
can be replaced by either:
real sum = f(x, 1, size(x), s1, s2, ...);
or the code:
real sum = 0.0;
for(i in 1:size(x)) {
sum += f({ x[i] }, i, i, s1, s2, ...);
}
23.1.1 Example: Logistic Regression
Logistic regression is a useful example to clarify both the syntax and semantics of reduce summation and how it can be used to speed up a typical model. A basic logistic regression can be coded in Stan as:
data {
int N;
int y[N];
vector[N] x;
}
parameters {
vector[2] beta;
}
model {
beta ~ std_normal();
y ~ bernoulli_logit(beta[1] + beta[2] * x);
}
In this model predictions are made about the N
outputs y
using the covariate x
. The intercept and slope of the linear equation are to be estimated. The key point to getting this calculation to use reduce summation, is recognizing that the statement:
y ~ bernoulli_logit(beta[1] + beta[2] * x);
can be rewritten (up to a proportionality constant) as:
for(n in 1:N) {
target += bernoulli_logit_lpmf(y[n] | beta[1] + beta[2] * x[n])
}
Now it is clear that the calculation is the sum of a number of conditionally independent Bernoulli log probability statements, which is the condition where reduce summation is useful. To use the reduce summation, a function must be written that can be used to compute arbitrary partial sums of the total sum. Using the interface defined in Reduce-Sum, such a function can be written like:
functions {
real partial_sum(int[] y_slice,
int start, int end,
vector x,
vector beta) {
return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
}
}
The likelihood statement in the model can now be written:
target += partial_sum(y, 1, N, x, beta); // Sum terms 1 to N of the likelihood
In this example, y
was chosen to be sliced over because there is one term in the summation per value of y
. Technically x
would have worked as well. Use whatever conceptually makes the most sense for a given model, e.g. slice over independent terms like conditionally independent observations or groups of observations as in hierarchical models. Because x
is a shared argument, it is subset accordingly with start:end
. With this function, reduce summation can be used to automatically parallelize the likelihood:
int grainsize = 1;
target += reduce_sum(partial_sum, y,
grainsize,
x, beta);
The reduce summation facility automatically breaks the sum into pieces and computes them in parallel. grainsize = 1
specifies that the grainsize
should be estimated automatically. The final model is:
functions {
real partial_sum(int[] y_slice,
int start, int end,
vector x,
vector beta) {
return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
}
}
data {
int N;
int y[N];
vector[N] x;
}
parameters {
vector[2] beta;
}
model {
int grainsize = 1;
beta ~ std_normal();
target += reduce_sum(partial_sum, y,
grainsize,
x, beta);
}
23.1.2 Picking the Grainsize
The rational for choosing a sensible grainsize
is based on balancing the overhead implied by creating many small tasks versus creating fewer large tasks which limits the potential parallelism.
In reduce_sum
, grainsize
is a recommendation on how to partition the work in the partial sum into smaller pieces. A grainsize
of 1 leaves this entirely up to the internal scheduler and should be chosen if no benchmarking of other grainsizes is done. Ideally this will be efficient, but there are no guarantees.
In reduce_sum_static
, grainsize
is an upper limit on the worksize. Work will be split until all partial sums are just smaller than grainsize
(and the split will happen the same way every time for the same inputs). For the static version it is more important to select a sensible grainsize
.
In order to figure out an optimal grainsize
, if there are N
terms and M
cores, run a quick test model with grainsize
set roughly to N / M
. Record the time, cut the grainsize
in half, and run the test again. Repeat this iteratively until the model runtime begins to increase. This is a suitable grainsize
for the model, because this ensures the caculations can be carried out with the most parallelism without losing too much efficiency.
For instance, in a model with N=10000
and M = 4
, start with grainsize = 25000
, and sequentially try grainsize = 12500
, grainsize = 6250
, etc.
It is important to repeat this process until performance gets worse. It is possible after many halvings nothing happens, but there might still be a smaller grainsize
that performs better. Even if a sum has many tens of thousands of terms, depending on the internal calculations, a grainsize
of thirty or forty or smaller might be the best, and it is difficult to predict this behavior. Without doing these halvings until performance actually gets worse, it is easy to miss this.