Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in [1].

RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment

Authors: Hanze Dong, Wei Xiong, Deepanshu Goyal, Yihan Zhang, Winnie Chow, Rui Pan, Shizhe Diao, Jipeng Zhang, KaShun SHUM, Tong Zhang

TMLR 2023 | Venue PDF | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our studies show that RAFT can effectively improve the model performance in both reward learning and other automated metrics in both large language models and diffusion models. We perform the experiment with the LLa MA-7B model (Touvron et al., 2023) and the HH-RLHF (Helpful and Harmless) dataset4 (Bai et al., 2022a), which is collected for model alignment according to human preferences.
Researcher Affiliation Academia The Hong Kong University of Science and Technology University of Illinois Urbana-Champaign
Pseudocode No The learning process of RAFT can be divided into three steps. For each stage t + 1, Step 1: Data collection. We first sample a batch of prompts Dt = {xt 1, , xt b} from X and generate y1, . . . , y K p1/λ Gt ( |wt, xt i) for each xt i Dt, where λ is the parameter to control the output diversity. Step 2: Data ranking. In this step, we first use the reward model to compute {r(x, y1), , r(x, y K)} for each x Dt. Then, we simply take y := arg maxyj {y1, ,y K} r(x, yj) and go through all the b prompts and collect a subset B of size b. Step 3: Model fine-tuning. Then, we simply fine-tune the current model on B and the next stage begins.
Open Source Code Yes The code will be publicly available on Git Hub in the camera ready version. LMFlow (Diao et al., 2023) (https://github.com/Optimal Scale/LMFlow) is a public package, which aims to provide a general and easy-to-use framework for researchers and engineers to finetune/align models. To run example code of RAFT alignment in LMFlow, one may simply execute: ./scripts/run_raft_align.sh ... We also added the diffusion demos in the LMFlow package.
Open Datasets Yes We perform the experiment with the LLa MA-7B model (Touvron et al., 2023) and the HH-RLHF (Helpful and Harmless) dataset4 (Bai et al., 2022a)...4https://huggingface.co/datasets/Dahoas/full-hh-rlhf. We use the CIFAR-10 labels as our prompts (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck).
Dataset Splits Yes The dataset consists of 112K training samples and 12.5K test samples. This results in a prompt set of 82147 samples (originally 112K). For a fair comparison, we keep the test configuration for all methods and report the metrics on a hand-out test set of size 4608.
Hardware Specification Yes All the experiments are conducted using 8 A40 (48G) with 600G RAM, and half-precision training (bf16). All our experiments are performed on NVIDIA A100 (40G) and A40 (48G). All the experiments of diffusion models are performed with Nvidia-3090 with 256G RAM.
Software Dependencies No We implement the PPO algorithm with the TRL package6, which requires loading multiple LLMs concurrently... Following TRL, we use Parameter-Efficient Fine-Tuning (PEFT) in our experiment with the peft library, and perform Low-Rank Adaptation (Lo RA) (Hu et al., 2021) for PPO with all the experiments. LMFlow (Diao et al., 2023) (https://github.com/Optimal Scale/LMFlow) is a public package, which aims to provide a general and easy-to-use framework for researchers and engineers to finetune/align models.
Experiment Setup Yes Hyper-parameters. For RAFT, we fix the batch size b as 2048 and the learning rate for SFT as 2 10 5. For each SFT stage, we train for 2 epochs and use a linear decay scheduler. Other hyper-parameters will be specified for each experiment. For PPO, we adopt most of the parameter settings in TRL package... The full list of hyper-parameters can be found in Appendix D.