Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture

Authors: Dan Fu, Simran Arora, Jessica Grogan, Isys Johnson, Evan Sabri Eyuboglu, Armin Thomas, Benjamin Spector, Michael Poli, Atri Rudra, Christopher Ré

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental As a proof of concept, we explore the performance of M2 in three domains: non-causal BERT-style language modeling, ViT-style image classification, and causal GPT-style language modeling. For non-causal BERT-style modeling, M2 matches BERT-base and BERT-large in downstream GLUE quality with up to 27% fewer parameters, and achieves up to 9.1 higher throughput at sequence length 4K. On ImageNet, M2 outperforms ViT-b by 1% in accuracy, with only half the parameters. ... Using this parameterization, M2 matches GPT-style Transformers at 360M parameters in pretraining perplexity on The PILE showing for the first time that it may be possible to match Transformer quality without attention or MLPs.
Researcher Affiliation Academia 1Department of Computer Science, Stanford University. 2Department of Computer Science and Engineering, University at Buffalo, SUNY. 3Department of Psychology, Stanford University.
Pseudocode Yes Our proof-of-concept implementation of an M2 layer, written in less than 40 lines of pure PyTorch (including imports), relies only on matrix multiplication, transpose, reshape, and elementwise products (see pseudocode in Figure 1 middle) and achieves 25.6% FLOP utilization3 for inputs of size 64K on an A100 GPU. ... and the Appendix gives an efficient implementation of M2 in under 40 lines of pure PyTorch (including imports).
Open Source Code Yes 1Code is available at https://github.com/HazyResearch/m2.
Open Datasets Yes We train M2-BERT using masked language modeling over C4 [66] with the bert-base-uncased tokenizer. ... On ImageNet, M2 outperforms ViT-b by 1% in accuracy... We pretrain M2-GPT on the PILE, a standard dataset for causal language modeling.
Dataset Splits Yes We further tune weight decay (0 or 0.1), stochastic depth rate (0 or 0.1), and base learning rate (1e-4 or 3e-4 or 1e-3) and report the test performance for the model variant that achieved the highest accuracy in a separate held-out validation dataset (randomly selected 10% of training data).
Hardware Specification Yes Inference times are reported in tokens/ms on an A100-40GB GPU. ... On the RTX 4090, which has a larger and faster L2 cache than the A100, we can manually optimize an implementation to amortize data movement costs. ... Measurements averaged over 10 examples on a 48 vCPU, 96 GB RAM instance from the GCP n2-standard-48 series, which runs Intel Cascade Lake processors.
Software Dependencies No Our proof-of-concept implementation of an M2 layer, written in less than 40 lines of pure PyTorch (including imports), relies only on matrix multiplication... Hugging Face implementations of BERT... standard cuBLAS sub-routines. The paper mentions software by name but does not provide specific version numbers for these dependencies.
Experiment Setup Yes We train all these models on C4 for 70,000 steps, with sequence length 128, and global batch size 4096 sequences. For all the models, we use decoupled AdamW with learning rate 8e-4 and decoupled weight decay 1e-5. We use linear learning rate decay with a warmup of 6% of the steps, and we use MLM masking percentage of 30%. For GLUE fine-tuning, we do a small search of learning rate, weight decay, and number of epochs. ... we use sinusoidal position embeddings and global average-pooling (GAP) instead of a class token. ... we follow the training procedure of T2T-ViT [86], including augmentations such as RandAugment [12] (magnitude = 9, magnitude-std = 0.5, layers = 2), Mixup [88] (α = 0.8), CutMix [87] (α = 1.0), Random erasing [90] (rate = 0.25), and AugMix [37].