RoSA: Accurate Parameter-Efficient Fine-Tuning via Robust Adaptation
Authors: Mahdi Nikdan, Soroush Tabesh, Elvir Crnčević, Dan Alistarh
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We investigate parameter-efficient fine-tuning (PEFT) methods that can provide good accuracy under limited computational and memory budgets in the context of large language models (LLMs). We present a new PEFT method called Robust Adaptation (Ro SA) inspired by robust principal component analysis that jointly trains low-rank and highly-sparse components on top of a set of fixed pretrained weights to efficiently approximate the performance of a full-fine-tuning (FFT) solution. Across a series of challenging generative tasks such as grade-school math and SQL query generation, which require fine-tuning for good performance, we show that Ro SA outperforms Lo RA, pure sparse fine-tuning, and alternative hybrid methods at the same parameter budget, and can even recover the performance of FFT on some tasks. We provide system support for Ro SA to complement the training algorithm, specifically in the form of sparse GPU kernels which enable memoryand computationallyefficient training, and show that it is also compatible with low-precision base weights, resulting in the first joint representation combining quantization, low-rank and sparse approximations. Our code is available at https://github.com/ IST-DASLab/Ro SA. |
| Researcher Affiliation | Collaboration | 1ISTAustria 2Graz University of Technology 3Neural Magic. Correspondence to: Mahdi Nikdan <mahdi.nikdan@ista.ac.at>, Soroush Tabesh <soroush.tabesh@ista.ac.at>, Dan Alistarh <dan.alistarh@ista.ac.at>, Elvir Crnˇcevi c <elvir.crncevic@ista.ac.at>. |
| Pseudocode | Yes | Algorithm 1 Mask Generation; Algorithm 2 Robust Adaptation (Ro SA) |
| Open Source Code | Yes | Our code is available at https://github.com/ IST-DASLab/Ro SA. |
| Open Datasets | Yes | We perform fine-tuning of the LLa MA2-7B model (Touvron et al., 2023b) on three standard datasets: Vi GGO (Juraska et al., 2019), GSM8k (Cobbe et al., 2021), and SQL generation (Zhong et al., 2017; Yu et al., 2018), containing 5.1k, 7.47k, and 30k training samples and 1.08k, 1.32k, and 1k test samples, respectively. Refer to Appendix F for examples of the GSM8k dataset. In the case of SQL, we follow the dataset formation strategy described in (Niederfahrenhorst et al., 2023). |
| Dataset Splits | No | The paper specifies the number of training and test samples for the datasets but does not provide explicit details about a validation set split or its size. |
| Hardware Specification | Yes | All experiments, except for FFT, comfortably run on a single NVIDIA Ge Force RTX 3090 GPU 24.3 GB memory (see Table 1). [...] Performing measurements on an NVIDIA RTX A6000 GPU, we find our current implementatio of Ro SA to be approximately 1.7-2x slower than Lo RA on the 80M parameter budget (see Appendix B). |
| Software Dependencies | No | The paper mentions using a 'fork of the standard PEFT library (Mangrulkar et al., 2022)' and 'the Mosaic ML llm-foundry codebase (Mosaic ML, 2023a)', as well as PyTorch and Sputnik. However, specific version numbers for these software components are not provided. |
| Experiment Setup | Yes | In all experiments, we use a standard batch size of 32 (micro-batch size 1 + gradient accumulation) and a maximum context length of 512, which matches the dataset sample structure. We employ the Adam W optimizer (Loshchilov & Hutter, 2017) with parameters β1 = 0.9, β2 = 0.999, ϵ = 10 8, and a linear learning rate scheduler with 20 batches warmup. Notably, all floating-point values are stored in bfloat16 (Dean et al., 2012), popular due to low memory usage and good accuracy. Our main experiments run for a single epoch, but we demonstrate in ablation studies that extended training can further improve adaptation results. Following (Hu et al., 2021), we use α = 16 and a dropout of 0.05 for the low-rank adapter, while experimenting with various r values ranging from 4 to 64. Additionally, we set the size of the mask generation dataset to 32 samples in all experiments while tuning the gradient accumulation exponent (α in Algorithm 1) as a binary hyperparameter (1 for averaging gradients and 2 for diagonal Fisher). The sparse adapter s density ranges from 0.15% to 2.4%. [...] The best learning rates for single-epoch FFT are 4 10 5, 2 10 5, and 1 10 4 on SQL, Vi GGO, and GSM8k, respectively, while for extended FFT it is 4 10 5 on Vi GGO and 5 10 5 on GSM8k. For Lo RA and Sp A parameters, the best-performing learning rates are selected in the range [10 4, 10 3] and [10 4, 8 10 4], respectively. In Ro SA experiments, we find it beneficial to initially fine-tune solely with Lo RA for 64 batches, generate and fix the sparse masks, and restart training with both Lo RA and sparse adaptation (Sp A) activated. |