Simplified and Generalized Masked Diffusion for Discrete Data
Authors: Jiaxin Shi, Kehang Han, Zhe Wang, Arnaud Doucet, Michalis Titsias
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | On GPT-2 scale text modeling and pixel-level image modeling tasks, masked diffusions trained using our simple ELBO objective outperform previous proposals, leading to the best likelihood and zero-shot transfer performance among discrete diffusion models. 7 Experiments |
| Researcher Affiliation | Industry | Jiaxin Shi , Kehang Han , Zhe Wang, Arnaud Doucet, Michalis K. Titsias Google Deep Mind Correspondence to: jiaxins@google.com. |
| Pseudocode | Yes | A single step of MD4 training algorithm is described in Alg. 1 in Appendix. A complete description of the sampling algorithm can be found in Alg. 2 in Appendix. |
| Open Source Code | Yes | Our code is available at https://github.com/google-deepmind/md4. |
| Open Datasets | Yes | text8 [55], a character-level text modeling benchmark, and Open Web Text [56], an open clone of the unreleased Web Text dataset used to train GPT-2 [57]. train MD4 on order-agnostic image data from CIFAR-10 and downsampled Image Net 64 64 [63]. |
| Dataset Splits | Yes | We kept 2% of the original training set for validation. |
| Hardware Specification | Yes | Our model is trained on 16 TPU-v5 lite for less than a day. Our CIFAR-10 model is trained on 32 TPU-v5 lite for 24 hours. Our Image Net-64 64 model is trained on 256 TPU-v5 lite for 3.5 days. |
| Software Dependencies | No | The paper mentions 'JAX [45] implementation of categorical sampling' but does not specify its version or any other software dependencies with version numbers. |
| Experiment Setup | Yes | We used a cosine learning rate schedule with a linear warm up of 2000 steps. We applied channel-wise dropout of rate 0.05 and used Adam W optimizer with learning rate 0.0003 and a weight decay factor of 0.03. We kept the training hyperparameters the same as text8 experiment except that we reduced the dropout rate to 0.02. We used Adam W optimizer and trained for 2M iterations. We used learning rate 0.0004, batch size 256, weight decay factor 0.01 for CIFAR-10 and learning rate 0.0002, batch size 512, weight decay factor 0.03 for Image Net 64 64. |