DHA: Learning Decoupled-Head Attention from Transformer Checkpoints via Adaptive Heads Fusion
Authors: Yilong Chen, Linhao Zhang, Junyuan Shang, Zhenyu Zhang, Tingwen Liu, Shuohuan Wang, YU SUN
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Our experiments show that DHA remarkably requires only 2.5% of the original model s pre-training budgets to achieve 96.1% of performance while saving 75% of KV cache. |
| Researcher Affiliation | Collaboration | Yilong Chen1,2 , Linhao Zhang3 , Junyuan Shang3 , Zhenyu Zhang3, Tingwen Liu1,2 , Shuohuan Wang3, Yu Sun3 1 Institute of Information Engineering, Chinese Academy of Sciences 2 School of Cyber Security, University of Chinese Academy of Sciences 3 Baidu Inc. |
| Pseudocode | Yes | Algorithm 1 Attention Module Initialization |
| Open Source Code | No | However, the complete code is still being organized and is under consideration for open sourcing. |
| Open Datasets | Yes | To train DHA operators and extend pre-training, we employ the Red Pajama [19], which parallels the LLa MA training data across seven domains: Common Crawl, C4, Git Hub, Wikipedia, Books, Ar Xiv, and Stack-Exchange. |
| Dataset Splits | Yes | This dataset comprises a validation set with 2 million tokens, a training set containing 4 billion tokens and an additional pre-training set totaling 50 billion tokens. |
| Hardware Specification | Yes | Our experimental framework utilizes the Sheared-LLa MA codebase [16] implemented on the Composer package [20], and is executed on 8 NVIDIA A100 GPUs (80GB). |
| Software Dependencies | No | The paper mentions 'Composer package [20]' and 'Flash Attention V1 [60]' but does not provide specific version numbers for these software dependencies. |
| Experiment Setup | Yes | The models are trained with a sequence length of 4096, employing a global batch size of 64 during the fusion phase and 256 during the continued pre-training phases. The learning rates were set at 1e-4 for language modeling loss, and 1e-2 for Lagrangian multipliers and fusion operators respectively. |