Reducing Transformer Key-Value Cache Size with Cross-Layer Attention

Authors: William Brandon, Mayank Mishra, Aniruddha Nrusimha, Rameswar Panda, Jonathan Ragan-Kelley

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In experiments training 1Band 3B-parameter models from scratch, we demonstrate that CLA provides a Pareto improvement over the memory/accuracy tradeoffs which are possible with traditional MQA, potentially enabling future models to operate at longer sequence lengths and larger batch sizes than would otherwise be possible.
Researcher Affiliation Collaboration William Brandon MIT CSAIL wbrandon@csail.mit.edu Mayank Mishra MIT-IBM Watson AI Lab Aniruddha Nrusimha MIT CSAIL Rameswar Panda MIT-IBM Watson AI Lab Jonathan Ragan-Kelley MIT CSAIL
Pseudocode No The paper describes the architecture of Cross-Layer Attention diagrammatically (Figure 2) and in text, but it does not include any formal pseudocode or algorithm blocks.
Open Source Code No The specific codebase we used to run our pretraining experiments is not currently ready to be released. We are still in the process of separating out the implementation logic related to the experiments in this paper from the implementation logic related to other, ongoing research projects which share the same codebase, and intend to complete that process before releasing the code. The Slim Pajama dataset we use is publicly available.
Open Datasets Yes We train our models from scratch on data from the Slim Pajama [Soboleva et al., 2023] dataset, tokenized with the GPT-Neo X tokenizer [Black et al., 2022] which uses Byte-Pair Encoding (BPE) [Wang et al., 2019]...Slim Pajama, made available under the Apache 2.0 license. https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama
Dataset Splits Yes We quantify the accuracy of models in our design space exploration using perplexity on a held-out validation set of 4M tokens drawn from our Slim Pajama corpus.
Hardware Specification Yes We perform all experiments on NVIDIA H100 GPUs using Py Torch [Paszke et al., 2019, Ansel et al., 2024].
Software Dependencies Yes We used the following software assets in conducting the experiments for this paper: Py Torch version 2.1.2, made available under the BSD-3 license. https://pytorch.org
Experiment Setup Yes We train all models using the Adam W optimizer [Loshchilov and Hutter, 2019] with gradient clipping, using β1 = 0.9, β2 = 0.95, a weight decay factor of 0.1, and a clipping norm of 1.0. We use a linear learning rate warmup for the first 5% of training examples and a cosine learning rate schedule Loshchilov and Hutter [2017] decaying to 10% of the peak learning rate over the remainder of training. We set the sequence length to 2048 tokens and the batch size to 2048 sequences, for a total of 4M tokens per training step. All our experiments initialize the weights of linear layers from a normal distribution with mean zero and standard deviation 0.01275.