In a previous case study we discussed coregional GPs and their hierarchical variants for forecasting election results. Recently, several RNN-based methods have been proposed for time-series prediction (Salinas et al. 2020), (Salinas et al. 2019). Hence, this notebook implements two methods based on LSTMs and compares their performance to the GP models above.

The notebook uses NumPyro for probabilisitc inference as well as Haiku as a neural network library.

import logging
import pandas as pd

import jax
import jax.numpy as np
import jax.scipy as sp
import jax.random as random

import numpyro
from numpyro.contrib.module import haiku_module
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import Trace_ELBO, SVI

import optax
import haiku as hk
from haiku._src.data_structures import FlatMapping

import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
import palettes

sns.set_style("ticks", {'font.family':'serif', 'font.serif':'Times New Roman'})
palettes.set_theme()
logging.basicConfig(level=logging.ERROR, stream=sys.stdout)

Presidential elections

As a test data set we will use the same data as in the coregional GP case study that consists of counts of votes for US parties between 1976 and 2016. For more info please see the case study or directly Rob Trangucci’s talk at StanCon from where it was taken.

D = pd.read_csv("./data/elections.csv")
D.loc[:, "proportion"] = D.dem / (D.dem + D.rep)
D
     year    dem     rep         region state  proportion
0    1976  44058   71555  Mountain West    AK    0.381082
1    1980  41842   86112  Mountain West    AK    0.327008
2    1984  62007  138377  Mountain West    AK    0.309441
3    1988  72584  119251  Mountain West    AK    0.378367
4    1992  78294  102000  Mountain West    AK    0.434257
..    ...    ...     ...            ...   ...         ...
545  2000  60481  147947  Mountain West    WY    0.290177
546  2004  70776  167629  Mountain West    WY    0.296873
547  2008  82868  164958  Mountain West    WY    0.334380
548  2012  69286  170962  Mountain West    WY    0.288394
549  2016  55973  174419  Mountain West    WY    0.242947

[550 rows x 6 columns]

Below we plot the proportion of votes for the democratic party for every state and region for the period between 1976 and 2016.

g = sns.FacetGrid(
    D, 
    col="region",
    hue="state",
    palette=palettes.discrete_sequential_colors(),
    col_wrap=4,
    sharex=False, 
    sharey=False
)
_ = g.map_dataframe(
    sns.lineplot, x="year", y="proportion", style="state", markers="o"
)
_ = g.set_axis_labels("Total bill", "Tip")
sns.despine(left=True)
plt.show()

We will use the same kind of preprocessing as in the GP case study.

years = pd.to_datetime(D.year)
years = (years - years.min()) / pd.Timedelta(1)

D.loc[:, ("year_numerical")] = years
D.loc[:, ("region_idxs")] = D["region"].apply(
    lambda x: list(D.region.unique()).index(x)
)
D.loc[:, ("state_idxs")] = D["state"].apply(
    lambda x: list(D.state.unique()).index(x)
)
D.loc[:, ("time_idxs")] = D["year_numerical"].apply(
    lambda x: list(D.year_numerical.unique()).index(x)
)
D = D.sort_values(["region", "state", "year_numerical"])
D
     year      dem      rep  ... region_idxs state_idxs  time_idxs
154  1976  1014714  1183958  ...           9         14          0
155  1980   844197  1255656  ...           9         14          1
156  1984   841481  1377230  ...           9         14          2
157  1988   860643  1297763  ...           9         14          3
158  1992   848420   989375  ...           9         14          4
..    ...      ...      ...  ...         ...        ...        ...
512  2000  1247652  1108864  ...           3         46          6
513  2004  1510201  1304894  ...           3         46          7
514  2008  1750848  1229216  ...           3         46          8
515  2012  1755396  1290670  ...           3         46          9
516  2016  1742718  1221747  ...           3         46         10

[550 rows x 10 columns]

We want to model the proportions of votes for the democratic candidate for every state regressed on the year of the election.

