Discrete Representations Strengthen Vision Transformer Robustness

Authors: Chengzhi Mao, Lu Jiang, Mostafa Dehghani, Carl Vondrick, Rahul Sukthankar, Irfan Essa

ICLR 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experimental results demonstrate that adding discrete representation on four architecture variants strengthens Vi T robustness by up to 12% across seven Image Net robustness benchmarks while maintaining the performance on Image Net.
Researcher Affiliation Collaboration Chengzhi Mao2 , Lu Jiang1, Mostafa Dehghani1, Carl Vondrick2, Rahul Sukthankar1, Irfan Essa1,3 1 Google Research 2 Computer Science, Columbia University 3 School of Interactive Computing, Georgia Insitute of Technology
Pseudocode Yes 1 # x: input image mini-batch; pixel_embed: pixel 2 # embeddings of x. vqgan.encoder and codebook are 3 # initialized form the pretraining. 4 import jax.numpy as np 5 discrete_token = jax.lax.stop_gradient(vqgan.encoder(x)) 6 discrete_embed = np.dot(discrete_token, codebook) 7 tokens = np.concatenate( 8 [discrete_embed, pixel_embed], dim=2) 9 predictions = Transformer Encoder(tokens) (a) Pseudo JAX code
Open Source Code No We will release our model and code to assist in comparisons and to support other researchers in reproducing our experiments and results.
Open Datasets Yes Datasets. In the experiments, we train all the models, including ours, on Image Net 2012 or Image Net-21K under the same training settings, where we use identical training data, batch size, and learning rate schedule, etc. Image Net (Deng et al., 2009) is the standard validation set of ILSVRC2012.
Dataset Splits Yes Image Net (Deng et al., 2009) is the standard validation set of ILSVRC2012.
Hardware Specification Yes Our batchsize is 4096 and trained on 64 TPU cores.
Software Dependencies No We implement our model in Jax and optimize all models with Adam (Kingma & Ba, 2014).
Experiment Setup Yes Unless specified otherwise, the input images are resized to 224x224, trained with a batch size of 4,096, with a weight decay of 0.1. We use a linear learning rate warmup and cosine decay. On Image Net, the models are trained for 300 epochs... We start from a learning rate of 0.001 and train 300 epoch. We use linear warm up for 10k iterations and then a cosine annealing for the learning rate schedule. We use 224 224 resolution for the input image. We use Rand Aug with hyper-parameter (2,15) and mixup with α = 0.5, and we apply a dropout rate of 0.1 and stochastic block dropout of 0.1.