Learning Fast Samplers for Diffusion Models by Differentiating Through Sample Quality

Authors: Daniel Watson, William Chan, Jonathan Ho, Mohammad Norouzi

ICLR 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental DDSS achieves strong results on unconditional image generation across various datasets (e.g., FID scores on LSUN church 128x128 of 11.6 with only 10 inference steps, and 4.82 with 20 steps, compared to 51.1 and 14.9 with strongest DDPM/DDIM baselines). 5 EXPERIMENTS. We evaluate all of our models on both FID and Inception Score (IS) (Salimans et al., 2016), comparing the samplers discovered by DDSS against DDPM and DDIM baselines with linear and quadratic strides.
Researcher Affiliation Industry Daniel Watson , William Chan, Jonathan Ho & Mohammad Norouzi Google Research, Brain Team {watsondaniel,williamchan,jonathanho,mnorouzi}@google.com
Pseudocode No The paper describes the proposed method and mathematical formulations (e.g., Theorem 1) but does not include any explicitly labeled 'Pseudocode' or 'Algorithm' blocks.
Open Source Code No The paper states 'In JAX (Bradbury et al., 2018), this is trivial to implement by simply wrapping the score function calls with jax.remat.' but does not provide a specific link or explicit statement about the public availability of their own source code for the described methodology.
Open Datasets Yes Specifically, we experiment with the DDPM trained by Ho et al. (2020) with Lsimple on CIFAR10, as well as a DDPM following the exact configuration of Nichol & Dhariwal (2021) trained on Image Net 64x64 (Deng et al., 2009) with their Lhybrid objective... We include results for LSUN (Yu et al., 2015) bedrooms and churches at the 128x128 resolution.
Dataset Splits No The paper mentions comparing 50K model and training data samples and computing IS scores on 5K samples, but does not explicitly provide details about train/validation/test splits, specific percentages, or absolute counts for a validation set, nor does it specify a cross-validation setup or predefined splits for validation.
Hardware Specification No The paper does not provide specific hardware details such as GPU or CPU models, memory, or cloud instance types used for running the experiments.
Software Dependencies No The paper mentions 'In JAX (Bradbury et al., 2018), this is trivial to implement by simply wrapping the score function calls with jax.remat.' but does not specify a version number for JAX or any other software dependencies. It also mentions 'Adam optimizer (Kingma & Ba, 2015)' but not its version or other libraries with versions.
Experiment Setup Yes We apply gradient updates using the Adam optimizer (Kingma & Ba, 2015). We sweeped over the learning rate and used λ = 0.0005. We did not sweep over other Adam hyperparameters and kept β1 = 0.9, β2 = 0.999, and ϵ = 1 10 8. 3. We tried batch sizes of 128 and 512 and opted for the latter... 4. We run all of our experiments for 50K training steps... We use the Adam optimizer with learning rate 0.0003 (linearly warmed up for the first 1000 training steps), batch size 2048, gradient clipping at norms over 1.0, dropout of 0.1, and EMA over the weights with decay rate 0.9999.