Why Warmup the Learning Rate? Underlying Mechanisms and Improvements
Authors: Dayal Singh Kalra, Maissam Barkeshli
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Here we perform extensive studies on the effect of learning rate warmup across architectures (FCNs, Res Nets, and Transformers), initializations and parameterizations, datasets (CIFAR-10, CIFAR-100, Tiny Image Net, Wiki Text), and for both SGD and Adam. |
| Researcher Affiliation | Academia | 1Department of Physics and Joint Quantum Institute, University of Maryland, College Park 2Institute for Physical Science and Technology, University of Maryland, College Park |
| Pseudocode | Yes | Algorithm 1 Exponential Search; Algorithm 2 Binary Search; Algorithm 3 Persistent Catapult Warmup |
| Open Source Code | Yes | The key results can be reproduced using the Git Hub repo: https: //github.com/dayal-kalra/why-warmup. |
| Open Datasets | Yes | We consider standard image classification datasets such as CIFAR-10, CIFAR-100 [21], and Tiny Image Net [1]. The images are normalized to have zero mean and unit variance. For MSE loss, we use one-hot encoding for the labels. We consider the next token prediction task on the Wikitext-2 and Wikitext-103 datasets [25]. |
| Dataset Splits | Yes | We consider standard image classification datasets such as CIFAR-10, CIFAR-100 [21], and Tiny Image Net [1]... For various image classification tasks, we employ data augmentation techniques, applied in the following order: random horizontal flips, random cropping, and mixup [41]. |
| Hardware Specification | Yes | The phase diagram experiments typically required about an hour on per run on an A100 GPU. |
| Software Dependencies | No | All models were implemented using the JAX [3], and Flax libraries [15]. While JAX and Flax are mentioned with citations, specific version numbers for these libraries are not provided in the text. |
| Experiment Setup | Yes | Experimental Setup: We consider Wide Res Nets (WRNs) and Transformers (LM) parameterized in either SP or µP. WRNs are trained on CIFAR-10, CIFAR-100, and Tiny-Image Net, employing data augmentation. Transformers are trained on the next token prediction task using the Wiki Text dataset. These models are trained with MSE or cross-entropy (xent) loss functions using SGD or Adam optimizers for a fixed training budget of T = 105 steps unless otherwise specified. Training begins with a linear warmup phase from ηinit = 0 to ηtrgt over Twrm steps. After the warmup period, training continues at ηtrgt for the remaining training budget. In some cases, following the warmup period, we decrease the learning rate using cosine decay [24]. Further details are provided in Appendix D. |