Transformers learn to implement preconditioned gradient descent for in-context learning

Authors: Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, Suvrit Sra

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We empirically validate the critical points analyzed in Theorem 3 and Theorem 4. For a transformer with three layers, our experimental results confirm the structural of critical points.
Researcher Affiliation Academia Kwangjun Ahn MIT EECS/LIDS kjahn@mit.edu Xiang Cheng MIT LIDS chengx@mit.edu Hadi Daneshmand MIT LIDS/FODSI hdanesh@mit.edu Suvrit Sra TU Munich / MIT suvrit@mit.edu
Pseudocode No The paper describes algorithmic steps and derivations (e.g., Eq. 9), but does not present them in a structured pseudocode or algorithm block format.
Open Source Code Yes Code for our experiments is available at https://github.com/chengxiang/Linear Transformer.
Open Datasets No Data distribution: random linear regression instances. Let x(i) Rd be the covariates drawn i.i.d. from a distribution DX , and w Rd be drawn from DW.
Dataset Splits No The paper focuses on generating random problem instances for training and evaluating. It does not provide explicit training/validation/test dataset splits with percentages or sample counts.
Hardware Specification No The paper states training was done using ADAM, but does not provide any specific hardware details such as GPU or CPU models.
Software Dependencies No We optimizes f for a three-layer linear transformer using ADAM... The paper mentions 'ADAM' but does not provide version numbers for any software dependencies.
Experiment Setup Yes The dimension is d = 5, and the number of training samples in the prompt is n = 20. ...We optimizes f for a three-layer linear transformer using ADAM, where the matrices A0, A1, and A2 are initialized by i.i.d. Gaussian matrices. Each gradient step is computed from a minibatch of size 20000, and we resample the minibatch every 100 steps. We clip the gradient of each matrix to 0.01. All plots are averaged over 5 runs with different U (i.e. Σ) sampled each time.