SEA: Sparse Linear Attention with Estimated Attention Mask

Authors: Heejun Lee, Jina Kim, Jeffrey Willette, Sung Ju Hwang

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We demonstrate the efficiency and effectiveness of our method through empirical evaluations on natural language processing tasks such as text classification on GLUE and language modeling on Wikitext-2, where we maintain competitive performance with the vanilla transformer baseline, as shown in Figs. 5a and 7a.
Researcher Affiliation Collaboration Heejun Lee1, Jina Kim1, Jeffrey Willette2, Sung Ju Hwang2,3 School of Computing1, Graduate School of AI2 Korea Advanced Institute of Science and Technology1,2, Deep Auto.ai3 Daejeon, South Korea {ainl,jinakim,jwillette,sjhwang}@kaist.ac.kr
Pseudocode No The paper describes the method conceptually and through diagrams, but does not include any structured pseudocode or algorithm blocks.
Open Source Code Yes We propose and provide code for a novel CSR tensor operation, called Flat CSR, which is capable of handling non-contiguous flatten tasks within a GPU kernel. We implement SEA attention modules, including self-attention and causal attention for the transformer encoder and decoder. Users can import src.models.perlin attention module to construct a custom transformer model.
Open Datasets Yes We validate SEA by applying it to BERT for text classification (Devlin et al., 2019; Wang et al., 2019) and to OPT for causal language modeling (Zhang et al., 2022; Merity et al., 2017). GLUE (Wang et al., 2019) benchmark Wikitext2 (Merity et al., 2017) MNLI (Williams et al., 2018) Open Web Text (Gokaslan et al., 2019).
Dataset Splits No The paper mentions 'validation curves' and training on specific datasets (MNLI, COLA, MRPC) for a certain number of epochs, but does not explicitly state the dataset splits (e.g., percentages or sample counts for training, validation, and test sets).
Hardware Specification Yes Our test environment is built with Ryzen 3950x, RTX 2080ti on 8x PCIe 3.0, DDR4-2400 64GB, and Ubuntu 22.04.
Software Dependencies Yes The versions of third-party libraries including Py Torch and Triton are described in the supplementary file, requirements.txt. Also, we provide the docker environment of our experiment environment for reproducing results, done with the supplementary file, Docker File.
Experiment Setup Yes Batch sizes for our experiments outlined in Section 4, can be seen in Table A.3, We define different learning rate values for original parameters and SEA attention parameters. We use learning rate 10 5 for original parameter, and 10 4 for SEA attention parameters. For OPT models, we use a learning rate 2 10 6 for the original parameter and 10 4 for SEA attention parameters. Weights for loss scaling outlined in Section 3.2 can be seen in Table A.4. ... We train all methods, 20 epochs in MNLI and 50 epochs in COLA and MRPC.