Massively Scaling Heteroscedastic Classifiers
Authors: Mark Collier, Rodolphe Jenatton, Basil Mustafa, Neil Houlsby, Jesse Berent, Effrosyni Kokiopoulou
ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | On large image classification datasets with up to 4B images and 30k classes our method requires 14 fewer additional parameters, does not require tuning the temperature on a held-out set and performs consistently better than the baseline heteroscedastic classifier. |
| Researcher Affiliation | Industry | Google AI {markcollier,rjenatton,basilm,neilhoulsby,jberent,kokiopou}@google.com |
| Pseudocode | No | The paper describes the proposed method (HET-XL) and related algorithms in prose, but does not provide any structured pseudocode or algorithm blocks. |
| Open Source Code | Yes | The code to implement HET-XL as a drop-in classifier last-layer, and the scripts to replicate our Image Net-21k results are publicly available on Git Hub (https://github.com/google/uncertainty-baselines). |
| Open Datasets | Yes | We evaluate HET-XL on three image classification benchmarks: (i) Imagenet-21k, which is an expanded version of the ILSVRC2012 Image Net benchmark (Deng et al., 2009; Beyer et al., 2020)... |
| Dataset Splits | Yes | To define our 3-fold split of the dataset, we take the standard validation set as the test set (containing 102,400 points), and extract from the training set a validation set (also with 102,400 points). |
| Hardware Specification | Yes | All image classification experiments are trained on 64 TPU v3 cells with 128 cores. |
| Software Dependencies | No | The paper mentions using 'Adam optimizer' and 'default JAX hyperparameters' but does not specify version numbers for JAX or any other software libraries or dependencies. |
| Experiment Setup | Yes | All methods and models are trained for 7 epochs with the Adam optimizer with β1 = 0.9, β2 = 0.999 and weight decay of 0.1 and otherwise default JAX hyperparameters. The learning rate undergoes a 10,000 linear warm-up phase starting at 10 5 and reaching 6 10 4. A batch size of 4096 is used. |