In-context Learning on Function Classes Unveiled for Transformers
Authors: Zhijie Wang, Bo Jiang, Shuai Li
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | A. Experiment In Figure 1 and Figure 2, we show that a pretrained transformer can learn a quadratic function and a 3-layer neural network without any parameter update. In the pretraining stage we set the learning rate of stochastic gradient descent to 3 10 4 and train for 100000 steps with batch size 32. In the inference stage we generate prompts (x, y)s according to the corresponding function (quadratic function and 3 layer neural network in our case) with xs generated from a standard Gaussian distribution. Experimental results Our results on the in-context learning ability of transformer is shown in Figure 1 and Figure 2. The transformer learns quadratic function well as we can see from Figure 1 that the test loss goes to 0 as in-context example increases and 3-Nearest Neighbors can t solve quadratic regression. The transformer can also learn the 3-layer neural network, where the test loss curve matches the gradient descent baseline, supporting our theoretical results. |
| Researcher Affiliation | Academia | 1Shanghai Jiao Tong University, Shanghai, China 2John Hopcroft Center for Computer Science, Shanghai Jiao Tong University, Shanghai, China. |
| Pseudocode | No | The paper does not contain structured pseudocode or algorithm blocks. |
| Open Source Code | No | The paper does not provide concrete access to source code for the methodology described in this paper. |
| Open Datasets | No | In the inference stage we generate prompts (x, y)s according to the corresponding function (quadratic function and 3 layer neural network in our case) with xs generated from a standard Gaussian distribution. We consider a d dimensional quadratic regression task with in-context examples of the form z = (x, y) Rd R, where xs are sampled i.i.d. from standard Gaussian distribution, and y = w x 2, where w N(0, Σ). We set dimension d = 20 in Figure 1. We consider a 3 layer neural network task with in-context examples of the form z = (x, y) Rd R, where xs are sampled i.i.d. from standard Gaussian distribution, and y = W(3)(r(W(2)((r(W(1)x))))). as in Definition 2.4, where Wij N(0, Σ), and we set the activation function r to be Re LU activation. We set the width of each hidden layer to be K = 50, and the input dimension d = 20. |
| Dataset Splits | No | The paper does not provide specific dataset split information needed to reproduce the data partitioning. |
| Hardware Specification | No | The paper does not provide specific hardware details used for running its experiments. |
| Software Dependencies | No | The paper does not provide specific ancillary software details with version numbers. |
| Experiment Setup | Yes | In the pretraining stage we set the learning rate of stochastic gradient descent to 3 10 4 and train for 100000 steps with batch size 32. Model architecture Our experiment setup follows Garg et al. 2022. We train a transformer model (GPT-2 structure) with 12 layers, 8 attention heads and 256 hidden dimensions. |