Learning cause-effect mechanisms among a set of random variables is not only of great epistemological interest, but also a fascinating statistical problem. In this notebook we implement a graph variational autoencoder (DAG-GNN by Yu et al. (2019)) to learn the DAG of a structural equations model and compare it to greedy equivalent search (GES, Chickering (2002)) which is one of the state-of-the-art methods for causal discovery.

Some useful references are Peters, Janzing, and Schölkopf (2017), Drton and Maathuis (2017) or Heinze-Deml, Maathuis, and Meinshausen (2018). Feedback and comments are welcome!

Causal structure learning

The goal of causal structure learning is to find the causal ordering of \(p\) random variables \({X} = (X_1, \dots, X_p)^T\). For instance, for a pair of random variables \(X_1\) and \(X_2\), if \(X_1\) causes \(X_2\), the causal ordering is \(X_1 < X_2\). For a set of \(p\) random variables the ordering induces a directed acyclic graph \(\mathcal{G}\) over \({X}\) (in general, if no topological ordering exists, the underlying graph can be cyclic, too, but we limit ourselves to the acyclic case here). Causal structure learning faces some major difficulties though. Some of these are:

  • the number of DAGs grows super-exponentially in \(p\),
  • there may be latent confounding,
  • there may be selection bias,
  • there may be feedback,

Score-based methods, such as DAG-GNN or GES, approach the structure learning problem by optimization of a defined score relative to a data set. A drawback of score-based methods is that they often need to assume causal sufficiency (i.e., no latent confounding), that they get stuck in local optimina, and that they need to make several parametric assumptions.

To learn the causal structure we will need some definitions and assumptions first. A Bayesian network represents a joint probability distribution \({P}\) over random variables \({X}\) that factorizes over a DAG \(\mathcal{G}\) as

\[\begin{align} P(X) = \prod_i^p P(X_i \mid \text{pa}_{X_i}) \end{align}\]

The set of conditional independence relations imposed by \(\mathcal{G}\) via the above factorization is also called Markov factorization property. It also entails that if \(X\) and \(Y\) are d-separated by \(Z\) then \(X\) is statistically independent of \(Y\) given \(Z\). This means that for each of the paths between a variable in \({X}\) and a variable in \({Y}\), there is either a chain \(\cdot \rightarrow Z_i \rightarrow \cdot\) or a fork \(\cdot \leftarrow Z_i \rightarrow \cdot\) such that \(Z_i\) is in \({Z}\), or there is a collider \(\cdot \rightarrow W \leftarrow \cdot\) such that neither \(W\) nor any of its descendents are in \({Z}\). In general, there can be multiple DAGs that encode the same conditional independence relations which means multiple DAGs can have the same score if the scores are (locally) consistent. As a consequence, one can identify the correct DAG only up to its equivalence class (the completed partially directed acyclic graph).

In order to learn the graph, we will assume having i.i.d. data and a causally sufficient system that does not have cycles. However, even with the two latter assumptions, the problem is still NP-hard.

Linear Gaussian SEMs

In the linear Gaussian case, the data generating process of the random variables \(X\) reads as

\[\begin{align} \mathbf{x} & \leftarrow \mathbf{A}^T \mathbf{x} + \mathbf{z} \\ \end{align}\]

where \(\mathbf{A}\) is a \(p \times p\)-dimensional adjacency matrix that defines the strength of the associations between the variables, and implicitely the causal ordering of the random variables, and \(\mathbf{z}\) is a vector of independent error terms. For a data set \(\mathbf{X} = \{ \mathbf{x}_i\}_{i=1}^n\) of \(n\) realizations of \(X\), the linear structural equation model is analogously:

\[\begin{align} \mathbf{X} & \leftarrow \mathbf{X} \mathbf{A} + \mathbf{Z} \end{align}\]

which we can rearrange to

\[\begin{align} \mathbf{X} & \leftarrow \mathbf{Z} \left(\mathbf{I} - \mathbf{A} \right)^{-1}\\ \mathbf{Z} & \leftarrow \mathbf{X} \left(\mathbf{I} - \mathbf{A} \right) \end{align}\]

where \(\mathbf{I}\) is the identity matrix of appropriate dimensions.

Methods

To learn the adjacency matrix \(\mathbf{A}\) we will use two different approaches:

  • a score-based method, DAG-GNN, which is based on graph variational autoencoders and has been introduced by Yu et al. (2019),
  • another score-based method, the GES-algorithm which has been introduced by Chickering (2002) and which is one of most frequently used method to learn causal DAGs in statistics.

I’ll introduce each briefly below and then apply both to data.

DAG-GNN

