Repairing Neural Networks by Leaving the Right Past Behind
Authors: Ryutaro Tanno, Melanie F. Pradier, Aditya Nori, Yingzhen Li
NeurIPS 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Experimentally, the proposed approach outperforms the baselines for both identification of detrimental training data and fixing model failures in a generalisable manner. We evaluate the efficacy of the proposed framework in a) identifying the causes of target prediction failures in Sec. 4.1, and b) repairing the original model by erasing the memories of such causes in Sec. 4.2. |
| Researcher Affiliation | Collaboration | Ryutaro Tanno 1 Melanie F. Pradier1 Aditya Nori1 Yingzhen Li2 1Microsoft Health Futures, Cambridge, UK 2Imperial College London, UK Now at Deep Mind, UK. |
| Pseudocode | Yes | Algorithm 1 Model Repairment Input: training data D; failure cases F; approximate posterior q(θ) p(θ|D); likelihood p(z|θ) Output: failure causes C, repaired posterior q C(θ) # Step I: Cause Identification Update posterior: Apply a continual learning method to obtain q+F(θ) p(θ|D,F) by fitting the failure set F Compute influences of training examples on F: Calculate r(z) z D (Eq. (9)) Find failure causes C: Return the examples with positive influence, C {z D: r(z)>0} # Step II: Treatment Delete information of C: Apply a continual (un)learning method to the original posterior q(θ), and obtain the posterior on the corrected data q C(θ) p(θ|D\C) |
| Open Source Code | No | The paper states in the checklist 'Did you include the code, data, and instructions needed to reproduce the main experimental results (either in the supplemental material or as a URL)? [Yes]', but it does not provide a concrete URL or specific instructions for accessing the source code for the methodology in the main body of the paper or appendix. |
| Open Datasets | Yes | We use augmented versions of MNIST and CIFAR-10 datasets with simulated annotation and input noise. We train the base classification models on the training split of the augmented MNIST and CIFAR-10 datasets. |
| Dataset Splits | Yes | For MNIST, we use 6% (3000 samples) of the original training set to make the task more challenging. For evaluation, we separate the test set T into the set of misclassified examples, F ( failure set ) and the others, T \F which are correctly classified ( remaining set ). We further split the failure set into query, Fq and holdout, Fh sets, where we only use the former to identify failure causes C, and use the latter to quantify how generalisably the removal of C can amend the failure cases. |
| Hardware Specification | No | Table 1 in Appendix B shows the total run-time of cause identification methods on a single GPU for their best sets of hyper-parameters selected based on the treatment accuracy on the failure set. The paper mentions 'single GPU' but does not specify the exact model or further hardware details. |
| Software Dependencies | No | We implement the proposed method and baselines in PyTorch. However, the paper does not specify the version number of PyTorch or any other software dependencies. |
| Experiment Setup | Yes | We train the base classification models on the training split of the augmented MNIST and CIFAR-10 datasets. The architecture and training details can be found in Appendix C. In Appendix C, it states: 'We use Adam with a learning rate of 0.001. The model is trained for 200 epochs for MNIST and 50 epochs for CIFAR-10... For MNIST, we use a 3-layer CNN model with the following architecture: (Conv2d (1, 32, kernel_size=5, padding=2), ReLU, MaxPool2d (kernel_size=2, stride=2)) * 2, Linear (1568, 50), ReLU, Linear (50, 10)). For CIFAR-10, we use a ResNet-18 model...' |