Why Do You Grok? A Theoretical Analysis on Grokking Modular Addition

Authors: Mohamad Amin Mohamadi, Zhiyuan Li, Lei Wu, Danica J. Sutherland

ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental First, we show that early in gradient descent, when the kernel regime approximately holds, no permutation-equivariant model can achieve small population error on modular addition unless it sees at least a constant fraction of all possible data points. Eventually, however, models escape the kernel regime. We show that one-hidden-layer quadratic networks that achieve zero training loss with bounded ℓ norm generalize well with substantially fewer training points, and further show such networks exist and can be found by gradient descent with small ℓ regularization. We further provide empirical evidence that these networks leave the kernel regime only after initially overfitting. Taken together, our results strongly support the case for grokking as a consequence of the transition from kernel-like behavior to limiting behavior of gradient descent on deep networks.
Researcher Affiliation Academia 1Toyota Technological Institute at Chicago 2Computer Science Department, University of British Columbia 3School of Mathematical Sciences, Peking University 4Alberta Machine Intelligence Institute.
Pseudocode No No pseudocode or algorithm blocks are explicitly provided in the paper.
Open Source Code No The paper does not provide an explicit statement about releasing source code or a link to a repository for the methodology described.
Open Datasets No The paper studies modular addition, which is an algorithmic task with data generated based on the problem definition, not a standard publicly available dataset with a direct link or citation.
Dataset Splits No The paper mentions 'training data' and 'test points' but does not specify validation dataset splits (percentages or counts).
Hardware Specification Yes We used the JAX framework (Bradbury et al., 2018) to implement and run the experiments on machines using NVIDIA V100 or A100 GPUs.
Software Dependencies Yes Py Torch 2: Faster machine learning through dynamic Python bytecode transformation and graph compilation. ... We used the JAX framework (Bradbury et al., 2018).
Experiment Setup Yes A.1. Regression: We have used vanilla gradient descent with squared less for 100,000 or 200,000 for each experiment. In regression, our learning has been fixed to 1, and the regularization strength has been set to 10 4. The network has been initialized according to He et al. (2015). The amount of data used for training in regression task has been set to 2 p2.25. A.2. Classification: In all experiments, we have used vanilla gradient descent with cross-entropy loss, for up to 100,000 steps. ... The learning rate in the presented experiments was set to 10 and was kept constant during the training. The regularization strength of ℓ regularizer has been set to 10 20. The network has been initialized according to He et al. (2015). The amount of data used for training in regression task has been set to 2 p5/3.