Why Do You Grok? A Theoretical Analysis on Grokking Modular Addition
Authors: Mohamad Amin Mohamadi, Zhiyuan Li, Lei Wu, Danica J. Sutherland
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | First, we show that early in gradient descent, when the kernel regime approximately holds, no permutation-equivariant model can achieve small population error on modular addition unless it sees at least a constant fraction of all possible data points. Eventually, however, models escape the kernel regime. We show that one-hidden-layer quadratic networks that achieve zero training loss with bounded ℓ norm generalize well with substantially fewer training points, and further show such networks exist and can be found by gradient descent with small ℓ regularization. We further provide empirical evidence that these networks leave the kernel regime only after initially overfitting. Taken together, our results strongly support the case for grokking as a consequence of the transition from kernel-like behavior to limiting behavior of gradient descent on deep networks. |
| Researcher Affiliation | Academia | 1Toyota Technological Institute at Chicago 2Computer Science Department, University of British Columbia 3School of Mathematical Sciences, Peking University 4Alberta Machine Intelligence Institute. |
| Pseudocode | No | No pseudocode or algorithm blocks are explicitly provided in the paper. |
| Open Source Code | No | The paper does not provide an explicit statement about releasing source code or a link to a repository for the methodology described. |
| Open Datasets | No | The paper studies modular addition, which is an algorithmic task with data generated based on the problem definition, not a standard publicly available dataset with a direct link or citation. |
| Dataset Splits | No | The paper mentions 'training data' and 'test points' but does not specify validation dataset splits (percentages or counts). |
| Hardware Specification | Yes | We used the JAX framework (Bradbury et al., 2018) to implement and run the experiments on machines using NVIDIA V100 or A100 GPUs. |
| Software Dependencies | Yes | Py Torch 2: Faster machine learning through dynamic Python bytecode transformation and graph compilation. ... We used the JAX framework (Bradbury et al., 2018). |
| Experiment Setup | Yes | A.1. Regression: We have used vanilla gradient descent with squared less for 100,000 or 200,000 for each experiment. In regression, our learning has been fixed to 1, and the regularization strength has been set to 10 4. The network has been initialized according to He et al. (2015). The amount of data used for training in regression task has been set to 2 p2.25. A.2. Classification: In all experiments, we have used vanilla gradient descent with cross-entropy loss, for up to 100,000 steps. ... The learning rate in the presented experiments was set to 10 and was kept constant during the training. The regularization strength of ℓ regularizer has been set to 10 20. The network has been initialized according to He et al. (2015). The amount of data used for training in regression task has been set to 2 p5/3. |