Easy-to-Hard Generalization: Scalable Alignment Beyond Human Supervision

Authors: Zhiqing Sun, Longhui Yu, Yikang Shen, Weiyang Liu, Yiming Yang, Sean Welleck, Chuang Gan

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our key insight is that an evaluator (reward model) trained on supervisions for easier tasks can be effectively used for scoring candidate solutions of harder tasks and hence facilitating easy-to-hard generalization over different levels of tasks. Based on this insight, we propose a novel approach to scalable alignment, which firstly trains the (process-supervised) reward models on easy problems (e.g., level 1-3), and then uses them to evaluate the performance of policy models on hard problems. We show that such easy-to-hard generalization from evaluators can enable easy-to-hard generalizations in generators either through re-ranking or reinforcement learning (RL). Notably, our process-supervised 7b RL model and 34b model (reranking@1024) achieves an accuracy of 34.0% and 52.5% on MATH500, respectively, despite only using human supervision on easy problems. Our approach suggests a promising path toward AI systems that advance beyond the frontier of human supervision.
Researcher Affiliation Collaboration Zhiqing Sun1 , Longhui Yu2 , Yikang Shen3, Weiyang Liu4,5, Yiming Yang1 , Sean Welleck1 , Chuang Gan3,6 1Carnegie Mellon University, 2Peking University, 3MIT-IBM Watson AI Lab 4University of Cambridge, 5Max Planck Institute for Intelligent Systems, 6UMass Amherst
Pseudocode No The paper does not contain any figures, blocks, or sections explicitly labeled "Pseudocode" or "Algorithm".
Open Source Code Yes Code: Edward-Sun/easy-to-hard
Open Datasets Yes Dataset MATH [32] is a dataset of 12,500 challenging competition mathematics problems, where 7,500 of them are training problems and 5,000 are originally used for testing. Following Lightman et al. [40], Wang et al. [75], we use the identical subset of 500 representative problems (i.e., MATH500) as our test set, uniformly sample another 500 problems for validation, across all five difficulty levels, and leave the rest 4,000 MATH test split problems combined with the original 7,500 MATH training split problems as our training set.
Dataset Splits Yes Following Lightman et al. [40], Wang et al. [75], we use the identical subset of 500 representative problems (i.e., MATH500) as our test set, uniformly sample another 500 problems for validation, across all five difficulty levels, and leave the rest 4,000 MATH test split problems combined with the original 7,500 MATH training split problems as our training set.
Hardware Specification No The paper mentions using "Llemma ... 7b and 34b variants" for experiments, which refers to model sizes. It does not provide specific details on the hardware used, such as GPU models, CPU types, or memory specifications.
Software Dependencies No The paper mentions that "Llemma is a large language model for mathematics [6], which is continue pretrained from Code Llama [58] / Lla MA-2 [72]". While these are software components, it does not specify other common software dependencies or libraries with version numbers (e.g., PyTorch 1.x, TensorFlow 2.x, Python 3.x) typically needed for full reproducibility.
Experiment Setup Yes For the PRM800K dataset [40], the SFT model is trained using steps that are labeled as correct. ... The PRMs are trained on the corresponding released dataset [40, 75]. For generating solutions to train ORMs, we sample 32 solutions for each question from the language model using top-K sampling with K=20 and temperature of 0.7. We also ensure that the ratio between positive and negative samples for each question is between 1:3 to 3:1. See Table 4 for a list of training hyper-parameters used in the training jobs. We use full fine-tuning for all SFT/RM training. Table 4: Hyper-parameters in our SFT/RM training jobs PRM800K METAMATH SFT PRM ORM OPRM SFT PRM LEARNING RATE 2E-5 2E-5 2E-5 2E-5 8E-6 2E-5 EPOCHS 3 2 2 2 3 2 BATCH SIZE 128 128 128 128 128 128 MAX SEQ LEN 768 768 1024 1024 1024 768 DTYPE BF16 BF16 BF16 BF16 FP32 BF16