Generalization bounds via distillation
Authors: Daniel Hsu, Ziwei Ji, Matus Telgarsky, Lan Wang
ICLR 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | This paper theoretically investigates the following empirical phenomenon: given a high-complexity network with poor generalization bounds, one can distill it into a network with nearly identical predictions but low complexity and vastly smaller generalization bounds. The main contribution is an analysis showing that the original network inherits this good generalization bound from its distillation, assuming the use of well-behaved data augmentation. This bound is presented both in an abstract and in a concrete form, the latter complemented by a reduction technique to handle modern computation graphs featuring convolutional layers, fullyconnected layers, and skip connections, to name a few. To round out the story, a (looser) classical uniform convergence analysis of compression is also presented, as well as a variety of experiments on cifar10 and mnist demonstrating similar generalization performance between the original network and its distillation.While this work is primarily theoretical, it is motivated by Figure 1 and related experiments: Figures 2 to 4 demonstrate that not only does distillation improve generalization upper bounds, but moreover it makes them sufficiently tight to capture intrinsic properties of the predictors, for example removing the usual bad dependence on width in these bounds (cf. Figure 3). These experiments are detailed in Section 2. |
| Researcher Affiliation | Academia | Daniel Hsu Columbia University, New York City djhsu@cs.columbia.edu Ziwei Ji, Matus Telgarsky, Lan Wang University of Illinois, Urbana-Champaign {ziweiji2,mjt,lanwang2}@illinois.edu |
| Pseudocode | No | The paper describes mathematical derivations and high-level steps for proofs and models, but does not include any explicitly labeled 'Pseudocode' or 'Algorithm' blocks, nor does it present structured, code-like procedural steps. |
| Open Source Code | No | The paper does not contain any statement about releasing source code for the methodology, nor does it provide any links to a code repository. |
| Open Datasets | Yes | as well as a variety of experiments on cifar10 and mnist demonstrating similar generalization performance between the original network and its distillation. |
| Dataset Splits | No | The paper mentions 'training data' and 'unseen data' and 'test errors' on CIFAR10 and MNIST, but it does not explicitly provide details about specific training, validation, or test dataset splits (e.g., percentages, sample counts, or citations to predefined splits). |
| Hardware Specification | No | The paper mentions support from 'NVIDIA under a GPU grant' in the Acknowledgments section, implying the use of GPUs, but it does not provide specific details such as the GPU model (e.g., 'NVIDIA A100', 'Tesla V100'), CPU models, or other hardware specifications used for running the experiments. |
| Software Dependencies | No | The paper mentions using 'Adam' and 'vanilla SGD' for optimization, and refers to 'Res Net8' architecture, but it does not specify any software dependencies with version numbers (e.g., Python version, specific deep learning framework versions like TensorFlow or PyTorch, or CUDA versions). |
| Experiment Setup | Yes | Experimental setup. As sketched before, networks were trained in a standard way on either cifar10 or mnist, and then distilled by trading off between complexity and distillation distance Φγ,m. Details are as follows. 1. Training initial network f. ... the training algorithm was Adam; this and most other choices followed the scheme in (Coleman et al., 2017) to achieve a competitively low test error on cifar10. In Figures 2b, 3 and 4, a 6-layer fully connected network was used (width 8192 in Figure 2b, widths {64, 256, 1024} in Figure 3, width 256 in Figure 4), and vanilla SGD was used to optimize. 2. Training distillation network g. Given f and a regularization strength λj, each distillation gj was found via approximate minimization of the objective g 7 Φγ,m(f, g) + λj Complexity(g). (2.1) In more detail, first g0 was initialized to f (g and f always used the same architecture) and optimized via eq. (2.1) with λ0 set to roughly risk(f)/Complexity(f), and thereafter gj+1 was initialized to gj and found by optimizing eq. (2.1) with λj+1 := 2λj. The optimization method was the same as the one used to find f. The term Complexity(g) was some computationally reasonable approximation of Lemma 3.1: for Figures 2b, 3 and 4, it was just P i W T i 2,1, but for Figures 1 and 2a, it also included a tractable surrogate for the product of the spectral norms, which greatly helped distillation performance with these deeper architectures. |