Training Chain-of-Thought via Latent-Variable Inference
Authors: Du Phan, Matthew Douglas Hoffman, David Dohan, Sholto Douglas, Tuan Anh Le, Aaron Parisi, Pavel Sountsov, Charles Sutton, Sharad Vikram, Rif A. Saurous
NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Applying our technique to GSM8K and the tasks in BIG-Bench Hard, we find that this MCMC-EM fine-tuning technique typically improves the model s accuracy on held-out examples more than STa R or prompt-tuning with or without Co T. |
| Researcher Affiliation | Industry | Du Phan Matthew D. Hoffman David Dohan Sholto Douglas Tuan Anh Le Aaron Parisi Pavel Sountsov Charles Sutton Sharad Vikram Rif A. Saurous Google |
| Pseudocode | Yes | Algorithm 1 outlines the method. |
| Open Source Code | Yes | A notebook with a reference implementation can be found at https://github.com/google-research/cascades/tree/main/cascades/ examples/notebooks/trice.ipynb. |
| Open Datasets | Yes | We evaluate TRICE on the GSM8K (Cobbe et al., 2021) dataset and the 27 Big Bench-Hard (BBH) tasks (Suzgun et al., 2022b). |
| Dataset Splits | No | On each BBH task, we split the examples into 60% train and 40% test sets. For GSM8K, we use the standard 7473-example training set and 1319-example test set. All methods are evaluated against the same validation sets. |
| Hardware Specification | Yes | All experiments were run on TPU v4 and v5e chips (Jouppi et al., 2023). |
| Software Dependencies | No | The paper mentions specific models (PaLM 2, Flan) and optimizers (Adam) but does not provide version numbers for any software libraries or frameworks used. |
| Experiment Setup | Yes | For all BBH tasks, we run TRICE for 500 steps with batch size M = 32 and do not use subsampling (i.e., compute L = 64 gradients per batch). We use the Adam optimizer (Kingma & Ba, 2015) with an initial learning rate 1.0 and a cosine decay schedule (Loshchilov & Hutter, 2017) that reduces the learning rate by 10x over the first 450 steps. |