Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models
Authors: Aviv Bick, Kevin Li, Eric Xing, J. Zico Kolter, Albert Gu
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Table 1 presents a comprehensive breakdown of downstream evaluation results for our models and multiple baselines on a standard set of commonsense reasoning and language understanding tasks: Wino Grande [33], Hella Swag [45], PIQA [2], ARC-Challenge and ARC-Easy [6], and LAMBADA [29]. Figure 1 shows the performance versus the training cost of Phi-Mamba compared to many open-source baselines from the literature at similar model sizes. |
| Researcher Affiliation | Collaboration | Aviv Bick1 , Kevin Y. Li1 , Eric P. Xing12, J. Zico Kolter1, Albert Gu13 1Carnegie Mellon University, 2MBZUAI, 3Cartesia.ai {abick, kyl2}@cs.cmu.edu |
| Pseudocode | Yes | Algorithm 1 Approximation of Attention Matrix M as an SSM with State Size n |
| Open Source Code | No | Question: Does the paper provide open access to the data and code, with sufficient instructions to faithfully reproduce the main experimental results, as described in supplemental material? Answer: [No] Justification: Our proposed method is quite simple and all important details are reported. The datasets used are reported and are all standard, open source datasets. |
| Open Datasets | Yes | Our final Phi Mamba-1.5B model is distilled on 3 billion tokens (distributed as 80M in Stage 1, 160M in Stage 2, and 2.76B tokens in Stage 3 as described in Appendix A) from the C4 dataset, with a sequence length of 2048. Table 1 presents a comprehensive breakdown of downstream evaluation results for our models and multiple baselines on a standard set of commonsense reasoning and language understanding tasks: Wino Grande [33], Hella Swag [45], PIQA [2], ARC-Challenge and ARC-Easy [6], and LAMBADA [29]. |
| Dataset Splits | No | The paper mentions training on the C4 dataset and evaluating on several benchmarks but does not explicitly state the specific train/validation/test dataset splits or cross-validation setup used for reproducibility. |
| Hardware Specification | Yes | Experiments were run across different types of hardware with frequent checkpointing and resuming, but all using commercially available standard resources (e.g. singlenode training on A100 or H100 GPUs). |
| Software Dependencies | No | The paper mentions using PyTorch for code examples and the AdamW optimizer, along with bf16 mixed precision training, but does not provide specific version numbers for these or other software dependencies. |
| Experiment Setup | Yes | Hyperparameter Search To construct Appendix A, we performed grid searches for training in Stages 1, 2, and 3 independently from scratch to find the optimal hyperparameters. We explored learning rates lr = {1, 2, 5} 10{ 3, 4} and batch sizes 2{15,16,17,18}. Adam W Optimizer was used with β = (0.9, 0.95), incorporating a weight decay of 0.1, gradient clipping at 1.0, and a Warmup-Stable-Decay (WSD) scheduler with 10% warmup and 10% decay utilizing linear warmup and cooldown functions. Automatic mixed precision training to bf16 was used in all stages. For Stages 1 and 2, we initially fixed the batch size at 216, then varied the learning rates. After identifying the optimal learning rate, we adjusted the batch sizes and subsequently finalized the learning rate after fixing the batch size. Consequently, Stage 1 used bs = 215, lr = 5 10 4 and Stage 2 used bs = 215, lr = 2 10 3. In Stage 3, we set the batch size to 219 0.5M and focused solely on varying the learning rate, resulting in 5 10 4. Stages 1 and 2 were trained to 200M steps each while Stage 3 extended to 1B steps. |