Just Train Twice: Improving Group Robustness without Training Group Information
Authors: Evan Z Liu, Behzad Haghgoo, Annie S Chen, Aditi Raghunathan, Pang Wei Koh, Shiori Sagawa, Percy Liang, Chelsea Finn
ICML 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | On four image classification and natural language processing tasks with spurious correlations, we show that JTT closes 73% of the gap in worst-group accuracy between standard ERM and group DRO, while only requiring group annotations on a small validation set in order to tune hyperparameters. |
| Researcher Affiliation | Academia | 1Department of Computer Science, Stanford University. Correspondence to: Evan Zheran Liu <evanliu@cs.stanford.edu>. |
| Pseudocode | Yes | Algorithm 1 JTT training |
| Open Source Code | Yes | Reproducibility. Our code is publicly available at https://github.com/anniesch/jtt. |
| Open Datasets | Yes | We evaluate JTT on two image classification datasets with spurious correlations, Waterbirds (Wah et al., 2011) and Celeb A (Liu et al., 2015) and two natural language processing datasets, Multi NLI (Williams et al., 2018) and Civil Comments-WILDS (Borkan et al., 2019; Koh et al., 2021). |
| Dataset Splits | Yes | In all the experiments, we tune the algorithm hyperparameters and model hyperparamters (such as early stopping) based on the worst-group accuracy on the validation set. |
| Hardware Specification | No | The paper does not provide specific hardware details such as GPU/CPU models or types of accelerators used for running experiments. |
| Software Dependencies | No | Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., De Vito, Z., Lin, Z., Desmaison, A., Antiga, L., and Lerer, A. Automatic differentiation in pytorch, 2017. |
| Experiment Setup | Yes | We tune the algorithm hyperparameters (the number of training epochs T for the identification model ˆfid, and the upweight factor λup) and both identification and final model hyperparameters (e.g., the learning rate and 2 regularization strength) based on the worst-group error of the final model ˆffinal on the validation set. In our experiments, we share the same hyperparameters and architecture between the identification and final models, outside of the early stopping T of the identification model, and we sometimes find it helpful to learn them with different optimizers. Note that setting the upweight factor λup to 1 recovers ERM, so JTT should perform at least as well as ERM, given a sufficiently large validation set. We describe full training details in Appendix A. |