Multistep Distillation of Diffusion Models via Moment Matching

Authors: Tim Salimans, Thomas Mensink, Jonathan Heek, Emiel Hoogeboom

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental By using up to 8 sampling steps, we obtain distilled models that outperform not only their one-step versions but also their original many-step teacher models, obtaining new state-of-the-art results on the Imagenet dataset. We evaluate our proposed methods in the class-conditional generation setting on the Image Net dataset.
Researcher Affiliation Industry Google Deep Mind, Amsterdam
Pseudocode Yes Algorithm 1 Ancestral sampling algorithm used for both standard denoising diffusion models as well as our distilled models. ... Algorithm 2 Moment matching algorithm with alternating optimization of generator gη and auxiliary denoising model gφ. ... Algorithm 3 Parameter-space moment matching algorithm with instant denoising model gφ.
Open Source Code No We are currently unable to share code but hope to be able to do so in the future.
Open Datasets Yes We evaluate our proposed methods in the class-conditional generation setting on the Image Net dataset (Deng et al., 2009) ... In Table 3 we report zero-shot FID (Heusel et al., 2017) and CLIP Score (Radford et al., 2021) on MS-COCO (Lin et al., 2014).
Dataset Splits Yes We distill our models for a maximum of 200,000 steps at batch size 2048, calculating FID every 5,000 steps. We report the optimal FID seen during the distillation process, keeping evaluation data and random seeds fixed across evaluations to minimize bias.
Hardware Specification Yes All experiments were run on TPUv5e, using 256 chips per experiment.
Software Dependencies No We use the Adam optimizer (Kingma & Ba, 2014) with β1 = 0, β2 = 0.99, ϵ = 1e-12.
Experiment Setup Yes We distill our models for a maximum of 200,000 steps at batch size 2048... We use the Adam optimizer (Kingma & Ba, 2014) with β1 = 0, β2 = 0.99, ϵ = 1e-12. We use learning rate warmup for the first 1,000 steps and then linearly anneal the learning rate to zero over the remainder of the optimization steps. We use gradient clipping with a maximum norm of 1.