Mixture modeling is a powerful technique for integrating multiple data generating processes into a single model. Unfortunately when those data data generating processes are degenerate the resulting mixture model suffers from inherent combinatorial non-identifiabilities that frustrate accurate computation. Consequently, in order to utilize mixture models reliably in practice we need strong and principled prior information to ameliorate these frustrations.
In this case study I will first introduce how mixture models are implemented in Bayesian inference. I will then discuss the non-identifiability inherent to that construction as well as how the non-identifiability can be tempered with principled prior information. Lastly I will demonstrate how these issues manifest in a simple example, with a final tangent to consider an additional pathology that can arise in Bayesian mixture models.
In a mixture model we assume that a given measurement, \(y\), can be drawn from one of \(K\) data generating processes, each with their own set of parameters, \(\pi_{k} ( y \mid \alpha_{k} )\). To implement such a model we need to construct the corresponding likelihood and then subsequent posterior distribution.
Let \(z \in \{0 ,\ldots, K\}\) be an assignment that indicates to which data generating process our measurement was generated. Conditioned on this assignment, the mixture likelihood is just \[ \pi (y \mid \boldsymbol{\alpha}, z) = \pi_{z} ( y \mid \alpha_{z} ), \] where \(\boldsymbol{\alpha} = (\alpha_1, \ldots, \alpha_K)\).
By combining assignments with a set of data generating processes we admit an extremely expressive class of models that encompass many different inferential and decision problems. For example, if multiple measurements \(y_n\) are given but the corresponding assignments \(z_n\) are unknown then inference over the mixture model is equivalent to clustering the measurements across the component data generating processes. Similarly, if both the measurements and the assignments are given then inference over the mixture model admits classification of future measurements. Finally, semi-supervised learning corresponds to inference over a mixture model where only some of the assignments are known.
In practice discrete assignments are difficult to fit accurately and efficiently, but we can facilitate inference by marginalizing the assignments out of the model entirely. If each component in the mixture occurs with probability \(\theta_k\), \[ \boldsymbol{\theta} = (\theta_1, \ldots, \theta_K), \, 0 \le \theta_{k} \le 1, \, \sum_{k = 1}^{K} \theta_{k} = 1, \] then the assignments follow a multinomial distribution, \[ \pi (z \mid \boldsymbol{\theta} ) = \theta_{z}, \] and the joint likelihood over the measurement and its assignment is given by \[ \pi (y, z \mid \boldsymbol{\alpha}, \boldsymbol{\theta}) = \pi (y \mid \boldsymbol{\alpha}, z) \, \pi (z \mid \boldsymbol{\theta} ) = \pi_{z} ( y \mid \alpha_{z} ) \, \theta_z. \] Marginalizing over all of the possible assignments then gives \[ \begin{align*} \pi (y \mid \boldsymbol{\alpha}, \boldsymbol{\theta}) &= \sum_{z} \pi (y, z \mid \boldsymbol{\alpha}, \boldsymbol{\theta}) \\ &= \sum_{z} \pi_{z} ( y \mid \alpha_{z} ) \, \theta_z \\ &= \sum_{k = 1}^{K} \pi_{k} ( y \mid \alpha_{k} ) \, \theta_k \\ &= \sum_{k = 1}^{K} \theta_k \, \pi_{k} ( y \mid \alpha_{k} ). \end{align*} \] In words, after marginalizing out the assignments the mixture likelihood reduces to a convex combination of the component data generating processes.
Marginalizing out the discrete assignments yields a likelihood that depends on only continuous parameters, making it amenable to state-of-the-art tools like Stan. Moreover, modeling the latent mixture probabilities instead of the discrete assignments admits more precise inferences as a consequence of the Rao-Blackwell theorem. From any perspective the marginalized mixture likelihood is the ideal basis for inference.
In order to perform Bayesian inference over a mixture model we need to complement the mixture likelihood with prior distributions for both the component parameters, \(\boldsymbol{\alpha}\), and the mixture probabilities, \(\boldsymbol{\theta}\). Assuming that these distributions are independent a priori, \[ \pi(\boldsymbol{\alpha}, \boldsymbol{\theta}) = \pi(\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}), \] the subsequent posterior for a single measurement takes the form \[ \pi(\boldsymbol{\alpha}, \boldsymbol{\theta} \mid y) \propto \pi(\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \sum_{k = 1}^{K} \theta_k \, \pi_{k} ( y \mid \alpha_k ). \]
Similarly, the posterior for multiple measurements becomes \[ \pi(\boldsymbol{\alpha}, \boldsymbol{\theta} \mid \mathbf{y}) \propto \pi(\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \sum_{n = 1}^{N} \sum_{k = 1}^{K} \theta_k \, \pi_{k} ( y_n \mid \alpha_k ). \] Additional measurements, however, do not impact the non-identifiability inherent to mixture models. Consequently we will consider only a single measurement in the proceeding section, returning to multiple measurements in the example.
When making inferences with a mixture model we need to learn each of the component weights, \(\theta_k\), and the component parameters, \(\alpha_k\). This introduces a subtle challenge because if the measurement cannot discriminate between the components then it cannot discriminate between the component parameters.
If the individual component distributions \(\pi_{k} (y \mid \alpha_{k})\) are distinct then the unique characteristics of each can be enough to inform the corresponding parameters individually and the mixture model is straightforward to fit. Circumstances become much more dire, however, in the degenerate case when the components are identical, \(\pi_{k} (y \mid \alpha_{k}) = \pi (y \mid \alpha_{k})\). In this case there is a fundamental ambiguity as to which parameters \(\alpha_{k}\) are associated with each component in the mixture.
To see this, let \(\sigma\) denote a permutation of the indices in our mixture, \[ \sigma (1, \ldots, K) \mapsto ( \sigma(1), \ldots, \sigma(K)), \] with \[ \sigma (\boldsymbol{\alpha}) = \sigma( \alpha_1, \ldots, \alpha_K) \mapsto ( \alpha_{\sigma(1)}, \ldots, \alpha_{\sigma(K)}). \] When the component distributions are identical the mixture likelihood is invariant to any permutation of the indices, \[ \begin{align*} \pi(y \mid \sigma(\boldsymbol{\alpha}), \sigma(\boldsymbol{\theta})) &= \sum_{k = 1}^{K} \theta_{\sigma(k)} \, \pi_{\sigma(k)} ( y \mid \alpha_{\sigma(k)} ) \\ &= \sum_{k' = 1}^{K} \theta_{k'} \, \pi_{k'} ( y \mid \alpha_{k'} ) \\ &= \pi(y \mid \boldsymbol{\alpha}, \boldsymbol{\theta}). \end{align*} \]
Moreover, when the priors are exchangeable, \(\pi (\sigma(\boldsymbol{\alpha})) = \pi(\boldsymbol{\alpha})\) and \(\pi (\sigma(\boldsymbol{\theta})) = \pi(\boldsymbol{\theta})\), then the posterior will inherit the permutation invariance of the mixture likelihood, \[ \begin{align*} \pi(\sigma(\boldsymbol{\alpha}), \sigma(\boldsymbol{\theta}) \mid y) &\propto \pi(\sigma(\boldsymbol{\alpha})) \, \pi(\sigma(\boldsymbol{\theta})) \sum_{k = 1}^{K} \theta_{\sigma(k)} \, \pi_{\sigma(k)} ( y \mid \alpha_{\sigma(k)} ) \\ &\propto \pi(\sigma(\boldsymbol{\alpha})) \, \pi(\sigma(\boldsymbol{\theta})) \sum_{k' = 1}^{K} \theta_{k'} \, \pi_{k'} ( y \mid \alpha_{k'} ) \\ &\propto \pi(\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \sum_{k' = 1}^{K} \theta_{k'} \, \pi_{k'} ( y \mid \alpha_{k'} ) \\ &= \pi(\boldsymbol{\alpha}, \boldsymbol{\theta} \mid y). \end{align*} \] In this case all of our inferences will be the same regardless of how we label the mixture components with explicit indices.
Because of this labeling degeneracy the posterior distribution will be non-identified. In particular, it will manifest multimodality, with one mode for each of the possible labelings. For a mixture with \(K\) identical components there are \(K!\) possible labelings and hence any degenerate mixture model will exhibit at least \(K!\) modes.
Hence even for a relatively small number of components the posterior distribution will have too many modes for any statistical algorithm to accurately quantify unless the modes collapse into each other. For example, if we applied Markov chain Monte Carlo then any chain would be able to explore one of the modes but it would not be able to transition between the modes, at least not within in any finite running time.
Even if we had a statistical algorithm that could transition between the degenerate modes and explore the entire mixture posterior, typically there will be too many modes to complete that exploration in any reasonable time. Consequently if we want to accurately fit these models in practice then we need to break the labeling degeneracy and remove the extreme multimodality altogether.
Exactly how we break the labeling degeneracy depends on what prior information we can exploit. In particular, our strategy will be different depending on whether our prior information is exchangeable or not.
Because the posterior distribution inherits the permutation-invariance of the mixture likelihood only if the priors are exchangeable, one way to immediately obstruct the labeling degeneracy of the mixture posterior is to employ non-exchangeable priors. This approach is especially useful when each component of the likelihood is meant to be responsible for a specific purpose, for example when each component models a known subpopulations with distinct behaviors about which we have prior information. If this principled prior information is strong enough then the prior can suppress all but the one labeling consistent with these responsibilities, ensuring a unimodal mixture posterior distribution.
When our prior information is exchangeable there is nothing preventing the mixture posterior from becoming multimodal and impractical to fit. When our inferences are also exchangeable, however, we can exploit the symmetry of the labeling degeneracies to simplify the computational problem dramatically.
In this section we’ll study the symmetric geometry induced by labeling degeneracies and show how that symmetry can be used to reduce the multimodal mixture posterior into a unimodal distribution that yields exactly the same inferences while being far easier to fit.
Each labeling is characterized by the unique assignment of indices to the components in our mixture. Permuting the indices yields a new assignment and hence a new labeling of our mixture model. Consequently a natural way to identify each labeling is to a choose a standard set of indices, \(\alpha_1, \ldots, \alpha_K\), and distinguish each labeling by the permutation that maps to the appropriate indices, \(\alpha_{\sigma(1)}, \ldots, \alpha_{\sigma(K)}\). The standard indices themselves identify a labeling with the trivial permutation that leaves the indices unchanged.
In general it is difficult to utilize these permutations, but if the component parameters, \(\alpha_n\), are scalar then we can exploit their unique ordering to readily identify permutations and hence labelings. For example, if we choose the standard labeling to be the one where the parameter values are ordered, \(\alpha_1 \le \ldots \le \alpha_K\), then any permutation will yield a new ordering of the parameters, \(\alpha_{\sigma(1)} \le \ldots \le \alpha_{\sigma(K)}\), which then identifies another labeling. In other words, we can identify the each labeling by the ordering the parameter values.
This identification also has a welcome geometric interpretation. The region of parameter space satisfying a given ordering constraint, such as \(\alpha_1 \le \ldots \le \alpha_K\), defines a square pyramid with the apex point at zero. The \(K\)-dimensional parameter space neatly decomposes into \(K!\) of these pyramids, each with a distinct ordering and hence association with a unique labeling.
When the priors are exchangeable the mixture posterior aliases across each of these pyramids in parameter space: if we were given the mixture posterior restricted to one of these pyramids then we could reconstruct the entire mixture distribution by simply rotating that restricted distribution into each of the other \(K! - 1\) pyramids. As we do this we also map the mode in the restricted distribution into each pyramid, creating exactly the expected \(K!\) multimodality. Moreover, those rotations are exactly given by permuting the parameter indices and reordering the corresponding parameter values.
From a Bayesian perspective, all well-defined inferences are given by expectations of certain functions with respect to our posterior distribution.
Hence if we want to limit ourselves to only those inferences insensitive to the labeling then we have to considering expectations only of those functions that are permutation invariant, \(f(\sigma(\boldsymbol{\alpha})) = f(\boldsymbol{\alpha})\).
Importantly, under this class of functions the symmetry of the degenerate mixture posterior carries over to the expectation values themselves: the expectation taken over each pyramid will yield exactly the same value.
Consequently we should be able to avoid the combinatorial cost of fitting the full mixture model by simply restricting our exploration to a single ordering of the parameters.
Imposing an ordering on parameter space can be taken as a a computational trick, but it can also be interpreted as method of making an exchangeable prior non-exchangeable without affecting the resulting inferences. Given the exchangeable prior \(\pi (\boldsymbol{\alpha})\) we define the non-exchangeable prior \[ \pi' (\boldsymbol{\alpha}) = \left\{ \begin{array}{ll} \pi (\boldsymbol{\alpha}), & \alpha_1 \le \ldots \le \alpha_K \\ 0, & \mathrm{else} \end{array} \right. , \] which limits the mixture posterior, and hence any expectations, to a single ordering. From this perspective all of our strategies for breaking the labeling degeneracy reduce to imposing a non-exchangeable prior.
In this section I will formalize the utility of ordering by proving that the resulting inferences are indeed correct. In order to outline the proof we will first consider a two-component mixture before moving onto the general case.
In the two-component case we have two parameters, \(\alpha_1\) and \(\alpha_2\), and two mixture weights, \(\theta_1\) and \(\theta_2 = 1 - \theta_1\).
We begin with the desired expectation and decompose it over the two pyramids that arise in the two-dimensional parameter space, \[ \begin{align*} \mathbb{E}_{\pi} [ f ] &= \int \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot \pi (\alpha_1 ,\alpha_2, \theta_1 ,\theta_2 \mid y) \\ &\propto \int \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot \pi (\alpha_1 ,\alpha_2) \, \pi (\theta_1 ,\theta_2) \, ( \theta_1 \pi (y \mid \alpha_1) + \theta_2 \pi (y \mid \alpha_2) ) \\ &\propto \quad \int_{\alpha_1 < \alpha_2} \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot \pi (\alpha_1 ,\alpha_2) \, \pi (\theta_1 ,\theta_2) \, ( \theta_1 \pi (y \mid \alpha_1) + \theta_2 \pi (y \mid \alpha_2) ) \\ &\quad + \int_{\alpha_2 < \alpha_1} \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot \pi (\alpha_1 ,\alpha_2) \, \pi (\theta_1 ,\theta_2) \, ( \theta_1 \pi (y \mid \alpha_1) + \theta_2 \pi (y \mid \alpha_2) ). \end{align*} \]
We want to manipulate the second term into something that looks like the first, which we can accomplish with a permutation of the parameters, \((\alpha_1, \alpha_2) \rightarrow (\beta_2, \beta_1)\) and \((\theta_1, \theta_2) \rightarrow (\lambda_2, \lambda_1)\), that rotates the second pyramid into the first. This gives \[ \begin{align*} \mathbb{E}_{\pi} [ f ] &\propto \quad \int_{\alpha_1 < \alpha_2} \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot \pi (\alpha_1 ,\alpha_2) \, \pi (\theta_1 ,\theta_2) \, ( \theta_1 \pi (y \mid \alpha_1) + \theta_2 \pi (y \mid \alpha_2) ) \\ &\quad + \int_{\beta_1 < \beta_2} \mathrm{d} \lambda_2 \mathrm{d} \lambda_1 \mathrm{d} \beta_2 \mathrm{d} \beta_1 \cdot f (\beta_2, \beta_1) \cdot \pi (\beta_2, \beta_1) \, \pi (\lambda_2 ,\lambda_1) \, ( \lambda_2 \pi (y \mid \beta_2) + \lambda_1 \pi (y \mid \beta_1) ) \\ &\propto \quad \int_{\alpha_1 < \alpha_2} \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot \pi (\alpha_1 ,\alpha_2) \, \pi (\theta_1 ,\theta_2) \, ( \theta_1 \pi (y \mid \alpha_1) + \theta_2 \pi (y \mid \alpha_2) ) \\ &\quad + \int_{\beta_1 < \beta_2} \mathrm{d} \lambda_1 \mathrm{d} \lambda_2 \mathrm{d} \beta_1 \mathrm{d} \beta_2 \cdot f (\beta_2, \beta_1) \cdot \pi (\beta_2, \beta_1) \, \pi (\lambda_2 ,\lambda_1) \, ( \lambda_1 \pi (y \mid \beta_1) + \lambda_2 \pi (y \mid \beta_2) ). \end{align*} \]
Now we exploit the permutation-invariance of \(f\) and the exchangeability of the priors to massage the second term to be equivalent to the first, \[ \begin{align*} \mathbb{E}_{\pi} [ f ] &\propto \quad \int_{\alpha_1 < \alpha_2} \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot \pi (\alpha_1 ,\alpha_2) \, \pi (\theta_1 ,\theta_2) \, ( \theta_1 \pi (y \mid \alpha_1) + \theta_2 \pi (y \mid \alpha_2) ) \\ &\quad + \int_{\beta_1 < \beta_2} \mathrm{d} \lambda_1 \mathrm{d} \lambda_2 \mathrm{d} \beta_1 \mathrm{d} \beta_2 \cdot f (\beta_1, \beta_2) \cdot \pi (\beta_1, \beta_2) \, \pi (\lambda_1 ,\lambda_2) \, ( \lambda_1 \pi (y \mid \beta_1) + \lambda_2 \pi (y \mid \beta_2) ) \\ &\propto 2 \int_{\alpha_1 < \alpha_2} \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot \pi (\alpha_1 ,\alpha_2) \, \pi (\theta_1 ,\theta_2) \, ( \theta_1 \pi (y \mid \alpha_1) + \theta_2 \pi (y \mid \alpha_2) ) \\ &\propto \int_{\alpha_1 < \alpha_2} \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot 2 \pi (\alpha_1 ,\alpha_2) \, \pi (\theta_1 ,\theta_2) \, 2 ( \theta_1 \pi (y \mid \alpha_1) + \theta_2 \pi (y \mid \alpha_2) ) \\ &= \int_{\alpha_1 < \alpha_2} \mathrm{d} \theta_1 \mathrm{d} \theta_2 \mathrm{d} \alpha_1 \mathrm{d} \alpha_2 \cdot f (\alpha_1, \alpha_2) \cdot \pi' (\alpha_1, \alpha_2, \theta_1, \theta_2 \mid y). \end{align*} \]
\(\pi' (\alpha_1, \alpha_2, \theta_1, \theta_2)\), however, is exactly the mixture posterior density restricted to the pyramid defined by the standard ordering, so we finally have \[ \mathbb{E}_{\pi} [ f ] = \mathbb{E}_{\pi'} [ f ]. \] In words, taking an expectation over the pyramid defined by the standard ordering yields the same value as the expectation taken over the entire parameter space. Only the distribution over that one pyramid is no longer multimodal!
The general case follows almost exactly once we use permutations of the standard ordering to identify each pyramid of the \(K!\) pyramids. Writing \(\Sigma'\) as the set of all \(K! - 1\) label permutations except for the trivial permutation, the desired expectation decomposes as \[ \begin{align*} \mathbb{E}_{\pi} [ f ] &= \int \prod_{k = 1}^{K} \mathrm{d} \theta_k \mathrm{d} \alpha_k \cdot f (\boldsymbol{\alpha}) \cdot \pi (\boldsymbol{\alpha}, \boldsymbol{\theta} \mid y) \\ &\propto \int \prod_{k = 1}^{K} \mathrm{d} \theta_k \mathrm{d} \alpha_k \cdot f (\boldsymbol{\alpha}) \cdot \pi (\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \, \sum_{k = 1}^{K} \theta_k \, \pi (y \mid \alpha_k) \\ &\propto \quad \int_{\alpha_1 < \ldots < \alpha_K} \prod_{k = 1}^{K} \mathrm{d} \theta_k \mathrm{d} \alpha_k \cdot f (\boldsymbol{\alpha}) \cdot \pi (\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \, \sum_{k = 1}^{K} \theta_k \, \pi (y \mid \alpha_k) \\ &\quad + \sum_{\sigma \in \Sigma'} \int_{\alpha_{\sigma(1)} < \ldots < \alpha_{\sigma(K)}} \prod_{k = 1}^{K} \mathrm{d} \theta_k \mathrm{d} \alpha_k \cdot f (\boldsymbol{\alpha}) \cdot \pi (\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \, \sum_{k = 1}^{K} \theta_k \, \pi (y \mid \alpha_k). \end{align*} \]
In the permuted terms we apply the transformations \[ \sigma( \boldsymbol{\alpha} ) \mapsto \boldsymbol{\beta}, \, \sigma( \boldsymbol{\theta} ) \mapsto \boldsymbol{\lambda} \] to give \[ \begin{align*} \mathbb{E}_{\pi} [ f ] &\propto \quad \int_{\alpha_1 < \ldots < \alpha_K} \prod_{k = 1}^{K} \mathrm{d} \theta_k \mathrm{d} \alpha_k \cdot f (\boldsymbol{\alpha}) \cdot \pi (\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \, \sum_{k = 1}^{K} \theta_k \, \pi (y \mid \alpha_k) \\ &\quad + \sum_{\sigma \in \Sigma'} \int_{\beta_1 < \ldots < \beta_K} \prod_{k = 1}^{K} \mathrm{d} \lambda_{\sigma^{-1}(k)} \mathrm{d} \beta_{\sigma^{-1}(k)} \cdot f ( \sigma^{-1}(\boldsymbol{\beta}) ) \cdot \pi ( \sigma^{-1}(\boldsymbol{\beta}) ) \, \pi ( \sigma^{-1}(\boldsymbol{\lambda}) ) \, \sum_{k = 1}^{K} \lambda_{\sigma^{-1}(k)} \, \pi (y \mid \beta_{\sigma^{-1}(k)}) \\ &\propto \quad \int_{\alpha_1 < \ldots < \alpha_K} \prod_{k = 1}^{K} \mathrm{d} \theta_k \mathrm{d} \alpha_k \cdot f (\boldsymbol{\alpha}) \cdot \pi (\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \, \sum_{k = 1}^{K} \theta_k \, \pi (y \mid \alpha_k) \\ &\quad + \sum_{\sigma \in \Sigma'} \int_{\beta_1 < \ldots < \beta_K} \prod_{k' = 1}^{K} \mathrm{d} \lambda_{k'} \mathrm{d} \beta_{k'} \cdot f ( \sigma^{-1}(\boldsymbol{\beta}) ) \cdot \pi ( \sigma^{-1}(\boldsymbol{\beta}) ) \, \pi ( \sigma^{-1}(\boldsymbol{\lambda}) ) \, \sum_{k' = 1}^{K} \lambda_{k'} \, \pi (y \mid \beta_{k'}). \end{align*} \] We now exploit the permutation-invariance of \(f\) and the exchangeability of the priors to give \[ \begin{align*} \mathbb{E}_{\pi} [ f ] &\propto \quad \int_{\alpha_1 < \ldots < \alpha_K} \prod_{k = 1}^{K} \mathrm{d} \theta_k \mathrm{d} \alpha_k \cdot f (\boldsymbol{\alpha}) \cdot \pi (\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \, \sum_{k = 1}^{K} \theta_k \, \pi (y \mid \alpha_k) \\ &\quad + \sum_{\sigma \in \Sigma'} \int_{\beta_1 < \ldots < \beta_K} \prod_{k' = 1}^{K} \mathrm{d} \lambda_{k'} \mathrm{d} \beta_{k'} \cdot f ( \boldsymbol{\beta} ) \cdot \pi ( \boldsymbol{\beta} ) \, \pi ( \boldsymbol{\lambda} ) \, \sum_{k' = 1}^{K} \lambda_{k'} \, \pi (y \mid \beta_{k'}) \\ &\propto K! \int_{\alpha_1 < \ldots < \alpha_K} \prod_{k = 1}^{K} \mathrm{d} \theta_k \mathrm{d} \alpha_k \cdot f (\boldsymbol{\alpha}) \cdot \pi (\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \, \sum_{k = 1}^{K} \theta_k \, \pi (y \mid \alpha_k) \\ &\propto \int_{\alpha_1 < \ldots < \alpha_K} \prod_{k = 1}^{K} \mathrm{d} \theta_k \mathrm{d} \alpha_k \cdot f (\boldsymbol{\alpha}) \cdot K! \, \pi (\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \, \sum_{k = 1}^{K} \theta_k \, \pi (y \mid \alpha_k) \\ &= \int_{\alpha_1 < \ldots < \alpha_N} \prod_{n = 1}^{K} \mathrm{d} \theta_n \mathrm{d} \alpha_n \cdot f (\boldsymbol{\alpha}) \cdot \pi' (\boldsymbol{\alpha}, \boldsymbol{\theta} \mid y). \end{align*} \]
Once again \[ \pi' (\boldsymbol{\alpha}, \boldsymbol{\theta} \mid y) \propto K! \, \pi (\boldsymbol{\alpha}) \, \pi(\boldsymbol{\theta}) \, \sum_{k = 1}^{K} \theta_k \, \pi (y \mid \alpha_k) \] is exactly the mixture posterior density confined to the pyramid given by the standard ordering. Consequently, in general we can exactly recover the desired expectation by integrating over only the region of parameter space satisfying the chosen ordering constraint.
While the ordering is limited to scalar parameters, it can still prove useful when the component distributions are multivariate. Although we cannot order the multivariate parameters themselves, ordering any one of the parameters is sufficient to break the labeling degeneracy for the entire mixture.
To illustrate the pathologies of Bayesian mixture models, and their potential resolutions, let’s consider a relatively simple example where the likelihood is given by a mixture of two Gaussians, \[ \pi(y_1, \ldots, y_N \mid \mu_1, \sigma_1, \mu_2, \sigma_2, \theta_1, \theta_2) = \sum_{n = 1}^{N} \theta_1 \mathcal{N} (y_n \mid \mu_1, \sigma_1) + \theta_2 \mathcal{N} (y_n \mid \mu_2, \sigma_2). \] Note that the mixture is applied to each datum individually – our model assumes that each measurement is drawn from one of the components independently as opposed to the entire dataset being drawn from one of the components as a whole.
We first define the component data generating processes to be well-separated relative to their standard deviations,
mu <- c(-2.75, 2.75);
sigma <- c(1, 1);
lambda <- 0.4
Then we simulate some data from the mixture likelihood by following its generative structure, first drawing assignments for each measurement and then drawing the measurements themselves from the corresponding Gaussian,
set.seed(689934)
N <- 1000
z <- rbinom(N, 1, lambda) + 1;
y <- rnorm(N, mu[z], sigma[z]);
library(rstan)
Loading required package: ggplot2
Loading required package: StanHeaders
rstan (Version 2.14.1, packaged: 2016-12-28 14:55:41 UTC, GitRev: 5fa1e80eb817)
For execution on a local, multicore CPU with excess RAM we recommend calling
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)
stan_rdump(c("N", "y"), file="mix.data.R")
Let’s now consider a Bayesian fit of this model, first with labeling degeneracies and then with various attempted resolutions.
As discussed above, in order to ensure that the labeling degeneracies persist in the posterior distribution we need exchangeable priors. We can, for example, accomplish this by assigning the identical priors to the Gaussian parameters, \[ \mu_1, \mu_2 \sim \mathcal{N} (0, 2), \, \sigma_1, \sigma_2 \sim \text{Half-}\mathcal{N} (0, 2), \] and a symmetric Beta distribution to the mixture weight, \[ \theta_1 \sim \text{Beta} (5, 5). \]
writeLines(readLines("gauss_mix.stan"))
data {
int<lower = 0> N;
vector[N] y;
}
parameters {
vector[2] mu;
real<lower=0> sigma[2];
real<lower=0, upper=1> theta;
}
model {
sigma ~ normal(0, 2);
mu ~ normal(0, 2);
theta ~ beta(5, 5);
for (n in 1:N)
target += log_mix(theta,
normal_lpdf(y[n] | mu[1], sigma[1]),
normal_lpdf(y[n] | mu[2], sigma[2]));
}
Equivalently we could also have defined \(\theta\) as a two-dimensional simplex with a \(\text{Dirichlet}(5, 5)\) prior which would yield the same exact model.
Aware of the labeling degeneracy, let’s go ahead and fit this Bayesian mixture model in Stan,
input_data <- read_rdump("mix.data.R")
degenerate_fit <- stan(file='gauss_mix.stan', data=input_data,
chains=4, seed=483892929, refresh=2000)
SAMPLING FOR MODEL 'gauss_mix' NOW (CHAIN 1).
Chain 1, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 1, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 1, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 1.38577 seconds (Warm-up)
0.973725 seconds (Sampling)
2.3595 seconds (Total)
SAMPLING FOR MODEL 'gauss_mix' NOW (CHAIN 2).
Chain 2, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 2, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 2, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 1.30787 seconds (Warm-up)
0.909254 seconds (Sampling)
2.21712 seconds (Total)
SAMPLING FOR MODEL 'gauss_mix' NOW (CHAIN 3).
Chain 3, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 3, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 3, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 1.33762 seconds (Warm-up)
1.2385 seconds (Sampling)
2.57612 seconds (Total)
SAMPLING FOR MODEL 'gauss_mix' NOW (CHAIN 4).
Chain 4, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 4, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 4, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 1.43549 seconds (Warm-up)
0.895336 seconds (Sampling)
2.33083 seconds (Total)
The split Rhat is atrocious, indicating that the chains are not exploring the same regions of parameter space.
print(degenerate_fit)
Inference for Stan model: gauss_mix.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 25% 50% 75%
mu[1] -1.33 1.72 2.43 -2.81 -2.75 -2.72 -1.26
mu[2] 1.47 1.72 2.43 -2.79 1.36 2.85 2.89
sigma[1] 1.03 0.00 0.03 0.96 1.00 1.03 1.05
sigma[2] 1.02 0.00 0.04 0.95 1.00 1.02 1.05
theta 0.56 0.07 0.11 0.36 0.53 0.62 0.63
lp__ -2108.57 0.03 1.55 -2112.46 -2109.43 -2108.28 -2107.43
97.5% n_eff Rhat
mu[1] 2.94 2 56.48
mu[2] 2.96 2 51.11
sigma[1] 1.09 4000 1.00
sigma[2] 1.10 4000 1.00
theta 0.65 2 7.52
lp__ -2106.51 2050 1.00
Samples were drawn using NUTS(diag_e) at Thu Mar 2 15:36:24 2017.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
Indeed this is to be expected as the individual chains find and then explore one of the two degenerate modes independently of the others,
c_light_trans <- c("#DCBCBCBF")
c_light_highlight_trans <- c("#C79999BF")
c_mid_trans <- c("#B97C7CBF")
c_mid_highlight_trans <- c("#A25050BF")
c_dark_trans <- c("#8F2727BF")
c_dark_highlight_trans <- c("#7C0000BF")
params1 <- as.data.frame(extract(degenerate_fit, permuted=FALSE)[,1,])
params2 <- as.data.frame(extract(degenerate_fit, permuted=FALSE)[,2,])
params3 <- as.data.frame(extract(degenerate_fit, permuted=FALSE)[,3,])
params4 <- as.data.frame(extract(degenerate_fit, permuted=FALSE)[,4,])
par(mar = c(4, 4, 0.5, 0.5))
plot(params1$"mu[1]", params1$"mu[2]", col=c_dark_highlight_trans, pch=16, cex=0.8,
xlab="mu1", xlim=c(-3, 3), ylab="mu2", ylim=c(-3, 3))
points(params2$"mu[1]", params2$"mu[2]", col=c_dark_trans, pch=16, cex=0.8)
points(params3$"mu[1]", params3$"mu[2]", col=c_mid_highlight_trans, pch=16, cex=0.8)
points(params4$"mu[1]", params4$"mu[2]", col=c_mid_trans, pch=16, cex=0.8)
lines(0.08*(1:100) - 4, 0.08*(1:100) - 4, col="grey", lw=2)
legend("topright", c("Chain 1", "Chain 2", "Chain 3", "Chain 4"),
fill=c(c_dark_highlight_trans, c_dark_trans,
c_mid_highlight_trans, c_mid_trans), box.lty=0, inset=0.0005)
This degenerate example is a particularly nice demonstration of the importance of running multiple chains in any MCMC analysis. If we had run just one chain then we would have had no indication of the multimodality in our the posterior and the incompleteness of our fits!
Our first potential resolution of the labeling degeneracy is to tweak the priors to no longer be exchangeable. With no reason to expect that the standard deviations will vary between the two components, we will instead adjust the priors for the means to strongly favor \(\mu_1\) positive and \(\mu_2\) negative,
writeLines(readLines("gauss_mix_asym_prior.stan"))
data {
int<lower = 0> N;
vector[N] y;
}
parameters {
vector[2] mu;
real<lower=0> sigma[2];
real<lower=0, upper=1> theta;
}
model {
sigma ~ normal(0, 2);
mu[1] ~ normal(4, 0.5);
mu[2] ~ normal(-4, 0.5);
theta ~ beta(5, 5);
for (n in 1:N)
target += log_mix(theta,
normal_lpdf(y[n] | mu[1], sigma[1]),
normal_lpdf(y[n] | mu[2], sigma[2]));
}
Running in Stan
asym_fit <- stan(file='gauss_mix_asym_prior.stan', data=input_data,
chains=4, seed=483892929, refresh=2000)
SAMPLING FOR MODEL 'gauss_mix_asym_prior' NOW (CHAIN 1).
Chain 1, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 1, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 1, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 1.32051 seconds (Warm-up)
0.892697 seconds (Sampling)
2.21321 seconds (Total)
SAMPLING FOR MODEL 'gauss_mix_asym_prior' NOW (CHAIN 2).
Chain 2, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 2, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 2, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 1.37814 seconds (Warm-up)
1.60082 seconds (Sampling)
2.97896 seconds (Total)
SAMPLING FOR MODEL 'gauss_mix_asym_prior' NOW (CHAIN 3).
Chain 3, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 3, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 3, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 1.34297 seconds (Warm-up)
0.927972 seconds (Sampling)
2.27095 seconds (Total)
SAMPLING FOR MODEL 'gauss_mix_asym_prior' NOW (CHAIN 4).
Chain 4, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 4, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 4, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 1.40858 seconds (Warm-up)
0.901724 seconds (Sampling)
2.3103 seconds (Total)
we see that split Rhat still looks terrible,
print(asym_fit)
Inference for Stan model: gauss_mix_asym_prior.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 25% 50% 75%
mu[1] 1.49 1.71 2.42 -2.75 1.39 2.86 2.91
mu[2] -1.36 1.69 2.40 -2.82 -2.76 -2.72 -1.29
sigma[1] 1.02 0.00 0.04 0.95 1.00 1.02 1.05
sigma[2] 1.03 0.00 0.03 0.97 1.01 1.03 1.05
theta 0.44 0.07 0.11 0.35 0.37 0.39 0.47
lp__ -2156.80 54.45 77.03 -2292.21 -2163.47 -2112.71 -2111.44
97.5% n_eff Rhat
mu[1] 2.98 2 49.94
mu[2] 2.86 2 57.85
sigma[1] 1.10 4000 1.00
sigma[2] 1.10 4000 1.00
theta 0.64 2 7.29
lp__ -2110.31 2 51.79
Samples were drawn using NUTS(diag_e) at Thu Mar 2 15:36:35 2017.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
Investigating the fit output we see that despite our strong, asymmetric prior the posterior distribution is still multimodal and the disfavored mode is still capturing chains,
params1 <- as.data.frame(extract(asym_fit, permuted=FALSE)[,1,])
params2 <- as.data.frame(extract(asym_fit, permuted=FALSE)[,2,])
params3 <- as.data.frame(extract(asym_fit, permuted=FALSE)[,3,])
params4 <- as.data.frame(extract(asym_fit, permuted=FALSE)[,4,])
par(mar = c(4, 4, 0.5, 0.5))
plot(params1$"mu[1]", params1$"mu[2]", col=c_dark_highlight_trans, pch=16, cex=0.8,
xlab="mu1", xlim=c(-3, 3), ylab="mu2", ylim=c(-3, 3))
points(params2$"mu[1]", params2$"mu[2]", col=c_dark_trans, pch=16, cex=0.8)
points(params3$"mu[1]", params3$"mu[2]", col=c_mid_highlight_trans, pch=16, cex=0.8)
points(params4$"mu[1]", params4$"mu[2]", col=c_mid_trans, pch=16, cex=0.8)
lines(0.08*(1:100) - 4, 0.08*(1:100) - 4, col="grey", lw=2)
legend("topright", c("Chain 1", "Chain 2", "Chain 3", "Chain 4"),
fill=c(c_dark_highlight_trans, c_dark_trans,
c_mid_highlight_trans, c_mid_trans), box.lty=0, inset=0.0005)
This example clearly demonstrates the subtle challenge of trying to resolve labeling degeneracy with most non-exchangeable prior distributions. When there are many data the mixture likelihood will be very informative and it can easily overwhelm even strong prior information. The posterior will no longer be symmetric between the degenerate modes, but the disfavored modes will still have sufficiently significant posterior mass to require exploration in the fit.
Reducing the amount data, however, reduces the influence of the likelihood and makes it easier to corral the mixture components in principled directions. With a smaller data set,
N <- 100
z <- rbinom(N, 1, lambda) + 1;
y <- rnorm(N, mu[z], sigma[z]);
stan_rdump(c("N", "y"), file="mix_low.data.R")
input_low_data <- read_rdump("mix_low.data.R")
asym_fit <- stan(file='gauss_mix_asym_prior.stan', data=input_low_data,
chains=4, seed=483892929, refresh=2000)
SAMPLING FOR MODEL 'gauss_mix_asym_prior' NOW (CHAIN 1).
Chain 1, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 1, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 1, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 0.163203 seconds (Warm-up)
0.256404 seconds (Sampling)
0.419607 seconds (Total)
SAMPLING FOR MODEL 'gauss_mix_asym_prior' NOW (CHAIN 2).
Chain 2, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 2, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 2, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 0.161927 seconds (Warm-up)
0.107424 seconds (Sampling)
0.269351 seconds (Total)
SAMPLING FOR MODEL 'gauss_mix_asym_prior' NOW (CHAIN 3).
Chain 3, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 3, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 3, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 0.163485 seconds (Warm-up)
0.169538 seconds (Sampling)
0.333023 seconds (Total)
SAMPLING FOR MODEL 'gauss_mix_asym_prior' NOW (CHAIN 4).
Chain 4, Iteration: 1 / 2000 [ 0%] (Warmup)
Chain 4, Iteration: 1001 / 2000 [ 50%] (Sampling)
Chain 4, Iteration: 2000 / 2000 [100%] (Sampling)
Elapsed Time: 0.158586 seconds (Warm-up)
0.123968 seconds (Sampling)
0.282554 seconds (Total)
we achieve a much better fit,
print(asym_fit)
Inference for Stan model: gauss_mix_asym_prior.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 25% 50% 75% 97.5%
mu[1] 2.80 0.00 0.18 2.45 2.67 2.79 2.91 3.16
mu[2] -2.69 0.00 0.12 -2.93 -2.77 -2.69 -2.60 -2.45
sigma[1] 1.01 0.00 0.16 0.76 0.90 0.99 1.10 1.37
sigma[2] 1.03 0.00 0.10 0.86 0.96 1.02 1.09 1.25
theta 0.33 0.00 0.04 0.25 0.30 0.33 0.36 0.42
lp__ -218.35 0.04 1.63 -222.57 -219.15 -218.02 -217.16 -216.22
n_eff Rhat
mu[1] 4000 1
mu[2] 4000 1
sigma[1] 4000 1
sigma[2] 4000 1
theta 4000 1
lp__ 2106 1
Samples were drawn using NUTS(diag_e) at Thu Mar 2 15:36:37 2017.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
All of our chains converge to the favored mode and our inference become well-behaved,
params1 <- as.data.frame(extract(asym_fit, permuted=FALSE)[,1,])
params2 <- as.data.frame(extract(asym_fit, permuted=FALSE)[,2,])
params3 <- as.data.frame(extract(asym_fit, permuted=FALSE)[,3,])
params4 <- as.data.frame(extract(asym_fit, permuted=FALSE)[,4,])
par(mar = c(4, 4, 0.5, 0.5))
plot(params1$"mu[1]", params1$"mu[2]", col=c_dark_highlight_trans, pch=16, cex=0.8,
xlab="mu1", xlim=c(-3, 3), ylab="mu2", ylim=c(-3, 3))
points(params2$"mu[1]", params2$"mu[2]", col=c_dark_trans, pch=16, cex=0.8)
points(params3$"mu[1]", params3$"mu[2]", col=c_mid_highlight_trans, pch=16, cex=0.8)
points(params4$"mu[1]", params4$"mu[2]", col=c_mid_trans, pch=16, cex=0.8)
lines(0.08*(1:100) - 4, 0.08*(1:100) - 4, col="grey", lw=2)
legend("topright", c("Chain 1", "Chain 2", "Chain 3", "Chain 4"),
fill=c(c_dark_highlight_trans, c_dark_trans,
c_mid_highlight_trans, c_mid_trans), box.lty=0, inset=0.0005)