Just Train Twice: Improving Group Robustness without Training Group Information

Authors: Evan Z Liu, Behzad Haghgoo, Annie S Chen, Aditi Raghunathan, Pang Wei Koh, Shiori Sagawa, Percy Liang, Chelsea Finn

ICML 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental On four image classification and natural language processing tasks with spurious correlations, we show that JTT closes 73% of the gap in worst-group accuracy between standard ERM and group DRO, while only requiring group annotations on a small validation set in order to tune hyperparameters.
Researcher Affiliation Academia 1Department of Computer Science, Stanford University. Correspondence to: Evan Zheran Liu <evanliu@cs.stanford.edu>.
Pseudocode Yes Algorithm 1 JTT training
Open Source Code Yes Reproducibility. Our code is publicly available at https://github.com/anniesch/jtt.
Open Datasets Yes We evaluate JTT on two image classification datasets with spurious correlations, Waterbirds (Wah et al., 2011) and Celeb A (Liu et al., 2015) and two natural language processing datasets, Multi NLI (Williams et al., 2018) and Civil Comments-WILDS (Borkan et al., 2019; Koh et al., 2021).
Dataset Splits Yes In all the experiments, we tune the algorithm hyperparameters and model hyperparamters (such as early stopping) based on the worst-group accuracy on the validation set.
Hardware Specification No The paper does not provide specific hardware details such as GPU/CPU models or types of accelerators used for running experiments.
Software Dependencies No Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., De Vito, Z., Lin, Z., Desmaison, A., Antiga, L., and Lerer, A. Automatic differentiation in pytorch, 2017.
Experiment Setup Yes We tune the algorithm hyperparameters (the number of training epochs T for the identification model ˆfid, and the upweight factor λup) and both identification and final model hyperparameters (e.g., the learning rate and 2 regularization strength) based on the worst-group error of the final model ˆffinal on the validation set. In our experiments, we share the same hyperparameters and architecture between the identification and final models, outside of the early stopping T of the identification model, and we sometimes find it helpful to learn them with different optimizers. Note that setting the upweight factor λup to 1 recovers ERM, so JTT should perform at least as well as ERM, given a sufficiently large validation set. We describe full training details in Appendix A.