In-context Learning on Function Classes Unveiled for Transformers

Authors: Zhijie Wang, Bo Jiang, Shuai Li

ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental A. Experiment In Figure 1 and Figure 2, we show that a pretrained transformer can learn a quadratic function and a 3-layer neural network without any parameter update. In the pretraining stage we set the learning rate of stochastic gradient descent to 3 10 4 and train for 100000 steps with batch size 32. In the inference stage we generate prompts (x, y)s according to the corresponding function (quadratic function and 3 layer neural network in our case) with xs generated from a standard Gaussian distribution. Experimental results Our results on the in-context learning ability of transformer is shown in Figure 1 and Figure 2. The transformer learns quadratic function well as we can see from Figure 1 that the test loss goes to 0 as in-context example increases and 3-Nearest Neighbors can t solve quadratic regression. The transformer can also learn the 3-layer neural network, where the test loss curve matches the gradient descent baseline, supporting our theoretical results.
Researcher Affiliation Academia 1Shanghai Jiao Tong University, Shanghai, China 2John Hopcroft Center for Computer Science, Shanghai Jiao Tong University, Shanghai, China.
Pseudocode No The paper does not contain structured pseudocode or algorithm blocks.
Open Source Code No The paper does not provide concrete access to source code for the methodology described in this paper.
Open Datasets No In the inference stage we generate prompts (x, y)s according to the corresponding function (quadratic function and 3 layer neural network in our case) with xs generated from a standard Gaussian distribution. We consider a d dimensional quadratic regression task with in-context examples of the form z = (x, y) Rd R, where xs are sampled i.i.d. from standard Gaussian distribution, and y = w x 2, where w N(0, Σ). We set dimension d = 20 in Figure 1. We consider a 3 layer neural network task with in-context examples of the form z = (x, y) Rd R, where xs are sampled i.i.d. from standard Gaussian distribution, and y = W(3)(r(W(2)((r(W(1)x))))). as in Definition 2.4, where Wij N(0, Σ), and we set the activation function r to be Re LU activation. We set the width of each hidden layer to be K = 50, and the input dimension d = 20.
Dataset Splits No The paper does not provide specific dataset split information needed to reproduce the data partitioning.
Hardware Specification No The paper does not provide specific hardware details used for running its experiments.
Software Dependencies No The paper does not provide specific ancillary software details with version numbers.
Experiment Setup Yes In the pretraining stage we set the learning rate of stochastic gradient descent to 3 10 4 and train for 100000 steps with batch size 32. Model architecture Our experiment setup follows Garg et al. 2022. We train a transformer model (GPT-2 structure) with 12 layers, 8 attention heads and 256 hidden dimensions.