Meta-learning to Improve Pre-training
Authors: Aniruddh Raghu, Jonathan Lorraine, Simon Kornblith, Matthew McDermott, David K. Duvenaud
NeurIPS 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We demonstrate that our method improves predictive performance on two real-world domains. First, we optimize high-dimensional task weighting hyperparameters for multitask pre-training on protein-protein interaction graphs and improve AUROC by up to 3.9%. Second, we optimize a data augmentation neural network for self-supervised PT with Sim CLR on electrocardiography data and improve AUROC by up to 1.9%. |
| Researcher Affiliation | Collaboration | Aniruddh Raghu Massachusetts Institute of Technology araghu@mit.edu Jonathan Lorraine University of Toronto Simon Kornblith Google Research Matthew Mc Dermott Massachusetts Institute of Technology David Duvenaud Google Research & University of Toronto |
| Pseudocode | Yes | Algorithm 1 Gradient-based algorithm to learn meta-parameters. Notation defined in Appendix B, Table 3. |
| Open Source Code | No | The paper does not contain an explicit statement about releasing code or a link to a code repository for the methodology described in this paper. |
| Open Datasets | Yes | We consider the transfer learning benchmark introduced in [28], where the prediction problem at both PT and FT is multitask binary classification: predicting the presence/absence of specific protein functions (y) given a Protein-Protein Interaction (PPI) network as input (represented as a graph x). The PT dataset has pairs DPT = {(xi, yi)}|DPT| i=1 , where y {0, 1}5000 characterizes the presence/absence of 5000 particular protein functions. The FT dataset has pairs DFT = {(xi, yi)}|DFT| i=1 , where y {0, 1}40 now characterizes the presence/absence of 40 different protein functions. Further dataset details in Appendix F. |
| Dataset Splits | Yes | Following standard convention, we split DFT into two subsets for meta-learning: D(tr) FT and D(val) FT (independent of any held-out DFT testing split), and define the FT data available at meta-PT time as D(Meta) FT = D(tr) FT D(val) FT . We use D(tr) FT for the computation of Alg FT θ(P ) PT ,ψ(0) FT and φ(n 1) and D(val) FT for the computation of LFT [θFT, ψFT] θ FT,ψ FT in Algorithm 1.We use the training and validation splits of the FT dataset DFT proposed by the dataset creators [28] for computing the relevant gradient terms.We use the training and validation splits of the FT dataset DFT proposed by the dataset creators [64] for computing the relevant gradient terms. |
| Hardware Specification | No | The paper does not provide specific hardware details (exact GPU/CPU models, processor types with speeds, memory amounts, or detailed computer specifications) used for running its experiments. |
| Software Dependencies | No | The paper does not provide specific ancillary software details (e.g., library or solver names with version numbers like Python 3.8, CPLEX 12.4) needed to replicate the experiment. |
| Experiment Setup | Yes | Experimental Details. We use a standardized setup to facilitate comparisons. Following [28], all methods use the Graph Isomorphism Network architecture [69], undergo PT for 100 epochs, and FT for 50 epochs, over 5 random seeds, using early stopping based on validation set performance. During FT, we initialize a new FT network head and either FT the whole network or freeze the PT feature extractor and learn the FT head alone (Linear Evaluation [50]). We report results for the strategy that performed best (full results in the appendix). We consider two experimental scenarios: (1) Full FT Access: Provide methods full access to DPT and DFT at PT time (D(Meta) FT = DFT) and evaluate on the full set of 40 FT tasks; (2) Partial FT Access: Limit the number of FT tasks seen at PT time, by letting D(Meta) FT include only 30 of the 40 FT tasks. At FT time, models are fine-tuned on the held-out 10 tasks not in D(Meta) FT . We use a 4-fold approach where we leave out 10 of the 40 FT tasks in turn, and examine performance across these 10 held-out tasks, over the folds. |