diff --git a/examples/attack_cifar100_gradinversion.py b/examples/attack_cifar100_gradinversion.py new file mode 100644 index 0000000..7d7c37b --- /dev/null +++ b/examples/attack_cifar100_gradinversion.py @@ -0,0 +1,203 @@ +import os + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from torch.nn.modules.loss import CrossEntropyLoss + +from gradattack.attacks.gradientinversion import GradientReconstructor +from gradattack.datamodules import CIFAR100DataModule +from gradattack.defenses.defense_utils import DefensePack +from gradattack.models import create_lightning_module +from gradattack.trainingpipeline import TrainingPipeline +from gradattack.utils import (cross_entropy_for_onehot, parse_args, + patch_image, save_fig) + +cifar100_mean = torch.tensor( + [0.50705882, 0.48666667, 0.44078431]) +cifar100_std = torch.tensor( + [0.26745098, 0.25647059, 0.27607843]) +dm = cifar100_mean[:, None, None] +ds = cifar100_std[:, None, None] + + +def setup_attack(): + """Setup the pipeline for the attack""" + args, hparams, attack_hparams = parse_args() + print(attack_hparams) + + global ROOT_DIR, DEVICE, EPOCH, devices + + DEVICE = torch.device(f"cuda:{args.gpuid}") + EPOCH = attack_hparams["epoch"] + devices = [args.gpuid] + + pl.utilities.seed.seed_everything(1234 + EPOCH) + torch.backends.cudnn.benchmark = True + + BN_str = '' + + if not args.attacker_eval_mode: + BN_str += "-attacker_train" + if not args.defender_eval_mode: + BN_str += '-defender_train' + if args.BN_exact: + BN_str = 'BN_exact' + attack_hparams['attacker_eval_mode'] = False + + datamodule = CIFAR100DataModule(batch_size=args.batch_size, + augment={ + "hflip": False, + "color_jitter": None, + "rotation": -1, + "crop": False + }, + num_workers=1, + seed=args.data_seed) + print("Loaded data!") + if args.defense_instahide or args.defense_mixup: # Customize loss + loss = cross_entropy_for_onehot + else: + loss = CrossEntropyLoss(reduction="mean") + + if args.defense_instahide: + model = create_lightning_module(args.model, + datamodule.num_classes, + training_loss_metric=loss, + pretrained=False, + # ckpt="checkpoint/InstaHide_ckpt.ckpt", + **hparams).to(DEVICE) + elif args.defense_mixup: + model = create_lightning_module(args.model, + datamodule.num_classes, + training_loss_metric=loss, + pretrained=False, + # ckpt="checkpoint/Mixup_ckpt.ckpt", + **hparams).to(DEVICE) + else: + model = create_lightning_module( + args.model, + datamodule.num_classes, + training_loss_metric=loss, + pretrained=False, + # ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt", + **hparams).to(DEVICE) + + logger = TensorBoardLogger("tb_logs", name=f"{args.logname}") + trainer = pl.Trainer(devices=-1, accelerator="auto", benchmark=True, logger=logger) + pipeline = TrainingPipeline(model, datamodule, trainer) + + defense_pack = DefensePack(args, logger) + if attack_hparams["mini"]: + datamodule.setup("attack_mini") + elif attack_hparams["large"]: + datamodule.setup("attack_large") + else: + datamodule.setup("attack") + + defense_pack.apply_defense(pipeline) + + ROOT_DIR = f"{args.results_dir}/CIFAR100-{args.batch_size}-{str(defense_pack)}/tv={attack_hparams['total_variation']}{BN_str}-bn={attack_hparams['bn_reg']}-dataseed={args.data_seed}/Epoch_{EPOCH}" + try: + os.makedirs(ROOT_DIR, exist_ok=True) + except FileExistsError: + pass + print("storing in root dir", ROOT_DIR) + + if "InstaHideDefense" in defense_pack.defense_params.keys(): + cur_lams = defense_pack.instahide_defense.cur_lams.cpu().numpy() + cur_selects = defense_pack.instahide_defense.cur_selects.cpu().numpy() + np.savetxt(f"{ROOT_DIR}/epoch_lams.txt", cur_lams) + np.savetxt(f"{ROOT_DIR}/epoch_selects.txt", cur_selects.astype(int)) + elif "MixupDefense" in defense_pack.defense_params.keys(): + cur_lams = defense_pack.mixup_defense.cur_lams.cpu().numpy() + cur_selects = defense_pack.mixup_defense.cur_selects.cpu().numpy() + np.savetxt(f"{ROOT_DIR}/epoch_lams.txt", cur_lams) + np.savetxt(f"{ROOT_DIR}/epoch_selects.txt", cur_selects.astype(int)) + + return pipeline, attack_hparams + + +def run_attack(pipeline, attack_hparams): + """Launch the real attack""" + trainloader = pipeline.datamodule.train_dataloader() + model = pipeline.model + + for batch_idx, (batch_inputs, batch_targets) in enumerate(trainloader): + BATCH_ROOT_DIR = ROOT_DIR + f"/{batch_idx}" + os.makedirs(BATCH_ROOT_DIR, exist_ok=True) + save_fig(batch_inputs, + f"{BATCH_ROOT_DIR}/original.png", + save_npy=True, + save_fig=False) + save_fig(patch_image(batch_inputs), + f"{BATCH_ROOT_DIR}/original.png", + save_npy=False) + + batch_inputs, batch_targets = batch_inputs.to( + DEVICE), batch_targets.to(DEVICE) + + batch_gradients, step_results = model.get_batch_gradients( + (batch_inputs, batch_targets), + batch_idx, + eval_mode=attack_hparams["defender_eval_mode"], + apply_transforms=True, + stop_track_bn_stats=False, + BN_exact=attack_hparams["BN_exact"]) + batch_inputs_transform, batch_targets_transform = step_results[ + "transformed_batch"] + + save_fig( + batch_inputs_transform, + f"{BATCH_ROOT_DIR}/transformed.png", + save_npy=True, + save_fig=False, + ) + save_fig( + patch_image(batch_inputs_transform), + f"{BATCH_ROOT_DIR}/transformed.png", + save_npy=False, + ) + + attack = GradientReconstructor( + pipeline, + ground_truth_inputs=batch_inputs_transform, + ground_truth_gradients=batch_gradients, + ground_truth_labels=batch_targets_transform, + reconstruct_labels=attack_hparams["reconstruct_labels"], + num_iterations=10000, + signed_gradients=True, + signed_image=attack_hparams["signed_image"], + boxed=True, + total_variation=attack_hparams["total_variation"], + bn_reg=attack_hparams["bn_reg"], + lr_scheduler=True, + lr=attack_hparams["attack_lr"], + mean_std=(dm, ds), + attacker_eval_mode=attack_hparams["attacker_eval_mode"], + BN_exact=attack_hparams["BN_exact"]) + tb_logger = TensorBoardLogger(BATCH_ROOT_DIR, name="tb_log") + attack_trainer = pl.Trainer( + devices=1, + accelerator="auto", + logger=tb_logger, + max_epochs=1, + benchmark=True, + checkpoint_callback=False, + ) + attack_trainer.fit(attack) + result = attack.best_guess.detach().to("cpu") + + save_fig(result, + f"{BATCH_ROOT_DIR}/reconstructed.png", + save_npy=True, + save_fig=False) + save_fig(patch_image(result), + f"{BATCH_ROOT_DIR}/reconstructed.png", + save_npy=False) + + +if __name__ == "__main__": + pipeline, attack_hparams = setup_attack() + run_attack(pipeline, attack_hparams) diff --git a/examples/attack_cifar10_gradinversion.py b/examples/attack_cifar10_gradinversion.py index 590fc19..7b623dc 100644 --- a/examples/attack_cifar10_gradinversion.py +++ b/examples/attack_cifar10_gradinversion.py @@ -1,4 +1,5 @@ import os + import numpy as np import pytorch_lightning as pl import torch @@ -32,7 +33,7 @@ def setup_attack(): EPOCH = attack_hparams["epoch"] devices = [args.gpuid] - pl.utilities.seed.seed_everything(1234 + EPOCH) + pl.utilities.seed.seed_everything(42 + EPOCH) torch.backends.cudnn.benchmark = True BN_str = '' @@ -65,14 +66,14 @@ def setup_attack(): datamodule.num_classes, training_loss_metric=loss, pretrained=False, - ckpt="checkpoint/InstaHide_ckpt.ckpt", + # ckpt="checkpoint/InstaHide_ckpt.ckpt", **hparams).to(DEVICE) elif args.defense_mixup: model = create_lightning_module("ResNet18", datamodule.num_classes, training_loss_metric=loss, pretrained=False, - ckpt="checkpoint/Mixup_ckpt.ckpt", + # ckpt="checkpoint/Mixup_ckpt.ckpt", **hparams).to(DEVICE) else: model = create_lightning_module( @@ -80,11 +81,11 @@ def setup_attack(): datamodule.num_classes, training_loss_metric=loss, pretrained=False, - ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt", + # ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt", **hparams).to(DEVICE) logger = TensorBoardLogger("tb_logs", name=f"{args.logname}") - trainer = pl.Trainer(gpus=devices, benchmark=True, logger=logger) + trainer = pl.Trainer(devices=-1, accelerator="auto", benchmark=True, logger=logger) pipeline = TrainingPipeline(model, datamodule, trainer) defense_pack = DefensePack(args, logger) @@ -179,7 +180,7 @@ def run_attack(pipeline, attack_hparams): tb_logger = TensorBoardLogger(BATCH_ROOT_DIR, name="tb_log") attack_trainer = pl.Trainer( - gpus=devices, + devices=1, accelerator="auto", logger=tb_logger, max_epochs=1, benchmark=True, diff --git a/examples/attack_decode.py b/examples/attack_decode.py index d750f5f..cc79d32 100644 --- a/examples/attack_decode.py +++ b/examples/attack_decode.py @@ -1,5 +1,6 @@ """ -The decode step for InstaHide and MixUp is based on Carlini et al.'s attack: Is Private Learning Possible with Instance Encoding? (https://arxiv.org/pdf/2011.05315.pdf) +The decode step for InstaHide and MixUp is based on Carlini et al.'s attack: +Is Private Learning Possible with Instance Encoding? (https://arxiv.org/pdf/2011.05315.pdf) The implementation heavily relies on their code: https://github.com/carlini/privacy/commit/28b8a80924cf3766ab3230b5976388139ddef295 """ diff --git a/examples/train_cifar.py b/examples/train_cifar.py new file mode 100644 index 0000000..f5113f3 --- /dev/null +++ b/examples/train_cifar.py @@ -0,0 +1,98 @@ +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger + +from gradattack.datamodules import CIFAR100DataModule, CIFAR10DataModule +from gradattack.defenses.defense_utils import DefensePack +from gradattack.models import create_lightning_module +from gradattack.trainingpipeline import TrainingPipeline +from gradattack.utils import cross_entropy_for_onehot, parse_args, parse_augmentation + +torch.set_float32_matmul_precision("high") +cifar_dm = { + "CIFAR10": CIFAR10DataModule, + "CIFAR100": CIFAR100DataModule +} + +if __name__ == "__main__": + args, hparams, _ = parse_args() + method = "" + if args.defense_mixup: + method += 'mixup_' + elif args.defense_instahide: + method += 'instahide_' + elif args.defense_gradprune: + method += 'gradprune_' + else: + method += 'vanilla' + exp_name = f"{args.data}/{method}{args.scheduler}" + logger = WandbLogger( + project='GradAttack', + name=exp_name, + ) + + early_stop_callback = EarlyStopping( + monitor="val/loss_epoch", + min_delta=0.00, + patience=50, + verbose=False, + mode="min", + ) + callback = ModelCheckpoint( + exp_name, save_last=True, save_top_k=3, monitor="val/acc", mode="max", + ) + + augment = parse_augmentation(args) + + datamodule = cifar_dm[args.data]( + augment=augment, + batch_size=args.batch_size, + tune_on_val=args.tune_on_val, + batch_sampler=None, + ) + + if args.defense_instahide or args.defense_mixup: + loss = cross_entropy_for_onehot + else: + loss = torch.nn.CrossEntropyLoss() + + if "multihead" in args.model: + multi_head = True + loss = torch.nn.CrossEntropyLoss() + else: + multi_head = False + + model = create_lightning_module( + args.model, + datamodule.num_classes, + pretrained=args.pretrained, + ckpt=args.ckpt, + freeze_extractor=args.freeze_extractor, + training_loss_metric=loss, + log_auc=args.log_auc, + multi_class=datamodule.multi_class, + multi_head=multi_head, + log_dir=exp_name, + **hparams, + ) + + trainer = pl.Trainer( + default_root_dir=exp_name, + devices=1, + check_val_every_n_epoch=3, + accelerator='auto', + benchmark=True, + logger=logger, + num_sanity_val_steps=0, + max_epochs=args.n_epoch, + callbacks=[early_stop_callback, callback], + accumulate_grad_batches=args.n_accumulation_steps, + ) + pipeline = TrainingPipeline(model, datamodule, trainer) + + defense_pack = DefensePack(args) + defense_pack.apply_defense(pipeline) + + pipeline.run() + pipeline.test() diff --git a/examples/train_cifar10.py b/examples/train_cifar10.py deleted file mode 100644 index ec60701..0000000 --- a/examples/train_cifar10.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytorch_lightning as pl -import torch -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.loggers import TensorBoardLogger - -from gradattack.datamodules import CIFAR10DataModule -from gradattack.defenses.defense_utils import DefensePack -from gradattack.models import create_lightning_module -from gradattack.trainingpipeline import TrainingPipeline -from gradattack.utils import cross_entropy_for_onehot, parse_args, parse_augmentation - -if __name__ == "__main__": - - args, hparams, _ = parse_args() - - logger = TensorBoardLogger( - "tb_logs", name=f"{args.logname}/{args.optimizer}/{args.scheduler}") - devices = [args.gpuid] - - if args.early_stopping: - early_stop_callback = EarlyStopping( - monitor="epoch/val_loss", - min_delta=0.00, - patience=args.patience, - verbose=False, - mode="min", - ) - - augment = parse_augmentation(args) - - assert args.data == "CIFAR10" - - datamodule = CIFAR10DataModule( - augment=augment, - batch_size=args.batch_size, - tune_on_val=args.tune_on_val, - batch_sampler=None, - ) - - if args.defense_instahide or args.defense_mixup: - loss = cross_entropy_for_onehot - else: - loss = torch.nn.CrossEntropyLoss(reduction="mean") - - if "multihead" in args.model: - multi_head = True - loss = torch.nn.CrossEntropyLoss(reduction="mean") - else: - multi_head = False - - model = create_lightning_module( - args.model, - datamodule.num_classes, - pretrained=args.pretrained, - ckpt=args.ckpt, - freeze_extractor=args.freeze_extractor, - training_loss_metric=loss, - log_auc=args.log_auc, - multi_class=datamodule.multi_class, - multi_head=multi_head, - **hparams, - ) - - trainer = pl.Trainer( - gpus=devices, - check_val_every_n_epoch=1, - logger=logger, - max_epochs=args.n_epoch, - callbacks=[early_stop_callback], - accumulate_grad_batches=args.n_accumulation_steps, - ) - pipeline = TrainingPipeline(model, datamodule, trainer) - - defense_pack = DefensePack(args, logger) - defense_pack.apply_defense(pipeline) - - pipeline.run() - pipeline.test() \ No newline at end of file diff --git a/gradattack/attacks/attack.py b/gradattack/attacks/attack.py index 7749c00..ff3ab58 100644 --- a/gradattack/attacks/attack.py +++ b/gradattack/attacks/attack.py @@ -2,23 +2,22 @@ from typing import Callable import torch -from numpy.lib.npyio import load -from gradattack.trainingpipeline import TrainingPipeline -from .invertinggradients.inversefed.reconstruction_algorithms import ( - GradientReconstructor, ) +from gradattack.trainingpipeline import TrainingPipeline +from .gradientinversion import GradientReconstructor class GradientInversionAttack: """Wrapper around Gradient Inversion attack""" + def __init__( - self, - pipeline: TrainingPipeline, - dm, - ds, - device: torch.device, - loss_metric: Callable, - reconstructor_args: dict = None, + self, + pipeline: TrainingPipeline, + dm, + ds, + device: torch.device, + loss_metric: Callable, + reconstructor_args: dict = None, ): self.pipeline = pipeline self.device = device @@ -58,5 +57,5 @@ def run_from_dump(self, filepath: str, dataset: torch.utils.data.Dataset): self.model.load_state_dict(loaded_data["model_state_dict"]) batch_inputs, batch_targets = ( dataset[i][0] for i in loaded_data["batch_indices"]), ( - dataset[i][1] for i in loaded_data["batch_indices"]) + dataset[i][1] for i in loaded_data["batch_indices"]) return self.run_attack_batch(batch_inputs, batch_targets) diff --git a/gradattack/attacks/gradientinversion.py b/gradattack/attacks/gradientinversion.py index 444dedc..d60747e 100644 --- a/gradattack/attacks/gradientinversion.py +++ b/gradattack/attacks/gradientinversion.py @@ -1,34 +1,18 @@ import copy -from typing import Any, Callable, Optional +from typing import Any, Callable -import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F -import torchmetrics +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset +from torchmetrics.functional import peak_signal_noise_ratio + from gradattack.metrics.gradients import CosineSimilarity, L2Diff from gradattack.metrics.pixelwise import MeanPixelwiseError -from gradattack.models import LightningWrapper from gradattack.trainingpipeline import TrainingPipeline from gradattack.utils import patch_image -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset - -# DEFAULT_HPARAMS = { -# "optimizer": "Adam", -# "lr_scheduler": True, -# "lr": 0.1, -# "total_variation": 1e-1, -# "l2": 0, -# "bn_reg": 0, -# "first_bn_multiplier": 10, -# # If true, will apply image priors on the absolute value of the recovered images -# "signed_image": False, -# "signed_gradients": True, -# "boxed": True, -# "attacker_eval_mode": True, -# "recipe": 'Geiping' -# } class BNFeatureHook: @@ -36,6 +20,7 @@ class BNFeatureHook: Implementation of the forward hook to track feature statistics and compute a loss on them. Will compute mean and variance, and will use l2 as a loss """ + def __init__(self, module): self.hook = module.register_forward_hook(self.hook_fn) @@ -90,32 +75,30 @@ def __getitem__(self, idx: int): class GradientReconstructor(pl.LightningModule): - def __init__( - self, - pipeline: TrainingPipeline, - ground_truth_inputs: tuple, - ground_truth_gradients: tuple, - ground_truth_labels: tuple, - intial_reconstruction: torch.tensor = None, - reconstruct_labels=False, - attack_loss_metric: Callable = CosineSimilarity(), - mean_std: tuple = (0.0, 1.0), - num_iterations=10000, - optimizer: str = "Adam", - lr_scheduler: bool = True, - lr: float = 0.1, - total_variation: float = 1e-1, - l2: float = 0, - bn_reg: float = 0, - first_bn_multiplier: float = 1, - signed_image: bool = False, - signed_gradients: bool = True, - boxed: bool = True, - attacker_eval_mode: bool = True, - recipe: str = 'Geiping', - BN_exact: bool = False, - grayscale: bool = False, - ): + def __init__(self, + pipeline: TrainingPipeline, + ground_truth_inputs: tuple, + ground_truth_gradients: tuple, + ground_truth_labels: tuple, + intial_reconstruction: torch.tensor = None, + reconstruct_labels=False, + attack_loss_metric: Callable = CosineSimilarity(), + mean_std: tuple = (0.0, 1.0), + num_iterations=10000, + optimizer: str = "Adam", + lr_scheduler: bool = True, + lr: float = 0.1, + total_variation: float = 1e-1, + l2: float = 0, + bn_reg: float = 0, + first_bn_multiplier: float = 1, + signed_image: bool = False, + signed_gradients: bool = True, + boxed: bool = True, + attacker_eval_mode: bool = True, + recipe: str = 'Geiping', + BN_exact: bool = False, + grayscale: bool = False): super().__init__() self.save_hyperparameters("optimizer", "lr_scheduler", "lr", "total_variation", "l2", "bn_reg", @@ -177,23 +160,15 @@ def __init__( class loss_fn(torch.nn.Module): def __call__(self, pred, labels): if len(labels.shape) >= 2: - labels = torch.nn.functional.softmax(labels, dim=-1) - return torch.mean( - torch.sum( - -labels * - torch.nn.functional.log_softmax(pred, dim=-1), - 1, - )) + labels = F.softmax(labels, dim=-1) + return torch.mean(torch.sum(-labels * F.log_softmax(pred, dim=-1), 1)) else: - return torch.nn.functional.cross_entropy(pred, labels) + return F.cross_entropy(pred, labels) - self._model._training_loss_metric = None self._model._training_loss_metric = loss_fn() else: self._reconstruct_labels = False - self._batch_transferred = False - self.loss_r_feature_layers = [] for module in self._model.modules(): @@ -272,17 +247,6 @@ def train_dataloader(self) -> Any: return DataLoader( DummyGradientDataset(num_values=self.num_iterations, ), ) - def transfer_batch_to_device(self, batch: Any, - device: Optional[torch.device]) -> Any: - if not self._batch_transferred: - self.ground_truth_labels = self.ground_truth_labels.detach().to( - self.device) - self.ground_truth_gradients = tuple( - x.detach().to(self.device) - for x in self.ground_truth_gradients) - self._batch_transferred = True - return (self.ground_truth_gradients, self.ground_truth_labels) - def training_step(self, batch, *args): input_gradients, labels = batch @@ -302,25 +266,19 @@ def _closure(): attacker=True, ) if self.recipe == 'Geiping': - reconstruction_loss = self._attack_loss_metric( + reconst_loss = self._attack_loss_metric( recovered_gradients, input_gradients) - reconstruction_loss += self.hparams[ - "total_variation"] * total_variation( - self.best_guess, self.hparams["signed_image"]) - reconstruction_loss += self.hparams["l2"] * l2_norm( + reconst_loss += self.hparams["total_variation"] * total_variation( + self.best_guess, self.hparams["signed_image"]) + reconst_loss += self.hparams["l2"] * l2_norm( self.best_guess, self.hparams["signed_image"]) elif self.recipe == 'Zhu': ## TODO: test self._attack_loss_metric = L2Diff() - reconstruction_loss = self._attack_loss_metric( + reconst_loss = self._attack_loss_metric( recovered_gradients, input_gradients) - recon_mean = [ - mod.mean - for (idx, mod) in enumerate(self.loss_r_feature_layers) - ] - recon_var = [ - mod.var for (idx, mod) in enumerate(self.loss_r_feature_layers) - ] + recon_mean = [mod.mean for (idx, mod) in enumerate(self.loss_r_feature_layers)] + recon_var = [mod.var for (idx, mod) in enumerate(self.loss_r_feature_layers)] if self.hparams["bn_reg"] > 0: rescale = [self.hparams["first_bn_multiplier"]] + [ @@ -330,16 +288,14 @@ def _closure(): mod.r_feature * rescale[idx] for (idx, mod) in enumerate(self.loss_r_feature_layers) ]) - - reconstruction_loss += self.hparams["bn_reg"] * loss_r_feature + reconst_loss += self.hparams["bn_reg"] * loss_r_feature self.logger.experiment.add_scalar("Loss", step_results["loss"], global_step=self.global_step) self.logger.experiment.add_scalar( - "Reconstruction Metric Loss", - reconstruction_loss, - global_step=self.global_step, + "Reconstruction Metric Loss", reconst_loss, + global_step=self.global_step ) if self.global_step % 100 == 0: @@ -347,24 +303,24 @@ def _closure(): self.logger.experiment.add_scalar( f"BN_loss/layer_{i}_mean_loss", torch.sqrt( - sum((recon_mean[i] - self.mean_gt[i])**2) / + sum((recon_mean[i] - self.mean_gt[i]) ** 2) / len(recon_mean[i])), global_step=self.global_step, ) self.logger.experiment.add_scalar( f"BN_loss/layer_{i}_var_loss", torch.sqrt( - sum((recon_var[i] - self.var_gt[i])**2) / + sum((recon_var[i] - self.var_gt[i]) ** 2) / len(recon_mean[i])), global_step=self.global_step, ) - self.manual_backward(reconstruction_loss, self.optimizer) + self.manual_backward(reconst_loss) if self.hparams["signed_gradients"]: if self.grayscale: self.best_guess_grayscale.grad.sign_() else: self.best_guess.grad.sign_() - return reconstruction_loss + return reconst_loss reconstruction_loss = self.optimizer.step(closure=_closure) if self.hparams["lr_scheduler"]: @@ -398,14 +354,10 @@ def _closure(): ), global_step=self.global_step, ) - psnrs = [ - torchmetrics.functional.psnr(a, b) - for (a, - b) in zip(self.best_guess, self.ground_truth_inputs) - ] + psnrs = [peak_signal_noise_ratio(a, b) + for (a, b) in zip(self.best_guess, self.ground_truth_inputs)] avg_psnr = sum(psnrs) / self.num_images - self.logger.experiment.add_scalar("Avg. PSNR", - avg_psnr, + self.logger.experiment.add_scalar("Avg. PSNR", avg_psnr, global_step=self.global_step) rmses = [ @@ -437,9 +389,8 @@ def configure_optimizers(self): self.best_guess = self.best_guess.to(self.device) self.labels = self.labels.to(self.device) if self.grayscale: - parameters = ([ - self.best_guess_grayscale, self.labels - ] if self._reconstruct_labels else [self.best_guess_grayscale]) + parameters = ([self.best_guess_grayscale, self.labels + ] if self._reconstruct_labels else [self.best_guess_grayscale]) else: parameters = ([self.best_guess, self.labels] if self._reconstruct_labels else [self.best_guess]) @@ -459,7 +410,7 @@ def configure_optimizers(self): def configure_lr_scheduler(self): if self.hparams["lr_scheduler"]: - self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + self.lr_scheduler = MultiStepLR( self.optimizer, milestones=[ self.num_iterations // 2.667, diff --git a/gradattack/datamodules.py b/gradattack/datamodules.py index a01091c..96582dc 100644 --- a/gradattack/datamodules.py +++ b/gradattack/datamodules.py @@ -10,8 +10,7 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset from torch.utils.data.sampler import Sampler -from torchvision.datasets.cifar import CIFAR10 -from torchvision.datasets import MNIST +from torchvision.datasets import CIFAR10, CIFAR100, MNIST DEFAULT_DATA_DIR = "./data" DEFAULT_NUM_WORKERS = 32 @@ -22,6 +21,15 @@ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] +DATASET_NORM = { + 'mnist': transforms.Normalize((0.1307,), (0.3081,)), + 'imagenet': transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + 'cifar100': transforms.Normalize((0.50705882, 0.48666667, 0.44078431), + (0.26745098, 0.25647059, 0.27607843)), + 'cifar10': transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)) +} def train_val_split(dataset_size: int, val_train_split: float = 0.02): @@ -35,11 +43,11 @@ def train_val_split(dataset_size: int, val_train_split: float = 0.02): def extract_attack_set( - dataset: Dataset, - sample_per_class: int = 5, - multi_class=False, - total_num_samples: int = 50, - seed: int = None, + dataset: Dataset, + sample_per_class: int = 5, + multi_class=False, + total_num_samples: int = 50, + seed: int = None, ): if not multi_class: num_classes = len(dataset.classes) @@ -66,12 +74,12 @@ def extract_attack_set( class FileDataModule(LightningDataModule): def __init__( - self, - data_dir: str = DEFAULT_DATA_DIR, - transform: torch.nn.Module = transforms.Compose(TRANSFORM_IMAGENET), - batch_size: int = 32, - num_workers: int = DEFAULT_NUM_WORKERS, - batch_sampler: Sampler = None, + self, + data_dir: str = DEFAULT_DATA_DIR, + transform: torch.nn.Module = transforms.Compose(TRANSFORM_IMAGENET), + batch_size: int = 32, + num_workers: int = DEFAULT_NUM_WORKERS, + batch_sampler: Sampler = None, ): self.data_dir = data_dir self.batch_size = batch_size @@ -96,40 +104,44 @@ def test_dataloader(self): return self.get_dataloader() -class ImageNetDataModule(LightningDataModule): +class BaseDataModule(LightningDataModule): def __init__( - self, - augment: dict = None, - data_dir: str = os.path.join(DEFAULT_DATA_DIR, "imagenet"), - batch_size: int = 32, - num_workers: int = DEFAULT_NUM_WORKERS, - batch_sampler: Sampler = None, - tune_on_val: bool = False, + self, + augment: dict = None, + data_dir: str = DEFAULT_DATA_DIR, + batch_size: int = 32, + num_workers: int = DEFAULT_NUM_WORKERS, + batch_sampler: Sampler = None, + tune_on_val: bool = False, ): + super().__init__() self.data_dir = data_dir self.batch_size = batch_size self.num_workers = num_workers - self.num_classes = 1000 - self.multi_class = False self.batch_sampler = batch_sampler self.tune_on_val = tune_on_val - + self._train_transforms = self.train_transform(augment) + print(self._train_transforms) + self._test_transforms = self.init_transform print(data_dir) - imagenet_normalize = transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) - self._train_transforms = [ + @property + def init_transform(self): + return [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), - imagenet_normalize, + DATASET_NORM[self.DATASET_NAME], ] + + def train_transform(self, augment): + train_transforms = self.init_transform if augment["hflip"]: - self._train_transforms.insert( + train_transforms.insert( 0, transforms.RandomHorizontalFlip(p=0.5)) if augment["color_jitter"] is not None: - self._train_transforms.insert( + train_transforms.insert( 0, transforms.ColorJitter( brightness=augment["color_jitter"][0], @@ -139,20 +151,48 @@ def __init__( ), ) if augment["rotation"] > 0: - self._train_transforms.insert( + train_transforms.insert( 0, transforms.RandomRotation(augment["rotation"])) if augment["crop"]: - self._train_transforms.insert(0, - transforms.RandomCrop(32, padding=4)) + train_transforms.insert(0, transforms.RandomCrop(32, padding=4)) + return train_transforms - print(self._train_transforms) + @property + def num_classes(self): + return None - self._test_transforms = [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - imagenet_normalize, - ] + def train_dataloader(self): + if self.batch_sampler is None: + return DataLoader(self.train_set, + batch_size=self.batch_size, + num_workers=self.num_workers) + else: + return DataLoader( + self.train_set, + batch_sampler=self.batch_sampler, + num_workers=self.num_workers, + pin_memory=True + ) + + def val_dataloader(self): + return DataLoader(self.val_set, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True) + + def test_dataloader(self): + return DataLoader(self.test_set, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True) + + +class ImageNetDataModule(BaseDataModule): + DATASET_NAME = 'imagenet' + + @property + def num_classes(self): + return 1000 def setup(self, stage: Optional[str] = None): """Initialize the dataset based on the stage option ('fit', 'test' or 'attack'): @@ -199,90 +239,36 @@ def setup(self, stage: Optional[str] = None): ori_train_set) self.train_set = Subset(ori_train_set, self.attack_indices) - def train_dataloader(self): - if self.batch_sampler is None: - return DataLoader(self.train_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - else: - return DataLoader( - self.train_set, - batch_sampler=self.batch_sampler, - num_workers=self.num_workers, - ) - def val_dataloader(self): - return DataLoader(self.val_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - - def test_dataloader(self): - return DataLoader(self.test_set, - batch_size=self.batch_size, - num_workers=self.num_workers) +class MNISTDataModule(BaseDataModule): + DATASET_NAME = 'mnist' - -class MNISTDataModule(LightningDataModule): def __init__( - self, - augment: dict = None, - batch_size: int = 32, - data_dir: str = DEFAULT_DATA_DIR, - num_workers: int = DEFAULT_NUM_WORKERS, - batch_sampler: Sampler = None, - tune_on_val: float = 0, + self, + augment: dict = None, + batch_size: int = 32, + data_dir: str = DEFAULT_DATA_DIR, + num_workers: int = DEFAULT_NUM_WORKERS, + batch_sampler: Sampler = None, + tune_on_val: float = 0, ): - super().__init__() + super().__init__(augment, data_dir, batch_size, num_workers, batch_sampler, tune_on_val) self._has_setup_attack = False - - self.data_dir = data_dir - self.batch_size = batch_size - self.num_workers = num_workers self.dims = (3, 32, 32) - self.num_classes = 10 - - self.batch_sampler = batch_sampler - self.tune_on_val = tune_on_val self.multi_class = False - mnist_normalize = transforms.Normalize((0.1307, ), (0.3081, )) - - self._train_transforms = [ - transforms.Resize(32), - transforms.Grayscale(3), - transforms.ToTensor(), - mnist_normalize, - ] - if augment["hflip"]: - self._train_transforms.insert( - 0, transforms.RandomHorizontalFlip(p=0.5)) - if augment["color_jitter"] is not None: - self._train_transforms.insert( - 0, - transforms.ColorJitter( - brightness=augment["color_jitter"][0], - contrast=augment["color_jitter"][1], - saturation=augment["color_jitter"][2], - hue=augment["color_jitter"][3], - ), - ) - if augment["rotation"] > 0: - self._train_transforms.insert( - 0, transforms.RandomRotation(augment["rotation"])) - if augment["crop"]: - self._train_transforms.insert(0, - transforms.RandomCrop(32, padding=4)) - - print(self._train_transforms) - - self._test_transforms = [ + @property + def init_transform(self): + return [ transforms.Resize(32), transforms.Grayscale(3), transforms.ToTensor(), - mnist_normalize, + DATASET_NORM[self.DATASET_NAME], ] - self.prepare_data() + @property + def num_classes(self): + return 10 def prepare_data(self): MNIST(self.data_dir, train=True, download=True) @@ -359,92 +345,43 @@ def setup(self, stage: Optional[str] = None): self.train_set = Subset(ori_train_set, self.attack_indices) self.test_set = Subset(self.test_set, range(100)) - def train_dataloader(self): - if self.batch_sampler is None: - return DataLoader( - self.train_set, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - else: - return DataLoader( - self.train_set, - batch_sampler=self.batch_sampler, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self): - return DataLoader(self.val_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - - def test_dataloader(self): - return DataLoader(self.test_set, - batch_size=self.batch_size, - num_workers=self.num_workers) +class CIFAR10DataModule(BaseDataModule): + DATASET_NAME = 'cifar10' -class CIFAR10DataModule(LightningDataModule): def __init__( - self, - augment: dict = None, - batch_size: int = 32, - data_dir: str = DEFAULT_DATA_DIR, - num_workers: int = DEFAULT_NUM_WORKERS, - batch_sampler: Sampler = None, - tune_on_val: float = 0, - seed: int = None, + self, + augment: dict = None, + batch_size: int = 32, + data_dir: str = DEFAULT_DATA_DIR, + num_workers: int = DEFAULT_NUM_WORKERS, + batch_sampler: Sampler = None, + tune_on_val: float = 0, + seed: int = None, ): - super().__init__() + super().__init__(augment, data_dir, batch_size, num_workers, batch_sampler, + tune_on_val) self._has_setup_attack = False + self.seed = seed - self.data_dir = data_dir - self.batch_size = batch_size - self.num_workers = num_workers self.dims = (3, 32, 32) - self.num_classes = 10 - - self.batch_sampler = batch_sampler - self.tune_on_val = tune_on_val self.multi_class = False - self.seed = seed - - cifar_normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)) - - self._train_transforms = [transforms.ToTensor(), cifar_normalize] - if augment["hflip"]: - self._train_transforms.insert( - 0, transforms.RandomHorizontalFlip(p=0.5)) - if augment["color_jitter"] is not None: - self._train_transforms.insert( - 0, - transforms.ColorJitter( - brightness=augment["color_jitter"][0], - contrast=augment["color_jitter"][1], - saturation=augment["color_jitter"][2], - hue=augment["color_jitter"][3], - ), - ) - if augment["rotation"] > 0: - self._train_transforms.insert( - 0, transforms.RandomRotation(augment["rotation"])) - if augment["crop"]: - self._train_transforms.insert(0, - transforms.RandomCrop(32, padding=4)) + self.prepare_data() - print(self._train_transforms) + @property + def init_transform(self): + return [transforms.ToTensor(), DATASET_NORM[self.DATASET_NAME]] - self._test_transforms = [transforms.ToTensor(), cifar_normalize] + def base_dataset(self, root, **kwargs): + return CIFAR10(root, **kwargs) - self.prepare_data() + @property + def num_classes(self): + return 10 def prepare_data(self): """Download the data""" - CIFAR10(self.data_dir, train=True, download=True) - CIFAR10(self.data_dir, train=False, download=True) + CIFAR10(self.data_dir, download=True) def setup(self, stage: Optional[str] = None): """Initialize the dataset based on the stage option ('fit', 'test' or 'attack'): @@ -456,13 +393,11 @@ def setup(self, stage: Optional[str] = None): stage (Optional[str], optional): stage option. Defaults to None. """ if stage == "fit" or stage is None: - self.train_set = CIFAR10( - self.data_dir, - train=True, - transform=transforms.Compose(self._train_transforms), - ) + self.train_set = self.base_dataset(self.data_dir, + train=True, + transform=transforms.Compose(self._train_transforms)) if self.tune_on_val: - self.val_set = CIFAR10( + self.val_set = self.base_dataset( self.data_dir, train=True, transform=transforms.Compose(self._test_transforms), @@ -472,7 +407,7 @@ def setup(self, stage: Optional[str] = None): self.train_set = Subset(self.train_set, train_indices) self.val_set = Subset(self.val_set, val_indices) else: - self.val_set = CIFAR10( + self.val_set = self.base_dataset( self.data_dir, train=False, transform=transforms.Compose(self._test_transforms), @@ -480,14 +415,14 @@ def setup(self, stage: Optional[str] = None): # Assign test dataset for use in dataloader(s) if stage == "test" or stage is None: - self.test_set = CIFAR10( + self.test_set = self.base_dataset( self.data_dir, train=False, transform=transforms.Compose(self._test_transforms), ) if stage == "attack": - ori_train_set = CIFAR10( + ori_train_set = self.base_dataset( self.data_dir, train=True, transform=transforms.Compose(self._train_transforms), @@ -495,9 +430,9 @@ def setup(self, stage: Optional[str] = None): self.attack_indices, self.class2attacksample = extract_attack_set( ori_train_set, seed=self.seed) self.train_set = Subset(ori_train_set, self.attack_indices) - self.test_set = Subset(self.test_set, range(100)) + self.test_set = Subset(self.test_set, range(200)) elif stage == "attack_mini": - ori_train_set = CIFAR10( + ori_train_set = self.base_dataset( self.data_dir, train=True, transform=transforms.Compose(self._train_transforms), @@ -505,9 +440,9 @@ def setup(self, stage: Optional[str] = None): self.attack_indices, self.class2attacksample = extract_attack_set( ori_train_set, sample_per_class=2) self.train_set = Subset(ori_train_set, self.attack_indices) - self.test_set = Subset(self.test_set, range(100)) + self.test_set = Subset(self.test_set, range(200)) elif stage == "attack_large": - ori_train_set = CIFAR10( + ori_train_set = self.base_dataset( self.data_dir, train=True, transform=transforms.Compose(self._train_transforms), @@ -515,26 +450,19 @@ def setup(self, stage: Optional[str] = None): self.attack_indices, self.class2attacksample = extract_attack_set( ori_train_set, sample_per_class=500) self.train_set = Subset(ori_train_set, self.attack_indices) - self.test_set = Subset(self.test_set, range(100)) + self.test_set = Subset(self.test_set, range(200)) - def train_dataloader(self): - if self.batch_sampler is None: - return DataLoader(self.train_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - else: - return DataLoader( - self.train_set, - batch_sampler=self.batch_sampler, - num_workers=self.num_workers, - ) - def val_dataloader(self): - return DataLoader(self.val_set, - batch_size=self.batch_size, - num_workers=self.num_workers) +class CIFAR100DataModule(CIFAR10DataModule): + DATASET_NAME = 'cifar100' - def test_dataloader(self): - return DataLoader(self.test_set, - batch_size=self.batch_size, - num_workers=self.num_workers) + def prepare_data(self): + """Download the data""" + CIFAR100(self.data_dir, download=True) + + @property + def num_classes(self): + return 100 + + def base_dataset(self, root, **kwargs): + return CIFAR100(root, **kwargs) diff --git a/gradattack/defenses/__init__.py b/gradattack/defenses/__init__.py index 1a6eca9..27c9f33 100644 --- a/gradattack/defenses/__init__.py +++ b/gradattack/defenses/__init__.py @@ -1,6 +1,6 @@ -from .defense import * -from .defense_utils import * -from .dpsgd import * -from .gradprune import * -from .mixup import * -from .instahide import * +from .defense import GradientDefense +from .defense_utils import DefensePack +from .dpsgd import DPSGDDefense +from .gradprune import GradPruneDefense +from .instahide import InstahideDefense +from .mixup import MixupDefense diff --git a/gradattack/defenses/defense.py b/gradattack/defenses/defense.py index 141af0d..60ac02b 100644 --- a/gradattack/defenses/defense.py +++ b/gradattack/defenses/defense.py @@ -5,6 +5,7 @@ class GradientDefense: """Applies a gradient defense to a given pipeline. **WARNING** This may modify the pipeline via monkey-patching of defenses! Please use with care.""" + def apply(self, pipeline: TrainingPipeline): assert (self.defense_name not in pipeline.applied_defenses ), f"Tried to apply duplicate defense {self.defense_name}!" diff --git a/gradattack/defenses/defense_utils.py b/gradattack/defenses/defense_utils.py index 86bec2e..a5351e3 100644 --- a/gradattack/defenses/defense_utils.py +++ b/gradattack/defenses/defense_utils.py @@ -1,11 +1,7 @@ -from typing import Any, Dict, Iterable, List, Optional, Union - import torch -from colorama import Back, Fore, Style, init -from torch.utils.data.dataset import Dataset +from colorama import Fore, Style, init from gradattack.trainingpipeline import TrainingPipeline - from .dpsgd import DPSGDDefense from .gradprune import GradPruneDefense from .instahide import InstahideDefense @@ -16,10 +12,9 @@ class DefensePack: - def __init__(self, args, logger=None): + def __init__(self, args): self.defense_params = {} self.parse_defense_params(args) - self.logger = logger # this might be useful for logging DP prarameters in the future def apply_defense(self, pipeline: TrainingPipeline): dataset = pipeline.datamodule @@ -108,7 +103,6 @@ def parse_defense_params(self, args, verbose=True): print(Fore.MAGENTA + f"{key}: {val}", end="\t") else: print(Fore.MAGENTA + "None", end="\t") - print() def get_defensepack_str(self): def get_param_str(paramname): diff --git a/gradattack/defenses/dpsgd.py b/gradattack/defenses/dpsgd.py index 7b8327e..7fde214 100644 --- a/gradattack/defenses/dpsgd.py +++ b/gradattack/defenses/dpsgd.py @@ -1,15 +1,12 @@ # DPSGD defense. The implementation of DPSGD is based on Opacus: https://opacus.ai/ -import time -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import List -import numpy as np import pytorch_lightning as pl -import torch from opacus import PrivacyEngine -from opacus.utils.module_modification import convert_batchnorm_modules from opacus.utils.uniform_sampler import UniformWithReplacementSampler from pytorch_lightning.core.datamodule import LightningDataModule + from gradattack.defenses.defense import GradientDefense from gradattack.models import StepTracker from gradattack.trainingpipeline import TrainingPipeline @@ -31,17 +28,18 @@ class DPSGDDefense(GradientDefense): secure_rng (bool, optional): If on, it will use ``torchcsprng`` for secure random number generation. Comes with a significant performance cost. Defaults to False. freeze_extractor (bool, optional): If on, only finetune the classifier (the final fully-connected layers). Defaults to False. """ + def __init__( - self, - mini_batch_size: int, - sample_size: int, - n_accumulation_steps: int, - max_grad_norm: float, - noise_multiplier: float, - delta_list: List[float] = [1e-3, 1e-4, 1e-5], - max_epsilon: float = 2, - secure_rng: bool = False, - freeze_extractor: bool = False, + self, + mini_batch_size: int, + sample_size: int, + n_accumulation_steps: int, + max_grad_norm: float, + noise_multiplier: float, + delta_list: List[float] = [1e-3, 1e-4, 1e-5], + max_epsilon: float = 2, + secure_rng: bool = False, + freeze_extractor: bool = False, ): super().__init__() diff --git a/gradattack/defenses/instahide.py b/gradattack/defenses/instahide.py index 32eb15e..6ccd3a1 100644 --- a/gradattack/defenses/instahide.py +++ b/gradattack/defenses/instahide.py @@ -4,7 +4,6 @@ from torch.nn.functional import one_hot from torch.utils.data.dataset import Dataset -import torchcsprng as csprng from gradattack.defenses.defense import GradientDefense from gradattack.trainingpipeline import TrainingPipeline @@ -51,11 +50,11 @@ def __init__(self, torch.tensor(self.alpha).repeat(self.dataset_size, 1)) self.use_csprng = use_csprng - if self.use_csprng: - if cs_prng is None: - self.cs_prng = csprng.create_random_device_generator() - else: - self.cs_prng = cs_prng + # if self.use_csprng: + # if cs_prng is None: + # self.cs_prng = csprng.create_random_device_generator() + # else: + # self.cs_prng = cs_prng # @profile def generate_mapping(self, return_tensor=True): @@ -124,10 +123,10 @@ def generate_mapping(self, return_tensor=True): return np.asarray(lams), np.asarray(selects) def instahide_batch( - self, - inputs: torch.tensor, - lams_b: float, - selects_b: np.array, + self, + inputs: torch.tensor, + lams_b: float, + selects_b: np.array, ): """Generate an InstaHide batch. @@ -139,7 +138,7 @@ def instahide_batch( Returns: (torch.tensor): the InstaHide images and labels """ - mixed_x = torch.zeros_like(inputs) + mixed_x = torch.zeros_like(inputs, device=self.device) mixed_y = torch.zeros((len(inputs), self.num_classes), device=self.device) diff --git a/gradattack/defenses/mixup.py b/gradattack/defenses/mixup.py index 7b7a956..2f8deb9 100644 --- a/gradattack/defenses/mixup.py +++ b/gradattack/defenses/mixup.py @@ -1,10 +1,10 @@ import numpy as np import torch +# import torchcsprng as csprng from torch.distributions.dirichlet import Dirichlet from torch.nn.functional import one_hot from torch.utils.data.dataset import Dataset -import torchcsprng as csprng from gradattack.defenses.defense import GradientDefense from gradattack.trainingpipeline import TrainingPipeline @@ -50,12 +50,12 @@ def __init__(self, self.lambda_sampler_whole = Dirichlet( torch.tensor(self.alpha).repeat(self.dataset_size, 1)) self.use_csprng = use_csprng - - if self.use_csprng: - if cs_prng is None: - self.cs_prng = csprng.create_random_device_generator() - else: - self.cs_prng = cs_prng + # + # if self.use_csprng: + # if cs_prng is None: + # self.cs_prng = csprng.create_random_device_generator() + # else: + # self.cs_prng = cs_prng # @profile def generate_mapping(self, return_tensor=True): @@ -99,8 +99,7 @@ def generate_mapping(self, return_tensor=True): lams = self.lambda_sampler_whole.sample().to(self.device) selects = torch.stack([ torch.randperm(self.dataset_size, - device=self.device, - generator=self.cs_prng) + device=self.device) for _ in range(self.klam) ]) selects = torch.transpose(selects, 0, 1) @@ -123,12 +122,7 @@ def generate_mapping(self, return_tensor=True): else: return np.asarray(lams), np.asarray(selects) - def mixup_batch( - self, - inputs: torch.tensor, - lams_b: float, - selects_b: np.array, - ): + def mixup_batch(self, inputs: torch.tensor, lams_b: float, selects_b: np.array): """Generate a MixUp batch. Args: @@ -139,7 +133,7 @@ def mixup_batch( Returns: (torch.tensor): the MixUp images and labels """ - mixed_x = torch.zeros_like(inputs) + mixed_x = torch.zeros_like(inputs, device=self.device) mixed_y = torch.zeros((len(inputs), self.num_classes), device=self.device) diff --git a/gradattack/models/__init__.py b/gradattack/models/__init__.py index b88c6b4..26b2674 100755 --- a/gradattack/models/__init__.py +++ b/gradattack/models/__init__.py @@ -1,27 +1,21 @@ -# FIXME: @Samyak, could you please help add docstring to this file? Thanks! import os -import time -from typing import Any, Callable, Optional +from typing import Callable import pytorch_lightning as pl -import torch.nn.functional as F -import torchvision.models as models -from gradattack.utils import StandardizeLayer -from sklearn import metrics -from torch.nn import init from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, ReduceLROnPlateau, StepLR +from gradattack.utils import StandardizeLayer +from .LeNet import * from .covidmodel import * from .densenet import * from .googlenet import * from .mobilenet import * +from .multihead_resnet import * from .nasnet import * from .resnet import * from .resnext import * from .simple import * from .vgg import * -from .multihead_resnet import * -from .LeNet import * class StepTracker: @@ -38,23 +32,24 @@ def end(self, deduction: float = 0): class LightningWrapper(pl.LightningModule): """Wraps a torch module in a pytorch-lightning module. Any .""" + def __init__( - self, - model: torch.nn.Module, - training_loss_metric: Callable = F.mse_loss, - optimizer: str = "SGD", - lr_scheduler: str = "ReduceLROnPlateau", - tune_on_val: float = 0.02, - lr_factor: float = 0.5, - lr_step: int = 10, - batch_size: int = 64, - lr: float = 0.05, - momentum: float = 0.9, - weight_decay: float = 5e-4, - nesterov: bool = False, - log_auc: bool = False, - multi_class: bool = False, - multi_head: bool = False, + self, + model: torch.nn.Module, + training_loss_metric: Callable = F.mse_loss, + optimizer: str = "SGD", + lr_scheduler: str = "ReduceLROnPlateau", + tune_on_val: float = 0.02, + lr_factor: float = 0.5, + lr_step: int = 10, + batch_size: int = 64, + lr: float = 0.05, + momentum: float = 0.9, + weight_decay: float = 5e-4, + nesterov: bool = False, + log_auc: bool = False, + multi_class: bool = False, + multi_head: bool = False, ): super().__init__() # if we didn't copy here, then we would modify the default dict by accident @@ -126,7 +121,6 @@ def _compute_training_step(self, Args: batch : The batch inputs. Should be a torch tensor with outermost dimension 2, where dimension 0 corresponds to inputs and dimension 1 corresponds to labels. - Returns: dict: The results from the training step. Is a dictionary with keys "loss", "transformed_batch", and "model_outputs". """ @@ -158,14 +152,15 @@ def training_step(self, batch, batch_idx, *_) -> dict: self.manual_backward(training_step_results["loss"]) - if self.should_accumulate(): - # Special case opacus optimizers to reduce memory footprint - # see: (https://github.com/pytorch/opacus/blob/244265582bffbda956511871a907e5de2c523d86/opacus/privacy_engine.py#L393) - if hasattr(self.optimizer, "virtual_step"): - with torch.no_grad(): - self.optimizer.virtual_step() - else: - self.on_non_accumulate_step() + # if self.should_accumulate(): + # # Special case opacus optimizers to reduce memory footprint + # # see: (https://github.com/pytorch/opacus/blob/244265582bffbda956511871a907e5de2c523d86/opacus/privacy_engine.py#L393) + # if hasattr(self.optimizer, "virtual_step"): + # with torch.no_grad(): + # self.optimizer.virtual_step() + # else: + # self.on_non_accumulate_step() + self.on_non_accumulate_step() if self.log_train_acc: top1_acc = accuracy( @@ -184,19 +179,17 @@ def training_step(self, batch, batch_idx, *_) -> dict: return training_step_results - def get_batch_gradients( - self, - batch: torch.tensor, - batch_idx: int = 0, - create_graph: bool = False, - clone_gradients: bool = True, - apply_transforms=True, - eval_mode: bool = False, - stop_track_bn_stats: bool = True, - BN_exact: bool = False, - attacker: bool = False, - *args, - ): + def get_batch_gradients(self, + batch: torch.tensor, + batch_idx: int = 0, + create_graph: bool = False, + clone_gradients: bool = True, + apply_transforms=True, + eval_mode: bool = False, + stop_track_bn_stats: bool = True, + BN_exact: bool = False, + attacker: bool = False, + *args): batch = tuple(k.to(self.device) for k in batch) if eval_mode is True: self.eval() @@ -317,7 +310,7 @@ def configure_lr_scheduler(self): elif self.hparams["lr_scheduler"] == "LambdaLR": self.lr_scheduler = LambdaLR( self.optimizer, - lr_lambda=[lambda epoch: self.hparams["lr_lambda"]**epoch], + lr_lambda=[lambda epoch: self.hparams["lr_lambda"] ** epoch], verbose=True, ) elif self.hparams["lr_scheduler"] == "ReduceLROnPlateau": @@ -345,54 +338,8 @@ def validation_step(self, batch, batch_idx): else: loss = self._val_loss_metric(y_hat, y) top1_acc = accuracy(y_hat, y, multi_head=self.multi_head)[0] - if self.log_auc: - pred_list, true_list = auc_list(y_hat, y) - else: - pred_list, true_list = None, None - return { - "batch/val_loss": loss, - "batch/val_accuracy": top1_acc, - "batch/val_pred_list": pred_list, - "batch/val_true_list": true_list, - } - - def validation_epoch_end(self, outputs): - # outputs is whatever returned in `validation_step` - avg_loss = torch.stack([x["batch/val_loss"] for x in outputs]).mean() - avg_accuracy = torch.stack([x["batch/val_accuracy"] - for x in outputs]).mean() - if self.log_auc: - self.log_aucs(outputs, stage="val") - - self.current_val_loss = avg_loss - if self.current_epoch > 0: - if self.hparams["lr_scheduler"] == "ReduceLROnPlateau": - self.lr_scheduler.step(self.current_val_loss) - else: - self.lr_scheduler.step() - - self.cur_lr = self.optimizer.param_groups[0]["lr"] - - self.log( - "epoch/val_accuracy", - avg_accuracy, - on_epoch=True, - prog_bar=True, - logger=True, - ) - self.log("epoch/val_loss", - avg_loss, - on_epoch=True, - prog_bar=True, - logger=True) - self.log("epoch/lr", - self.cur_lr, - on_epoch=True, - prog_bar=True, - logger=True) - - for callback in self._epoch_end_callbacks: - callback(self) + self.log('val/loss', loss, on_epoch=True, logger=True) + self.log('val/acc', top1_acc, on_epoch=True, logger=True) def test_step(self, batch, batch_idx): x, y = batch @@ -405,73 +352,22 @@ def test_step(self, batch, batch_idx): else: loss = self._val_loss_metric(y_hat, y) top1_acc = accuracy(y_hat, y, multi_head=self.multi_head)[0] - if self.log_auc: - pred_list, true_list = auc_list(y_hat, y) - else: - pred_list, true_list = None, None - return { - "batch/test_loss": loss, - "batch/test_accuracy": top1_acc, - "batch/test_pred_list": pred_list, - "batch/test_true_list": true_list, - } - - def test_epoch_end(self, outputs): - avg_loss = torch.stack([x["batch/test_loss"] for x in outputs]).mean() - avg_accuracy = torch.stack([x["batch/test_accuracy"] - for x in outputs]).mean() - if self.log_auc: - self.log_aucs(outputs, stage="test") - - self.log("run/test_accuracy", - avg_accuracy, - on_epoch=True, - prog_bar=True, - logger=True) - self.log("run/test_loss", - avg_loss, - on_epoch=True, - prog_bar=True, - logger=True) - - def log_aucs(self, outputs, stage="test"): - pred_list = np.concatenate( - [x[f"batch/{stage}_pred_list"] for x in outputs]) - true_list = np.concatenate( - [x[f"batch/{stage}_true_list"] for x in outputs]) - - aucs = [] - for c in range(len(pred_list[0])): - fpr, tpr, thresholds = metrics.roc_curve(true_list[:, c], - pred_list[:, c], - pos_label=1) - auc_val = metrics.auc(fpr, tpr) - aucs.append(auc_val) - - self.log( - f"epoch/{stage}_auc/class_{c}", - auc_val, - on_epoch=True, - prog_bar=False, - logger=True, - ) - self.log( - f"epoch/{stage}_auc/avg", - np.mean(aucs), - on_epoch=True, - prog_bar=True, - logger=True, - ) + # if self.log_auc: + # pred_list, true_list = auc_list(y_hat, y) + # else: + # pred_list, true_list = None, None + self.log('test/loss', loss, logger=True, on_epoch=True) + self.log('test/acc', top1_acc, logger=True, on_epoch=True) def create_lightning_module( - model_name: str, - num_classes: int, - pretrained: bool = False, - ckpt: str = None, - freeze_extractor: bool = False, - *args, - **kwargs, + model_name: str, + num_classes: int, + pretrained: bool = False, + ckpt: str = None, + freeze_extractor: bool = False, + *args, + **kwargs, ) -> LightningWrapper: if "models" in model_name: # Official models by PyTorch model_name = model_name.replace("models.", "") @@ -524,12 +420,12 @@ def do_freeze_extractor(model): def multihead_accuracy(output, target): prec1 = [] for j in range(output.size(1)): - acc = accuracy(output[:, j], target[:, j], topk=(1, )) + acc = accuracy(output[:, j], target[:, j], topk=(1,)) prec1.append(acc[0]) return torch.mean(torch.Tensor(prec1)) -def accuracy(output, target, topk=(1, ), multi_head=False): +def accuracy(output, target, topk=(1,), multi_head=False): """Computes the precision@k for the specified values of k""" with torch.no_grad(): maxk = max(topk) @@ -554,11 +450,3 @@ def accuracy(output, target, topk=(1, ), multi_head=False): correct *= 100.0 / (batch_size * target.size(1)) res = [correct] return res - - -def auc_list(output, target): - assert len(target.size()) == 2 - pred_list = torch.sigmoid(output).cpu().detach().numpy() - true_list = target.cpu().detach().numpy() - - return pred_list, true_list diff --git a/gradattack/models/alldnet.py b/gradattack/models/alldnet.py index fb071bc..7fedb92 100755 --- a/gradattack/models/alldnet.py +++ b/gradattack/models/alldnet.py @@ -1,7 +1,6 @@ """LeNet in PyTorch.""" import torch.nn as nn import torch.nn.functional as F -from torch.autograd import Variable class AllDNet(nn.Module): diff --git a/gradattack/models/covidmodel.py b/gradattack/models/covidmodel.py index 9331a6f..46ea61b 100644 --- a/gradattack/models/covidmodel.py +++ b/gradattack/models/covidmodel.py @@ -1,5 +1,3 @@ -import numpy as np -import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models diff --git a/gradattack/models/densenet.py b/gradattack/models/densenet.py index 935470d..3660562 100755 --- a/gradattack/models/densenet.py +++ b/gradattack/models/densenet.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from torch.autograd import Variable @@ -17,10 +16,8 @@ def __init__(self, in_planes, growth_rate): kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(4 * growth_rate) - self.conv2 = nn.Conv2d(4 * growth_rate, - growth_rate, - kernel_size=3, - padding=1, + self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, + kernel_size=3, padding=1, bias=False) def forward(self, x): @@ -43,9 +40,7 @@ def forward(self, x): class DenseNet(nn.Module): - def __init__(self, - block, - nblocks, + def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): @@ -128,5 +123,4 @@ def test_densenet(): y = net(Variable(x)) print(y) - # test_densenet() diff --git a/gradattack/models/googlenet.py b/gradattack/models/googlenet.py index 041baab..38371b8 100755 --- a/gradattack/models/googlenet.py +++ b/gradattack/models/googlenet.py @@ -1,9 +1,6 @@ """GoogLeNet with PyTorch.""" import torch import torch.nn as nn -import torch.nn.functional as F - -from torch.autograd import Variable class Inception(nn.Module): @@ -100,7 +97,6 @@ def forward(self, x): out = self.linear(out) return out - # net = GoogLeNet() # x = torch.randn(1,3,32,32) # y = net(Variable(x)) diff --git a/gradattack/models/mobilenet.py b/gradattack/models/mobilenet.py index 96a5638..8d47cb6 100755 --- a/gradattack/models/mobilenet.py +++ b/gradattack/models/mobilenet.py @@ -12,6 +12,7 @@ class Block(nn.Module): """Depthwise conv + Pointwise conv""" + def __init__(self, in_planes, out_planes, stride=1): super(Block, self).__init__() self.conv1 = nn.Conv2d( @@ -92,5 +93,4 @@ def test(): y = net(Variable(x)) print(y.size()) - # test() diff --git a/gradattack/models/nasnet.py b/gradattack/models/nasnet.py index e204816..541aac8 100644 --- a/gradattack/models/nasnet.py +++ b/gradattack/models/nasnet.py @@ -10,7 +10,6 @@ class SeperableConv2d(nn.Module): def __init__(self, input_channels, output_channels, kernel_size, **kwargs): - super().__init__() self.depthwise = nn.Conv2d(input_channels, input_channels, @@ -63,6 +62,7 @@ class Fit(nn.Module): prev_filters: filter number of tensor prev, needs to be modified filters: filter number of normal cell branch output filters """ + def __init__(self, prev_filters, filters): super().__init__() self.relu = nn.ReLU() @@ -242,7 +242,8 @@ def forward(self, x): return ( torch.cat( [ - layer1block2, # https://github.com/keras-team/keras-applications/blob/master/keras_applications/nasnet.py line 739 + layer1block2, + # https://github.com/keras-team/keras-applications/blob/master/keras_applications/nasnet.py line 739 layer1block3, layer2block1, layer2block2, @@ -314,7 +315,6 @@ def _make_layers(self, repeat_cell_num, reduction_num): layers = [] for i in range(reduction_num): - layers.extend( self._make_normal(NormalCell, repeat_cell_num, self.filters)) self.filters *= 2 @@ -339,6 +339,5 @@ def forward(self, x): def nasnet(num_classes=100): - # stem filters must be 44, it's a pytorch workaround, cant change to other number return NasNetA(4, 2, 44, 44, num_classes=num_classes) diff --git a/gradattack/models/resnet.py b/gradattack/models/resnet.py index 2089512..9d91292 100755 --- a/gradattack/models/resnet.py +++ b/gradattack/models/resnet.py @@ -12,7 +12,6 @@ import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable -import torchvision def conv3x3(in_planes, out_planes, stride=1): @@ -243,5 +242,4 @@ def test(): y = net(Variable(torch.randn(1, 3, 32, 32))) print(y.size()) - # test() diff --git a/gradattack/models/resnext.py b/gradattack/models/resnext.py index 805e024..1a75a98 100755 --- a/gradattack/models/resnext.py +++ b/gradattack/models/resnext.py @@ -1,4 +1,5 @@ from __future__ import division + """ Creates a ResNeXt Model as defined in: Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). @@ -15,6 +16,7 @@ class ResNeXtBottleneck(nn.Module): """ RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) """ + def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): """Constructor @@ -95,6 +97,7 @@ class CifarResNeXt(nn.Module): ResNext optimized for the Cifar dataset, as specified in https://arxiv.org/pdf/1611.05431.pdf """ + def __init__(self, cardinality, depth, diff --git a/gradattack/models/simple.py b/gradattack/models/simple.py index b86341a..db54fd0 100644 --- a/gradattack/models/simple.py +++ b/gradattack/models/simple.py @@ -1,6 +1,3 @@ -import random - -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F diff --git a/gradattack/models/vgg.py b/gradattack/models/vgg.py index a0f8759..0c851a1 100755 --- a/gradattack/models/vgg.py +++ b/gradattack/models/vgg.py @@ -1,12 +1,10 @@ """VGG11/13/16/19 in Pytorch.""" -import torch import torch.nn as nn -from torch.autograd import Variable cfg = { "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "VGG13": - [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "VGG16": [ 64, 64, @@ -81,7 +79,6 @@ def _make_layers(self, cfg): layers += [nn.AvgPool2d(kernel_size=1, stride=1)] return nn.Sequential(*layers) - # net = VGG('VGG11') # x = torch.randn(2,3,32,32) # print(net(Variable(x)).size()) diff --git a/gradattack/trainingpipeline.py b/gradattack/trainingpipeline.py index 055b1bc..d9c8f9b 100644 --- a/gradattack/trainingpipeline.py +++ b/gradattack/trainingpipeline.py @@ -4,12 +4,9 @@ class TrainingPipeline: - def __init__( - self, - model: LightningWrapper, - datamodule: pl.LightningDataModule, - trainer: pl.Trainer, - ): + def __init__(self, model: LightningWrapper, + datamodule: pl.LightningDataModule, + trainer: pl.Trainer): self.model = model self.datamodule = datamodule self.trainer = trainer @@ -20,10 +17,6 @@ def __init__( ) # Modifications to the model architecture, trainable params ... self.datamodule.setup() - # FIXME: @Samyak, are we actually using this funciton? - def log_hparams(self): - self.trainer.logger.log_hyperparams(self.model.hparams) - def setup_pipeline(self): self.datamodule.prepare_data() @@ -38,15 +31,8 @@ def setup_pipeline(self): def run(self): self.setup_pipeline() # If we didn't call setup(), any updates to transforms (e.g. from defenses) wouldn't be applied - return self.trainer.fit(self.model, self.datamodule) + return self.trainer.fit(self.model, datamodule=self.datamodule) def test(self): return self.trainer.test( - self.model, test_dataloaders=self.datamodule.test_dataloader()) - - # FIXME: @Samyak, are we actually using this funciton? - def get_datamodule_batch(self): - self.datamodule.setup() - trainloader = self.datamodule.train_dataloader() - for batch in trainloader: - return batch + self.model, self.datamodule.test_dataloader()) diff --git a/gradattack/utils.py b/gradattack/utils.py index 6cdc822..0fcbb4f 100755 --- a/gradattack/utils.py +++ b/gradattack/utils.py @@ -14,43 +14,29 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.init as init from torch.nn.functional import log_softmax def parse_args(): - parser = argparse.ArgumentParser(description="gradattack training") + parser = argparse.ArgumentParser(description="GradAttack training") parser.add_argument("--gpuid", default="0", type=int, help="gpu id to use") parser.add_argument("--model", - default="ResNet18", + default="ResNet34", type=str, help="name of model") - parser.add_argument("--data", - default="CIFAR10", - type=str, - help="name of dataset") - parser.add_argument( - "--results_dir", - default="./results", - type=str, - help="directory to save attack results", - ) - parser.add_argument("--n_epoch", - default=200, - type=int, - help="number of epochs") - parser.add_argument("--batch_size", - default=128, - type=int, + parser.add_argument("--data", default="CIFAR10", + type=str, help="name of dataset") + parser.add_argument("--results_dir", default="./results", + type=str, help="directory to save attack results") + parser.add_argument("--n_epoch", default=200, type=int, help="number of epochs") + parser.add_argument("--batch_size", default=512, type=int, help="batch size") parser.add_argument("--optimizer", default="SGD", type=str, help="which optimizer") - parser.add_argument("--lr", default=0.05, type=float, help="initial lr") - parser.add_argument("--decay", - default=5e-4, - type=float, + parser.add_argument("--lr", default=0.4, type=float, help="initial lr") + parser.add_argument("--decay", default=5e-4, type=float, help="weight decay") parser.add_argument("--momentum", type=float, default=0.9) parser.add_argument("--nesterov", action="store_true") @@ -58,16 +44,14 @@ def parse_args(): default="ReduceLROnPlateau", type=str, help="which scheduler") - parser.add_argument("--lr_step", - default=30, - type=int, + parser.add_argument("--lr_step", default=30, type=int, help="reduce LR per ? epochs") parser.add_argument("--lr_lambda", default=0.95, type=float, help="lambda of LambdaLR scheduler") parser.add_argument("--lr_factor", - default=0.5, + default=0.2, type=float, help="factor of lr reduction") parser.add_argument("--disable_early_stopping", @@ -77,49 +61,33 @@ def parse_args(): default=10, type=int, help="patience for early stopping") - parser.add_argument( - "--tune_on_val", - default=0.02, - type=float, - help= - "fraction of validation data. If set to 0, use test data as the val data", - ) + parser.add_argument("--tune_on_val", + default=0.02, + type=float, + help="fraction of validation data. If set to 0, use test data as the val data", ) parser.add_argument("--log_auc", dest="log_auc", action="store_true") - parser.add_argument("--logname", - default="vanilla", - type=str, - help="log name") + parser.add_argument("--logname", default="vanilla", type=str, help="log name") parser.add_argument("--pretrained", dest="pretrained", action="store_true") - parser.add_argument("--ckpt", - default=None, - type=str, - help="directory for ckpt") - parser.add_argument( - "--freeze_extractor", - dest="freeze_extractor", - action="store_true", - help="Whether only training the fc layer", - ) + parser.add_argument("--ckpt", default=None, type=str, help="directory for ckpt") + parser.add_argument("--freeze_extractor", + dest="freeze_extractor", + action="store_true", + help="Whether only training the fc layer", ) # Augmentation - parser.add_argument( - "--dis_aug_crop", - dest="aug_crop", - action="store_false", - help="Whether to apply random cropping", - ) - parser.add_argument( - "--dis_aug_hflip", - dest="aug_hflip", - action="store_false", - help="Whether to apply horizontally flipping", - ) - parser.add_argument( - "--aug_affine", - dest="aug_affine", - action="store_true", - help="Enable random affine", - ) + parser.add_argument("--dis_aug_crop", + dest="aug_crop", + action="store_false", + help="Whether to apply random cropping") + parser.add_argument("--dis_aug_hflip", + dest="aug_hflip", + action="store_false", + help="Whether to apply horizontally flipping") + parser.add_argument("--aug_affine", + dest="aug_affine", + action="store_true", + help="Enable random affine", + ) parser.add_argument( "--aug_rotation", type=float, @@ -137,17 +105,11 @@ def parse_args(): parser.add_argument("--defense_instahide", dest="defense_instahide", action="store_true") - parser.add_argument("--klam", - default=4, - type=int, + parser.add_argument("--klam", default=4, type=int, help="How many images to mix with") - parser.add_argument("--c_1", - default=0, - type=float, + parser.add_argument("--c_1", efault=0, type=float, help="Lower bound of mixing coefs") - parser.add_argument("--c_2", - default=1, - type=float, + parser.add_argument("--c_2", default=1, type=float, help="Upper bound of mixing coefs") parser.add_argument("--use_csprng", dest="use_csprng", action="store_true") # GradPrune @@ -166,34 +128,20 @@ def parse_args(): type=float, help="Failure prob of DP", ) - parser.add_argument("--max_epsilon", - default=2, - type=float, + parser.add_argument("--max_epsilon", default=2, type=float, help="Privacy budget") - parser.add_argument( - "--max_grad_norm", - default=1, - type=float, - help="Clip per-sample gradients to this norm", - ) + parser.add_argument("--max_grad_norm", default=1, type=float, + help="Clip per-sample gradients to this norm", ) parser.add_argument("--noise_multiplier", default=1, type=float, help="Noise multiplier") - - parser.add_argument( - "--n_accumulation_steps", - default=1, - type=int, - help="Run optimization per ? step", - ) - parser.add_argument( - "--secure_rng", - dest="secure_rng", - action="store_true", - help= - "Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost", - ) + parser.add_argument("--n_accumulation_steps", default=1, type=int, + help="Run optimization per ? step") + parser.add_argument("--secure_rng", + dest="secure_rng", + action="store_true", + help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost", ) # For attack parser.add_argument("--reconstruct_labels", action="store_true") @@ -212,31 +160,18 @@ def parse_args(): default=None, type=int, help="seed to select attack subset") - parser.add_argument("--attack_epoch", - default=0, - type=int, + parser.add_argument("--attack_epoch", default=0, type=int, help="iterations for the attack") - parser.add_argument( - "--bn_reg", - default=0, - type=float, - help="coef. for batchnorm regularization term", - ) - parser.add_argument( - "--attacker_eval_mode", - action="store_true", - help="use eval model for gradients calculation for attack", - ) - parser.add_argument( - "--defender_eval_mode", - action="store_true", - help="use eval model for gradients calculation for training", - ) - parser.add_argument( - "--BN_exact", - action="store_true", - help="use training batch's mean and var", - ) + parser.add_argument("--bn_reg", default=0, type=float, + help="coef. for batchnorm regularization term") + parser.add_argument("--attacker_eval_mode", + action="store_true", + help="use eval model for gradients calculation for attack") + parser.add_argument("--defender_eval_mode", action="store_true", + help="use eval model for gradients calculation for training") + parser.add_argument("--BN_exact", + action="store_true", + help="use training batch's mean and var",) args = parser.parse_args() @@ -255,7 +190,7 @@ def parse_args(): hparams["lr_step"] = args.lr_step hparams["lr_factor"] = args.lr_factor elif args.scheduler == "MultiStepLR": - hparams["lr_step"] = [100, 150] + hparams["lr_step"] = [60, 120, 160] hparams["lr_factor"] = args.lr_factor elif args.scheduler == "LambdaLR": hparams["lr_lambda"] = args.lr_lambda @@ -286,16 +221,12 @@ def parse_args(): def parse_augmentation(args): return { - "hflip": - args.aug_hflip, - "crop": - args.aug_crop, - "rotation": - args.aug_rotation, + "hflip": args.aug_hflip, + "crop": args.aug_crop, + "rotation": args.aug_rotation, "color_jitter": [float(i) for i in args.aug_colorjitter] if args.aug_colorjitter is not None else None, - "affine": - args.aug_affine, + "affine": args.aug_affine, } @@ -386,7 +317,7 @@ def patch_image(x, dim=(32, 32)): x = np.append(x, np.zeros((pad_size, *x[0].shape)), axis=0) batch_size = len(x) x = np.transpose(x, (0, 2, 3, 1)) - if int(np.sqrt(batch_size))**2 == batch_size: + if int(np.sqrt(batch_size)) ** 2 == batch_size: s = int(np.sqrt(batch_size)) x = np.reshape(x, (s, s, dim[0], dim[1], 3)) x = np.concatenate(x, axis=2) diff --git a/requirements.txt b/requirements.txt index 915ae65..f263b7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,6 @@ opacus==0.14.0 pytorch_lightning==1.5.1 scikit_learn==1.0.1 torch==1.8.1 -torchcsprng==0.2.1 +torchcsprng==0.3.0 torchmetrics==0.6.0 torchvision==0.9.1 diff --git a/setup.cfg b/setup.cfg index b52e3ff..38e0c5e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,15 +16,4 @@ classifiers = [options] packages = find: -python_requires = >=3.8 -install_requires = - colorama==0.4.4 - matplotlib==3.4.3 - numpy==1.21.4 - opacus==0.14.0 - pytorch_lightning==1.3.1 - scikit_learn==1.0.1 - torch==1.8.1 - torchcsprng==0.2.1 - torchmetrics==0.6.0 - torchvision==0.9.1 +python_requires = >=3.7 diff --git a/test/train_config.py b/test/train_config.py index d04ea26..4afd7fe 100644 --- a/test/train_config.py +++ b/test/train_config.py @@ -1,5 +1,3 @@ -from argparse import Namespace - vanilla_hparams = { "optimizer": "SGD", "lr": 0.05, @@ -10,4 +8,4 @@ "tune_on_val": 0.02, "batch_size": 128, "lr_factor": 0.5, -} \ No newline at end of file +}