Combining Axes Preconditioners through Kronecker Approximation for Deep Learning

Authors: Sai Surya Duvvuri, Fnu Devvrit, Rohan Anil, Cho-Jui Hsieh, Inderjit S Dhillon

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Furthermore, our experiments demonstrates that CASPR approximates the gradient second-moment matrix in fullmatrix Adagrad more accurately, and shows significant improvement in training and generalization performance compared to existing practical adaptive regularization based methods such as Shampoo and Adam in a variety of tasks including graph neural network on OGBG-molpcba, Transformer on a universal dependencies dataset and auto-regressive large language modeling on C4 dataset.
Researcher Affiliation Collaboration Sai Surya Duvvuri Department of Computer Science The University of Texas at Austin saisurya@cs.utexas.edu Fnu Devvrit Department of Computer Science The University of Texas at Austin devvrit@cs.utexas.edu Rohan Anil Google Deep Mind rohananil@google.com Cho-Jui Hsieh CS Department, UCLA & Google chohsieh@cs.ucla.edu Inderjit S. Dhillon Google isd@google.com
Pseudocode Yes Algorithm 1 CASPR Algorithm
Open Source Code No The CASPR code is adapted from the Optax (Babuschkin et al., 2020) implementation of Shampoo (Anil et al.), which is a JAX implementation. However, the paper does not provide a direct link or explicit statement about the public availability of their specific CASPR implementation.
Open Datasets Yes graph neural network on OGBG-molpcba (Hu et al., 2020), Transformer on a universal dependencies dataset (Nivre et al., 2020) and auto-regressive large language modeling on C4 dataset (Raffel et al., 2020).
Dataset Splits Yes The training data consists of 350,343 graphs and the test set has 43,793 graphs.
Hardware Specification Yes Our data-parallel training runs required 4 TPUv2s, however, the walltime in Figure 3 is computed on TPU v4s to utilize more latest optimizations that TPU v2s don t offer. We use one Nvidia A100 gpu for this benchmark. Training involved 16 TPU v3s for 234M model with 256 batch size and 64 TPU v3s for 14M model with 8192 batch size, using paxml software (Google).
Software Dependencies No The paper mentions using JAX, Flax, Optax, Jraph, and paxml software, but it does not specify version numbers for these dependencies, making it difficult to fully reproduce the software environment.
Experiment Setup Yes Our training process uses binary cross entropy loss for each class among the 128 with a batch size of 512. We utilize a specific learning rate schedule, involving a linear warmup followed by a cosine decay. In this benchmark, we compare CASPR against Shampoo and Adam W. We use random search with upto 300 hyperparameters, where we search over weight decay, learning rate and momentum parameter for all the algorithms. We fix the dropout to 0.1.