Bespoke Non-Stationary Solvers for Fast Sampling of Diffusion and Flow Models
Authors: Neta Shaul, Uriel Singer, Ricky T. Q. Chen, Matthew Le, Ali Thabet, Albert Pumarola, Yaron Lipman
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We evaluate BNS solvers on: (i) Class conditional/unconditional image generation, (ii) Text-to-Image generation, and (iii) Text-to-Audio generation. Additionally, we compare our method with model distillation. Unless stated otherwise, conditional sampling is done using classifier-free guidance (CFG) (Ho & Salimans, 2022; Zheng et al., 2023b). All BNS solvers are trained on 520 pairs (x0, x(1)) of noise and generated image using adaptive RK45 (Shampine, 1986) solver. During optimization (Algorithm 2) we log PSNR on a validation set of 1024 such pairs and report results on best validation iteration. Further details are in Appendix D.1. |
| Researcher Affiliation | Collaboration | 1Weizmann Institute of Science 2Gen AI, Meta 3FAIR, Meta. |
| Pseudocode | Yes | Algorithm 1 Non-Stationary sampling. Algorithm 2 Bespoke Non-Stationary solver training. |
| Open Source Code | No | The paper does not contain an explicit statement offering the source code for the methodology described, nor does it provide a direct link to a code repository. |
| Open Datasets | Yes | We evaluate our method on the class conditional Image Net-64/128 (Deng et al., 2009) dataset. As recommended by the authors (ima) to support fairness we used the official face-blurred data, see more details in Appendix D.2. We report PSNR w.r.t. ground truth (GT) images generated with adaptive RK45 solver (Shampine, 1986), and Fr echet Inception Distance (FID) (Heusel et al., 2017), both metrics are computed on 50k samples from the models. We train our BNS solvers for NFE {4, 6, . . . , 20} with RK-Midpoint initial solver and preconditioning σ0 = 1, each taking 0.2 1% fraction of the GPU days used to train the diffusion/flow models (i.e., 2-10 GPU days with Nvidia V100). We compare our results against various baselines, including generic solvers, exponential solvers: DDIM (Song et al., 2022), DPM++ (Zhang & Chen, 2023), and Uni PC (Zhao et al., 2023), as well as the BST (Shaul et al., 2023) distilled solvers. Figure 4 shows our BNS solvers improves both PSNR and FID over all baselines. Specifically, in PSNR metric we achieve a large improvement of at least 5 10d B above the runner-up baseline and get to 5% from FID of the GT solver (about 160 320 NFE) with 16 NFE. Qualitative examples are shown in Figures 9 and 10 in Appendix D.2. Interestingly, for PSNR we see the order: BNS > BST > DPM > RK-Midpoint/Euler, that matches well the solver hierarchy proved in Theorem 3.2, see also Figure 3. In Figure 11 we also show an ablation experiment comparing the Non-Stationary and Scale-Time family both optimized with Algorithm 2, demonstrating the benefit in the NS family of solvers over the ST family. Lastly, Figure 4 additionally shows results of a BNS solver transferred between different models (BNS-transfer), see details in Appendix D.2. Unconditional image generation. We also evaluate our method on a pretrained CIFAR10 model from (Karras et al., 2022) in Table 6, in particular improving upon DPM-solverv3 (Zheng et al., 2023a) and Uni PC (Zhao et al., 2023) in the low NFE regime. Text-to-Image generation. In considerations regarding the training data of Stable Diffusion, we have opted not to experiment with this model. Hence, we use a large latent FM-OT T2I model (2.2b parameters) trained on a proprietary dataset of 330m image-text pairs. Image size is 512 512 3 while the latent space is of dimension 64 64 4; see implementation details in Appendix E. For evaluation we report PSNR w.r.t GT images, similar to the class conditional task. Additionally, we use MS-COCO (Lin et al., 2015) validation set and report perceptual metrics including Pick Score (Kirstain et al., 2023), Clip Score (Ramesh et al., 2022), and zero-shot FID. All four metrics are computed on 30K generated and validation images and reported for guidance (CFG) scale w = 2 and w = 6.5. For each guidance |
| Dataset Splits | Yes | For all tasks training set and validation set were generate using using adaptive RK45 solver, optimization is done with Adam optimizer (Kingma & Ba, 2017) and results are reported on best validation iteration. Class condition image generation. For this task we generated 520 pairs of (x0, x(1)), noise and image, for the training set, and 1024 such pairs for the validation set. Text-to-Image. For this task, we generate two training and validation sets of 520 and 1024 pairs (resp.), one for guidance scale w = 2.0 and one for w = 6.5. The text prompts for the generation were taken from the training set of MS-COCO (Lin et al., 2015). For each guidance scale we train BNS solvers with NFE {12, 16, 20}, learning rate of 1e 4, cosine annealing learning rate scheduler, batch size of 8, for 20k iterations. We compute PSNR on the validation set every 200 iterations. Text-to-Audio. We generate a training set of 10k pairs and a validation set of 1024 pairs. We train BNS solver with NFE {8, 12, 16, 20}, and optimize with learning rate of 1e 4, cosine annealing learning rate scheduler, batch size of 40, for 15k iterations. We compute SNR on the validation set every 5k itrations. |
| Hardware Specification | Yes | We train our BNS solvers for NFE {4, 6, . . . , 20} with RK-Midpoint initial solver and preconditioning σ0 = 1, each taking 0.2 1% fraction of the GPU days used to train the diffusion/flow models (i.e., 2-10 GPU days with Nvidia V100). |
| Software Dependencies | No | The paper mentions using "Adam optimizer" but does not specify software dependencies with version numbers (e.g., Python, PyTorch, TensorFlow versions or specific library versions). |
| Experiment Setup | Yes | Class condition image generation. For this task we generated 520 pairs of (x0, x(1)), noise and image, for the training set, and 1024 such pairs for the validation set. For each model on this task, Image Net-64 eps-VP/FM-CS/FM-OT, and Image Net-128 FM-OT, we train BNS solvers with NFE {4, 6, 8, 10, 12, 14, 16, 18, 20}. We use with learning rate of 5e 4, a polynomial decay learning rate scheduler, batch size of 40, for 15k iterations. We compute PSNR on the validation set every 100 iterations. Text-to-Image. For this task, we generate two training and validation sets of 520 and 1024 pairs (resp.), one for guidance scale w = 2.0 and one for w = 6.5. The text prompts for the generation were taken from the training set of MS-COCO (Lin et al., 2015). For each guidance scale we train BNS solvers with NFE {12, 16, 20}, learning rate of 1e 4, cosine annealing learning rate scheduler, batch size of 8, for 20k iterations. We compute PSNR on the validation set every 200 iterations. Text-to-Audio. We generate a training set of 10k pairs and a validation set of 1024 pairs. We train BNS solver with NFE {8, 12, 16, 20}, and optimize with learning rate of 1e 4, cosine annealing learning rate scheduler, batch size of 40, for 15k iterations. We compute SNR on the validation set every 5k itrations. |