Learning Linear Causal Representations from Interventions under General Nonlinear Mixing

Authors: Simon Buchholz, Goutham Rajendran, Elan Rosenfeld, Bryon Aragam, Bernhard Schölkopf, Pradeep Ravikumar

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this section, we explain our experimental methodology and the theoretical underpinning of our approach. Our main experiments for interventional causal representation learning focus on a method based on contrastive learning. We train a deep neural network to learn to distinguish observational samples x X(0) from interventional samples x X(i). Additionally, we design the last layer of the model to model the log-likelihood of a linear Gaussian SCM. Accordingly, with careful design of the last layer parametric form, we indirectly learn the parameters of the underlying causal model.
Researcher Affiliation Academia 1Max Planck Institute for Intelligent Systems, Tübingen, Germany 2Carnegie Mellon University, Pittsburgh, USA 3University of Chicago, Chicago, USA
Pseudocode No The paper describes the proposed algorithm (e.g., contrastive algorithm) but does not provide it in a formal pseudocode block or algorithm environment.
Open Source Code No The paper does not contain an explicit statement about the release of source code or a direct link to a code repository for the methodology described.
Open Datasets Yes For all our experiments we use Erdös-Rényi graphs, i.e., we add each undirected edge with equal probability p to the graph and then finally orient them according to some random order of the nodes. We write ER(d, k) for the Erdös-Rényi graph distribution on d nodes with kd expected edges. For a given graph G we then sample edge weights from U( [0.25, 1.0]) and a scale matrix D. For simplicity we assume that we have n samples from each environment i I. We only consider the setting where each node is intervened upon once and thus the latent dimension is also known. We consider three types of mixing functions. First, we consider linear mixing functions where we sample all matrix entries i.i.d. from a Gaussian distribution. Then, we consider non-linear mixing functions that are parametrized by MLPs with three hidden layers which are randomly initialized, and have Leaky ReLU activations. Finally, we consider image data as described in [2]. Pairs of latent variables (z2i+1, z2i+2) describe the coordinates of balls in an image and the non-linearity f is the rendering of the image. The image generation is based on pygame [77]. A sample image can be found in Figure 3.
Dataset Splits Yes We use a 80/20 split for train and validation/test set respectively.
Hardware Specification No The experiments using synthetic data were run on 2 CPUs with 16Gb RAM on a compute cluster. For image data we added a GPU to increase the speed. The description mentions 'CPUs', 'RAM', and 'GPU' but lacks specific model numbers, types, or detailed specifications (e.g., 'NVIDIA A100', 'Intel Xeon').
Software Dependencies No We use Py Torch [60] for all our experiments. To generate the latent DAG and sample the latent variables we use the sempler package [21]. The image generation is based on pygame [77]. We use Adam [38] as optimizer. The paper names software components but does not provide specific version numbers for them.
Experiment Setup Yes For training we use the hyperparameters outlined in Table 13. We use a 80/20 split for train and validation/test set respectively. After subsampling each dataset to the same size, we copy the observation samples for each interventional dataset in order to have an equal number of observational and interventional samples so they can be naturally paired during the contrastive learning. We select the model with the smallest validation loss on the validation set, where we only use the cross entropy loss for validation. For the VAE baseline we use the standard VAE validation loss for model selection. Table 13 lists specific values for τ1 (10-5), τ2 (10-4), learning rate (5 × 10-4), batch size (512), and epochs (200/100).