Yu et al. (2019) make use of variational autoencoders for learning the causal structure, i.e., the weighted adjacency matrix \(\mathbf{A}\), and for that introduce a new network architecture based on graph neural networks:

\[\begin{align} \mathbf{X} & \leftarrow f_1 \left( f_2 \left( \mathbf{Z}\right) \left(\mathbf{I} - \mathbf{A} \right)^{-1} \right) \\ \mathbf{Z} & \leftarrow g_1 \left( g_2 \left( \mathbf{X}\right) \left(\mathbf{I} - \mathbf{A} \right) \right) \end{align}\]

where \(f_i\) and \(g_j\) are multilayer perceptrons (MLPs). Note that if \(f_1\) and \(f_2\) were invertible this is just rearranging the formulae as for the linear SEM:

\[\begin{align} \mathbf{X} & \leftarrow f_1 \left( f_2 \left( \mathbf{Z}\right) \left(\mathbf{I} - \mathbf{A} \right)^{-1} \right) \\ f_1^{-1}(\mathbf{X} ) & \leftarrow f_2 \left( \mathbf{Z}\right) \left(\mathbf{I} - \mathbf{A} \right)^{-1} \\ f_2 \left( \mathbf{Z}\right) & \leftarrow f_1^{-1}(\mathbf{X} ) \left(\mathbf{I} - \mathbf{A} \right) \\ \mathbf{Z} & \leftarrow f_2^{-1} \left( f_1^{-1}(\mathbf{X} ) \left(\mathbf{I} - \mathbf{A} \right) \right) \end{align}\]

For their model Yu et al. (2019) set both \(f_2\) and \(g_1\) to be identity mappings and \(f_1\) and \(g_2\) to be MLPs. Hence, the encoder of the VAE which computes the variational posterior \(q(Z \mid X)\) has the following form:

\[\begin{align} q(Z \mid X) &= \mathcal{MN}(\mu_Z, \sigma^2_Z) \\ \left[\mu_Z, \sigma^2_Z\right] & = g_2 \left( \mathbf{X}\right) \left(\mathbf{I} - \mathbf{A} \right)\\ &= \text{MLP}\left(\mathbf{X}, \mathbf{W}_1, \mathbf{W}_2 \right) \left(\mathbf{I} - \mathbf{A} \right) \end{align}\]

where \(\mathcal{MN}\) is the matrix normal, \(\text{MLP}\left(\mathbf{X}, \mathbf{W}_1, \mathbf{W}_2 \right) = \text{ReLU}\left(\mathbf{X}\mathbf{W}_1 \right)\mathbf{W}_2\) and \(\mathbf{W}_i\) are parameter matrices of appropriate sizes. The decoder which computes the parameters of the hidden likelihood \(P(X \mid Z)\) looks like this:

\[\begin{align} p(X \mid Z) &= \mathcal{MN}(\mu_X, \sigma^2_X) \\ \left[\mu_X, \sigma^2_X\right] & = f_1 \left( \mathbf{Z} \left(\mathbf{I} - \mathbf{A} \right)^{-1} \right) \\ & = \text{MLP}\left(\mathbf{Z}\left(\mathbf{I} - \mathbf{A} \right)^{-1}, \mathbf{W}_3, \mathbf{W}_4 \right) \end{align}\]

The model is straightforward to implement in R. For convenience, we specify a helper function to construct a dense layer first.

The model itself uses keras_model_custom:

The last part of the model is equivalent to Python’s __call__ member function, i.e., if we construct a model m and call m(X) on a data set \(\mathbf{X}\), it computes the variational posterior, samples from it and then computes the distribution of the decoder.

tfp.distributions.Independent("IndependentNormal/", batch_shape=[100], event_shape=[5], dtype=float32)

The weights of the network are trained by maximizing the evidence low bound

\[\begin{align} \text{ELBO}\left(\mathbf{X}\right) = \mathbb{E}_{q\left(Z \mid X\right)} \left[ \log p(X \mid Z) \right] - \text{D}_{\text{KL}} \left( q\left(Z \mid X\right) || p\left(Z\right) \right) \end{align}\]

where we take \(p(Z) = \mathcal{MN}(\mathbf{0}, \mathbf{I})\) as the standard matrix normal. In practice we take the expected value of the likelihood using a Monte carlo sample. The ELBO above, however, does not ensure acyclicity of \(\mathbf{A}\), hence the authors introduce an acyclicity constraint to the ELBO:

\[\begin{align} h(\mathbf{A}) = \text{tr}\left[ \left( \mathbf{I} + \alpha \mathbf{A} \circ \mathbf{A} \right)^p \right] - p = 0 \end{align}\]