Y = D[["state", "year_numerical", "proportion"]].pivot_table(
    index="state", values="proportion", columns="year_numerical"
)
X = np.tile(np.array(Y.columns), (Y.shape[0], 1))

Y.head()
year_numerical      0.0       4.0       8.0   ...      32.0      36.0      40.0
state                                         ...                              
AK              0.381082  0.327008  0.309441  ...  0.389352  0.426847  0.416143
AL              0.566667  0.493237  0.387366  ...  0.391091  0.387838  0.356259
AR              0.650228  0.496803  0.387724  ...  0.398283  0.378456  0.357149
AZ              0.413867  0.317879  0.328833  ...  0.456861  0.453866  0.481100
CA              0.490822  0.405291  0.417755  ...  0.622784  0.618728  0.661282

[5 rows x 11 columns]
Y = np.array(Y.values)

Implementing an LSTM

Having no experience with RNN-based sequence models whatsoever, we first should try to implement an LSTM cell from scratch to understand the math behind it a bit better. Using Haiku this seems deceptively easy, since we can use Jax to compute gradients automatically and merely need to implement the logic of an LSTM cell. The implementation below follows the documentation of an LSTM cell from Haiku and Zaremba, Sutskever, and Vinyals (2014):

class LSTM(hk.Module):
    def __init__(self, name='lstm'):
        super().__init__(name=name)
        self._w = hk.Linear(4, True, name="w")
        self._u = hk.Linear(4, False, name="u")

    def __call__(self, x):
        outs = [None] * x.shape[-1]
        h, c = np.zeros((x.shape[0], 1)), np.zeros((x.shape[0], 1))
        for i in range(x.shape[-1]):
            o, h, c = self._call(x[:, i, None], h, c)
            outs[i] = o
        return np.hstack(outs)

    def _call(self, x_t, h_t, c_t):
        iw, gw, fw, ow = np.split(self._w(x_t), indices_or_sections=4, axis=-1)
        iu, gu, fu, ou = np.split(self._u(h_t), indices_or_sections=4, axis=-1)

        i = jax.nn.sigmoid(iw + iu)
        f = jax.nn.sigmoid(fw + fu + 1.0)
        g = np.tanh(gw + gu)
        o = jax.nn.sigmoid(ow + ou)
        c = f * c_t + i * g
        h =  o * np.tanh(c)
        return h, h, c


def _lstm(x):
    module = LSTM()
    return module(x)

Let’s test this. To use the LSTM wit Haiku, we need to call transform and init the model first. We can do that by just supplying the first element of the matrix if time points X:

key = jax.random.PRNGKey(42)

model = hk.without_apply_rng(hk.transform(_lstm))
params = model.init(key, X[[0], :])
params
FlatMapping({
  'lstm/~/w': FlatMapping({
                'w': DeviceArray([[-0.5389954,  0.8341133, -0.8763848,  1.3341686]], dtype=float32),
                'b': DeviceArray([0., 0., 0., 0.], dtype=float32),
              }),
  'lstm/~/u': FlatMapping({
                'w': DeviceArray([[ 0.6433483 , -0.11852746,  0.88966376, -0.33986157]], dtype=float32),
              }),
})

The parameters are simply a mapping from the name of the layer to a matrix of weights. Since we have two linear layers, one of which including a bias, params has a total of 12 free parameters.

If we now call the model on a subset of data, i.e., the first row of the matrix of time points, we get predictions for all of these.

model.apply(x=X[[0], :], params=params)
DeviceArray([[0.00000000e+00, 1.02649376e-01, 1.43961413e-02,
              1.56547839e-03, 1.79905328e-04, 2.08155125e-05,
              2.41001339e-06, 2.79051449e-07, 3.23111848e-08,
              3.74129749e-09, 4.33203112e-10]], dtype=float32)

Let’s compare this to Haiku. Since we used Haiku’s documentation as a basis, the model should be the same, or at least produce the same outputs given the parameters.

