Continual Learning with Global Alignment
Authors: Xueying Bai, Jinghuan Shang, Yifan Sun, Niranjan Balasubramanian
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Without experience replay, our model achieves SOTA performance in continual learning tasks. It also achieves advanced class-incremental performance through task-incremental training. The code is available at: https://github.com/Stony Brook NLP/global-alignment. ... 5 Experiments |
| Researcher Affiliation | Academia | Xueying Bai Jinghuan Shang Yifan Sun Niranjan Balasubramanian Department of Computer Science Stony Brook University {xubai, jishang, ysun, niranjan}@cs.stonybrook.edu |
| Pseudocode | No | The paper does not contain structured pseudocode or algorithm blocks (clearly labeled algorithm sections or code-like formatted procedures). |
| Open Source Code | Yes | The code is available at: https://github.com/Stony Brook NLP/global-alignment. |
| Open Datasets | Yes | (1) Yahoo: a split of Yahoo dataset for news question-answer categorization [63] with 5 disjoint tasks containing 2 classes each; (2) DB: a split of DBPedia data for Wikipedia article classification [63] with 7 disjoint tasks containing 2 classes each; (3) News Series: a sequence of tasks on news-related data, including AG_news (news classification, 4 classes), MRPC (paraphrase detection, 2 classes) [12], RTE (text entailment, 2 classes) [58] and SST (sentiment analysis, 2 classes) [51]; (4). All: All tasks in the above sequences. |
| Dataset Splits | No | The paper describes training and testing procedures but does not explicitly provide specific validation dataset splits (percentages, sample counts, or explicit mention of a validation set). |
| Hardware Specification | Yes | We perform all experiments on one Nvidia RTX A6000 machine. |
| Software Dependencies | No | The paper mentions using a 'BERT-base encoder' and 'transformer-based models', but it does not specify any software dependencies with version numbers (e.g., Python version, PyTorch/TensorFlow versions, or specific library versions). |
| Experiment Setup | Yes | Probing: We fix the encoder and only train the classifier. We train 5 epochs for each taskwith the learning rate 5e-4. FT: We fine-tune the whole model, including the encoder and classifier. We train 3 epochs for each task in BERT, with the learning rate 2e-5. Adapter: we select learning rates from {5e-5, 1e-4, 1e-3} and train {5, 20} epochs for each task. For all continual learning tasks, we train with the learning rate 5e-5 and 20 epochs. The rank of Adapter projection r is 32 as suggested in the original paper. Lo RA (C-Lo RA): we select learning rates from {5e-4, 1e-3} and train Lo RA for {5, 8} epochs for each task. Wire-Fixed and Wire-Neigh: we select learning rates from {2e-4, 5e-4, 1e-3} and train {5, 8} epochs for each task. The rank of the learnable key matrix is 8. For Wire-Neigh, the number of neighbors is 5. In practice, they are randomly sampled from a larger neighborhood ranging from {10, 20,50,100}. We set the mixing ratio as s = 0.1. IDBR: We train IDBR with the learning rate 3e-5 for 3 epoches per task. We follow the k-means memory selection rule, and the replay batch size is 16 (training batch size) number of tasks in the memory. CTR: We follow the settings in the original paper, training 5 epochs for each task. L2P: We have the prompt pool with 100 prompt tokens and select 50 of them to prepend to the input. We train the model with the learning rate 1e-3 for 20 epochs for each task. CODA: We have the prompt component size 20 for each task, and set the prompt length as 20. We train the model with the learning rate 1e-3 for 10 epochs for each task. ER: We apply sparse experience replay with 1% replay ratio. At each replay time, we sample 32 samples from the memory and perform one-step gradient descent based on them. A-GEM: We store all previous data in the memory. At each gradient step, we randomly extract 32 samples from the memory and apply the A-GEM gradient projection. MBPA++: We fine-tune the model with ER and then adapt the model at the inference time. At the inference time, we retrieve 32 nearest samples in the memory for local adaptation. |