From faa07ad9087ec4a4a793a91b4f24152753c253b3 Mon Sep 17 00:00:00 2001 From: limberc Date: Mon, 2 Jan 2023 16:57:03 +0000 Subject: [PATCH 01/17] Zehua Update. --- gradattack/attacks/attack.py | 21 +++--- gradattack/attacks/gradientinversion.py | 64 ++++++++--------- gradattack/datamodules.py | 95 +++++++++++++++---------- gradattack/defenses/defense_utils.py | 6 +- gradattack/defenses/dpsgd.py | 28 ++++---- gradattack/defenses/instahide.py | 10 +-- gradattack/defenses/mixup.py | 10 +-- gradattack/models/__init__.py | 82 ++++++++++----------- 8 files changed, 161 insertions(+), 155 deletions(-) diff --git a/gradattack/attacks/attack.py b/gradattack/attacks/attack.py index 7749c00..00ab5bc 100644 --- a/gradattack/attacks/attack.py +++ b/gradattack/attacks/attack.py @@ -2,23 +2,22 @@ from typing import Callable import torch -from numpy.lib.npyio import load from gradattack.trainingpipeline import TrainingPipeline -from .invertinggradients.inversefed.reconstruction_algorithms import ( - GradientReconstructor, ) +from .gradientinversion import GradientReconstructor class GradientInversionAttack: """Wrapper around Gradient Inversion attack""" + def __init__( - self, - pipeline: TrainingPipeline, - dm, - ds, - device: torch.device, - loss_metric: Callable, - reconstructor_args: dict = None, + self, + pipeline: TrainingPipeline, + dm, + ds, + device: torch.device, + loss_metric: Callable, + reconstructor_args: dict = None, ): self.pipeline = pipeline self.device = device @@ -58,5 +57,5 @@ def run_from_dump(self, filepath: str, dataset: torch.utils.data.Dataset): self.model.load_state_dict(loaded_data["model_state_dict"]) batch_inputs, batch_targets = ( dataset[i][0] for i in loaded_data["batch_indices"]), ( - dataset[i][1] for i in loaded_data["batch_indices"]) + dataset[i][1] for i in loaded_data["batch_indices"]) return self.run_attack_batch(batch_inputs, batch_targets) diff --git a/gradattack/attacks/gradientinversion.py b/gradattack/attacks/gradientinversion.py index 444dedc..d7a3f41 100644 --- a/gradattack/attacks/gradientinversion.py +++ b/gradattack/attacks/gradientinversion.py @@ -1,19 +1,18 @@ import copy from typing import Any, Callable, Optional -import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F import torchmetrics from gradattack.metrics.gradients import CosineSimilarity, L2Diff from gradattack.metrics.pixelwise import MeanPixelwiseError -from gradattack.models import LightningWrapper from gradattack.trainingpipeline import TrainingPipeline from gradattack.utils import patch_image from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset + # DEFAULT_HPARAMS = { # "optimizer": "Adam", # "lr_scheduler": True, @@ -36,6 +35,7 @@ class BNFeatureHook: Implementation of the forward hook to track feature statistics and compute a loss on them. Will compute mean and variance, and will use l2 as a loss """ + def __init__(self, module): self.hook = module.register_forward_hook(self.hook_fn) @@ -91,30 +91,30 @@ def __getitem__(self, idx: int): class GradientReconstructor(pl.LightningModule): def __init__( - self, - pipeline: TrainingPipeline, - ground_truth_inputs: tuple, - ground_truth_gradients: tuple, - ground_truth_labels: tuple, - intial_reconstruction: torch.tensor = None, - reconstruct_labels=False, - attack_loss_metric: Callable = CosineSimilarity(), - mean_std: tuple = (0.0, 1.0), - num_iterations=10000, - optimizer: str = "Adam", - lr_scheduler: bool = True, - lr: float = 0.1, - total_variation: float = 1e-1, - l2: float = 0, - bn_reg: float = 0, - first_bn_multiplier: float = 1, - signed_image: bool = False, - signed_gradients: bool = True, - boxed: bool = True, - attacker_eval_mode: bool = True, - recipe: str = 'Geiping', - BN_exact: bool = False, - grayscale: bool = False, + self, + pipeline: TrainingPipeline, + ground_truth_inputs: tuple, + ground_truth_gradients: tuple, + ground_truth_labels: tuple, + intial_reconstruction: torch.tensor = None, + reconstruct_labels=False, + attack_loss_metric: Callable = CosineSimilarity(), + mean_std: tuple = (0.0, 1.0), + num_iterations=10000, + optimizer: str = "Adam", + lr_scheduler: bool = True, + lr: float = 0.1, + total_variation: float = 1e-1, + l2: float = 0, + bn_reg: float = 0, + first_bn_multiplier: float = 1, + signed_image: bool = False, + signed_gradients: bool = True, + boxed: bool = True, + attacker_eval_mode: bool = True, + recipe: str = 'Geiping', + BN_exact: bool = False, + grayscale: bool = False, ): super().__init__() self.save_hyperparameters("optimizer", "lr_scheduler", "lr", @@ -305,8 +305,8 @@ def _closure(): reconstruction_loss = self._attack_loss_metric( recovered_gradients, input_gradients) reconstruction_loss += self.hparams[ - "total_variation"] * total_variation( - self.best_guess, self.hparams["signed_image"]) + "total_variation"] * total_variation( + self.best_guess, self.hparams["signed_image"]) reconstruction_loss += self.hparams["l2"] * l2_norm( self.best_guess, self.hparams["signed_image"]) elif self.recipe == 'Zhu': ## TODO: test @@ -347,14 +347,14 @@ def _closure(): self.logger.experiment.add_scalar( f"BN_loss/layer_{i}_mean_loss", torch.sqrt( - sum((recon_mean[i] - self.mean_gt[i])**2) / + sum((recon_mean[i] - self.mean_gt[i]) ** 2) / len(recon_mean[i])), global_step=self.global_step, ) self.logger.experiment.add_scalar( f"BN_loss/layer_{i}_var_loss", torch.sqrt( - sum((recon_var[i] - self.var_gt[i])**2) / + sum((recon_var[i] - self.var_gt[i]) ** 2) / len(recon_mean[i])), global_step=self.global_step, ) @@ -438,8 +438,8 @@ def configure_optimizers(self): self.labels = self.labels.to(self.device) if self.grayscale: parameters = ([ - self.best_guess_grayscale, self.labels - ] if self._reconstruct_labels else [self.best_guess_grayscale]) + self.best_guess_grayscale, self.labels + ] if self._reconstruct_labels else [self.best_guess_grayscale]) else: parameters = ([self.best_guess, self.labels] if self._reconstruct_labels else [self.best_guess]) diff --git a/gradattack/datamodules.py b/gradattack/datamodules.py index a01091c..b7f372a 100644 --- a/gradattack/datamodules.py +++ b/gradattack/datamodules.py @@ -10,7 +10,7 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset from torch.utils.data.sampler import Sampler -from torchvision.datasets.cifar import CIFAR10 +from torchvision.datasets.cifar import CIFAR10, CIFAR100 from torchvision.datasets import MNIST DEFAULT_DATA_DIR = "./data" @@ -22,6 +22,15 @@ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] +DATASET_NORM = { + 'mnist': transforms.Normalize((0.1307,), (0.3081,)), + 'imagenet': transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + 'cifar100': transforms.Normalize((0.50705882, 0.48666667, 0.44078431), + (0.26745098, 0.25647059, 0.27607843)), + 'cifar10': transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)) +} def train_val_split(dataset_size: int, val_train_split: float = 0.02): @@ -35,11 +44,11 @@ def train_val_split(dataset_size: int, val_train_split: float = 0.02): def extract_attack_set( - dataset: Dataset, - sample_per_class: int = 5, - multi_class=False, - total_num_samples: int = 50, - seed: int = None, + dataset: Dataset, + sample_per_class: int = 5, + multi_class=False, + total_num_samples: int = 50, + seed: int = None, ): if not multi_class: num_classes = len(dataset.classes) @@ -66,12 +75,12 @@ def extract_attack_set( class FileDataModule(LightningDataModule): def __init__( - self, - data_dir: str = DEFAULT_DATA_DIR, - transform: torch.nn.Module = transforms.Compose(TRANSFORM_IMAGENET), - batch_size: int = 32, - num_workers: int = DEFAULT_NUM_WORKERS, - batch_sampler: Sampler = None, + self, + data_dir: str = DEFAULT_DATA_DIR, + transform: torch.nn.Module = transforms.Compose(TRANSFORM_IMAGENET), + batch_size: int = 32, + num_workers: int = DEFAULT_NUM_WORKERS, + batch_sampler: Sampler = None, ): self.data_dir = data_dir self.batch_size = batch_size @@ -98,13 +107,13 @@ def test_dataloader(self): class ImageNetDataModule(LightningDataModule): def __init__( - self, - augment: dict = None, - data_dir: str = os.path.join(DEFAULT_DATA_DIR, "imagenet"), - batch_size: int = 32, - num_workers: int = DEFAULT_NUM_WORKERS, - batch_sampler: Sampler = None, - tune_on_val: bool = False, + self, + augment: dict = None, + data_dir: str = os.path.join(DEFAULT_DATA_DIR, "imagenet"), + batch_size: int = 32, + num_workers: int = DEFAULT_NUM_WORKERS, + batch_sampler: Sampler = None, + tune_on_val: bool = False, ): self.data_dir = data_dir self.batch_size = batch_size @@ -223,14 +232,16 @@ def test_dataloader(self): class MNISTDataModule(LightningDataModule): + DATASET_NAME = 'mnist' + def __init__( - self, - augment: dict = None, - batch_size: int = 32, - data_dir: str = DEFAULT_DATA_DIR, - num_workers: int = DEFAULT_NUM_WORKERS, - batch_sampler: Sampler = None, - tune_on_val: float = 0, + self, + augment: dict = None, + batch_size: int = 32, + data_dir: str = DEFAULT_DATA_DIR, + num_workers: int = DEFAULT_NUM_WORKERS, + batch_sampler: Sampler = None, + tune_on_val: float = 0, ): super().__init__() self._has_setup_attack = False @@ -244,8 +255,7 @@ def __init__( self.batch_sampler = batch_sampler self.tune_on_val = tune_on_val self.multi_class = False - - mnist_normalize = transforms.Normalize((0.1307, ), (0.3081, )) + mnist_normalize = DATASET_NORM[self.DATASET_NAME] self._train_transforms = [ transforms.Resize(32), @@ -387,15 +397,17 @@ def test_dataloader(self): class CIFAR10DataModule(LightningDataModule): + DATASET_NAME = 'cifar10' + def __init__( - self, - augment: dict = None, - batch_size: int = 32, - data_dir: str = DEFAULT_DATA_DIR, - num_workers: int = DEFAULT_NUM_WORKERS, - batch_sampler: Sampler = None, - tune_on_val: float = 0, - seed: int = None, + self, + augment: dict = None, + batch_size: int = 32, + data_dir: str = DEFAULT_DATA_DIR, + num_workers: int = DEFAULT_NUM_WORKERS, + batch_sampler: Sampler = None, + tune_on_val: float = 0, + seed: int = None, ): super().__init__() self._has_setup_attack = False @@ -410,9 +422,7 @@ def __init__( self.tune_on_val = tune_on_val self.multi_class = False self.seed = seed - - cifar_normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)) + cifar_normalize = DATASET_NORM[self.DATASET_NAME] self._train_transforms = [transforms.ToTensor(), cifar_normalize] if augment["hflip"]: @@ -538,3 +548,10 @@ def test_dataloader(self): return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers) +class CIFAR100DataModule(CIFAR10DataModule): + DATASET_NAME = 'cifar100' + + def prepare_data(self): + """Download the data""" + CIFAR100(self.data_dir, train=True, download=True) + CIFAR100(self.data_dir, train=False, download=True) diff --git a/gradattack/defenses/defense_utils.py b/gradattack/defenses/defense_utils.py index 86bec2e..f62a976 100644 --- a/gradattack/defenses/defense_utils.py +++ b/gradattack/defenses/defense_utils.py @@ -1,11 +1,7 @@ -from typing import Any, Dict, Iterable, List, Optional, Union - import torch -from colorama import Back, Fore, Style, init -from torch.utils.data.dataset import Dataset +from colorama import Fore, Style, init from gradattack.trainingpipeline import TrainingPipeline - from .dpsgd import DPSGDDefense from .gradprune import GradPruneDefense from .instahide import InstahideDefense diff --git a/gradattack/defenses/dpsgd.py b/gradattack/defenses/dpsgd.py index 7b8327e..7fde214 100644 --- a/gradattack/defenses/dpsgd.py +++ b/gradattack/defenses/dpsgd.py @@ -1,15 +1,12 @@ # DPSGD defense. The implementation of DPSGD is based on Opacus: https://opacus.ai/ -import time -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import List -import numpy as np import pytorch_lightning as pl -import torch from opacus import PrivacyEngine -from opacus.utils.module_modification import convert_batchnorm_modules from opacus.utils.uniform_sampler import UniformWithReplacementSampler from pytorch_lightning.core.datamodule import LightningDataModule + from gradattack.defenses.defense import GradientDefense from gradattack.models import StepTracker from gradattack.trainingpipeline import TrainingPipeline @@ -31,17 +28,18 @@ class DPSGDDefense(GradientDefense): secure_rng (bool, optional): If on, it will use ``torchcsprng`` for secure random number generation. Comes with a significant performance cost. Defaults to False. freeze_extractor (bool, optional): If on, only finetune the classifier (the final fully-connected layers). Defaults to False. """ + def __init__( - self, - mini_batch_size: int, - sample_size: int, - n_accumulation_steps: int, - max_grad_norm: float, - noise_multiplier: float, - delta_list: List[float] = [1e-3, 1e-4, 1e-5], - max_epsilon: float = 2, - secure_rng: bool = False, - freeze_extractor: bool = False, + self, + mini_batch_size: int, + sample_size: int, + n_accumulation_steps: int, + max_grad_norm: float, + noise_multiplier: float, + delta_list: List[float] = [1e-3, 1e-4, 1e-5], + max_epsilon: float = 2, + secure_rng: bool = False, + freeze_extractor: bool = False, ): super().__init__() diff --git a/gradattack/defenses/instahide.py b/gradattack/defenses/instahide.py index 32eb15e..5026ee5 100644 --- a/gradattack/defenses/instahide.py +++ b/gradattack/defenses/instahide.py @@ -1,10 +1,10 @@ import numpy as np import torch +import torchcsprng as csprng from torch.distributions.dirichlet import Dirichlet from torch.nn.functional import one_hot from torch.utils.data.dataset import Dataset -import torchcsprng as csprng from gradattack.defenses.defense import GradientDefense from gradattack.trainingpipeline import TrainingPipeline @@ -124,10 +124,10 @@ def generate_mapping(self, return_tensor=True): return np.asarray(lams), np.asarray(selects) def instahide_batch( - self, - inputs: torch.tensor, - lams_b: float, - selects_b: np.array, + self, + inputs: torch.tensor, + lams_b: float, + selects_b: np.array, ): """Generate an InstaHide batch. diff --git a/gradattack/defenses/mixup.py b/gradattack/defenses/mixup.py index 7b7a956..6e7704d 100644 --- a/gradattack/defenses/mixup.py +++ b/gradattack/defenses/mixup.py @@ -1,10 +1,10 @@ import numpy as np import torch +import torchcsprng as csprng from torch.distributions.dirichlet import Dirichlet from torch.nn.functional import one_hot from torch.utils.data.dataset import Dataset -import torchcsprng as csprng from gradattack.defenses.defense import GradientDefense from gradattack.trainingpipeline import TrainingPipeline @@ -124,10 +124,10 @@ def generate_mapping(self, return_tensor=True): return np.asarray(lams), np.asarray(selects) def mixup_batch( - self, - inputs: torch.tensor, - lams_b: float, - selects_b: np.array, + self, + inputs: torch.tensor, + lams_b: float, + selects_b: np.array, ): """Generate a MixUp batch. diff --git a/gradattack/models/__init__.py b/gradattack/models/__init__.py index b88c6b4..d469548 100755 --- a/gradattack/models/__init__.py +++ b/gradattack/models/__init__.py @@ -1,14 +1,9 @@ -# FIXME: @Samyak, could you please help add docstring to this file? Thanks! import os -import time -from typing import Any, Callable, Optional +from typing import Callable import pytorch_lightning as pl -import torch.nn.functional as F -import torchvision.models as models from gradattack.utils import StandardizeLayer from sklearn import metrics -from torch.nn import init from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, ReduceLROnPlateau, StepLR from .covidmodel import * @@ -38,23 +33,24 @@ def end(self, deduction: float = 0): class LightningWrapper(pl.LightningModule): """Wraps a torch module in a pytorch-lightning module. Any .""" + def __init__( - self, - model: torch.nn.Module, - training_loss_metric: Callable = F.mse_loss, - optimizer: str = "SGD", - lr_scheduler: str = "ReduceLROnPlateau", - tune_on_val: float = 0.02, - lr_factor: float = 0.5, - lr_step: int = 10, - batch_size: int = 64, - lr: float = 0.05, - momentum: float = 0.9, - weight_decay: float = 5e-4, - nesterov: bool = False, - log_auc: bool = False, - multi_class: bool = False, - multi_head: bool = False, + self, + model: torch.nn.Module, + training_loss_metric: Callable = F.mse_loss, + optimizer: str = "SGD", + lr_scheduler: str = "ReduceLROnPlateau", + tune_on_val: float = 0.02, + lr_factor: float = 0.5, + lr_step: int = 10, + batch_size: int = 64, + lr: float = 0.05, + momentum: float = 0.9, + weight_decay: float = 5e-4, + nesterov: bool = False, + log_auc: bool = False, + multi_class: bool = False, + multi_head: bool = False, ): super().__init__() # if we didn't copy here, then we would modify the default dict by accident @@ -185,17 +181,17 @@ def training_step(self, batch, batch_idx, *_) -> dict: return training_step_results def get_batch_gradients( - self, - batch: torch.tensor, - batch_idx: int = 0, - create_graph: bool = False, - clone_gradients: bool = True, - apply_transforms=True, - eval_mode: bool = False, - stop_track_bn_stats: bool = True, - BN_exact: bool = False, - attacker: bool = False, - *args, + self, + batch: torch.tensor, + batch_idx: int = 0, + create_graph: bool = False, + clone_gradients: bool = True, + apply_transforms=True, + eval_mode: bool = False, + stop_track_bn_stats: bool = True, + BN_exact: bool = False, + attacker: bool = False, + *args, ): batch = tuple(k.to(self.device) for k in batch) if eval_mode is True: @@ -317,7 +313,7 @@ def configure_lr_scheduler(self): elif self.hparams["lr_scheduler"] == "LambdaLR": self.lr_scheduler = LambdaLR( self.optimizer, - lr_lambda=[lambda epoch: self.hparams["lr_lambda"]**epoch], + lr_lambda=[lambda epoch: self.hparams["lr_lambda"] ** epoch], verbose=True, ) elif self.hparams["lr_scheduler"] == "ReduceLROnPlateau": @@ -465,13 +461,13 @@ def log_aucs(self, outputs, stage="test"): def create_lightning_module( - model_name: str, - num_classes: int, - pretrained: bool = False, - ckpt: str = None, - freeze_extractor: bool = False, - *args, - **kwargs, + model_name: str, + num_classes: int, + pretrained: bool = False, + ckpt: str = None, + freeze_extractor: bool = False, + *args, + **kwargs, ) -> LightningWrapper: if "models" in model_name: # Official models by PyTorch model_name = model_name.replace("models.", "") @@ -524,12 +520,12 @@ def do_freeze_extractor(model): def multihead_accuracy(output, target): prec1 = [] for j in range(output.size(1)): - acc = accuracy(output[:, j], target[:, j], topk=(1, )) + acc = accuracy(output[:, j], target[:, j], topk=(1,)) prec1.append(acc[0]) return torch.mean(torch.Tensor(prec1)) -def accuracy(output, target, topk=(1, ), multi_head=False): +def accuracy(output, target, topk=(1,), multi_head=False): """Computes the precision@k for the specified values of k""" with torch.no_grad(): maxk = max(topk) From a0e68de6ccfaf6da443f8bac1602bb4d7c49852a Mon Sep 17 00:00:00 2001 From: limberc Date: Mon, 2 Jan 2023 16:57:18 +0000 Subject: [PATCH 02/17] CIFAR 100 exp. --- examples/attack_cifar100_gradinversion.py | 202 ++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 examples/attack_cifar100_gradinversion.py diff --git a/examples/attack_cifar100_gradinversion.py b/examples/attack_cifar100_gradinversion.py new file mode 100644 index 0000000..3d6dcf3 --- /dev/null +++ b/examples/attack_cifar100_gradinversion.py @@ -0,0 +1,202 @@ +import os +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from torch.nn.modules.loss import CrossEntropyLoss + +from gradattack.attacks.gradientinversion import GradientReconstructor +from gradattack.datamodules import CIFAR100DataModule +from gradattack.defenses.defense_utils import DefensePack +from gradattack.models import create_lightning_module +from gradattack.trainingpipeline import TrainingPipeline +from gradattack.utils import (cross_entropy_for_onehot, parse_args, + patch_image, save_fig) + +cifar100_mean = torch.tensor( + [0.50705882, 0.48666667, 0.44078431]) +cifar100_std = torch.tensor( + [0.26745098, 0.25647059, 0.27607843]) +dm = cifar100_mean[:, None, None] +ds = cifar100_std[:, None, None] + + +def setup_attack(): + """Setup the pipeline for the attack""" + args, hparams, attack_hparams = parse_args() + print(attack_hparams) + + global ROOT_DIR, DEVICE, EPOCH, devices + + DEVICE = torch.device(f"cuda:{args.gpuid}") + EPOCH = attack_hparams["epoch"] + devices = [args.gpuid] + + pl.utilities.seed.seed_everything(1234 + EPOCH) + torch.backends.cudnn.benchmark = True + + BN_str = '' + + if not args.attacker_eval_mode: + BN_str += "-attacker_train" + if not args.defender_eval_mode: + BN_str += '-defender_train' + if args.BN_exact: + BN_str = 'BN_exact' + attack_hparams['attacker_eval_mode'] = False + + datamodule = CIFAR100DataModule(batch_size=args.batch_size, + augment={ + "hflip": False, + "color_jitter": None, + "rotation": -1, + "crop": False + }, + num_workers=1, + seed=args.data_seed) + print("Loaded data!") + if args.defense_instahide or args.defense_mixup: # Customize loss + loss = cross_entropy_for_onehot + else: + loss = CrossEntropyLoss(reduction="mean") + + if args.defense_instahide: + model = create_lightning_module(args.model, + datamodule.num_classes, + training_loss_metric=loss, + pretrained=False, + ckpt="checkpoint/InstaHide_ckpt.ckpt", + **hparams).to(DEVICE) + elif args.defense_mixup: + model = create_lightning_module(args.model, + datamodule.num_classes, + training_loss_metric=loss, + pretrained=False, + ckpt="checkpoint/Mixup_ckpt.ckpt", + **hparams).to(DEVICE) + else: + model = create_lightning_module( + args.model, + datamodule.num_classes, + training_loss_metric=loss, + pretrained=False, + ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt", + **hparams).to(DEVICE) + + logger = TensorBoardLogger("tb_logs", name=f"{args.logname}") + trainer = pl.Trainer(gpus=devices, benchmark=True, logger=logger) + pipeline = TrainingPipeline(model, datamodule, trainer) + + defense_pack = DefensePack(args, logger) + if attack_hparams["mini"]: + datamodule.setup("attack_mini") + elif attack_hparams["large"]: + datamodule.setup("attack_large") + else: + datamodule.setup("attack") + + defense_pack.apply_defense(pipeline) + + ROOT_DIR = f"{args.results_dir}/CIFAR100-{args.batch_size}-{str(defense_pack)}/tv={attack_hparams['total_variation']}{BN_str}-bn={attack_hparams['bn_reg']}-dataseed={args.data_seed}/Epoch_{EPOCH}" + try: + os.makedirs(ROOT_DIR, exist_ok=True) + except FileExistsError: + pass + print("storing in root dir", ROOT_DIR) + + if "InstaHideDefense" in defense_pack.defense_params.keys(): + cur_lams = defense_pack.instahide_defense.cur_lams.cpu().numpy() + cur_selects = defense_pack.instahide_defense.cur_selects.cpu().numpy() + np.savetxt(f"{ROOT_DIR}/epoch_lams.txt", cur_lams) + np.savetxt(f"{ROOT_DIR}/epoch_selects.txt", cur_selects.astype(int)) + elif "MixupDefense" in defense_pack.defense_params.keys(): + cur_lams = defense_pack.mixup_defense.cur_lams.cpu().numpy() + cur_selects = defense_pack.mixup_defense.cur_selects.cpu().numpy() + np.savetxt(f"{ROOT_DIR}/epoch_lams.txt", cur_lams) + np.savetxt(f"{ROOT_DIR}/epoch_selects.txt", cur_selects.astype(int)) + + return pipeline, attack_hparams + + +def run_attack(pipeline, attack_hparams): + """Launch the real attack""" + trainloader = pipeline.datamodule.train_dataloader() + model = pipeline.model + + for batch_idx, (batch_inputs, batch_targets) in enumerate(trainloader): + BATCH_ROOT_DIR = ROOT_DIR + f"/{batch_idx}" + os.makedirs(BATCH_ROOT_DIR, exist_ok=True) + save_fig(batch_inputs, + f"{BATCH_ROOT_DIR}/original.png", + save_npy=True, + save_fig=False) + save_fig(patch_image(batch_inputs), + f"{BATCH_ROOT_DIR}/original.png", + save_npy=False) + + batch_inputs, batch_targets = batch_inputs.to( + DEVICE), batch_targets.to(DEVICE) + + batch_gradients, step_results = model.get_batch_gradients( + (batch_inputs, batch_targets), + batch_idx, + eval_mode=attack_hparams["defender_eval_mode"], + apply_transforms=True, + stop_track_bn_stats=False, + BN_exact=attack_hparams["BN_exact"]) + batch_inputs_transform, batch_targets_transform = step_results[ + "transformed_batch"] + + save_fig( + batch_inputs_transform, + f"{BATCH_ROOT_DIR}/transformed.png", + save_npy=True, + save_fig=False, + ) + save_fig( + patch_image(batch_inputs_transform), + f"{BATCH_ROOT_DIR}/transformed.png", + save_npy=False, + ) + + attack = GradientReconstructor( + pipeline, + ground_truth_inputs=batch_inputs_transform, + ground_truth_gradients=batch_gradients, + ground_truth_labels=batch_targets_transform, + reconstruct_labels=attack_hparams["reconstruct_labels"], + num_iterations=10000, + signed_gradients=True, + signed_image=attack_hparams["signed_image"], + boxed=True, + total_variation=attack_hparams["total_variation"], + bn_reg=attack_hparams["bn_reg"], + lr_scheduler=True, + lr=attack_hparams["attack_lr"], + mean_std=(dm, ds), + attacker_eval_mode=attack_hparams["attacker_eval_mode"], + BN_exact=attack_hparams["BN_exact"]) + + tb_logger = TensorBoardLogger(BATCH_ROOT_DIR, name="tb_log") + attack_trainer = pl.Trainer( + gpus=devices, + logger=tb_logger, + max_epochs=1, + benchmark=True, + checkpoint_callback=False, + ) + attack_trainer.fit(attack) + result = attack.best_guess.detach().to("cpu") + + save_fig(result, + f"{BATCH_ROOT_DIR}/reconstructed.png", + save_npy=True, + save_fig=False) + save_fig(patch_image(result), + f"{BATCH_ROOT_DIR}/reconstructed.png", + save_npy=False) + + +if __name__ == "__main__": + pipeline, attack_hparams = setup_attack() + run_attack(pipeline, attack_hparams) From 1df9a5bfbba96d971a448288fcdb92b8655210de Mon Sep 17 00:00:00 2001 From: limberc Date: Mon, 2 Jan 2023 16:57:58 +0000 Subject: [PATCH 03/17] Modify the setup req to python 3.7 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index b52e3ff..5d25544 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,7 @@ classifiers = [options] packages = find: -python_requires = >=3.8 +python_requires = >=3.7 install_requires = colorama==0.4.4 matplotlib==3.4.3 From 06d17f30a0b77441cd1e12236a3493f0d872d051 Mon Sep 17 00:00:00 2001 From: limberc Date: Mon, 2 Jan 2023 17:48:07 +0000 Subject: [PATCH 04/17] Reformat the datamodules.py --- gradattack/datamodules.py | 248 ++++++++++++-------------------------- 1 file changed, 76 insertions(+), 172 deletions(-) diff --git a/gradattack/datamodules.py b/gradattack/datamodules.py index b7f372a..1f29051 100644 --- a/gradattack/datamodules.py +++ b/gradattack/datamodules.py @@ -10,8 +10,8 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset from torch.utils.data.sampler import Sampler -from torchvision.datasets.cifar import CIFAR10, CIFAR100 from torchvision.datasets import MNIST +from torchvision.datasets.cifar import CIFAR10, CIFAR100 DEFAULT_DATA_DIR = "./data" DEFAULT_NUM_WORKERS = 32 @@ -105,40 +105,44 @@ def test_dataloader(self): return self.get_dataloader() -class ImageNetDataModule(LightningDataModule): +class BaseDataModule(LightningDataModule): def __init__( self, augment: dict = None, - data_dir: str = os.path.join(DEFAULT_DATA_DIR, "imagenet"), + data_dir: str = DEFAULT_DATA_DIR, batch_size: int = 32, num_workers: int = DEFAULT_NUM_WORKERS, batch_sampler: Sampler = None, tune_on_val: bool = False, ): + super().__init__() self.data_dir = data_dir self.batch_size = batch_size self.num_workers = num_workers - self.num_classes = 1000 - self.multi_class = False self.batch_sampler = batch_sampler self.tune_on_val = tune_on_val - + self._train_transforms = self.train_transform(augment) + print(self._train_transforms) + self._test_transforms = self.init_transform print(data_dir) - imagenet_normalize = transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) - self._train_transforms = [ + @property + def init_transform(self): + return [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), - imagenet_normalize, + DATASET_NORM[self.DATASET_NAME], ] + + def train_transform(self, augment): + train_transforms = self.init_transform if augment["hflip"]: - self._train_transforms.insert( + train_transforms.insert( 0, transforms.RandomHorizontalFlip(p=0.5)) if augment["color_jitter"] is not None: - self._train_transforms.insert( + train_transforms.insert( 0, transforms.ColorJitter( brightness=augment["color_jitter"][0], @@ -148,20 +152,45 @@ def __init__( ), ) if augment["rotation"] > 0: - self._train_transforms.insert( + train_transforms.insert( 0, transforms.RandomRotation(augment["rotation"])) if augment["crop"]: - self._train_transforms.insert(0, - transforms.RandomCrop(32, padding=4)) + train_transforms.insert(0, transforms.RandomCrop(32, padding=4)) + return train_transforms - print(self._train_transforms) + @property + def num_classes(self): + return None - self._test_transforms = [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - imagenet_normalize, - ] + def train_dataloader(self): + if self.batch_sampler is None: + return DataLoader(self.train_set, + batch_size=self.batch_size, + num_workers=self.num_workers) + else: + return DataLoader( + self.train_set, + batch_sampler=self.batch_sampler, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + return DataLoader(self.val_set, + batch_size=self.batch_size, + num_workers=self.num_workers) + + def test_dataloader(self): + return DataLoader(self.test_set, + batch_size=self.batch_size, + num_workers=self.num_workers) + + +class ImageNetDataModule(BaseDataModule): + DATASET_NAME = 'imagenet' + + @property + def num_classes(self): + return 1000 def setup(self, stage: Optional[str] = None): """Initialize the dataset based on the stage option ('fit', 'test' or 'attack'): @@ -208,30 +237,8 @@ def setup(self, stage: Optional[str] = None): ori_train_set) self.train_set = Subset(ori_train_set, self.attack_indices) - def train_dataloader(self): - if self.batch_sampler is None: - return DataLoader(self.train_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - else: - return DataLoader( - self.train_set, - batch_sampler=self.batch_sampler, - num_workers=self.num_workers, - ) - - def val_dataloader(self): - return DataLoader(self.val_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - - def test_dataloader(self): - return DataLoader(self.test_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - -class MNISTDataModule(LightningDataModule): +class MNISTDataModule(BaseDataModule): DATASET_NAME = 'mnist' def __init__( @@ -243,56 +250,23 @@ def __init__( batch_sampler: Sampler = None, tune_on_val: float = 0, ): - super().__init__() + super().__init__(augment, data_dir, batch_size, num_workers, batch_sampler, tune_on_val) self._has_setup_attack = False - - self.data_dir = data_dir - self.batch_size = batch_size - self.num_workers = num_workers self.dims = (3, 32, 32) - self.num_classes = 10 - - self.batch_sampler = batch_sampler - self.tune_on_val = tune_on_val self.multi_class = False - mnist_normalize = DATASET_NORM[self.DATASET_NAME] - self._train_transforms = [ + @property + def init_transform(self): + return [ transforms.Resize(32), transforms.Grayscale(3), transforms.ToTensor(), - mnist_normalize, + DATASET_NORM[self.DATASET_NAME], ] - if augment["hflip"]: - self._train_transforms.insert( - 0, transforms.RandomHorizontalFlip(p=0.5)) - if augment["color_jitter"] is not None: - self._train_transforms.insert( - 0, - transforms.ColorJitter( - brightness=augment["color_jitter"][0], - contrast=augment["color_jitter"][1], - saturation=augment["color_jitter"][2], - hue=augment["color_jitter"][3], - ), - ) - if augment["rotation"] > 0: - self._train_transforms.insert( - 0, transforms.RandomRotation(augment["rotation"])) - if augment["crop"]: - self._train_transforms.insert(0, - transforms.RandomCrop(32, padding=4)) - print(self._train_transforms) - - self._test_transforms = [ - transforms.Resize(32), - transforms.Grayscale(3), - transforms.ToTensor(), - mnist_normalize, - ] - - self.prepare_data() + @property + def num_classes(self): + return 10 def prepare_data(self): MNIST(self.data_dir, train=True, download=True) @@ -369,34 +343,8 @@ def setup(self, stage: Optional[str] = None): self.train_set = Subset(ori_train_set, self.attack_indices) self.test_set = Subset(self.test_set, range(100)) - def train_dataloader(self): - if self.batch_sampler is None: - return DataLoader( - self.train_set, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - else: - return DataLoader( - self.train_set, - batch_sampler=self.batch_sampler, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self): - return DataLoader(self.val_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - - def test_dataloader(self): - return DataLoader(self.test_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - -class CIFAR10DataModule(LightningDataModule): +class CIFAR10DataModule(BaseDataModule): DATASET_NAME = 'cifar10' def __init__( @@ -409,52 +357,26 @@ def __init__( tune_on_val: float = 0, seed: int = None, ): - super().__init__() + super().__init__(augment, data_dir, batch_size, num_workers, batch_sampler, + tune_on_val) self._has_setup_attack = False + self.seed = seed - self.data_dir = data_dir - self.batch_size = batch_size - self.num_workers = num_workers self.dims = (3, 32, 32) - self.num_classes = 10 - - self.batch_sampler = batch_sampler - self.tune_on_val = tune_on_val self.multi_class = False - self.seed = seed - cifar_normalize = DATASET_NORM[self.DATASET_NAME] - - self._train_transforms = [transforms.ToTensor(), cifar_normalize] - if augment["hflip"]: - self._train_transforms.insert( - 0, transforms.RandomHorizontalFlip(p=0.5)) - if augment["color_jitter"] is not None: - self._train_transforms.insert( - 0, - transforms.ColorJitter( - brightness=augment["color_jitter"][0], - contrast=augment["color_jitter"][1], - saturation=augment["color_jitter"][2], - hue=augment["color_jitter"][3], - ), - ) - if augment["rotation"] > 0: - self._train_transforms.insert( - 0, transforms.RandomRotation(augment["rotation"])) - if augment["crop"]: - self._train_transforms.insert(0, - transforms.RandomCrop(32, padding=4)) + self.prepare_data() - print(self._train_transforms) + @property + def init_transform(self): + return [transforms.ToTensor(), DATASET_NORM[self.DATASET_NAME]] - self._test_transforms = [transforms.ToTensor(), cifar_normalize] - - self.prepare_data() + @property + def num_classes(self): + return 10 def prepare_data(self): """Download the data""" - CIFAR10(self.data_dir, train=True, download=True) - CIFAR10(self.data_dir, train=False, download=True) + CIFAR10(self.data_dir, download=True) def setup(self, stage: Optional[str] = None): """Initialize the dataset based on the stage option ('fit', 'test' or 'attack'): @@ -527,31 +449,13 @@ def setup(self, stage: Optional[str] = None): self.train_set = Subset(ori_train_set, self.attack_indices) self.test_set = Subset(self.test_set, range(100)) - def train_dataloader(self): - if self.batch_sampler is None: - return DataLoader(self.train_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - else: - return DataLoader( - self.train_set, - batch_sampler=self.batch_sampler, - num_workers=self.num_workers, - ) - - def val_dataloader(self): - return DataLoader(self.val_set, - batch_size=self.batch_size, - num_workers=self.num_workers) - def test_dataloader(self): - return DataLoader(self.test_set, - batch_size=self.batch_size, - num_workers=self.num_workers) class CIFAR100DataModule(CIFAR10DataModule): DATASET_NAME = 'cifar100' + def __init__(self): + self.num_classes = 100 + def prepare_data(self): """Download the data""" - CIFAR100(self.data_dir, train=True, download=True) - CIFAR100(self.data_dir, train=False, download=True) + CIFAR100(self.data_dir, download=True) From 26cb79ef285a91ec7dbd846170b8ebc3434f8177 Mon Sep 17 00:00:00 2001 From: limberc Date: Mon, 2 Jan 2023 17:52:50 +0000 Subject: [PATCH 05/17] Implement for CIFAR100. --- gradattack/datamodules.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/gradattack/datamodules.py b/gradattack/datamodules.py index 1f29051..cf0a893 100644 --- a/gradattack/datamodules.py +++ b/gradattack/datamodules.py @@ -370,6 +370,9 @@ def __init__( def init_transform(self): return [transforms.ToTensor(), DATASET_NORM[self.DATASET_NAME]] + def base_dataset(self, root, **kwargs): + return CIFAR10(root, **kwargs) + @property def num_classes(self): return 10 @@ -388,13 +391,11 @@ def setup(self, stage: Optional[str] = None): stage (Optional[str], optional): stage option. Defaults to None. """ if stage == "fit" or stage is None: - self.train_set = CIFAR10( - self.data_dir, - train=True, - transform=transforms.Compose(self._train_transforms), - ) + self.train_set = self.base_dataset(self.data_dir, + train=True, + transform=transforms.Compose(self._train_transforms)) if self.tune_on_val: - self.val_set = CIFAR10( + self.val_set = self.base_dataset( self.data_dir, train=True, transform=transforms.Compose(self._test_transforms), @@ -404,7 +405,7 @@ def setup(self, stage: Optional[str] = None): self.train_set = Subset(self.train_set, train_indices) self.val_set = Subset(self.val_set, val_indices) else: - self.val_set = CIFAR10( + self.val_set = self.base_dataset( self.data_dir, train=False, transform=transforms.Compose(self._test_transforms), @@ -412,14 +413,14 @@ def setup(self, stage: Optional[str] = None): # Assign test dataset for use in dataloader(s) if stage == "test" or stage is None: - self.test_set = CIFAR10( + self.test_set = self.base_dataset( self.data_dir, train=False, transform=transforms.Compose(self._test_transforms), ) if stage == "attack": - ori_train_set = CIFAR10( + ori_train_set = self.base_dataset( self.data_dir, train=True, transform=transforms.Compose(self._train_transforms), @@ -429,7 +430,7 @@ def setup(self, stage: Optional[str] = None): self.train_set = Subset(ori_train_set, self.attack_indices) self.test_set = Subset(self.test_set, range(100)) elif stage == "attack_mini": - ori_train_set = CIFAR10( + ori_train_set = self.base_dataset( self.data_dir, train=True, transform=transforms.Compose(self._train_transforms), @@ -439,7 +440,7 @@ def setup(self, stage: Optional[str] = None): self.train_set = Subset(ori_train_set, self.attack_indices) self.test_set = Subset(self.test_set, range(100)) elif stage == "attack_large": - ori_train_set = CIFAR10( + ori_train_set = self.base_dataset( self.data_dir, train=True, transform=transforms.Compose(self._train_transforms), @@ -459,3 +460,6 @@ def __init__(self): def prepare_data(self): """Download the data""" CIFAR100(self.data_dir, download=True) + + def base_dataset(self, root, **kwargs): + return CIFAR100(root, **kwargs) From bfe1c19e414bbe4ff7d15b61ba0fcbac9cc46a1b Mon Sep 17 00:00:00 2001 From: limberc Date: Mon, 2 Jan 2023 17:56:26 +0000 Subject: [PATCH 06/17] Implement for CIFAR100. --- gradattack/datamodules.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gradattack/datamodules.py b/gradattack/datamodules.py index cf0a893..1df63a0 100644 --- a/gradattack/datamodules.py +++ b/gradattack/datamodules.py @@ -454,12 +454,13 @@ def setup(self, stage: Optional[str] = None): class CIFAR100DataModule(CIFAR10DataModule): DATASET_NAME = 'cifar100' - def __init__(self): - self.num_classes = 100 - def prepare_data(self): """Download the data""" CIFAR100(self.data_dir, download=True) + @property + def num_classes(self): + return 100 + def base_dataset(self, root, **kwargs): return CIFAR100(root, **kwargs) From b98d882e0d1f4833cb3af533a8c7cbee194641fd Mon Sep 17 00:00:00 2001 From: limberc Date: Sat, 7 Jan 2023 00:36:55 +0000 Subject: [PATCH 07/17] Modify to support latest PyTorch Lightning. --- gradattack/trainingpipeline.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/gradattack/trainingpipeline.py b/gradattack/trainingpipeline.py index 055b1bc..200d978 100644 --- a/gradattack/trainingpipeline.py +++ b/gradattack/trainingpipeline.py @@ -4,12 +4,10 @@ class TrainingPipeline: - def __init__( - self, - model: LightningWrapper, - datamodule: pl.LightningDataModule, - trainer: pl.Trainer, - ): + def __init__(self, + model: LightningWrapper, + datamodule: pl.LightningDataModule, + trainer: pl.Trainer): self.model = model self.datamodule = datamodule self.trainer = trainer @@ -20,10 +18,6 @@ def __init__( ) # Modifications to the model architecture, trainable params ... self.datamodule.setup() - # FIXME: @Samyak, are we actually using this funciton? - def log_hparams(self): - self.trainer.logger.log_hyperparams(self.model.hparams) - def setup_pipeline(self): self.datamodule.prepare_data() @@ -42,7 +36,7 @@ def run(self): def test(self): return self.trainer.test( - self.model, test_dataloaders=self.datamodule.test_dataloader()) + self.model, self.datamodule.test_dataloader()) # FIXME: @Samyak, are we actually using this funciton? def get_datamodule_batch(self): From 7250c0cf9c5fe6aec746928a9ae42e6594ecd216 Mon Sep 17 00:00:00 2001 From: limberc Date: Sat, 7 Jan 2023 23:51:46 +0000 Subject: [PATCH 08/17] We don't really call get_datamodule_batch --- gradattack/trainingpipeline.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/gradattack/trainingpipeline.py b/gradattack/trainingpipeline.py index 200d978..66ccc7e 100644 --- a/gradattack/trainingpipeline.py +++ b/gradattack/trainingpipeline.py @@ -37,10 +37,3 @@ def run(self): def test(self): return self.trainer.test( self.model, self.datamodule.test_dataloader()) - - # FIXME: @Samyak, are we actually using this funciton? - def get_datamodule_batch(self): - self.datamodule.setup() - trainloader = self.datamodule.train_dataloader() - for batch in trainloader: - return batch From ee56f03b49b15e6531562a58705a7b1cd05fb600 Mon Sep 17 00:00:00 2001 From: limberc Date: Sun, 8 Jan 2023 02:32:53 +0000 Subject: [PATCH 09/17] Update --- gradattack/attacks/attack.py | 2 +- gradattack/defenses/__init__.py | 12 +-- gradattack/defenses/defense_utils.py | 2 - gradattack/models/__init__.py | 131 ++++++++++++--------------- 4 files changed, 65 insertions(+), 82 deletions(-) diff --git a/gradattack/attacks/attack.py b/gradattack/attacks/attack.py index 00ab5bc..ff3ab58 100644 --- a/gradattack/attacks/attack.py +++ b/gradattack/attacks/attack.py @@ -2,8 +2,8 @@ from typing import Callable import torch -from gradattack.trainingpipeline import TrainingPipeline +from gradattack.trainingpipeline import TrainingPipeline from .gradientinversion import GradientReconstructor diff --git a/gradattack/defenses/__init__.py b/gradattack/defenses/__init__.py index 1a6eca9..27c9f33 100644 --- a/gradattack/defenses/__init__.py +++ b/gradattack/defenses/__init__.py @@ -1,6 +1,6 @@ -from .defense import * -from .defense_utils import * -from .dpsgd import * -from .gradprune import * -from .mixup import * -from .instahide import * +from .defense import GradientDefense +from .defense_utils import DefensePack +from .dpsgd import DPSGDDefense +from .gradprune import GradPruneDefense +from .instahide import InstahideDefense +from .mixup import MixupDefense diff --git a/gradattack/defenses/defense_utils.py b/gradattack/defenses/defense_utils.py index f62a976..850f141 100644 --- a/gradattack/defenses/defense_utils.py +++ b/gradattack/defenses/defense_utils.py @@ -15,7 +15,6 @@ class DefensePack: def __init__(self, args, logger=None): self.defense_params = {} self.parse_defense_params(args) - self.logger = logger # this might be useful for logging DP prarameters in the future def apply_defense(self, pipeline: TrainingPipeline): dataset = pipeline.datamodule @@ -104,7 +103,6 @@ def parse_defense_params(self, args, verbose=True): print(Fore.MAGENTA + f"{key}: {val}", end="\t") else: print(Fore.MAGENTA + "None", end="\t") - print() def get_defensepack_str(self): def get_param_str(paramname): diff --git a/gradattack/models/__init__.py b/gradattack/models/__init__.py index d469548..e41c742 100755 --- a/gradattack/models/__init__.py +++ b/gradattack/models/__init__.py @@ -1,22 +1,23 @@ import os +from pathlib import Path from typing import Callable import pytorch_lightning as pl -from gradattack.utils import StandardizeLayer from sklearn import metrics from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, ReduceLROnPlateau, StepLR +from gradattack.utils import StandardizeLayer +from .LeNet import * from .covidmodel import * from .densenet import * from .googlenet import * from .mobilenet import * +from .multihead_resnet import * from .nasnet import * from .resnet import * from .resnext import * from .simple import * from .vgg import * -from .multihead_resnet import * -from .LeNet import * class StepTracker: @@ -51,6 +52,7 @@ def __init__( log_auc: bool = False, multi_class: bool = False, multi_head: bool = False, + save_log: str = 'save_ckpts' ): super().__init__() # if we didn't copy here, then we would modify the default dict by accident @@ -71,6 +73,8 @@ def __init__( self._training_loss_metric = training_loss_metric self._val_loss_metric = training_loss_metric + self._optimizer = optimizer + self._batch_transformations = [] self._grad_transformations = [] self._opt_transformations = [] @@ -89,6 +93,8 @@ def __init__( self.multi_class = multi_class self.multi_head = multi_head + self.save_log = save_log + def forward(self, x): if self.multi_head: output = self._model(x) @@ -98,13 +104,6 @@ def forward(self, x): else: return self._model(x) - def should_accumulate(self): - return self.trainer.train_loop.should_accumulate() - - def on_train_epoch_start(self) -> None: - for callback in self._on_train_epoch_start_callbacks: - callback(self) - def _transform_batch(self, batch, batch_idx, *args): for transform in self._batch_transformations: batch = transform(batch, batch_idx, *args) @@ -154,14 +153,15 @@ def training_step(self, batch, batch_idx, *_) -> dict: self.manual_backward(training_step_results["loss"]) - if self.should_accumulate(): - # Special case opacus optimizers to reduce memory footprint - # see: (https://github.com/pytorch/opacus/blob/244265582bffbda956511871a907e5de2c523d86/opacus/privacy_engine.py#L393) - if hasattr(self.optimizer, "virtual_step"): - with torch.no_grad(): - self.optimizer.virtual_step() - else: - self.on_non_accumulate_step() + # if self.should_accumulate(): + # # Special case opacus optimizers to reduce memory footprint + # # see: (https://github.com/pytorch/opacus/blob/244265582bffbda956511871a907e5de2c523d86/opacus/privacy_engine.py#L393) + # if hasattr(self.optimizer, "virtual_step"): + # with torch.no_grad(): + # self.optimizer.virtual_step() + # else: + # self.on_non_accumulate_step() + self.on_non_accumulate_step() if self.log_train_acc: top1_acc = accuracy( @@ -169,30 +169,23 @@ def training_step(self, batch, batch_idx, *_) -> dict: training_step_results["transformed_batch"][1], multi_head=self.multi_head, )[0] - self.log( - "step/train_acc", - top1_acc, - on_step=True, - on_epoch=False, - prog_bar=True, - logger=True, - ) + self.log("step/train_acc", top1_acc, + on_step=True, on_epoch=False, + prog_bar=True, logger=True) return training_step_results - def get_batch_gradients( - self, - batch: torch.tensor, - batch_idx: int = 0, - create_graph: bool = False, - clone_gradients: bool = True, - apply_transforms=True, - eval_mode: bool = False, - stop_track_bn_stats: bool = True, - BN_exact: bool = False, - attacker: bool = False, - *args, - ): + def get_batch_gradients(self, + batch: torch.tensor, + batch_idx: int = 0, + create_graph: bool = False, + clone_gradients: bool = True, + apply_transforms=True, + eval_mode: bool = False, + stop_track_bn_stats: bool = True, + BN_exact: bool = False, + attacker: bool = False, + *args): batch = tuple(k.to(self.device) for k in batch) if eval_mode is True: self.eval() @@ -248,14 +241,11 @@ def on_non_accumulate_step(self) -> None: if self._log_gradients: grad_norm_dict = self.grad_norm(1) for k, v in grad_norm_dict.items(): - self.log( - f"gradients/{k}", - v, - on_step=True, - on_epoch=True, - prog_bar=False, - logger=True, - ) + self.log(f"gradients/{k}", v, + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True) self.optimizer.step() self.optimizer.zero_grad() @@ -266,14 +256,12 @@ def on_non_accumulate_step(self) -> None: self.trainer.should_stop = True self.step_tracker.end() - self.log( - "step/train_loss", - self.step_tracker.cur_loss, - on_step=True, - on_epoch=False, - prog_bar=True, - logger=True, - ) + self.log("step/train_loss", + self.step_tracker.cur_loss, + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True) def configure_optimizers(self): if self.hparams["optimizer"] == "Adam": @@ -369,26 +357,23 @@ def validation_epoch_end(self, outputs): self.cur_lr = self.optimizer.param_groups[0]["lr"] - self.log( - "epoch/val_accuracy", - avg_accuracy, - on_epoch=True, - prog_bar=True, - logger=True, - ) - self.log("epoch/val_loss", - avg_loss, - on_epoch=True, - prog_bar=True, + self.log("epoch/val_accuracy", avg_accuracy, + on_epoch=True, prog_bar=True, logger=True) - self.log("epoch/lr", - self.cur_lr, - on_epoch=True, - prog_bar=True, + self.log("epoch/val_loss", avg_loss, + on_epoch=True, prog_bar=True, logger=True) - - for callback in self._epoch_end_callbacks: - callback(self) + self.log("epoch/lr", self.cur_lr, + on_epoch=True, prog_bar=True, + logger=True) + # Mannuly save model here. + path = Path(f"{self.save_log}/checkpoints/") + path.mkdir(parents=True, exist_ok=True) + if os.listdir(f"{self.save_log}/checkpoints/"): + old_ckpt = os.listdir(f"{self.save_log}/checkpoints/")[0] + os.remove(f"{self.save_log}/checkpoints/{old_ckpt}") + torch.save(self._model.state_dict(), + f"{self.save_log}/checkpoints/epoch={self.current_epoch}-val_acc_{avg_accuracy}.ckpt") def test_step(self, batch, batch_idx): x, y = batch From 31915efeb87e0dfa57954313478989d0a103bf3a Mon Sep 17 00:00:00 2001 From: limberc Date: Sun, 19 Mar 2023 12:45:16 +0800 Subject: [PATCH 10/17] Use latest PL support. --- examples/attack_cifar100_gradinversion.py | 13 ++++--- examples/attack_cifar10_gradinversion.py | 13 ++++--- examples/attack_decode.py | 3 +- examples/train_cifar10.py | 27 ++++++++----- gradattack/attacks/gradientinversion.py | 10 ++--- gradattack/datamodules.py | 6 +-- gradattack/defenses/defense.py | 1 + gradattack/models/__init__.py | 47 ++--------------------- gradattack/models/alldnet.py | 1 - gradattack/models/covidmodel.py | 2 - gradattack/models/densenet.py | 2 - gradattack/models/googlenet.py | 4 -- gradattack/models/mobilenet.py | 2 +- gradattack/models/nasnet.py | 7 ++-- gradattack/models/resnet.py | 2 - gradattack/models/resnext.py | 3 ++ gradattack/models/simple.py | 3 -- gradattack/models/vgg.py | 5 +-- gradattack/utils.py | 21 +++++----- requirements.txt | 2 +- test/train_config.py | 4 +- 21 files changed, 66 insertions(+), 112 deletions(-) diff --git a/examples/attack_cifar100_gradinversion.py b/examples/attack_cifar100_gradinversion.py index 3d6dcf3..7d7c37b 100644 --- a/examples/attack_cifar100_gradinversion.py +++ b/examples/attack_cifar100_gradinversion.py @@ -1,4 +1,5 @@ import os + import numpy as np import pytorch_lightning as pl import torch @@ -65,14 +66,14 @@ def setup_attack(): datamodule.num_classes, training_loss_metric=loss, pretrained=False, - ckpt="checkpoint/InstaHide_ckpt.ckpt", + # ckpt="checkpoint/InstaHide_ckpt.ckpt", **hparams).to(DEVICE) elif args.defense_mixup: model = create_lightning_module(args.model, datamodule.num_classes, training_loss_metric=loss, pretrained=False, - ckpt="checkpoint/Mixup_ckpt.ckpt", + # ckpt="checkpoint/Mixup_ckpt.ckpt", **hparams).to(DEVICE) else: model = create_lightning_module( @@ -80,11 +81,11 @@ def setup_attack(): datamodule.num_classes, training_loss_metric=loss, pretrained=False, - ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt", + # ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt", **hparams).to(DEVICE) logger = TensorBoardLogger("tb_logs", name=f"{args.logname}") - trainer = pl.Trainer(gpus=devices, benchmark=True, logger=logger) + trainer = pl.Trainer(devices=-1, accelerator="auto", benchmark=True, logger=logger) pipeline = TrainingPipeline(model, datamodule, trainer) defense_pack = DefensePack(args, logger) @@ -176,10 +177,10 @@ def run_attack(pipeline, attack_hparams): mean_std=(dm, ds), attacker_eval_mode=attack_hparams["attacker_eval_mode"], BN_exact=attack_hparams["BN_exact"]) - tb_logger = TensorBoardLogger(BATCH_ROOT_DIR, name="tb_log") attack_trainer = pl.Trainer( - gpus=devices, + devices=1, + accelerator="auto", logger=tb_logger, max_epochs=1, benchmark=True, diff --git a/examples/attack_cifar10_gradinversion.py b/examples/attack_cifar10_gradinversion.py index 590fc19..7b623dc 100644 --- a/examples/attack_cifar10_gradinversion.py +++ b/examples/attack_cifar10_gradinversion.py @@ -1,4 +1,5 @@ import os + import numpy as np import pytorch_lightning as pl import torch @@ -32,7 +33,7 @@ def setup_attack(): EPOCH = attack_hparams["epoch"] devices = [args.gpuid] - pl.utilities.seed.seed_everything(1234 + EPOCH) + pl.utilities.seed.seed_everything(42 + EPOCH) torch.backends.cudnn.benchmark = True BN_str = '' @@ -65,14 +66,14 @@ def setup_attack(): datamodule.num_classes, training_loss_metric=loss, pretrained=False, - ckpt="checkpoint/InstaHide_ckpt.ckpt", + # ckpt="checkpoint/InstaHide_ckpt.ckpt", **hparams).to(DEVICE) elif args.defense_mixup: model = create_lightning_module("ResNet18", datamodule.num_classes, training_loss_metric=loss, pretrained=False, - ckpt="checkpoint/Mixup_ckpt.ckpt", + # ckpt="checkpoint/Mixup_ckpt.ckpt", **hparams).to(DEVICE) else: model = create_lightning_module( @@ -80,11 +81,11 @@ def setup_attack(): datamodule.num_classes, training_loss_metric=loss, pretrained=False, - ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt", + # ckpt="checkpoint/vanilla_epoch=1-step=1531.ckpt", **hparams).to(DEVICE) logger = TensorBoardLogger("tb_logs", name=f"{args.logname}") - trainer = pl.Trainer(gpus=devices, benchmark=True, logger=logger) + trainer = pl.Trainer(devices=-1, accelerator="auto", benchmark=True, logger=logger) pipeline = TrainingPipeline(model, datamodule, trainer) defense_pack = DefensePack(args, logger) @@ -179,7 +180,7 @@ def run_attack(pipeline, attack_hparams): tb_logger = TensorBoardLogger(BATCH_ROOT_DIR, name="tb_log") attack_trainer = pl.Trainer( - gpus=devices, + devices=1, accelerator="auto", logger=tb_logger, max_epochs=1, benchmark=True, diff --git a/examples/attack_decode.py b/examples/attack_decode.py index d750f5f..cc79d32 100644 --- a/examples/attack_decode.py +++ b/examples/attack_decode.py @@ -1,5 +1,6 @@ """ -The decode step for InstaHide and MixUp is based on Carlini et al.'s attack: Is Private Learning Possible with Instance Encoding? (https://arxiv.org/pdf/2011.05315.pdf) +The decode step for InstaHide and MixUp is based on Carlini et al.'s attack: +Is Private Learning Possible with Instance Encoding? (https://arxiv.org/pdf/2011.05315.pdf) The implementation heavily relies on their code: https://github.com/carlini/privacy/commit/28b8a80924cf3766ab3230b5976388139ddef295 """ diff --git a/examples/train_cifar10.py b/examples/train_cifar10.py index ec60701..46b6acd 100644 --- a/examples/train_cifar10.py +++ b/examples/train_cifar10.py @@ -1,7 +1,7 @@ import pytorch_lightning as pl import torch from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import WandbLogger from gradattack.datamodules import CIFAR10DataModule from gradattack.defenses.defense_utils import DefensePack @@ -10,18 +10,25 @@ from gradattack.utils import cross_entropy_for_onehot, parse_args, parse_augmentation if __name__ == "__main__": - args, hparams, _ = parse_args() - - logger = TensorBoardLogger( - "tb_logs", name=f"{args.logname}/{args.optimizer}/{args.scheduler}") - devices = [args.gpuid] + method = "" + if args.defense_mixup: + method += 'mixup_' + elif args.defense_instahide: + method += 'instahide_' + elif args.defense_gradprune: + method += 'gradprune_' + logger = WandbLogger( + project='FLock_GradAttack', + name=f"CIFAR10/{method}/{args.scheduler}", + log_model=True + ) if args.early_stopping: early_stop_callback = EarlyStopping( monitor="epoch/val_loss", min_delta=0.00, - patience=args.patience, + patience=20, verbose=False, mode="min", ) @@ -62,8 +69,8 @@ ) trainer = pl.Trainer( - gpus=devices, - check_val_every_n_epoch=1, + devices=1, + check_val_every_n_epoch=3, logger=logger, max_epochs=args.n_epoch, callbacks=[early_stop_callback], @@ -75,4 +82,4 @@ defense_pack.apply_defense(pipeline) pipeline.run() - pipeline.test() \ No newline at end of file + pipeline.test() diff --git a/gradattack/attacks/gradientinversion.py b/gradattack/attacks/gradientinversion.py index d7a3f41..171ac80 100644 --- a/gradattack/attacks/gradientinversion.py +++ b/gradattack/attacks/gradientinversion.py @@ -3,14 +3,14 @@ import pytorch_lightning as pl import torch -import torch.nn.functional as F import torchmetrics +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset + from gradattack.metrics.gradients import CosineSimilarity, L2Diff from gradattack.metrics.pixelwise import MeanPixelwiseError from gradattack.trainingpipeline import TrainingPipeline from gradattack.utils import patch_image -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset # DEFAULT_HPARAMS = { @@ -358,7 +358,7 @@ def _closure(): len(recon_mean[i])), global_step=self.global_step, ) - self.manual_backward(reconstruction_loss, self.optimizer) + self.manual_backward(reconstruction_loss) if self.hparams["signed_gradients"]: if self.grayscale: self.best_guess_grayscale.grad.sign_() @@ -399,7 +399,7 @@ def _closure(): global_step=self.global_step, ) psnrs = [ - torchmetrics.functional.psnr(a, b) + torchmetrics.functional.peak_signal_noise_ratio(a, b) for (a, b) in zip(self.best_guess, self.ground_truth_inputs) ] diff --git a/gradattack/datamodules.py b/gradattack/datamodules.py index 1df63a0..ff8e62c 100644 --- a/gradattack/datamodules.py +++ b/gradattack/datamodules.py @@ -428,7 +428,7 @@ def setup(self, stage: Optional[str] = None): self.attack_indices, self.class2attacksample = extract_attack_set( ori_train_set, seed=self.seed) self.train_set = Subset(ori_train_set, self.attack_indices) - self.test_set = Subset(self.test_set, range(100)) + self.test_set = Subset(self.test_set, range(200)) elif stage == "attack_mini": ori_train_set = self.base_dataset( self.data_dir, @@ -438,7 +438,7 @@ def setup(self, stage: Optional[str] = None): self.attack_indices, self.class2attacksample = extract_attack_set( ori_train_set, sample_per_class=2) self.train_set = Subset(ori_train_set, self.attack_indices) - self.test_set = Subset(self.test_set, range(100)) + self.test_set = Subset(self.test_set, range(200)) elif stage == "attack_large": ori_train_set = self.base_dataset( self.data_dir, @@ -448,7 +448,7 @@ def setup(self, stage: Optional[str] = None): self.attack_indices, self.class2attacksample = extract_attack_set( ori_train_set, sample_per_class=500) self.train_set = Subset(ori_train_set, self.attack_indices) - self.test_set = Subset(self.test_set, range(100)) + self.test_set = Subset(self.test_set, range(200)) class CIFAR100DataModule(CIFAR10DataModule): diff --git a/gradattack/defenses/defense.py b/gradattack/defenses/defense.py index 141af0d..60ac02b 100644 --- a/gradattack/defenses/defense.py +++ b/gradattack/defenses/defense.py @@ -5,6 +5,7 @@ class GradientDefense: """Applies a gradient defense to a given pipeline. **WARNING** This may modify the pipeline via monkey-patching of defenses! Please use with care.""" + def apply(self, pipeline: TrainingPipeline): assert (self.defense_name not in pipeline.applied_defenses ), f"Tried to apply duplicate defense {self.defense_name}!" diff --git a/gradattack/models/__init__.py b/gradattack/models/__init__.py index e41c742..d4cf515 100755 --- a/gradattack/models/__init__.py +++ b/gradattack/models/__init__.py @@ -1,7 +1,7 @@ import os -from pathlib import Path from typing import Callable +import numpy as np import pytorch_lightning as pl from sklearn import metrics from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, ReduceLROnPlateau, StepLR @@ -169,7 +169,7 @@ def training_step(self, batch, batch_idx, *_) -> dict: training_step_results["transformed_batch"][1], multi_head=self.multi_head, )[0] - self.log("step/train_acc", top1_acc, + self.log("train/acc", top1_acc, on_step=True, on_epoch=False, prog_bar=True, logger=True) @@ -333,47 +333,8 @@ def validation_step(self, batch, batch_idx): pred_list, true_list = auc_list(y_hat, y) else: pred_list, true_list = None, None - return { - "batch/val_loss": loss, - "batch/val_accuracy": top1_acc, - "batch/val_pred_list": pred_list, - "batch/val_true_list": true_list, - } - - def validation_epoch_end(self, outputs): - # outputs is whatever returned in `validation_step` - avg_loss = torch.stack([x["batch/val_loss"] for x in outputs]).mean() - avg_accuracy = torch.stack([x["batch/val_accuracy"] - for x in outputs]).mean() - if self.log_auc: - self.log_aucs(outputs, stage="val") - - self.current_val_loss = avg_loss - if self.current_epoch > 0: - if self.hparams["lr_scheduler"] == "ReduceLROnPlateau": - self.lr_scheduler.step(self.current_val_loss) - else: - self.lr_scheduler.step() - - self.cur_lr = self.optimizer.param_groups[0]["lr"] - - self.log("epoch/val_accuracy", avg_accuracy, - on_epoch=True, prog_bar=True, - logger=True) - self.log("epoch/val_loss", avg_loss, - on_epoch=True, prog_bar=True, - logger=True) - self.log("epoch/lr", self.cur_lr, - on_epoch=True, prog_bar=True, - logger=True) - # Mannuly save model here. - path = Path(f"{self.save_log}/checkpoints/") - path.mkdir(parents=True, exist_ok=True) - if os.listdir(f"{self.save_log}/checkpoints/"): - old_ckpt = os.listdir(f"{self.save_log}/checkpoints/")[0] - os.remove(f"{self.save_log}/checkpoints/{old_ckpt}") - torch.save(self._model.state_dict(), - f"{self.save_log}/checkpoints/epoch={self.current_epoch}-val_acc_{avg_accuracy}.ckpt") + self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + self.log('val/acc', top1_acc, on_epoch=True, prog_bar=True, logger=True) def test_step(self, batch, batch_idx): x, y = batch diff --git a/gradattack/models/alldnet.py b/gradattack/models/alldnet.py index fb071bc..7fedb92 100755 --- a/gradattack/models/alldnet.py +++ b/gradattack/models/alldnet.py @@ -1,7 +1,6 @@ """LeNet in PyTorch.""" import torch.nn as nn import torch.nn.functional as F -from torch.autograd import Variable class AllDNet(nn.Module): diff --git a/gradattack/models/covidmodel.py b/gradattack/models/covidmodel.py index 9331a6f..46ea61b 100644 --- a/gradattack/models/covidmodel.py +++ b/gradattack/models/covidmodel.py @@ -1,5 +1,3 @@ -import numpy as np -import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models diff --git a/gradattack/models/densenet.py b/gradattack/models/densenet.py index 935470d..9157815 100755 --- a/gradattack/models/densenet.py +++ b/gradattack/models/densenet.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from torch.autograd import Variable @@ -128,5 +127,4 @@ def test_densenet(): y = net(Variable(x)) print(y) - # test_densenet() diff --git a/gradattack/models/googlenet.py b/gradattack/models/googlenet.py index 041baab..38371b8 100755 --- a/gradattack/models/googlenet.py +++ b/gradattack/models/googlenet.py @@ -1,9 +1,6 @@ """GoogLeNet with PyTorch.""" import torch import torch.nn as nn -import torch.nn.functional as F - -from torch.autograd import Variable class Inception(nn.Module): @@ -100,7 +97,6 @@ def forward(self, x): out = self.linear(out) return out - # net = GoogLeNet() # x = torch.randn(1,3,32,32) # y = net(Variable(x)) diff --git a/gradattack/models/mobilenet.py b/gradattack/models/mobilenet.py index 96a5638..8d47cb6 100755 --- a/gradattack/models/mobilenet.py +++ b/gradattack/models/mobilenet.py @@ -12,6 +12,7 @@ class Block(nn.Module): """Depthwise conv + Pointwise conv""" + def __init__(self, in_planes, out_planes, stride=1): super(Block, self).__init__() self.conv1 = nn.Conv2d( @@ -92,5 +93,4 @@ def test(): y = net(Variable(x)) print(y.size()) - # test() diff --git a/gradattack/models/nasnet.py b/gradattack/models/nasnet.py index e204816..541aac8 100644 --- a/gradattack/models/nasnet.py +++ b/gradattack/models/nasnet.py @@ -10,7 +10,6 @@ class SeperableConv2d(nn.Module): def __init__(self, input_channels, output_channels, kernel_size, **kwargs): - super().__init__() self.depthwise = nn.Conv2d(input_channels, input_channels, @@ -63,6 +62,7 @@ class Fit(nn.Module): prev_filters: filter number of tensor prev, needs to be modified filters: filter number of normal cell branch output filters """ + def __init__(self, prev_filters, filters): super().__init__() self.relu = nn.ReLU() @@ -242,7 +242,8 @@ def forward(self, x): return ( torch.cat( [ - layer1block2, # https://github.com/keras-team/keras-applications/blob/master/keras_applications/nasnet.py line 739 + layer1block2, + # https://github.com/keras-team/keras-applications/blob/master/keras_applications/nasnet.py line 739 layer1block3, layer2block1, layer2block2, @@ -314,7 +315,6 @@ def _make_layers(self, repeat_cell_num, reduction_num): layers = [] for i in range(reduction_num): - layers.extend( self._make_normal(NormalCell, repeat_cell_num, self.filters)) self.filters *= 2 @@ -339,6 +339,5 @@ def forward(self, x): def nasnet(num_classes=100): - # stem filters must be 44, it's a pytorch workaround, cant change to other number return NasNetA(4, 2, 44, 44, num_classes=num_classes) diff --git a/gradattack/models/resnet.py b/gradattack/models/resnet.py index 2089512..9d91292 100755 --- a/gradattack/models/resnet.py +++ b/gradattack/models/resnet.py @@ -12,7 +12,6 @@ import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable -import torchvision def conv3x3(in_planes, out_planes, stride=1): @@ -243,5 +242,4 @@ def test(): y = net(Variable(torch.randn(1, 3, 32, 32))) print(y.size()) - # test() diff --git a/gradattack/models/resnext.py b/gradattack/models/resnext.py index 805e024..1a75a98 100755 --- a/gradattack/models/resnext.py +++ b/gradattack/models/resnext.py @@ -1,4 +1,5 @@ from __future__ import division + """ Creates a ResNeXt Model as defined in: Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). @@ -15,6 +16,7 @@ class ResNeXtBottleneck(nn.Module): """ RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) """ + def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): """Constructor @@ -95,6 +97,7 @@ class CifarResNeXt(nn.Module): ResNext optimized for the Cifar dataset, as specified in https://arxiv.org/pdf/1611.05431.pdf """ + def __init__(self, cardinality, depth, diff --git a/gradattack/models/simple.py b/gradattack/models/simple.py index b86341a..db54fd0 100644 --- a/gradattack/models/simple.py +++ b/gradattack/models/simple.py @@ -1,6 +1,3 @@ -import random - -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F diff --git a/gradattack/models/vgg.py b/gradattack/models/vgg.py index a0f8759..0c851a1 100755 --- a/gradattack/models/vgg.py +++ b/gradattack/models/vgg.py @@ -1,12 +1,10 @@ """VGG11/13/16/19 in Pytorch.""" -import torch import torch.nn as nn -from torch.autograd import Variable cfg = { "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "VGG13": - [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "VGG16": [ 64, 64, @@ -81,7 +79,6 @@ def _make_layers(self, cfg): layers += [nn.AvgPool2d(kernel_size=1, stride=1)] return nn.Sequential(*layers) - # net = VGG('VGG11') # x = torch.randn(2,3,32,32) # print(net(Variable(x)).size()) diff --git a/gradattack/utils.py b/gradattack/utils.py index 6cdc822..bae2a47 100755 --- a/gradattack/utils.py +++ b/gradattack/utils.py @@ -14,15 +14,14 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.init as init from torch.nn.functional import log_softmax def parse_args(): - parser = argparse.ArgumentParser(description="gradattack training") + parser = argparse.ArgumentParser(description="GradAttack training") parser.add_argument("--gpuid", default="0", type=int, help="gpu id to use") parser.add_argument("--model", - default="ResNet18", + default="ResNet34", type=str, help="name of model") parser.add_argument("--data", @@ -47,7 +46,7 @@ def parse_args(): default="SGD", type=str, help="which optimizer") - parser.add_argument("--lr", default=0.05, type=float, help="initial lr") + parser.add_argument("--lr", default=0.4, type=float, help="initial lr") parser.add_argument("--decay", default=5e-4, type=float, @@ -67,7 +66,7 @@ def parse_args(): type=float, help="lambda of LambdaLR scheduler") parser.add_argument("--lr_factor", - default=0.5, + default=0.2, type=float, help="factor of lr reduction") parser.add_argument("--disable_early_stopping", @@ -255,7 +254,7 @@ def parse_args(): hparams["lr_step"] = args.lr_step hparams["lr_factor"] = args.lr_factor elif args.scheduler == "MultiStepLR": - hparams["lr_step"] = [100, 150] + hparams["lr_step"] = [5, 60, 120, 160] hparams["lr_factor"] = args.lr_factor elif args.scheduler == "LambdaLR": hparams["lr_lambda"] = args.lr_lambda @@ -287,15 +286,15 @@ def parse_args(): def parse_augmentation(args): return { "hflip": - args.aug_hflip, + args.aug_hflip, "crop": - args.aug_crop, + args.aug_crop, "rotation": - args.aug_rotation, + args.aug_rotation, "color_jitter": [float(i) for i in args.aug_colorjitter] if args.aug_colorjitter is not None else None, "affine": - args.aug_affine, + args.aug_affine, } @@ -386,7 +385,7 @@ def patch_image(x, dim=(32, 32)): x = np.append(x, np.zeros((pad_size, *x[0].shape)), axis=0) batch_size = len(x) x = np.transpose(x, (0, 2, 3, 1)) - if int(np.sqrt(batch_size))**2 == batch_size: + if int(np.sqrt(batch_size)) ** 2 == batch_size: s = int(np.sqrt(batch_size)) x = np.reshape(x, (s, s, dim[0], dim[1], 3)) x = np.concatenate(x, axis=2) diff --git a/requirements.txt b/requirements.txt index 915ae65..f263b7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,6 @@ opacus==0.14.0 pytorch_lightning==1.5.1 scikit_learn==1.0.1 torch==1.8.1 -torchcsprng==0.2.1 +torchcsprng==0.3.0 torchmetrics==0.6.0 torchvision==0.9.1 diff --git a/test/train_config.py b/test/train_config.py index d04ea26..4afd7fe 100644 --- a/test/train_config.py +++ b/test/train_config.py @@ -1,5 +1,3 @@ -from argparse import Namespace - vanilla_hparams = { "optimizer": "SGD", "lr": 0.05, @@ -10,4 +8,4 @@ "tune_on_val": 0.02, "batch_size": 128, "lr_factor": 0.5, -} \ No newline at end of file +} From 6b05089d5634df5f64025f330f8348e5228f698a Mon Sep 17 00:00:00 2001 From: limberc Date: Sun, 19 Mar 2023 12:55:39 +0800 Subject: [PATCH 11/17] Disable torchcsprng. --- examples/train_cifar10.py | 12 ++++++++---- gradattack/defenses/instahide.py | 11 +++++------ gradattack/defenses/mixup.py | 14 +++++++------- setup.cfg | 11 ----------- 4 files changed, 20 insertions(+), 28 deletions(-) diff --git a/examples/train_cifar10.py b/examples/train_cifar10.py index 46b6acd..44c9979 100644 --- a/examples/train_cifar10.py +++ b/examples/train_cifar10.py @@ -1,5 +1,6 @@ import pytorch_lightning as pl import torch +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.loggers import WandbLogger @@ -19,20 +20,23 @@ elif args.defense_gradprune: method += 'gradprune_' logger = WandbLogger( - project='FLock_GradAttack', + project='GradAttack', name=f"CIFAR10/{method}/{args.scheduler}", log_model=True ) if args.early_stopping: early_stop_callback = EarlyStopping( - monitor="epoch/val_loss", + monitor="val/loss", min_delta=0.00, patience=20, verbose=False, mode="min", ) + checkpoint_callback = ModelCheckpoint( + ) + augment = parse_augmentation(args) assert args.data == "CIFAR10" @@ -47,11 +51,11 @@ if args.defense_instahide or args.defense_mixup: loss = cross_entropy_for_onehot else: - loss = torch.nn.CrossEntropyLoss(reduction="mean") + loss = torch.nn.CrossEntropyLoss() if "multihead" in args.model: multi_head = True - loss = torch.nn.CrossEntropyLoss(reduction="mean") + loss = torch.nn.CrossEntropyLoss() else: multi_head = False diff --git a/gradattack/defenses/instahide.py b/gradattack/defenses/instahide.py index 5026ee5..0c2e6e6 100644 --- a/gradattack/defenses/instahide.py +++ b/gradattack/defenses/instahide.py @@ -1,6 +1,5 @@ import numpy as np import torch -import torchcsprng as csprng from torch.distributions.dirichlet import Dirichlet from torch.nn.functional import one_hot from torch.utils.data.dataset import Dataset @@ -51,11 +50,11 @@ def __init__(self, torch.tensor(self.alpha).repeat(self.dataset_size, 1)) self.use_csprng = use_csprng - if self.use_csprng: - if cs_prng is None: - self.cs_prng = csprng.create_random_device_generator() - else: - self.cs_prng = cs_prng + # if self.use_csprng: + # if cs_prng is None: + # self.cs_prng = csprng.create_random_device_generator() + # else: + # self.cs_prng = cs_prng # @profile def generate_mapping(self, return_tensor=True): diff --git a/gradattack/defenses/mixup.py b/gradattack/defenses/mixup.py index 6e7704d..0c85edf 100644 --- a/gradattack/defenses/mixup.py +++ b/gradattack/defenses/mixup.py @@ -1,6 +1,6 @@ import numpy as np import torch -import torchcsprng as csprng +# import torchcsprng as csprng from torch.distributions.dirichlet import Dirichlet from torch.nn.functional import one_hot from torch.utils.data.dataset import Dataset @@ -50,12 +50,12 @@ def __init__(self, self.lambda_sampler_whole = Dirichlet( torch.tensor(self.alpha).repeat(self.dataset_size, 1)) self.use_csprng = use_csprng - - if self.use_csprng: - if cs_prng is None: - self.cs_prng = csprng.create_random_device_generator() - else: - self.cs_prng = cs_prng + # + # if self.use_csprng: + # if cs_prng is None: + # self.cs_prng = csprng.create_random_device_generator() + # else: + # self.cs_prng = cs_prng # @profile def generate_mapping(self, return_tensor=True): diff --git a/setup.cfg b/setup.cfg index 5d25544..38e0c5e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,14 +17,3 @@ classifiers = [options] packages = find: python_requires = >=3.7 -install_requires = - colorama==0.4.4 - matplotlib==3.4.3 - numpy==1.21.4 - opacus==0.14.0 - pytorch_lightning==1.3.1 - scikit_learn==1.0.1 - torch==1.8.1 - torchcsprng==0.2.1 - torchmetrics==0.6.0 - torchvision==0.9.1 From 70dbf0d0eecdb11898481a0aaf8dcc156f1a6a84 Mon Sep 17 00:00:00 2001 From: limberc Date: Sun, 19 Mar 2023 13:05:26 +0800 Subject: [PATCH 12/17] Fix typos. --- gradattack/defenses/instahide.py | 2 +- gradattack/defenses/mixup.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/gradattack/defenses/instahide.py b/gradattack/defenses/instahide.py index 0c2e6e6..6ccd3a1 100644 --- a/gradattack/defenses/instahide.py +++ b/gradattack/defenses/instahide.py @@ -138,7 +138,7 @@ def instahide_batch( Returns: (torch.tensor): the InstaHide images and labels """ - mixed_x = torch.zeros_like(inputs) + mixed_x = torch.zeros_like(inputs, device=self.device) mixed_y = torch.zeros((len(inputs), self.num_classes), device=self.device) diff --git a/gradattack/defenses/mixup.py b/gradattack/defenses/mixup.py index 0c85edf..b3f76ce 100644 --- a/gradattack/defenses/mixup.py +++ b/gradattack/defenses/mixup.py @@ -123,12 +123,7 @@ def generate_mapping(self, return_tensor=True): else: return np.asarray(lams), np.asarray(selects) - def mixup_batch( - self, - inputs: torch.tensor, - lams_b: float, - selects_b: np.array, - ): + def mixup_batch(self, inputs: torch.tensor, lams_b: float, selects_b: np.array): """Generate a MixUp batch. Args: @@ -139,7 +134,7 @@ def mixup_batch( Returns: (torch.tensor): the MixUp images and labels """ - mixed_x = torch.zeros_like(inputs) + mixed_x = torch.zeros_like(inputs, device=self.device) mixed_y = torch.zeros((len(inputs), self.num_classes), device=self.device) From 9832eb893a47cee08b38d784410e3afe30d42bd5 Mon Sep 17 00:00:00 2001 From: limberc Date: Sun, 19 Mar 2023 13:46:28 +0800 Subject: [PATCH 13/17] Update. --- examples/{train_cifar10.py => train_cifar.py} | 44 +++++++++++-------- gradattack/datamodules.py | 7 ++- gradattack/models/__init__.py | 16 +++---- gradattack/trainingpipeline.py | 2 +- 4 files changed, 39 insertions(+), 30 deletions(-) rename examples/{train_cifar10.py => train_cifar.py} (69%) diff --git a/examples/train_cifar10.py b/examples/train_cifar.py similarity index 69% rename from examples/train_cifar10.py rename to examples/train_cifar.py index 44c9979..5620518 100644 --- a/examples/train_cifar10.py +++ b/examples/train_cifar.py @@ -1,15 +1,20 @@ import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from gradattack.datamodules import CIFAR10DataModule +from gradattack.datamodules import CIFAR100DataModule, CIFAR10DataModule from gradattack.defenses.defense_utils import DefensePack from gradattack.models import create_lightning_module from gradattack.trainingpipeline import TrainingPipeline from gradattack.utils import cross_entropy_for_onehot, parse_args, parse_augmentation +torch.set_float32_matmul_precision("high") +cifar_dm = { + "CIFAR10": CIFAR10DataModule, + "CIFAR100": CIFAR100DataModule +} + if __name__ == "__main__": args, hparams, _ = parse_args() method = "" @@ -19,29 +24,29 @@ method += 'instahide_' elif args.defense_gradprune: method += 'gradprune_' + else: + method += 'vanilla' + exp_name = f"{args.data}/{method}{args.scheduler}" logger = WandbLogger( project='GradAttack', - name=f"CIFAR10/{method}/{args.scheduler}", + name=exp_name, log_model=True ) - if args.early_stopping: - early_stop_callback = EarlyStopping( - monitor="val/loss", - min_delta=0.00, - patience=20, - verbose=False, - mode="min", - ) - - checkpoint_callback = ModelCheckpoint( + early_stop_callback = EarlyStopping( + monitor="val/loss_epoch", + min_delta=0.00, + patience=20, + verbose=False, + mode="min", + ) + callback = ModelCheckpoint( + exp_name, save_last=True, save_top_k=3, monitor="val/acc", mode="max", ) augment = parse_augmentation(args) - assert args.data == "CIFAR10" - - datamodule = CIFAR10DataModule( + datamodule = cifar_dm[args.data]( augment=augment, batch_size=args.batch_size, tune_on_val=args.tune_on_val, @@ -75,9 +80,12 @@ trainer = pl.Trainer( devices=1, check_val_every_n_epoch=3, + accelerator='auto', + benchmark=True, logger=logger, + num_sanity_val_steps=0, max_epochs=args.n_epoch, - callbacks=[early_stop_callback], + callbacks=[early_stop_callback, callback], accumulate_grad_batches=args.n_accumulation_steps, ) pipeline = TrainingPipeline(model, datamodule, trainer) diff --git a/gradattack/datamodules.py b/gradattack/datamodules.py index ff8e62c..3385f17 100644 --- a/gradattack/datamodules.py +++ b/gradattack/datamodules.py @@ -172,17 +172,20 @@ def train_dataloader(self): self.train_set, batch_sampler=self.batch_sampler, num_workers=self.num_workers, + pin_memory=True ) def val_dataloader(self): return DataLoader(self.val_set, batch_size=self.batch_size, - num_workers=self.num_workers) + num_workers=self.num_workers, + pin_memory=True) def test_dataloader(self): return DataLoader(self.test_set, batch_size=self.batch_size, - num_workers=self.num_workers) + num_workers=self.num_workers, + pin_memory=True) class ImageNetDataModule(BaseDataModule): diff --git a/gradattack/models/__init__.py b/gradattack/models/__init__.py index d4cf515..c1b438e 100755 --- a/gradattack/models/__init__.py +++ b/gradattack/models/__init__.py @@ -406,15 +406,13 @@ def log_aucs(self, outputs, stage="test"): ) -def create_lightning_module( - model_name: str, - num_classes: int, - pretrained: bool = False, - ckpt: str = None, - freeze_extractor: bool = False, - *args, - **kwargs, -) -> LightningWrapper: +def create_lightning_module(model_name: str, + num_classes: int, + pretrained: bool = False, + ckpt: str = None, + freeze_extractor: bool = False, + *args, + **kwargs) -> LightningWrapper: if "models" in model_name: # Official models by PyTorch model_name = model_name.replace("models.", "") if pretrained is False: diff --git a/gradattack/trainingpipeline.py b/gradattack/trainingpipeline.py index 66ccc7e..bb521ef 100644 --- a/gradattack/trainingpipeline.py +++ b/gradattack/trainingpipeline.py @@ -32,7 +32,7 @@ def setup_pipeline(self): def run(self): self.setup_pipeline() # If we didn't call setup(), any updates to transforms (e.g. from defenses) wouldn't be applied - return self.trainer.fit(self.model, self.datamodule) + return self.trainer.fit(self.model, datamodule=self.datamodule) def test(self): return self.trainer.test( From 2a1d1b9738fc4498f56a6b301be6598133cd69db Mon Sep 17 00:00:00 2001 From: limberc Date: Sun, 19 Mar 2023 13:53:21 +0800 Subject: [PATCH 14/17] Update. --- gradattack/models/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/gradattack/models/__init__.py b/gradattack/models/__init__.py index c1b438e..a385eec 100755 --- a/gradattack/models/__init__.py +++ b/gradattack/models/__init__.py @@ -93,8 +93,6 @@ def __init__( self.multi_class = multi_class self.multi_head = multi_head - self.save_log = save_log - def forward(self, x): if self.multi_head: output = self._model(x) @@ -329,13 +327,14 @@ def validation_step(self, batch, batch_idx): else: loss = self._val_loss_metric(y_hat, y) top1_acc = accuracy(y_hat, y, multi_head=self.multi_head)[0] - if self.log_auc: - pred_list, true_list = auc_list(y_hat, y) - else: - pred_list, true_list = None, None + self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log('val/acc', top1_acc, on_epoch=True, prog_bar=True, logger=True) + def validation_epoch_end(self, outputs): + for callback in self._epoch_end_callbacks: + callback(self) + def test_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) From b385ee3a93bca7a327633994e9b12dcff11bdbbe Mon Sep 17 00:00:00 2001 From: limberc Date: Sun, 19 Mar 2023 15:06:06 +0800 Subject: [PATCH 15/17] Currently, there is no way to save the model automatically --- examples/train_cifar.py | 5 +- gradattack/defenses/defense_utils.py | 2 +- gradattack/defenses/mixup.py | 3 +- gradattack/models/__init__.py | 144 ++++----------------------- 4 files changed, 23 insertions(+), 131 deletions(-) diff --git a/examples/train_cifar.py b/examples/train_cifar.py index 5620518..e1d713f 100644 --- a/examples/train_cifar.py +++ b/examples/train_cifar.py @@ -30,7 +30,6 @@ logger = WandbLogger( project='GradAttack', name=exp_name, - log_model=True ) early_stop_callback = EarlyStopping( @@ -74,10 +73,12 @@ log_auc=args.log_auc, multi_class=datamodule.multi_class, multi_head=multi_head, + log_dir=exp_name, **hparams, ) trainer = pl.Trainer( + default_root_dir=exp_name, devices=1, check_val_every_n_epoch=3, accelerator='auto', @@ -90,7 +91,7 @@ ) pipeline = TrainingPipeline(model, datamodule, trainer) - defense_pack = DefensePack(args, logger) + defense_pack = DefensePack(args) defense_pack.apply_defense(pipeline) pipeline.run() diff --git a/gradattack/defenses/defense_utils.py b/gradattack/defenses/defense_utils.py index 850f141..a5351e3 100644 --- a/gradattack/defenses/defense_utils.py +++ b/gradattack/defenses/defense_utils.py @@ -12,7 +12,7 @@ class DefensePack: - def __init__(self, args, logger=None): + def __init__(self, args): self.defense_params = {} self.parse_defense_params(args) diff --git a/gradattack/defenses/mixup.py b/gradattack/defenses/mixup.py index b3f76ce..2f8deb9 100644 --- a/gradattack/defenses/mixup.py +++ b/gradattack/defenses/mixup.py @@ -99,8 +99,7 @@ def generate_mapping(self, return_tensor=True): lams = self.lambda_sampler_whole.sample().to(self.device) selects = torch.stack([ torch.randperm(self.dataset_size, - device=self.device, - generator=self.cs_prng) + device=self.device) for _ in range(self.klam) ]) selects = torch.transpose(selects, 0, 1) diff --git a/gradattack/models/__init__.py b/gradattack/models/__init__.py index a385eec..fadfe84 100755 --- a/gradattack/models/__init__.py +++ b/gradattack/models/__init__.py @@ -1,9 +1,7 @@ import os from typing import Callable -import numpy as np import pytorch_lightning as pl -from sklearn import metrics from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, ReduceLROnPlateau, StepLR from gradattack.utils import StandardizeLayer @@ -52,7 +50,7 @@ def __init__( log_auc: bool = False, multi_class: bool = False, multi_head: bool = False, - save_log: str = 'save_ckpts' + log_dir: str = None ): super().__init__() # if we didn't copy here, then we would modify the default dict by accident @@ -92,6 +90,7 @@ def __init__( self.log_auc = log_auc self.multi_class = multi_class self.multi_head = multi_head + self.log_dir = log_dir def forward(self, x): if self.multi_head: @@ -107,6 +106,10 @@ def _transform_batch(self, batch, batch_idx, *args): batch = transform(batch, batch_idx, *args) return batch + def on_train_epoch_start(self) -> None: + for callback in self._on_train_epoch_start_callbacks: + callback(self) + def _transform_gradients(self): for transform in self._grad_transformations: self._model = transform(self._model) @@ -161,77 +164,15 @@ def training_step(self, batch, batch_idx, *_) -> dict: # self.on_non_accumulate_step() self.on_non_accumulate_step() - if self.log_train_acc: - top1_acc = accuracy( - training_step_results["model_outputs"], - training_step_results["transformed_batch"][1], - multi_head=self.multi_head, - )[0] - self.log("train/acc", top1_acc, - on_step=True, on_epoch=False, - prog_bar=True, logger=True) + top1_acc = accuracy( + training_step_results["model_outputs"], + training_step_results["transformed_batch"][1], + multi_head=self.multi_head, + )[0] + self.log("train/acc", top1_acc, on_epoch=True, logger=True) return training_step_results - def get_batch_gradients(self, - batch: torch.tensor, - batch_idx: int = 0, - create_graph: bool = False, - clone_gradients: bool = True, - apply_transforms=True, - eval_mode: bool = False, - stop_track_bn_stats: bool = True, - BN_exact: bool = False, - attacker: bool = False, - *args): - batch = tuple(k.to(self.device) for k in batch) - if eval_mode is True: - self.eval() - else: - self.train() - - if BN_exact: - for module in self._model.modules(): - if isinstance(module, torch.nn.BatchNorm2d): - if not attacker: - module.reset_running_stats() # reset BN statistics - module.momentum = ( - 1 # save current BN statistics as running statistics - ) - if attacker: - self.training = False # set BN module to eval mode - module.momentum = 0 # stop tracking BN statistics - if hasattr(module, "weight"): - module.weight.requires_grad_(True) - if hasattr(module, "bias"): - module.bias.requires_grad_(True) - - if stop_track_bn_stats: - for module in self._model.modules(): - if isinstance(module, torch.nn.BatchNorm2d): - module.momentum = 0 # Stop tracking running mean and std any more - - self.zero_grad() - training_step_results = self._compute_training_step( - batch, batch_idx, apply_batch_transforms=apply_transforms, *args) - - # Make sure to apply transformations to gradients - if apply_transforms: - training_step_results["loss"].backward() - self._transform_gradients() - # Clone to prevent the gradients from being changed by training - batch_gradients = tuple( - p.grad.clone() if clone_gradients is True else p.grad - for p in self.parameters()) - else: - batch_gradients = torch.autograd.grad( - training_step_results["loss"], - self._model.parameters(), - create_graph=create_graph, - ) - - return batch_gradients, training_step_results - def on_non_accumulate_step(self) -> None: # This hook runs only after accumulation self._transform_gradients() @@ -240,9 +181,7 @@ def on_non_accumulate_step(self) -> None: grad_norm_dict = self.grad_norm(1) for k, v in grad_norm_dict.items(): self.log(f"gradients/{k}", v, - on_step=True, on_epoch=True, - prog_bar=False, logger=True) self.optimizer.step() @@ -254,11 +193,9 @@ def on_non_accumulate_step(self) -> None: self.trainer.should_stop = True self.step_tracker.end() - self.log("step/train_loss", + self.log("train/loss", self.step_tracker.cur_loss, - on_step=True, - on_epoch=False, - prog_bar=True, + on_epoch=True, logger=True) def configure_optimizers(self): @@ -332,8 +269,10 @@ def validation_step(self, batch, batch_idx): self.log('val/acc', top1_acc, on_epoch=True, prog_bar=True, logger=True) def validation_epoch_end(self, outputs): - for callback in self._epoch_end_callbacks: - callback(self) + path = self.log_dir + '/last.ckpt' + if os.path.exists(self.log_dir + '/last.ckpt'): + os.remove(self.log_dir + '/last.ckpt') + self.trainer.save_checkpoint(self.log_dir + '/last.ckpt') def test_step(self, batch, batch_idx): x, y = batch @@ -357,53 +296,6 @@ def test_step(self, batch, batch_idx): "batch/test_true_list": true_list, } - def test_epoch_end(self, outputs): - avg_loss = torch.stack([x["batch/test_loss"] for x in outputs]).mean() - avg_accuracy = torch.stack([x["batch/test_accuracy"] - for x in outputs]).mean() - if self.log_auc: - self.log_aucs(outputs, stage="test") - - self.log("run/test_accuracy", - avg_accuracy, - on_epoch=True, - prog_bar=True, - logger=True) - self.log("run/test_loss", - avg_loss, - on_epoch=True, - prog_bar=True, - logger=True) - - def log_aucs(self, outputs, stage="test"): - pred_list = np.concatenate( - [x[f"batch/{stage}_pred_list"] for x in outputs]) - true_list = np.concatenate( - [x[f"batch/{stage}_true_list"] for x in outputs]) - - aucs = [] - for c in range(len(pred_list[0])): - fpr, tpr, thresholds = metrics.roc_curve(true_list[:, c], - pred_list[:, c], - pos_label=1) - auc_val = metrics.auc(fpr, tpr) - aucs.append(auc_val) - - self.log( - f"epoch/{stage}_auc/class_{c}", - auc_val, - on_epoch=True, - prog_bar=False, - logger=True, - ) - self.log( - f"epoch/{stage}_auc/avg", - np.mean(aucs), - on_epoch=True, - prog_bar=True, - logger=True, - ) - def create_lightning_module(model_name: str, num_classes: int, From 1920b88911e877f4efe77cc82776040220b7b269 Mon Sep 17 00:00:00 2001 From: limberc Date: Wed, 29 Mar 2023 22:57:38 +0800 Subject: [PATCH 16/17] Update. --- examples/train_cifar.py | 2 +- gradattack/attacks/gradientinversion.py | 145 +++++---------- gradattack/datamodules.py | 3 +- gradattack/models/__init__.py | 235 ++++++++++++++++++++---- gradattack/models/densenet.py | 10 +- gradattack/trainingpipeline.py | 3 +- gradattack/utils.py | 186 ++++++------------- 7 files changed, 312 insertions(+), 272 deletions(-) diff --git a/examples/train_cifar.py b/examples/train_cifar.py index e1d713f..f5113f3 100644 --- a/examples/train_cifar.py +++ b/examples/train_cifar.py @@ -35,7 +35,7 @@ early_stop_callback = EarlyStopping( monitor="val/loss_epoch", min_delta=0.00, - patience=20, + patience=50, verbose=False, mode="min", ) diff --git a/gradattack/attacks/gradientinversion.py b/gradattack/attacks/gradientinversion.py index 171ac80..d60747e 100644 --- a/gradattack/attacks/gradientinversion.py +++ b/gradattack/attacks/gradientinversion.py @@ -1,11 +1,13 @@ import copy -from typing import Any, Callable, Optional +from typing import Any, Callable import pytorch_lightning as pl import torch -import torchmetrics +import torch.nn.functional as F +from torch.optim.lr_scheduler import MultiStepLR from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset +from torchmetrics.functional import peak_signal_noise_ratio from gradattack.metrics.gradients import CosineSimilarity, L2Diff from gradattack.metrics.pixelwise import MeanPixelwiseError @@ -13,23 +15,6 @@ from gradattack.utils import patch_image -# DEFAULT_HPARAMS = { -# "optimizer": "Adam", -# "lr_scheduler": True, -# "lr": 0.1, -# "total_variation": 1e-1, -# "l2": 0, -# "bn_reg": 0, -# "first_bn_multiplier": 10, -# # If true, will apply image priors on the absolute value of the recovered images -# "signed_image": False, -# "signed_gradients": True, -# "boxed": True, -# "attacker_eval_mode": True, -# "recipe": 'Geiping' -# } - - class BNFeatureHook: """ Implementation of the forward hook to track feature statistics and compute a loss on them. @@ -90,32 +75,30 @@ def __getitem__(self, idx: int): class GradientReconstructor(pl.LightningModule): - def __init__( - self, - pipeline: TrainingPipeline, - ground_truth_inputs: tuple, - ground_truth_gradients: tuple, - ground_truth_labels: tuple, - intial_reconstruction: torch.tensor = None, - reconstruct_labels=False, - attack_loss_metric: Callable = CosineSimilarity(), - mean_std: tuple = (0.0, 1.0), - num_iterations=10000, - optimizer: str = "Adam", - lr_scheduler: bool = True, - lr: float = 0.1, - total_variation: float = 1e-1, - l2: float = 0, - bn_reg: float = 0, - first_bn_multiplier: float = 1, - signed_image: bool = False, - signed_gradients: bool = True, - boxed: bool = True, - attacker_eval_mode: bool = True, - recipe: str = 'Geiping', - BN_exact: bool = False, - grayscale: bool = False, - ): + def __init__(self, + pipeline: TrainingPipeline, + ground_truth_inputs: tuple, + ground_truth_gradients: tuple, + ground_truth_labels: tuple, + intial_reconstruction: torch.tensor = None, + reconstruct_labels=False, + attack_loss_metric: Callable = CosineSimilarity(), + mean_std: tuple = (0.0, 1.0), + num_iterations=10000, + optimizer: str = "Adam", + lr_scheduler: bool = True, + lr: float = 0.1, + total_variation: float = 1e-1, + l2: float = 0, + bn_reg: float = 0, + first_bn_multiplier: float = 1, + signed_image: bool = False, + signed_gradients: bool = True, + boxed: bool = True, + attacker_eval_mode: bool = True, + recipe: str = 'Geiping', + BN_exact: bool = False, + grayscale: bool = False): super().__init__() self.save_hyperparameters("optimizer", "lr_scheduler", "lr", "total_variation", "l2", "bn_reg", @@ -177,23 +160,15 @@ def __init__( class loss_fn(torch.nn.Module): def __call__(self, pred, labels): if len(labels.shape) >= 2: - labels = torch.nn.functional.softmax(labels, dim=-1) - return torch.mean( - torch.sum( - -labels * - torch.nn.functional.log_softmax(pred, dim=-1), - 1, - )) + labels = F.softmax(labels, dim=-1) + return torch.mean(torch.sum(-labels * F.log_softmax(pred, dim=-1), 1)) else: - return torch.nn.functional.cross_entropy(pred, labels) + return F.cross_entropy(pred, labels) - self._model._training_loss_metric = None self._model._training_loss_metric = loss_fn() else: self._reconstruct_labels = False - self._batch_transferred = False - self.loss_r_feature_layers = [] for module in self._model.modules(): @@ -272,17 +247,6 @@ def train_dataloader(self) -> Any: return DataLoader( DummyGradientDataset(num_values=self.num_iterations, ), ) - def transfer_batch_to_device(self, batch: Any, - device: Optional[torch.device]) -> Any: - if not self._batch_transferred: - self.ground_truth_labels = self.ground_truth_labels.detach().to( - self.device) - self.ground_truth_gradients = tuple( - x.detach().to(self.device) - for x in self.ground_truth_gradients) - self._batch_transferred = True - return (self.ground_truth_gradients, self.ground_truth_labels) - def training_step(self, batch, *args): input_gradients, labels = batch @@ -302,25 +266,19 @@ def _closure(): attacker=True, ) if self.recipe == 'Geiping': - reconstruction_loss = self._attack_loss_metric( + reconst_loss = self._attack_loss_metric( recovered_gradients, input_gradients) - reconstruction_loss += self.hparams[ - "total_variation"] * total_variation( + reconst_loss += self.hparams["total_variation"] * total_variation( self.best_guess, self.hparams["signed_image"]) - reconstruction_loss += self.hparams["l2"] * l2_norm( + reconst_loss += self.hparams["l2"] * l2_norm( self.best_guess, self.hparams["signed_image"]) elif self.recipe == 'Zhu': ## TODO: test self._attack_loss_metric = L2Diff() - reconstruction_loss = self._attack_loss_metric( + reconst_loss = self._attack_loss_metric( recovered_gradients, input_gradients) - recon_mean = [ - mod.mean - for (idx, mod) in enumerate(self.loss_r_feature_layers) - ] - recon_var = [ - mod.var for (idx, mod) in enumerate(self.loss_r_feature_layers) - ] + recon_mean = [mod.mean for (idx, mod) in enumerate(self.loss_r_feature_layers)] + recon_var = [mod.var for (idx, mod) in enumerate(self.loss_r_feature_layers)] if self.hparams["bn_reg"] > 0: rescale = [self.hparams["first_bn_multiplier"]] + [ @@ -330,16 +288,14 @@ def _closure(): mod.r_feature * rescale[idx] for (idx, mod) in enumerate(self.loss_r_feature_layers) ]) - - reconstruction_loss += self.hparams["bn_reg"] * loss_r_feature + reconst_loss += self.hparams["bn_reg"] * loss_r_feature self.logger.experiment.add_scalar("Loss", step_results["loss"], global_step=self.global_step) self.logger.experiment.add_scalar( - "Reconstruction Metric Loss", - reconstruction_loss, - global_step=self.global_step, + "Reconstruction Metric Loss", reconst_loss, + global_step=self.global_step ) if self.global_step % 100 == 0: @@ -358,13 +314,13 @@ def _closure(): len(recon_mean[i])), global_step=self.global_step, ) - self.manual_backward(reconstruction_loss) + self.manual_backward(reconst_loss) if self.hparams["signed_gradients"]: if self.grayscale: self.best_guess_grayscale.grad.sign_() else: self.best_guess.grad.sign_() - return reconstruction_loss + return reconst_loss reconstruction_loss = self.optimizer.step(closure=_closure) if self.hparams["lr_scheduler"]: @@ -398,14 +354,10 @@ def _closure(): ), global_step=self.global_step, ) - psnrs = [ - torchmetrics.functional.peak_signal_noise_ratio(a, b) - for (a, - b) in zip(self.best_guess, self.ground_truth_inputs) - ] + psnrs = [peak_signal_noise_ratio(a, b) + for (a, b) in zip(self.best_guess, self.ground_truth_inputs)] avg_psnr = sum(psnrs) / self.num_images - self.logger.experiment.add_scalar("Avg. PSNR", - avg_psnr, + self.logger.experiment.add_scalar("Avg. PSNR", avg_psnr, global_step=self.global_step) rmses = [ @@ -437,9 +389,8 @@ def configure_optimizers(self): self.best_guess = self.best_guess.to(self.device) self.labels = self.labels.to(self.device) if self.grayscale: - parameters = ([ - self.best_guess_grayscale, self.labels - ] if self._reconstruct_labels else [self.best_guess_grayscale]) + parameters = ([self.best_guess_grayscale, self.labels + ] if self._reconstruct_labels else [self.best_guess_grayscale]) else: parameters = ([self.best_guess, self.labels] if self._reconstruct_labels else [self.best_guess]) @@ -459,7 +410,7 @@ def configure_optimizers(self): def configure_lr_scheduler(self): if self.hparams["lr_scheduler"]: - self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + self.lr_scheduler = MultiStepLR( self.optimizer, milestones=[ self.num_iterations // 2.667, diff --git a/gradattack/datamodules.py b/gradattack/datamodules.py index 3385f17..96582dc 100644 --- a/gradattack/datamodules.py +++ b/gradattack/datamodules.py @@ -10,8 +10,7 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset from torch.utils.data.sampler import Sampler -from torchvision.datasets import MNIST -from torchvision.datasets.cifar import CIFAR10, CIFAR100 +from torchvision.datasets import CIFAR10, CIFAR100, MNIST DEFAULT_DATA_DIR = "./data" DEFAULT_NUM_WORKERS = 32 diff --git a/gradattack/models/__init__.py b/gradattack/models/__init__.py index fadfe84..634b414 100755 --- a/gradattack/models/__init__.py +++ b/gradattack/models/__init__.py @@ -2,6 +2,7 @@ from typing import Callable import pytorch_lightning as pl +from sklearn import metrics from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, ReduceLROnPlateau, StepLR from gradattack.utils import StandardizeLayer @@ -50,7 +51,6 @@ def __init__( log_auc: bool = False, multi_class: bool = False, multi_head: bool = False, - log_dir: str = None ): super().__init__() # if we didn't copy here, then we would modify the default dict by accident @@ -71,8 +71,6 @@ def __init__( self._training_loss_metric = training_loss_metric self._val_loss_metric = training_loss_metric - self._optimizer = optimizer - self._batch_transformations = [] self._grad_transformations = [] self._opt_transformations = [] @@ -90,7 +88,6 @@ def __init__( self.log_auc = log_auc self.multi_class = multi_class self.multi_head = multi_head - self.log_dir = log_dir def forward(self, x): if self.multi_head: @@ -101,15 +98,18 @@ def forward(self, x): else: return self._model(x) - def _transform_batch(self, batch, batch_idx, *args): - for transform in self._batch_transformations: - batch = transform(batch, batch_idx, *args) - return batch + def should_accumulate(self): + return self.trainer.train_loop.should_accumulate() def on_train_epoch_start(self) -> None: for callback in self._on_train_epoch_start_callbacks: callback(self) + def _transform_batch(self, batch, batch_idx, *args): + for transform in self._batch_transformations: + batch = transform(batch, batch_idx, *args) + return batch + def _transform_gradients(self): for transform in self._grad_transformations: self._model = transform(self._model) @@ -122,7 +122,6 @@ def _compute_training_step(self, Args: batch : The batch inputs. Should be a torch tensor with outermost dimension 2, where dimension 0 corresponds to inputs and dimension 1 corresponds to labels. - Returns: dict: The results from the training step. Is a dictionary with keys "loss", "transformed_batch", and "model_outputs". """ @@ -164,15 +163,82 @@ def training_step(self, batch, batch_idx, *_) -> dict: # self.on_non_accumulate_step() self.on_non_accumulate_step() - top1_acc = accuracy( - training_step_results["model_outputs"], - training_step_results["transformed_batch"][1], - multi_head=self.multi_head, - )[0] - self.log("train/acc", top1_acc, on_epoch=True, logger=True) + if self.log_train_acc: + top1_acc = accuracy( + training_step_results["model_outputs"], + training_step_results["transformed_batch"][1], + multi_head=self.multi_head, + )[0] + self.log( + "step/train_acc", + top1_acc, + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) return training_step_results + def get_batch_gradients(self, + batch: torch.tensor, + batch_idx: int = 0, + create_graph: bool = False, + clone_gradients: bool = True, + apply_transforms=True, + eval_mode: bool = False, + stop_track_bn_stats: bool = True, + BN_exact: bool = False, + attacker: bool = False, + *args): + batch = tuple(k.to(self.device) for k in batch) + if eval_mode is True: + self.eval() + else: + self.train() + + if BN_exact: + for module in self._model.modules(): + if isinstance(module, torch.nn.BatchNorm2d): + if not attacker: + module.reset_running_stats() # reset BN statistics + module.momentum = ( + 1 # save current BN statistics as running statistics + ) + if attacker: + self.training = False # set BN module to eval mode + module.momentum = 0 # stop tracking BN statistics + if hasattr(module, "weight"): + module.weight.requires_grad_(True) + if hasattr(module, "bias"): + module.bias.requires_grad_(True) + + if stop_track_bn_stats: + for module in self._model.modules(): + if isinstance(module, torch.nn.BatchNorm2d): + module.momentum = 0 # Stop tracking running mean and std any more + + self.zero_grad() + training_step_results = self._compute_training_step( + batch, batch_idx, apply_batch_transforms=apply_transforms, *args) + + # Make sure to apply transformations to gradients + if apply_transforms: + training_step_results["loss"].backward() + self._transform_gradients() + # Clone to prevent the gradients from being changed by training + batch_gradients = tuple( + p.grad.clone() if clone_gradients is True else p.grad + for p in self.parameters()) + else: + batch_gradients = torch.autograd.grad( + training_step_results["loss"], + self._model.parameters(), + create_graph=create_graph, + ) + + return batch_gradients, training_step_results + def on_non_accumulate_step(self) -> None: # This hook runs only after accumulation self._transform_gradients() @@ -180,9 +246,14 @@ def on_non_accumulate_step(self) -> None: if self._log_gradients: grad_norm_dict = self.grad_norm(1) for k, v in grad_norm_dict.items(): - self.log(f"gradients/{k}", v, - on_epoch=True, - logger=True) + self.log( + f"gradients/{k}", + v, + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + ) self.optimizer.step() self.optimizer.zero_grad() @@ -193,10 +264,14 @@ def on_non_accumulate_step(self) -> None: self.trainer.should_stop = True self.step_tracker.end() - self.log("train/loss", - self.step_tracker.cur_loss, - on_epoch=True, - logger=True) + self.log( + "step/train_loss", + self.step_tracker.cur_loss, + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) def configure_optimizers(self): if self.hparams["optimizer"] == "Adam": @@ -264,15 +339,54 @@ def validation_step(self, batch, batch_idx): else: loss = self._val_loss_metric(y_hat, y) top1_acc = accuracy(y_hat, y, multi_head=self.multi_head)[0] - - self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - self.log('val/acc', top1_acc, on_epoch=True, prog_bar=True, logger=True) + if self.log_auc: + pred_list, true_list = auc_list(y_hat, y) + else: + pred_list, true_list = None, None + return { + "batch/val_loss": loss, + "batch/val_accuracy": top1_acc, + "batch/val_pred_list": pred_list, + "batch/val_true_list": true_list, + } def validation_epoch_end(self, outputs): - path = self.log_dir + '/last.ckpt' - if os.path.exists(self.log_dir + '/last.ckpt'): - os.remove(self.log_dir + '/last.ckpt') - self.trainer.save_checkpoint(self.log_dir + '/last.ckpt') + # outputs is whatever returned in `validation_step` + avg_loss = torch.stack([x["batch/val_loss"] for x in outputs]).mean() + avg_accuracy = torch.stack([x["batch/val_accuracy"] + for x in outputs]).mean() + if self.log_auc: + self.log_aucs(outputs, stage="val") + + self.current_val_loss = avg_loss + if self.current_epoch > 0: + if self.hparams["lr_scheduler"] == "ReduceLROnPlateau": + self.lr_scheduler.step(self.current_val_loss) + else: + self.lr_scheduler.step() + + self.cur_lr = self.optimizer.param_groups[0]["lr"] + + self.log( + "epoch/val_accuracy", + avg_accuracy, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.log("epoch/val_loss", + avg_loss, + on_epoch=True, + prog_bar=True, + logger=True) + self.log("epoch/lr", + self.cur_lr, + on_epoch=True, + prog_bar=True, + logger=True) + + for callback in self._epoch_end_callbacks: + callback(self) def test_step(self, batch, batch_idx): x, y = batch @@ -296,14 +410,63 @@ def test_step(self, batch, batch_idx): "batch/test_true_list": true_list, } + def test_epoch_end(self, outputs): + avg_loss = torch.stack([x["batch/test_loss"] for x in outputs]).mean() + avg_accuracy = torch.stack([x["batch/test_accuracy"] + for x in outputs]).mean() + if self.log_auc: + self.log_aucs(outputs, stage="test") + + self.log("run/test_accuracy", + avg_accuracy, + on_epoch=True, + prog_bar=True, + logger=True) + self.log("run/test_loss", + avg_loss, + on_epoch=True, + prog_bar=True, + logger=True) + + def log_aucs(self, outputs, stage="test"): + pred_list = np.concatenate( + [x[f"batch/{stage}_pred_list"] for x in outputs]) + true_list = np.concatenate( + [x[f"batch/{stage}_true_list"] for x in outputs]) + + aucs = [] + for c in range(len(pred_list[0])): + fpr, tpr, thresholds = metrics.roc_curve(true_list[:, c], + pred_list[:, c], + pos_label=1) + auc_val = metrics.auc(fpr, tpr) + aucs.append(auc_val) + + self.log( + f"epoch/{stage}_auc/class_{c}", + auc_val, + on_epoch=True, + prog_bar=False, + logger=True, + ) + self.log( + f"epoch/{stage}_auc/avg", + np.mean(aucs), + on_epoch=True, + prog_bar=True, + logger=True, + ) + -def create_lightning_module(model_name: str, - num_classes: int, - pretrained: bool = False, - ckpt: str = None, - freeze_extractor: bool = False, - *args, - **kwargs) -> LightningWrapper: +def create_lightning_module( + model_name: str, + num_classes: int, + pretrained: bool = False, + ckpt: str = None, + freeze_extractor: bool = False, + *args, + **kwargs, +) -> LightningWrapper: if "models" in model_name: # Official models by PyTorch model_name = model_name.replace("models.", "") if pretrained is False: diff --git a/gradattack/models/densenet.py b/gradattack/models/densenet.py index 9157815..3660562 100755 --- a/gradattack/models/densenet.py +++ b/gradattack/models/densenet.py @@ -16,10 +16,8 @@ def __init__(self, in_planes, growth_rate): kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(4 * growth_rate) - self.conv2 = nn.Conv2d(4 * growth_rate, - growth_rate, - kernel_size=3, - padding=1, + self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, + kernel_size=3, padding=1, bias=False) def forward(self, x): @@ -42,9 +40,7 @@ def forward(self, x): class DenseNet(nn.Module): - def __init__(self, - block, - nblocks, + def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): diff --git a/gradattack/trainingpipeline.py b/gradattack/trainingpipeline.py index bb521ef..d9c8f9b 100644 --- a/gradattack/trainingpipeline.py +++ b/gradattack/trainingpipeline.py @@ -4,8 +4,7 @@ class TrainingPipeline: - def __init__(self, - model: LightningWrapper, + def __init__(self, model: LightningWrapper, datamodule: pl.LightningDataModule, trainer: pl.Trainer): self.model = model diff --git a/gradattack/utils.py b/gradattack/utils.py index bae2a47..0fcbb4f 100755 --- a/gradattack/utils.py +++ b/gradattack/utils.py @@ -24,32 +24,19 @@ def parse_args(): default="ResNet34", type=str, help="name of model") - parser.add_argument("--data", - default="CIFAR10", - type=str, - help="name of dataset") - parser.add_argument( - "--results_dir", - default="./results", - type=str, - help="directory to save attack results", - ) - parser.add_argument("--n_epoch", - default=200, - type=int, - help="number of epochs") - parser.add_argument("--batch_size", - default=128, - type=int, + parser.add_argument("--data", default="CIFAR10", + type=str, help="name of dataset") + parser.add_argument("--results_dir", default="./results", + type=str, help="directory to save attack results") + parser.add_argument("--n_epoch", default=200, type=int, help="number of epochs") + parser.add_argument("--batch_size", default=512, type=int, help="batch size") parser.add_argument("--optimizer", default="SGD", type=str, help="which optimizer") parser.add_argument("--lr", default=0.4, type=float, help="initial lr") - parser.add_argument("--decay", - default=5e-4, - type=float, + parser.add_argument("--decay", default=5e-4, type=float, help="weight decay") parser.add_argument("--momentum", type=float, default=0.9) parser.add_argument("--nesterov", action="store_true") @@ -57,9 +44,7 @@ def parse_args(): default="ReduceLROnPlateau", type=str, help="which scheduler") - parser.add_argument("--lr_step", - default=30, - type=int, + parser.add_argument("--lr_step", default=30, type=int, help="reduce LR per ? epochs") parser.add_argument("--lr_lambda", default=0.95, @@ -76,49 +61,33 @@ def parse_args(): default=10, type=int, help="patience for early stopping") - parser.add_argument( - "--tune_on_val", - default=0.02, - type=float, - help= - "fraction of validation data. If set to 0, use test data as the val data", - ) + parser.add_argument("--tune_on_val", + default=0.02, + type=float, + help="fraction of validation data. If set to 0, use test data as the val data", ) parser.add_argument("--log_auc", dest="log_auc", action="store_true") - parser.add_argument("--logname", - default="vanilla", - type=str, - help="log name") + parser.add_argument("--logname", default="vanilla", type=str, help="log name") parser.add_argument("--pretrained", dest="pretrained", action="store_true") - parser.add_argument("--ckpt", - default=None, - type=str, - help="directory for ckpt") - parser.add_argument( - "--freeze_extractor", - dest="freeze_extractor", - action="store_true", - help="Whether only training the fc layer", - ) + parser.add_argument("--ckpt", default=None, type=str, help="directory for ckpt") + parser.add_argument("--freeze_extractor", + dest="freeze_extractor", + action="store_true", + help="Whether only training the fc layer", ) # Augmentation - parser.add_argument( - "--dis_aug_crop", - dest="aug_crop", - action="store_false", - help="Whether to apply random cropping", - ) - parser.add_argument( - "--dis_aug_hflip", - dest="aug_hflip", - action="store_false", - help="Whether to apply horizontally flipping", - ) - parser.add_argument( - "--aug_affine", - dest="aug_affine", - action="store_true", - help="Enable random affine", - ) + parser.add_argument("--dis_aug_crop", + dest="aug_crop", + action="store_false", + help="Whether to apply random cropping") + parser.add_argument("--dis_aug_hflip", + dest="aug_hflip", + action="store_false", + help="Whether to apply horizontally flipping") + parser.add_argument("--aug_affine", + dest="aug_affine", + action="store_true", + help="Enable random affine", + ) parser.add_argument( "--aug_rotation", type=float, @@ -136,17 +105,11 @@ def parse_args(): parser.add_argument("--defense_instahide", dest="defense_instahide", action="store_true") - parser.add_argument("--klam", - default=4, - type=int, + parser.add_argument("--klam", default=4, type=int, help="How many images to mix with") - parser.add_argument("--c_1", - default=0, - type=float, + parser.add_argument("--c_1", efault=0, type=float, help="Lower bound of mixing coefs") - parser.add_argument("--c_2", - default=1, - type=float, + parser.add_argument("--c_2", default=1, type=float, help="Upper bound of mixing coefs") parser.add_argument("--use_csprng", dest="use_csprng", action="store_true") # GradPrune @@ -165,34 +128,20 @@ def parse_args(): type=float, help="Failure prob of DP", ) - parser.add_argument("--max_epsilon", - default=2, - type=float, + parser.add_argument("--max_epsilon", default=2, type=float, help="Privacy budget") - parser.add_argument( - "--max_grad_norm", - default=1, - type=float, - help="Clip per-sample gradients to this norm", - ) + parser.add_argument("--max_grad_norm", default=1, type=float, + help="Clip per-sample gradients to this norm", ) parser.add_argument("--noise_multiplier", default=1, type=float, help="Noise multiplier") - - parser.add_argument( - "--n_accumulation_steps", - default=1, - type=int, - help="Run optimization per ? step", - ) - parser.add_argument( - "--secure_rng", - dest="secure_rng", - action="store_true", - help= - "Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost", - ) + parser.add_argument("--n_accumulation_steps", default=1, type=int, + help="Run optimization per ? step") + parser.add_argument("--secure_rng", + dest="secure_rng", + action="store_true", + help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost", ) # For attack parser.add_argument("--reconstruct_labels", action="store_true") @@ -211,31 +160,18 @@ def parse_args(): default=None, type=int, help="seed to select attack subset") - parser.add_argument("--attack_epoch", - default=0, - type=int, + parser.add_argument("--attack_epoch", default=0, type=int, help="iterations for the attack") - parser.add_argument( - "--bn_reg", - default=0, - type=float, - help="coef. for batchnorm regularization term", - ) - parser.add_argument( - "--attacker_eval_mode", - action="store_true", - help="use eval model for gradients calculation for attack", - ) - parser.add_argument( - "--defender_eval_mode", - action="store_true", - help="use eval model for gradients calculation for training", - ) - parser.add_argument( - "--BN_exact", - action="store_true", - help="use training batch's mean and var", - ) + parser.add_argument("--bn_reg", default=0, type=float, + help="coef. for batchnorm regularization term") + parser.add_argument("--attacker_eval_mode", + action="store_true", + help="use eval model for gradients calculation for attack") + parser.add_argument("--defender_eval_mode", action="store_true", + help="use eval model for gradients calculation for training") + parser.add_argument("--BN_exact", + action="store_true", + help="use training batch's mean and var",) args = parser.parse_args() @@ -254,7 +190,7 @@ def parse_args(): hparams["lr_step"] = args.lr_step hparams["lr_factor"] = args.lr_factor elif args.scheduler == "MultiStepLR": - hparams["lr_step"] = [5, 60, 120, 160] + hparams["lr_step"] = [60, 120, 160] hparams["lr_factor"] = args.lr_factor elif args.scheduler == "LambdaLR": hparams["lr_lambda"] = args.lr_lambda @@ -285,16 +221,12 @@ def parse_args(): def parse_augmentation(args): return { - "hflip": - args.aug_hflip, - "crop": - args.aug_crop, - "rotation": - args.aug_rotation, + "hflip": args.aug_hflip, + "crop": args.aug_crop, + "rotation": args.aug_rotation, "color_jitter": [float(i) for i in args.aug_colorjitter] if args.aug_colorjitter is not None else None, - "affine": - args.aug_affine, + "affine": args.aug_affine, } From 6389ab7a1f0e0b64c5c669e5d7432d01d57abf69 Mon Sep 17 00:00:00 2001 From: limberc Date: Wed, 29 Mar 2023 23:00:45 +0800 Subject: [PATCH 17/17] Update. --- gradattack/models/__init__.py | 122 +++------------------------------- 1 file changed, 8 insertions(+), 114 deletions(-) diff --git a/gradattack/models/__init__.py b/gradattack/models/__init__.py index 634b414..26b2674 100755 --- a/gradattack/models/__init__.py +++ b/gradattack/models/__init__.py @@ -2,7 +2,6 @@ from typing import Callable import pytorch_lightning as pl -from sklearn import metrics from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, ReduceLROnPlateau, StepLR from gradattack.utils import StandardizeLayer @@ -339,54 +338,8 @@ def validation_step(self, batch, batch_idx): else: loss = self._val_loss_metric(y_hat, y) top1_acc = accuracy(y_hat, y, multi_head=self.multi_head)[0] - if self.log_auc: - pred_list, true_list = auc_list(y_hat, y) - else: - pred_list, true_list = None, None - return { - "batch/val_loss": loss, - "batch/val_accuracy": top1_acc, - "batch/val_pred_list": pred_list, - "batch/val_true_list": true_list, - } - - def validation_epoch_end(self, outputs): - # outputs is whatever returned in `validation_step` - avg_loss = torch.stack([x["batch/val_loss"] for x in outputs]).mean() - avg_accuracy = torch.stack([x["batch/val_accuracy"] - for x in outputs]).mean() - if self.log_auc: - self.log_aucs(outputs, stage="val") - - self.current_val_loss = avg_loss - if self.current_epoch > 0: - if self.hparams["lr_scheduler"] == "ReduceLROnPlateau": - self.lr_scheduler.step(self.current_val_loss) - else: - self.lr_scheduler.step() - - self.cur_lr = self.optimizer.param_groups[0]["lr"] - - self.log( - "epoch/val_accuracy", - avg_accuracy, - on_epoch=True, - prog_bar=True, - logger=True, - ) - self.log("epoch/val_loss", - avg_loss, - on_epoch=True, - prog_bar=True, - logger=True) - self.log("epoch/lr", - self.cur_lr, - on_epoch=True, - prog_bar=True, - logger=True) - - for callback in self._epoch_end_callbacks: - callback(self) + self.log('val/loss', loss, on_epoch=True, logger=True) + self.log('val/acc', top1_acc, on_epoch=True, logger=True) def test_step(self, batch, batch_idx): x, y = batch @@ -399,63 +352,12 @@ def test_step(self, batch, batch_idx): else: loss = self._val_loss_metric(y_hat, y) top1_acc = accuracy(y_hat, y, multi_head=self.multi_head)[0] - if self.log_auc: - pred_list, true_list = auc_list(y_hat, y) - else: - pred_list, true_list = None, None - return { - "batch/test_loss": loss, - "batch/test_accuracy": top1_acc, - "batch/test_pred_list": pred_list, - "batch/test_true_list": true_list, - } - - def test_epoch_end(self, outputs): - avg_loss = torch.stack([x["batch/test_loss"] for x in outputs]).mean() - avg_accuracy = torch.stack([x["batch/test_accuracy"] - for x in outputs]).mean() - if self.log_auc: - self.log_aucs(outputs, stage="test") - - self.log("run/test_accuracy", - avg_accuracy, - on_epoch=True, - prog_bar=True, - logger=True) - self.log("run/test_loss", - avg_loss, - on_epoch=True, - prog_bar=True, - logger=True) - - def log_aucs(self, outputs, stage="test"): - pred_list = np.concatenate( - [x[f"batch/{stage}_pred_list"] for x in outputs]) - true_list = np.concatenate( - [x[f"batch/{stage}_true_list"] for x in outputs]) - - aucs = [] - for c in range(len(pred_list[0])): - fpr, tpr, thresholds = metrics.roc_curve(true_list[:, c], - pred_list[:, c], - pos_label=1) - auc_val = metrics.auc(fpr, tpr) - aucs.append(auc_val) - - self.log( - f"epoch/{stage}_auc/class_{c}", - auc_val, - on_epoch=True, - prog_bar=False, - logger=True, - ) - self.log( - f"epoch/{stage}_auc/avg", - np.mean(aucs), - on_epoch=True, - prog_bar=True, - logger=True, - ) + # if self.log_auc: + # pred_list, true_list = auc_list(y_hat, y) + # else: + # pred_list, true_list = None, None + self.log('test/loss', loss, logger=True, on_epoch=True) + self.log('test/acc', top1_acc, logger=True, on_epoch=True) def create_lightning_module( @@ -548,11 +450,3 @@ def accuracy(output, target, topk=(1,), multi_head=False): correct *= 100.0 / (batch_size * target.size(1)) res = [correct] return res - - -def auc_list(output, target): - assert len(target.size()) == 2 - pred_list = torch.sigmoid(output).cpu().detach().numpy() - true_list = target.cpu().detach().numpy() - - return pred_list, true_list