Training Data Attribution via Approximate Unrolling

Authors: Juhan Bae, Wu Lin, Jonathan Lorraine, Roger B. Grosse

NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We evaluate SOURCE for counterfactual prediction across various tasks, including regression, image classification, and text classification. SOURCE outperforms existing TDA techniques in counterfactual prediction, especially in settings where implicit-differentiation-based approaches fall short.
Researcher Affiliation Collaboration 1University of Toronto, 2Vector Institute, 3NVIDIA, 3Anthropic
Pseudocode No In this section, we introduce SOURCE, a gradient-based TDA technique combining the advantages of implicit and unrolled differentiation.
Open Source Code Yes The code for implementing SOURCE (as well as baseline techniques) will be provided at https://github.com/pomonam/kronfluence.
Open Datasets Yes Our experiments consider diverse machine learning tasks, including: (a) regression using datasets from the UCI Repository [45], (b) image classification with datasets such as MNIST [57], Fashion MNIST [91], CIFAR-10 [53], Rotated MNIST [20], and PACS [58], and (c) text classification using the GLUE benchmark [85].
Dataset Splits Yes We conducted systematic hyperparameter optimization for all tasks. This process involved conducting grid searches to find hyperparameter configurations that achieve the best average validation performance (accuracy for classification tasks and loss for others).
Hardware Specification Yes We used CPUs to conduct UCI regression experiments, A100 (80GB) GPUs to conduct GLUE and Wiki Text-2 experiments, and A6000 (48GB) GPUs for other experiments.
Software Dependencies Yes All experiments were conducted using PYTORCH version 2.1.0 [71].
Experiment Setup Yes We conducted systematic hyperparameter optimization for all tasks. This process involved conducting grid searches to find hyperparameter configurations that achieve the best average validation performance (accuracy for classification tasks and loss for others). The models were optimized using SGDm for 20 epochs with a batch size of 32 and a constant learning rate schedule.