24.1 Reduce-sum

It is often necessary in probabilistic modeling to compute the sum of a number of independent function evaluations. This occurs, for instance, when evaluating a number of conditionally independent terms in a log-likelihood. If g: U -> real is the function and { x1, x2, ... } is an array of inputs, then that sum looks like:

g(x1) + g(x2) + ...

reduce_sum and reduce_sum_static are tools for parallelizing these calculations.

For efficiency reasons the reduce function doesn’t work with the element-wise evaluated function g, but instead the partial sum function f: U[] -> real, where f computes the partial sum corresponding to a slice of the sequence x passed in. Due to the associativity of the sum reduction it holds that:

g(x1) + g(x2) + g(x3) = f({ x1, x2, x3 })
                      = f({ x1, x2 }) + f({ x3 })
                      = f({ x1 }) + f({ x2, x3 })
                      = f({ x1 }) + f({ x2 }) + f({ x3 })

With the partial sum function f: U[] -> real reduction of a large number of terms can be evaluated in parallel automatically, since the overall sum can be partitioned into arbitrary smaller partial sums. The exact partitioning into the partial sums is not under the control of the user. However, since the exact numerical result will depend on the order of summation, Stan provides two versions of the reduce summation facility:

  • reduce_sum: Automatically choose partial sums partitioning based on a dynamic scheduling algorithm.
  • reduce_sum_static: Compute the same sum as reduce_sum, but partition the input in the same way for given data set (in reduce_sum this partitioning might change depending on computer load).

grainsize is the one tuning parameter. For reduce_sum, grainsize is a suggested partial sum size. A grainsize of 1 leaves the partitioning entirely up to the scheduler. This should be the default way of using reduce_sum unless time is spent carefully picking grainsize. For picking a grainsize, see details below.

For reduce_sum_static, grainsize specifies the maximal partial sum size. With reduce_sum_static it is more important to choose grainsize carefully since it entirely determines the partitioning of work. See details below.

For efficiency and convenience additional shared arguments can be passed to every term in the sum. So for the array { x1, x2, ... } and the shared arguments s1, s2, ... the effective sum (with individual terms) looks like:

g(x1, s1, s2, ...) + g(x2, s1, s2, ...) + g(x3, s1, s2, ...) + ...

which can be written equivalently with partial sums to look like:

f({ x1, x2 }, s1, s2, ...) + f({ x3 }, s1, s2, ...)

where the particular slicing of the x array can change.

Given this, the signatures are:

real reduce_sum(F f, T[] x, int grainsize, T1 s1, T2 s2, ...)
real reduce_sum_static(F f, T[] x, int grainsize, T1 s1, T2 s2, ...)
  1. f - User defined function that computes partial sums
  2. x - Array to slice, each element corresponds to a term in the summation
  3. grainsize - Target for size of slices
  4. s1, s2, ... - Arguments shared in every term

The user-defined partial sum functions have the signature:

real f(T[] x_slice, int start, int end, T1 s1, T2 s2, ...)

and take the arguments:

  1. x_slice - The subset of x (from reduce_sum / reduce_sum_static) for which this partial sum is responsible (x_slice = x[start:end])
  2. start - An integer specifying the first term in the partial sum
  3. end - An integer specifying the last term in the partial sum (inclusive)
  4. s1, s2, ... - Arguments shared in every term (passed on without modification from the reduce_sum / reduce_sum_static call)

The user-provided function f is expected to compute the partial sum with the terms start through end of the overall sum. The user function is passed the subset x[start:end] as x_slice. start and end are passed so that f can index any of the tailing sM arguments as necessary. The trailing sM arguments are passed without modification to every call of f.

A reduce_sum (or reduce_sum_static) call:

real sum = reduce_sum(f, x, grainsize, s1, s2, ...);

can be replaced by either:

real sum = f(x, 1, size(x), s1, s2, ...);

or the code:

real sum = 0.0;
for(i in 1:size(x)) {
  sum += f({ x[i] }, i, i, s1, s2, ...);
}

24.1.1 Example: logistic regression

