TRAK: Attributing Model Behavior at Scale

Authors: Sung Min Park, Kristian Georgiev, Andrew Ilyas, Guillaume Leclerc, Aleksander Madry

ICML 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We demonstrate the utility of TRAK across various modalities and scales: image classifiers trained on Image Net, vision-language models (CLIP), and language models (BERT and m T5). We evaluate TRAK in a variety of vision and NLP settings. To this end, we compare TRAK with existing data attribution methods and show that it achieves significantly better tradeoffs between efficacy and computational efficiency. We perform a number of ablation studies to understand how different components of TRAK affect its performance.
Researcher Affiliation Academia 1Department of EECS, Massachusetts Institute of Technology, Cambridge, MA. Correspondence to: Sung Min Park <sp765@mit.edu>, Kristian Georgiev <krisgrg@mit.edu>, Andrew Ilyas <ailyas@mit.edu>.
Pseudocode Yes D.2. Pseudocode, Algorithm 1 TRAK for multi-class classifiers (as implemented)
Open Source Code Yes We provide code for using TRAK (and reproducing our work) at https://github.com/Madry Lab/trak.
Open Datasets Yes We use Res Net-9 classifiers trained on CIFAR-10;Res Net18 classifiers trained on the 1000-class Image Net dataset, and pre-trained BERT models finetuned on the QNLI (Question-answering Natural Language Inference) task from the GLUE benchmark (Wang et al., 2018). (See Appendix C.1 for more details.) We train image-text models using the CLIP objective on MS COCO (Lin et al., 2014). Akyurek et al. (2022) develop a testbed for the fact tracing problem by way of a dataset called FTRACE-TREX.
Dataset Splits Yes Let {S1, . . . , Sm : Si S} be m randomly sampled subsets of the training set S, each of size α n for some α (0, 1). Each subset Sj is sampled to be 50% of the size of S. Finally, we average the LDS (Definition 2.3) across 2,000 examples of interest, sampled at random from the validation set, and report this score along with the 95% bootstrap confidence intervals corresponding to the random re-sampling from the subsets Sj.
Hardware Specification Yes For all of our experiments, we use NVIDIA A100 GPUs each with 40GB of memory and 12 CPU cores.
Software Dependencies No The paper mentions software like PyTorch, Hugging Face, JAX, and functorch, but does not provide specific version numbers for these dependencies.
Experiment Setup Yes For CIFAR-2, we use (max) learning rate 0.4, momentum 0.9, weight decay 5e-4, and train for 100 epochs using a cyclic learning rate schedule with a single peak at epoch 5. For CIFAR-10, we replace the learning rate with 0.5 and train for 24 epochs. Models are trained from scratch for 15 epochs, cyclic learning rate with peak at epoch 2 and initial learning rate 5.2, momentum 0.8, weight decay 4e-5, and label smoothing 0.05. We use SGD (20 epochs, learning rate starting at 1e-3) instead of Adam W, and we remove the last tanh non-linearity before the classification layer. We train for 100 epochs using the Adam optimizer with batch size 600, a cosine learning rate schedule with starting learning rate 0.001, weight decay 0.1, and momentum 0.9.