Continual Learning with Global Alignment

Authors: Xueying Bai, Jinghuan Shang, Yifan Sun, Niranjan Balasubramanian

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Without experience replay, our model achieves SOTA performance in continual learning tasks. It also achieves advanced class-incremental performance through task-incremental training. The code is available at: https://github.com/Stony Brook NLP/global-alignment. ... 5 Experiments
Researcher Affiliation Academia Xueying Bai Jinghuan Shang Yifan Sun Niranjan Balasubramanian Department of Computer Science Stony Brook University {xubai, jishang, ysun, niranjan}@cs.stonybrook.edu
Pseudocode No The paper does not contain structured pseudocode or algorithm blocks (clearly labeled algorithm sections or code-like formatted procedures).
Open Source Code Yes The code is available at: https://github.com/Stony Brook NLP/global-alignment.
Open Datasets Yes (1) Yahoo: a split of Yahoo dataset for news question-answer categorization [63] with 5 disjoint tasks containing 2 classes each; (2) DB: a split of DBPedia data for Wikipedia article classification [63] with 7 disjoint tasks containing 2 classes each; (3) News Series: a sequence of tasks on news-related data, including AG_news (news classification, 4 classes), MRPC (paraphrase detection, 2 classes) [12], RTE (text entailment, 2 classes) [58] and SST (sentiment analysis, 2 classes) [51]; (4). All: All tasks in the above sequences.
Dataset Splits No The paper describes training and testing procedures but does not explicitly provide specific validation dataset splits (percentages, sample counts, or explicit mention of a validation set).
Hardware Specification Yes We perform all experiments on one Nvidia RTX A6000 machine.
Software Dependencies No The paper mentions using a 'BERT-base encoder' and 'transformer-based models', but it does not specify any software dependencies with version numbers (e.g., Python version, PyTorch/TensorFlow versions, or specific library versions).
Experiment Setup Yes Probing: We fix the encoder and only train the classifier. We train 5 epochs for each taskwith the learning rate 5e-4. FT: We fine-tune the whole model, including the encoder and classifier. We train 3 epochs for each task in BERT, with the learning rate 2e-5. Adapter: we select learning rates from {5e-5, 1e-4, 1e-3} and train {5, 20} epochs for each task. For all continual learning tasks, we train with the learning rate 5e-5 and 20 epochs. The rank of Adapter projection r is 32 as suggested in the original paper. Lo RA (C-Lo RA): we select learning rates from {5e-4, 1e-3} and train Lo RA for {5, 8} epochs for each task. Wire-Fixed and Wire-Neigh: we select learning rates from {2e-4, 5e-4, 1e-3} and train {5, 8} epochs for each task. The rank of the learnable key matrix is 8. For Wire-Neigh, the number of neighbors is 5. In practice, they are randomly sampled from a larger neighborhood ranging from {10, 20,50,100}. We set the mixing ratio as s = 0.1. IDBR: We train IDBR with the learning rate 3e-5 for 3 epoches per task. We follow the k-means memory selection rule, and the replay batch size is 16 (training batch size) number of tasks in the memory. CTR: We follow the settings in the original paper, training 5 epochs for each task. L2P: We have the prompt pool with 100 prompt tokens and select 50 of them to prepend to the input. We train the model with the learning rate 1e-3 for 20 epochs for each task. CODA: We have the prompt component size 20 for each task, and set the prompt length as 20. We train the model with the learning rate 1e-3 for 10 epochs for each task. ER: We apply sparse experience replay with 1% replay ratio. At each replay time, we sample 32 samples from the memory and perform one-step gradient descent based on them. A-GEM: We store all previous data in the memory. At each gradient step, we randomly extract 32 samples from the memory and apply the A-GEM gradient projection. MBPA++: We fine-tune the model with ER and then adapt the model at the inference time. At the inference time, we retrieve 32 nearest samples in the memory for local adaptation.