LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models
Authors: Yukang Chen, Shengju Qian, Haotian Tang, Xin Lai, Zhijian Liu, Song Han, Jiaya Jia
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We present Long Lo RA, an efficient fine-tuning approach that extends the context sizes of pre-trained large language models (LLMs), with limited computation cost. Typically, training LLMs with long context sizes is computationally expensive, requiring extensive training hours and GPU resources. For example, training on the context length of 8192 needs 16 computational costs in self-attention layers as that of 2048. In this paper, we speed up the context extension of LLMs in two aspects. On the one hand, although dense global attention is needed during inference, fine-tuning the model can be effectively and efficiently done by sparse local attention. The proposed shifted sparse attention (S2-Attn) effectively enables context extension, leading to non-trivial computation saving with similar performance to fine-tuning with vanilla attention. Particularly, it can be implemented with only two lines of code in training, while being optional in inference. On the other hand, we revisit the parameter-efficient fine-tuning regime for context expansion. Notably, we find that Lo RA for context extension works well under the premise of trainable embedding and normalization. Long Lo RA combines this improved Lo RA with S2-Attn. Long Lo RA demonstrates strong empirical results on various tasks on Llama2 models from 7B/13B to 70B. Long Lo RA extends Llama2 7B from 4k context to 100k, or Llama2 70B to 32k on a single 8 A100 machine. Long Lo RA extends models context while retaining their original architectures, and is compatible with most existing techniques, like Flash-Attention2. In addition, we further conduct supervised fine-tuning with Long Lo RA and our long instruction-following Long Alpaca dataset. All our code, models, dataset, and demo are available at github.com/dvlab-research/Long Lo RA. 2.66 2.52 2.48 2.78 2.94 2.98 8192 16364 32768 65536 8192 16364 32768 65536 Full FT Lo RA Long Lo RA 8192 16364 32768 65536 Training hours Context Context Figure 1: Long Lo RA closes the accuracy gap that between conventional Lo RA and full fine-tuning, while still maintaining up to 1.8 lower memory cost than full fine-tuning. Furthermore, Long Lo RA improves the training speed of Lo RA by up to 1.8 with S2-Attn. Llama2-7B are fine-tuned to various context lengths with Flash-Attention2 (Dao, 2023) and Deep Speed (Rasley et al., 2020) stage 2 and evaluated on the proof-pile (Azerbayev et al., 2022) test set in perplexity. |
| Researcher Affiliation | Collaboration | Yukang Chen 1 Shengju Qian 1 Haotian Tang 2 Xin Lai 1 Zhijian Liu 2 Song Han 2,3 Jiaya Jia 1 1CUHK 2MIT 3NVIDIA |
| Pseudocode | Yes | Algorithm 1: Pseudocode of S2-Attn in Py Torch-like style. |
| Open Source Code | Yes | All our code, models, dataset, and demo are available at github.com/dvlab-research/Long Lo RA. |
| Open Datasets | Yes | Datasets We use the Redpajama (Computer, 2023) dataset for training. We evaluate the long-sequence language modeling performance of our fine-tuned models on the book corpus dataset PG19 (Rae et al., 2020) and the cleaned Arxiv Math proof-pile dataset (Azerbayev et al., 2022). All our code, models, dataset, and demo are available at github.com/dvlab-research/Long Lo RA. |
| Dataset Splits | Yes | We evaluate the perplexity on PG19 (Rae et al., 2020) validation set. |
| Hardware Specification | Yes | Long Lo RA extends Llama2 7B from 4k context to 100k, or Llama2 70B to 32k on a single 8 A100 machine. All our experiments are conducted on an 8 A100 machine. |
| Software Dependencies | No | We train all models using Py Torch (Paszke et al., 2019) with the Deep Speed (Rasley et al., 2020) and Flash-Attention2 (Dao, 2023). While these software components are mentioned, specific version numbers are not provided. |
| Experiment Setup | Yes | We follow most training hyper-parameters in Position Interpolation (Chen et al., 2023), except that our batch size is smaller as we use a single 8 A100 GPUs machine in some cases. All models are fine-tuned via the next token prediction objective. We use Adam W (Loshchilov & Hutter, 2019) with β1 = 0.9 and β2 = 0.95. The learning rate is set to 2 10 5 for 7B and 13B models, and 10 5 for 70B models. We also use a linear learning rate warmup. The weight decay is zero. We set the per-device batch size as 1 and gradient accumulation steps as 8, which means that the global batch size equals 64, using 8 GPUs. We train our models for 1000 steps. |