Induced Model Matching: Restricted Models Help Train Full-Featured Models
Authors: Usama Muneeb, Mesrob I Ohannessian
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Experimentally, we first motivate IMM using logistic regression as a toy example. We then explore it in language modeling, the application that initially inspired it, and demonstrate it on both LSTM and transformer full models, using bigrams as restricted models. We lastly give a simple RL example, which shows that POMDP policies can help learn better MDP policies. 7 Experimental Results |
| Researcher Affiliation | Academia | Usama Muneeb Electrical and Computer Engineering University of Illinois Chicago umunee2@uic.edu Mesrob I. Ohannessian Electrical and Computer Engineering University of Illinois Chicago mesrob@uic.edu |
| Pseudocode | Yes | Algorithm 1 Sampled IMM with SGD for a Model Q with parameters W. Algorithm 2 Serialized IMM with SGD for a Model Q with parameters W. |
| Open Source Code | Yes | Code is available at https://github.com/uicdice/imm-logistic-regression. Code is available at https://github.com/uicdice/imm-language-modeling. Code is available at https://github.com/uicdice/imm-reinforce. |
| Open Datasets | Yes | Table 1: Perplexity for an LSTM Language Model using the Penn Tree Bank dataset. BERT experiments we used k = 5. IMM through reintroduced MLM loss in BERT fine-tuning BERT on GLUE differs from LSTM on PTB in two ways. GLUE (Wang et al., 2018) datasets (used in the original BERT paper). |
| Dataset Splits | Yes | For training and measuring validation perplexity, L = 35 unroll positions are used. To determine a good schedule for λ as a function of dataset size n, we sweep the ratio λ 1+λ... for a range of dataset sizes, from a minimum of 2 to a maximum of 50. |
| Hardware Specification | Yes | In our experiments, it was problematic for the GPU (Nvidia V100 with 32 GB memory) to perform backpropagation for k > 6 for the LSTM RNN. |
| Software Dependencies | No | The paper mentions using existing codebases (e.g., Google's BERT code, Xie et al. (2017) code, POMDPs.jl package) but does not provide specific version numbers for these or other software dependencies. |
| Experiment Setup | Yes | For training and measuring validation perplexity, L = 35 unroll positions are used. We report perplexity values using k-sampled IMM with k = 10. For the LSTM experiments, there was not much gain going from 10 to 20, so we settled for k = 10 (and j = 5). For the BERT experiments, we used k = 5. We set the per epoch observation length to 50. |