def _hk_lstm(x):
    module = hk.LSTM(1)    
    outs, state = hk.dynamic_unroll(module, x, module.initial_state(1))
    return outs

Haiku’s LSTM cell needs a tensor of a different shape. Recall, that for our implementation, we plugged in a \(1 \times 11\) vector, representing one row of time points. Haiku requires us to reshape this such that the leading shape has the time points, i.e.:

print(X[None, [0], :].T.shape)
(11, 1, 1)
X[None, [0], :].T
DeviceArray([[[ 0.]],

             [[ 4.]],

             [[ 8.]],

             [[12.]],

             [[16.]],

             [[20.]],

             [[24.]],

             [[28.]],

             [[32.]],

             [[36.]],

             [[40.]]], dtype=float32)

We call the Haiku implementation as above:

key = jax.random.PRNGKey(42)

hk_model = hk.without_apply_rng(hk.transform(_hk_lstm))
hk_params = hk_model.init(key, X[None, [0], :].T)
hk_params
FlatMapping({
  'lstm/linear': FlatMapping({
                   'w': DeviceArray([[-0.51025134,  0.5771896 , -0.09637077, -0.49400634],
                                     [ 0.454539  ,  0.27723938, -1.2919832 , -0.24847537]],            dtype=float32),
                   'b': DeviceArray([0., 0., 0., 0.], dtype=float32),
                 }),
})

The initial parameters are unfortunately not the same as for our implementation so we cannot compare if we correctly implemented the cell. We can, however, create a mapping of parameters using the values from Haiku’s initial parameter set:

w = FlatMapping({
    "w": hk_params['lstm/linear']["w"][[0], :],
    'b': hk_params['lstm/linear']["b"]

})
u = FlatMapping({
    "w": hk_params['lstm/linear']["w"][[1], :],
})

params = FlatMapping({
    'lstm/~/w': w,
    'lstm/~/u': u,
})
params
FlatMapping({
  'lstm/~/w': FlatMapping({
                'w': DeviceArray([[-0.51025134,  0.5771896 , -0.09637077, -0.49400634]], dtype=float32),
                'b': DeviceArray([0., 0., 0., 0.], dtype=float32),
              }),
  'lstm/~/u': FlatMapping({
                'w': DeviceArray([[ 0.454539  ,  0.27723938, -1.2919832 , -0.24847537]], dtype=float32),
              }),
})

Let’s compare the two implementations:

model.apply(x=X[[0], :], params=params)
DeviceArray([[0.0000000e+00, 1.3664528e-02, 1.4810549e-03, 1.0234739e-04,
              5.3378062e-06, 2.1170729e-07, 6.2559611e-09, 1.3474961e-10,
              2.0783364e-12, 2.2680715e-14, 1.7471648e-16]],            dtype=float32)
hk_model.apply(x=X[ None, [0], :].T, params=hk_params).T
DeviceArray([[[0.0000000e+00, 1.3664528e-02, 1.4810549e-03,
               1.0234739e-04, 5.3378062e-06, 2.1170729e-07,
               6.2559602e-09, 1.3474957e-10, 2.0783360e-12,
               2.2680710e-14, 1.7471644e-16]]], dtype=float32)

Great! The two implementations do the same thing.

BetaLSTMs

In the following we’ll use Haiku’s implementation as it is probably more optimized and less error-prone.

Since, the proportions of the electoral votes are random variables with support on the unit interval, we’ll use a Beta distribution as observation model which will require us to rewrite the LSTM a bit. Since we also want to model a batch of \(50\) timeseries jointly without explicitely introducing some correlation structure or hierarchy as in the GP notebook we will add an “indicator” of which timeseries we are dealing with as a covariable.

Y = D[["state", "year_numerical", "proportion"]].pivot_table(
    index="state", values="proportion", columns="year_numerical"
)

E = np.repeat(np.arange(Y.shape[0]), Y.shape[1]).reshape(Y.shape)
X = np.tile(np.array(Y.columns), (1, Y.shape[0], 1))

