Wasserstein Wormhole: Scalable Optimal Transport Distance with Transformer
Authors: Doron Haviv, Russell Zhang Kunes, Thomas Dougherty, Cassandra Burdziak, Tal Nawy, Anna Gilbert, Dana Pe’Er
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Empirically, distances between Wormhole embeddings closely match Wasserstein distances, enabling linear time computation of OT distances.6. Experiments We demonstrate Wasserstein Wormhole on datasets representing multiple contexts where OT is frequently applied. |
| Researcher Affiliation | Academia | 1Computational and Systems Biology Program, Sloan Kettering Institute, Memorial Sloan Kettering Cancer Center 2Tri Institutional Training Program in Computational Biology and Medicine, Weill Cornell Medicine 3Department of Statistics, Columbia University 4Department of Mathematics and Statistics, Yale University 5Howard Hughes Medical Institute. |
| Pseudocode | Yes | Algorithm 1 Wasserstein Wormhole Algorithm 2 Theoretically Optimal Embedding of non Euclidean Distance Matrices with Projected Gradient Descent (PGD) |
| Open Source Code | Yes | Software is available at http://wassersteinwormhole. readthedocs.io/en/latest/. |
| Open Datasets | Yes | We applied Wormhole to the MNIST dataset of hand-written digits, which consists of 70,000 samples of 28 28 pixel grayscale images.Similar to MNIST, Fashion MNIST contains 70,000 images that each measure 28 28 pixels, but the images depict articles of clothing instead of digits.The Model Net40 dataset comprises point clouds from 3D synthetic CAD models of 40 diverse object classes.Shape Net is a CAD-based dataset of 3D point clouds consisting of 16 different object classes.MERFISH dataset of the motor cortex (Zhang et al., 2021)sc RNA-seq atlas of patient response to COVID (Stephenson et al., 2021). |
| Dataset Splits | No | The paper mentions training and testing on datasets but does not explicitly provide details for a separate validation set or specific train/validation/test split ratios or counts. |
| Hardware Specification | Yes | straightforward computation of all 70000 2 = 2.45 109 pairwise Wasserstein distances using current OT solvers is far from computationally feasible and would require a week of JAX-OTT (Cuturi et al., 2022) on an 80GB GPU (Figure 2). |
| Software Dependencies | Yes | Our algorithm is implemented in JAX (Bradbury et al., 2021) and integrates with OT tools (OTT-JAX) (Cuturi et al., 2022) for efficient Wasserstein distance calculations.To calculate the object class accuracy of learned Wormhole embeddings, we train an MLP classifier using the default implementation from scikit-learn (Pedregosa et al., 2011). |
| Experiment Setup | Yes | Wormhole is trained with 10,000 gradient descent (GD) steps using the ADAM optimizer (Kingma & Ba, 2014) with an initial learning rate of 10 4 and an exponential decay schedule. The default batch size was set to 16... |