Mnemosyne: Learning to Train Transformers with Transformers

Authors: Deepali Jain, Krzysztof M Choromanski, Kumar Avinava Dubey, Sumeet Singh, Vikas Sindhwani, Tingnan Zhang, Jie Tan

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We conduct an extensive empirical evaluation of Mnemosyne on: (a) fine-tuning a wide range of Vision Transformers (Vi Ts) from medium-size architectures to massive Vi T-Hs (36 layers, 16 heads), (b) pre-training BERT models and (c) soft prompt-tuning large 11B+ T5XXL models. We complement our results with a comprehensive theoretical analysis of the compact associative memory used by Mnemosyne which we believe was never done before.
Researcher Affiliation Collaboration Deepali Jain Google Deep Mind, jaindeepali@google.com Krzysztof Choromanski Google Deep Mind and Columbia University, kchoro@google.com Avinava Dubey Google Research, avinavadubey@google.com Sumeet Singh Google Deep Mind, ssumeet@google.com Vikas Sindhwani Google Deep Mind, sindhwani@google.com Tingnan Zhang Google Deep Mind, tingnan@google.com Jie Tan Google Deep Mind, jietan@google.com
Pseudocode No The paper includes mathematical equations and architectural diagrams (Figure 1, Figure 2) but does not provide any pseudocode or algorithm blocks.
Open Source Code No The paper does not contain any explicit statements about releasing code or links to a source code repository for the methodology described.
Open Datasets Yes We conduct an extensive empirical evaluation of Mnemosyne on: (a) fine-tuning a wide range of Vision Transformers (Vi Ts) from medium-size architectures to massive Vi T-Hs (36 layers, 16 heads), (b) pre-training BERT models and (c) soft prompt-tuning large 11B+ T5XXL models. We conduct an extensive empirical evaluation of Mnemosyne on: (a) fine-tuning a wide range of Vision Transformers (Vi Ts) from medium-size architectures to massive Vi T-Hs (36 layers, 16 heads), (b) pre-training BERT models and (c) soft prompt-tuning large 11B+ T5XXL models. In Section 5.1 and Figure 3, the paper mentions experiments on 'MNIST, CIFAR10 and CIFAR100'. Section 5.2 and Figure 4 mention 'imagenet2012, places365 and caltech-birds-2011' datasets. Section 5.3 and Table 4 mention 'Books [77]' and 'Wikipedia' for BERT pre-training. Section 5.2 and Figure 6 mention 'Super GLUE task'.
Dataset Splits No The paper mentions training on batches of examples and reports 'Test loss curves', implying the use of test sets, but it does not explicitly provide the specific percentages or counts for training, validation, and test splits, nor does it cite standard splits with sufficient detail for reproduction.
Hardware Specification Yes All Mnemosyne optimizer variants were trained and tested on a TPU pod containing 4 TPU v3 chips with JAX.
Software Dependencies No The paper mentions the use of 'JAX' (Section B.7) and 'pytree structure... used in JAX' (footnote 2), but it does not provide specific version numbers for JAX or any other software libraries or dependencies used in the experiments.
Experiment Setup Yes All Mnemosyne variants are meta-trained on small scale MLP and VIT training tasks for a short horizon of 100 steps. The optimizee task is to train the model for 100 steps on batches of 64 image-class examples. Each cell uses exponential discount factor τ = 0.1, r = 16 random projections, 16 hidden dimensions and 1 attention head. Mnemosyne s optimizer is meta-trained by gradient descent using Adam optimizer with learning rate η = 3e 4. (Section B.1) Batch sizes of 128, 256, and 512 are also mentioned for different experiments (Sections B.8.2, B.10, B.11).