Causal Contrastive Learning for Counterfactual Regression Over Time
Authors: Mouad EL Bouchattaoui, Myriam Tami, BENOIT LEPETIT, Paul-Henry Cournède
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our method achieves state-of-the-art counterfactual estimation results using both synthetic and real-world data, marking the pioneering incorporation of Contrastive Predictive Encoding in causal inference. and 6 Experiments We compare Causal CPC with SOTA baselines: MSMs [64], RMSN [41], CRN [7], G-Net [39], and CT [48]. |
| Researcher Affiliation | Collaboration | Mouad El Bouchattaoui 1,2, Myriam Tami1, Benoit Lepetit2, and Paul-Henry Cournède1 1Paris-Saclay University, Centrale Supélec, MICS Lab, Gif-sur-Yvette, France 2Saint-Gobain, Paris, France |
| Pseudocode | Yes | H Causal CPC Pseudo algorithm In this section, we present a detailed overview of the training procedure for Causal CPC. Initially, we train the Encoder using only the contrastive terms, as outlined in Algorithm 1. Our primary objective is to ensure that, for each time step t, the process history Ht is predictive of future local features Zt. However, calculating the Info NCE loss for a batch across all possible time steps t = 0, . . . , tmax can be computationally demanding. To address this, we adopt a more efficient approach by uniformly sampling a single time step t per batch. Subsequently, the corresponding process history Ht is contrasted. The sampled Ht is then employed as input for the Info Max objective and randomly partitioned into future Hf t and past Hh t sub-processes. The decoder is trained while taking the encoder as input (Algorithm 2), utilizing a lower learning rate compared to the untrained part of the decoder. It is trained autoregressively and without teacher forcing. |
| Open Source Code | Yes | 5. Open access to data and code Question: Does the paper provide open access to the data and code, with sufficient instructions to faithfully reproduce the main experimental results, as described in supplemental material? Answer: [Yes] Justification: Code is provided in the supplementary material at submission. |
| Open Datasets | Yes | 6.1 Experiments with Synthetic Data Tumor Growth We use the Pharmaco Kinetic-Pharmaco Dynamic (PK-PD) model [24] to simulate responses of non-small cell lung cancer patients, following previous works [41, 7, 48]. and 6.2 Experiments with semi-synthetic and real data Semi-synthetic MIMIC-III We used a semi-synthetic dataset constructed by [48] based on the MIMIC-III dataset [35], incorporating both endogenous temporal dependencies and exogenous dependencies from observational patient trajectories, as detailed in Appendix F.1. |
| Dataset Splits | Yes | Model selection is based on mean squared error (MSE) on factual outcomes from a validation set, and the same criterion is used for early stopping. and Similar to the cancer simulation, the training data consisted of relatively few sequences (500 for training, 100 for validation, and 400 for testing). |
| Hardware Specification | Yes | Table 2: Models complexity and the running time averaged over five seeds. Results are reported for tumor growth simulation (γ = 1). Hardware: GPU1x NVIDIA Tesla M60. and Table 12: The number of parameters to train for each model after hyper-parameters fine-tuning and the corresponding running time averaged over five seeds. Results are reported for semi-synthetic MIMIC III data; the processing unit is GPU 1 x NVIDIA Tesla M60 . |
| Software Dependencies | No | All models were implemented using Py Torch [53] and Py Torch Lightning [21]. |
| Experiment Setup | Yes | All models were implemented using Py Torch [53] and Py Torch Lightning [21]. In contrast to the approach in [48], we employed early stopping for all models. The stopping criterion was defined as the Mean Squared Error over factual outcomes for a dedicated validation dataset. Specifically, for the Causal CPC encoder, the stopping criterion was determined by the validation loss of the encoder. While all models in the benchmark were trained using the Adam optimizer [36], we opted for training Causal CPC (encoder plus decoder without the treatment subnetwork) with Adam W [46] due to its observed stability during training. and J Models hyperparameters In this section, we report the range of all hyperparameters to be fine-tuned, as well as fixed hyperparameters for all models and across the different datasets used in experiments. Best hyperparameter values are reported in the configuration files in the code repository. |