Understanding and Minimising Outlier Features in Transformer Training

Authors: Bobby He, Lorenzo Noci, Daniele Paliotta, Imanol Schlag, Thomas Hofmann

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our work focuses on the above questions, first identifying several quantitative metrics, such as the kurtosis over neuron activation norms, to measure OFs. With these metrics, we study how architectural and optimisation choices influence OFs, and provide practical insights to minimise OFs during training. As highlights, we introduce a novel unnormalised transformer block, the Outlier Protected block, and present a previously unknown benefit of non-diagonal preconditioning optimisers, finding both approaches to significantly reduce OFs and improve quantisation without compromising convergence speed, at scales of up to 7B parameters. Notably, our combination of OP block and non-diagonal preconditioner (SOAP) achieves 14.87 weight-and-activation int8 perplexity (from 14.71 in standard precision), compared to 63.4 int8 perplexity (from 16.00) with a default OF-prone combination of Pre-Norm model and Adam, when quantising OPT-125m models post-training.
Researcher Affiliation Academia Bobby He1 Lorenzo Noci1 Daniele Paliotta2 Imanol Schlag1 Thomas Hofmann1 1Department of Computer Science, ETH Zürich 2Machine Learning Group, University of Geneva
Pseudocode No Insufficient information. The paper describes methods and operations mathematically and textually but does not include any explicitly labeled 'Pseudocode' or 'Algorithm' blocks.
Open Source Code Yes Our code for experiments on the Code Parrot dataset can be found at https://github.com/ bobby-he/simplified_transformers.
Open Datasets Yes Throughout this work, we train transformers on the next-token language modelling task, and study OFs, on a range of datasets, including: 1) Code Parrot,5 2) Languini Books [39], 3) Book Corpus [40] and English Wikipedia,6 and 4) Fine Web-Edu [41]. Unless stated otherwise our experimental results are conducted on Code Parrot, but importantly our conclusions regarding OFs are consistent throughout across language modelling datasets. In App E.1, we also explore OFs in image classification settings with other architectures like Vision Transformers [42] and MLPs. (Code Parrot: https://huggingface.co/datasets/transformersbook/codeparrot-train. English Wikipedia: https://huggingface.co/datasets/google/wiki40b).
Dataset Splits Yes In Tab 2, we take the OPT-125m [6] setting of Bondarenko et al. [15], training models using Adam W in standard mixed FP16/FP32 precision on Book Corpus+Wikipedia for around 12B tokens. Post training, we quantise (PTQ) to int8 weight-and-activations, using the same quantisation recipe as [15].
Hardware Specification Yes We train for 4.2B tokens at 1.2B scale as this took 24 hours on 4 A100 80GB GPUs; we were unable to train for longer due to compute constraints. Scales under 1B were trained on a single A5000 or RTX-2080Ti GPU, taking around 2 days for 3.3B tokens (or equivalently, 50K steps at batch size 128 and sequence length 512). For our 6B token 7B model runs, we train for 20K steps with global batch size 72 (node batch size 6 on 12 nodes of 4 GH200 GPUs, with tensor parallelism) and context length 4096.
Software Dependencies No Insufficient information. The paper mentions software like 'nanotron' and 'SOAP [69]' but does not provide specific version numbers for these or other key software components.
Experiment Setup Yes Unless otherwise stated our experimental results are conducted on Code Parrot... Our default architecture scale has width d = 768 and 6 layers... We train with Adam W optimiser [43] and weight decay 0.1, (β1, β2) = (0.9, 0.999), and ϵ = 1e 8 unless otherwise stated, and clip the maximum gradient norm to 1. ...Unless otherwise stated, we train with sequence length 128 and batch size 32 for 80K steps, with linear warmup to maximum learning rate 1e 3, for 5% of the steps, before linear decay. ...For our 6B token 7B model runs, we train for 20K steps with global batch size 72... and context length 4096.