Progressive Distillation for Fast Sampling of Diffusion Models

Authors: Tim Salimans, Jonathan Ho

ICLR 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental On standard image generation benchmarks like CIFAR-10, Image Net, and LSUN, we start out with state-of-the-art samplers taking as many as 8192 steps, and are able to distill down to models taking as few as 4 steps without losing much perceptual quality; achieving, for example, a FID of 3.0 on CIFAR-10 in 4 steps. Finally, we show that the full progressive distillation procedure does not take more time than it takes to train the original model, thus representing an efficient solution for generative modeling using diffusion at both train and test time.
Researcher Affiliation Industry Tim Salimans & Jonathan Ho Google Research, Brain team {salimans,jonathanho}@google.com
Pseudocode Yes Algorithm 1 Standard diffusion training; Algorithm 2 Progressive distillation
Open Source Code Yes In Algorithm 2 we provide fairly detailed pseudocode that closely matches our actual implementation, which is available in open source at https: //github.com/google-research/google-research/tree/master/diffusion_distillation.
Open Datasets Yes On standard image generation benchmarks like CIFAR-10, Image Net, and LSUN... We evaluate our proposed progressive distillation algorithm on 4 data sets: CIFAR-10, 64x64 downsampled Image Net, 128x128 LSUN bedrooms, and 128x128 LSUN Church-Outdoor.
Dataset Splits No The paper does not provide explicit details about training/validation/test dataset splits (e.g., percentages or sample counts for each split). While standard datasets are used and training procedures are described, specific partitioning is not detailed.
Hardware Specification Yes We run our experiments on TPUv4, using 8 TPU chips for CIFAR-10, and 64 chips for the other data sets.
Software Dependencies No The paper mentions using Adam optimizer and specific hyperparameters but does not provide version numbers for any software components (e.g., Python, PyTorch, TensorFlow, or specific libraries).
Experiment Setup Yes For CIFAR-10 we use an architecture with a fixed number of channels at all resolutions of 256. The model consists of a UNet that internally downsamples the data twice, to 16 × 16 and to 8 × 8. At each resolution we apply 3 residual blocks... We use dropout of 0.2 when training the original model. No dropout is used during distillation. ... All original models are trained with Adam with standard settings (learning rate of 3 × 10−4), using a parameter moving average with constant 0.9999 and very slight decoupled weight decay... We clip the norm of gradients to a global norm of 1... For CIFAR-10 we train for 800k parameter updates, for Image Net we use 550k updates, and for LSUN we use 400k updates. During distillation we train for 50k updates per iteration... We linearly anneal the learning rate from 10−4 to zero during each iteration.