Skip to content

Rohanjames1997/Robust-Prompt

Repository files navigation

Robust-Prompt

Visual prompting has recently gained traction in its ability to transfer the learnings of a model on a particular distribution to another out-of-distribution dataset. Moreover, one of the best models evaluated on the ImageNet-C dataset, which aims at testing robustness of image classification, uses augmentation techniques such as DeepAugment - an image-to-image neural network for performing data augmentation. In this work, we present a novel (to the best of our knowledge) technnique which combines visual prompting and DeepAugment to improve robustness of large language+vision models, specifically - CLIP. Apart from DeepAugment, we also perform composition with techninques like CutOut and CutMix and find that CutMix performs especially well on CIFAR-C, which serves as a measure for testing the model robustness to perturbations.

image

Installation

To run the requirements for running the code, you can clone the repo and install the requirements pip install -r requirements.txt To install the models run the download_models.sh file.

Structure

  • /DeepAugment contains scripts for evaluating deep augmented samples for ImageNet andn CIFAR-100.
  • /Evaluation contains scripts for evaluating models on CIFAR-C dataset
  • /ImageNetLoader contains scripts for loading and manipulating ImageNet data
  • /model_checkpoints_version_1 contains checkpoints for our saved models
  • /VisualPrompting contains scripts for training Visual prompts

Data augmentation

DeepAug:

CUDA_VISIBLE_DEVICES=0 python3 DeepAugment/CAE_distort_cifar.py --total-workers=5 --worker-number=0

CutMix

python VisualPrompting/main_clip_cutmix.py 

CutOut:

To create the augmented images using cutout, run the following command:

python cutout/create_and_save_data.py --dataset cifar100 --data_augmentation --cutout --length 16 --batch_size 1

Visual Prompt Tuning

Once the dataset is created concatenated with the augmented data from CIFAR100, we run the prompt tuning stage to train the network to learn the parameters for the prompt. To tune the prompt for the CLiP model, run the following command:

python3 VisualPrompting/main_clip.py --dataset cifar100 --root ./data --train_folder [TRAIN_FOLDER_PATH_WITH_AUGMENTATIONS] --val_folder [VAL_FOLDER_PATHS_WITH_AUGMENTATIONS] --num_workers 1 --batch_size 64

Acknowledgments

We would like to thank the work done by the following:

  • Hyojin Bahng, Ali Jahanian, Swami Sankaranarayanan, and Phillip Isola. Exploring visual prompts for adapting large-scale models, 2022. URL https://arxiv.org/abs/2203.17274
  • Dan Hendrycks and Thomas Dietterich. Benchmarking neural network robustness to common corruptions and perturbations, 2019. URL https://arxiv.org/abs/1903.12261.

References:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •