25.1 Reduce-sum
It is often necessary in probabilistic modeling to compute the sum of
a number of independent function evaluations. This occurs, for instance, when
evaluating a number of conditionally independent terms in a log-likelihood.
If g: U -> real
is the function and { x1, x2, ... }
is an array of
inputs, then that sum looks like:
g(x1) + g(x2) + ...
reduce_sum
and reduce_sum_static
are tools for parallelizing these
calculations.
For efficiency reasons the reduce function doesn’t work with the
element-wise evaluated function g
, but instead the partial
sum function f: U[] -> real
, where f
computes the partial
sum corresponding to a slice of the sequence x
passed in. Due to the
associativity of the sum reduction it holds that:
g(x1) + g(x2) + g(x3) = f({ x1, x2, x3 })
= f({ x1, x2 }) + f({ x3 })
= f({ x1 }) + f({ x2, x3 })
= f({ x1 }) + f({ x2 }) + f({ x3 })
With the partial sum function f: U[] -> real
reduction of a
large number of terms can be evaluated in parallel automatically, since the
overall sum can be partitioned into arbitrary smaller partial
sums. The exact partitioning into the partial sums is not under the
control of the user. However, since the exact numerical result will
depend on the order of summation, Stan provides two versions of the
reduce summation facility:
reduce_sum
: Automatically choose partial sums partitioning based on a dynamic scheduling algorithm.reduce_sum_static
: Compute the same sum asreduce_sum
, but partition the input in the same way for given data set (inreduce_sum
this partitioning might change depending on computer load).
grainsize
is the one tuning parameter. For reduce_sum
, grainsize
is
a suggested partial sum size. A grainsize
of 1 leaves the partitioning
entirely up to the scheduler. This should be the default way of using
reduce_sum
unless time is spent carefully picking grainsize
. For picking a grainsize
, see details below.
For reduce_sum_static
, grainsize
specifies the maximal partial sum size.
With reduce_sum_static
it is more important to choose grainsize
carefully since it entirely determines the partitioning of work.
See details below.
For efficiency and convenience additional
shared arguments can be passed to every term in the sum. So for the
array { x1, x2, ... }
and the shared arguments s1, s2, ...
stan
the effective sum (with individual terms) looks like:
g(x1, s1, s2, ...) + g(x2, s1, s2, ...) + g(x3, s1, s2, ...) + ...
which can be written equivalently with partial sums to look like:
f({ x1, x2 }, s1, s2, ...) + f({ x3 }, s1, s2, ...)
where the particular slicing of the x
array can change.
Given this, the signatures are:
real reduce_sum(F f, array[] T x, int grainsize, T1 s1, T2 s2, ...)
real reduce_sum_static(F f, array[] T x, int grainsize, T1 s1, T2 s2, ...)
f
- User defined function that computes partial sumsx
- Array to slice, each element corresponds to a term in the summationgrainsize
- Target for size of slicess1, s2, ...
- Arguments shared in every term
The user-defined partial sum functions have the signature:
real f(array[] T x_slice, int start, int end, T1 s1, T2 s2, ...)
and take the arguments:
x_slice
- The subset ofx
(fromreduce_sum
/reduce_sum_static
) for which this partial sum is responsible (x_slice = x[start:end]
)start
- An integer specifying the first term in the partial sumend
- An integer specifying the last term in the partial sum (inclusive)s1, s2, ...
- Arguments shared in every term (passed on without modification from thereduce_sum
/reduce_sum_static
call)
The user-provided function f
is expected to compute the partial
sum with the terms start
through end
of the overall
sum. The user function is passed the subset x[start:end]
as
x_slice
. start
and end
are passed so that f
stan
can index any of the tailing sM
arguments as necessary. The
trailing sM
arguments are passed without modification to every
call of f
.
A reduce_sum
(or reduce_sum_static
) call:
real sum = reduce_sum(f, x, grainsize, s1, s2, ...);
can be replaced by either:
real sum = f(x, 1, size(x), s1, s2, ...);
or the code:
real sum = 0.0;
for(i in 1:size(x)) {
sum += f({ x[i] }, i, i, s1, s2, ...);
}
25.1.1 Example: logistic regression
Logistic regression is a useful example to clarify both the syntax and semantics of reduce summation and how it can be used to speed up a typical model. A basic logistic regression can be coded in Stan as:
data {
int N;
array[N] int y;
vector[N] x;
}
parameters {
vector[2] beta;
}
model {
beta ~ std_normal();
y ~ bernoulli_logit(beta[1] + beta[2] * x);
}
In this model predictions are made about the N
outputs y
using the
covariate x
. The intercept and slope of the linear equation are to be estimated.
The key point to getting this calculation to use reduce summation, is recognizing that
the statement:
y ~ bernoulli_logit(beta[1] + beta[2] * x);
can be rewritten (up to a proportionality constant) as:
for(n in 1:N) {
target += bernoulli_logit_lpmf(y[n] | beta[1] + beta[2] * x[n])
}
Now it is clear that the calculation is the sum of a number of conditionally independent Bernoulli log probability statements, which is the condition where reduce summation is useful. To use the reduce summation, a function must be written that can be used to compute arbitrary partial sums of the total sum. Using the interface defined in Reduce-Sum, such a function can be written like:
functions {
real partial_sum(array[] int y_slice,
int start, int end,
vector x,
vector beta) {
return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
}
}
The likelihood statement in the model can now be written:
target += partial_sum(y, 1, N, x, beta); // Sum terms 1 to N of the likelihood
In this example, y
was chosen to be sliced over because there
is one term in the summation per value of y
. Technically x
would have
worked as well. Use whatever conceptually makes the most
sense for a given model, e.g. slice over independent terms like
conditionally independent observations or groups of observations as in
hierarchical models. Because x
is a shared argument, it is subset
accordingly with start:end
. With this function, reduce summation can
be used to automatically parallelize the likelihood:
int grainsize = 1;
target += reduce_sum(partial_sum, y,
grainsize,
x, beta);
The reduce summation facility automatically breaks the sum into pieces
and computes them in parallel. grainsize = 1
specifies that the
grainsize
should be estimated automatically. The final model is:
functions {
real partial_sum(array[] int y_slice,
int start, int end,
vector x,
vector beta) {
return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
}
}
data {
int N;
array[N] int y;
vector[N] x;
}
parameters {
vector[2] beta;
}
model {
int grainsize = 1;
beta ~ std_normal();
target += reduce_sum(partial_sum, y,
grainsize,
x, beta);
}
25.1.2 Picking the grainsize
The rational for choosing a sensible grainsize
is based on
balancing the overhead implied by creating many small tasks versus
creating fewer large tasks which limits the potential parallelism.
In reduce_sum
, grainsize
is a recommendation on how to partition
the work in the partial sum into smaller pieces. A grainsize
of 1
leaves this entirely up to the internal scheduler and should be chosen
if no benchmarking of other grainsizes is done. Ideally this will be
efficient, but there are no guarantees.
In reduce_sum_static
, grainsize
is an upper limit on the worksize.
Work will be split until all partial sums are just smaller than grainsize
(and the split will happen the same way every time for the same inputs).
For the static version it is more important to select a sensible grainsize
.
In order to figure out an optimal grainsize
, if there are N
terms and M
cores, run a quick test model with grainsize
set
roughly to N / M
. Record the time, cut the grainsize
in half, and
run the test again. Repeat this iteratively until the model runtime
begins to increase. This is a suitable grainsize
for the model,
because this ensures the caculations can be carried out with the most
parallelism without losing too much efficiency.
For instance, in a model with N=10000
and M = 4
, start with grainsize = 25000
, and
sequentially try grainsize = 12500
, grainsize = 6250
, etc.
It is important to repeat this process until performance gets worse.
It is possible after many halvings nothing happens, but there might
still be a smaller grainsize
that performs better. Even if a sum has
many tens of thousands of terms, depending on the internal
calculations, a grainsize
of thirty or forty or smaller might be the
best, and it is difficult to predict this behavior. Without doing
these halvings until performance actually gets worse, it is easy to
miss this.