Self-Play Fine-tuning of Diffusion Models for Text-to-image Generation
Authors: Huizhuo Yuan, Zixiang Chen, Kaixuan Ji, Quanquan Gu
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our experiments on the Pick-a-Pic dataset reveal that SPIN-Diffusion outperforms the existing supervised finetuning method in aspects of human preference alignment and visual appeal right from its first iteration. Empirically, we evaluate the performance of SPIN-Diffusion on text-to-image generation tasks (Ramesh et al., 2022; Rombach et al., 2022a; Saharia et al., 2022a). Our experiments on the Pick-a-Pic dataset (Kirstain et al., 2023), with base model being Stable Diffusion-1.5 (Rombach et al., 2022b), demonstrate that SPIN-Diffusion surpasses SFT from the very first iteration. |
| Researcher Affiliation | Academia | Huizhuo Yuan Zixiang Chen Kaixuan Ji Quanquan Gu Department of Computer Science University of California, Los Angeles Los Angeles, CA 90095 {hzyuan,chenzx19,kaixuanji,qgu}@cs.ucla.edu |
| Pseudocode | Yes | Algorithm 1 Self-Play Diffusion (SPIN-Diffusion) Input: {(x0, c)}i [N]: SFT Dataset, pθ0: Diffusion Model with parameter θ0, K: Number of iterations. for k = 0, . . . , K 1 do for i = 1, . . . N do Generate real diffusion trajectories x1:T q(x1:T |x0). Generate synthetic diffusion trajectories x 0:T pθk( |c). end for Update θk+1 = argminθ Θ b LSPIN(θ, θk), which is the empirical version of (4.8) or (4.9) . end for Output: θK. |
| Open Source Code | Yes | Codes are available at https://github.com/uclaml/SPIN-Diffusion/. |
| Open Datasets | Yes | We use the Pick-a-Pic dataset (Kirstain et al., 2023) for fine-tuning. Pick-a-Pic is a dataset with pairs of images generated by Dreamlike3 (a fine-tuned version of SD-1.5) and SDXL-beta (Podell et al., 2023), where each pair corresponds to a human preference label. We also train SD-1.5 with SFT and Diffusion-DPO (Wallace et al., 2023) as the baselines. For SFT, we train the model to fit the winner images in the Pick-a Pic (Kirstain et al., 2023) trainset. |
| Dataset Splits | Yes | We conduct a grid search on the learning rate, coefficient βt, and number of training steps, and choose the hyperparameters that perform the best on the validation set. We select the best checkpoint, trained after 2000 steps as our SFT checkpoint. |
| Hardware Specification | Yes | We train the SPIN-Diffusion on 8 NVIDIA A100 GPUs with 80G memory. In SFT training, we use 4 NVIDIA A6000 GPUs. |
| Software Dependencies | No | The paper mentions the use of the Adam W optimizer, but does not specify version numbers for other key software components or libraries (e.g., PyTorch, TensorFlow). |
| Experiment Setup | Yes | We use the Adam W optimizer with a weight decay factor of 1e 2. The images are processed at a 512 512 resolution. The batch size is set to 8 locally, alongside a gradient accumulation of 32. For the learning rate, we use a schedule starting with 200 warm-up steps, followed by linear decay. We conduct a grid search on the learning rate, coefficient βt, and number of training steps, and choose the hyperparameters that perform the best on the validation set. We set the learning rate at 2.0e 5 for the initial two iterations, reducing it to 5.0e 8 for the third iteration. The coefficient βt is chosen as 2000 for the first iteration, increasing to 5000 for the subsequent second and third iterations. Training steps are 50 for the first iteration, 500 for the second, and 200 for the third. In training the DPO model, we employ the same Adam W optimizer and maintain a batch size of 8 and a gradient accumulation of 32. The learning rate is set to 2.0e 5, and βt is set to 2000. The total number of training steps for DPO is 350. In SFT training, we use 4 NVIDIA A6000 GPUs. We use the Adam W optimizer with a weight decay of 0.01. The local batch size is set to 32 and the global batch size is set to 512. Our learning rate is 1e-5, with linear warmup for 500 steps with no learning rate decay. We save checkpoints every 500 steps and evaluate the checkpoints on Pick-a-Pic validation. We select the best checkpoint, trained after 2000 steps as our SFT checkpoint. During generation, we use a guidance scale of 7.5, and fixed the random seed as 5775709. |