Test-Time Prompt Tuning for Zero-Shot Generalization in Vision-Language Models

Authors: Manli Shu, Weili Nie, De-An Huang, Zhiding Yu, Tom Goldstein, Anima Anandkumar, Chaowei Xiao

NeurIPS 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental In this work, we propose test-time prompt tuning (TPT), a method that can learn adaptive prompts on the fly with a single test sample. For image classification, TPT optimizes the prompt by minimizing the entropy with confidence selection so that the model has consistent predictions across different augmented views of each test sample. In evaluating generalization to natural distribution shifts, TPT improves the zero-shot top-1 accuracy of CLIP by 3.6% on average, surpassing previous prompt tuning approaches that require additional task-specific training data. In evaluating cross-dataset generalization with unseen categories, TPT performs on par with the state-of-the-art approaches that use additional training data. Project page: https://azshue.github.io/TPT/.
Researcher Affiliation Collaboration Manli Shu1 Weili Nie2 De-An Huang2 Zhiding Yu 2 Tom Goldstein 1 Anima Anandkumar2,3, Chaowei Xiao2,4 1 University of Maryland, 2 NVIDIA, 3 Caltech, 4 Arizona State University
Pseudocode No The paper does not contain any explicitly labeled pseudocode or algorithm blocks. The methods are described in text and mathematical formulations.
Open Source Code Yes Project page: https://azshue.github.io/TPT/. (The project page is typically where open-source code is provided or linked). Also, in the checklist: Did you include the code, data, and instructions needed to reproduce the main experimental results (either in the supplemental material or as a URL)? [Yes] See implementation details in section 4 and the supplemental materials.
Open Datasets Yes We evaluate the model s robustness to natural distribution shifts on 4 Image Net Variants as follows [...] Image Net-V2 [58], Image Net-A [59], Image Net-R [14], Image Net-Sketch [61]. We conduct a cross-dataset evaluation on the task of image classification. We consider 10 datasets, covering fine-grained classifications including species of plants or animals (Flower102 [62], Oxford Pets [63]), scenes (SUN397 [64]), textures (DTD [65]), food (Food101 [66]), transportation (Stanford Cars [67], Aircraft [68]), human actions (UCF101 [69]), satellite images (Euro SAT [70]), and general objects (Caltech101 [71]).
Dataset Splits No The paper mentions 'Following their original configuration, we train both methods on Image Net using 16-shot training data per category' for baselines and 'Image Net validation set' in the context of describing ImageNet-Sketch, but it does not explicitly provide the specific validation dataset splits used for its own experiments in a way that allows direct reproduction of the split percentages or counts for validation data.
Hardware Specification No The paper does not provide specific details about the hardware used for experiments, such as GPU models, CPU types, or memory specifications. The checklist indicates this information might be in supplemental materials, but it is not in the main paper.
Software Dependencies No The paper does not explicitly list software dependencies with specific version numbers (e.g., Python, PyTorch, CUDA versions).
Experiment Setup Yes For TPT, we initialize the prompt as the default hand-crafted one a photo of a", and optimize the corresponding 4 tokens in the text input embedding space based on a single test image. We augment a single test image 63 times using random resized crops and construct a batch of 64 images, including the original one. Among the 64 predictions, we select the top 10% (ρ=0.1) confident samples (lowest 10% in self-entropy) and compute the entropy of the averaged probability of the selected predictions (i.e., marginal entropy). We optimize the prompt to minimize the marginal entropy for 1 step, using the Adam W optimizer with a learning rate of 0.005. [...] All learnable tokens are initialized in the text embedding space from a Gaussian distribution with σ = 0.02. We optimize the prompt on all support images of a test sample for 64 steps, using the Adam W optimizer with a learning rate of 0.005