xs = []
for i in range(X.T.shape[1]):
    x = X.T[:, [i], :].flatten()
    e = E[i].flatten()
    xe = np.vstack([x, e]).T[:, None]
    xs.append(xe)
X = np.hstack(xs)
X.shape
(11, 50, 2)

Let’s have a look at a single observation:

X[:, [0], :]
DeviceArray([[[ 0.,  0.]],

             [[ 4.,  0.]],

             [[ 8.,  0.]],

             [[12.,  0.]],

             [[16.,  0.]],

             [[20.,  0.]],

             [[24.,  0.]],

             [[28.,  0.]],

             [[32.,  0.]],

             [[36.,  0.]],

             [[40.,  0.]]], dtype=float32)

So, the first axis indexes the time points, the second the observations, and the third the covariables. In the example above we show timeseries \(0\), hence the covariable also has value \(0\). Of course, we need to reshape the response matrix \(Y\) as well:

Y = np.array(Y.values)
Y = Y[None, :, :].T
Y.shape
(11, 50, 1)
Y[:, [0], :]
DeviceArray([[[0.3810817 ]],

             [[0.32700816]],

             [[0.30944088]],

             [[0.37836683]],

             [[0.4342574 ]],

             [[0.39571497]],

             [[0.32063052]],

             [[0.36773717]],

             [[0.38935214]],

             [[0.4268471 ]],

             [[0.41614345]]], dtype=float32)

We implement the Beta-LSTM as a Haiku module with two LSTM cells, two relus and a linear layer at the end. In addition, we add a paramter \(\nu\) to model constant offset and a parameter \(\kappa\) for the precision of the Beta. Specifically, we implement the following model

\[\begin{aligned} y_{st} &\sim \text{Beta}\left(\mu_{st}\kappa, (1.0 - \mu_{st})\kappa \right) \\ \mu_{st} & = \text{logit}^{-1}\left(\nu + \phi_{st} \right) \\ \end{aligned} \] where \(\phi_{st}\) is the output of the LSTM for state \(s\) at time \(t\).

class BetaLSTM(hk.Module):
    def __init__(self, name='beta_lstm'):
        super().__init__(name=name)
        self._net = hk.DeepRNN([
            hk.LSTM(40), jax.nn.relu,
            hk.LSTM(40), jax.nn.relu,
            hk.Linear(1)
        ])
        self._nu = hk.get_parameter('nu', [], init=np.ones)
        self._kappa = hk.get_parameter('kappa', [], init=np.zeros)

    def __call__(self, x, pr=False):
        p = x.shape[1]
        if pr:
            print(x.shape)
            print(x)
        outs, state = hk.dynamic_unroll(self._net, x, self._net.initial_state(p))
        mu = sp.special.expit(self._nu + outs)
        kappa = np.exp(self._kappa)
        be = dist.Beta(mu * kappa, (1.0 - mu) * kappa)
        return be
        

def _beta_lstm(x, pr=False):
    module = BetaLSTM()
    return module(x, pr)

The model is initialized and called as above. To check if the data are provided properly, I’ve added a flag that prints the data and the shape.

beta_model = hk.without_apply_rng(hk.transform(_beta_lstm))

key = jax.random.PRNGKey(42)
params = beta_model.init(key,  X)
_ = beta_model.apply(x=X[:, [0], :], pr=True, params=params).sample(key=key)
(11, 1, 2)
[[[ 0.  0.]]

 [[ 4.  0.]]

 [[ 8.  0.]]

 [[12.  0.]]

 [[16.  0.]]

 [[20.  0.]]

 [[24.  0.]]

 [[28.  0.]]

 [[32.  0.]]

 [[36.  0.]]

 [[40.  0.]]]

The result is a \(11 \times 1 \times 1\)-dimensional tensor.

beta_model.apply(x=X[:, [0], :], params=params).sample(key=key)
DeviceArray([[[0.9994723 ]],

             [[0.5204014 ]],

             [[0.9997094 ]],

             [[0.8825871 ]],

             [[0.97088987]],

             [[0.04735272]],

             [[0.38127446]],

             [[0.7802068 ]],

             [[0.36030933]],

             [[0.16266209]],

             [[0.11841452]]], dtype=float32)

