$\boldsymbol{\mu}\mathbf{P^2}$: Effective Sharpness Aware Minimization Requires Layerwise Perturbation Scaling

Authors: Moritz Haas, Jin Xu, Volkan Cevher, Leena Chennuru Vankadara

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Through experiments with MLPs, Res Nets and Vision Transformers, we empirically demonstrate that µP2 achieves hyperparameter transfer of the joint optimum of learning rate and perturbation radius across model scales.
Researcher Affiliation Collaboration 1University of Tübingen, Tübingen AI Center 2University of Oxford 3LIONS, EPFL 4AGI Foundations, Amazon
Pseudocode Yes Algorithm 1: Pytorch implementation of µP2 for SAM using the mup-package. Key changes from the original implementation that correct the layerwise perturbation scaling are highlighted with gray boxes. This code decouples the scalings of numerator and denominator terms following (DP), and scales the gradient norm contributions of all layers by group["gradnorm_scaling"] in the denominator to be width-independent. The numerator terms group["rho"] of all weight tensors are scaled to achieve effective perturbations. This scaling is equivalent to (a-µP 2) together with naive perturbation and learning rate scaling.
Open Source Code No We are working on making Python code to reproduce all of our experiments publicly available.
Open Datasets Yes We train MLPs and Res Nets (He et al., 2016) on CIFAR10 (Krizhevsky et al., 2009) and Vision Transformers (Vi Ts) (Dosovitskiy et al., 2021) on Imagenet1K (Deng et al., 2009).
Dataset Splits No We evaluate the test accuracy after every epoch and use the snapshot across training with the best accuracy.
Hardware Specification Yes We ran all of our experiments on Amazon EC2 G5 instances each containing up to 8 NVIDIA A10G GPUs.
Software Dependencies No While we directly implement bcd-parameterizations for MLPs and Res Nets in Py Torch (Paszke et al., 2019), we use the mup-package (Yang et al., 2022) as a basis for Vi T experiments.
Experiment Setup Yes If not mentioned otherwise, experiments use the settings specified in this section. Implementation details. For MLPs, we exactly implement our Definition 4 of bcd-parameterizations to precisely validate our theoretical results. For Res Nets and Vi Ts, the width varies inside the network, so that we implement the spectral scaling rules derived in Appendix F.7. Like the mup-package, we introduce a base width at which SP and µP are equivalent, allowing to immediately transfer setups that perform well in SP. We use the mup-package only for Vi Ts, and our implementation of µP2 resembles the pseudocode provided in Algorithm 1. For Res Nets, we use no width-dependent lastlayer multiplier. At initialization, µP differs from SP only through a smaller last layer initialization. For MLPs we exactly implement the bcd-parameterization with b L+1 = 1, but use the large widthindependent input layer initialization variance 2 instead of the width-independent 2/din in µP, which can be seen as a tuned initialization variance multiplier. For Res Nets and Vits, we initialize the last layer to 0 in µP, which corresponds to b L+1 and which recovers the limit behaviour f0 0 already at finite width. We are working on making Python code to reproduce all of our experiments publicly available.