Parallelizing Linear Transformers with the Delta Rule over Sequence Length
Authors: Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, Yoon Kim
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We train a 1.3B model for 100B tokens and find that it outperforms recent linear-time baselines such as Mamba [31] and GLA [124] in terms of perplexity and zero-shot performance on downstream tasks. We also experiment with two hybrid models which combine Delta Net layers with (1) sliding-window attention layers every other layer or (2) two global attention layers, and find that these hybrids outperform strong transformer baselines. |
| Researcher Affiliation | Collaboration | Massachusetts Institute of Technology Soochow University MIT-IBM Watson AI Lab |
| Pseudocode | Yes | Listing 1: Pytorch-like code snippet of the forward pass of our chunkwise algorithm for training Delta Net. We omit the dimensions of batch size and number of heads for clarity. |
| Open Source Code | Yes | The parallel Delta Net layer is made available as part of the FLASHLINEARATTENTION library [124, 123]: https://github.com/fla-org/flash-linear-attention |
| Open Datasets | Yes | We evaluate on Wikitext perplexity and zero-shot common sense reasoning tasks, including LAMBADA [LMB.; 77], Pi QA [12], Hella Swag [Hella.; 127], Wino Grande [Wino.; 99], ARC-easy (ARC-e) and ARC-challenge (Arc-c) [16]... All models are trained on the same subset of the Slim Pajama dataset with the Mistral tokenizer. |
| Dataset Splits | No | The paper specifies training tokens and batch sizes for models (e.g., 'The 340M models are trained using 15 billion tokens and a batch size of 0.5M tokens'), but does not explicitly provide percentages or counts for training, validation, and test splits from the main datasets. It mentions evaluating on specific tasks which act as test sets, but not explicit validation splits from the training corpus. |
| Hardware Specification | Yes | We used 8 H100 GPUs for 340M and 1.3B language modeling experiments. |
| Software Dependencies | No | The paper mentions software like PyTorch and Triton [112], but does not provide specific version numbers for these dependencies. |
| Experiment Setup | Yes | We used 8 H100 GPUs for 340M and 1.3B language modeling experiments. Each model uses Adam W for optimization, with a peak learning rate of 3 10 4. The 340M models are trained using 15 billion tokens and a batch size of 0.5M tokens, while the 1.3B models are trained with 100 billion tokens and a batch size of 2M tokens. We use a cosine learning rate schedule, starting with a warm-up phase of 0.5 billion tokens for the 340M models and 1 billion tokens for the 1.3B models. Both configurations have initial and final learning rates set at 3 10 5. We apply a weight decay of 0.01 and use gradient clipping at a maximum of 1.0. The head dimension of Delta Net is set to 128, and the kernel size for convolution layers is set at 4. |