As in the notebook using GPs we use all but the 7th time point for training of every timeseries.

train_idxs = np.arange(11) != 7

We fit this model with the code below. Notice that we can use the train data set by making a prediction on the entire data and then evaluating the log probability of the observed values, i.e. the electoral vote proportions, using the training indexes.

@jax.jit
def nll(params: hk.Params):
    beta = beta_model.apply(x=X, params=params)
    ll = np.sum(beta.log_prob(Y)[train_idxs, :, :])
    return -ll

@jax.jit
def update(params, opt_state):
    val, grads = jax.value_and_grad(nll)(params)  
    updates, new_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_state, val
  
optimizer = optax.adam(0.001)
opt_state = optimizer.init(params)

nlls = []
for step in range(10000):
    params, opt_state, val = update(params, opt_state)
    nlls.append(float(val))

Let’s have a look at the trace of losses, the negative log likelihood.

fig, _ = plt.subplots(1, 1)
ax = sns.lineplot(
  data=pd.DataFrame({"y": nlls, "x": range(len(nlls))}),
  y="y", x="x",
  color='black'
);
ax.set(xlabel="", ylabel="NLL");  
plt.show()

We can make predictions of the entire sequence like this:

beta = beta_model.apply(x=X, params=params)

Y_hat = beta.sample(key=key, sample_shape=(100,))
Y_hat = np.mean(Y_hat, axis=0)

print(Y[:, [0], :].T)
[[[0.3810817  0.32700816 0.30944088 0.37836683 0.4342574  0.39571497
   0.32063052 0.36773717 0.38935214 0.4268471  0.41614345]]]
print(Y_hat[:, [0], :].T)
[[[0.38854533 0.33375356 0.31334743 0.37708554 0.4431399  0.39919963
   0.32724532 0.35664418 0.3853164  0.42825118 0.42239124]]]

As a measure of predictive performance, we compute the mean absolute error on the test instances:

np.mean(np.abs(Y[~train_idxs, :, :] - Y_hat[~train_idxs, :, :]))
DeviceArray(0.05373552, dtype=float32)

Modelling correlations

In the GP notebook we modelled correlations between the timeseries using an hierarchical approach and coregionalisation, respectively. We can do a similar thing here by modelling the correlation of those per timepoint using the covariance matrix of a low-rank multivariate Gaussian. A similar approach has been applied by Salinas et al. (2019), only that they apply a copula, while here we use a latent random variable to model the correlations. Specifically, we use a Gaussian process (GP) to encode the correlation between the timeseries for every time point. A sample from the GP is then used as latent predictor of the mean of every time point of every timeseries. The model reads like this:

\[\begin{aligned} y_{st} & \sim \text{Beta}\left(\mu_{st}\kappa, (1.0 - \mu_{st})\kappa \right) \\ \mu_{st} & = \text{logit}^{-1}\left(f_{st} \right) \\ f_{t} & \sim GP(m_t, d_t, v_t) \end{aligned} \] where \(m\), \(d\) and \(v\) are parameters estimated by an LSTM.

We will fit the model variationally using NumPyro. We use the LSTM from above to compute the mean and covariance of the GP like this:

class MultivariateLSTM(hk.Module):
    def __init__(self, name='beta_lstm'):
        super().__init__(name=name)
        self._net = hk.DeepRNN([
            hk.LSTM(40), jax.nn.relu,
            hk.LSTM(40), jax.nn.relu,
           hk.Linear(1 + 1 + 5)
        ])

    def __call__(self, x):
        p = x.shape[1]
        outs, _ = hk.dynamic_unroll(self._net, x, self._net.initial_state(p))
        mu, d, v = np.split(outs, [1, 2], axis=-1)
        d, v = np.exp(d), v[:, :, None, :]
        return mu, v, d

