Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 203 additions & 0 deletions examples/attack_cifar100_gradinversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import os

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from torch.nn.modules.loss import CrossEntropyLoss

from gradattack.attacks.gradientinversion import GradientReconstructor
from gradattack.datamodules import CIFAR100DataModule
from gradattack.defenses.defense_utils import DefensePack
from gradattack.models import create_lightning_module
from gradattack.trainingpipeline import TrainingPipeline
from gradattack.utils import (cross_entropy_for_onehot, parse_args,
patch_image, save_fig)

cifar100_mean = torch.tensor(
[0.50705882, 0.48666667, 0.44078431])
cifar100_std = torch.tensor(
[0.26745098, 0.25647059, 0.27607843])
dm = cifar100_mean[:, None, None]
ds = cifar100_std[:, None, None]


def setup_attack():
"""Setup the pipeline for the attack"""
args, hparams, attack_hparams = parse_args()
print(attack_hparams)

global ROOT_DIR, DEVICE, EPOCH, devices

DEVICE = torch.device(f"cuda:{args.gpuid}")
EPOCH = attack_hparams["epoch"]
devices = [args.gpuid]

pl.utilities.seed.seed_everything(1234 + EPOCH)
torch.backends.cudnn.benchmark = True

BN_str = ''

if not args.attacker_eval_mode:
BN_str += "-attacker_train"
if not args.defender_eval_mode:
BN_str += '-defender_train'
if args.BN_exact:
BN_str = 'BN_exact'
attack_hparams['attacker_eval_mode'] = False

datamodule = CIFAR100DataModule(batch_size=args.batch_size,
augment={
"hflip": False,
"color_jitter": None,
"rotation": -1,
"crop": False
},
num_workers=1,
seed=args.data_seed)
print("Loaded data!")
if args.defense_instahide or args.defense_mixup: # Customize loss
loss = cross_entropy_for_onehot
else:
loss = CrossEntropyLoss(reduction="mean")

if args.defense_instahide:
model = create_lightning_module(args.model,
datamodule.num_classes,
training_loss_metric=loss,
pretrained=False,
# ckpt="checkpoint/InstaHide_ckpt.ckpt",
**hparams).to(DEVICE)
elif args.defense_mixup:
model = create_lightning_module(args.model,
datamodule.num_classes,
training_loss_metric=loss,
pretrained=False,
# ckpt="checkpoint/Mixup_ckpt.ckpt",
**hparams).to(DEVICE)
else:
model = create_lightning_module(
args.model,
datamodule.num_classes,
training_loss_metric=loss,
pretrained=False,
# ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt",
**hparams).to(DEVICE)

logger = TensorBoardLogger("tb_logs", name=f"{args.logname}")
trainer = pl.Trainer(devices=-1, accelerator="auto", benchmark=True, logger=logger)
pipeline = TrainingPipeline(model, datamodule, trainer)

defense_pack = DefensePack(args, logger)
if attack_hparams["mini"]:
datamodule.setup("attack_mini")
elif attack_hparams["large"]:
datamodule.setup("attack_large")
else:
datamodule.setup("attack")

defense_pack.apply_defense(pipeline)

ROOT_DIR = f"{args.results_dir}/CIFAR100-{args.batch_size}-{str(defense_pack)}/tv={attack_hparams['total_variation']}{BN_str}-bn={attack_hparams['bn_reg']}-dataseed={args.data_seed}/Epoch_{EPOCH}"
try:
os.makedirs(ROOT_DIR, exist_ok=True)
except FileExistsError:
pass
print("storing in root dir", ROOT_DIR)

if "InstaHideDefense" in defense_pack.defense_params.keys():
cur_lams = defense_pack.instahide_defense.cur_lams.cpu().numpy()
cur_selects = defense_pack.instahide_defense.cur_selects.cpu().numpy()
np.savetxt(f"{ROOT_DIR}/epoch_lams.txt", cur_lams)
np.savetxt(f"{ROOT_DIR}/epoch_selects.txt", cur_selects.astype(int))
elif "MixupDefense" in defense_pack.defense_params.keys():
cur_lams = defense_pack.mixup_defense.cur_lams.cpu().numpy()
cur_selects = defense_pack.mixup_defense.cur_selects.cpu().numpy()
np.savetxt(f"{ROOT_DIR}/epoch_lams.txt", cur_lams)
np.savetxt(f"{ROOT_DIR}/epoch_selects.txt", cur_selects.astype(int))

