Set Learning for Accurate and Calibrated Models
Authors: Lukas Muttenthaler, Robert A. Vandermeulen, Qiuyi Zhang, Thomas Unterthiner, Klaus Robert Muller
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this work, we propose a novel method to alleviate these problems that we call odd-k-out learning (OKO), which minimizes the cross-entropy error for sets rather than for single examples. This naturally allows the model to capture correlations across data examples and achieves both better accuracy and calibration, especially in limited training data and class-imbalanced regimes. Perhaps surprisingly, OKO often yields better calibration even when training with hard labels and dropping any additional calibration parameter tuning, such as temperature scaling. We demonstrate this in extensive experimental analyses and provide a mathematical theory to interpret our findings. We emphasize that OKO is a general framework that can be easily adapted to many settings and a trained model can be applied to single examples at inference time, without significant run-time overhead or architecture changes. |
| Researcher Affiliation | Collaboration | Lukas Muttenthaler1,2,3, , , Robert A. Vandermeulen1,2,*, Qiuyi (Richard) Zhang3, Thomas Unterthiner3, and Klaus-Robert Müller1,2,3,4,5 1Machine Learning Group, Technische Universität Berlin, Germany 2Berlin Institute for the Foundations of Learning and Data, Berlin, Germany 3Google Deep Mind 4Department of Artificial Intelligence, Korea University, Seoul 5Max Planck Institute for Informatics, Saarbrücken, Germany |
| Pseudocode | Yes | Algorithm 1 A OKO set sampling Input: D, C, k C is the number of classes and k is the number of odd classes Output: Sx, Sy, y |
| Open Source Code | Yes | 1A JAX implementation of OKO is publicly available at: https://github.com/Lukas Mut/OKO |
| Open Datasets | Yes | For every experiment we present in this section, we use a simple randomlyinitialized CNN for MNIST and Fashion MNIST and Res Net18 and Res Net34 architectures (He et al., 2016) for CIFAR-10 and CIFAR-100 respectively. We use standard SGD with momentum and schedule the learning rate via cosine annealing. |
| Dataset Splits | Yes | We select hyperparameters and train every model until convergence on a held-out validation set. To examine generalization performance and model calibration in low data regimes, we vary the number of training data points while holding the number of test data points fixed. |
| Hardware Specification | Yes | We used a compute time of approximately 50 hours on a single Nvidia A100 GPU with 40GB VRAM for all CIFAR-10 and CIFAR-100 experiments using a Res Net-18 or Res Net-34 respectively and approximately 100 CPU-hours of 2.90GHz Intel Xeon Gold 6326 CPUs for MNIST and Fashion MNIST experiments using the custom convolutional neural network architecture. The computations were performed on a standard, large-scale academic SLURM cluster. |
| Software Dependencies | No | The paper mentions "JAX implementation" and provides code, but it does not specify version numbers for JAX or other software libraries. |
| Experiment Setup | Yes | For every experiment we present in this section, we use a simple randomlyinitialized CNN for MNIST and Fashion MNIST and Res Net18 and Res Net34 architectures (He et al., 2016) for CIFAR-10 and CIFAR-100 respectively. We use standard SGD with momentum and schedule the learning rate via cosine annealing. We select hyperparameters and train every model until convergence on a held-out validation set. ... We set k in OKO to 1 for all experiments. |