Gaussian processes (GPs) offer an elegant framework for non-parametric Bayesian regression by endowing a prior distribution over a function space. In this notebook we derive the basics of how they can be used for regression and machine learning. Most of the material is based on Rasmussen and Williams (2006), but I also recommend Betancourt (2020) as a resource. Throughout this notebook, we will use Stan to fit GPs. Feedback and comments are welcome!

suppressMessages({
  library(tidyverse)
  library(ggthemes)
  library(colorspace)

  library(rstan)
  library(bayesplot)
})

set.seed(42)
color_scheme_set("darkgray")

Priors over functions

Bayesian linear regression models assume a dependency

\[\begin{align*} f_{\boldsymbol \beta}& : \ \mathcal{X} \rightarrow \mathcal{Y},\\ f_{\boldsymbol \beta}(\mathbf{x}) & = \ \mathbf{x}^T \boldsymbol \beta + \epsilon, \end{align*}\]

parametrized by a coefficient vector \(\boldsymbol \beta\). In order to quantifiy uncertainty we have about the parameters, we put a prior distribution on \(\boldsymbol \beta\). When we use Gaussian processes, we instead put a prior on the function \(f\) itself:

\[\begin{align*} f(\mathbf{x}) & \sim \mathcal{GP}(m(\mathbf{x}), k(\mathbf{x}, \mathbf{x}')) \end{align*}\]

So a Gaussian process is a distribution over functions. It is parameterized by a mean function \(m\) that returns a vector of length \(n\) and a covariance function \(k\) that returns a matrix of dimension \(n \times n\), where \(n\) is the number of samples. For instance, the mean function could be a constant (which we will assume throughout the rest of this notebook), and the kernel could be an exponentiated quadratic which is defined as:

\[\begin{align*} k(\mathbf{x}, \mathbf{x}') &= \alpha^2 \exp\left(- \frac{1}{2\rho^2} ||\mathbf{x} - \mathbf{x}' ||^2 \right) \end{align*}\]

where \(\alpha\) and \(\rho\) are hyperparameters.

Sampling from a GP prior

To sample from a GP we merely need to create points \(x_i \in \mathcal{X}\) from some domain \(\mathcal{X}\) and specify the hyperparameters of the covariance function.

n.star <- 1000L
x.star <- seq(-1, 1, length.out=n.star)

alpha <- 1
rho <- .1

We specify the covariance function using Stan.

prior.model <- "_models/gp_prior.stan"
cat(readLines(prior.model), sep = "\n")
data {
  int<lower=1> n;
  real x[n];

  real<lower=0> alpha;
  real<lower=0> rho;
}

transformed data {
  matrix[n, n] K = cov_exp_quad(x, alpha, rho)
      + diag_matrix(rep_vector(1e-10, n));
  matrix[n, n] L_K = cholesky_decompose(K);
}

parameters {}
model {}

generated quantities {
  vector[n] f = multi_normal_cholesky_rng(rep_vector(0, n), L_K);
}

Having all components set up we sample five realizations from the prior using Stan.

prior <- stan(
  prior.model, 
  data=list(n=n.star, x=x.star, alpha=alpha, rho=rho),
  iter=5,
  warmup=0,
  chains=1,
  algorithm="Fixed_param"
)
prior <- extract(prior)$f

Let’s plot the prior samples. Every line in the plot below represents one realization

prior %>%
  data.frame(x=x.star, f=t(.)) %>% 
  tidyr::pivot_longer(starts_with("f"), names_to="sample", values_to = "f") %>%
  ggplot() +
  geom_line(aes(x, f, color=sample)) +
  scale_color_discrete_sequential(l1 = 1, l2 = 60) +
  theme_tufte() +
  theme(
    axis.text = element_text(colour = "black", size = 15),
    strip.text = element_text(colour = "black", size = 15)
  ) +
  xlab(NULL) +
  ylab(NULL) +
  guides(color=FALSE)

Sampling from a GP posterior

To make use of GPs for regression, we model the conditional distribution of a random variable \(Y\) (for which we observe \(n\) data points) as

\[\begin{align*} Y \mid f \sim \mathcal{N}(f, \Sigma) \end{align*}\]

where \(\Sigma = \sigma^2 \mathbf{I}\). Since we assume both \(Y\) and every finite realization of \(f\) to be Gaussian, \(Y\) and \(f\) are also jointly Gaussian

\[\begin{align*} \left[ \begin{array}{c} \mathbf{y} \\ {f} \end{array} \right] \sim \mathcal{N} \left(\mathbf{0}, \begin{array}{cc} k(\mathbf{x}, \mathbf{x}')+ \Sigma & k(\mathbf{x}, \mathbf{x}') \\ k(\mathbf{x}, \mathbf{x}') & k(\mathbf{x}, \mathbf{x}') \end{array} \right) \end{align*}\]

Conditioning on \(\mathbf{y}\) gives:

\[\begin{align*} f \mid \mathbf{y}, \mathbf{x} & \sim \mathcal{GP}\left(\tilde{m}(\tilde{\mathbf{x}}), \tilde{k}({\mathbf{x}}, {\mathbf{x}}')\right) \end{align*}\]

where the posterior mean function \(\tilde{m}(\mathbf{x})\) is specified as

\[\begin{align*} \tilde{m}(\mathbf{x}) & = k({\mathbf{x}}, \mathbf{x}')\left( k(\mathbf{x}, \mathbf{x}') + \Sigma \right)^{-1} \mathbf{y} \end{align*}\]

and the posterior covariance function \(\tilde{k}(\mathbf{x}, \mathbf{x}')\)

\[\begin{align*} \tilde{k}(\mathbf{x}, \mathbf{x}') & = k({\mathbf{x}}, {\mathbf{x}}') - k({\mathbf{x}}, \mathbf{x}') \left( k(\mathbf{x}, \mathbf{x}') + \Sigma \right)^{-1} k(\mathbf{x}, \mathbf{x}') \end{align*}\]

So the posterior is again a Gaussian process with modified mean and variance functions. This is straightforward to compute in R, but let’s rather compute it in Stan again.

First, we create a set of observations \(\mathbf{y}\). We can create such as set, for instance, by taking a sample from the prior and adding noise to it:

sigma <- 0.1

n <- 30L
idxs <- sort(sample(seq(prior[1, ]), n, replace=FALSE))

x <- x.star[idxs]
f <- prior[1, idxs]
y <- f + rnorm(n, sigma)
D <- data.frame(y=y, x=x)

ggplot(D) +
  geom_point(aes(x, y), size=1) +
  theme_tufte() +
  theme(
    axis.text = element_text(colour = "black", size = 15),
    strip.text = element_text(colour = "black", size = 15)
  ) +
  xlab(NULL) +
  ylab(NULL)

The model file to sample from the posterior can be found below.

posterior.model <- "_models/gp_posterior.stan"
cat(readLines(posterior.model), sep = "\n")
data {
  int<lower=1> n;
  real x[n];
  vector[n] y;
  int<lower=1> n_star;
  real x_star[n_star];

  real<lower=0> alpha;
  real<lower=0> rho;
  real<lower=0> sigma;
}

parameters {}
model {}

generated quantities {
  vector[n_star] f_star;
  {
    matrix[n, n] K =  cov_exp_quad(x, alpha, rho)
        + diag_matrix(rep_vector(1e-10, n));
    matrix[n_star, n] K_star =  cov_exp_quad(x_star, x, alpha, rho);
    matrix[n_star, n_star] K_star_star =  cov_exp_quad(x_star, alpha, rho)
        + diag_matrix(rep_vector(1e-10, n_star));

    matrix[n, n] K_sigma = K
        + diag_matrix(rep_vector(square(sigma), n));
    matrix[n, n] K_sigma_inv = inverse(K_sigma);

    f_star = multi_normal_rng(
      K_star * K_sigma_inv * y,
      K_star_star - (K_star * K_sigma_inv * K_star')
    );
  }
}

We sample 1000 times this time to get a good estimate of the posterior quantiles (for this we of course also could compute the variance analytically, but we can also take the quantiles of the sample). More specifically, we compute the posterior mean and covariance and then sample from a multivariate normal.

posterior <- stan(
  posterior.model, 
  data=list(n=n, x=x, y=y, 
            n_star=n.star, x_star = x.star, 
            alpha=alpha, rho=rho, sigma=sigma),
  iter=1000,
  warmup=0,
  chain=1,
  algorithm="Fixed_param"
)

Having the posterior samples, we compute their mean and 90% quantiles, and plot them.

posterior <- extract(posterior, "f_star")$f_star
posterior.mean      <- apply(posterior, 2, mean)
posterior.quantiles <- apply(posterior, 2, quantile, prob=c(0.05, 0.95))
posterior.frame <- data.frame(
  x=x.star,
  m=posterior.mean, 
  q=t(posterior.quantiles)) %>%
  set_names(c("x", "mean", "lower", "upper"))

ggplot() +
  geom_point(data=D, aes(x, y), size=1) +
  geom_ribbon(data = posterior.frame, 
              aes(x=x, ymin=lower, ymax=upper), 
              fill="#A1A6C8") +
  geom_line(data = posterior.frame,  aes(x, mean), color="darkblue") +
  theme_tufte() +
  theme(
    axis.text = element_text(colour = "black", size = 15),
    strip.text = element_text(colour = "black", size = 15)
  ) +
  xlab(NULL) +
  ylab(NULL)

Posterior predictive

We can use the same formalism as above to derive the posterior predictive distribution, i.e. the distribution of function values \(f^*\) for new observations \(\mathbf{x}^*\). This is useful, when we want to do prediction.

The predictive posterior is given like this:

\[\begin{align*} p(f^* \mid \mathbf{y}, \mathbf{x}, \mathbf{x}^*) = \int p(f^* \mid f) \ p(f \mid \mathbf{y}, \mathbf{x}), \end{align*}\]

(where we included \(\mathbf{x}\) for clarity). However, since our original data set \(\mathbf{y}\) and \(f^*\) have a joint normal distribution, we can just use Gaussian conditioning again.

Fitting hyperparameters

Usually, the kernel hyperparameters as well as the noise variances are not given, so we need to estimate them from data. We can do that for instance by endowing the hyperparameters with priors, or by optimizing them using maximum marginal likelihood. In this notebook we’ll do the former. For a detailed discussion of both, see for instance Betancourt (2020).

Since the hyperparameters are not known and need to be fit, we need to update our posterior code a bit.

posterior.model <- "_models/gp_posterior_parameters.stan"
cat(readLines(posterior.model), sep = "\n")
data {
  int<lower=1> n;
  real x[n];
  vector[n] y;
}

parameters {
  real<lower=0> alpha;
  real<lower=0> rho;
  real<lower=0> sigma;

  vector[n] f_tilde;
}

model {
  vector[n] f;
  {
    matrix[n, n] K =  cov_exp_quad(x, alpha, rho)
        + diag_matrix(rep_vector(1e-10, n));
    matrix[n, n] L_K = cholesky_decompose(K);
    f = L_K * f_tilde;
  }

  rho ~ inv_gamma(5, 5);
  alpha ~ std_normal();
  sigma ~ std_normal();
  f_tilde ~ std_normal();

  y ~ normal(f, sigma);
}

We then infer the posteriors of the hyperparameters and the noise variance as before with the latent GP.

posterior <- stan(
  posterior.model, 
  data=list(n=n, x=x, y=y),
  chains=4,
  iter=2000
)

Let’s have a look at the summary:

posterior
Inference for Stan model: gp_posterior_parameters.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

              mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
alpha         0.86    0.01 0.49   0.10   0.51   0.80   1.15   1.99  1795    1
rho           1.22    0.01 0.64   0.49   0.80   1.06   1.45   2.90  2364    1
sigma         1.49    0.00 0.20   1.15   1.35   1.47   1.61   1.95  3490    1
f_tilde[1]   -0.15    0.01 0.69  -1.46  -0.59  -0.17   0.24   1.37  2702    1
f_tilde[2]    0.75    0.02 0.82  -0.89   0.21   0.75   1.31   2.34  2839    1
f_tilde[3]    0.57    0.02 0.92  -1.25  -0.04   0.57   1.19   2.35  3564    1
f_tilde[4]    0.37    0.02 0.96  -1.52  -0.28   0.38   1.00   2.24  4046    1
f_tilde[5]    0.25    0.02 0.98  -1.68  -0.40   0.26   0.90   2.11  4080    1
f_tilde[6]    0.14    0.01 1.00  -1.82  -0.53   0.13   0.83   2.15  4542    1
f_tilde[7]    0.10    0.02 1.01  -1.86  -0.59   0.12   0.79   2.13  4501    1
f_tilde[8]    0.13    0.01 0.99  -1.81  -0.55   0.15   0.81   2.04  4694    1
f_tilde[9]    0.05    0.01 1.02  -1.95  -0.62   0.06   0.70   2.02  4813    1
f_tilde[10]   0.12    0.01 1.02  -1.96  -0.55   0.13   0.80   2.15  5599    1
f_tilde[11]   0.12    0.01 0.98  -1.79  -0.55   0.11   0.80   2.03  4577    1
f_tilde[12]   0.06    0.01 1.01  -1.89  -0.62   0.06   0.74   2.07  4676    1
f_tilde[13]   0.08    0.01 0.99  -1.91  -0.59   0.09   0.75   2.01  5252    1
f_tilde[14]   0.02    0.01 1.01  -2.01  -0.65   0.01   0.67   2.07  5165    1
f_tilde[15]   0.09    0.01 1.00  -1.87  -0.58   0.08   0.78   1.99  4939    1
f_tilde[16]   0.06    0.01 1.01  -1.92  -0.64   0.06   0.75   2.03  4640    1
f_tilde[17]   0.01    0.01 1.00  -1.96  -0.67   0.03   0.69   2.01  4917    1
f_tilde[18]   0.11    0.01 1.00  -1.90  -0.55   0.13   0.79   2.03  5196    1
f_tilde[19]   0.06    0.01 1.01  -1.96  -0.63   0.08   0.74   2.04  4922    1
f_tilde[20]   0.09    0.01 1.00  -1.83  -0.58   0.07   0.73   2.08  4586    1
f_tilde[21]   0.06    0.01 0.98  -1.86  -0.62   0.06   0.75   1.98  5241    1
f_tilde[22]   0.04    0.01 1.01  -1.93  -0.64   0.04   0.74   2.00  4972    1
f_tilde[23]   0.01    0.01 1.00  -1.95  -0.66   0.02   0.71   1.94  4682    1
f_tilde[24]   0.03    0.02 1.02  -1.96  -0.63   0.03   0.73   2.01  4523    1
f_tilde[25]   0.01    0.01 0.98  -1.90  -0.66   0.01   0.69   1.87  4583    1
f_tilde[26]   0.01    0.01 0.98  -1.88  -0.66   0.02   0.70   1.88  5133    1
f_tilde[27]   0.01    0.02 1.06  -2.06  -0.71   0.00   0.71   2.11  4541    1
f_tilde[28]  -0.02    0.01 0.99  -2.01  -0.70  -0.02   0.66   1.88  4756    1
f_tilde[29]   0.02    0.01 0.99  -1.88  -0.67   0.02   0.70   1.94  4625    1
f_tilde[30]   0.00    0.01 0.98  -1.91  -0.67   0.00   0.67   1.91  4532    1
lp__        -49.42    0.11 4.26 -58.53 -52.15 -49.09 -46.37 -41.93  1622    1

Samples were drawn using NUTS(diag_e) at Sat Mar 20 22:03:57 2021.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

The inferences look good: high-effective sample sizes as well as \(\hat{R}\)s of one. Let’s in the end plot the traces as well as the histograms of the posteriors.

bayesplot::mcmc_trace(posterior, pars=c("sigma", "rho", "alpha"), ) +
  theme_tufte() +
  theme(
    axis.text = element_text(colour = "black", size = 15),
    strip.text = element_text(colour = "black", size = 15)
  ) +
  scale_color_discrete_sequential(l1 = 1, l2 = 60) +
  xlab(NULL) +
  ylab(NULL) +
  guides(color=FALSE)

bayesplot::mcmc_hist(posterior, pars=c("sigma", "rho", "alpha")) +
  theme_tufte() +
  theme(
    axis.text = element_text(colour = "black", size = 15),
    strip.text = element_text(colour = "black", size = 15)
  ) +
  xlab(NULL) +
  ylab(NULL)

Session info

sessionInfo()
R version 4.0.2 (2020-06-22)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.1 LTS

Matrix products: default
BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-openmp/libopenblasp-r0.3.8.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=de_CH.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=de_CH.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=de_CH.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=de_CH.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] bayesplot_1.7.2      rstan_2.21.2         StanHeaders_2.21.0-5
 [4] colorspace_1.4-1     ggthemes_4.2.4       forcats_0.5.0       
 [7] stringr_1.4.0        dplyr_1.0.1          purrr_0.3.4         
[10] readr_1.3.1          tidyr_1.1.1          tibble_3.0.3        
[13] ggplot2_3.3.2        tidyverse_1.3.0     

loaded via a namespace (and not attached):
 [1] httr_1.4.2         jsonlite_1.7.0     modelr_0.1.8       RcppParallel_5.0.2
 [5] assertthat_0.2.1   stats4_4.0.2       blob_1.2.1         cellranger_1.1.0  
 [9] yaml_2.2.1         pillar_1.4.6       backports_1.1.8    glue_1.4.1        
[13] digest_0.6.25      rvest_0.3.6        htmltools_0.5.0    plyr_1.8.6        
[17] pkgconfig_2.0.3    broom_0.7.0        haven_2.3.1        scales_1.1.1      
[21] processx_3.4.3     generics_0.0.2     farver_2.0.3       ellipsis_0.3.1    
[25] withr_2.2.0        cli_2.0.2          magrittr_1.5       crayon_1.3.4      
[29] readxl_1.3.1       evaluate_0.14      ps_1.3.4           fs_1.5.0          
[33] fansi_0.4.1        xml2_1.3.2         pkgbuild_1.1.0     tools_4.0.2       
[37] loo_2.3.1          prettyunits_1.1.1  hms_0.5.3          lifecycle_0.2.0   
[41] matrixStats_0.56.0 V8_3.2.0           munsell_0.5.0      reprex_0.3.0      
[45] callr_3.4.3        compiler_4.0.2     rlang_0.4.7        grid_4.0.2        
[49] ggridges_0.5.2     rstudioapi_0.11    labeling_0.3       rmarkdown_2.6     
[53] gtable_0.3.0       codetools_0.2-16   inline_0.3.15      DBI_1.1.0         
[57] curl_4.3           reshape2_1.4.4     R6_2.4.1           gridExtra_2.3     
[61] lubridate_1.7.9    knitr_1.29         stringi_1.4.6      parallel_4.0.2    
[65] Rcpp_1.0.5         vctrs_0.3.2        dbplyr_1.4.4       tidyselect_1.1.0  
[69] xfun_0.16         

References

Rasmussen, Carl Edward, and Christopher K. I. Williams. 2006. Gaussian Processes for Machine Learning. The MIT Press. http://www.gaussianprocess.org/gpml.