Stepping on the Edge: Curvature Aware Learning Rate Tuners
Authors: Vincent Roulet, Atish Agarwala, Jean-Bastien Grill, Grzegorz Swirszcz, Mathieu Blondel, Fabian Pedregosa
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We empirically observe that classical learning rate tuners qualitatively underperform their constant learning rate counterparts across several deep learning benchmarks, in the full batch regime, for which these methods were originally designed. Our empirical analysis of curvature dynamics reveals that classical learning rate tuners generally undershoot the edge of stability. We observe empirically that the proposed learning rate tuner can outperform fine-tuned constant learning rate counterparts in a full batch regime. |
| Researcher Affiliation | Industry | Vincent Roulet* Google Deep Mind vroulet@google.com Atish Agarwala* Google Deep Mind thetish@google.com Jean-Bastien Grill Google Deep Mind jbgrill@google.com Grzegorz Swirszcz Google Deep Mind swirszcz@google.com Mathieu Blondel Google Deep Mind mblondel@google.com Fabian Pedregosa Google Deep Mind pedregosa@google.com |
| Pseudocode | No | The paper does not contain clearly labeled pseudocode or algorithm blocks. While Equation (8) describes the CDAT rule, it is presented as a mathematical formula, not a structured algorithm. |
| Open Source Code | No | We are not able to open source the code yet. We hope to release it with the arxiv version of the manuscript later. |
| Open Datasets | Yes | MNIST is an image classification dataset of handwritten digits (Le Cun et al., 2010). CIFAR10 is an image classification dataset of colored images of size 32 32 with 10 classes (Krizhevsky et al., 2009). Tiny Shakespeare. This consists in 40, 000 lines of Shakespeare from a variety of Shakespeare s plays (Karpathy, 2015). Imagenet. Imagenet is an image classification dataset of various images from the web (Deng et al., 2009). |
| Dataset Splits | Yes | In the mini batch regime, we considered the full dataset of 50, 000 samples for training (dropping out the remainder batch) and tested on the 10, 000 validation samples. In the full batch experiments, we consider the Imagenette (Howard, 2019) subset that consists in only 10 classes and took 1024 samples out of it. In the mini-batch regime we consider the complete training dataset of 1.2 million images (Imagenet-1K), dropped the remainder batch, and reported test error on the 50, 000 validation images. |
| Hardware Specification | Yes | Experiments have mostly been run on Tensor Processing Units (TPUs) v2 (180 Tera Floating-Point Operations per Second (TFLOPS), 64 GB High Bandwidth Memory (HBM)). Experiments on MLP Mixers required TPUs v3 (420 TFLOPS 128 GB HBM). Very small scale experiments on MNIST with MLPs were run on CPUs. |
| Software Dependencies | No | All experiments are done in the open-source JAX ecosystem (Deep Mind et al., 2020): architectures are taken from Scenic (Dehghani et al., 2022), datasets from Tensor Flow Dataset, algorithms from Optax. The paper lists software names but does not provide specific version numbers for these components. |
| Experiment Setup | Yes | In all experiments, we fix all hyperparameters of the base optimizer ((S)GD, (S)GD with Momentum, RMSProp, Adam) to their default values: 0.9 for the momentum of (S)GD with Momentum, 0.999 for the EMA parameter of the second moment of RMSProp and Adam, 0.9 for the EMA parameter of the first moment of Adam. We fine-tune the learning rate on a logarithmic base of 2 around a base learning rate such as 10 3 or 10 4 (depending on the algorithm, the architecture and the mini-batch size in the stochastic regime as detailed below in Appendix C.5), while making sure that the grid is sufficiently large such that the best learning rate is found inside the grid and not as the smallest or the largest. The scaling factor σ of CDAT was searched on a grid {0.4, 0.6, . . . , 2.8}, and we also tuned the EMA parameter βcdat in the computation of the numerators and denominators of the edge in {0, 0.9, 0.99}. |