Understanding Reconstruction Attacks with the Neural Tangent Kernel and Dataset Distillation
Authors: Noel Loo, Ramin Hasani, Mathias Lechner, Alexander Amini, Daniela Rus
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this work, we first build a stronger version of the dataset reconstruction attack and show how it can provably recover the entire training set in the infinite width regime. We then empirically study the characteristics of this attack on two-layer networks and reveal that its success heavily depends on deviations from the frozen infinite-width Neural Tangent Kernel limit. Next, we study the nature of easily-reconstructed images. We show that both theoretically and empirically, reconstructed images tend to outliers in the dataset, and that these reconstruction attacks can be used for dataset distillation, that is, we can retrain on reconstructed images and obtain high predictive accuracy.1 |
| Researcher Affiliation | Academia | Noel Loo, Ramin Hasani, Mathias Lechner, Alexander Amini and Daniela Rus MIT CSAIL Cambridge, Massachussetts, USA {loo, rhasani, mlechner, amini, rus}@mit.edu |
| Pseudocode | Yes | Algorithm 1 Standard Reconstruction Attack Algorithm 2 Batched Reconstruction Attack |
| Open Source Code | Yes | Code is available at https://github.com/yolky/understanding_reconstruction |
| Open Datasets | Yes | To answer these questions, we follow the experimental protocol of Haim et al. (2022), where we try to recover images from the MNIST and CIFAR10 datasets on the task of odd/even digit or animal/vehicle classification for MNIST and CIFAR-10, respectively. |
| Dataset Splits | No | The paper does not explicitly specify train/validation/test splits with percentages or sample counts for general reproduction, focusing more on training set sizes and test accuracy. |
| Hardware Specification | Yes | All experiments were run on Nvidia Titan RTX graphics cards with 24Gb VRAM. |
| Software Dependencies | No | We use the JAX, Optax, Flax, and neural-tangents libraries (Bradbury et al., 2018; Babuschkin et al., 2020; Heek et al., 2020; Novak et al., 2020; 2022). |
| Experiment Setup | Yes | Unless otherwise stated, networks trained on real data are trained for 106 iterations of full batch gradient descent, with SGD with momentum 0.9. For the learning rate, we set η = N 2e 7, where N is the number of training images. For distilled data, we use a learning rate of η = N 6e 6, where N is now the distilled dataset size. We did not find that results were heavily dependent on the learning rates used during training. Additionally, if the training loss was less than 1e 10, we terminated training early. Every reconstruction curve in the main text is the average of 3 unique networks trained on 3 unique splits of training data. For binary classification, we use labels in {+1, 2}, and for 10-way multiclass classification, we use labels of 0.9 corresponding to the selected class and -0.1 for other classes. To create reconstructions, we initialize reconstruction images with a standard deviation of 0.2, and dual parameters to be uniform random within [ 0.5, 0.5]. We use Adam optimizer (Kingma & Ba, 2015), with a learning rate of 0.02 for all reconstructions. We optimize the images for 80k iterations. We annealed softplus temperature from 10 to 200 over the course of training. |