Compact Language Models via Pruning and Knowledge Distillation
Authors: Saurav Muralidharan, Sharath Turuvekere Sreenivas, Raviraj Joshi, Marcin Chochowski, Mostofa Patwary, Mohammad Shoeybi, Bryan Catanzaro, Jan Kautz, Pavlo Molchanov
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this paper, we investigate if pruning an existing LLM and then re-training it with a fraction (<3%) of the original training data can be a suitable alternative to repeated, full retraining. To this end, we develop a set of practical and effective compression best practices for LLMs that combine depth, width, attention and MLP pruning with knowledge distillation-based retraining; we arrive at these best practices through a detailed empirical exploration of pruning strategies for each axis, methods to combine axes, distillation strategies, and search techniques for arriving at optimal compressed architectures. We use this guide to compress the Nemotron-4 family of LLMs by a factor of 2-4 , and compare their performance to similarly-sized models on a variety of language modeling tasks. |
| Researcher Affiliation | Industry | NVIDIA {sauravm,sharatht,ravirajj,mchochowski,mpatwary,mshoeybi, bcatanzaro,jkautz,pmolchanov}@nvidia.com |
| Pseudocode | No | The paper describes methods and processes through textual descriptions and figures (e.g., Figure 2 for overview, Figure 3 for search strategy), but does not include formal pseudocode or algorithm blocks. |
| Open Source Code | Yes | We have open-sourced MINITRON model weights on Huggingface 2, with corresponding supplementary material including example code available on Git Hub 3. |
| Open Datasets | Yes | We use the Nemotron-4 curated 8 trillion token (8T) base pretraining dataset and the continued training dataset (CT) [42, 44, 43]. |
| Dataset Splits | Yes | LM validation loss is reported on the validation set of the 8T and Wiki Text2 datasets. |
| Hardware Specification | Yes | All experiments were performed on 16 NVIDIA DGX A100 nodes (8 A100 80GB) for short turnaround times. |
| Software Dependencies | No | We use the NVIDIA Megatron-LM framework [47] to implement our pruning and distillation algorithms for compression and retraining. However, no specific version numbers for Megatron-LM or other software dependencies are provided. |
| Experiment Setup | Yes | Unless otherwise specified, we use 1.8 billion tokens (400 steps) for lightweight retraining. The calibration dataset D used for importance estimation consists of 1024 samples drawn randomly from the full dataset. We use the same optimizer settings and data split as [43] with cosine LR decay schedule from 2 4 to 4.5 7. |