Linear attention is (maybe) all you need (to understand Transformer optimization)
Authors: Kwangjun Ahn, Xiang Cheng, Minhak Song, Chulhee Yun, Ali Jadbabaie, Suvrit Sra
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We make progress towards understanding the subtleties of training Transformers by carefully studying a simple yet canonical linearized shallow Transformer model. Specifically, we train linear Transformers to solve regression tasks, inspired by J. von Oswald et al. (ICML 2023), and K. Ahn et al. (Neur IPS 2023). Most importantly, we observe that our proposed linearized models can reproduce several prominent aspects of Transformer training dynamics. Consequently, the results obtained in this paper suggest that a simple linearized Transformer model could actually be a valuable, realistic abstraction for understanding Transformer optimization. |
| Researcher Affiliation | Academia | Kwangjun Ahn MIT EECS/LIDS kjahn@mit.edu Xiang Cheng MIT LIDS chengx@mit.edu Minhak Song KAIST ISys E/Math minhaksong@kaist.ac.kr Chulhee Yun KAIST AI chulhee.yun@kaist.ac.kr Ali Jadbabaie MIT CEE/LIDS jadbabai@mit.edu Suvrit Sra MIT EECS/LIDS suvrit@mit.edu |
| Pseudocode | No | The paper describes the linear Transformer architecture mathematically (e.g., Equation 1 and recursive definitions for Zℓ+1) but does not provide a formal pseudocode block or algorithm steps. |
| Open Source Code | No | The paper does not contain any explicit statement or link indicating that the source code for the described methodology is publicly available. |
| Open Datasets | No | The data distribution can be thought of as the random instances of linear regression. Concretely, for i = 1, 2 . . . , n + 1, let x(i) Rd be drawn i.i.d. from a distribution DX. We then draw w DW and then generate the scalar responses y = [ x(1), w , . . . , x(n), w ] Rn. Our default setup is Setting 1 of Table 1, where the context consists of 20 context demonstrations; each context covariate is sampled from the standard Gaussian, i.e., x(i) N(0, Id), and we draw w N(0, Id). This is consistent with previous works (Garg et al., 2022; Akyürek et al., 2022; von Oswald et al., 2023; Ahn et al., 2023b). |
| Dataset Splits | No | The paper describes the generation of synthetic data and the experimental settings (e.g., Figures 2-5 showing training loss), but it does not specify explicit training, validation, or test dataset splits in terms of percentages or sample counts for reproduction. |
| Hardware Specification | No | The paper does not provide any specific details about the hardware used to conduct the experiments, such as GPU models, CPU types, or memory configurations. |
| Software Dependencies | No | The paper mentions using "torch.nn in Py Torch" for generating an MLP in Appendix B, but it does not provide specific version numbers for PyTorch or any other software dependencies, which would be necessary for exact replication. |
| Experiment Setup | Yes | For each different setting, we pick the best learning rate from a grid search over 10 different choices. We choose the momentum parameter 0.9 for SGD, and β1 = β2 = 0.9 for Adam. We also employ the (global) gradient clipping where the thresholds are chosen to be 1 for all settings (i.e., the clipped gradient direction is the same as the non-clipped direction). All the experiments are run over 6 different random seeds. See Appendix A for details. Table 2: The choice of learning rates for experiments. |