On amortizing convex conjugates for optimal transport

Authors: Brandon Amos

ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental I show that combining amortized approximations to the conjugate with a solver for fine-tuning significantly improves the quality of transport maps learned for the Wasserstein-2 benchmark by Korotin et al. (2021a) and is able to model many 2-dimensional couplings and flows considered in the literature. Sect. 5 shows that amortizing and fine-tuning the conjugate results in state-of-the-art performance in all of the tasks proposed in the Wasserstein-2 benchmark by Korotin et al. (2021a).
Researcher Affiliation Industry Brandon Amos Meta AI
Pseudocode Yes Algorithm 1 Learning Wasserstein-2 dual potentials with amortized and fine-tuned conjugation. Algorithm 2 CONJUGATE(f, y, xinit)
Open Source Code Yes All of the baselines, methods, and solvers in this paper are available at http://github.com/facebookresearch/w2ot.
Open Datasets Yes Wasserstein-2 benchmark by Korotin et al. (2021a) and samples from generative models trained on Celeb A (Liu et al., 2015).
Dataset Splits No The paper refers to using a benchmark and sampling, but does not explicitly provide percentages or counts for training, validation, and test dataset splits.
Hardware Specification Yes wall-clock time for the entire training run measured on an NVIDIA Tesla V100 GPU
Software Dependencies No The core set of tools in Python (Van Rossum and Drake Jr, 1995; Oliphant, 2007) enabled this work, including Hydra (Yadan, 2019), JAX (Bradbury et al., 2018), Flax (Heek et al., 2020), Matplotlib (Hunter, 2007), numpy (Oliphant, 2006; Van Der Walt et al., 2011), and pandas (Mc Kinney, 2012).
Experiment Setup Yes Tables 4 and 5 detail the main hyper-parameters for the Wasserstein-2 benchmark experiments. I tried to keep these consistent with the choices from Korotin et al. (2021a), e.g. using the same batch sizes, number of training iterations, and hidden layer sizes for the potential.