Universal Neural Functionals
Authors: Allan Zhou, Chelsea Finn, James Harrison
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | This work proposes an algorithm that automatically constructs permutation equivariant models, which we refer to as universal neural functionals (UNFs), for any weight space. Among other applications, we demonstrate how UNFs can be substituted into existing learned optimizer designs, and find promising improvements over prior methods when optimizing small image classifiers and language models. Our results suggest that learned optimizers can benefit from considering the (symmetry) structure of the weight space they optimize. |
| Researcher Affiliation | Collaboration | Allan Zhou Stanford University ayz@cs.stanford.edu Chelsea Finn Stanford University James Harrison Google Deep Mind jamesharrison@google.com |
| Pseudocode | Yes | Algorithm 1 Basis for equivariant W(m) W(ℓ) layer |
| Open Source Code | Yes | We open-source our library for constructing UNFs at https://github.com/Allan Yang Zhou/universal_neural_functional. |
| Open Datasets | Yes | We construct Tiny RNN Zoo2, a dataset of recurrent neural networks trained to do arithmetic by completing given questions character-by-character. ... MLP on Fashion MNIST. ... CNN on CIFAR-10. ... RNN on LM1B. ... Transformer on LM1B. |
| Dataset Splits | Yes | We split the Tiny RNN Zoo into 8000/1000/1000 training, validation, and test examples. |
| Hardware Specification | Yes | Experiments were run on a mix of TPU v3 and v4 accelerators. ... We are grateful to the TPU Research Cloud (TRC) for providing compute for some of the experiments. |
| Software Dependencies | Yes | Our open-source implementation is compatible with most JAX [Bradbury et al., 2018] neural network libraries. |
| Experiment Setup | Yes | We train each predictor with binary cross entropy loss (since the target SR is in [0, 1]), using the Adam optimizer with learning rate 0.001, batch size 10, and training for up to 10 epochs. ... We use an inner training horizon T = 2,000 for the first three tasks and T = 5,000 for the Transformer task... We meta-train for 50,000 steps using Adam, estimating meta-gradients over 16 parallel training runs using persistent evolutionary strategies (PES) [Vicol et al., 2021] with a truncation length of 50 and a noise standard deviation of 0.01. The meta-training objective is training loss at the end of the inner training horizon... and we apply a gradient clipping of 1.0. |