Mini-Sequence Transformers: Optimizing Intermediate Memory for Long Sequences Training

Authors: cheng Luo, Jiawei Zhao, Zhuoming Chen, Beidi Chen, Animashree Anandkumar

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We evaluate the impact of using chunk-based MINI-SEQUENCE TRANSFORMER (MST) on Llama3[36], a state-of-the-art model for many NLP tasks. We also evaluate Qwen [6], Mistral [24], and Gemma-2 [54] for context length improvements. We validate our claims about scaling sequence length, reporting training time, and memory overhead.
Researcher Affiliation Collaboration Cheng Luo California Institute of Technology chengluo@caltech.edu Jiawei Zhao Meta FAIR jwzhao@meta.com Zhuoming Chen Carnegie Mellon University zhuominc@andrew.cmu.edu Beidi Chen Carnegie Mellon University beidic@andrew.cmu.edu Anima Anandkumar California Institute of Technology anima@caltech.edu
Pseudocode Yes Algorithm 1 Mini-Sequence MLP, Algorithm 2 Mini-Sequence LM-Head, Algorithm 3 Mini-Sequence MLP Backward, Algorithm 4 Mini-Sequence LM-Head Backward.
Open Source Code Yes This work is open-source under an MIT license on https://github.com/wdlctc/mini-s. ... We made this method open-source on https://github.com/wdlctc/transformers.
Open Datasets Yes We train a Llama3-8B[36] MST on the Long Alpaca dataset[9].
Dataset Splits No The paper uses various LLMs (Llama3, Llama2, Qwen, Mistral, Gemma-2) and the Long Alpaca dataset but does not specify the train/validation/test splits for these datasets.
Hardware Specification Yes MST trains 12 24 longer sequence lengths than existing systems on a single A100 GPU with no degradation in throughput and convergence of training. ... MST can train Llama3-8B with context length 60k and Llama3-7B with context length 84k on a single A100 GPU... We train a Llama3-8B[36] MST and Llama2 models[43] MST by exploring the sequence length on a single A100 GPU...
Software Dependencies No The paper mentions implementing MST with PyTorch and integrating it with the Hugging Face library, but it does not specify the version numbers for these software components or any other dependencies.
Experiment Setup Yes For all implementation, we use the Adam W optimizer [32]. We use a weight decay of 0.001, gradient clipping of 1.0, and a constant learning rate of 1e-4. All batch sizes equal 16, with a gradient accumulation step of 16. The bf16 precision is also deployed.