Disentangling the Predictive Variance of Deep Ensembles through the Neural Tangent Kernel

Authors: Seijin Kobayashi, Pau Vilimelis Aceituno, Johannes von Oswald

NeurIPS 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We further show theoretically and empirically that both noise sources affect the predictive variance of non-linear deep ensembles in toy models and realistic settings after training. Finally, we propose practical ways to eliminate part of these noise sources leading to significant changes and improved OOD detection in trained deep ensembles.Our contributions are the following:We conduct empirical studies validating our theoretical results, and investigate how the different variance terms influence generalization on in and out-of-distribution.
Researcher Affiliation Academia Seijin Kobayashi Department of Computer Science ETH Zürich seijink@ethz.ch Pau Vilimelis Aceituno Institute of Neuroinformatics University of Zürich & ETH Zürich pau@ini.ethz.ch Johannes von Oswald Department of Computer Science ETH Zürich voswaldj@ethz.ch
Pseudocode No The paper describes mathematical formulations and processes using equations and prose but does not include any clearly labeled 'Pseudocode' or 'Algorithm' blocks with structured steps.
Open Source Code Yes Source code for all experiments: github.com/seijin-kobayashi/disentangle-predvar
Open Datasets Yes We conduct empirical studies validating our theoretical results, and investigate how the different variance terms influence generalization on in and out-of-distribution. We highlight the practical implications of our theory by proposing simple methods to isolate noise sources in realistic settings which can lead to improved OOD detection. We analyze the predictive variance of the kernel models based on MLPs and Convolutional Neural Networks (CNN) for various depths and widths and on subsets of MNIST [39] and CIFAR10. We fit a linearized ensemble on a larger subset of the standard 10-way classification MNIST and CIFAR10 datasets using MSE loss. When training our ensembles on MNIST, we test and average the OOD detection performance on Fashion MNIST (FM) [40], E-MNIST (EM) [41] and K-MNIST (KM) [42]. When training our ensembles on CIFAR10, we compute the AUROC for SVHN [43], LSUN [44], Tiny Image Net (TIN) and CIFAR100 (C100)
Dataset Splits No The paper mentions 'training data' and 'test set' explicitly. However, it does not specify the exact proportions (e.g., 80/10/10) or specific methods used for train/validation/test splits within the main text. It mentions dataset sizes (e.g., N=1000) for training, but not the detailed split methodology.
Hardware Specification No The paper does not explicitly state specific hardware details such as GPU models (e.g., NVIDIA A100), CPU models, or cloud computing instance types used for the experiments. The ethics checklist indicates this information is in the Appendix, but the Appendix content is not provided in this extract.
Software Dependencies No The paper does not provide specific version numbers for software dependencies (e.g., 'PyTorch 1.9' or 'Python 3.8'). While it mentions using libraries or frameworks implicitly (e.g., for neural networks), it does not detail their versions.
Experiment Setup Yes To qualitatively visualize the different terms, we construct a two-way star-shaped regression problem on a 2d-plane depicted in Figure 1. After training an ensemble we visualize its predictive variance on the input space. Our first goal is to visualize qualitative differences in the predictive variance of ensembles consisting of f lin and the 3 disentangled models from above. We train a large ensemble of size 300 where each model is a one-layer Re LU neural network with hidden dimension 512 and 1 hidden layer.We fit a linearized ensemble on a larger subset of the standard 10-way classification MNIST and CIFAR10 datasets using MSE loss. When training our ensembles on MNIST, we test and average the OOD detection performance on Fashion MNIST (FM) [40], E-MNIST (EM) [41] and K-MNIST (KM) [42]. We train the commonly used Wide Res Net 28-10 [45] on CIFAR10 with Batch Norm [46] Layers and cross-entropy (CE) loss with batchsize of 128, without data augmentation (see Table 3.2.3).