Training Data Attribution via Approximate Unrolling
Authors: Juhan Bae, Wu Lin, Jonathan Lorraine, Roger B. Grosse
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We evaluate SOURCE for counterfactual prediction across various tasks, including regression, image classification, and text classification. SOURCE outperforms existing TDA techniques in counterfactual prediction, especially in settings where implicit-differentiation-based approaches fall short. |
| Researcher Affiliation | Collaboration | 1University of Toronto, 2Vector Institute, 3NVIDIA, 3Anthropic |
| Pseudocode | No | In this section, we introduce SOURCE, a gradient-based TDA technique combining the advantages of implicit and unrolled differentiation. |
| Open Source Code | Yes | The code for implementing SOURCE (as well as baseline techniques) will be provided at https://github.com/pomonam/kronfluence. |
| Open Datasets | Yes | Our experiments consider diverse machine learning tasks, including: (a) regression using datasets from the UCI Repository [45], (b) image classification with datasets such as MNIST [57], Fashion MNIST [91], CIFAR-10 [53], Rotated MNIST [20], and PACS [58], and (c) text classification using the GLUE benchmark [85]. |
| Dataset Splits | Yes | We conducted systematic hyperparameter optimization for all tasks. This process involved conducting grid searches to find hyperparameter configurations that achieve the best average validation performance (accuracy for classification tasks and loss for others). |
| Hardware Specification | Yes | We used CPUs to conduct UCI regression experiments, A100 (80GB) GPUs to conduct GLUE and Wiki Text-2 experiments, and A6000 (48GB) GPUs for other experiments. |
| Software Dependencies | Yes | All experiments were conducted using PYTORCH version 2.1.0 [71]. |
| Experiment Setup | Yes | We conducted systematic hyperparameter optimization for all tasks. This process involved conducting grid searches to find hyperparameter configurations that achieve the best average validation performance (accuracy for classification tasks and loss for others). The models were optimized using SGDm for 20 epochs with a batch size of 32 and a constant learning rate schedule. |