Transformers learn to implement preconditioned gradient descent for in-context learning
Authors: Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, Suvrit Sra
NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We empirically validate the critical points analyzed in Theorem 3 and Theorem 4. For a transformer with three layers, our experimental results confirm the structural of critical points. |
| Researcher Affiliation | Academia | Kwangjun Ahn MIT EECS/LIDS kjahn@mit.edu Xiang Cheng MIT LIDS chengx@mit.edu Hadi Daneshmand MIT LIDS/FODSI hdanesh@mit.edu Suvrit Sra TU Munich / MIT suvrit@mit.edu |
| Pseudocode | No | The paper describes algorithmic steps and derivations (e.g., Eq. 9), but does not present them in a structured pseudocode or algorithm block format. |
| Open Source Code | Yes | Code for our experiments is available at https://github.com/chengxiang/Linear Transformer. |
| Open Datasets | No | Data distribution: random linear regression instances. Let x(i) Rd be the covariates drawn i.i.d. from a distribution DX , and w Rd be drawn from DW. |
| Dataset Splits | No | The paper focuses on generating random problem instances for training and evaluating. It does not provide explicit training/validation/test dataset splits with percentages or sample counts. |
| Hardware Specification | No | The paper states training was done using ADAM, but does not provide any specific hardware details such as GPU or CPU models. |
| Software Dependencies | No | We optimizes f for a three-layer linear transformer using ADAM... The paper mentions 'ADAM' but does not provide version numbers for any software dependencies. |
| Experiment Setup | Yes | The dimension is d = 5, and the number of training samples in the prompt is n = 20. ...We optimizes f for a three-layer linear transformer using ADAM, where the matrices A0, A1, and A2 are initialized by i.i.d. Gaussian matrices. Each gradient step is computed from a minibatch of size 20000, and we resample the minibatch every 100 steps. We clip the gradient of each matrix to 0.01. All plots are averaged over 5 runs with different U (i.e. Σ) sampled each time. |