In this case study we compare a multi-level Gaussian process model to a hierarchical coregionalized Gaussian process model in terms of their predictive performance and their MCMC sampling diagnostics. The case study is inspired by Rob Trangucci’s talk at StanCon 2017 (Trangucci 2017), where he demonstrated a multi-level GP model to predict US presidential votes.
Usually I implement my notebooks in Stan, but since I wanted to test Numpyro for a long time, we will be using it here for a change. Feedback and comments are welcome!
import warnings
import numpy as onp
import pandas as pd
import jax
from jax import vmap
import jax.numpy as np
import jax.random as random
import numpyro
import numpyro.distributions as nd
from numpyro.infer import MCMC, NUTS
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
import palettes
palettes.set_theme()
warnings.filterwarnings('ignore')
The data set consists of counts of votes for US parties between 1976 and 2016 adopted from the data of Rob’s talk. The counts are available for every state in the US. We can either model these counts directly, or, following Trangucci (2017), we model the proportion of votes one party received in a state.
year dem rep region state proportion
0 1976 44058 71555 Mountain West AK 0.381082
1 1980 41842 86112 Mountain West AK 0.327008
2 1984 62007 138377 Mountain West AK 0.309441
3 1988 72584 119251 Mountain West AK 0.378367
4 1992 78294 102000 Mountain West AK 0.434257
.. ... ... ... ... ... ...
545 2000 60481 147947 Mountain West WY 0.290177
546 2004 70776 167629 Mountain West WY 0.296873
547 2008 82868 164958 Mountain West WY 0.334380
548 2012 69286 170962 Mountain West WY 0.288394
549 2016 55973 174419 Mountain West WY 0.242947
[550 rows x 6 columns]
The geographical location of the states suggest a grouping per region. Below we plot the proportion of votes for the democratic party for every state and region for the period between 1976 and 2016.
g = sns.FacetGrid(
D,
col="region",
hue="state",
palette=palettes.discrete_sequential_colors(),
col_wrap=4,
sharex=False,
sharey=False
)
_ = g.map_dataframe(
sns.lineplot, x="year", y="proportion", style="state", markers="o"
)
_ = g.set_axis_labels("Total bill", "Tip")
sns.despine(left=True)
plt.show()
The data suggests a model that considers a general national trend per election year and region, and a baseline offset for democratic votes. Within regions there is also a clear correlation between states that suggests either coregionalisation or a hierarchical prior.
Before we model the data, we implement some utility functions. We use a exponentiated quadratic covariance function throughout the case study, which in Jax can be implemented like this.
def rbf(X1, X2, sigma=1.0, rho=1.0, jitter=1.0e-6):
X1_e = np.expand_dims(X1, 1) / rho
X2_e = np.expand_dims(X2, 0) / rho
d = np.sum((X1_e - X2_e) ** 2, axis=2)
K = sigma * np.exp(-0.5 * d) + np.eye(d.shape[0]) * jitter
return K
To measure the sampling time our models take, we also implement a decorator that wraps a function and times it.
def timer(func):
from timeit import default_timer
def f(*args, **kwargs):
start = default_timer()
res = func(*args, **kwargs)
stop = default_timer()
print(f"Elapsed time: {stop - start}")
return res
return f
For inference, we use Numpyro’s NUTS and add our previously defined decorator to it:
@timer
def sample(model, niter=1000):
rng_key, rng_key_predict = random.split(random.PRNGKey(23))
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=niter,
num_samples=niter,
num_chains=4,
progress_bar=False,
)
mcmc.run(
rng_key,
y,
Xu,
n_times,
time_idxs,
n_states,
state_idxs,
n_regions,
region_idxs,
train_idxs,
n_states_per_region,
)
return mcmc
Finally, we make a prediction by sampling from the joint posterior.
def predict(mcmc, n=5):
samples = mcmc.get_samples()
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
vmap_args = (
random.split(rng_key_predict, samples["nu"].shape[0]),
samples["nu"],
samples["eta"],
)
preds_map = jax.vmap(
lambda key, nu, eta: np.mean(
nd.Beta(eta, nu - eta).sample(key, sample_shape=(n,)), axis=0
)
)
preds = preds_map(*vmap_args)
means = np.mean(preds, axis=0)
quantiles = np.percentile(preds, [5.0, 95.0], axis=0)
return means, quantiles
We implement both models by regressing the proportions on the electoral year. For that we convert the years to numerical values first and then sort the data frame by region, state and year.
years = pd.to_datetime(D.year)
years = (years - years.min()) / pd.Timedelta(1)
D.loc[:, ("year_numerical")] = years
D.loc[:, ("region_idxs")] = D["region"].apply(
lambda x: list(D.region.unique()).index(x)
)
D.loc[:, ("state_idxs")] = D["state"].apply(
lambda x: list(D.state.unique()).index(x)
)
D.loc[:, ("time_idxs")] = D["year_numerical"].apply(
lambda x: list(D.year_numerical.unique()).index(x)
)
D = D.sort_values(["region", "state", "year_numerical"])
D
year dem rep ... region_idxs state_idxs time_idxs
154 1976 1014714 1183958 ... 9 14 0
155 1980 844197 1255656 ... 9 14 1
156 1984 841481 1377230 ... 9 14 2
157 1988 860643 1297763 ... 9 14 3
158 1992 848420 989375 ... 9 14 4
.. ... ... ... ... ... ... ...
512 2000 1247652 1108864 ... 3 46 6
513 2004 1510201 1304894 ... 3 46 7
514 2008 1750848 1229216 ... 3 46 8
515 2012 1755396 1290670 ... 3 46 9
516 2016 1742718 1221747 ... 3 46 10
[550 rows x 10 columns]
X = np.array(D["year_numerical"].values).reshape(-1, 1)
Xu = np.unique(X).reshape(-1, 1)
y = np.array(D["proportion"].values)
In addition, we need to compute the indexes of the time points, states and regions, such that we can correctly assign everything.
time_idxs = np.array(D["time_idxs"].values)
n_times = len(np.unique(time_idxs))
state_idxs = np.array(D["state_idxs"].values)
n_states = len(np.unique(state_idxs))
region_idxs = np.array(D["region_idxs"].values)
n_regions = len(np.unique(region_idxs))
n_states_per_region = np.array(
D.groupby(["region", "state"]).size().groupby("region").size()
)
Since we also want to compare predictive performance, we treat one of 11 data points as test point and the other 10 as train indexes. That means for a data set of 550 observations (50 * 11 states), we use 50 test points.
To model the data, we will need to choose a suitable likelihood. A fitting one having the same support as our data \(Y\) is the Beta distribution:
\[ P(y \mid \alpha, \beta) = \frac{1}{B(\alpha, \beta)} y^{\alpha - 1} (1 - y)^{\beta - 1} \] For regression modelling, this one is a bit awkward to use, so we use its alternative parameterization following the Stan manual:
\[ P(y \mid \mu, \kappa) = \frac{1}{B\left(\mu\kappa, (1 - \mu)\kappa\right)} y^{\mu\kappa -1} (1 - y)^{(1 - \mu)\kappa -1} \] Unfortunately, Numpyro does not have this parameterization, but this is not a problem as we can easily reparameterize.
For Gaussian likelihoods and given a posterior sample of the parameters of the covariance function, prediction boils down to Gaussian conditioning. For non-Gaussian likelihoods, we would first infer the posterior of the latent GP \(P(f \mid y, X)\) and given this compute the posterior predictive
\[ P(y^* \mid y, X, X^*) = \int \int P(y^* \mid f^*) P(f^* \mid f, X, X^*) P)(f \mid y, X) df df^* \] We can directly model this via the joint distribution of the observed (training data) and unobserved (testing data) responses, i.e.,
\[ P(y, y^* \mid X, X^*) \] Specifically, we the generative model is defined on the entire set of predictors \(X, X*\), but the likelihood only considers observed values \(y\).
Our first model will be an adoption of Rob’s model, with the exception that we will use GPs entirely. The generative model reads:
\[\begin{aligned} y_{rst} &\sim \text{Beta}\left(\eta_{rst}, \nu - \eta_{rst} \right) \\ \eta_{rst} & = \nu * \text{logit}^{-1}\left(\mu + f_{rt} + g_{st} + h_{st} \right) \\ f_r & \sim GP(0, K(\sigma^f, \rho^f)) \\ g_{s} & \sim GP(0, K(\sigma^g, \rho^g )) \\ h_{s} & \sim GP(0, K(\sigma^h, \rho^h_{s})) \\ \nu & \sim \text{Gamma}(5, 0.01) \\ \mu & \sim \text{Normal}(0, 0.5) \end{aligned} \] The other parameters (covariance variances and lengthscales) are drawn from distributions with appropriate support. Notably, we are using a sum of three GPs per state: a regional GP with common variance and lengthscale between all regions, a state-level GP with common variance and length-scale between all states, and another state-level GP with common variance between states, but individual lengthscales for every state to account for everything that is not explained by the other two. The function below implements the generative model above:
def multilevel_model(
y,
Xu,
n_times,
time_idxs,
n_states,
state_idxs,
n_regions,
region_idxs,
train_idxs,
n_states_per_region,
):
n = 3
sigma_tot = numpyro.sample("sigma_tot", nd.Gamma(3.0, 3.0))
sigma_prop = numpyro.sample("sigma_prop", nd.Dirichlet(np.repeat(2.0, n)))
sigmas = n * sigma_prop * sigma_tot
rho_region_gp = numpyro.sample("rho_region_gp", nd.LogNormal(0.0, 1.0))
K_region_gp = rbf(Xu, Xu, sigmas[0], rho_region_gp)
L_region_gp = np.linalg.cholesky(K_region_gp)
with numpyro.plate("regions", size=n_regions):
f_reg_tilde = numpyro.sample(
"f_reg_tilde", nd.Normal(loc=np.zeros((Xu.shape[0], 1)))
)
f_reg = numpyro.deterministic("f_reg", L_region_gp @ f_reg_tilde)
f_reg = np.repeat(f_reg, n_states_per_region, axis=1)
f_reg = f_reg.T.reshape(-1)
rho_state_gp = numpyro.sample("rho_state_gp", nd.LogNormal(0.0, 1.0))
K_state_gp = rbf(Xu, Xu, sigmas[1], rho_state_gp)
L_state_gp = np.linalg.cholesky(K_state_gp)
with numpyro.plate("states", size=n_states):
f_stat_tilde = numpyro.sample(
"f_stat_tilde", nd.Normal(loc=np.zeros((Xu.shape[0], 1)))
)
f_stat = numpyro.deterministic("f_stat", L_state_gp @ f_stat_tilde)
f_stat = f_stat.reshape(-1)
with numpyro.plate("states", size=n_states):
rho = numpyro.sample("rho", nd.LogNormal(0.0, 1.0))
K = rbf(Xu, Xu, sigmas[2], rho)
L = np.linalg.cholesky(K)
f_tilde = numpyro.sample(
"f_tilde", nd.Normal(loc=np.zeros((Xu.shape[0], 1)))
)
f = numpyro.deterministic("f", L @ f_tilde)
f = f.reshape(-1)
nu = numpyro.sample("nu", nd.Gamma(5.0, 0.01))
mu = numpyro.sample("mu", nd.Normal(0.0, 0.5))
eta = numpyro.deterministic(
"eta", nu * jax.scipy.special.expit(mu + f_reg + f_stat + f)
)
numpyro.sample(
"y", nd.Beta(eta[train_idxs], nu - eta[train_idxs]), obs=y[train_idxs]
)
Having defined the model, posterior inference is fairly easy. We also compute MCMC diagnostics to check if the posteriors mix, no divergences occur, etc.
Elapsed time: 199.51950429099998
Data variables: (12/14)
eta (eta_dim_0) float64 1.053 1.053 1.053 ... 1.054 1.054 1.054
f (f_dim_0, f_dim_1) float64 1.009 1.012 1.003 ... 1.006 1.002
f_reg (f_reg_dim_0, f_reg_dim_1) float64 1.001 1.0 ... 1.001 1.001
f_reg_tilde (f_reg_tilde_dim_0, f_reg_tilde_dim_1) float64 1.001 ... 1.0
f_stat (f_stat_dim_0, f_stat_dim_1) float64 1.008 1.016 ... 1.003
f_stat_tilde (f_stat_tilde_dim_0, f_stat_tilde_dim_1) float64 1.004 ......
... ...
nu float64 1.054
rho (rho_dim_0) float64 0.9998 0.9998 1.0 ... 1.0 1.001 1.001
rho_region_gp float64 1.002
rho_state_gp float64 1.004
sigma_prop (sigma_prop_dim_0) float64 1.008 1.052 1.039
sigma_tot float64 1.004
Data variables: (12/14)
eta (eta_dim_0) float64 75.37 75.63 75.35 ... 74.75 74.73 74.6
f (f_dim_0, f_dim_1) float64 531.1 408.5 ... 812.8 4.051e+03
f_reg (f_reg_dim_0, f_reg_dim_1) float64 2.77e+03 ... 3.26e+03
f_reg_tilde (f_reg_tilde_dim_0, f_reg_tilde_dim_1) float64 2.77e+03 .....
f_stat (f_stat_dim_0, f_stat_dim_1) float64 620.7 ... 4.531e+03
f_stat_tilde (f_stat_tilde_dim_0, f_stat_tilde_dim_1) float64 2.344e+03...
... ...
nu float64 74.02
rho (rho_dim_0) float64 8.034e+03 7.163e+03 ... 8.305e+03
rho_region_gp float64 1.267e+03
rho_state_gp float64 2.108e+03
sigma_prop (sigma_prop_dim_0) float64 353.6 111.4 130.5
sigma_tot float64 809.0
Both of the diagnostics are ok. The R-hats should be ideally around one, while the effective sample sizes should be as high as possible.
Next we make sample from the posterior predictive, compute its mean and quantiles.
means_multilevel, quantiles_multilevel = predict(mcmc_multilevel)
Dm = D.copy()
Dm.loc[:, "y_hat"] = onp.array(means_multilevel)
Dm.loc[:, "y_hat_lower"] = onp.array(quantiles_multilevel[0])
Dm.loc[:, "y_hat_upper"] = onp.array(quantiles_multilevel[1])
Let’s overlay the mean of the posterior predictive to the actual data.
g = sns.FacetGrid(
Dm,
col="region",
hue="state",
col_wrap=4,
palette=palettes.discrete_diverging_colors(),
sharex=False,
sharey=False
)
_ = g.map_dataframe(
sns.lineplot,
x="year",
y="y_hat",
style="state",
marker="o",
alpha=0.5
)
_ = g.map_dataframe(
sns.lineplot,
x="year",
y="proportion",
style="state",
)
_ = g.set_axis_labels("Total bill", "Tip")
sns.despine(left=True)
plt.show()
The predictions also look good. Let’s in the end compute the average absolute error of the test indexes before we turn to fitting the second model:
DeviceArray(0.036768, dtype=float32)
In the model above, we used two GPs for a state, one where the lengthscale of the covariance function is shared among states and one where it is allowed to vary for every state. We can alternatively try to explicitely correlate the state-level GPs using a coregion covariance function. The generative model has the following form:
\[\begin{aligned} y_{rst} &\sim \text{Beta}\left(\eta_{rst}, \nu - \eta_{rst} \right) \\ \eta_{rst} & = \nu * \text{logit}^{-1}\left(\mu + f_{rt} + h_{r[s]t} \right) \\ f_r & \sim GP \left(0, K(\sigma^f, \rho^f)\right) \\ h_{r} & \sim GP\left(0, K(\sigma^h_{r}, \rho^h_{r}), C(\omega_r)\right) \\ \nu & \sim \text{Gamma}(5, 0.01) \\ \mu & \sim \text{Normal}(0, 0.5) \end{aligned} \]
Here, the sampling statement \(h_{r} \sim GP\left(0, K(\sigma^h_{r}, \rho^h_{r}), C(\omega_r)\right)\) represents a “Matrix” GP, i.e., a stochastic process of which any finite sample is a matrix normal random variable with a fixed number of columns. In this case the number of columns is the cardinality of \(\{i : i \in r\}\). Or: supposing region \(r\) consists of six states, then a sampling from \(h_r\) yields always a matrix with six columns.
But how do we sample from it? Fortunately we can reparameterize the sampling statement. We first sample a \(T \times q\) matrix of standard normals, where \(T\) is the number of timepoints and \(q\) is the cardinality of \(\{i : i \in r\}\), then left multiply it with the square of the covariance matrix $ K(1, ^{h}_{r})$ (its Cholesky decomposition) while setting the covariance function’s variance parameter to \(1.0\) (since the variance is not identified) and then right multiply the product with the root of the covariance \(C(\omega)\). We can implement the model like this:
def coregional_model(
y,
Xu,
n_times,
time_idxs,
n_states,
state_idxs,
n_regions,
region_idxs,
train_idxs,
n_states_per_region,
):
rho_region_gp = numpyro.sample("rho_region_gp", nd.LogNormal(0.0, 1.0))
sigma_region_gp = numpyro.sample("sigma_region_gp", nd.LogNormal(0.0, 1.0))
K_region_gp = rbf(Xu, Xu, sigma_region_gp, rho_region_gp)
L_region_gp = np.linalg.cholesky(K_region_gp)
with numpyro.plate("regions", size=n_regions):
f_reg_tilde = numpyro.sample(
"f_reg_tilde", nd.Normal(loc=np.zeros((Xu.shape[0], 1)))
)
f_reg = numpyro.deterministic("f_reg", L_region_gp @ f_reg_tilde)
f_reg = np.repeat(f_reg, n_states_per_region, axis=1)
f_reg = f_reg.T.reshape(-1)
fs = []
for i, q in enumerate(n_states_per_region):
rho = numpyro.sample(f"rho_{i}", nd.LogNormal(0.0, 1.0))
K = rbf(Xu, Xu, 1.0, rho)
L = np.linalg.cholesky(K)
sigma = numpyro.sample(f"sigma_{i}", nd.LogNormal(np.zeros(q), 1.0))
omega = numpyro.sample(f"omega_{str(i)}", nd.LKJCholesky(q, 2.0))
f_tilde = numpyro.sample(
f"f_tilde_{i}", nd.Normal(loc=np.zeros((Xu.shape[0], q)))
)
f = numpyro.deterministic(
f"f_{i}", L @ f_tilde @ (np.diag(np.sqrt(sigma)) @ omega).T
)
fs.append(f.reshape(-1))
f = np.concatenate(fs)
nu = numpyro.sample("nu", nd.Gamma(5.0, 0.01))
mu = numpyro.sample("mu", nd.Normal(0.0, 0.5))
eta = numpyro.deterministic(
"eta", nu * jax.scipy.special.expit(mu + f_reg + f)
)
numpyro.sample(
"y", nd.Beta(eta[train_idxs], nu - eta[train_idxs]), obs=y[train_idxs]
)
As above, after sampling we compute MCMC diagnostics.
Elapsed time: 666.5036768950013
Data variables: (12/57)
eta (eta_dim_0) float64 1.562 1.556 1.563 ... 1.589 1.585 1.58
f_0 (f_0_dim_0, f_0_dim_1) float64 1.007 1.023 ... 1.06 1.195
f_1 (f_1_dim_0, f_1_dim_1) float64 1.004 1.004 ... 1.018 1.034
f_2 (f_2_dim_0, f_2_dim_1) float64 1.002 1.015 ... 1.089 1.045
f_3 (f_3_dim_0, f_3_dim_1) float64 1.001 1.002 1.0 ... 1.0 1.0
f_4 (f_4_dim_0, f_4_dim_1) float64 1.006 1.008 ... 1.002 1.004
... ...
sigma_5 (sigma_5_dim_0) float64 1.001 1.001 1.01 1.003 1.01 1.013
sigma_6 (sigma_6_dim_0) float64 1.138 1.112 1.101 1.15 1.161
sigma_7 (sigma_7_dim_0) float64 1.014 1.054 1.012 1.056 1.011
sigma_8 (sigma_8_dim_0) float64 1.0 1.004 1.001 1.003
sigma_9 (sigma_9_dim_0) float64 1.004 1.014 1.013 1.009
sigma_region_gp float64 1.009
Data variables: (12/57)
eta (eta_dim_0) float64 6.941 6.986 6.958 ... 6.802 6.808 6.842
f_0 (f_0_dim_0, f_0_dim_1) float64 1.354e+03 201.2 ... 14.25
f_1 (f_1_dim_0, f_1_dim_1) float64 1.747e+03 ... 101.4
f_2 (f_2_dim_0, f_2_dim_1) float64 1.9e+03 1.134e+03 ... 58.65
f_3 (f_3_dim_0, f_3_dim_1) float64 1.214e+03 ... 1.338e+03
f_4 (f_4_dim_0, f_4_dim_1) float64 1.191e+03 ... 2.194e+03
... ...
sigma_5 (sigma_5_dim_0) float64 2.087e+03 2.324e+03 ... 1.274e+03
sigma_6 (sigma_6_dim_0) float64 18.66 23.0 24.98 17.59 16.58
sigma_7 (sigma_7_dim_0) float64 579.6 51.71 ... 53.16 1.416e+03
sigma_8 (sigma_8_dim_0) float64 3.407e+03 2.308e+03 ... 2.751e+03
sigma_9 (sigma_9_dim_0) float64 1.542e+03 459.3 1.098e+03 1.236e+03
sigma_region_gp float64 1.016e+03
The diagnostics look fine for this model as well: R-hats are roughly one, no divergences, and sufficiently high effective sample sizes. Next we compute the means and quantiles of the posterior predictive:
means_coregional, quantiles_coregional = predict(mcmc_coregional)
Dc = D.copy()
Dc.loc[:, "y_hat"] = onp.array(means_coregional)
Dc.loc[:, "y_hat_lower"] = onp.array(quantiles_coregional[0])
Dc.loc[:, "y_hat_upper"] = onp.array(quantiles_coregional[1])
We also visualize the predictions again, overlayed over the actual data.
g = sns.FacetGrid(
Dc,
col="region",
hue="state",
col_wrap=4,
palette=palettes.discrete_diverging_colors(),
sharex=False,
sharey=False
)
_ = g.map_dataframe(
sns.lineplot,
x="year",
y="y_hat",
style="state",
marker="o",
alpha=0.5
)
_ = g.map_dataframe(
sns.lineplot,
x="year",
y="proportion",
style="state",
)
_ = g.set_axis_labels("Total bill", "Tip")
sns.despine(left=True)
plt.show()
As before, the predictions look great, and are hardly different to the other model. Hence, lets’ have a look at the average absolute error of the test data:
DeviceArray(0.0300479, dtype=float32)
The prediction is a bit better, but the runtime is significantly worse (probably due to the increased dimensionality of the variance paramaters). Hence Rob’s model seems to be the clear winner if we factor in both aspects.
The notebook is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
Trangucci, Rob. 2017. “Hierarchical Gaussian Processes in Stan.” StanCon. https://mc-stan.org/events/stancon. https://mc-stan.org/events/stancon2017-notebooks/stancon2017-trangucci-hierarchical-gps.pdf.