Induced Model Matching: Restricted Models Help Train Full-Featured Models

Authors: Usama Muneeb, Mesrob I Ohannessian

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experimentally, we first motivate IMM using logistic regression as a toy example. We then explore it in language modeling, the application that initially inspired it, and demonstrate it on both LSTM and transformer full models, using bigrams as restricted models. We lastly give a simple RL example, which shows that POMDP policies can help learn better MDP policies. 7 Experimental Results
Researcher Affiliation Academia Usama Muneeb Electrical and Computer Engineering University of Illinois Chicago umunee2@uic.edu Mesrob I. Ohannessian Electrical and Computer Engineering University of Illinois Chicago mesrob@uic.edu
Pseudocode Yes Algorithm 1 Sampled IMM with SGD for a Model Q with parameters W. Algorithm 2 Serialized IMM with SGD for a Model Q with parameters W.
Open Source Code Yes Code is available at https://github.com/uicdice/imm-logistic-regression. Code is available at https://github.com/uicdice/imm-language-modeling. Code is available at https://github.com/uicdice/imm-reinforce.
Open Datasets Yes Table 1: Perplexity for an LSTM Language Model using the Penn Tree Bank dataset. BERT experiments we used k = 5. IMM through reintroduced MLM loss in BERT fine-tuning BERT on GLUE differs from LSTM on PTB in two ways. GLUE (Wang et al., 2018) datasets (used in the original BERT paper).
Dataset Splits Yes For training and measuring validation perplexity, L = 35 unroll positions are used. To determine a good schedule for λ as a function of dataset size n, we sweep the ratio λ 1+λ... for a range of dataset sizes, from a minimum of 2 to a maximum of 50.
Hardware Specification Yes In our experiments, it was problematic for the GPU (Nvidia V100 with 32 GB memory) to perform backpropagation for k > 6 for the LSTM RNN.
Software Dependencies No The paper mentions using existing codebases (e.g., Google's BERT code, Xie et al. (2017) code, POMDPs.jl package) but does not provide specific version numbers for these or other software dependencies.
Experiment Setup Yes For training and measuring validation perplexity, L = 35 unroll positions are used. We report perplexity values using k-sampled IMM with k = 10. For the LSTM experiments, there was not much gain going from 10 to 20, so we settled for k = 10 (and j = 5). For the BERT experiments, we used k = 5. We set the per epoch observation length to 50.