Learning and Transferring Sparse Contextual Bigrams with Linear Transformers
Authors: Yunwei Ren, Zixuan Wang, Jason D. Lee
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this paper, we investigate the training dynamics and sample complexity of training a linear transformer to learn the SCB task using a stochastic gradient-based algorithm. Our contributions are summarized as follows: Data model: We introduce the Sparse Contextual Bigram (SCB) model, a simple task that requires the model to learn both in-context and global information. Convergence: We prove convergence guarantees for a one-layer linear transformer trained on with the nonconvex β1-regularized MSE loss using preconditioned projected proximal descent, given a dataset sampled from the SCB model. Sample Complexity: Under mild conditions on the data distribution, initialization, and hyperparameters, we prove that our algorithm can recover the ground-truth with polynomial dependence on the sequence length π, number of states π, and the sparsity parameter π π. We show that the training first goes through an initial sample-intensive stage which boosts the signal with poly(π) samples, followed by a more sample-efficient stage to achieve final convergence with poly(π, π) samples. We empirically verify that our gradient-based methods converge to the ground truth with a small batch size, while unregularized stochastic gradient descent fails due to the large variance. Transfer Learning: We prove that, when there is a nontrivial correlation between the pretraining and downstream tasks, we can transfer a pre-trained model to bypass the first sample intensive stage, so that our algorithm converges to the ground truth of the downstream task with only poly(π, π) samples. |
| Researcher Affiliation | Academia | Yunwei Ren Zixuan Wang Jason D. Lee Princeton University {yunwei.ren, wangzx, jasonlee}@princeton.edu |
| Pseudocode | Yes | Algorithm 1 Projected preconditioned β1-proximal gradient descent |
| Open Source Code | Yes | We will upload the codes in the supplementary materials. |
| Open Datasets | No | Our experiments ... train on the synthetic data. The data distribution follows the SCB model (1) with a randomly sampled transition matrix π·together with its stationary π, and the ground truth attention pattern πΈ. |
| Dataset Splits | No | The paper does not explicitly mention standard train/validation/test splits with percentages or counts. It states that training is done on "synthetic data". |
| Hardware Specification | Yes | For all our experiments, we use Numpy and run on a normal laptop which takes about 20 minutes. |
| Software Dependencies | No | For all our experiments, we use Numpy and run on a normal laptop which takes about 20 minutes. |
| Experiment Setup | Yes | We choose the number of states π= 3, sparsity π= 2, and the sequence length π= 5000 π, π. We use a batch size π΅= 64 to run the online projected proximal gradient descent with π= 1e-5 and the vanilla SGD for T = 1000 iterations. Through the signal boosting stage π [0, 400], we use π1 = 0.01 to accelerate the process. After π> 400, we use π2 = 0.005 for further improvement. For SGD, we add another set of experiments with π 2 = 0.001 to prevent potential instability. |