Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes
Authors: Jaehyeong Jo, Sung Ju Hwang
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We experimentally validate our approach on diverse manifolds of both real-world and synthetic datasets, on which our method outperforms or is on par with the state-of-the-art baselines. We demonstrate that ours can scale to high dimensions while allowing significantly faster training compared to the previous diffusion models relying on score matching. Especially on general manifolds, our method shows superior performance with dramatically reduced in-training simulation steps, using only 5% of the steps compared to CNF model. |
| Researcher Affiliation | Collaboration | Jaehyeong Jo 1 Sung Ju Hwang 1 2 1Korea Advanced Institute of Science and Technology (KAIST) 2Deep Auto.ai. |
| Pseudocode | Yes | Algorithm 1 Two-way bridge matching |
| Open Source Code | Yes | Code: github.com/harryjo97/riemannian-diffusion-mixture |
| Open Datasets | Yes | We first evaluate the generative models on real-world datasets living on the 2-dimensional sphere, which consists of earth and climate science events including volcanic eruptions (NOAA, 2020b), earthquakes (NOAA, 2020a), floods (Brakenridge, 2017), and wild fires (EOSDIS, 2020). |
| Dataset Splits | Yes | We split the datasets into training, validation, and test sets with (0.8, 0.1, 0.1) proportions. |
| Hardware Specification | Yes | For all experiments, we use NVIDIA Ge Force RTX 3090 and 2080 Ti and implement the source code with Py Torch (Paszke et al., 2019) and JAX. |
| Software Dependencies | No | The paper mentions PyTorch (Paszke et al., 2019) and JAX, but does not provide specific version numbers for these software dependencies. |
| Experiment Setup | Yes | For all experiments except the high dimensional tori, we use 512 hidden units and select the number of layers from 6 to 13, using either the sinusoidal or swish activation function. All models are trained with Adam optimizer and we either do not use a learning rate scheduler or use the scheduler with the learning rate annealed by a linear map which then applies cosine scheduler, as introduced in Bortoli et al. (2022). We also use the exponential moving average for the model weights (Polyak & Juditsky, 1992) with decay 0.999. For all experiments, we train our models using the time-scaled two-way bridge matching in Eq. (12), where we use 15 steps for the in-training simulation carried out by Geodesic Random Walk (Jørgensen, 1975; Bortoli et al., 2022). |