def _mvn_lstm(x):
    module = MultivariateLSTM()
    return module(x)

mvn_lstm = hk.transform(_mvn_lstm)

We then use this implementation to define the NumPyro model:

def model(y, x, train_idxs):    
    nn = haiku_module("nn", mvn_lstm, x=x)
    mu, v, d =  nn(x)
    f = numpyro.sample("f", dist.LowRankMultivariateNormal(mu, v, d))
    mu = numpyro.deterministic("mu", sp.special.expit(f))
    kappa = numpyro.param("kappa", 1.0, constraint=constraints.positive)
    numpyro.sample(
        "y", 
        dist.Beta(mu[train_idxs, :, :] * kappa, (1.0 - mu[train_idxs, :, :]) * kappa),
        obs=y[train_idxs, :, :]
    )

The model is fairly straightforward. We use the LSTM to predict the mean and covariance of the GP, then sample a random variable from the GP, and then use this sample to compute the mean of the Beta which we use as observation model. Finally we compute the log probabilitiy of the data wrt the estimated mean on the train set.

Estimating this using variational inference requires definition of a variational distribution for the latent GP. We can do this using the same LSTM defined above.

def guide(y, x, train_idxs):    
    nn = haiku_module("nn", mvn_lstm, x=x)
    mu, v, d =  nn(x)
    numpyro.sample(
        "f", 
        dist.LowRankMultivariateNormal(mu, v, d)
    )

Thanks to NumPyro this is optimized using a couple of lines of code.

optimizer = numpyro.optim.Adam(step_size=0.001)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
nsteps = 10000
svi_result = svi.run(key, nsteps, y=Y, x=X, train_idxs=train_idxs)

As before we plot the losses. In this case we optimize the ELBO (before it was the negative log likelihood).

fig, _ = plt.subplots(1, 1)
ax = sns.lineplot(
  data=pd.DataFrame({"y": svi_result.losses, "x": range(len(svi_result.losses))}),
  y="y", x="x",
  color='black'
);
ax.set(xlabel="", ylabel="NLL");
plt.show()

Let’s compare the prediction to the LSTM above:

mu, _, _ = mvn_lstm.apply(x=X, params=svi_result.params['nn$params'], rng=key)
mu = sp.special.expit(mu)
kappa = svi_result.params['kappa']

Y_hat = dist.Beta(mu * kappa, (1.0 - mu) * kappa).sample(key=key, sample_shape=(100,))
Y_hat = np.mean(Y_hat, axis=0)

print(Y[:, [0], :].T)
[[[0.3810817  0.32700816 0.30944088 0.37836683 0.4342574  0.39571497
   0.32063052 0.36773717 0.38935214 0.4268471  0.41614345]]]
print(Y_hat[:, [0], :].T)
[[[0.38623443 0.32821316 0.30966443 0.374171   0.44032004 0.39565864
   0.3224485  0.33204356 0.38139075 0.42249826 0.41476214]]]

And the error on the test set:

np.mean(np.abs(Y[~train_idxs, :, :] - Y_hat[~train_idxs, :, :]))
DeviceArray(0.0374837, dtype=float32)

It looks like this model is significantly better, but both have a worse predictive performance than the GPs. This is most likely due to the fact that my experience with LSTMs is pretty much non-existent and we didn’t spend much time optimizing the hyperparameters.

References

Salinas, David, Michael Bohlke-Schneider, Laurent Callot, Roberto Medico, and Jan Gasthaus. 2019. “High-Dimensional Multivariate Forecasting with Low-Rank Gaussian Copula Processes.” arXiv Preprint arXiv:1910.03002.

Salinas, David, Valentin Flunkert, Jan Gasthaus, and Tim Januschowski. 2020. “DeepAR: Probabilistic forecasting with autoregressive recurrent networks.” International Journal of Forecasting 36 (3): 1181–91.

Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. 2014. “Recurrent Neural Network Regularization.” arXiv Preprint arXiv:1409.2329.