Epistemic Neural Networks

Authors: Ian Osband, Zheng Wen, Seyed Mohammad Asghari, Vikranth Dwaracherla, MORTEZA IBRAHIMI, Xiuyuan Lu, Benjamin Van Roy

NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Figure 2 offers a preview of results presented in Section 6, where we compare these approaches on Image Net. The quality of the Res Net s marginal predictions measured by classification error or marginal log-loss does not change much if supplemented with an epinet. However the epinet-enhanced Res Net dramatically improves the quality of joint predictions, as measured by the joint log-loss, outperforming the ensemble of 100 particles, with total parameters less than 2 particles.
Researcher Affiliation Industry Ian Osband , Zheng Wen, Seyed Mohammad Asghari, Vikranth Dwaracherla, Morteza Ibrahimi, Xiuyuan Lu, and Benjamin Van Roy Google Deep Mind, Efficient Agent Team, Mountain View {ian.osband, m.ibrahimi}@gmail.com {zhengwen,smasghari,vikranthd,lxlu,benvanroy}@google.com
Pseudocode Yes Algorithm 1 ENN training via SGD Inputs: dataset training examples D = {(xi,yi,i)}N i=1 ENN network f, reference PZ, initialization θ0 loss ℓevaluates example (xi,yi,i) for index z batch size data samples n B, index samples n Z optimizer update rule and number of iterations T Returns: θT parameter estimates for the ENN.
Open Source Code Yes Two related github repositories complement this paper: 1. enn: https://anonymous.4open.science/r/enn-55BC 2. neural_testbed: https://anonymous.4open.science/r/neural_testbed-8961 These libraries contain the code necessary to reproduce the key results in our paper, divided into repositories based on focus.
Open Datasets Yes Image Net (Deng et al., 2009); qualitatively similar results for both CIFAR-10 and CIFAR-100 are presented in Appendix H. We compare our epinet agent against ensemble approaches as well as the uncertainty baselines of Nado et al. (2021). ... The Neural Testbed is an open-source benchmark that evaluates the quality of joint predictions in classification problems using synthetic data produced by neural-network-based generative models (Osband et al., 2022a).
Dataset Splits No We tune the learning rate, weight decay and temperature rescaling (Wenzel et al., 2020) on Res Net-50 and apply those settings to other Res Nets. ... For Image Net, for faster evaluation, rather than sampling anchor points from the evaluation set, we split the evaluation set into batches of size 2. We then iterate over these batches of size 2, re-sample τ = 10 points from each input pair, and evaluate the log-loss of an agent s joint predictions on these batches of size τ = 10.
Hardware Specification Yes We train the Res Net agents for 90 epochs on 4 4 TPUs with a per-device batch size of 128.
Software Dependencies No Our implementation of the epinet agent can be found under the path /agents/factories/ epinet.py in the anonymized neural testbed github. ... Each of these libraries is written in Python, and relies heavily on JAX for scientific computing (Bradbury et al., 2018). We view this open-source effort as a major contribution of our paper. The first library, enn, focuses on the design of epistemic neural networks and their training. This includes all of our network definitions and loss functions. Our library is built around Haiku (Babuschkin et al., 2020).
Experiment Setup Yes We tune the learning rate, weight decay and temperature rescaling (Wenzel et al., 2020) on Res Net-50 and apply those settings to other Res Nets. ... We apply L2 weight decay of strength 1e-4. We also incorporate label smoothing into the loss, where instead of one-hot labels, the incorrect classes receive a small weight of 0.1/C. We train the Res Net agents for 90 epochs on 4 4 TPUs with a per-device batch size of 128. ... We fix the index dimension DZ = 30... We use a 1-layer MLP with 50 hidden units for the learnable network (6). ... We optimize the loss using SGD with a learning rate 0.1, Nesterov momentum and decay 0.9. The epinet is trained on the same hardware with the same batch size for 9 epochs.