FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
Authors: Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We demonstrate that our method, FLASHATTENTION-3, achieves speedup on H100 GPUs by 1.5-2.0 with BF16 reaching up to 840 TFLOPs/s (85% utilization), and with FP8 reaching 1.3 PFLOPs/s. We validate that FP8 FLASHATTENTION-3 achieves 2.6 lower numerical error than a baseline FP8 attention. |
| Researcher Affiliation | Collaboration | Jay Shah 1, Ganesh Bikshandi 1, Ying Zhang 2, Vijay Thakkar 3,4, Pradeep Ramani 3, Tri Dao5,6 1 Colfax Research, 2 Meta, 3 NVIDIA, 4 Georgia Institute of Technology 5 Princeton University, 6 Together AI |
| Pseudocode | Yes | Algorithm 1 FLASHATTENTION-3 forward pass without intra-consumer overlapping CTA view |
| Open Source Code | Yes | We open-source FLASHATTENTION-3 with a permissive license3 and plan to integrate it with Py Torch to benefit the largest number of researchers and developers. FLASHATTENTION-3 is available at https://github.com/Dao-AILab/flash-attention |
| Open Datasets | No | The paper generates its own data for numerical error validation: 'we generate the entries of Q,K,V with the following distribution: N (0,1)+N (0,100) Bernoulli(0.001)'. It does not use a publicly available or open dataset with concrete access information for its main experiments. |
| Dataset Splits | No | The paper discusses empirical validation of its method but does not explicitly provide training/validation/test dataset splits needed to reproduce data partitioning for the experiments. |
| Hardware Specification | Yes | We measure the runtime of different attention methods on an H100 80GB SXM5 GPU for different settings (without / with causal mask, head dimension 64 or 128) for BF16 inputs. We benchmark the speed on an H100 80GB SXM5 (700W). |
| Software Dependencies | Yes | Specifically, we use: cu DNN 9.5.0.50 CUTLASS 3.6 FLASHATTENTION 2.6.3 Py Torch 2.5.0 |
| Experiment Setup | Yes | We vary the sequence length as 512, 1k, ..., 16k, and set batch size so that the total number of tokens is 16k. We set the hidden dimension to 2048, and head dimension to be either 64, 128, or 256 (i.e., 32 heads, 16 heads, or 8 heads). To calculate the FLOPs of the forward pass, we use: 4 seqlen2 head dimension number of heads. |