FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

Authors: Tri Dao

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We empirically validate that when used end-to-end to train GPT-style models, FLASHATTENTION-2 reaches training speed of up to 225 TFLOPs/s per A100 GPU (72% model FLOPs utilization). In Section 4, we empirically validate that FLASHATTENTION-2 yields significant speedup compared to even FLASHATTENTION. Benchmarks on different settings (with or without causal mask, different head dimensions) show that FLASHATTENTION-2 achieves around 2 speedup over FLASHATTENTION, reaching up to 73% of the theoretical max throughput in the forward pass, and up to 63% of the theoretical max throughput in the backward pass.
Researcher Affiliation Academia 1Department of Computer Science, Princeton University 2Department of Computer Science, Stanford University tridao@princeton.edu
Pseudocode Yes Algorithm 1 FLASHATTENTION-2 forward pass and Algorithm 2 FLASHATTENTION-2 Backward Pass are provided in the paper.
Open Source Code No The paper thanks individuals for implementing versions of FLASHATTENTION in Triton and xformers, but it does not explicitly state that the authors are releasing their own FLASHATTENTION-2 code, nor does it provide a direct link to the source code for their implementation.
Open Datasets No The paper mentions training "GPT-style models" but does not specify the particular public dataset(s) used for this training, nor does it provide access information or formal citations for any such dataset.
Dataset Splits No The paper mentions training and benchmarking but does not provide specific details on dataset splits for training, validation, or testing (e.g., percentages, sample counts, or references to predefined splits).
Hardware Specification Yes We measure the runtime of different attention methods on an A100 80GB SXM4 GPU...Just running the same implementation on H100 GPUs...H100 80GB SXM5
Software Dependencies No The paper mentions several software components like Py Torch, Triton, xformers library, and CUTLASS 3.x, but it does not provide specific version numbers for its own implementation's software dependencies (e.g., PyTorch version).
Experiment Setup Yes Benchmark setting: we vary the sequence length from 512, 1k, ..., 16k, and set batch size so that the total number of tokens is 16k. We set hidden dimension to 2048, and head dimension to be either 64 or 128 (i.e., 32 heads or 16 heads).