Logistic regression is a useful example to clarify both the syntax and semantics of reduce summation and how it can be used to speed up a typical model. A basic logistic regression can be coded in Stan as:

data {
  int N;
  int y[N];
  vector[N] x;
}
parameters {
  vector[2] beta;
}
model {
  beta ~ std_normal();
  y ~ bernoulli_logit(beta[1] + beta[2] * x);
}

In this model predictions are made about the N outputs y using the covariate x. The intercept and slope of the linear equation are to be estimated. The key point to getting this calculation to use reduce summation, is recognizing that the statement:

y ~ bernoulli_logit(beta[1] + beta[2] * x);

can be rewritten (up to a proportionality constant) as:

for(n in 1:N) {
  target += bernoulli_logit_lpmf(y[n] | beta[1] + beta[2] * x[n])
}

Now it is clear that the calculation is the sum of a number of conditionally independent Bernoulli log probability statements, which is the condition where reduce summation is useful. To use the reduce summation, a function must be written that can be used to compute arbitrary partial sums of the total sum. Using the interface defined in Reduce-Sum, such a function can be written like:

functions {
  real partial_sum(int[] y_slice,
                   int start, int end,
                   vector x,
                   vector beta) {
    return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
  }
}

The likelihood statement in the model can now be written:

target += partial_sum(y, 1, N, x, beta); // Sum terms 1 to N of the likelihood

In this example, y was chosen to be sliced over because there is one term in the summation per value of y. Technically x would have worked as well. Use whatever conceptually makes the most sense for a given model, e.g. slice over independent terms like conditionally independent observations or groups of observations as in hierarchical models. Because x is a shared argument, it is subset accordingly with start:end. With this function, reduce summation can be used to automatically parallelize the likelihood:

int grainsize = 1;
target += reduce_sum(partial_sum, y,
                     grainsize,
                     x, beta);

The reduce summation facility automatically breaks the sum into pieces and computes them in parallel. grainsize = 1 specifies that the grainsize should be estimated automatically. The final model is:

functions {
  real partial_sum(int[] y_slice,
                   int start, int end,
                   vector x,
                   vector beta) {
    return bernoulli_logit_lpmf(y_slice | beta[1] + beta[2] * x[start:end]);
  }
}
data {
  int N;
  int y[N];
  vector[N] x;
}
parameters {
  vector[2] beta;
}
model {
  int grainsize = 1;
  beta ~ std_normal();
  target += reduce_sum(partial_sum, y,
                       grainsize,
                       x, beta);
}

24.1.2 Picking the grainsize

The rational for choosing a sensible grainsize is based on balancing the overhead implied by creating many small tasks versus creating fewer large tasks which limits the potential parallelism.

In reduce_sum, grainsize is a recommendation on how to partition the work in the partial sum into smaller pieces. A grainsize of 1 leaves this entirely up to the internal scheduler and should be chosen if no benchmarking of other grainsizes is done. Ideally this will be efficient, but there are no guarantees.

In reduce_sum_static, grainsize is an upper limit on the worksize. Work will be split until all partial sums are just smaller than grainsize (and the split will happen the same way every time for the same inputs). For the static version it is more important to select a sensible grainsize.

In order to figure out an optimal grainsize, if there are N terms and M cores, run a quick test model with grainsize set roughly to N / M. Record the time, cut the grainsize in half, and run the test again. Repeat this iteratively until the model runtime begins to increase. This is a suitable grainsize for the model, because this ensures the caculations can be carried out with the most parallelism without losing too much efficiency.

For instance, in a model with N=10000 and M = 4, start with grainsize = 25000, and sequentially try grainsize = 12500, grainsize = 6250, etc.

It is important to repeat this process until performance gets worse. It is possible after many halvings nothing happens, but there might still be a smaller grainsize that performs better. Even if a sum has many tens of thousands of terms, depending on the internal calculations, a grainsize of thirty or forty or smaller might be the best, and it is difficult to predict this behavior. Without doing these halvings until performance actually gets worse, it is easy to miss this.