Keypoint-based Progressive Chain-of-Thought Distillation for LLMs

Authors: Kaituo Feng, Changsheng Li, Xiaolu Zhang, Jun Zhou, Ye Yuan, Guoren Wang

ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Extensive experiments on four reasoning benchmarks illustrate our KPOD outperforms previous methods by a large margin.
Researcher Affiliation Collaboration 1Beijing Institute of Technology 2Ant Group 3Hebei Province Key Laboratory of Big Data Science and Intelligent Technology.
Pseudocode Yes Algorithm 1 outlines the training procedure of our KPOD.
Open Source Code No The paper does not provide an explicit statement about releasing its source code or a link to a code repository for the methodology described.
Open Datasets Yes We evaluate our method on both mathematical reasoning tasks and commonsense reasoning tasks, following (Hsieh et al., 2023; Fu et al., 2023). For mathematical reasoning, we adopt three benchmark datasets for evaluation: GSM8K (Cobbe et al., 2021), ASDiv (Patel et al., 2021) and SVAMP (Miao et al., 2021). For commonsense reasoning, Commonsense QA benchmark (Talmor et al., 2019) is employed to evaluate our method.
Dataset Splits Yes Table 4. Dataset statistics. Datasets Train Size Validation Size Test Size GSM8K 7473 660 659 ASDiv 1462 313 314 SVAMP 700 150 150 Commonsense QA 8520 1221 1221
Hardware Specification Yes We perform our experiments using GeForce RTX 3090 GPUs.
Software Dependencies No The paper mentions software components like 'Lo RA', 'Flan T5-Large', 'LLa MA-7B', and 'Adam optimizer', but it does not provide specific version numbers for these or other relevant software dependencies such as Python, PyTorch/TensorFlow, or CUDA.
Experiment Setup Yes The rank of Lo RA is set to 64 for LLa MA-7B and 128 for Flan T5-XL. We use Adam optimizer for optimization with a learning rate of 1e-5 for LLa MA-7B and 5e-5 for Flan T5 models. The batch size is set to 4. In terms of LLa MA-7B, the epoch number for training the student model is set to 20 for the GSM8K, Commonsense QA datasets, and 40 for the ASDiv, SVAMP datasets. As for Flan T5 models, the epoch number is set to 100... The hyper-parameters α that balances the answer prediction loss and mask ratio loss is set to 0.5.