Align Your Prompts: Test-Time Prompting with Distribution Alignment for Zero-Shot Generalization
Authors: Jameel Abdul Samadh, Mohammad Hanan Gani, Noor Hussein, Muhammad Uzair Khattak, Muhammad Muzammal Naseer, Fahad Shahbaz Khan, Salman H. Khan
NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Evaluating against the domain generalization benchmark, our method improves zero-shot top1 accuracy beyond existing prompt-learning techniques, with a 3.08% improvement over the baseline Ma PLe. In cross-dataset generalization with unseen categories across 10 datasets, our method improves consistently across all datasets compared to the existing state-of-the-art. |
| Researcher Affiliation | Academia | Jameel Hassan1 Hanan Gani1 Noor Hussein1 Muhammad Uzair Khattak1 Muzammal Naseer1 Fahad Shahbaz Khan1,2 Salman Khan1,3 1Mohamed Bin Zayed University of AI 2Linköping University 3Australian National University {jameel.hassan, hanan.ghani, noor.hussein, uzair.khattak muzammal.naseer, fahad.khan, salman.khan} @mbzuai.ac.ae |
| Pseudocode | No | The paper does not contain any structured pseudocode or algorithm blocks. |
| Open Source Code | Yes | Our source code and models are available at https://jameelhassan.github.io/promptalign/. |
| Open Datasets | Yes | For the domain generalization setting, we evaluate the four out-of-distribution (OOD) variants of Image Net [7]; Image Net V2 [31], Image Net-Sketch [38], Image Net-A [16] and Image Net R [15]. We also evaluate the domain generalization setting on the recent challenging Photorealistic Unreal Graphics (PUG) dataset [3], comprising different textures, backgrounds, sizes and orientations. In the cross-dataset evaluation, we follow TPT [34] and evaluate the performance of methods on 10 image classification datasets covering a wide range of visual recognition tasks. This includes one generic-objects dataset Caltech101 [10]; five fine-grained datasets Oxford Pets [28], Stanford Cars [20], Flowers102 [27], Food101 [4] and FGVC-Aircraft [24], which contain images of animals, flowers and transportation; and four datasets of scenes, textures, satellite imagery and human actions SUN397 [36], DTD [6], EUROSAT [13] and UCF101 [35] respectively. |
| Dataset Splits | Yes | We use Image Net validation set for this ablation and choose β = 100 across all experiments. |
| Hardware Specification | Yes | Our models were implemented on a single NVIDIA A100 40GB GPU using the Py Torch framework. |
| Software Dependencies | No | The paper mentions "Py Torch framework" but does not specify its version number or any other software dependencies with their versions. |
| Experiment Setup | Yes | Following Ma PLe [18], we train on Image Net using 16-shot training data with 2 prompt tokens for a depth of 3 layers. We optimize the prompts on both the text and vision branches using a single test image. We obtain 63 augmented views using random resized crops and horizontal flip augmentations to construct a batch of 64 images including the original image to mimic the setting of TPT. From the 64 predictions, we obtain top 10% confident predictions based on the lowest entropy and compute the average prediction probability. We compute the token distribution alignment loss between the tokens of all 64 images. We optimize the prompts to minimize the combined loss of average prediction entropy and the token distribution alignment loss using the Adam W optimizer. We use a learning rate of 5e 4 for the fine-grained datasets Flowers102, Oxford Pets, Food101, SUN397, FGVCAircraft, and Euro SAT and a learning rate of 0.04 for the rest of the datasets, and set the loss scale factor β equal to 100. |