Finite dimensional mixture models represent a distribution as a weighted sum of a fixed number of \(K\) components. We can either find \(K\) using model selection, i.e. with AIC, BIC, WAIC, etc., or try to automatically infer this number. Nonparametric mixture models do exactly this.
Here we implement a nonparametric Bayesian mixture model using Gibbs sampling. We use a Chinese restaurant process prior and stick-breaking construction to sample from a Dirichlet process (see for instance Hjort et al. (2010), Orbanz (2014), Murphy (2012) and Kamper (2013).
We’ll implement the Gibbs sampler using the CRP ourselves, since Stan doesn’t allow us to do this and then use the stick-breaking construction with Stan using a truncated DP.
suppressMessages({
library(tidyverse)
library(rlang)
library(ggthemes)
library(colorspace)
library(rstan)
library(bayesplot)
library(MCMCpack)
library(e1071)
library(mvtnorm)
})
set.seed(23)
options(mc.cores = parallel::detectCores())
Bayesian mixture models are hierarchical models that can generally be formalized like this:
\[\begin{align*} \boldsymbol \theta_k & \sim \mathcal{G}_0\\ \boldsymbol \pi & \sim \text{Dirichlet}(\boldsymbol \alpha_0)\\ z_i & \sim \text{Discrete}(\boldsymbol \pi)\\ \mathbf{x}_i \mid z_i = k & \sim {P}(\boldsymbol \theta_k) \end{align*}\]
where \(\mathcal{G}_0\) is some base distribution for the model parameters.
The DP on contrast, as any BNP model, puts priors on structures that accomodate infinite sizes. The resulting posteriors give a distribution on structures that grow with new observations. A mixture model using an possibly infinite number of components could look like this:
\[\begin{align*} \mathcal{G} & \sim \mathcal{DP}(\alpha, \mathcal{G}_0)\\ \boldsymbol \theta_i & \sim \mathcal{G}\\ \mathbf{x}_i& \sim {P}(\boldsymbol \theta_i) \end{align*}\]
where \(\mathcal{G}_0\) is the same base measure as above and \(\mathcal{G}\) is a sample from the DP, i.e. also a random measure.
One way, and possibly the easiest, to implement a DPMM is using a Chinese restaurant process (CRP) which is a distribution over partitions. The hierarchical model using a CRP is:
\[\begin{align*} \boldsymbol \theta_k & \sim \mathcal{G}_0 \\ z_i \mid \mathbf{z}_{1:i-1} & \sim \text{CRP} \\ \mathbf{x}_i & \sim P(\boldsymbol \theta_{z_i}) \end{align*}\]
where \(\text{CRP}\) is a prior on possible infinitely many classes. Specifically the CRP is defined as:
\[\begin{align*} P(z_i = k \mid \mathbf{z}_{-i}) = \left\{ \begin{array}{ll} \frac{N_k}{N - 1 + \alpha}\\ \frac{\alpha}{N - 1 + \alpha}\\ \end{array} \right. \end{align*}\]
where \(N_k\) is the number of customers at table \(k\) and \(\alpha\) some hyperparameter.
For the variables of interest, \(\boldsymbol \theta_k\) and \(\boldsymbol z\) the posterior is:
\[\begin{align*} P(\boldsymbol \theta, \boldsymbol z \mid \mathbf{X}) \propto P(\mathbf{X} \mid \boldsymbol \theta, \boldsymbol z ) P(\boldsymbol \theta) P ( \boldsymbol z ) \end{align*}\]
Using a Gibbs sampler, we iterate over the following two steps:
sample \(z_i \sim P(z_i \mid \mathbf{z}_{-i}, \mathbf{X}, \boldsymbol \theta) \propto P(z_i \mid \mathbf{z}_{-i}) P(\mathbf{x}_i \mid \boldsymbol \theta_{z_i}, \mathbf{X}_{-i}, \mathbf{z})\)
sample \(\boldsymbol \theta_k \sim P(\boldsymbol \theta_k \mid \mathbf{z}, \mathbf{X})\)
So we alternate sampling assignments of data to classes and sampling the parameters of the data distribution given the class assignments. The major difference here compared to the finite case is the way of sampling \(z_i\) which we do using the CRP in the infinite case. The CRP itself is defined by $ P(z_i _{-i}) $, so replacing this by a usual finite sample would give us a finite mixture. Evaluation of the likelihoods in the first step is fairly straightforward as we will see. Updating the model parameters in the second step is conditional on every class, an by that also not too hard to do.
With the CRP with put a prior distribution on the possibly infinite number of class assignments. An alternative approach is to use stick-breaking construction. The advantage here is that we could use Stan
using a truncated DP, thus we don’t need to implement the sampler ourselves. If we, instead of putting a CRP prior on the latent labels, put a prior on the possibly infinite sequence of mixing weights \(\boldsymbol \pi\) we arrive at the stick-breaking construction. The hierarchical model now looks like this:
\[\begin{align*} \nu_k &\sim \text{Beta}(1, \alpha) \\ \pi_k & = \nu_k \prod_{j=1}^{k-1} (1 - \nu_j) \\ \boldsymbol \theta_k & \sim G_0 \\ \mathbf{x}_i & \sim \sum_k \pi_k P(\boldsymbol \theta_k) \end{align*}\]
where \(N_k\) is the number of customers at table \(k\) and \(\alpha\) some hyperparameter. The distribution of the mixing weights is sometimes denoted as
\[ \boldsymbol \pi \sim \text{GEM}(\alpha) \]
In the following section, we derive a Gaussian Dirichlet process mixture using the CRP with a Gibbs sampler and the stick-breaking construction using Stan.
In the Gaussian case the hierarchical model using the CRP has the following form:
\[\begin{align*} \boldsymbol \Sigma_k & \sim \mathcal{IW}\\ \boldsymbol \mu_k & \sim \mathcal{N}(\boldsymbol \mu_0, \boldsymbol \Sigma_0) \\ z_i \mid z_{1:i-1} & \sim \text{CRP} \\ \mathbf{x}_i & \sim \mathcal{N}(\boldsymbol \mu_{z_i}, \boldsymbol \Sigma_{z_i}) \end{align*}\]
Let’s derive the Gibbs sampler for a infinite Gaussian mixture using the CRP. First we set data \(\mathbf{X}\) some constants. We create a very simple data set to avoid problems with identifiability and label switching. For a treatment of the topic see Michael Betancourt’s case study. \(n\) is the number of samples, \(p\) is the dimensionality of the Gaussian, \(\alpha\) is the Dirichlet concentration.
Latent class assignments (Z
), the current table index and the number of customers per table:
Parameters of the Gaussians:
Then we create a random assignment of customers to tables with probability \(P(z_i \mid Z_{-i})\), i.e. we use the CRP to put data into classes. Note that we don’t know the number of classes that comes out!
for (i in seq(n))
{
probs <- c(tables / (i - 1 + alpha), alpha / (i - 1 + alpha))
table <- rdiscrete(1, probs)
if (table > curr.tab) {
curr.tab <- curr.tab + 1
tables <- c(tables, 0)
mu <- mvtnorm::rmvnorm(1, c(0, 0), 10 * diag(p))
mus <- rbind(mus, mu)
}
Z[i] <- table
X[i, ] <- mvtnorm::rmvnorm(1, mus[Z[i], ], sigma * diag(p))
tables[table] <- tables[table] + 1
}
Let’s see how many clusters and how many data points per clusters we have.
data.frame(table(Z)) %>%
ggplot() +
geom_col(aes(Z, Freq), width = .35) +
theme_tufte() +
theme(
axis.text = element_text(colour = "black"),
axis.title = element_text(colour = "black")
) +
xlab("Cluster") +
ylab("Frequency")
data.frame(X = X, Z = as.factor(Z)) %>%
ggplot() +
geom_point(aes(X.1, X.2, color = Z)) +
theme_tufte() +
theme(axis.text = element_text(colour = "black")) +
scale_color_discrete_diverging(palette = "Blue-Red", l1 = 1, l2 = 60) +
xlab(NULL) +
ylab(NULL) +
labs(color = "Cluster")
We randomly initialize the cluster assignments and set all customers to table 1. Hyperparameter \(\alpha\) controls the probability of opening a new table.
We assume the covariances to be known.
Base distribution \(\mathcal{G}_0\):
To infer the posterior we would use the Gibbs sampler described above. Here, I am only interested in the most likely assignment, i.e. the MAP of \(Z\).
for (iter in seq(100))
{
for (i in seq(n))
{
# look at data x_i and romove its statistics from the clustering
zi <- zs[i]
tables[zi] <- tables[zi] - 1
if (tables[zi] == 0) {
K <- K - 1
zs[zs > zi] <- zs[zs > zi] - 1
tables <- tables[-zi]
mu.prior <- mu.prior[-zi, ]
}
# compute posterior probabilitites P(z_i \mid z_-i, ...)
no_i <- seq(n)[-i]
probs <- sapply(seq(K), function(k) {
crp <- sum(zs[no_i] == k) / (n + alpha - 1)
lik <- mvtnorm::dmvnorm(X[i, ], mu.prior[k, ], sigma.prior)
crp * lik
})
# compute probability for opening up a new one
crp <- alpha / (n + alpha - 1)
lik <- mvtnorm::dmvnorm(X[i, ], mu0, sigma.prior + sigma0)
probs <- c(probs, crp * lik)
probs <- probs / sum(probs)
# sample new z_i according to the conditional posterior above
z_new <- which.max(probs)
if (z_new > K) {
K <- K + 1
tables <- c(tables, 0)
mu.prior <- rbind(mu.prior, mvtnorm::rmvnorm(1, mu0, sigma0))
}
zs[i] <- z_new
tables[z_new] <- tables[z_new] + 1
# compute conditional posterior P(mu \mid ...)
for (k in seq(K)) {
Xk <- X[zs == k, , drop = FALSE]
lambda <- solve(q.prior + tables[k] * q.prior)
nominator <- tables[k] * q.prior %*% apply(Xk, 2, mean)
mu.prior[k, ] <- mvtnorm::rmvnorm(1, lambda %*% nominator, lambda)
}
}
}
Let’s see if that worked out!
data.frame(X = X, Z = as.factor(zs)) %>%
ggplot() +
geom_point(aes(X.1, X.2, col = Z)) +
theme_tufte() +
theme(axis.text = element_text(colour = "black")) +
scale_color_discrete_diverging(palette = "Blue-Red", l1 = 1, l2 = 60) +
xlab(NULL) +
ylab(NULL) +
labs(color = "Cluster")
Except for the lone guy on top the clustering worked nicely.
In order to make the DPMM with stick-breaking work in Stan, we need to supply a maximum number of clusters \(K\) from which we can choose. Setting \(K=n\) would mean that we allow that every data point defines its own cluster. For the sake of the exercise I’ll set it the maximum number of clusters to \(10\). The hyperparameter \(\alpha\) parameterizes the Beta-distribution which we use to sample stick lengths. We use the same data we already generated above.
The model is a bit more verbose in comparison to the finite case. We only need to add the stick breaking part in the transformed parameters
, the rest stays the same. We again use the LKJ prior for the correlation matrix of the single components and set a fixed prior scale of \(1\). In order to get nice, unimodel posteriors, we also introduce an ordering of the mean values.
data {
int<lower=0> K;
int<lower=0> n;
int<lower=1> p;
row_vector[p] x[n];
real alpha;
}
parameters {
ordered[p] mu[K];
cholesky_factor_corr[p] L;
real <lower=0, upper=1> nu[K];
}
transformed parameters {
simplex[K] pi;
pi[1] = nu[1];
for(j in 2:(K-1))
{
pi[j] = nu[j] * (1 - nu[j - 1]) * pi[j - 1] / nu[j - 1];
}
pi[K] = 1 - sum(pi[1:(K - 1)]);
}
model {
real mix[K];
L ~ lkj_corr_cholesky(5);
nu ~ beta(1, alpha);
for (i in 1:K)
{
mu[i] ~ normal(0, 5);
}
for(i in 1:n)
{
for(k in 1:K)
{
mix[k] = log(pi[k]) + multi_normal_cholesky_lpdf(x[i] | mu[k], L);
}
target += log_sum_exp(mix);
}
}
fit <- stan(
stan.file,
data = list(K = K, n = n, x = X, p = p, alpha = alpha),
iter = 5000,
warmup = 1000,
chains = 1
)
SAMPLING FOR MODEL 'dirichlet_process_mixture' NOW (CHAIN 1).
Chain 1:
Chain 1: Gradient evaluation took 0.000591 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 5.91 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1:
Chain 1:
Chain 1: Iteration: 1 / 5000 [ 0%] (Warmup)
Chain 1: Iteration: 500 / 5000 [ 10%] (Warmup)
Chain 1: Iteration: 1000 / 5000 [ 20%] (Warmup)
Chain 1: Iteration: 1001 / 5000 [ 20%] (Sampling)
Chain 1: Iteration: 1500 / 5000 [ 30%] (Sampling)
Chain 1: Iteration: 2000 / 5000 [ 40%] (Sampling)
Chain 1: Iteration: 2500 / 5000 [ 50%] (Sampling)
Chain 1: Iteration: 3000 / 5000 [ 60%] (Sampling)
Chain 1: Iteration: 3500 / 5000 [ 70%] (Sampling)
Chain 1: Iteration: 4000 / 5000 [ 80%] (Sampling)
Chain 1: Iteration: 4500 / 5000 [ 90%] (Sampling)
Chain 1: Iteration: 5000 / 5000 [100%] (Sampling)
Chain 1:
Chain 1: Elapsed Time: 94.7307 seconds (Warm-up)
Chain 1: 156.761 seconds (Sampling)
Chain 1: 251.491 seconds (Total)
Chain 1:
First we have a look at the traces for the means and mixing weights.
posterior <- extract(fit)
data.frame(posterior$pi) %>%
set_names(paste0("PI_", 1:10)) %>%
tidyr::gather(key, value) %>%
ggplot() +
geom_histogram(aes(x = value, y = ..density.., fill = key), bins = 50) +
scale_y_continuous(breaks = scales::breaks_pretty(n = 3)) +
facet_grid(key ~ ., scales = "free_y") +
theme_tufte() +
theme(
axis.text.x = element_text(colour = "black"),
axis.text.y = element_blank()
) +
scale_fill_discrete_diverging(palette = "Blue-Red", l1 = 1, l2 = 60) +
xlab(NULL) +
ylab(NULL) +
guides(fill = FALSE)
From the plot above it looks as if Stan believes it’s sufficient to use three components as the means of the mixing weights of the seven other components are fairly low or even zero. However, let’s extract all means of the posterior means and assign each data point to a cluster.
probs <- purrr::map_dfc(seq(10), function(i) {
mu <- apply(posterior$mu[, i, ], 2, mean)
mvtnorm::dmvnorm(
X, mu, diag(2)
)
})
probs <- set_names(probs, paste0("Z", seq(10)))
zs.stan <- apply(probs, 1, which.max)
And the final plot:
data.frame(X = X, Z = as.factor(zs.stan)) %>%
ggplot() +
geom_point(aes(X.1, X.2, col = Z)) +
theme_tufte() +
theme(axis.text = element_text(colour = "black")) +
scale_color_discrete_diverging(palette = "Blue-Red", l1 = 1, l2 = 60) +
xlab(NULL) +
ylab(NULL) +
labs(color = "Cluster")
Using a truncated DP with Stan worked even better than our CRP implementation. Here, we managed to give every point its correct label.
The notebook is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
R version 4.0.2 (2020-06-22)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.1 LTS
Matrix products: default
BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-openmp/libopenblasp-r0.3.8.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] mvtnorm_1.1-1 e1071_1.7-4 MCMCpack_1.4-9
[4] MASS_7.3-51.6 coda_0.19-3 bayesplot_1.7.2
[7] rstan_2.21.2 StanHeaders_2.21.0-5 colorspace_1.4-1
[10] ggthemes_4.2.4 rlang_0.4.7 forcats_0.5.0
[13] stringr_1.4.0 dplyr_1.0.1 purrr_0.3.4
[16] readr_1.3.1 tidyr_1.1.1 tibble_3.0.3
[19] ggplot2_3.3.2 tidyverse_1.3.0
loaded via a namespace (and not attached):
[1] mcmc_0.9-7 matrixStats_0.56.0 fs_1.5.0 lubridate_1.7.9
[5] httr_1.4.2 tools_4.0.2 backports_1.1.8 R6_2.4.1
[9] DBI_1.1.0 withr_2.2.0 tidyselect_1.1.0 gridExtra_2.3
[13] prettyunits_1.1.1 processx_3.4.3 curl_4.3 compiler_4.0.2
[17] cli_2.0.2 rvest_0.3.6 quantreg_5.61 SparseM_1.78
[21] xml2_1.3.2 labeling_0.3 scales_1.1.1 ggridges_0.5.2
[25] callr_3.4.3 digest_0.6.25 rmarkdown_2.6 pkgconfig_2.0.3
[29] htmltools_0.5.0 dbplyr_1.4.4 readxl_1.3.1 rstudioapi_0.11
[33] farver_2.0.3 generics_0.0.2 jsonlite_1.7.0 inline_0.3.15
[37] magrittr_1.5 loo_2.3.1 Matrix_1.2-18 Rcpp_1.0.5
[41] munsell_0.5.0 fansi_0.4.1 lifecycle_0.2.0 stringi_1.4.6
[45] yaml_2.2.1 pkgbuild_1.1.0 plyr_1.8.6 grid_4.0.2
[49] blob_1.2.1 parallel_4.0.2 crayon_1.3.4 lattice_0.20-41
[53] haven_2.3.1 hms_0.5.3 knitr_1.29 ps_1.3.4
[57] pillar_1.4.6 codetools_0.2-16 stats4_4.0.2 reprex_0.3.0
[61] glue_1.4.1 evaluate_0.14 V8_3.2.0 RcppParallel_5.0.2
[65] modelr_0.1.8 vctrs_0.3.2 MatrixModels_0.4-1 cellranger_1.1.0
[69] gtable_0.3.0 assertthat_0.2.1 xfun_0.16 broom_0.7.0
[73] class_7.3-17 conquer_1.0.1 ellipsis_0.3.1
Hjort, Nils Lid, Chris Holmes, Peter Müller, and Stephen G Walker. 2010. Bayesian Nonparametrics. Vol. 28. Cambridge University Press.
Kamper, Herman. 2013. “Gibbs sampling for fitting finite and infinite 2013 mixture models.” https://www.kamperh.com/notes/kamper_bayesgmm13.pdf.
Murphy, Kevin P. 2012. Machine Learning: A Probabilistic Perspective. MIT press.
Orbanz, Peter. 2014. “Lecture Notes on Bayesian Nonparametrics.” http://www.gatsby.ucl.ac.uk/~porbanz/papers/porbanz_BNP_draft.pdf.