Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in [1].

Wasserstein Flow Matching: Generative Modeling Over Families of Distributions

Authors: Doron Haviv, Aram-Alexandre Pooladian, Dana Pe’Er, Brandon Amos

ICML 2025 | Venue PDF | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We demonstrate the effectiveness of our approach for generative modeling between distributions over Gaussian distributions and distributions over point-clouds, which are realization of general distributions. The former task is motivated by recent directions in single-cell and spatial transcriptomics (Haviv et al., 2024b; Persad et al., 2023), where we consider matching problems over the Bures Wasserstein space (BW), the Gaussian submanifold of the Wasserstein space. In this case, we show that WFM can be further modified, resulting in the Bures Wasserstein FM (BW-FM) algorithm. We validate BW-FM on a variety of Gaussian-based datasets, where we observe that samples generated by our algorithm are significantly more robust than na ıve approaches which do not fully exploit the underlying geometry of the data.
Researcher Affiliation Collaboration 1Weill Cornell 2Memorial Sloan Kettering Cancer Center 3Center for Data Science, New York University 4Howard Hughes Medical Institute 5Meta AI. Correspondence to: Doron Haviv <EMAIL>, Aram-Alexandre Pooladian <EMAIL>, Brandon Amos <EMAIL>.
Pseudocode Yes Algorithm 1 Wasserstein FM Training Algorithm 2 BW(Rd) Generation Algorithm 3 General Distribution Generation
Open Source Code Yes Code is available at Wasserstein Flow Matching.
Open Datasets Yes We validate BW-FM on a variety of Gaussian-based datasets... Applied on environments of gut-tube cells from a seq FISH dataset of mouse embryogenesis (Lohoff et al., 2022)... On a sc RNA-seq atlas elucidating human immune response to COVID (Stephenson et al., 2021)... Derived from 3D CAD designs, Shape Net & Model Net (Wu et al., 2015; Chang et al., 2015) are touchstone shape datasets in computational geometry... This happens in the MNIST or Letters datasets... Using the 254-gene MERFISH atlas (Zhang et al., 2021)... XENIUM assay of melanoma metastasis to the brain (Haviv et al., 2024b).
Dataset Splits No The paper describes using "training set" and "test set" for Shape Net & Model Net datasets, and mentions sampling `n` points from shapes, but does not provide specific split percentages, sample counts, or detailed methodology for dividing the overall datasets into training, validation, and test sets. For example: "We sample 64 shapes from the training set for each class, and randomly select n = 2048 points from each." and "To evaluate generation quality, we synthesize point-clouds to much the size of the test set".
Hardware Specification Yes For the Shape Net experiments (see Table 3), the model is trained for 500,000 training steps, totaling to about 3 days of trainings of a single A100 GPU. All other experiments (see Table 4) trained for 100,000 steps, requiring around 3-4 hours of GPU use.
Software Dependencies No WFM relies on JAX and OTT-JAX (Bradbury et al., 2021; Cuturi et al., 2022) and enjoys seamless optimization via end-to-end just-in-time compilation. The paper mentions these software but does not specify their version numbers.
Experiment Setup Yes By default, all models use a 6-layer neural network using relu non-linearity, with 1024 neurons per layer, applying skip connections and layer-norm (Ba, 2016). Training is performed for 100, 000 gradient descent steps using the Adam optimizer (Kingma, 2014) with an exponential learning rate decay of 0.97 every 1000 steps and batch size of 128. ... By default, the entropic OT map is constructed with regularization weight of ε = 0.002 and 200 Sinkhorn iterations... The transformer network is composed to 6 multi-head attention blocks, with an embedding dimension of 512 and 4 heads. Our model is optimizer with Adam (Kingma, 2014) using an exponential learning rate decay and batch size of 64.