Generative Conditional Distributions by Neural (Entropic) Optimal Transport
Authors: Bao Nguyen, Binh Nguyen, Hieu Trung Nguyen, Viet Anh Nguyen
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | 5. Experiments. Datasets. To benchmark our proposed method GENTLE, we use the following two datasets and the same preprocessing steps as in Athey et al. (2021). ... Baselines. We compare our method against state-of-the-art baselines, including CWGAN (Athey et al., 2021), WGAN-GP (Gulrajani et al., 2017), MGAN (Baptista et al., 2020), and CDSB (Shi et => al., 2022). ... Results. We plot the qualitative results of estimated densities for different covariates X in Figure 3 and 4. Overall, synthetic observations generated by GENTLE exhibit a strong resemblance to the ground truth distributions in both datasets, while the ones generated by baselines deviate significantly from the data-generating distributions. ... This is consistent with the quantitative results in Table 1, which show that the data generated by GENTLE have not only smaller Wasserstein distance (WD) and Kolmogorov-Smirnov (KS) values but also smaller standard deviations. ... 5.1. Ablation Study on Training Loss Terms. |
| Researcher Affiliation | Collaboration | Bao Nguyen 1 2 Binh Nguyen 3 Hieu Trung Nguyen 1 4 Viet Anh Nguyen 1 1The Chinese University of Hong Kong 2Vin University 3Department of Mathematics, National University of Singapore 4Vin AI Research. |
| Pseudocode | Yes | Algorithm 1 Empirical loss evaluation |
| Open Source Code | Yes | Our implementation can be found at https://github.com/ nguyenngocbaocmt02/GENTLE. |
| Open Datasets | Yes | The LDW-CPS dataset, constructed by La Londe (1986); Dehejia & Wahba (1999), is widely used in studies of Average Treatment Effects. ... We also apply frameworks to a practical simulator called the Esophageal Cancer Markov Chain Simulation Model (ECM).1 This simulator enables us to simulate patients quality-adjusted life years (QALYs) from their current treatment until death. The approach to generate data for training and testing follows the authors in Hong et al. (2023). ... 1Code is publicly available at https://simopt. github.io/ECSim |
| Dataset Splits | Yes | To achieve better testing evaluation and model selection, we choose covariates based on their frequency for the test and validation sets. Specifically, covariates with a frequency higher than 30 are included in the test set, while those with a frequency higher than 20 but less than or equal to 30 are included in the validation set. According to this split strategy, the training set contains 15,255 samples corresponding to 13,745 covariates, the validation set contains 488 samples corresponding to 19 covariates, and the test set contains 434 samples corresponding to 20 covariates. ... The validation set has 200 covariates that are different from those in the train set, while the test set also has 200 covariates that are different from those in the train and validation sets. |
| Hardware Specification | No | The paper does not provide specific hardware details such as GPU models, CPU types, or cloud instance specifications used for running the experiments. |
| Software Dependencies | No | The paper does not provide specific software dependency details, such as programming language versions or library versions (e.g., Python 3.x, PyTorch x.x, CUDA x.x). |
| Experiment Setup | Yes | We keep the architecture of Tθ(x, U) and vϕ simple with seven fully connected layers with Re LU activation. ... We heuristically fix the learning rates α = β = 0.001 as a reasonable value for almost every simple multilayer perception model. For other parameters, we conducted a grid search on minimizing the WD metric of the validation set. We found that the best combination for our parameters in the LDW-CPS dataset is KDE bandwidth of h = 0.3, ε = 1.0, λ = 0.4, r1 = 3.0, r2 = 2.0, γ = 0.5, δ = 0.7. Regarding the ECM dataset, the best combination is h = 0.2, ε = 1.0, λ = 0.4, r1 = 3.0, r2 = 2.0, γ = 0.5, δ = 0.7. We fix these parameter values for all our experiments in the main paper. |