Neural Optimal Transport with General Cost Functionals

Authors: Arip Asadulaev, Alexander Korotin, Vage Egiazarian, Petr Mokrov, Evgeny Burnaev

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our method achieves notable improvements in accuracy over existing algorithms. Also, we show the performance of our method on the supervised image-to-image translation task. The main contributions of our paper are: 1. We show that the general OT problem (M2) can be reformulated as a saddle point optimization problem... 2. We provide the error analysis... 3. We construct and test examples... We use MNIST (Le Cun & Cortes, 2010), Fashion MNIST (Xiao et al., 2017) and MNISTM (Ganin & Lempitsky, 2015) datasets as P, Q. All the models are fitted on the train parts of datasets; all the provided qualitative and quantitative results are exclusively for test (unseen) data. To evaluate the visual quality, we compute FID (Heusel et al., 2017) of the entire mapped source test set w.r.t. the entire target test set. To estimate the accuracy of the mapping we use a pre-trained Res Net18 (He et al., 2016) classifier (with 95+ accuracy) on the target data. Qualitative results are shown in Figure 3; FID, accuracies in Tables 2 and 1, respectively.
Researcher Affiliation Collaboration Arip Asadulaev 1,3 Alexander Korotin 2,1 Vage Egiazarian4,5 Petr Mokrov2 Evgeny Burnaev2,1 1Artificial Intelligence Research Institute 2Skolkovo Institute of Science and Technology 3Moscow Institute of Physics and Technology 4HSE University 5Yandex aripasadulaev@airi.net,a.korotin@skoltech.ru
Pseudocode Yes Algorithm 1: Neural optimal transport with the class-guided cost functional e FG. Algorithm 2: Neural optimal transport for general cost functionals. Algorithm 3: Neural optimal transport with pair-guided cost functional and deterministic map.
Open Source Code Yes Our implementation is available at https://github.com/machinestein/gnot
Open Datasets Yes We use MNIST (Le Cun & Cortes, 2010), Fashion MNIST (Xiao et al., 2017) and MNISTM (Ganin & Lempitsky, 2015) datasets as P, Q. Each dataset has 10 (balanced) classes and the pre-defined train-test split.
Dataset Splits No The paper mentions "default train-test splits" for datasets but does not explicitly state the existence or details of a separate validation split (percentages, sample counts, or explicit use of a validation set beyond what might be implicitly part of a 'train' split). For example, it states: "All the models are fitted on the train parts of datasets; all the provided qualitative and quantitative results are exclusively for test (unseen) data." This primarily focuses on train and test, without clear specification for validation.
Hardware Specification Yes On the image data, our method converges in 5–15 hours on a Tesla V100 (16 GB).
Software Dependencies No The code is written in Py Torch framework and publicly available at https://github.com/machinestein/gnot. We use Adam (Kingma & Ba, 2014) optimizer... We use Wand B for babysitting the experiments (Biewald, 2020). We use WGAN-QC discriminator s Res Net architecture (He et al., 2016) for potential vω. We use UNet2 (Ronneberger et al., 2015) as the stochastic transport map Tθ(x, z). While it mentions software and libraries, it consistently lacks specific version numbers (e.g., PyTorch version, Adam optimizer version, specific WandB version, UNet version etc.).
Experiment Setup Yes We use Adam (Kingma & Ba, 2014) optimizer with lr = 10−4 for both Tθ and vω. The number of inner iterations for Tθ is KT = 10. The batch size is KB = 32, KX = KY = 2, KZ = 2 for training with z. Our method converges in 60k iterations of vω.