Learning and Transferring Sparse Contextual Bigrams with Linear Transformers

Authors: Yunwei Ren, Zixuan Wang, Jason D. Lee

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this paper, we investigate the training dynamics and sample complexity of training a linear transformer to learn the SCB task using a stochastic gradient-based algorithm. Our contributions are summarized as follows: Data model: We introduce the Sparse Contextual Bigram (SCB) model, a simple task that requires the model to learn both in-context and global information. Convergence: We prove convergence guarantees for a one-layer linear transformer trained on with the nonconvex β„“1-regularized MSE loss using preconditioned projected proximal descent, given a dataset sampled from the SCB model. Sample Complexity: Under mild conditions on the data distribution, initialization, and hyperparameters, we prove that our algorithm can recover the ground-truth with polynomial dependence on the sequence length 𝑇, number of states 𝑁, and the sparsity parameter 𝑄 𝑇. We show that the training first goes through an initial sample-intensive stage which boosts the signal with poly(𝑇) samples, followed by a more sample-efficient stage to achieve final convergence with poly(𝑁, 𝑄) samples. We empirically verify that our gradient-based methods converge to the ground truth with a small batch size, while unregularized stochastic gradient descent fails due to the large variance. Transfer Learning: We prove that, when there is a nontrivial correlation between the pretraining and downstream tasks, we can transfer a pre-trained model to bypass the first sample intensive stage, so that our algorithm converges to the ground truth of the downstream task with only poly(𝑁, 𝑄) samples.
Researcher Affiliation Academia Yunwei Ren Zixuan Wang Jason D. Lee Princeton University {yunwei.ren, wangzx, jasonlee}@princeton.edu
Pseudocode Yes Algorithm 1 Projected preconditioned β„“1-proximal gradient descent
Open Source Code Yes We will upload the codes in the supplementary materials.
Open Datasets No Our experiments ... train on the synthetic data. The data distribution follows the SCB model (1) with a randomly sampled transition matrix 𝑷together with its stationary πœ‡, and the ground truth attention pattern 𝑸.
Dataset Splits No The paper does not explicitly mention standard train/validation/test splits with percentages or counts. It states that training is done on "synthetic data".
Hardware Specification Yes For all our experiments, we use Numpy and run on a normal laptop which takes about 20 minutes.
Software Dependencies No For all our experiments, we use Numpy and run on a normal laptop which takes about 20 minutes.
Experiment Setup Yes We choose the number of states 𝑁= 3, sparsity 𝑄= 2, and the sequence length 𝑇= 5000 𝑁, 𝑄. We use a batch size 𝐡= 64 to run the online projected proximal gradient descent with πœ†= 1e-5 and the vanilla SGD for T = 1000 iterations. Through the signal boosting stage 𝜏 [0, 400], we use πœ‚1 = 0.01 to accelerate the process. After 𝜏> 400, we use πœ‚2 = 0.005 for further improvement. For SGD, we add another set of experiments with πœ‚ 2 = 0.001 to prevent potential instability.