FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

Authors: Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We demonstrate that our method, FLASHATTENTION-3, achieves speedup on H100 GPUs by 1.5-2.0 with BF16 reaching up to 840 TFLOPs/s (85% utilization), and with FP8 reaching 1.3 PFLOPs/s. We validate that FP8 FLASHATTENTION-3 achieves 2.6 lower numerical error than a baseline FP8 attention.
Researcher Affiliation Collaboration Jay Shah 1, Ganesh Bikshandi 1, Ying Zhang 2, Vijay Thakkar 3,4, Pradeep Ramani 3, Tri Dao5,6 1 Colfax Research, 2 Meta, 3 NVIDIA, 4 Georgia Institute of Technology 5 Princeton University, 6 Together AI
Pseudocode Yes Algorithm 1 FLASHATTENTION-3 forward pass without intra-consumer overlapping CTA view
Open Source Code Yes We open-source FLASHATTENTION-3 with a permissive license3 and plan to integrate it with Py Torch to benefit the largest number of researchers and developers. FLASHATTENTION-3 is available at https://github.com/Dao-AILab/flash-attention
Open Datasets No The paper generates its own data for numerical error validation: 'we generate the entries of Q,K,V with the following distribution: N (0,1)+N (0,100) Bernoulli(0.001)'. It does not use a publicly available or open dataset with concrete access information for its main experiments.
Dataset Splits No The paper discusses empirical validation of its method but does not explicitly provide training/validation/test dataset splits needed to reproduce data partitioning for the experiments.
Hardware Specification Yes We measure the runtime of different attention methods on an H100 80GB SXM5 GPU for different settings (without / with causal mask, head dimension 64 or 128) for BF16 inputs. We benchmark the speed on an H100 80GB SXM5 (700W).
Software Dependencies Yes Specifically, we use: cu DNN 9.5.0.50 CUTLASS 3.6 FLASHATTENTION 2.6.3 Py Torch 2.5.0
Experiment Setup Yes We vary the sequence length as 512, 1k, ..., 16k, and set batch size so that the total number of tokens is 16k. We set the hidden dimension to 2048, and head dimension to be either 64, 128, or 256 (i.e., 32 heads, 16 heads, or 8 heads). To calculate the FLOPs of the forward pass, we use: 4 seqlen2 head dimension number of heads.