Code for the paper "Robustness Certificates for Sparse Adversarial Attacks by Randomized Ablation" by Alexander Levine and Soheil Feizi. Files are provided for training and evaluation of certifiably robust classifiers, robust to L_0 attacks, on MNIST, CIFAR-10, and ImageNet datasets.
On MNIST and CIFAR-10, there are two architectures provided, which differ in how they encode ablated (NULL) pixels: a standard (multichannel) architecture, and an architecture which encodes NULL as the mean value on the dataset. Files for the mean encoding have their names suffixed with _mean.
Code should run with Python 3.7 and PyTorch 1.2.0.
Explanation of files: (substitute mnist for cifar or imagenet appropriately)
-
train_mnist.pywill train the base classifier on ablated images, and save the model to thecheckpointsdirectory -
mnist_certify.pywill load a model fromcheckpoints, certify the robustness of images from the test set, and save the list of robustness certificates to theradiidirectory as a PyTorch save file (.pth) -
mnist_predict.pyWill load a model fromcheckpoints, evaluate the prediction accuracy on images from the test set, and save summary statistics to theaccuraciesdirectory as a PyTorch save file (.pth)
Example Usage: (training MNIST with 45 retained pixels)
python3 train_mnist.py --keep 45 --lr 0.01
python3 train_mnist.py --keep 45 --lr 0.001 --resume mnist_lr_0.01_keep_45_epoch_199.pth
python3 mnist_certify.py --keep 45 --model mnist_lr_0.001_keep_45_epoch_399_resume_mnist_lr_0.01_keep_45_epoch_199.pth.pth
python3 mnist_predict.py --keep 45 --model mnist_lr_0.001_keep_45_epoch_399_resume_mnist_lr_0.01_keep_45_epoch_199.pth.pth
Caveats:
-
imagenetfiles expect the ILSVRC2012 training and validation sets to be in the directoriesimagenet-train/trainandimagenet-val/val, respectively. This can be changed using the--trainpathand--valpathoptions. -
While
train_cifar.pyhas the option of training a ResNet50 model,cifar_certify.pyandcifar_predict.pyare hardcoded to use ResNet18.
Adversarial Attack Tests: for MNIST only, there is code to attack the robust model using the Pointwise attack from FoolBox:
mnist_test_pointwise.pywill load a model from 'checkpoints', perform the Pointwise attack on images from the mnist test set, and save data on each attacked image as a separate file to to theempiricaldirectory as PyTorch save files (.pth)
Attributions:
- Code in the
pytorch_cifardirectory is from https://github.com/kuangliu/pytorch-cifar, with slight modification to allow for 6-channel input. - The file
resnet_imgnt.pyis modified from the PyTorch torchvision implementation of ResNet (https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py), again with slight modification to allow for 6-channel input.