In this notebook, we will explore stick-breaking constructions for non-parameteric mixture and factor models and fit them via variational inference. Both of these model are traditionally fit via, for instance, slice sampling or Gibbs sampling, but recent developments in probabilistic programming languages are allowing us to fit them easily via automated variational inference. While nonparametric mixture models using the Dirichlet process (DP) as prior are found frequently in the literature, factor models using the Indian Buffet process (IBP) have received less attention.

I have long been enthusiastic about nonparametric Bayesian models but, except for GPs, have found them hard to work with in practice (at least for principled statistical data analysis). Especially Hamiltonian Monte Carlo samplers, as implemented in Stan or NumPyro, where the simulated trajectories often divergence even for “easy” data sets, seem to be not very suited for this class of models, so I am curious of the results of this study.

We implement the models and variational surrogates using Numpyro. Feedback and comments are welcome!

We load some libraries for inference and working with data first.

import pandas as pd

import jax
import jax.numpy as np
import jax.scipy as sp
import jax.random as random

import numpyro
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.distributions.transforms import OrderedTransform
from numpyro.infer import SVI, Trace_ELBO
import numpyro.optim as optim

import tensorflow_probability.substrates.jax.distributions as tfp_jax

import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
import palettes

palettes.set_theme()
numpyro.set_host_device_count(4)

Check if JAX recognizes the four set cores.

jax.local_device_count()
4

Infinite mixture models

Nonparametric Bayesian mixture models implement a observation model that consists of infinitely many component distributions. Using the stick-breaking construction of a Dirichlet process (Ghosal and Van der Vaart (2017), Blei and Jordan (2006)), the generative model we assume here has the following form

\[\begin{align} \beta & \sim \text{Gamma}(1.0, 1.0) \\ \nu_k & \sim\text{Beta}(1.0, \beta) \\ \pi_k & = \nu_k \prod_{j=1}^{k-1} (1 - \nu_j) \\ \mu_k & \sim \text{Normal}(0.0, 1.0)\\ \sigma_k & \sim \text{Normal}^+(1.0) \\ y_i & \sim \sum_k^{\infty} \pi_k \text{Normal}(\mu_k, \sigma_k) \end{align}\]

where \(k\) indexes a component distribution and \(i\) indexes a data point.

Data

We begin by simulating a data set consisting of three components and 1000 samples. The simulated data set should be fairly easy to fit, real world data is usually significantly more noisy.

n_samples = 1000
K = 3

means = np.linspace(-2.0, 2.0, K)
standard_deviations = np.array([0.25, 0.2, 0.3])

Z = random.randint(
    key=random.PRNGKey(23),
    minval=0,
    maxval=K, 
    shape=(n_samples,)
)

eps = dist.Normal(0.0, 1.0).sample(
    random.PRNGKey(23),
    sample_shape=(n_samples,)
)
y = means[Z] + eps * standard_deviations[Z]

The three components are centered around \(-2\), \(0\) and \(2\) with a low standard deviation.

df = pd.DataFrame(np.vstack([y, Z]).T, columns=["y", "z"])
_ = plt.figure(figsize=(15, 5))
_ = sns.histplot(
    x="y",
    hue="z",
    data=df,
    palette=palettes.discrete_sequential_colors(),
    legend=False,
    bins=50,
)
plt.show()

Model

We will truncate the stick at a sufficiently large \(K\) (the error of this truncation is, as I believe to recall from a reference I cannot find anymore, negligible).

K_stick = 10

Next we define a routine to compute the mixing weights \(\pi\) from \(\nu\).

def sample_stick(nu):
    ones = np.ones((*nu.shape[:-1], 1))
    rem = np.concatenate(
      [ones, np.cumprod(1 - nu, axis=-1)[:-1]],
      axis=-1
    )
    mix_probs = nu * rem
    return mix_probs

We will infer the posterior distributions over the latent variables using the marginal mixture representation above. We define the prior model first. In comparison to the generative model defined above, we will order the mean variables.

def prior():
    beta = numpyro.sample("beta", dist.Gamma(1.0, 1.0))
    nu = numpyro.sample(
        "nu",
        dist.Beta(
          concentration1=np.ones(K_stick), 
          concentration0=beta
        )
    )
    pi = numpyro.deterministic("pi", sample_stick(nu))
    mu = numpyro.sample(
        "mu",
        dist.TransformedDistribution(
            dist.Normal(loc=np.zeros(K_stick)), 
            OrderedTransform()
        ),
    )
    sigma = numpyro.sample("sigma", dist.HalfNormal(scale=np.ones(K_stick)))

    return pi, mu, sigma

We then define the log-likelihood function:

def log_likelihood(y, pi, mu, sigma):
    lpdf_weights = np.log(pi)
    lpdf_components = dist.Normal(loc=mu, scale=sigma).log_prob(y[:, np.newaxis])

    lpdf = lpdf_weights + lpdf_components
    lpdf = sp.special.logsumexp(lpdf, axis=-1)
    return np.sum(lpdf)

To test the implementation, we can make a draw from the prior and plug it into the likelihood.

with numpyro.handlers.seed(rng_seed=23):
    pi, mu, sigma = prior()
    
log_likelihood(y, pi, mu, sigma)
DeviceArray(-4293.6284, dtype=float32)

The NumPyro model itself is then only a two-liner. We include the likelihood term using a factor in the model specification.

def model():
    pi, mu, sigma = prior()
    numpyro.factor("log_likelihood", log_likelihood(y, pi, mu, sigma))

We approximate the posterior distributions using mean field variational inference which requires us to define surrogate distributions for each of the latent variables. Specifically, we will use the following variational surrogates, adopting from Blei and Jordan (2006)

\[\begin{align} q_{\lambda}(\beta) & = \text{Gamma}(\lambda_{\beta_0}, \lambda_{\beta_1}) \\ q_{\lambda}(\nu_k) & = \text{Beta}(\lambda_{\nu_{k0}}, \lambda_{\nu_{k1}}) \\ q_{\lambda}(\mu_k) & = \text{Normal}(\lambda_{\mu_{k0}}, \lambda_{\mu_{k1}}) \\ q_{\lambda}(\sigma_k) & = \text{Normal}^+(\lambda_{\sigma_{k}})\\ \end{align}\]

where we constraint the scale parameters to be positive and the vector \(\mu\) to be ordered.

def guide():
    q_beta_concentration = numpyro.param(
        "beta_concentration", init_value=1.0, constraint=constraints.positive
    )
    q_beta_rate = numpyro.param(
        "beta_rate", init_value=1.0, constraint=constraints.positive
    )
    q_beta = numpyro.sample(
        "beta", dist.Gamma(q_beta_concentration, q_beta_rate)
    )

    q_nu_concentration1 = numpyro.param(
        "nu_concentration1",
        init_value=np.ones(K_stick),
        constraint=constraints.positive,
    )
    q_nu_concentration0 = numpyro.param(
        "nu_concentration0",
        init_value=np.ones(K_stick) * 2.0,
        constraint=constraints.positive,
    )
    q_nu = numpyro.sample(
        "nu",
        dist.Beta(
          concentration1=q_nu_concentration1, 
          concentration0=q_nu_concentration0
        )
    )

    q_mu_mu = numpyro.param(
        "q_mu_mu", 
        init_value=np.linspace(-2.0, 0.0, K_stick)
    )
    q_mu_sd = numpyro.param(
        "q_mu_sd", 
        init_value=np.ones(K_stick), 
        constraint=constraints.positive
    )
    q_mu = numpyro.sample(
        "mu",
        dist.TransformedDistribution(
            dist.Normal(loc=q_mu_mu, scale=q_mu_sd), 
            OrderedTransform()
        ),
    )

    q_sigma_scale = numpyro.param(
        "q_sigma_scale",
        init_value=np.ones(K_stick),
        constrain=constraints.positive,
    )
    q_sigma = numpyro.sample(
      "sigma", 
      dist.HalfNormal(scale=q_sigma_scale)
    )

Inference

We optimize the variational parameters \(\lambda\) using NumPyro’s stochastic variational inference (Hoffman et al. 2013):

num_steps = 20000

adam = optim.Adam(0.01)
svi = SVI(model, guide, adam, loss=Trace_ELBO(20))
res = svi.run(random.PRNGKey(1), num_steps=num_steps, progress_bar=False)

Let’s have a look at the posterior mixing weights. Ideally most of the density is on the first three weights:

nu = dist.Beta(
    concentration1=res.params["nu_concentration1"],
    concentration0=res.params["nu_concentration0"]
)

sample_stick(nu.mean)
DeviceArray([0.07260582, 0.10314478, 0.10146222, 0.06835859, 0.08807334,
             0.1260116 , 0.11079573, 0.10370086, 0.11305631, 0.11147972],            dtype=float32)

Let’s also visualize the posterior means:

plot_means(res)

Infinite latent feature models

In statistical analysis, we a frequently interested in decomposing a high-dimensional data set into a small number of components. Latent feature models decompose a data set \(Y\) into a binary matrix \(Z\) and a matrix of loadings \(\Psi\)

\[\begin{equation} Y = Z \Psi + \epsilon \end{equation}\]

where \(\epsilon\) is a Gaussian noise matrix. Nonparametric factor models implement a observation model that consists of infinitely many features. Using the stick-breaking construction of a Indian buffet process (Ghosal and Van der Vaart (2017), Doshi et al. (2009), Teh, Grür, and Ghahramani (2007), Paisley and Carin (2009)), we will explore fitting the following generative model using variational inference

\[\begin{align} \beta & \sim \text{Gamma}(1.0, 1.0) \\ \nu_k & \sim\text{Beta}(1.0, \beta) \\ \pi_k & = \prod_{j=1}^{k} \nu_j \\ z_{ik} & \sim \text{Bernoulli}(\pi_k)\\ \Psi & \sim \text{MatrixNormal}(0, I, I) \\ \sigma_k & \sim \text{Normal}^+(1.0) \\ y_i & \sim \text{MvNormal}(z_i^T \Psi, \sigma_k I) \end{align}\]

where \(k\) indexes a component distribution and \(i\) indexes a data point.

Data

As above, we simulate some artificial data for inference. We’ll simulate data with \(Q=10\) dimensions and a latent dimensionality of \(K=5\).

n_samples = 100
K = 5
Q = 10

Next we simulate the probabilities that a latent feature is active and the binary feature matrix itself.

nu = dist.Beta(5.0, 1.0).sample(random.PRNGKey(0), sample_shape=(K,))
pi = np.cumprod(nu)

Z = dist.Bernoulli(probs=pi).sample(
  random.PRNGKey(1),
  sample_shape=(n_samples,)
)

The binary feature matrix is shown below.

_ = plt.figure(figsize=(15, 5))
ax = sns.heatmap(
    Z.T,
    linewidths=0.1,
    cbar=False,
    cmap=["white", "black"],
    linecolor="darkgrey",
)
_ = ax.set_ylabel("Active features")
_ = ax.set_xlabel("Samples")
_ = ax.minorticks_off()
plt.show()

We finally simulate the actual data and visualize it.

psi = dist.Normal(0.0, 3.0).sample(
  random.PRNGKey(0), 
  sample_shape=(K, Q)
)
eps = dist.Normal(0.0, 0.1).sample(
  random.PRNGKey(0), 
  sample_shape=(n_samples, Q)
)
y = Z @ psi + eps
df = pd.DataFrame(y, columns=[f"y{i}" for i in range(Q)])
df = df.melt(var_name="y", value_name="Value")

_ = plt.figure(figsize=(10, 4))
g = sns.FacetGrid(
  df, 
  col="y", 
  col_wrap=5, 
  sharex=False, 
  sharey=False
)
_ = g.map_dataframe(
  sns.histplot, 
  x="Value",
  color="darkgrey"
)
plt.show()

Model

Since our model involves binary latent parameters, we will use a continuous relaxation using the concrete distribution (Maddison, Mnih, and Teh (2017), Jang, Gu, and Poole (2017)) which we, since this distribution is still missing in NumPyro, can use via a TensorFlow Probability wrapper. The prior model is shown below:

temperature = 0.000001
rec_temperature = np.reciprocal(temperature)

def prior():
    nu = numpyro.sample("nu", dist.Beta(np.ones(K), 1.0))
    pi = numpyro.deterministic("pi", np.cumprod(nu))

    Z = numpyro.sample(
        "Z",
        tfp_jax.RelaxedBernoulli(temperature, probs=pi),
        sample_shape=(n_samples,),
    )

    psi = numpyro.sample("psi", dist.Normal(np.zeros((K, Q)), 1.0))
    sigma = numpyro.sample("sigma", dist.HalfNormal(np.ones(Q)))

    return nu, pi, Z, psi, sigma

We implement the log likelihood again as a separate function to be able to easily check the code:

def log_likelihood(y, pi, Z, psi, sigma):
    mean = Z @ psi
    lpdf = dist.Independent(
        dist.Normal(loc=mean, scale=sigma), reinterpreted_batch_ndims=1
    ).log_prob(y)
    
    return np.sum(lpdf)

The actual NumPyro model consists of the prior model and a factor that increments the log-density via a factor:

def model():
    _, pi, Z, psi, sigma = prior()
    numpyro.factor("log_likelihood", log_likelihood(y, pi, Z, psi, sigma))

We approximate the posterior distributions again using mean field variational inference again which requires us to place surrogates over each latent variable. We use the following guides, adopting from Doshi et al. (2009)

\[\begin{align} q_{\lambda}(\beta) & = \text{Gamma}(\lambda_{\beta_0}, \lambda_{\beta_1}) \\ q_{\lambda}(\nu_k) & = \text{Beta}(\lambda_{\nu_{k0}}, \lambda_{\nu_{k1}})\\ q_{\lambda}(z_{ik}) & = \text{RelaxedBernoulli}(\lambda_{z_{ik}})\\ q_{\lambda}(\Psi) & = \text{MatrixNormal}\left(\lambda_{\Psi_{0}}, I, \text{diag}(\lambda_{\Psi_{1}})\right) \\ q_{\lambda}(\sigma_k) & = \text{Normal}^+(\lambda_{\sigma_{k}})\\ \end{align}\]

where we constraint the scale parameters to be positive.

def guide():
    q_nu_concentration1 = numpyro.param(
        "nu_concentration1",
        init_value=np.ones(K),
        constraint=constraints.positive,
    )
    q_nu_concentration0 = numpyro.param(
        "nu_concentration0",
        init_value=np.ones(K) * 2.0,
        constraint=constraints.positive,
    )
    q_nu = numpyro.sample(
        "nu",
        dist.Beta(
            concentration1=q_nu_concentration1,
            concentration0=q_nu_concentration0,
        ),
    )

    z_logits = numpyro.param(
        "z_logits",
        init_value=np.tile(np.linspace(3.0, -3.0, K), (n_samples, 1))
    )
    Z = numpyro.sample(
        "Z",
        dist.TransformedDistribution(
            dist.Logistic(z_logits * rec_temperature, rec_temperature),
            dist.transforms.SigmoidTransform(),
        ),
    )

    q_psi_mu = numpyro.param("psi_mu", init_value=np.zeros((K, Q)))
    q_psi_sd = numpyro.param(
        "psi_sd",
        init_value=np.ones((K, Q)),
        constraint=constraints.positive,
    )
    psi = numpyro.sample(
      "psi", 
      dist.Normal(q_psi_mu, q_psi_sd)
    )

    q_sigma_scale = numpyro.param(
        "sigma_scale",
        init_value=np.ones(Q),
        constrain=constraints.positive
    )
    q_sigma = numpyro.sample(
      "sigma",
      dist.HalfNormal(scale=q_sigma_scale)
    )

Inference

We use NumPyro again for fitting the variational parameters.

adam = optim.Adam(0.01)
svi = SVI(model, guide, adam, loss=Trace_ELBO(20))
res = svi.run(random.PRNGKey(1), num_steps=num_steps, progress_bar=False)

Let’s have a look at the losses

res.losses
DeviceArray([1.1303346e+09, 1.1550102e+09, 1.1401368e+09, ...,
             6.8995693e+08, 6.9592230e+08, 6.8834541e+08], dtype=float32)

Next we check out the means of the variational posterior of \(\Psi\)

res.params["psi_mu"]
DeviceArray([[-3.89895916e-01, -8.57985139e-01, -2.16530534e-04,
               1.04643315e-01, -1.38771808e+00, -5.03355622e-01,
              -8.04338396e-01,  5.97144425e-01, -2.48718232e-01,
               3.25112611e-01],
             [-2.60213166e-01, -1.28231084e+00, -3.03932279e-01,
               7.91628063e-01, -6.10843241e-01, -5.51389754e-01,
              -7.12928653e-01,  2.97082752e-01, -3.62518370e-01,
               3.93981248e-01],
             [-2.91130245e-01, -4.95102197e-01,  3.92822117e-01,
               3.85009885e-01, -7.03571975e-01, -1.01273322e+00,
              -7.60439515e-01,  7.04354525e-01, -6.18651927e-01,
               1.45355659e-02],
             [-3.21906209e-01, -7.70041287e-01, -1.81577265e-01,
              -9.48242471e-02, -9.06787217e-01, -7.64060020e-01,
              -9.44639206e-01,  5.36571085e-01, -4.71601546e-01,
               9.29950625e-02],
             [-2.62866408e-01, -4.60114062e-01,  1.75498813e-01,
               9.17578712e-02, -9.83754516e-01, -6.16102576e-01,
              -1.17928576e+00,  5.13079166e-01,  2.08586529e-02,
               1.05515592e-01]], dtype=float32)

In comparison, these are the real features:

psi
DeviceArray([[-5.8489043e-01, -3.6814418e+00, -2.7533705e+00,
               4.3562627e-01,  5.4449798e-03, -2.2487614e+00,
              -2.4079003e+00,  3.6728698e-01, -5.2767074e-01,
               3.1742225e+00],
             [-7.3477870e-01, -2.3293500e+00,  1.2220554e+00,
               1.6042180e+00, -4.0779777e+00, -6.5401325e+00,
              -8.1589353e-01,  4.0971441e+00,  9.2994606e-01,
               1.0416718e+00],
             [-1.1062758e+00,  8.2450229e-01,  3.6503370e+00,
              -8.9458096e-01, -5.2701826e+00,  4.6718516e+00,
              -1.5832694e+00,  3.8993502e-01, -3.8639338e+00,
              -4.1280193e+00],
             [ 1.7521545e+00,  1.3760606e+00,  1.2504476e-01,
              -1.7918383e-01,  9.2262292e-01,  4.9367529e-01,
              -1.5742345e-01,  3.5032740e+00,  5.1644583e+00,
               1.1037910e+00],
             [-2.8223643e+00, -3.3620679e-01,  2.1731548e+00,
               2.7324674e+00, -2.9335537e+00, -9.5999140e-01,
              -4.2618980e+00, -3.0981445e-01, -2.6359632e+00,
              -1.1806486e+00]], dtype=float32)

Session info

References

Blei, David M, and Michael I Jordan. 2006. “Variational Inference for Dirichlet Process Mixtures.” Bayesian Analysis 1 (1): 121–43.

Doshi, Finale, Kurt Miller, Jurgen Van Gael, and Yee Whye Teh. 2009. “Variational Inference for the Indian Buffet Process.” In Proceedings of the Twelfth International Conference on Artificial Intelligence and Statistics, AISTATS 2009, edited by David A. Van Dyk and Max Welling.

Ghosal, Subhashis, and Aad Van der Vaart. 2017. Fundamentals of Nonparametric Bayesian Inference. Vol. 44. Cambridge University Press.

Hoffman, Matthew D, David M Blei, Chong Wang, and John Paisley. 2013. “Stochastic Variational Inference.” Journal of Machine Learning Research 14 (5).

Jang, Eric, Shixiang Gu, and Ben Poole. 2017. “Categorical Reparameterization with Gumbel-Softmax.” In 5th International Conference on Learning Representations, ICLR 2017.

Maddison, Chris J., Andriy Mnih, and Yee Whye Teh. 2017. “The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables.” In 5th International Conference on Learning Representations, ICLR 2017.

Paisley, John, and Lawrence Carin. 2009. “Nonparametric Factor Analysis with Beta Process Priors.” In Proceedings of the 26th Annual International Conference on Machine Learning, 777–84.

Teh, Yee Whye, Dilan Grür, and Zoubin Ghahramani. 2007. “Stick-Breaking Construction for the Indian Buffet Process.” In Artificial Intelligence and Statistics, 556–63. PMLR.