Modular Learning of Deep Causal Generative Models for High-dimensional Causal Inference
Authors: Md Musfiqur Rahman, Murat Kocaoglu
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | With extensive experiments on the Colored-MNIST dataset, we demonstrate that our algorithm outperforms the baselines. We also show our algorithm s convergence on the COVIDx dataset and its utility with a causal invariant prediction problem on Celeb A-HQ. |
| Researcher Affiliation | Academia | 1School of Electrical and Computer Engineering, Purdue University, West Lafayette, IN, USA. |
| Pseudocode | Yes | Algorithm 1 Modular Training(G, D) |
| Open Source Code | Yes | We share our implementation at https://github.com/Musfiqshohan/Modular-DCM. |
| Open Datasets | Yes | With extensive experiments on the Colored-MNIST dataset, we demonstrate that Modular-DCM converges better compared to the closest baselines and can correctly generate interventional samples. We also show our convergence on COVIDx CXR-3 and solve an invariant prediciton problem on Celeb A-HQ. |
| Dataset Splits | No | The paper describes training and test splits for the Celeb A-HQ dataset ("5380 train samples and 1280 test samples") but does not explicitly mention a validation split or its size. |
| Hardware Specification | Yes | We performed our experiments on a machine with an RTX-3090 GPU. |
| Software Dependencies | No | The paper mentions software components like Wasserstein GAN with penalized gradients, Gumbel-softmax, Batch Norm, ReLU, and ADAM optimizer, but it does not provide specific version numbers for these or other software dependencies. |
| Experiment Setup | Yes | Our datasets contained 20 40K samples, and the batch size was 200, and we used the ADAM optimizer. For Wassertein GAN with gradient penalty, we used LAMBDA GP=10. We had learning rate = 5 1e 4. We used Gumbel-softmax with a temperature starting from 1 and decreasing it until 0.1. |