S-STE: Continuous Pruning Function for Efficient 2:4 Sparse Pre-training
Authors: Yuezhou Hu, Jun Zhu, Jianfei Chen
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Results show that our method surpasses previous 2:4 pre-training recipes and is comparable even with full parameter models. Our toolkit is available at https://github.com/huyz2023/2by4-pretrain. |
| Researcher Affiliation | Academia | Yuezhou Hu1, Jun Zhu1, Jianfei Chen1 1Dept. of Comp. Sci. & Tech., Institute for AI, BNRist Center, Tsinghua-Bosch Joint ML Center, THBI Lab, Tsinghua University. huyz21@mails.tsinghua.edu.cn, {dcszj,jianfeic}@tsinghua.edu.cn |
| Pseudocode | No | No explicitly labeled pseudocode or algorithm blocks were found. |
| Open Source Code | Yes | Our toolkit is available at https://github.com/huyz2023/2by4-pretrain. |
| Open Datasets | Yes | For Transformer, we train Transformer-base models on WMT 14 En-De dataset [3] with fairseq [34] codebase and evaluate it with BLEU [36] scores. For Dei T, we pre-train Deit-small model for Image Net-1K [10] classification task. For GPT-2, we pre-train GPT-2 124M, 350M and 774M models on Open Web Text [16] and evaluate it on GLUE [47] and SQu AD [39] benchmarks. |
| Dataset Splits | Yes | For Transformer, we train Transformer-base models on WMT 14 En-De dataset [3] and evaluate it with BLEU [36] scores. For Dei T, we pre-train Deit-small model for Image Net-1K [10] classification task. For GPT-2, we pre-train GPT-2 124M, 350M and 774M models on Open Web Text [16] and evaluate it on GLUE [47] and SQu AD [39] benchmarks. |
| Hardware Specification | Yes | For acceleration, we measure the acceleration ratio of a typical GPT-2 model using implementation from Hu et al. [20]. Note that on H100 GPUs, FP8 2:4-sp MM kernel turns out to be unsatisfying; see Appendix A.4. Consequently, we fall back to use RTX3090 GPUs with FP16 training. H100 PCIe 2:4-sp MM 3200 TFLOPS H100 PCIe GEMM 1600 TFLOPS H100 SXM 2:4-sp MM 4000 TFLOPS H100 SXM GEMM 2000 TFLOPS |
| Software Dependencies | No | The paper mentions 'fairseq [34] codebase' and 'Transformer Engine 2' but does not specify version numbers for these software dependencies or any other key libraries. |
| Experiment Setup | Yes | For all models, we replace the two linear layers in the feed forward network of each transformer block with S-STE. We keep the rest of the networks, the optimization algorithms as well as all hyperparameters the same as their dense counterparts. ... we use FP8 e3m4 in forward pass and e5m2 in backward pass. Besides, we use per-tensor rescaling before casting to FP8 formats. |