Gated Linear Attention Transformers with Hardware-Efficient Training

Authors: Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, Yoon Kim

ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our main experiments are on language modeling, where we study whether GLA can perform competitively against a (i) strong Transformer baseline with modern architectural recipes and (ii) recent linear-time models. We use the Slim Pajama dataset (Soboleva et al., 2023) and tokenize it using the Mistral tokenizer (Jiang et al., 2023). The original dataset contains 627B tokens; we use a 100B subset.
Researcher Affiliation Collaboration 1Massachusetts Institute of Technology 2MIT-IBM Watson AI Lab. Correspondence to: Songlin Yang <yangsl66@mit.edu>, Bailin Wang <bailinw@mit.edu>.
Pseudocode Yes Algorithm 1 FLASHLINEARATTENTION: Forward Pass; Algorithm 2 FLASHLINEARATTENTION: Backward Pass; Listing 1: Pytorch-like code snippet of our two-level chunking algorithm for training GLA.
Open Source Code Yes https://github.com/sustcsonglin/fl ash-linear-attention
Open Datasets Yes We use the Slim Pajama dataset (Soboleva et al., 2023) and tokenize it using the Mistral tokenizer (Jiang et al., 2023).
Dataset Splits No The paper mentions training on specific token counts and batch sizes, but does not explicitly provide training/validation/test dataset splits (e.g., percentages or sample counts for each split).
Hardware Specification Yes Figure 2 shows the speed and memory footprint of our implementation. Both versions of FLASHLINEARATTENTION are substantially faster than FLASHATTENTION-2 (Dao, 2023) on a single H100 GPU with batch size 32, number of heads 16, head dimension 64, and chunk size 64.
Software Dependencies No The paper mentions software like PyTorch, Mistral tokenizer, Adam W, and LM evaluation harness, but does not provide specific version numbers for these dependencies.
Experiment Setup Yes We train all models from scratch at two scales: 340M and 1.3B. All models are trained with Adam W (Loshchilov & Hutter, 2018) using a maximum learning rate of 3e-4. The 340M models are trained on 15B tokens with a batch size of 0.5M tokens, while the 1.3B models are trained on 100B tokens with a batch size of 2M tokens. We use a cosine learning rate schedule with a warmup of 0.5B/1B tokens for the 340M/1.3B settings, respectively. The initial and final learning rates are 3e-5. We use a weight decay of 0.01, and gradient clipping of 1.0.