Gated Linear Attention Transformers with Hardware-Efficient Training
Authors: Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, Yoon Kim
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our main experiments are on language modeling, where we study whether GLA can perform competitively against a (i) strong Transformer baseline with modern architectural recipes and (ii) recent linear-time models. We use the Slim Pajama dataset (Soboleva et al., 2023) and tokenize it using the Mistral tokenizer (Jiang et al., 2023). The original dataset contains 627B tokens; we use a 100B subset. |
| Researcher Affiliation | Collaboration | 1Massachusetts Institute of Technology 2MIT-IBM Watson AI Lab. Correspondence to: Songlin Yang <yangsl66@mit.edu>, Bailin Wang <bailinw@mit.edu>. |
| Pseudocode | Yes | Algorithm 1 FLASHLINEARATTENTION: Forward Pass; Algorithm 2 FLASHLINEARATTENTION: Backward Pass; Listing 1: Pytorch-like code snippet of our two-level chunking algorithm for training GLA. |
| Open Source Code | Yes | https://github.com/sustcsonglin/fl ash-linear-attention |
| Open Datasets | Yes | We use the Slim Pajama dataset (Soboleva et al., 2023) and tokenize it using the Mistral tokenizer (Jiang et al., 2023). |
| Dataset Splits | No | The paper mentions training on specific token counts and batch sizes, but does not explicitly provide training/validation/test dataset splits (e.g., percentages or sample counts for each split). |
| Hardware Specification | Yes | Figure 2 shows the speed and memory footprint of our implementation. Both versions of FLASHLINEARATTENTION are substantially faster than FLASHATTENTION-2 (Dao, 2023) on a single H100 GPU with batch size 32, number of heads 16, head dimension 64, and chunk size 64. |
| Software Dependencies | No | The paper mentions software like PyTorch, Mistral tokenizer, Adam W, and LM evaluation harness, but does not provide specific version numbers for these dependencies. |
| Experiment Setup | Yes | We train all models from scratch at two scales: 340M and 1.3B. All models are trained with Adam W (Loshchilov & Hutter, 2018) using a maximum learning rate of 3e-4. The 340M models are trained on 15B tokens with a batch size of 0.5M tokens, while the 1.3B models are trained on 100B tokens with a batch size of 2M tokens. We use a cosine learning rate schedule with a warmup of 0.5B/1B tokens for the 340M/1.3B settings, respectively. The initial and final learning rates are 3e-5. We use a weight decay of 0.01, and gradient clipping of 1.0. |