10.7 Unit simplex
Variables constrained to the unit simplex show up in multivariate discrete models as both parameters (categorical and multinomial) and as variates generated by their priors (Dirichlet and multivariate logistic).
The unit \(K\)-simplex is the set of points \(x \in \mathbb{R}^K\) such that for \(1 \leq k \leq K\),
\[ x_k > 0, \]
and
\[ \sum_{k=1}^K x_k = 1. \]
An alternative definition is to take the convex closure of the vertices. For instance, in 2-dimensions, the simplex vertices are the extreme values \((0,1)\), and \((1,0)\) and the unit 2-simplex is the line connecting these two points; values such as \((0.3,0.7)\) and \((0.99,0.01)\) lie on the line. In 3-dimensions, the basis is \((0,0,1)\), \((0,1,0)\) and \((1,0,0)\) and the unit 3-simplex is the boundary and interior of the triangle with these vertices. Points in the 3-simplex include \((0.5,0.5,0)\), \((0.2,0.7,0.1)\) and all other triplets of non-negative values summing to 1.
As these examples illustrate, the simplex always picks out a subspace of \(K-1\) dimensions from \(\mathbb{R}^K\). Therefore a point \(x\) in the \(K\)-simplex is fully determined by its first \(K-1\) elements \(x_1, x_2, \ldots, x_{K-1}\), with
\[ x_K = 1 - \sum_{k=1}^{K-1} x_k. \]
Unit simplex inverse transform
Stan’s unit simplex inverse transform may be understood using the following stick-breaking metaphor.14
- Take a stick of unit length (i.e., length 1).
- Break a piece off and label it as \(x_1\), and set it aside, keeping what’s left.
- Next, break a piece off what’s left, label it \(x_2\), and set it aside, keeping what’s left.
- Continue breaking off pieces of what’s left, labeling them, and setting them aside for pieces \(x_3,\ldots,x_{K-1}\).
- Label what’s left \(x_K\).
The resulting vector \(x = [x_1,\ldots,x_{K}]^{\top}\) is a unit simplex because each piece has non-negative length and the sum of the stick lengths is one by construction.
This full inverse mapping requires the breaks to be represented as the fraction in \((0,1)\) of the original stick that is broken off. These break ratios are themselves derived from unconstrained values in \((-\infty,\infty)\) using the inverse logit transform as described above for unidimensional variables with lower and upper bounds.
More formally, an intermediate vector \(z \in \mathbb{R}^{K-1}\), whose coordinates \(z_k\) represent the proportion of the stick broken off in step \(k\), is defined elementwise for \(1 \leq k < K\) by
\[ z_k = \mathrm{logit}^{-1} \left( y_k + \log \left( \frac{1}{K - k} \right) \right). \]
The logit term \(\log\left(\frac{1}{K-k}\right) (i.e., \mathrm{logit}\left(\frac{1}{K-k+1}\right)\)) in the above definition adjusts the transform so that a zero vector \(y\) is mapped to the simplex \(x = (1/K,\ldots,1/K)\). For instance, if \(y_1 = 0\), then \(z_1 = 1/K\); if \(y_2 = 0\), then \(z_2 = 1/(K-1)\); and if \(y_{K-1} = 0\), then \(z_{K-1} = 1/2\).
The break proportions \(z\) are applied to determine the stick sizes and resulting value of \(x_k\) for \(1 \leq k < K\) by
\[ x_k = \left( 1 - \sum_{k'=1}^{k-1} x_{k'} \right) z_k. \]
The summation term represents the length of the original stick left at stage \(k\). This is multiplied by the break proportion \(z_k\) to yield \(x_k\). Only \(K-1\) unconstrained parameters are required, with the last dimension’s value \(x_K\) set to the length of the remaining piece of the original stick,
\[ x_K = 1 - \sum_{k=1}^{K-1} x_k. \]
Absolute Jacobian determinant of the unit-simplex inverse transform
The Jacobian \(J\) of the inverse transform \(f^{-1}\) is lower-triangular, with diagonal entries
\[ J_{k,k} = \frac{\partial x_k}{\partial y_k} = \frac{\partial x_k}{\partial z_k} \, \frac{\partial z_k}{\partial y_k}, \]
where
\[ \frac{\partial z_k}{\partial y_k} = \frac{\partial}{\partial y_k} \mathrm{logit}^{-1} \left( y_k + \log \left( \frac{1}{K-k} \right) \right) = z_k (1 - z_k), \]
and
\[ \frac{\partial x_k}{\partial z_k} = \left( 1 - \sum_{k' = 1}^{k-1} x_{k'} \right) . \]
This definition is recursive, defining \(x_k\) in terms of \(x_{1},\ldots,x_{k-1}\).
Because the Jacobian \(J\) of \(f^{-1}\) is lower triangular and positive, its absolute determinant reduces to
\[ \left| \, \det J \, \right| \ = \ \prod_{k=1}^{K-1} J_{k,k} \ = \ \prod_{k=1}^{K-1} z_k \, (1 - z_k) \ \left( 1 - \sum_{k'=1}^{k-1} x_{k'} \right) . \]
Thus the transformed variable \(Y = f(X)\) has a density given by
\[ p_Y(y) = p_X(f^{-1}(y)) \, \prod_{k=1}^{K-1} z_k \, (1 - z_k) \ \left( 1 - \sum_{k'=1}^{k-1} x_{k'} \right) . \]
Even though it is expressed in terms of intermediate values \(z_k\), this expression still looks more complex than it is. The exponential function need only be evaluated once for each unconstrained parameter \(y_k\); everything else is just basic arithmetic that can be computed incrementally along with the transform.
Unit simplex transform
The transform \(Y = f(X)\) can be derived by reversing the stages of the inverse transform. Working backwards, given the break proportions \(z\), \(y\) is defined elementwise by
\[ y_k = \mathrm{logit}(z_k) - \mbox{log}\left( \frac{1}{K-k} \right) . \]
The break proportions \(z_k\) are defined to be the ratio of \(x_k\) to the length of stick left after the first \(k-1\) pieces have been broken off,
\[ z_k = \frac{x_k} {1 - \sum_{k' = 1}^{k-1} x_{k'}} . \]
References
For an alternative derivation of the same transform using hyperspherical coordinates, see (Betancourt 2010).↩︎