DHA: Learning Decoupled-Head Attention from Transformer Checkpoints via Adaptive Heads Fusion

Authors: Yilong Chen, Linhao Zhang, Junyuan Shang, Zhenyu Zhang, Tingwen Liu, Shuohuan Wang, YU SUN

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our experiments show that DHA remarkably requires only 2.5% of the original model s pre-training budgets to achieve 96.1% of performance while saving 75% of KV cache.
Researcher Affiliation Collaboration Yilong Chen1,2 , Linhao Zhang3 , Junyuan Shang3 , Zhenyu Zhang3, Tingwen Liu1,2 , Shuohuan Wang3, Yu Sun3 1 Institute of Information Engineering, Chinese Academy of Sciences 2 School of Cyber Security, University of Chinese Academy of Sciences 3 Baidu Inc.
Pseudocode Yes Algorithm 1 Attention Module Initialization
Open Source Code No However, the complete code is still being organized and is under consideration for open sourcing.
Open Datasets Yes To train DHA operators and extend pre-training, we employ the Red Pajama [19], which parallels the LLa MA training data across seven domains: Common Crawl, C4, Git Hub, Wikipedia, Books, Ar Xiv, and Stack-Exchange.
Dataset Splits Yes This dataset comprises a validation set with 2 million tokens, a training set containing 4 billion tokens and an additional pre-training set totaling 50 billion tokens.
Hardware Specification Yes Our experimental framework utilizes the Sheared-LLa MA codebase [16] implemented on the Composer package [20], and is executed on 8 NVIDIA A100 GPUs (80GB).
Software Dependencies No The paper mentions 'Composer package [20]' and 'Flash Attention V1 [60]' but does not provide specific version numbers for these software dependencies.
Experiment Setup Yes The models are trained with a sequence length of 4096, employing a global batch size of 64 during the fusion phase and 256 during the continued pre-training phases. The learning rates were set at 1e-4 for language modeling loss, and 1e-2 for Lagrangian multipliers and fusion operators respectively.