What learning algorithm is in-context learning? Investigations with linear models
Authors: Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, Denny Zhou
ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | First, we prove by construction that transformers can implement learning algorithms for linear models based on gradient descent and closed-form ridge regression. Second, we show that trained in-context learners closely match the predictors computed by gradient descent, ridge regression, and exact least-squares regression, transitioning between different predictors as transformer depth and dataset noise vary, and converging to Bayesian estimators for large widths and depths. Third, we present preliminary evidence that in-context learners share algorithmic features with these predictors: learners late layers non-linearly encode weight vectors and moment matrices. In Section 4, we investigate empirical properties of trained in-context learners. |
| Researcher Affiliation | Collaboration | 1Google Research 2MIT CSAIL 3 Stanford University |
| Pseudocode | No | Appendices A and B describe the computational steps for implementing algorithms using operations like 'mov', 'mul', 'div', and 'aff', and show how these map to transformer actions. However, these are presented as descriptive derivations of operator chains rather than formally structured 'Pseudocode' or 'Algorithm' blocks. |
| Open Source Code | Yes | Code and reference implementations are released at this web page. The accompanying code release contains a reference implementation of SGD defined in terms of the base primitive provided in an anymous links https://icl1.s3.us-east-2.amazonaws.com/theory/{primitives,sgd,ridge}.py (to preserve the anonymity we did not provide the library dependencies). |
| Open Datasets | No | For the main experiments we generate data according to p(w) = N(0, I) and p(x) = N(0, I). This indicates that the data was synthetically generated by the authors rather than being a publicly available dataset with concrete access information (link, DOI, etc.). |
| Dataset Splits | No | For our main experiments, we found that L = 16, H = 512, M = 4 minimized loss on a validation set. While a validation set is mentioned, no specific details about its size or percentage split from the total data are provided. |
| Hardware Specification | Yes | We perform these experiments using the Jax framework on P100 GPUs. |
| Software Dependencies | No | The paper mentions using the 'Jax framework' and 'Adam optimizer' and 'Ge LU activation function', but it does not specify exact version numbers for any of these software components, which is required for reproducibility. |
| Experiment Setup | Yes | For all experiments, we perform a hyperparameter search over depth L {1, 2, 4, 8, 12, 16}, hidden size W {16, 32, 64, 256, 512, 1024} and heads M {1, 2, 4, 8}. Other hyper-parameters are noted in Appendix D. For our main experiments, we found that L = 16, H = 512, M = 4 minimized loss on a validation set. |