$\boldsymbol{\mu}\mathbf{P^2}$: Effective Sharpness Aware Minimization Requires Layerwise Perturbation Scaling
Authors: Moritz Haas, Jin Xu, Volkan Cevher, Leena Chennuru Vankadara
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Through experiments with MLPs, Res Nets and Vision Transformers, we empirically demonstrate that µP2 achieves hyperparameter transfer of the joint optimum of learning rate and perturbation radius across model scales. |
| Researcher Affiliation | Collaboration | 1University of Tübingen, Tübingen AI Center 2University of Oxford 3LIONS, EPFL 4AGI Foundations, Amazon |
| Pseudocode | Yes | Algorithm 1: Pytorch implementation of µP2 for SAM using the mup-package. Key changes from the original implementation that correct the layerwise perturbation scaling are highlighted with gray boxes. This code decouples the scalings of numerator and denominator terms following (DP), and scales the gradient norm contributions of all layers by group["gradnorm_scaling"] in the denominator to be width-independent. The numerator terms group["rho"] of all weight tensors are scaled to achieve effective perturbations. This scaling is equivalent to (a-µP 2) together with naive perturbation and learning rate scaling. |
| Open Source Code | No | We are working on making Python code to reproduce all of our experiments publicly available. |
| Open Datasets | Yes | We train MLPs and Res Nets (He et al., 2016) on CIFAR10 (Krizhevsky et al., 2009) and Vision Transformers (Vi Ts) (Dosovitskiy et al., 2021) on Imagenet1K (Deng et al., 2009). |
| Dataset Splits | No | We evaluate the test accuracy after every epoch and use the snapshot across training with the best accuracy. |
| Hardware Specification | Yes | We ran all of our experiments on Amazon EC2 G5 instances each containing up to 8 NVIDIA A10G GPUs. |
| Software Dependencies | No | While we directly implement bcd-parameterizations for MLPs and Res Nets in Py Torch (Paszke et al., 2019), we use the mup-package (Yang et al., 2022) as a basis for Vi T experiments. |
| Experiment Setup | Yes | If not mentioned otherwise, experiments use the settings specified in this section. Implementation details. For MLPs, we exactly implement our Definition 4 of bcd-parameterizations to precisely validate our theoretical results. For Res Nets and Vi Ts, the width varies inside the network, so that we implement the spectral scaling rules derived in Appendix F.7. Like the mup-package, we introduce a base width at which SP and µP are equivalent, allowing to immediately transfer setups that perform well in SP. We use the mup-package only for Vi Ts, and our implementation of µP2 resembles the pseudocode provided in Algorithm 1. For Res Nets, we use no width-dependent lastlayer multiplier. At initialization, µP differs from SP only through a smaller last layer initialization. For MLPs we exactly implement the bcd-parameterization with b L+1 = 1, but use the large widthindependent input layer initialization variance 2 instead of the width-independent 2/din in µP, which can be seen as a tuned initialization variance multiplier. For Res Nets and Vits, we initialize the last layer to 0 in µP, which corresponds to b L+1 and which recovers the limit behaviour f0 0 already at finite width. We are working on making Python code to reproduce all of our experiments publicly available. |