The Mamba in the Llama: Distilling and Accelerating Hybrid Models
Authors: Junxiong Wang, Daniele Paliotta, Avner May, Alexander Rush, Tri Dao
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. ... Our experiments distill different large-scale open chat LLMs, Zephyr-7B [72], Llama-3 8B [21] to linear RNN models (hybrid Mamba and Mamba2), using only 20B tokens of training. Results show that the distilled approach matches the teacher model in standard Chat benchmarks [84, 43]. We also show that it performs on par or better with all similarly sized pretrained-from-scatch Mamba models including Mamba 7B models [52, 26] trained from scratch with 1.2T tokens or NVIDIA Hybrid Mamba2 models [74] trained from scratch with 3.5T tokens in multiple tasks (e.g., MMLU [34], Truthful QA [47]) in the LM evaluation [25]. |
| Researcher Affiliation | Collaboration | Junxiong Wang1 Daniele Paliotta2,3 Avner May3 Alexander M. Rush1 Tri Dao3,4 1Cornell University 2University of Geneva 3Together AI 4Princeton University |
| Pseudocode | Yes | Algorithm 1 Attention-Initialized Mamba and Algorithm 2 Multi-Step Linear RNN Speculation |
| Open Source Code | Yes | Code and pre-trained checkpoints are open-sourced at https://github.com/jxiw/Mamba In Llama and https://github.com/itsdaniele/speculative_mamba. |
| Open Datasets | Yes | We use Ultra Chat [20] and Ultra Feedback [17] as seed prompts... In the second stage, we use supervised finetuning with our model on the Gen QA [12], Infinity Instruct [3] and Open Hermes 2.5 [70] datasets... For the draft models, we train 2 and 4-layer Transformer Draft models on the Open Hermes2.5 dataset [70]... |
| Dataset Splits | No | No explicit mention of validation dataset splits (e.g., percentages or counts) for the datasets used in training. The paper discusses 'evaluation' on benchmarks, but not the partitioning of training data into explicit validation sets. |
| Hardware Specification | Yes | The total distillation process for each hybrid model (e.g., Mamba-Llama3 (50% att)) takes less than five days in 8x80G A100. The distillation phase takes eight days on 8x A100 and four days on 8x H100. Speculative decoding experiments are run on a single NVIDIA RTX 3090 on data from Open Hermes2.5. |
| Software Dependencies | No | The paper mentions 'LM Evaluation Harness library [25] (branch big-refactor)' but does not provide specific version numbers for Python, PyTorch, CUDA, or other key software libraries used in the experiments. |
| Experiment Setup | Yes | The student model is trained in one epoch using the loss L in Eq 2 with α = 1 and β = 0.1. Models are trained using Adam W optimizer with β = (0.9, 0.98) with a batch size 64. We use a linear learning rate warm-up (for the first 500 steps) followed by cosine annealing. |