Adaptive Orthogonal Projection for Batch and Online Continual Learning

Authors: Yiduo Guo, Wenpeng Hu, Dongyan Zhao, Bing Liu6783-6791

AAAI 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experimental evaluation shows that AOP not only outperforms the original OWM but also other existing state-of-the-art baselines markedly in both batch and online CL settings. To the best of our knowledge, it is the first gradient orthogonal method for online CL. Note that both OWM and AOP do not save any training examples or build data generators. AOP is thus more general and applicable to settings where the training data is no longer accessible after learning. The inaccessibility of the old data could be due to unrecorded legacy data, proprietary data, and data privacy, e.g., in federated learning (Zhang et al. 2020).
Researcher Affiliation Academia Yiduo Guo1, Wenpeng Hu2, Dongyan Zhao1,*, Bing Liu3,* 1 Wangxuan Institute of Computer Technology, Peking University 2 School of Mathematical Sciences, Peking University 3 Department of Computer Science, University of Illinois at Chicago yiduo@stu.pku.edu.cn, {wenpeng.hu, zhaody}@pku.edu.cn, liub@uic.edu
Pseudocode Yes Equations (10) and (11) describe the iterative update rules for the orthogonal projector and weight matrix, respectively, in a structured, algorithmic format, akin to pseudocode: 'Pl(i, j) =Pl(i 1, j) Ql(i, j)xl 1(i, j)T Pl(i 1, j) Ql(i, j) =Pl(i 1, j)xl 1(i, j)/[αj i + xl 1(i, j)T Pl(i 1, j) xl 1(i, j)] αj i =average(E(xl 1(i, j)(xl 1(i, j))T )) =average(E(xl 1(i, j))E(xl 1(i, j))T + Cov(xl 1)(i, j)) average(xl 1(i, j)(xl 1)(i, j)T + Cov(xl 1(i, j))) Pl(0, 0) =I Pl(0, j) =Pl(nj 1, j 1) (10) where xl 1(i, j) is the output of the (l-1)th layer in response to the average of inputs in the ith batch of task j, average(.) gets the average value of all entries in a matrix, and nj 1 is the number of batches in task j 1. After the inputs of the ith batch in task j go through the network, we calculate the projector and then update the weight matrix of each layer by Wl(i, j) =Wl(i 1, j) κ(i, j) WBP l (i, j) if j = 1 Wl(i, j) =Wl(i 1, j) κ(i, j)Pl(nj 1, j 1) WBP l (i, j) if j = 2, 3, ... Pl(0, 0) =Il (11)'
Open Source Code Yes The code of AOP : https://github.com/gydpku/Official-pytorch-implementation-of-AOP-AAAI-2022-.
Open Datasets Yes Four image classification datasets are used in our experiments: MNIST, EMNIST-47 (Cohen et al. 2017), CIFAR10 and CIFAR100 (Krizhevsky and Hinton 2009).
Dataset Splits Yes For batch CL, we tune the learning rate by using 10% of randomly selected training examples of each dataset as the validation set.
Hardware Specification No The paper does not provide specific details about the hardware used for the experiments, such as GPU models, CPU types, or memory specifications. It only discusses model architectures (e.g., 'two-layer MLP', 'CNN model', 'WRN-28 model').
Software Dependencies No The paper mentions using the 'SGD optimizer (momentum=0.9)' and 'Pytorch' in the context of the provided code link, but it does not specify version numbers for any software dependencies. For example, it does not state 'PyTorch 1.9' or 'Python 3.x'.
Experiment Setup Yes Hyper-parameters: We set the batch-size as 64 for all methods in batch CL and set the batch-size as 10 for all methods in online CL. For replay-based online CL baselines, we set their buffer batch size (data sampled from the replay buffer) as 10 as well following existing work (Aljundi et al. 2019a). The proposed AOP method does not need to save any data. The learning rates for different data are as follows: 0.5 for MNIST, 0.2 for EMNIST-47, CIFAR10 (without pre-training) and CIFAR100 (without pre-training). For experiments that use pre-trained features in batch CL, we set the learning rate to 0.35 for the one class per task setting, and 0.04 for the multi-class per task setting. For all online CL experiments, the learning rates of AOP for different datasets are as follows: 0.03 for MNIST, 0.02 for CIFAR100 and CIFAR10.