Why Does Sharpness-Aware Minimization Generalize Better Than SGD?

Authors: Zixiang Chen, Junkai Zhang, Yiwen Kou, Xiangning Chen, Cho-Jui Hsieh, Quanquan Gu

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experiments on both synthetic and real data corroborate our theory. and In this section, we conduct synthetic experiments to validate our theory. Additional experiments on real data sets can be found in Appendix A.
Researcher Affiliation Academia Zixiang Chen Junkai Zhang Yiwen Kou Xiangning Chen Cho-Jui Hsieh Quanquan Gu Department of Computer Science University of California, Los Angeles Los Angeles, CA 90095
Pseudocode Yes Algorithm 1 Minibatch Sharpness Aware Minimization
Open Source Code No No explicit statement or link indicating the release of source code for the methodology described in the paper was found.
Open Datasets Yes We conduct experiments on the ImageNet dataset with ResNet50. and Here, we conduct experiments on the CIFAR dataset with WRN-16-8.
Dataset Splits Yes The model is trained for 90 epochs with the best learning rate in grid search {0.01, 0.03, 0.1, 0.3}.
Hardware Specification No The paper does not specify the exact GPU models, CPU models, or other specific hardware specifications used for experiments.
Software Dependencies No The paper mentions 'PyTorch' for initialization but does not provide specific version numbers for any software dependencies.
Experiment Setup Yes We set training data size n = 20... The number of filters is set as m = 10... learning rate of 0.01 for 100 iterations. We consider different dimensions d ranging from 1000 to 20000, and different signal strengths µ 2 ranging from 0 to 10. and batch size as 1024, and the model is trained for 90 epochs with the best learning rate in grid search {0.01, 0.03, 0.1, 0.3}. and batch size as 128 and train the model over 200 epochs using a learning rate of 0.1, a momentum of 0.9, and a weight decay of 5e 4. The SAM hyperparameter is chosen as τ = 2.0.