Accelerating Blockwise Parallel Language Models with Draft Refinement
Authors: Taehyeon Kim, Ananda Theertha Suresh, Kishore Papineni, Michael D Riley, Sanjiv Kumar, Adrian Benton
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Experiments demonstrate that by refining block drafts of open-sourced Vicuna and Medusa LLMs, the mean accepted token length are increased by 5-25% relative. This results in over a 3x speedup in wall clock time compared to standard autoregressive decoding in open-source 7B and 13B LLMs. |
| Researcher Affiliation | Collaboration | Taehyeon Kim1 Ananda Theertha Suresh2 Kishore Papineni2 Michael Riley2 Sanjiv Kumar2 Adrian Benton2 1KAIST AI 2Google Research |
| Pseudocode | Yes | Algorithm 1: Blockwise parallel decoding (BPD) and Algorithm 2: Local rescoring via neural models |
| Open Source Code | No | We have not provided open access to the source code used in our experiments. However, we have detailed the data access, architecture, and training processes in the supplemental material to enable replication of our results. |
| Open Datasets | Yes | We investigate the drafts produced by this 1.5B blockwise parallel LM on LAMBADA [33] (language modeling), SQu AD V1 [38] (extractive QA), along with five summarization tasks: XSUM [32], Multi News [12], SAMSum [14], News Room [15] and CNN/Daily Mail [17]. |
| Dataset Splits | No | Checkpoints were selected based on heldout set model performance. Interpolation weight for all rescoring models was tuned for block efficiency on 100 randomly selected examples from the evaluation set for each task, and performance was reported on the remainder of the evaluation set. |
| Hardware Specification | Yes | Model training/inference was run on TPUv3/TPUv4 [20], and implemented in Jax [2]. Additionally, all timings were evaluated on a single NVIDIA A100 80GB GPU with batch size 1. |
| Software Dependencies | No | Model training/inference was run on TPUv3/TPUv4 [20], and implemented in Jax [2]. |
| Experiment Setup | Yes | During pretraining, we use batches of 2048 subword sequences, each 512 tokens in length, amounting to 200B input tokens in total. For downstream tasks, models were finetuned for a maximum 100K iterations with a batch size of two examples with maximum sequence length of 2048. Maximum learning rate was fixed to 10^-4 for all runs, with a cosine learning rate schedule. Dropout was not applied. All n-gram LMs in this work are Katz backoff n-gram LMs [22] fit on the train split of the GPT3 subword-tokenized English C4 corpus with n-gram order {2, 4}. |