Constrained Diffusion Models via Dual Training
Authors: Shervin Khalafi, Dongsheng Ding, Alejandro Ribeiro
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We empirically demonstrate the effectiveness of our constrained models in two constrained generation tasks: (i) we consider a dataset with one or more underrepresented classes where we train the model with constraints to ensure fairly sampling from all classes during inference; (ii) we fine-tune a pre-trained diffusion model to sample from a new dataset while avoiding overfitting. |
| Researcher Affiliation | Academia | Shervin Khalafi Dongsheng Ding Alejandro Ribeiro {shervink,dongshed,aribeiro}@seas.upenn.edu University of Pennsylvania |
| Pseudocode | Yes | Algorithm 1 Constrained Diffusion Models via Dual Training |
| Open Source Code | Yes | The source code is available here.2 (Footnote 2: https://github.com/shervinkhal/Constrained_Diffusion_Dual_Training) |
| Open Datasets | Yes | We train constrained diffusion models over three datasets: MNIST digits [44], Celeb-A faces [51], and Image-Net1 [63]. |
| Dataset Splits | No | The paper describes how FID scores are computed by comparing generated samples to a balanced version of the original dataset for evaluation, but it does not provide specific details on the training, validation, and test splits for the model training process itself (e.g., percentages or sample counts for each split). |
| Hardware Specification | Yes | We run all experiments on two NVIDIA RTX 3090-Ti GPUs in parallel. The amount of GPU memory used was 16 Gigabytes per GPU. |
| Software Dependencies | Yes | We use the Py Torch [59] and Diffusers [73] Python libraries for training our constrained diffusion models, and Adam with decoupled weight decay [52] as an optimizer. The accelerate library [30] is used for the parallelization of the training processes across multiple GPUs. |
| Experiment Setup | Yes | We summarize the important hyperparameters in our experiments in Table 2. |