In a previous case study we discussed coregional GPs and their hierarchical variants for forecasting election results. Recently, several RNN-based methods have been proposed for time-series prediction (Salinas et al. 2020), (Salinas et al. 2019). Hence, this notebook implements two methods based on LSTMs and compares their performance to the GP models above.
The notebook uses NumPyro for probabilisitc inference as well as Haiku as a neural network library.
import logging
import pandas as pd
import jax
import jax.numpy as np
import jax.scipy as sp
import jax.random as random
import numpyro
from numpyro.contrib.module import haiku_module
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import Trace_ELBO, SVI
import optax
import haiku as hk
from haiku._src.data_structures import FlatMapping
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
import palettes
sns.set_style("ticks", {'font.family':'serif', 'font.serif':'Times New Roman'})
palettes.set_theme()
logging.basicConfig(level=logging.ERROR, stream=sys.stdout)
As a test data set we will use the same data as in the coregional GP case study that consists of counts of votes for US parties between 1976 and 2016. For more info please see the case study or directly Rob Trangucci’s talk at StanCon from where it was taken.
year dem rep region state proportion
0 1976 44058 71555 Mountain West AK 0.381082
1 1980 41842 86112 Mountain West AK 0.327008
2 1984 62007 138377 Mountain West AK 0.309441
3 1988 72584 119251 Mountain West AK 0.378367
4 1992 78294 102000 Mountain West AK 0.434257
.. ... ... ... ... ... ...
545 2000 60481 147947 Mountain West WY 0.290177
546 2004 70776 167629 Mountain West WY 0.296873
547 2008 82868 164958 Mountain West WY 0.334380
548 2012 69286 170962 Mountain West WY 0.288394
549 2016 55973 174419 Mountain West WY 0.242947
[550 rows x 6 columns]
Below we plot the proportion of votes for the democratic party for every state and region for the period between 1976 and 2016.
g = sns.FacetGrid(
D,
col="region",
hue="state",
palette=palettes.discrete_sequential_colors(),
col_wrap=4,
sharex=False,
sharey=False
)
_ = g.map_dataframe(
sns.lineplot, x="year", y="proportion", style="state", markers="o"
)
_ = g.set_axis_labels("Total bill", "Tip")
sns.despine(left=True)
plt.show()
We will use the same kind of preprocessing as in the GP case study.
years = pd.to_datetime(D.year)
years = (years - years.min()) / pd.Timedelta(1)
D.loc[:, ("year_numerical")] = years
D.loc[:, ("region_idxs")] = D["region"].apply(
lambda x: list(D.region.unique()).index(x)
)
D.loc[:, ("state_idxs")] = D["state"].apply(
lambda x: list(D.state.unique()).index(x)
)
D.loc[:, ("time_idxs")] = D["year_numerical"].apply(
lambda x: list(D.year_numerical.unique()).index(x)
)
D = D.sort_values(["region", "state", "year_numerical"])
D
year dem rep ... region_idxs state_idxs time_idxs
154 1976 1014714 1183958 ... 9 14 0
155 1980 844197 1255656 ... 9 14 1
156 1984 841481 1377230 ... 9 14 2
157 1988 860643 1297763 ... 9 14 3
158 1992 848420 989375 ... 9 14 4
.. ... ... ... ... ... ... ...
512 2000 1247652 1108864 ... 3 46 6
513 2004 1510201 1304894 ... 3 46 7
514 2008 1750848 1229216 ... 3 46 8
515 2012 1755396 1290670 ... 3 46 9
516 2016 1742718 1221747 ... 3 46 10
[550 rows x 10 columns]
We want to model the proportions of votes for the democratic candidate for every state regressed on the year of the election.
Y = D[["state", "year_numerical", "proportion"]].pivot_table(
index="state", values="proportion", columns="year_numerical"
)
X = np.tile(np.array(Y.columns), (Y.shape[0], 1))
Y.head()
year_numerical 0.0 4.0 8.0 ... 32.0 36.0 40.0
state ...
AK 0.381082 0.327008 0.309441 ... 0.389352 0.426847 0.416143
AL 0.566667 0.493237 0.387366 ... 0.391091 0.387838 0.356259
AR 0.650228 0.496803 0.387724 ... 0.398283 0.378456 0.357149
AZ 0.413867 0.317879 0.328833 ... 0.456861 0.453866 0.481100
CA 0.490822 0.405291 0.417755 ... 0.622784 0.618728 0.661282
[5 rows x 11 columns]
Having no experience with RNN-based sequence models whatsoever, we first should try to implement an LSTM cell from scratch to understand the math behind it a bit better. Using Haiku this seems deceptively easy, since we can use Jax to compute gradients automatically and merely need to implement the logic of an LSTM cell. The implementation below follows the documentation of an LSTM cell from Haiku and Zaremba, Sutskever, and Vinyals (2014):
class LSTM(hk.Module):
def __init__(self, name='lstm'):
super().__init__(name=name)
self._w = hk.Linear(4, True, name="w")
self._u = hk.Linear(4, False, name="u")
def __call__(self, x):
outs = [None] * x.shape[-1]
h, c = np.zeros((x.shape[0], 1)), np.zeros((x.shape[0], 1))
for i in range(x.shape[-1]):
o, h, c = self._call(x[:, i, None], h, c)
outs[i] = o
return np.hstack(outs)
def _call(self, x_t, h_t, c_t):
iw, gw, fw, ow = np.split(self._w(x_t), indices_or_sections=4, axis=-1)
iu, gu, fu, ou = np.split(self._u(h_t), indices_or_sections=4, axis=-1)
i = jax.nn.sigmoid(iw + iu)
f = jax.nn.sigmoid(fw + fu + 1.0)
g = np.tanh(gw + gu)
o = jax.nn.sigmoid(ow + ou)
c = f * c_t + i * g
h = o * np.tanh(c)
return h, h, c
def _lstm(x):
module = LSTM()
return module(x)
Let’s test this. To use the LSTM wit Haiku, we need to call transform
and init the model first. We can do that by just supplying the first element of the matrix if time points X
:
key = jax.random.PRNGKey(42)
model = hk.without_apply_rng(hk.transform(_lstm))
params = model.init(key, X[[0], :])
params
FlatMapping({
'lstm/~/w': FlatMapping({
'w': DeviceArray([[-0.5389954, 0.8341133, -0.8763848, 1.3341686]], dtype=float32),
'b': DeviceArray([0., 0., 0., 0.], dtype=float32),
}),
'lstm/~/u': FlatMapping({
'w': DeviceArray([[ 0.6433483 , -0.11852746, 0.88966376, -0.33986157]], dtype=float32),
}),
})
The parameters are simply a mapping from the name of the layer to a matrix of weights. Since we have two linear layers, one of which including a bias, params
has a total of 12 free parameters.
If we now call the model on a subset of data, i.e., the first row of the matrix of time points, we get predictions for all of these.
DeviceArray([[0.00000000e+00, 1.02649376e-01, 1.43961413e-02,
1.56547839e-03, 1.79905328e-04, 2.08155125e-05,
2.41001339e-06, 2.79051449e-07, 3.23111848e-08,
3.74129749e-09, 4.33203112e-10]], dtype=float32)
Let’s compare this to Haiku. Since we used Haiku’s documentation as a basis, the model should be the same, or at least produce the same outputs given the parameters.
def _hk_lstm(x):
module = hk.LSTM(1)
outs, state = hk.dynamic_unroll(module, x, module.initial_state(1))
return outs
Haiku’s LSTM cell needs a tensor of a different shape. Recall, that for our implementation, we plugged in a \(1 \times 11\) vector, representing one row of time points. Haiku requires us to reshape this such that the leading shape has the time points, i.e.:
(11, 1, 1)
DeviceArray([[[ 0.]],
[[ 4.]],
[[ 8.]],
[[12.]],
[[16.]],
[[20.]],
[[24.]],
[[28.]],
[[32.]],
[[36.]],
[[40.]]], dtype=float32)
We call the Haiku implementation as above:
key = jax.random.PRNGKey(42)
hk_model = hk.without_apply_rng(hk.transform(_hk_lstm))
hk_params = hk_model.init(key, X[None, [0], :].T)
hk_params
FlatMapping({
'lstm/linear': FlatMapping({
'w': DeviceArray([[-0.51025134, 0.5771896 , -0.09637077, -0.49400634],
[ 0.454539 , 0.27723938, -1.2919832 , -0.24847537]], dtype=float32),
'b': DeviceArray([0., 0., 0., 0.], dtype=float32),
}),
})
The initial parameters are unfortunately not the same as for our implementation so we cannot compare if we correctly implemented the cell. We can, however, create a mapping of parameters using the values from Haiku’s initial parameter set:
w = FlatMapping({
"w": hk_params['lstm/linear']["w"][[0], :],
'b': hk_params['lstm/linear']["b"]
})
u = FlatMapping({
"w": hk_params['lstm/linear']["w"][[1], :],
})
params = FlatMapping({
'lstm/~/w': w,
'lstm/~/u': u,
})
params
FlatMapping({
'lstm/~/w': FlatMapping({
'w': DeviceArray([[-0.51025134, 0.5771896 , -0.09637077, -0.49400634]], dtype=float32),
'b': DeviceArray([0., 0., 0., 0.], dtype=float32),
}),
'lstm/~/u': FlatMapping({
'w': DeviceArray([[ 0.454539 , 0.27723938, -1.2919832 , -0.24847537]], dtype=float32),
}),
})
Let’s compare the two implementations:
DeviceArray([[0.0000000e+00, 1.3664528e-02, 1.4810549e-03, 1.0234739e-04,
5.3378062e-06, 2.1170729e-07, 6.2559611e-09, 1.3474961e-10,
2.0783364e-12, 2.2680715e-14, 1.7471648e-16]], dtype=float32)
DeviceArray([[[0.0000000e+00, 1.3664528e-02, 1.4810549e-03,
1.0234739e-04, 5.3378062e-06, 2.1170729e-07,
6.2559602e-09, 1.3474957e-10, 2.0783360e-12,
2.2680710e-14, 1.7471644e-16]]], dtype=float32)
Great! The two implementations do the same thing.
In the following we’ll use Haiku’s implementation as it is probably more optimized and less error-prone.
Since, the proportions of the electoral votes are random variables with support on the unit interval, we’ll use a Beta distribution as observation model which will require us to rewrite the LSTM a bit. Since we also want to model a batch of \(50\) timeseries jointly without explicitely introducing some correlation structure or hierarchy as in the GP notebook we will add an “indicator” of which timeseries we are dealing with as a covariable.
Y = D[["state", "year_numerical", "proportion"]].pivot_table(
index="state", values="proportion", columns="year_numerical"
)
E = np.repeat(np.arange(Y.shape[0]), Y.shape[1]).reshape(Y.shape)
X = np.tile(np.array(Y.columns), (1, Y.shape[0], 1))
xs = []
for i in range(X.T.shape[1]):
x = X.T[:, [i], :].flatten()
e = E[i].flatten()
xe = np.vstack([x, e]).T[:, None]
xs.append(xe)
X = np.hstack(xs)
X.shape
(11, 50, 2)
Let’s have a look at a single observation:
DeviceArray([[[ 0., 0.]],
[[ 4., 0.]],
[[ 8., 0.]],
[[12., 0.]],
[[16., 0.]],
[[20., 0.]],
[[24., 0.]],
[[28., 0.]],
[[32., 0.]],
[[36., 0.]],
[[40., 0.]]], dtype=float32)
So, the first axis indexes the time points, the second the observations, and the third the covariables. In the example above we show timeseries \(0\), hence the covariable also has value \(0\). Of course, we need to reshape the response matrix \(Y\) as well:
(11, 50, 1)
DeviceArray([[[0.3810817 ]],
[[0.32700816]],
[[0.30944088]],
[[0.37836683]],
[[0.4342574 ]],
[[0.39571497]],
[[0.32063052]],
[[0.36773717]],
[[0.38935214]],
[[0.4268471 ]],
[[0.41614345]]], dtype=float32)
We implement the Beta-LSTM as a Haiku module with two LSTM cells, two relus and a linear layer at the end. In addition, we add a paramter \(\nu\) to model constant offset and a parameter \(\kappa\) for the precision of the Beta. Specifically, we implement the following model
\[\begin{aligned} y_{st} &\sim \text{Beta}\left(\mu_{st}\kappa, (1.0 - \mu_{st})\kappa \right) \\ \mu_{st} & = \text{logit}^{-1}\left(\nu + \phi_{st} \right) \\ \end{aligned} \] where \(\phi_{st}\) is the output of the LSTM for state \(s\) at time \(t\).
class BetaLSTM(hk.Module):
def __init__(self, name='beta_lstm'):
super().__init__(name=name)
self._net = hk.DeepRNN([
hk.LSTM(40), jax.nn.relu,
hk.LSTM(40), jax.nn.relu,
hk.Linear(1)
])
self._nu = hk.get_parameter('nu', [], init=np.ones)
self._kappa = hk.get_parameter('kappa', [], init=np.zeros)
def __call__(self, x, pr=False):
p = x.shape[1]
if pr:
print(x.shape)
print(x)
outs, state = hk.dynamic_unroll(self._net, x, self._net.initial_state(p))
mu = sp.special.expit(self._nu + outs)
kappa = np.exp(self._kappa)
be = dist.Beta(mu * kappa, (1.0 - mu) * kappa)
return be
def _beta_lstm(x, pr=False):
module = BetaLSTM()
return module(x, pr)
The model is initialized and called as above. To check if the data are provided properly, I’ve added a flag that prints the data and the shape.
beta_model = hk.without_apply_rng(hk.transform(_beta_lstm))
key = jax.random.PRNGKey(42)
params = beta_model.init(key, X)
_ = beta_model.apply(x=X[:, [0], :], pr=True, params=params).sample(key=key)
(11, 1, 2)
[[[ 0. 0.]]
[[ 4. 0.]]
[[ 8. 0.]]
[[12. 0.]]
[[16. 0.]]
[[20. 0.]]
[[24. 0.]]
[[28. 0.]]
[[32. 0.]]
[[36. 0.]]
[[40. 0.]]]
The result is a \(11 \times 1 \times 1\)-dimensional tensor.
DeviceArray([[[0.9994723 ]],
[[0.5204014 ]],
[[0.9997094 ]],
[[0.8825871 ]],
[[0.97088987]],
[[0.04735272]],
[[0.38127446]],
[[0.7802068 ]],
[[0.36030933]],
[[0.16266209]],
[[0.11841452]]], dtype=float32)
As in the notebook using GPs we use all but the 7th time point for training of every timeseries.
We fit this model with the code below. Notice that we can use the train data set by making a prediction on the entire data and then evaluating the log probability of the observed values, i.e. the electoral vote proportions, using the training indexes.
@jax.jit
def nll(params: hk.Params):
beta = beta_model.apply(x=X, params=params)
ll = np.sum(beta.log_prob(Y)[train_idxs, :, :])
return -ll
@jax.jit
def update(params, opt_state):
val, grads = jax.value_and_grad(nll)(params)
updates, new_state = optimizer.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_state, val
optimizer = optax.adam(0.001)
opt_state = optimizer.init(params)
nlls = []
for step in range(10000):
params, opt_state, val = update(params, opt_state)
nlls.append(float(val))
Let’s have a look at the trace of losses, the negative log likelihood.
fig, _ = plt.subplots(1, 1)
ax = sns.lineplot(
data=pd.DataFrame({"y": nlls, "x": range(len(nlls))}),
y="y", x="x",
color='black'
);
ax.set(xlabel="", ylabel="NLL");
plt.show()
We can make predictions of the entire sequence like this:
beta = beta_model.apply(x=X, params=params)
Y_hat = beta.sample(key=key, sample_shape=(100,))
Y_hat = np.mean(Y_hat, axis=0)
print(Y[:, [0], :].T)
[[[0.3810817 0.32700816 0.30944088 0.37836683 0.4342574 0.39571497
0.32063052 0.36773717 0.38935214 0.4268471 0.41614345]]]
[[[0.38854533 0.33375356 0.31334743 0.37708554 0.4431399 0.39919963
0.32724532 0.35664418 0.3853164 0.42825118 0.42239124]]]
As a measure of predictive performance, we compute the mean absolute error on the test instances:
DeviceArray(0.05373552, dtype=float32)
In the GP notebook we modelled correlations between the timeseries using an hierarchical approach and coregionalisation, respectively. We can do a similar thing here by modelling the correlation of those per timepoint using the covariance matrix of a low-rank multivariate Gaussian. A similar approach has been applied by Salinas et al. (2019), only that they apply a copula, while here we use a latent random variable to model the correlations. Specifically, we use a Gaussian process (GP) to encode the correlation between the timeseries for every time point. A sample from the GP is then used as latent predictor of the mean of every time point of every timeseries. The model reads like this:
\[\begin{aligned} y_{st} & \sim \text{Beta}\left(\mu_{st}\kappa, (1.0 - \mu_{st})\kappa \right) \\ \mu_{st} & = \text{logit}^{-1}\left(f_{st} \right) \\ f_{t} & \sim GP(m_t, d_t, v_t) \end{aligned} \] where \(m\), \(d\) and \(v\) are parameters estimated by an LSTM.
We will fit the model variationally using NumPyro. We use the LSTM from above to compute the mean and covariance of the GP like this:
class MultivariateLSTM(hk.Module):
def __init__(self, name='beta_lstm'):
super().__init__(name=name)
self._net = hk.DeepRNN([
hk.LSTM(40), jax.nn.relu,
hk.LSTM(40), jax.nn.relu,
hk.Linear(1 + 1 + 5)
])
def __call__(self, x):
p = x.shape[1]
outs, _ = hk.dynamic_unroll(self._net, x, self._net.initial_state(p))
mu, d, v = np.split(outs, [1, 2], axis=-1)
d, v = np.exp(d), v[:, :, None, :]
return mu, v, d
def _mvn_lstm(x):
module = MultivariateLSTM()
return module(x)
mvn_lstm = hk.transform(_mvn_lstm)
We then use this implementation to define the NumPyro model:
def model(y, x, train_idxs):
nn = haiku_module("nn", mvn_lstm, x=x)
mu, v, d = nn(x)
f = numpyro.sample("f", dist.LowRankMultivariateNormal(mu, v, d))
mu = numpyro.deterministic("mu", sp.special.expit(f))
kappa = numpyro.param("kappa", 1.0, constraint=constraints.positive)
numpyro.sample(
"y",
dist.Beta(mu[train_idxs, :, :] * kappa, (1.0 - mu[train_idxs, :, :]) * kappa),
obs=y[train_idxs, :, :]
)
The model is fairly straightforward. We use the LSTM to predict the mean and covariance of the GP, then sample a random variable from the GP, and then use this sample to compute the mean of the Beta which we use as observation model. Finally we compute the log probabilitiy of the data wrt the estimated mean on the train set.
Estimating this using variational inference requires definition of a variational distribution for the latent GP. We can do this using the same LSTM defined above.
def guide(y, x, train_idxs):
nn = haiku_module("nn", mvn_lstm, x=x)
mu, v, d = nn(x)
numpyro.sample(
"f",
dist.LowRankMultivariateNormal(mu, v, d)
)
Thanks to NumPyro this is optimized using a couple of lines of code.
optimizer = numpyro.optim.Adam(step_size=0.001)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
nsteps = 10000
svi_result = svi.run(key, nsteps, y=Y, x=X, train_idxs=train_idxs)
As before we plot the losses. In this case we optimize the ELBO (before it was the negative log likelihood).
fig, _ = plt.subplots(1, 1)
ax = sns.lineplot(
data=pd.DataFrame({"y": svi_result.losses, "x": range(len(svi_result.losses))}),
y="y", x="x",
color='black'
);
ax.set(xlabel="", ylabel="NLL");
plt.show()
Let’s compare the prediction to the LSTM above:
mu, _, _ = mvn_lstm.apply(x=X, params=svi_result.params['nn$params'], rng=key)
mu = sp.special.expit(mu)
kappa = svi_result.params['kappa']
Y_hat = dist.Beta(mu * kappa, (1.0 - mu) * kappa).sample(key=key, sample_shape=(100,))
Y_hat = np.mean(Y_hat, axis=0)
print(Y[:, [0], :].T)
[[[0.3810817 0.32700816 0.30944088 0.37836683 0.4342574 0.39571497
0.32063052 0.36773717 0.38935214 0.4268471 0.41614345]]]
[[[0.38623443 0.32821316 0.30966443 0.374171 0.44032004 0.39565864
0.3224485 0.33204356 0.38139075 0.42249826 0.41476214]]]
And the error on the test set:
DeviceArray(0.0374837, dtype=float32)
It looks like this model is significantly better, but both have a worse predictive performance than the GPs. This is most likely due to the fact that my experience with LSTMs is pretty much non-existent and we didn’t spend much time optimizing the hyperparameters.
The notebook is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
Salinas, David, Michael Bohlke-Schneider, Laurent Callot, Roberto Medico, and Jan Gasthaus. 2019. “High-Dimensional Multivariate Forecasting with Low-Rank Gaussian Copula Processes.” arXiv Preprint arXiv:1910.03002.
Salinas, David, Valentin Flunkert, Jan Gasthaus, and Tim Januschowski. 2020. “DeepAR: Probabilistic forecasting with autoregressive recurrent networks.” International Journal of Forecasting 36 (3): 1181–91.
Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. 2014. “Recurrent Neural Network Regularization.” arXiv Preprint arXiv:1409.2329.