Training Diffusion Models with Reinforcement Learning

Authors: Kevin Black, Michael Janner, Yilun Du, Ilya Kostrikov, Sergey Levine

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Empirically, DDPO can adapt text-to-image diffusion models to objectives that are difficult to express via prompting, such as image compressibility, and those derived from human feedback, such as aesthetic quality.
Researcher Affiliation Academia 1 University of California, Berkeley 2 Massachusetts Institute of Technology {kvablack, janner, kostrikov, sergey.levine}@berkeley.edu yilundu@mit.edu
Pseudocode No The paper describes algorithms and derivations but does not include any explicitly labeled pseudocode or algorithm blocks.
Open Source Code Yes The project s website can be found at http://rl-diffusion.github.io.
Open Datasets Yes Compressibility and incompressibility prompts are sampled uniformly from all 398 animals in the ImageNet-1000 (Deng et al., 2009) categories.
Dataset Splits No The paper does not explicitly describe training, validation, and test splits with specific percentages or counts. It mentions finetuning prompts and generalization to unseen prompts, implying a test set, but no distinct validation set for hyperparameter tuning or early stopping.
Hardware Specification Yes RWR experiments were conducted on a v3-128 TPU pod, and took approximately 4 hours to reach 50k samples. DDPO experiments were conducted on a v4-64 TPU pod, and took approximately 4 hours to reach 50k samples. For the VLM-based reward function, LLaVA inference was conducted on a DGX machine with 8 80Gb A100 GPUs.
Software Dependencies No We used the following open-source libraries for this work: NumPy (Harris et al., 2020), JAX (Bradbury et al., 2018), Flax (Heek et al., 2023), optax (Babuschkin et al., 2020), h5py (Collette, 2013), transformers (Wolf et al., 2020), and diffusers (von Platen et al., 2022). The paper lists software libraries but does not specify their exact version numbers.
Experiment Setup Yes For all experiments, we use Stable Diffusion v1.4 (Rombach et al., 2022) as the base model and finetune only the UNet weights while keeping the text encoder and autoencoder weights frozen. We collect 256 samples per training iteration. For DDPOIS, we use the same clipped surrogate objective as in proximal policy optimization (Schulman et al., 2017), but find that we need to use a very small clip range compared to standard RL tasks. We use a clip range of 1e-4 for all experiments.