Optimal Transport for Structure Learning Under Missing Data
Authors: Vy Vo, He Zhao, Trung Le, Edwin V. Bonilla, Dinh Phung
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our framework is shown to recover the true causal graphs more effectively than competing methods in most simulations and real-data settings. We evaluate OTM2 on both synthetic and real-world datasets. |
| Researcher Affiliation | Collaboration | 1Monash University, Australia 2CSIRO s Data61, Australia 3Vin AI Research, Vietnam. |
| Pseudocode | Yes | The training procedure of OTM is summarized in Algorithm 1. Algorithm 1 OTM Algorithm Input: Incomplete data matrix X = X1, , Xn T Rn d; missing mask M, regularization coefficients λ, γ1, γ2 > 0; loss function lc; characteristic positivedefinite kernel κ and M matrix domain s > 0. Output: Weighted adjacency matrix W θ. Initialize the parameters θ, Φ. while (Φ, θ) not converged do Set XO = h Xj i if M j i = 0 else 0, i [d], j [n] i ; Sample f X from Φ (XO); Evaluate Y = fθ(f X); Update Φ, θ by descending L(Φ, θ) = 1 j=1 lc f Xj, Y j + λ MMD(f X, Y , κ) + γ1 log det s I W W + d log s + γ2 W 1. |
| Open Source Code | Yes | 2The code is released at https://github.com/ is Vy08/OTM. |
| Open Datasets | Yes | We evaluate OTM2 on 3 well-known biological datasets with ground-true causal relations. The first one is the Neuropathic Pain dataset (Tu et al., 2019), containing diagnosis records of neuropathic pain patients. [...] The second causal graph named Sachs (Sachs et al., 2005) models a network of cellular signals, consisting of 11 continuous variables and 7466 samples. Dream4 (Greenfield et al., 2010; Marbach et al., 2010) is the last dataset... |
| Dataset Splits | No | The paper mentions total sample sizes for datasets (e.g., '1000 observations', '7466 samples', '100 samples') and missing rates, but it does not specify explicit training, validation, or test dataset splits (e.g., 80/10/10 split or specific counts for each). |
| Hardware Specification | No | The paper does not provide any specific details about the hardware (e.g., CPU, GPU models, memory, or cloud instances) used to run the experiments. |
| Software Dependencies | No | The paper mentions software components such as 'DAGMA', 'ADAM optimizers', and 'Scikit-learn implementation', but it does not provide specific version numbers for any of these software dependencies. |
| Experiment Setup | Yes | We apply the default structure learning setting of DAGMA to OTM and the imputation baselines: log MSE for the non-linear score function, number of iterations T = 4, initial central path coefficient µ = 1, decay factor α = 0.1, log-det parameter s = {1, 0.9, 0.8, 0.7}. DAGMA implements an adaptive gradient method using the ADAM optimizers (Kingma & Ba, 2014) with learning rate of 2 10 4 and (γ1, γ2) = (0.99, 0.999), where γ2 is the coefficient for l1 regularizer included to promote sparsity. The optimal values of λ are found to be 0.01 for MLP models (the correctly specified setting) and 1.5 for the other models (mis-specified settings). For the divergence measure D, we use MMD with RBF kernel in our experiments. The imputation network is integrated into DAGMA, and all parameters are optimized for 8 103 iterations. For OT imputer (Muzellec et al., 2020), we set the learning rate to 0.01 and the number of iterations to 10, 000. Miss DAG is currently built on NOTEARS (Zheng et al., 2018). For Miss DAG, we set the number of iterations for the EM procedure is 10 on synthetic datasets and 100 for real datasets, while leaving the other hyper-parameters of Miss DAG same as reported in the paper. |