return pipeline, attack_hparams


def run_attack(pipeline, attack_hparams):
"""Launch the real attack"""
trainloader = pipeline.datamodule.train_dataloader()
model = pipeline.model

for batch_idx, (batch_inputs, batch_targets) in enumerate(trainloader):
BATCH_ROOT_DIR = ROOT_DIR + f"/{batch_idx}"
os.makedirs(BATCH_ROOT_DIR, exist_ok=True)
save_fig(batch_inputs,
f"{BATCH_ROOT_DIR}/original.png",
save_npy=True,
save_fig=False)
save_fig(patch_image(batch_inputs),
f"{BATCH_ROOT_DIR}/original.png",
save_npy=False)

batch_inputs, batch_targets = batch_inputs.to(
DEVICE), batch_targets.to(DEVICE)

batch_gradients, step_results = model.get_batch_gradients(
(batch_inputs, batch_targets),
batch_idx,
eval_mode=attack_hparams["defender_eval_mode"],
apply_transforms=True,
stop_track_bn_stats=False,
BN_exact=attack_hparams["BN_exact"])
batch_inputs_transform, batch_targets_transform = step_results[
"transformed_batch"]

save_fig(
batch_inputs_transform,
f"{BATCH_ROOT_DIR}/transformed.png",
save_npy=True,
save_fig=False,
)
save_fig(
patch_image(batch_inputs_transform),
f"{BATCH_ROOT_DIR}/transformed.png",
save_npy=False,
)

attack = GradientReconstructor(
pipeline,
ground_truth_inputs=batch_inputs_transform,
ground_truth_gradients=batch_gradients,
ground_truth_labels=batch_targets_transform,
reconstruct_labels=attack_hparams["reconstruct_labels"],
num_iterations=10000,
signed_gradients=True,
signed_image=attack_hparams["signed_image"],
boxed=True,
total_variation=attack_hparams["total_variation"],
bn_reg=attack_hparams["bn_reg"],
lr_scheduler=True,
lr=attack_hparams["attack_lr"],
mean_std=(dm, ds),
attacker_eval_mode=attack_hparams["attacker_eval_mode"],
BN_exact=attack_hparams["BN_exact"])
tb_logger = TensorBoardLogger(BATCH_ROOT_DIR, name="tb_log")
attack_trainer = pl.Trainer(
devices=1,
accelerator="auto",
logger=tb_logger,
max_epochs=1,
benchmark=True,
checkpoint_callback=False,
)
attack_trainer.fit(attack)
result = attack.best_guess.detach().to("cpu")

save_fig(result,
f"{BATCH_ROOT_DIR}/reconstructed.png",
save_npy=True,
save_fig=False)
save_fig(patch_image(result),
f"{BATCH_ROOT_DIR}/reconstructed.png",
save_npy=False)


if __name__ == "__main__":
pipeline, attack_hparams = setup_attack()
run_attack(pipeline, attack_hparams)
13 changes: 7 additions & 6 deletions examples/attack_cifar10_gradinversion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import numpy as np
import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -32,7 +33,7 @@ def setup_attack():
EPOCH = attack_hparams["epoch"]
devices = [args.gpuid]

pl.utilities.seed.seed_everything(1234 + EPOCH)
pl.utilities.seed.seed_everything(42 + EPOCH)
torch.backends.cudnn.benchmark = True

BN_str = ''
Expand Down Expand Up @@ -65,26 +66,26 @@ def setup_attack():
datamodule.num_classes,
training_loss_metric=loss,
pretrained=False,
ckpt="checkpoint/InstaHide_ckpt.ckpt",
# ckpt="checkpoint/InstaHide_ckpt.ckpt",
**hparams).to(DEVICE)
elif args.defense_mixup:
model = create_lightning_module("ResNet18",
datamodule.num_classes,
training_loss_metric=loss,
pretrained=False,
ckpt="checkpoint/Mixup_ckpt.ckpt",
# ckpt="checkpoint/Mixup_ckpt.ckpt",
**hparams).to(DEVICE)
else:
model = create_lightning_module(
"ResNet18",
datamodule.num_classes,
training_loss_metric=loss,
pretrained=False,
ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt",
# ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt",
**hparams).to(DEVICE)

