Finite dimensional mixture models represent a distribution as a weighted sum of a fixed number of \(K\) components. We can either find \(K\) using model selection, i.e. with AIC, BIC, WAIC, etc., or try to automatically infer this number. Nonparametric mixture models do exactly this.

Here we implement a nonparametric Bayesian mixture model using Gibbs sampling. We use a Chinese restaurant process prior and stick-breaking construction to sample from a Dirichlet process (see for instance Hjort et al. (2010), Orbanz (2014), Murphy (2012) and Kamper (2013).

We’ll implement the Gibbs sampler using the CRP ourselves, since Stan doesn’t allow us to do this and then use the stick-breaking construction with Stan using a truncated DP.

suppressMessages({
  library(tidyverse)
  library(rlang)
  library(ggthemes)
  library(colorspace)

  library(rstan)
  library(bayesplot)
  library(MCMCpack)

  library(e1071)
  library(mvtnorm)
})

set.seed(23)
options(mc.cores = parallel::detectCores())

Infinite mixtures

Bayesian mixture models are hierarchical models that can generally be formalized like this:

\[\begin{align*} \boldsymbol \theta_k & \sim \mathcal{G}_0\\ \boldsymbol \pi & \sim \text{Dirichlet}(\boldsymbol \alpha_0)\\ z_i & \sim \text{Discrete}(\boldsymbol \pi)\\ \mathbf{x}_i \mid z_i = k & \sim {P}(\boldsymbol \theta_k) \end{align*}\]

where \(\mathcal{G}_0\) is some base distribution for the model parameters.

The DP on contrast, as any BNP model, puts priors on structures that accomodate infinite sizes. The resulting posteriors give a distribution on structures that grow with new observations. A mixture model using an possibly infinite number of components could look like this:

\[\begin{align*} \mathcal{G} & \sim \mathcal{DP}(\alpha, \mathcal{G}_0)\\ \boldsymbol \theta_i & \sim \mathcal{G}\\ \mathbf{x}_i& \sim {P}(\boldsymbol \theta_i) \end{align*}\]

where \(\mathcal{G}_0\) is the same base measure as above and \(\mathcal{G}\) is a sample from the DP, i.e. also a random measure.

The Chinese restaurant process

One way, and possibly the easiest, to implement a DPMM is using a Chinese restaurant process (CRP) which is a distribution over partitions. The hierarchical model using a CRP is:

\[\begin{align*} \boldsymbol \theta_k & \sim \mathcal{G}_0 \\ z_i \mid \mathbf{z}_{1:i-1} & \sim \text{CRP} \\ \mathbf{x}_i & \sim P(\boldsymbol \theta_{z_i}) \end{align*}\]

where \(\text{CRP}\) is a prior on possible infinitely many classes. Specifically the CRP is defined as:

\[\begin{align*} P(z_i = k \mid \mathbf{z}_{-i}) = \left\{ \begin{array}{ll} \frac{N_k}{N - 1 + \alpha}\\ \frac{\alpha}{N - 1 + \alpha}\\ \end{array} \right. \end{align*}\]

where \(N_k\) is the number of customers at table \(k\) and \(\alpha\) some hyperparameter.

For the variables of interest, \(\boldsymbol \theta_k\) and \(\boldsymbol z\) the posterior is:

\[\begin{align*} P(\boldsymbol \theta, \boldsymbol z \mid \mathbf{X}) \propto P(\mathbf{X} \mid \boldsymbol \theta, \boldsymbol z ) P(\boldsymbol \theta) P ( \boldsymbol z ) \end{align*}\]

Using a Gibbs sampler, we iterate over the following two steps:

  1. sample \(z_i \sim P(z_i \mid \mathbf{z}_{-i}, \mathbf{X}, \boldsymbol \theta) \propto P(z_i \mid \mathbf{z}_{-i}) P(\mathbf{x}_i \mid \boldsymbol \theta_{z_i}, \mathbf{X}_{-i}, \mathbf{z})\)

  2. sample \(\boldsymbol \theta_k \sim P(\boldsymbol \theta_k \mid \mathbf{z}, \mathbf{X})\)

So we alternate sampling assignments of data to classes and sampling the parameters of the data distribution given the class assignments. The major difference here compared to the finite case is the way of sampling \(z_i\) which we do using the CRP in the infinite case. The CRP itself is defined by $ P(z_i _{-i}) $, so replacing this by a usual finite sample would give us a finite mixture. Evaluation of the likelihoods in the first step is fairly straightforward as we will see. Updating the model parameters in the second step is conditional on every class, an by that also not too hard to do.

Stick-breaking construction

With the CRP with put a prior distribution on the possibly infinite number of class assignments. An alternative approach is to use stick-breaking construction. The advantage here is that we could use Stan using a truncated DP, thus we don’t need to implement the sampler ourselves. If we, instead of putting a CRP prior on the latent labels, put a prior on the possibly infinite sequence of mixing weights \(\boldsymbol \pi\) we arrive at the stick-breaking construction. The hierarchical model now looks like this:

\[\begin{align*} \nu_k &\sim \text{Beta}(1, \alpha) \\ \pi_k & = \nu_k \prod_{j=1}^{k-1} (1 - \nu_j) \\ \boldsymbol \theta_k & \sim G_0 \\ \mathbf{x}_i & \sim \sum_k \pi_k P(\boldsymbol \theta_k) \end{align*}\]

where \(N_k\) is the number of customers at table \(k\) and \(\alpha\) some hyperparameter. The distribution of the mixing weights is sometimes denoted as

\[ \boldsymbol \pi \sim \text{GEM}(\alpha) \]

Gaussian DPMMs

In the following section, we derive a Gaussian Dirichlet process mixture using the CRP with a Gibbs sampler and the stick-breaking construction using Stan.

CRP

In the Gaussian case the hierarchical model using the CRP has the following form:

\[\begin{align*} \boldsymbol \Sigma_k & \sim \mathcal{IW}\\ \boldsymbol \mu_k & \sim \mathcal{N}(\boldsymbol \mu_0, \boldsymbol \Sigma_0) \\ z_i \mid z_{1:i-1} & \sim \text{CRP} \\ \mathbf{x}_i & \sim \mathcal{N}(\boldsymbol \mu_{z_i}, \boldsymbol \Sigma_{z_i}) \end{align*}\]

Let’s derive the Gibbs sampler for a infinite Gaussian mixture using the CRP. First we set data \(\mathbf{X}\) some constants. We create a very simple data set to avoid problems with identifiability and label switching. For a treatment of the topic see Michael Betancourt’s case study. \(n\) is the number of samples, \(p\) is the dimensionality of the Gaussian, \(\alpha\) is the Dirichlet concentration.

n <- 100
p <- 2
alpha <- .5

Latent class assignments (Z), the current table index and the number of customers per table:

Z <- integer(n)
X <- matrix(0, n, p)
curr.tab <- 0
tables <- c()

Parameters of the Gaussians:

sigma <- .1
mus <- NULL

Then we create a random assignment of customers to tables with probability \(P(z_i \mid Z_{-i})\), i.e. we use the CRP to put data into classes. Note that we don’t know the number of classes that comes out!

for (i in seq(n))
{
  probs <- c(tables / (i - 1 + alpha), alpha / (i - 1 + alpha))
  table <- rdiscrete(1, probs)
  if (table > curr.tab) {
    curr.tab <- curr.tab + 1
    tables <- c(tables, 0)
    mu <- mvtnorm::rmvnorm(1, c(0, 0), 10 * diag(p))
    mus <- rbind(mus, mu)
  }
  Z[i] <- table
  X[i, ] <- mvtnorm::rmvnorm(1, mus[Z[i], ], sigma * diag(p))
  tables[table] <- tables[table] + 1
}

Let’s see how many clusters and how many data points per clusters we have.

data.frame(table(Z)) %>%
  ggplot() +
  geom_col(aes(Z, Freq), width = .35) +
  theme_tufte() +
  theme(
    axis.text = element_text(colour = "black"),
    axis.title = element_text(colour = "black")
  ) +
  xlab("Cluster") +
  ylab("Frequency")

data.frame(X = X, Z = as.factor(Z)) %>%
  ggplot() +
  geom_point(aes(X.1, X.2, color = Z)) +
  theme_tufte() +
  theme(axis.text = element_text(colour = "black")) +
  scale_color_discrete_diverging(palette = "Blue-Red", l1 = 1, l2 = 60) +
  xlab(NULL) +
  ylab(NULL) +
  labs(color = "Cluster")

Posterior inference using Gibbs sampling

We randomly initialize the cluster assignments and set all customers to table 1. Hyperparameter \(\alpha\) controls the probability of opening a new table.

K <- 1
zs <- rep(K, n)
alpha <- 5
tables <- n

We assume the covariances to be known.

mu.prior <- matrix(c(0, 0), ncol = 2)
sigma.prior <- diag(p)
q.prior <- solve(sigma.prior)

Base distribution \(\mathcal{G}_0\):

sigma0 <- diag(p)
prec0 <- solve(sigma0)
mu0 <- rep(0, p)

To infer the posterior we would use the Gibbs sampler described above. Here, I am only interested in the most likely assignment, i.e. the MAP of \(Z\).

for (iter in seq(100))
{
  for (i in seq(n))
  {
    # look at data x_i and romove its statistics from the clustering
    zi <- zs[i]
    tables[zi] <- tables[zi] - 1
    if (tables[zi] == 0) {
      K <- K - 1
      zs[zs > zi] <- zs[zs > zi] - 1
      tables <- tables[-zi]
      mu.prior <- mu.prior[-zi, ]
    }

    # compute posterior probabilitites P(z_i \mid z_-i, ...)
    no_i <- seq(n)[-i]
    probs <- sapply(seq(K), function(k) {
      crp <- sum(zs[no_i] == k) / (n + alpha - 1)
      lik <- mvtnorm::dmvnorm(X[i, ], mu.prior[k, ], sigma.prior)
      crp * lik
    })

    # compute probability for opening up a new one
    crp <- alpha / (n + alpha - 1)
    lik <- mvtnorm::dmvnorm(X[i, ], mu0, sigma.prior + sigma0)
    probs <- c(probs, crp * lik)
    probs <- probs / sum(probs)

    # sample new z_i according to the conditional posterior above
    z_new <- which.max(probs)
    if (z_new > K) {
      K <- K + 1
      tables <- c(tables, 0)
      mu.prior <- rbind(mu.prior, mvtnorm::rmvnorm(1, mu0, sigma0))
    }
    zs[i] <- z_new
    tables[z_new] <- tables[z_new] + 1

    # compute conditional posterior P(mu \mid ...)
    for (k in seq(K)) {
      Xk <- X[zs == k, , drop = FALSE]
      lambda <- solve(q.prior + tables[k] * q.prior)
      nominator <- tables[k] * q.prior %*% apply(Xk, 2, mean)
      mu.prior[k, ] <- mvtnorm::rmvnorm(1, lambda %*% nominator, lambda)
    }
  }
}

Let’s see if that worked out!

data.frame(X = X, Z = as.factor(zs)) %>%
  ggplot() +
  geom_point(aes(X.1, X.2, col = Z)) +
  theme_tufte() +
  theme(axis.text = element_text(colour = "black")) +
  scale_color_discrete_diverging(palette = "Blue-Red", l1 = 1, l2 = 60) +
  xlab(NULL) +
  ylab(NULL) +
  labs(color = "Cluster")

Except for the lone guy on top the clustering worked nicely.

Stick breaking construction

In order to make the DPMM with stick-breaking work in Stan, we need to supply a maximum number of clusters \(K\) from which we can choose. Setting \(K=n\) would mean that we allow that every data point defines its own cluster. For the sake of the exercise I’ll set it the maximum number of clusters to \(10\). The hyperparameter \(\alpha\) parameterizes the Beta-distribution which we use to sample stick lengths. We use the same data we already generated above.

K <- 10
alpha <- 2

The model is a bit more verbose in comparison to the finite case. We only need to add the stick breaking part in the transformed parameters, the rest stays the same. We again use the LKJ prior for the correlation matrix of the single components and set a fixed prior scale of \(1\). In order to get nice, unimodel posteriors, we also introduce an ordering of the mean values.

stan.file <- "_models/dirichlet_process_mixture.stan"
cat(readLines(stan.file), sep = "\n")
data {
    int<lower=0> K;
    int<lower=0> n;
    int<lower=1> p;
    row_vector[p] x[n];
    real alpha;
}

parameters {        
    ordered[p] mu[K];
    cholesky_factor_corr[p] L;
    real <lower=0, upper=1> nu[K];
}

transformed parameters {
  simplex[K] pi;
  pi[1] = nu[1];
  for(j in 2:(K-1)) 
  {
      pi[j] = nu[j] * (1 - nu[j - 1]) * pi[j - 1] / nu[j - 1]; 
  }

  pi[K] = 1 - sum(pi[1:(K - 1)]);
}

model {
    real mix[K];

    L ~ lkj_corr_cholesky(5);
    nu ~ beta(1, alpha);    
    for (i in 1:K) 
    {
        mu[i] ~ normal(0, 5);
    }

  
    for(i in 1:n) 
    {
        for(k in 1:K) 
        {
            mix[k] = log(pi[k]) + multi_normal_cholesky_lpdf(x[i] | mu[k], L);
        }
        target += log_sum_exp(mix);
    }
}
fit <- stan(
  stan.file,
  data = list(K = K, n = n, x = X, p = p, alpha = alpha),
  iter = 5000,
  warmup = 1000,
  chains = 1
)

SAMPLING FOR MODEL 'dirichlet_process_mixture' NOW (CHAIN 1).
Chain 1: 
Chain 1: Gradient evaluation took 0.000591 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 5.91 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1: 
Chain 1: 
Chain 1: Iteration:    1 / 5000 [  0%]  (Warmup)
Chain 1: Iteration:  500 / 5000 [ 10%]  (Warmup)
Chain 1: Iteration: 1000 / 5000 [ 20%]  (Warmup)
Chain 1: Iteration: 1001 / 5000 [ 20%]  (Sampling)
Chain 1: Iteration: 1500 / 5000 [ 30%]  (Sampling)
Chain 1: Iteration: 2000 / 5000 [ 40%]  (Sampling)
Chain 1: Iteration: 2500 / 5000 [ 50%]  (Sampling)
Chain 1: Iteration: 3000 / 5000 [ 60%]  (Sampling)
Chain 1: Iteration: 3500 / 5000 [ 70%]  (Sampling)
Chain 1: Iteration: 4000 / 5000 [ 80%]  (Sampling)
Chain 1: Iteration: 4500 / 5000 [ 90%]  (Sampling)
Chain 1: Iteration: 5000 / 5000 [100%]  (Sampling)
Chain 1: 
Chain 1:  Elapsed Time: 94.7307 seconds (Warm-up)
Chain 1:                156.761 seconds (Sampling)
Chain 1:                251.491 seconds (Total)
Chain 1: 

First we have a look at the traces for the means and mixing weights.

posterior <- extract(fit)

data.frame(posterior$pi) %>%
  set_names(paste0("PI_", 1:10)) %>%
  tidyr::gather(key, value) %>%
  ggplot() +
  geom_histogram(aes(x = value, y = ..density.., fill = key), bins = 50) +
  scale_y_continuous(breaks = scales::breaks_pretty(n = 3)) +
  facet_grid(key ~ ., scales = "free_y") +
  theme_tufte() +
  theme(
    axis.text.x = element_text(colour = "black"),
    axis.text.y = element_blank()
  ) +
  scale_fill_discrete_diverging(palette = "Blue-Red", l1 = 1, l2 = 60) +
  xlab(NULL) +
  ylab(NULL) +
  guides(fill = FALSE)

From the plot above it looks as if Stan believes it’s sufficient to use three components as the means of the mixing weights of the seven other components are fairly low or even zero. However, let’s extract all means of the posterior means and assign each data point to a cluster.

probs <- purrr::map_dfc(seq(10), function(i) {
  mu <- apply(posterior$mu[, i, ], 2, mean)
  mvtnorm::dmvnorm(
    X, mu, diag(2)
  )
})
probs <- set_names(probs, paste0("Z", seq(10)))
zs.stan <- apply(probs, 1, which.max)

And the final plot:

data.frame(X = X, Z = as.factor(zs.stan)) %>%
  ggplot() +
  geom_point(aes(X.1, X.2, col = Z)) +
  theme_tufte() +
  theme(axis.text = element_text(colour = "black")) +
  scale_color_discrete_diverging(palette = "Blue-Red", l1 = 1, l2 = 60) +
  xlab(NULL) +
  ylab(NULL) +
  labs(color = "Cluster")

Using a truncated DP with Stan worked even better than our CRP implementation. Here, we managed to give every point its correct label.

Session info

sessionInfo()
R version 4.0.2 (2020-06-22)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.1 LTS

Matrix products: default
BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-openmp/libopenblasp-r0.3.8.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] mvtnorm_1.1-1        e1071_1.7-4          MCMCpack_1.4-9      
 [4] MASS_7.3-51.6        coda_0.19-3          bayesplot_1.7.2     
 [7] rstan_2.21.2         StanHeaders_2.21.0-5 colorspace_1.4-1    
[10] ggthemes_4.2.4       rlang_0.4.7          forcats_0.5.0       
[13] stringr_1.4.0        dplyr_1.0.1          purrr_0.3.4         
[16] readr_1.3.1          tidyr_1.1.1          tibble_3.0.3        
[19] ggplot2_3.3.2        tidyverse_1.3.0     

loaded via a namespace (and not attached):
 [1] mcmc_0.9-7         matrixStats_0.56.0 fs_1.5.0           lubridate_1.7.9   
 [5] httr_1.4.2         tools_4.0.2        backports_1.1.8    R6_2.4.1          
 [9] DBI_1.1.0          withr_2.2.0        tidyselect_1.1.0   gridExtra_2.3     
[13] prettyunits_1.1.1  processx_3.4.3     curl_4.3           compiler_4.0.2    
[17] cli_2.0.2          rvest_0.3.6        quantreg_5.61      SparseM_1.78      
[21] xml2_1.3.2         labeling_0.3       scales_1.1.1       ggridges_0.5.2    
[25] callr_3.4.3        digest_0.6.25      rmarkdown_2.6      pkgconfig_2.0.3   
[29] htmltools_0.5.0    dbplyr_1.4.4       readxl_1.3.1       rstudioapi_0.11   
[33] farver_2.0.3       generics_0.0.2     jsonlite_1.7.0     inline_0.3.15     
[37] magrittr_1.5       loo_2.3.1          Matrix_1.2-18      Rcpp_1.0.5        
[41] munsell_0.5.0      fansi_0.4.1        lifecycle_0.2.0    stringi_1.4.6     
[45] yaml_2.2.1         pkgbuild_1.1.0     plyr_1.8.6         grid_4.0.2        
[49] blob_1.2.1         parallel_4.0.2     crayon_1.3.4       lattice_0.20-41   
[53] haven_2.3.1        hms_0.5.3          knitr_1.29         ps_1.3.4          
[57] pillar_1.4.6       codetools_0.2-16   stats4_4.0.2       reprex_0.3.0      
[61] glue_1.4.1         evaluate_0.14      V8_3.2.0           RcppParallel_5.0.2
[65] modelr_0.1.8       vctrs_0.3.2        MatrixModels_0.4-1 cellranger_1.1.0  
[69] gtable_0.3.0       assertthat_0.2.1   xfun_0.16          broom_0.7.0       
[73] class_7.3-17       conquer_1.0.1      ellipsis_0.3.1    

References

Hjort, Nils Lid, Chris Holmes, Peter Müller, and Stephen G Walker. 2010. Bayesian Nonparametrics. Vol. 28. Cambridge University Press.

Kamper, Herman. 2013. “Gibbs sampling for fitting finite and infinite 2013 mixture models.” https://www.kamperh.com/notes/kamper_bayesgmm13.pdf.

Murphy, Kevin P. 2012. Machine Learning: A Probabilistic Perspective. MIT press.

Orbanz, Peter. 2014. “Lecture Notes on Bayesian Nonparametrics.” http://www.gatsby.ucl.ac.uk/~porbanz/papers/porbanz_BNP_draft.pdf.