Unbalanced Low-rank Optimal Transport Solvers

Authors: Meyer Scetbon, Michal Klein, Giovanni Palla, Marco Cuturi

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We propose custom algorithms to implement these extensions for the linear OT problem and its fused-Gromov-Wasserstein generalization, and demonstrate their practical relevance to challenging spatial transcriptomics matching problems. These algorithms are implemented in the ott-jax toolbox [Cuturi et al., 2022]. We focus first in Exp. 1 on demonstrating the empirical benefits of the TI variant of our algorithm to solve linear ULOT, as implemented in Alg. 3 vs. Alg. 5; that algorithm is subsequently used as an inner routine to solve all quadratic ULR problems. We compare in Exp. 2 unbalanced low-rank (ULR) solvers to balanced low-rank (LR) counterparts on a spatial transcriptomics task, and follow in Exp. 3 by comparing ULR solvers to entropic (E) counterparts on a smaller task, to accommodate entropic solvers quadratic complexity. We conclude in Exp. 4 by comparing ULR solvers to [Thual et al., 2022], which can learn a sparse transport coupling, in the unbalanced FGW setting. Datasets. We consider two realworld datasets, described in B.1, and two synthetic datasets, that are large enough to showcase our solvers. Metrics. Following Klein et al. [2023], we evaluate maps by focusing on the two following metrics: (i) pearson correlation computed between the (ground truth) source s feature matrix F s 2 Rn d, and the barycentric projection of the target t to the source scaled by the target marginals bt. Writing P as the transport matrix from source to target, this can be computed as Pdiag( 1 bt )F t; (ii) F1 score when assessing class transfer (among 11 possible classes), computed between the original source vector of labels ls, taken in {1, , 11}n, and the inferred labels for the same points, predicted for each i by taking the argmaxj Bi,j, where B is a matrix of n 11 row probabilities, each the barycentric projection of the target t one-hot encoded labels Lt 2 {0, 1}m 11, B := Pdiag( 1.
Researcher Affiliation Collaboration Meyer Scetbon Microsoft Research t-mscetbon@microsoft.com Michael Klein Apple michalk@apple.com Giovanni Palla Helmholtz Center Munich giovanni.palla@helmholtz-muenchen.de Marco Cuturi Apple cuturi@apple.com
Pseudocode Yes Algorithm 1 ULOT(C, a, b, r, γ0, 1, 2, δ) Inputs: C, a, b, r, γ0, 1, 2, δ Q, R, g Initialization as proposed in [Scetbon and Cuturi, 2022] repeat Q = Q, R = R, g = g, r Q = CR diag(1/g), r R = C>Q diag(1/g), ! D(QT CR), rg = !/g2, γ γ0/ max(kr Qk2 1), (1) Q exp( γr Q), (2) R exp( γr R), (3) g exp( γrg), Q, R, g ULR-Dykstra(a, b, , γ, 1, 2, δ) (Alg. 5) until ((Q, R, g), ( Q, R, g), γ) < δ; Result: Q, R, g
Open Source Code Yes These algorithms are implemented in the ott-jax toolbox [Cuturi et al., 2022].
Open Datasets Yes We use the mouse brain STARmap spatial transcriptomics data from [Shi et al., 2022] for Exp. 2 and Exp. 3. We use data from the Individual Brain Charting dataset [Pinho et al., 2018], to replicate the settings of [Thual et al., 2022], in Exp. 4.
Dataset Splits Yes We selected 10 marker genes for the validation and test sets from the HPF_CA cluster. We run an extensive grid search as reported in B.2, we pick the best hyperparameters combination using performance on the 10 validation genes as a criterion, and we report that metric on the other genes in Table 1, as well as qualitative results in Figure 1 and Figure 2. Similar to Exp. 2, the fused term C is a squared-Euclidean distance matrix on 30-D PCA space, computed on gene expressions. As done in Exp. 2, we select 10 marker genes for the validation and 10 genes for the test set, from cluster OB_1. We run an extensive grid search, as in Exp. 2 and B.2.
Hardware Specification Yes In Figure 3, we compare the execution time (using our ott-jax implementation, and a single NVIDIA Ge Force RTX 2080 Ti card) of unbalanced LR Sinkhorn on large and high dimensional Gaussian distributions.
Software Dependencies Yes These algorithms are implemented in the ott-jax toolbox [Cuturi et al., 2022].
Experiment Setup Yes We use a δ = 10 9 convergence threshold and 1000 maximal number of iterations for Dykstra, in 64-bit precision. We run an extensive grid search as reported in B.2, we pick the best hyperparameters combination using performance on the 10 validation genes as a criterion, and we report that metric on the other genes in Table 1, as well as qualitative results in Figure 1 and Figure 2. We run an extensive grid search, as in Exp. 2 and B.2.