Understanding Catastrophic Forgetting in Language Models via Implicit Inference

Authors: Suhas Kotha, Jacob Mitchell Springer, Aditi Raghunathan

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In a simplified scenario, we demonstrate that improving performance on tasks within the fine-tuning data distribution comes at the expense of capabilities on other tasks. We hypothesize that language models implicitly infer the task of the prompt and that fine-tuning skews this inference towards tasks in the fine-tuning distribution. To test this, we propose Conjugate Prompting... we find that this recovers some of the pretraining capabilities in our synthetic setup.
Researcher Affiliation Academia Suhas Kotha, Jacob Mitchell Springer & Aditi Raghunathan Carnegie Mellon University {suhask, jspringe, aditirag}@cs.cmu.edu
Pseudocode No The paper describes methods and strategies but does not include any formal pseudocode or algorithm blocks.
Open Source Code Yes Code available at github.com/kothasuhas/understanding-forgetting
Open Datasets Yes XNLI benchmark (Conneau et al., 2018), a multi-lingual version of MNLI (Williams et al., 2018) from GLUE (Wang et al., 2019)
Dataset Splits No The paper mentions evaluating on 400 samples and 2490 test samples, but does not provide explicit train/validation/test splits for any dataset used for training or fine-tuning, in terms of percentages or counts across the entire dataset.
Hardware Specification No We thank Huan Zhang for providing compute for the linear regression experiments
Software Dependencies No Our code is based on the wonderful code provided by Garg et al. (2023) at https://github.com/dtsip/in-context-learning.
Experiment Setup Yes Unless otherwise specified, we train with 64 tasks in the discrete distribution, σ = 1 noise, exemplar count uniformly sampled from 0 to 40, weights sampled from the Gaussian prior with parameter τ = 1, and learning rate 0.0001. For our model, we use a 22.4 million paramater GPT-2 style transformer. For more experimental details, refer to Appendix C.8.