REPAIR: REnormalizing Permuted Activations for Interpolation Repair

Authors: Keller Jordan, Hanie Sedghi, Olga Saukh, Rahim Entezari, Behnam Neyshabur

ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this paper we look into the conjecture of Entezari et al. (2021) which states that if the permutation invariance of neural networks is taken into account, then there is likely no loss barrier to the linear interpolation between SGD solutions. First, we observe that neuron alignment methods alone are insufficient to establish lowbarrier linear connectivity between SGD solutions due to a phenomenon we call variance collapse: interpolated deep networks suffer a collapse in the variance of their activations, causing poor performance. Next, we propose REPAIR (REnormalizing Permuted Activations for Interpolation Repair) which mitigates variance collapse by rescaling the preactivations of such interpolated networks. We explore the interaction between our method and the choice of normalization layer, network width, and depth, and demonstrate that using REPAIR on top of neuron alignment methods leads to 60%-100% relative barrier reduction across a wide variety of architecture families and tasks. In particular, we report a 74% barrier reduction for Res Net50 on Image Net and 90% barrier reduction for Res Net18 on CIFAR10.
Researcher Affiliation Collaboration Keller Jordan1, Hanie Sedghi2, Olga Saukh3, Rahim Entezari3 & Behnam Neyshabur2 Hive AI1 Google Research2 TU Graz / CSH Vienna3 keller@thehive.ai, {hsedghi, neyshabur}@google.com {olga.saukh, rahim.entezari}@gmail.com
Pseudocode Yes B.2 PSEUDOCODE FOR PYTORCH MODULES 1. The first step of REPAIR is to measure the statistics of identified channels in the endpoint networks. There are many ways to do this, but we found that the following was efficient and had low code-complexity in a Pytorch environment. For each module in the interpolated network whose outputs we wish to REPAIR, we wrap the corresponding modules in the endpoint networks with the following: class Track Layer(nn.Module): def __init__(self, layer): super().__init__() self.layer = layer self.bn = nn.Batch Norm2d(len(layer.weight)) self.bn.train() self.layer.eval() def get_stats(self): return (self.bn.running_mean, self.bn.running_var.sqrt()) def forward(self, inputs): outputs = self.layer(inputs) # Apply Batch Norm so that the running mean/variance are updated; discard the output. self.bn(outputs) return outputs
Open Source Code Yes Our code is available at https://github.com/Keller Jordan/REPAIR.
Open Datasets Yes We report a 74% barrier reduction for Res Net50 on Image Net and 90% barrier reduction for Res Net18 on CIFAR10. We begin by considering Res Net20s trained on CIFAR-10. We test Res Net18, Res Net50, and a double-width variant of Res Net50 in Figure 6 (right). Without REPAIR, the interpolated midpoints between each aligned pair of networks perform at below 1% (top-1) accuracy on the Image Net validation set. Next, we explore the impact of REPAIR on the barrier to interpolation between standard Res Net models trained from scratch on Image Net (Deng et al., 2009). We first split the CIFAR-100 training set, consisting of 50,000 images distributed across 100 classes, into two disjoint sets of 25,000 images.
Dataset Splits Yes We first split the CIFAR-100 training set, consisting of 50,000 images distributed across 100 classes, into two disjoint sets of 25,000 images. The first split contains a random 80% of the images in the first 50 classes, and 20% of the second 50 classes; the second split has the proportions reversed.
Hardware Specification No The paper does not specify any hardware details such as GPU models, CPU types, or memory used for the experiments.
Software Dependencies No The paper mentions PyTorch in the pseudocode section but does not provide specific version numbers for PyTorch or any other software dependencies.
Experiment Setup Yes Table 1 summarizes the hyperparameters we used to train the neural networks which appear in this work. We train all networks using SGD with momentum 0.9. The weight decay and learning rates differ for each task, and are specified below. For our MLP trainings, we keep the hyperparameters below constant across varying widths, depths, and datasets. When training on MNIST and SVHN, we remove cutout and horizontal flip from the list of data augmenations used by our Res Net20 training. Otherwise, we keep the below Res Net20 hyperparameters constant across different choices of width, dataset, and normalization layer. Table 1: Training hyperparameters Hyper-parameters MLP VGG Res Net20 Res Net50/Image Net Batch Size 2000 500 500 512 Epochs 100 100 200 300 Learning Rate Linear 0.2 Cosine 0.08 Cosine 0.4 Linear 0.5 Weight decay 0.0 0.0005 0.0001 0.0001 Data augmentation Translate Flip/Translate Flip/Translate+Cutout Flip/RRC