logger = TensorBoardLogger("tb_logs", name=f"{args.logname}")
trainer = pl.Trainer(gpus=devices, benchmark=True, logger=logger)
trainer = pl.Trainer(devices=-1, accelerator="auto", benchmark=True, logger=logger)
pipeline = TrainingPipeline(model, datamodule, trainer)

defense_pack = DefensePack(args, logger)
Expand Down Expand Up @@ -179,7 +180,7 @@ def run_attack(pipeline, attack_hparams):

tb_logger = TensorBoardLogger(BATCH_ROOT_DIR, name="tb_log")
attack_trainer = pl.Trainer(
gpus=devices,
devices=1, accelerator="auto",
logger=tb_logger,
max_epochs=1,
benchmark=True,
Expand Down
3 changes: 2 additions & 1 deletion examples/attack_decode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
The decode step for InstaHide and MixUp is based on Carlini et al.'s attack: Is Private Learning Possible with Instance Encoding? (https://arxiv.org/pdf/2011.05315.pdf)
The decode step for InstaHide and MixUp is based on Carlini et al.'s attack:
Is Private Learning Possible with Instance Encoding? (https://arxiv.org/pdf/2011.05315.pdf)
The implementation heavily relies on their code: https://github.com/carlini/privacy/commit/28b8a80924cf3766ab3230b5976388139ddef295
"""

Expand Down
98 changes: 98 additions & 0 deletions examples/train_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from gradattack.datamodules import CIFAR100DataModule, CIFAR10DataModule
from gradattack.defenses.defense_utils import DefensePack
from gradattack.models import create_lightning_module
from gradattack.trainingpipeline import TrainingPipeline
from gradattack.utils import cross_entropy_for_onehot, parse_args, parse_augmentation

torch.set_float32_matmul_precision("high")
cifar_dm = {
"CIFAR10": CIFAR10DataModule,
"CIFAR100": CIFAR100DataModule
}

if __name__ == "__main__":
args, hparams, _ = parse_args()
method = ""
if args.defense_mixup:
method += 'mixup_'
elif args.defense_instahide:
method += 'instahide_'
elif args.defense_gradprune:
method += 'gradprune_'
else:
method += 'vanilla'
exp_name = f"{args.data}/{method}{args.scheduler}"
logger = WandbLogger(
project='GradAttack',
name=exp_name,
)

early_stop_callback = EarlyStopping(
monitor="val/loss_epoch",
min_delta=0.00,
patience=50,
verbose=False,
mode="min",
)
callback = ModelCheckpoint(
exp_name, save_last=True, save_top_k=3, monitor="val/acc", mode="max",
)

augment = parse_augmentation(args)

datamodule = cifar_dm[args.data](
augment=augment,
batch_size=args.batch_size,
tune_on_val=args.tune_on_val,
batch_sampler=None,
)

if args.defense_instahide or args.defense_mixup:
loss = cross_entropy_for_onehot
else:
loss = torch.nn.CrossEntropyLoss()

if "multihead" in args.model:
multi_head = True
loss = torch.nn.CrossEntropyLoss()
else:
multi_head = False

model = create_lightning_module(
args.model,
datamodule.num_classes,
pretrained=args.pretrained,
ckpt=args.ckpt,
freeze_extractor=args.freeze_extractor,
training_loss_metric=loss,
log_auc=args.log_auc,
multi_class=datamodule.multi_class,
multi_head=multi_head,
log_dir=exp_name,
**hparams,
)

trainer = pl.Trainer(
default_root_dir=exp_name,
devices=1,
check_val_every_n_epoch=3,
accelerator='auto',
benchmark=True,
logger=logger,
num_sanity_val_steps=0,
max_epochs=args.n_epoch,
callbacks=[early_stop_callback, callback],
accumulate_grad_batches=args.n_accumulation_steps,
)
pipeline = TrainingPipeline(model, datamodule, trainer)

defense_pack = DefensePack(args)
defense_pack.apply_defense(pipeline)

pipeline.run()
pipeline.test()
Loading