Transformers Learn In-Context by Gradient Descent

Authors: Johannes Von Oswald, Eyvind Niklasson, Ettore Randazzo, Joao Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, Max Vladymyrov

ICML 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental show empirically that when training self-attention-only Transformers on simple regression tasks either the models learned by GD and Transformers show great similarity or, remarkably, the weights found by optimization match the construction.
Researcher Affiliation Collaboration 1Department of Computer Science, ETH Z urich, Z urich, Switzerland 2Google Research.
Pseudocode No The paper describes methods through prose and equations, but does not include any clearly labeled pseudocode or algorithm blocks.
Open Source Code Yes Main experiments can be reproduced with notebooks provided under the following link: https://github.com/google-research/self-organising-systems/tree/master/transformers_learn_icl_by_gd
Open Datasets No We focus on solvable tasks and similarly to Garg et al. (2022) generate data for each task using a teacher model with parameters Wτ N(0, I). We then sample xτ,i U(-1, 1)n I and construct targets using the task-specific teacher model, yτ,i = Wτxτ,i.
Dataset Splits Yes More concretely, to compare trained and constructed LSA layers, we sample Tval = 10^4 validation tasks and record the following quantities, averaged over validation tasks
Hardware Specification No The paper does not provide specific details regarding the hardware used for experiments, such as CPU/GPU models or memory specifications.
Software Dependencies No The paper mentions software like Adam, Optax, and Haiku, but does not provide specific version numbers for these dependencies, which are necessary for full reproducibility.
Experiment Setup Yes Optimizer: Adam (Kingma & Ba, 2014) with default parameters and learning rate of 0.001 for Transformer with depth K < 3 and 0.0005 otherwise. We use a batchsize of 2048 and applied gradient clipping to obtain gradients with global norm of 10.