Reducing Transformer Key-Value Cache Size with Cross-Layer Attention
Authors: William Brandon, Mayank Mishra, Aniruddha Nrusimha, Rameswar Panda, Jonathan Ragan-Kelley
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In experiments training 1Band 3B-parameter models from scratch, we demonstrate that CLA provides a Pareto improvement over the memory/accuracy tradeoffs which are possible with traditional MQA, potentially enabling future models to operate at longer sequence lengths and larger batch sizes than would otherwise be possible. |
| Researcher Affiliation | Collaboration | William Brandon MIT CSAIL wbrandon@csail.mit.edu Mayank Mishra MIT-IBM Watson AI Lab Aniruddha Nrusimha MIT CSAIL Rameswar Panda MIT-IBM Watson AI Lab Jonathan Ragan-Kelley MIT CSAIL |
| Pseudocode | No | The paper describes the architecture of Cross-Layer Attention diagrammatically (Figure 2) and in text, but it does not include any formal pseudocode or algorithm blocks. |
| Open Source Code | No | The specific codebase we used to run our pretraining experiments is not currently ready to be released. We are still in the process of separating out the implementation logic related to the experiments in this paper from the implementation logic related to other, ongoing research projects which share the same codebase, and intend to complete that process before releasing the code. The Slim Pajama dataset we use is publicly available. |
| Open Datasets | Yes | We train our models from scratch on data from the Slim Pajama [Soboleva et al., 2023] dataset, tokenized with the GPT-Neo X tokenizer [Black et al., 2022] which uses Byte-Pair Encoding (BPE) [Wang et al., 2019]...Slim Pajama, made available under the Apache 2.0 license. https://www.cerebras.net/blog/slimpajama-a-627b-token-cleaned-and-deduplicated-version-of-redpajama |
| Dataset Splits | Yes | We quantify the accuracy of models in our design space exploration using perplexity on a held-out validation set of 4M tokens drawn from our Slim Pajama corpus. |
| Hardware Specification | Yes | We perform all experiments on NVIDIA H100 GPUs using Py Torch [Paszke et al., 2019, Ansel et al., 2024]. |
| Software Dependencies | Yes | We used the following software assets in conducting the experiments for this paper: Py Torch version 2.1.2, made available under the BSD-3 license. https://pytorch.org |
| Experiment Setup | Yes | We train all models using the Adam W optimizer [Loshchilov and Hutter, 2019] with gradient clipping, using β1 = 0.9, β2 = 0.95, a weight decay factor of 0.1, and a clipping norm of 1.0. We use a linear learning rate warmup for the first 5% of training examples and a cosine learning rate schedule Loshchilov and Hutter [2017] decaying to 10% of the peak learning rate over the remainder of training. We set the sequence length to 2048 tokens and the batch size to 2048 sequences, for a total of 4M tokens per training step. All our experiments initialize the weights of linear layers from a normal distribution with mean zero and standard deviation 0.01275. |