Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in [1].
Fast unsupervised ground metric learning with tree-Wasserstein distance
Authors: Kira Michaela Düsterwald, Samo Hromadka, Makoto Yamada
ICLR 2025 | Venue PDF | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We demonstrate theoretically and empirically that the algorithm converges to a better approximation of the standard WSV approach than the best known alternatives, and does so with O(n3 + m3 + mn) complexity. In addition, we prove that the initial tree structure can be chosen flexibly, since tree geometry does not constrain the richness of the approximation up to the number of edge weights. This proof suggests a fast and recursive algorithm for computing the tree parameter basis set, which we find crucial to realising the efficiency gains at scale. Finally, we employ the tree WSV algorithm to several single-cell RNA sequencing genomics datasets, demonstrating its scalability and utility for unsupervised cell-type clustering problems. These results poise unsupervised ground metric learning with TWD as a low-rank approximation of WSV with the potential for widespread application. |
| Researcher Affiliation | Academia | 1 Gatsby Computational Neuroscience Unit, University College London, United Kingdom 2 Machine Learning and Data Science Unit, Okinawa Institute of Science and Technology, Japan |
| Pseudocode | Yes | Algorithm 1 Unsupervised ground metric learning with trees (Tree-WSV) Algorithm 2 Tree basis set recursion |
| Open Source Code | Yes | We provide source code at: https://github.com/kiradust/tree-wsv/. |
| Open Datasets | Yes | We first compare performance on the PBMC 3k dataset from 10x Genomics (Wolf et al., 2018), which consists of 2043 peripheral blood mononuclear cells with 6 broad types and 1030 genes. We also ran Tree-WSV on Neurons V1 , a dataset of sc RNA-seq neuronal visual area 1 (V1) cells in mice by Tasic et al. (2018), consisting of 1468 cells with 7 broad types and 1000 genes, and a subset of the Human Lung Cell Atlas (HLCA) sc RNA-seq consortium dataset of human respiratory tissue (Sikkema et al., 2023) |
| Dataset Splits | No | The paper describes preprocessing steps for genomic datasets (e.g., CPMnormalisation, log1p-transformation, gene selection, confidence filtering, removal of low-gene-count cells) and mentions using the average silhouette width (ASW) metric for clustering tasks, but it does not specify explicit training/test/validation splits with percentages, sample counts, or predefined split references for reproducibility. |
| Hardware Specification | Yes | Experiments were run on a NVIDIA A100 Tensor Core GPU with JAX (Bradbury et al., 2018). Computations on CPU were done using a Apple M2 Mac Book Air M2 with 16 GB RAM. GPU computations were performed on on an NVIDIA A100 node with 16 GB requested memory. |
| Software Dependencies | No | The paper mentions software like JAX and Scanpy, and notes that the code is in Python+NumPy, but it does not provide specific version numbers for these dependencies. |
| Experiment Setup | Yes | Genomics experiments using Tree-WSV were run for 20 iterations of the internal linear system of equations singular vector loop (finding w) and 15 meta-iterations of the entire algorithm (starting with constructing new trees based on the 2 previous weight matrices, through to computing 2 new TWD-based distance matrices). A JAX random seed of 0 was used. The ASW and total runtime after the 4th meta-iteration and the best score overall (usually the 12-15th meta-iteration) were reported. Genomics experiments using SSV were run for 15 iterations of the singular vectors loop with τ = 0.001, ϵ = 0.1; these follow the settings and replicate the results in Huizing et al. (2022). |