Fast Attention Over Long Sequences With Dynamic Sparse Flash Attention

Authors: Matteo Pagliardini, Daniele Paliotta, Martin Jaggi, François Fleuret

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experimental evaluations show that SCFA can efficiently be used for a variety of sequence modeling tasks, and that our open-source implementation in the Triton language and compiler (Tillet et al., 2019) significantly outperforms Flash Attention as we increase the sparsity and for longer sequences.
Researcher Affiliation Academia Matteo Pagliardini* EPFL matteo.pagliardini@epfl.ch Daniele Paliotta* University of Geneva daniele.paliotta@unige.ch Martin Jaggi EPFL martin.jaggi@epfl.ch François Fleuret University of Geneva francois.fleuret@unige.ch
Pseudocode Yes In Alg. 1 we detail the core of the QK-sparse algorithm from 3.1. This algorithm is computing the softmax attention result only for one block of queries, corresponding to one head. In practice this algorithm would be run in parallel for all blocks of queries for all heads. We find the index of the last relevant tile to compute by iterating over block of key indices and comparing them with the largest query index for the current block of queries. This works as qidx and kidx have a monotonic structure thanks to the stable sort used when reshaping the tensors (see the pre-processing code in App. B.2 for more details).
Open Source Code Yes The code for all our experiments can be found via the following link: https://github.com/epfml/dynamic-sparse-flash-attention.
Open Datasets Yes Datasets. We test our hash-based sparsity scheme on MNIST (Le Cun et al., 1998) for autoregressive image generation, enwik8 (Hutter, 2012), and Open Web Text2 (Gao et al., 2020).
Dataset Splits No While the paper mentions 'Validation Perplexity' in figures and discussion, it does not explicitly provide specific dataset split information such as percentages, sample counts, or clear references to predefined splits with citations for reproduction.
Hardware Specification Yes All of our timing experiments with random tensors are done on NVIDIA A100 GPUs, using bfloat16. For our language modeling tasks on Open Web Text2, we trained using dataparallelism on two or three A100s for experiments with sequence lengths of respectively 8192 and 16384. Comparisons with the Reformer are performed on a single A100 or a single NVIDIA RTX 4090 GPU.
Software Dependencies Yes Experimental evaluations show that SCFA can efficiently be used for a variety of sequence modeling tasks, and that our open-source implementation in the Triton language and compiler (Tillet et al., 2019)... Flash Attention has already reached wide adoption, as it s now available directly in Pytorch as of version 2.0.
Experiment Setup Yes For our language modeling experiments on Open Web Text2, we use a base autoregressive transformer architecture with 12 layers, a hidden size of 768, 12 heads of 64 dimensions each. For experiments on sequence length T = 8192, we use a batch size of 96 = 4 × 8 × 2 (batch size 4 with 8 accumulation steps and data parallelism over 2 node). When T = 16384 we use a batch size of 30 = 2 × 5 × 3. The resulting models are of around 122M parameters. The goal not being to outperform the state-of-the-art perplexity, we train for 15k iterations. The attention modules used are either using Flash Attention for the baselines or one of our sparse kernels for our methods. ... Weight-decay: 0.1 Depth (number of transformer blocks): 12 Number of heads: 12 Dropout: 0.0 Learning rate: 0.001 Percentage of iterations for warmup: 2%. We use a cosine learning rate scheduler. Adam beta1: 0.9 Adam beta2: 0.95 Hidden dimensions: 768 Dimensions per head: 64