for some \(\alpha > 0\). Adding this constraint to the ELBO using an augmented Lagrangian method gives us the objective function for the GVAE. The augmented Lagrangian consists of a Lagrange multiplier \(\lambda h(\mathbf{A})\) and a quadratic penalty \(\frac{c}{2} |h(\mathbf{A})|^2\). The objective we want to optimize is then:

\[\begin{align} \min - \text{ELBO}(\mathbf{X}) + \lambda h(\mathbf{A}) + \frac{c}{2} ||h(\mathbf{A})||^2 \end{align}\]

The penality can be coded like this:

The R code for the entire penalized ELBO looks like that:

This pretty much is it. We now only need to define a method to train the parameters which, with Tensorflow 2, can be done like this:

In my opinion the paper would have benefitted from a couple of additions. Firstly, even though concepts like consistency, score equivalence or faithfulness are mentioned in the introduction the paper does not really discuss them w.r.t their method. Secondly, while popular methods like the PC-algorithm or GES are cited, the authors don’t include them in their benchmark. This is especially unfortunate, because both are (pointwise/locally) consistent and should serve as great baseline for large sample sizes. Thirdly, it would probably have made sense to discuss identifiability and what assumptions have to be made to achieve it. For instance, in the linear Gaussian case, returning the CPDAG instead of any DAG in the equivalence class is reasonable, since the true DAG cannot be identified. On the other hand, under specific conditions, such as non-linear or non-Gaussian cases, the true DAG can be identified and returning a DAG makes sense. Finally, since the relationship between the data \(\mathbf{X}\) and \(\mathbf{A}\) is nonlinear (and nonparametric) in DAG-GNN, interpretation of the coefficients of the weighted adjacency is difficult.

Use case

We simulate data from a SEM with \(p=7\) variables. We can use pcalg to first create a DAG and then sample data from it.

The dag object is a graphNEL, so we convert it to a matrix first.

Since the methods above assume acyclicity, we check for it using igraph.

[1] TRUE

We can also check this via the acyclicity constraint, which should be zero if \(\mathbf{A}\) has no cycles.

tf.Tensor(0.0, shape=(), dtype=float32)

As a helper we also define a method to plot a matrix as graph. I personally like ggraph for plotting graphs the best.

The graph below is the causal structure that we want to learn, i.e., an edge \(X_i \rightarrow X_j\) encodes a cause-effect relationship.

We simulate data using the equation of the SEM above as:

              X1          X2         X3          X4          X5         X6
[1,]  0.01106827  0.04690320  0.1728460  0.08985497  0.06879523  0.1197701
[2,] -0.05558411 -0.02112469 -0.0027891 -0.02377340  0.00699006 -0.1163466
[3,]  0.17869131  0.01870511  0.2489155  0.14634076  0.31859957  0.6472214
[4,]  0.04978505  0.02275427 -0.1624462  0.01243783 -0.21683510 -0.3356445
[5,] -0.19666172 -0.12619005 -0.1629389  0.14537408 -0.24879381 -0.5624779
[6,]  0.07013559  0.02855896 -0.1076570  0.03502270 -0.08467841  0.1235339
              X7
[1,]  0.05077993
[2,]  0.04613310
[3,]  0.11794496
[4,] -0.06793963
[5,] -0.03186116
[6,] -0.06580193

To use GES, we first need to define a score object and provide the data as argument. In order to get the original GES algorithm by Chickering, we need to pass the phases as well (forward and backward), because the current implementation of the pcalg package uses an additional third phase.

From the fitted object we extract the graph first, turn it into a CPDAG and plot it. GES already computes a CPDAG, but it’s just good practice to call the function regardless of the method we are using, since many methods merely return an element of the equivalence class and not the equivalence class itself. If one isn’t aware of this and expects a CPDAG, it is easy to make incorrect causal claims (recall that we generally cannot distinguish the direction of edges in the linear Gaussian case unless a triple is a v-structure or would become one, if we reversed an edge).

That looks good. GES managed to correctly infer all the edges of the true DAG (including v-structures). This was expected, since the sample size is fairly high. Next we fit the autoencoder. DAG-GNN has several tuning paramameters. When testing the method, Lagrange multipliers \(c\) and \(\lambda\) with high values yielded the best results. Choices of \(\alpha\) did not significantly change the output of the method.

Let’s check if this is really a DAG first.

[1] TRUE

As before we compute the CPDAG from the DAG and plot that. The result of the training is a weighted adjacency matrix which we first convert to a graphNEL, then to the CPDAG and then back to an adjacency matrix.

DAG-GNN, in this example, unfortunately wasn’t able to recover the true DAG from the data and was easily outperformed by GES. I am not sure if this is due to a bug in the implementation, the data, the parameterization, or whatever (feedback on bugs would be welcome).

Regardless, it impressively shows that classics with theoretical/asymptotic guarantees, like GES, still have their place in modern machine learning and should be considered in applications and benchmarks.

Session info

R version 4.0.0 (2020-04-24)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 19.10

Matrix products: default
BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/libopenblasp-r0.3.7.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=de_CH.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=de_CH.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=de_CH.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=de_CH.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] pcalg_2.6-10           forcats_0.5.0          stringr_1.4.0         
 [4] dplyr_0.8.5            purrr_0.3.4            readr_1.3.1           
 [7] tidyr_1.0.2            tibble_3.0.1           tidyverse_1.3.0       
[10] ggraph_2.0.2           ggplot2_3.3.0          igraph_1.2.5          
[13] keras_2.2.5.0          tfprobability_0.10.0.0 tensorflow_2.2.0      

loaded via a namespace (and not attached):
 [1] nlme_3.1-147        fs_1.4.1            lubridate_1.7.8    
 [4] httr_1.4.1          tools_4.0.0         backports_1.1.6    
 [7] R6_2.4.1            DBI_1.1.0           BiocGenerics_0.34.0
[10] colorspace_1.4-1    withr_2.2.0         tidyselect_1.0.0   
[13] gridExtra_2.3       compiler_4.0.0      graph_1.66.0       
[16] cli_2.0.2           rvest_0.3.5         xml2_1.3.2         
[19] labeling_0.3        sfsmisc_1.1-7       scales_1.1.0       
[22] DEoptimR_1.0-8      robustbase_0.93-6   RBGL_1.64.0        
[25] rappdirs_0.3.1      tfruns_1.4          digest_0.6.25      
[28] rmarkdown_2.1       base64enc_0.1-3     pkgconfig_2.0.3    
[31] htmltools_0.4.0     dbplyr_1.4.3        rlang_0.4.5        
[34] readxl_1.3.1        rstudioapi_0.11     farver_2.0.3       
[37] generics_0.0.2      jsonlite_1.6.1      magrittr_1.5       
[40] Matrix_1.2-18       Rcpp_1.0.4.6        munsell_0.5.0      
[43] fansi_0.4.1         abind_1.4-5         reticulate_1.15    
[46] viridis_0.5.1       lifecycle_0.2.0     stringi_1.4.6      
[49] whisker_0.4         yaml_2.2.1          MASS_7.3-51.5      
[52] grid_4.0.0          parallel_4.0.0      ggrepel_0.8.2      
[55] bdsmatrix_1.3-4     crayon_1.3.4        lattice_0.20-41    
[58] graphlayouts_0.7.0  haven_2.2.0         hms_0.5.3          
[61] zeallot_0.1.0       knitr_1.28          pillar_1.4.3       
[64] corpcor_1.6.9       stats4_4.0.0        reprex_0.3.0       
[67] glue_1.4.0          evaluate_0.14       modelr_0.1.6       
[70] vctrs_0.2.4         tweenr_1.0.1        cellranger_1.1.0   
[73] gtable_0.3.0        polyclip_1.10-0     clue_0.3-57        
[76] assertthat_0.2.1    xfun_0.13           ggforce_0.3.1      
[79] broom_0.5.6         tidygraph_1.2.0     viridisLite_0.3.0  
[82] ggm_2.5             cluster_2.1.0       fastICA_1.2-2      
[85] ellipsis_0.3.0     

References

Chickering, David Maxwell. 2002. “Optimal Structure Identification with Greedy Search.” Journal of Machine Learning Research 3: 507–54. http://www.jmlr.org/papers/v3/chickering02b.html.

Drton, Mathias, and Marloes H Maathuis. 2017. “Structure Learning in Graphical Modeling.” Annual Review of Statistics and Its Application 4. Annual Reviews: 365–93. https://doi.org/10.1146/annurev-statistics-060116-053803.

Heinze-Deml, Christina, Marloes H Maathuis, and Nicolai Meinshausen. 2018. “Causal Structure Learning.” Annual Review of Statistics and Its Application 5. Annual Reviews: 371–91. https://doi.org/10.1146/annurev-statistics-031017-100630.

Peters, Jonas, Dominik Janzing, and Bernhard Schölkopf. 2017. Elements of Causal Inference. The MIT Press. http://library.oapen.org/handle/20.500.12657/26040.

Yu, Yue, Jie Chen, Tian Gao, and Mo Yu. 2019. “DAG-GNN: DAG Structure Learning with Graph Neural Networks.” ICML. https://arxiv.org/abs/1904.10098.