Small-scale proxies for large-scale Transformer training instabilities
Authors: Mitchell Wortsman, Peter J Liu, Lechao Xiao, Katie E Everett, Alexander A Alemi, Ben Adlam, John D Co-Reyes, Izzeddin Gur, Abhishek Kumar, Roman Novak, Jeffrey Pennington, Jascha Sohl-Dickstein, Kelvin Xu, Jaehoon Lee, Justin Gilmer, Simon Kornblith
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this report we reproduce, study, and predict training instability in Transformer models. We find that measuring the relationship between learning rate and loss across scales is a useful tool to identify instability (e.g., Figure 1). This section details our experimental set-up (Section 2.1) and useful tools employed by our analysis: (i) measuring the relationship between learning rate and loss across scales (Section 2.2) and (ii) examining scaling trends for model characteristics (Section 2.3). |
| Researcher Affiliation | Industry | Google DeepMind |
| Pseudocode | No | The paper does not contain any section or figure explicitly labeled as 'Pseudocode' or 'Algorithm'. |
| Open Source Code | No | The paper mentions implementing their models in Flax, which is a third-party library, but does not provide any explicit statement about releasing their own source code for the methodology, nor a link to a code repository. |
| Open Datasets | Yes | We use rotary positional embeddings (Su et al., 2021), and for training data we use C4 (Raffel et al., 2020a). |
| Dataset Splits | No | The paper mentions 'final validation loss' and 'validaton loss' in Section 2.2, indicating the use of a validation set, but does not provide specific details on the split percentages or sample counts for training, validation, and test datasets. |
| Hardware Specification | Yes | We train on TPUs (Jouppi et al., 2017) in bfloat16 precision using Flax (Heek et al., 2023) and JAX (Bradbury et al., 2018). |
| Software Dependencies | No | The paper mentions using Flax (Heek et al., 2023), JAX (Bradbury et al., 2018), Orbax (Gaffney et al., 2023), and Grain (Google, 2023), but does not provide specific version numbers for these software dependencies (e.g., 'Flax (Heek et al., 2023)' indicates the publication year of the reference, not a software version number). |
| Experiment Setup | Yes | By default, we use Adam W (Loshchilov & Hutter, 2019) with β1 = 0.9, β2 = 0.95, ϵ = 1e-8, and gradient clipping at global norm 1. The default warmup is 5e3 steps, and the default number of total steps is 1e5. We use a linear schedule for warmup and and a cosine-decay (Loshchilov & Hutter, 2016) schedule for the remainder, with minimum learning rate 1e-5. We use an independent weight decay of 1e-4 and auxiliary z-loss (Chowdhery et al., 2022) with coefficient 1e-4. We use pre-normalization (Radford et al., 2019) Transformers with qk-layernorm (Dehghani et al., 2023). We do not use any biases following Chowdhery et al. (2022), and the layernorm (Ba et al., 2016) ϵ remains at the default value in Flax (Heek et al., 2023) of 1e-6. The default batch size is 256, where each batch element has a sequence length of 512 tokens. Sequences are packed so that no padding is required. Finally, we use the vocabulary from Raffel et al. (2020b) which has size 32101 and uses a Sentence Piece (Kudo & Richardson, 2018) tokenizer. |