Language Model Cascades: Token-Level Uncertainty And Beyond
Authors: Neha Gupta, Harikrishna Narasimhan, Wittawat Jitkrittum, Ankit Singh Rawat, Aditya Krishna Menon, Sanjiv Kumar
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We begin by examining the natural extension of predicted class uncertainty to generative LM tasks, namely, the predicted sequence uncertainty. We show that this measure suffers from the length bias problem, either over- or under-emphasizing outputs based on their lengths. This is because LMs produce a sequence of uncertainty values, one for each output token; and moreover, the number of output tokens is variable across examples. To mitigate this issue, we propose to exploit the richer token-level uncertainty information implicit in generative LMs. We argue that na ıve predicted sequence uncertainty corresponds to a simple aggregation of these uncertainties. By contrast, we show that incorporating token-level uncertainty through learned post-hoc deferral rules can significantly outperform such simple aggregation strategies, via experiments on a range of natural language benchmarks with FLAN-T5 models. We further show that incorporating embeddings from the smaller model and intermediate layers of the larger model can give an additional boost in the overall cost-quality tradeoff. |
| Researcher Affiliation | Industry | Neha Gupta, Harikrishna Narasimhan, Wittawat Jitkrittum, Ankit Singh Rawat, Aditya Krishna Menon, Sanjiv Kumar Google Research, New York {nehagup, hnarasimhan, wittawat, ankitsrawat, adityakmenon, sanjivk}@google.com |
| Pseudocode | No | The paper does not contain structured pseudocode or algorithm blocks. |
| Open Source Code | No | The paper does not provide any concrete access to source code for the described methodology. |
| Open Datasets | Yes | Datasets. In the body, we show deferral curves for three different NLP tasks: MNLI (Williams et al., 2018), a multi-class classification problem; Trivia QA (Joshi et al., 2017), a closed-book question answering problem; and WMT DE FR, a translation problem. We report AUC-DF numbers for an expanded dataset pool. These span Classification (IMDb (Maas et al., 2011), Super GLUE (Wang et al., 2019a), MNLI (Williams et al., 2018), ANLI (Nie et al., 2020)); Question answering (Trivia QA (Joshi et al., 2017), Natural QA (Kwiatkowski et al., 2019), Ty Di QA { ID, SW, FI } (Clark et al., 2020)); Reading comprehension (Lambada (Paperno et al., 2016), SQu AD (Rajpurkar et al., 2016)); Translation (WMT 14: EN FR (Bojar et al., 2014), WMT 19: DE FR (Foundation), and WMT 14: FR EN (Bojar et al., 2014)); and Common-sense reasoning (Winogrande (Sakaguchi et al., 2021)). |
| Dataset Splits | Yes | We split the training set into train (0.8 fraction) and validation sets (0.2 fraction). And, use the validation or test split for reporting numbers. The details of all the datasets are included in Table 2. |
| Hardware Specification | No | The paper does not provide specific hardware details (e.g., CPU, GPU models, or memory) used for running its experiments. |
| Software Dependencies | No | The paper mentions 'FLAN-T5 models' but does not provide specific version numbers for any software dependencies like programming languages, libraries, or frameworks. |
| Experiment Setup | Yes | We use a Multi-Layer-Perceptron (MLP) with 5 layers and 32 hidden units per dimension for training post-hoc deferral rules with Quantiles with batch normalization layers. We train for 200 epochs and early stop depending on the AUC of predicting the target label on the validation set for classification tasks and validation regression loss for the regression tasks. Since the embeddings are high dimensional, we use MLP with 2 layers and 8 hidden units to prevent overfitting. We train for the same number of epochs with early stopping. We use ADAM optimizer with a learning rate of 10 5. We use a batch size of 16 for training. We train over 5 random runs and show mean and standard deviations over the runs. |