Linear Mode Connectivity in Multitask and Continual Learning

Authors: Seyed Iman Mirzadeh, Mehrdad Farajtabar, Dilan Gorur, Razvan Pascanu, Hassan Ghasemzadeh

ICLR 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We empirically find that indeed such connectivity can be reliably achieved and, more interestingly, it can be done by a linear path, conditioned on having the same initialization for both. We show that our method outperforms several state of the art continual learning algorithms on various vision benchmarks.
Researcher Affiliation Collaboration Seyed Iman Mirzadeh Washington State University, USA seyediman.mirzadeh@wsu.edu; Mehrdad Farajtabar Deep Mind, USA farajtabar@google.com; Dilan Gorur Deep Mind, USA dilang@google.com; Razvan Pascanu Deep Mind, UK razp@google.com; Hassan Ghasemzadeh Washington State University, USA hassan.ghasemzadeh@wsu.edu
Pseudocode Yes def get_grads_on_interpolation_line(w_hat, w_bar, alphas, memory): """ Get accumulated gradients on interpolation line between w_hat and w_bar """ accumulated_gradients_on_line = 0.0 # for each point on the interpolation line for alpha in alphas: interpolation_grads = [] # (1) calculate the parameters on the interpolation line interpolation_w = w_hat + (w_bar w_hat) * alpha model = make_model_from_param_vector(interpolation_w) loss = calculate_loss(model, memory) # (2) calculate the loss to obtain gradients loss.backward() # Now the gradients are stored in the graph # (3) collect the gradients for param in model.parameters(): # get gradients for each module/block/layer interpolation_grads.append(param.grad.view(-1)) interpolation_grads = torch.cat(interpolation_grads) # (4) accumulate gradients accumulated_gradients_on_line += interpolation_grads return accumulated_gradients_on_line
Open Source Code Yes The code is available at:https://github.com/imirzadeh/MC-SGD
Open Datasets Yes We report on three standard continual learning benchmarks: Permuted MNIST (Goodfellow et al., 2013), Rotated MNIST, and Split CIFAR-100.
Dataset Splits Yes Average Accuracy: The average validation accuracy after the model has been continually learned task t, is defined by At = 1/t Sum at,i, where, at,i is the validation accuracy on dataset i after the model finished learning task t.
Hardware Specification No The paper does not provide specific details on the hardware used, such as CPU or GPU models, for running the experiments.
Software Dependencies No To implement MC-SGD, we use Py Torch (Paszke et al., 2019) because of its dynamic graph capability.
Experiment Setup Yes The experimental setup, such as benchmarks, network architectures, continual learning setting (e.g., number of tasks, episodic memory size, and training epochs per task), hyper-parameters, and evaluation metrics are chosen to be similar to several other studies (Chaudhry et al., 2018b; Mirzadeh et al., 2020; Chaudhry et al., 2019; Farajtabar et al., 2019; Chaudhry et al., 2019). For the experiment in Section 5.1, we have used the following grid for each model. We note that for other algorithms (e.g., A-GEM, and EWC), we ensured that our grid contains the optimal values that the original papers reported. learning rate: [0.25, 0.1, 0.01 (MNIST, CIFAR-100), 0.001] batch size: 10 [...]