Linear Mode Connectivity in Multitask and Continual Learning
Authors: Seyed Iman Mirzadeh, Mehrdad Farajtabar, Dilan Gorur, Razvan Pascanu, Hassan Ghasemzadeh
ICLR 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We empirically find that indeed such connectivity can be reliably achieved and, more interestingly, it can be done by a linear path, conditioned on having the same initialization for both. We show that our method outperforms several state of the art continual learning algorithms on various vision benchmarks. |
| Researcher Affiliation | Collaboration | Seyed Iman Mirzadeh Washington State University, USA seyediman.mirzadeh@wsu.edu; Mehrdad Farajtabar Deep Mind, USA farajtabar@google.com; Dilan Gorur Deep Mind, USA dilang@google.com; Razvan Pascanu Deep Mind, UK razp@google.com; Hassan Ghasemzadeh Washington State University, USA hassan.ghasemzadeh@wsu.edu |
| Pseudocode | Yes | def get_grads_on_interpolation_line(w_hat, w_bar, alphas, memory): """ Get accumulated gradients on interpolation line between w_hat and w_bar """ accumulated_gradients_on_line = 0.0 # for each point on the interpolation line for alpha in alphas: interpolation_grads = [] # (1) calculate the parameters on the interpolation line interpolation_w = w_hat + (w_bar w_hat) * alpha model = make_model_from_param_vector(interpolation_w) loss = calculate_loss(model, memory) # (2) calculate the loss to obtain gradients loss.backward() # Now the gradients are stored in the graph # (3) collect the gradients for param in model.parameters(): # get gradients for each module/block/layer interpolation_grads.append(param.grad.view(-1)) interpolation_grads = torch.cat(interpolation_grads) # (4) accumulate gradients accumulated_gradients_on_line += interpolation_grads return accumulated_gradients_on_line |
| Open Source Code | Yes | The code is available at:https://github.com/imirzadeh/MC-SGD |
| Open Datasets | Yes | We report on three standard continual learning benchmarks: Permuted MNIST (Goodfellow et al., 2013), Rotated MNIST, and Split CIFAR-100. |
| Dataset Splits | Yes | Average Accuracy: The average validation accuracy after the model has been continually learned task t, is defined by At = 1/t Sum at,i, where, at,i is the validation accuracy on dataset i after the model finished learning task t. |
| Hardware Specification | No | The paper does not provide specific details on the hardware used, such as CPU or GPU models, for running the experiments. |
| Software Dependencies | No | To implement MC-SGD, we use Py Torch (Paszke et al., 2019) because of its dynamic graph capability. |
| Experiment Setup | Yes | The experimental setup, such as benchmarks, network architectures, continual learning setting (e.g., number of tasks, episodic memory size, and training epochs per task), hyper-parameters, and evaluation metrics are chosen to be similar to several other studies (Chaudhry et al., 2018b; Mirzadeh et al., 2020; Chaudhry et al., 2019; Farajtabar et al., 2019; Chaudhry et al., 2019). For the experiment in Section 5.1, we have used the following grid for each model. We note that for other algorithms (e.g., A-GEM, and EWC), we ensured that our grid contains the optimal values that the original papers reported. learning rate: [0.25, 0.1, 0.01 (MNIST, CIFAR-100), 0.001] batch size: 10 [...] |