From 1a0eb17be24beb7cc32135f6554f2f03f5f78633 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Aug 2021 13:17:53 +0800 Subject: [PATCH 1/8] enhance swish Signed-off-by: Yiheng Wang --- monai/networks/blocks/activation.py | 43 +++++++++++++++++++++++------ monai/networks/nets/efficientnet.py | 2 +- monai/utils/__init__.py | 1 + monai/utils/module.py | 6 ++++ 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index ef2d19b550..0afd202c92 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -12,18 +12,30 @@ import torch from torch import nn -from monai.utils import optional_import +from monai.utils import PT_BEFORE_1_7, PT_BEFORE_1_9 -if optional_import("torch.nn.functional", name="mish")[1]: +if PT_BEFORE_1_9: - def monai_mish(x): - return torch.nn.functional.mish(x, inplace=True) + def monai_mish(x, inplace: bool = False): + return x * torch.tanh(torch.nn.functional.softplus(x)) else: - def monai_mish(x): - return x * torch.tanh(torch.nn.functional.softplus(x)) + def monai_mish(x, inplace: bool = False): + return torch.nn.functional.mish(x, inplace=inplace) + + +if PT_BEFORE_1_7: + + def monai_swish(x, inplace: bool = False): + return SwishImplementation.apply(x) + + +else: + + def monai_swish(x, inplace: bool = False): + return torch.nn.functional.silu(x, inplace=inplace) class Swish(nn.Module): @@ -92,6 +104,9 @@ class MemoryEfficientSwish(nn.Module): Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941. + From Pytorch 1.7.0+, the optimized version of `Swish` named `SiLU` is implemented, + this class will utilize `torch.nn.functional.silu` to do the calculation if meets the version. + Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions @@ -107,8 +122,13 @@ class MemoryEfficientSwish(nn.Module): >>> output = m(input) """ + def __init__(self, inplace: bool = False): + super(MemoryEfficientSwish, self).__init__() + # inplace only works when using torch.nn.functional.silu + self.inplace = inplace + def forward(self, input: torch.Tensor): - return SwishImplementation.apply(input) + return monai_swish(input, self.inplace) class Mish(nn.Module): @@ -119,6 +139,8 @@ class Mish(nn.Module): Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681. + From Pytorch 1.9.0+, the optimized version of `Mish` is implemented, + this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version. Shape: - Input: :math:`(N, *)` where `*` means, any number of additional @@ -135,5 +157,10 @@ class Mish(nn.Module): >>> output = m(input) """ + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + # inplace only works when using torch.nn.functional.mish + self.inplace = inplace + def forward(self, input: torch.Tensor): - return monai_mish(input) + return monai_mish(input, self.inplace) diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index cb8e195b04..b199c5b6f9 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -140,7 +140,7 @@ def __init__( # swish activation to use - using memory efficient swish by default # can be switched to normal swish using self.set_swish() function call - self._swish = Act["memswish"]() + self._swish = Act["memswish"](inplace=True) def forward(self, inputs: torch.Tensor): """MBConvBlock"s forward function. diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index dd300fce34..6a63e766f3 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -57,6 +57,7 @@ ) from .module import ( PT_BEFORE_1_7, + PT_BEFORE_1_9, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, diff --git a/monai/utils/module.py b/monai/utils/module.py index 33314fb0e3..b0ce1c01cd 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -34,6 +34,7 @@ "get_package_version", "get_torch_version_tuple", "PT_BEFORE_1_7", + "PT_BEFORE_1_9", "version_leq", ] @@ -406,3 +407,8 @@ def _try_cast(val): PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0") except (AttributeError, TypeError): PT_BEFORE_1_7 = True + +try: + PT_BEFORE_1_9 = torch.__version__ != "1.9.0" and version_leq(torch.__version__, "1.9.0") +except (AttributeError, TypeError): + PT_BEFORE_1_9 = True From 683492f7177df633d9482b42b513dc7efdde375f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Aug 2021 13:21:07 +0800 Subject: [PATCH 2/8] add inplace to swish Signed-off-by: Yiheng Wang --- monai/networks/nets/efficientnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index b199c5b6f9..69d1d2a655 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -188,7 +188,7 @@ def set_swish(self, memory_efficient: bool = True) -> None: Args: memory_efficient (bool): Whether to use memory-efficient version of swish. """ - self._swish = Act["memswish"]() if memory_efficient else Act["swish"](alpha=1.0) + self._swish = Act["memswish"](inplace=True) if memory_efficient else Act["swish"](alpha=1.0) class EfficientNet(nn.Module): From 600b1a3f0df0791f83d49f84ecea9ea3e29a5f17 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 23 Aug 2021 16:44:16 +0800 Subject: [PATCH 3/8] add efn features Signed-off-by: Yiheng Wang --- docs/source/networks.rst | 5 + monai/networks/blocks/activation.py | 14 +-- monai/networks/nets/__init__.py | 9 +- monai/networks/nets/efficientnet.py | 175 ++++++++++++++++++++-------- monai/utils/__init__.py | 1 - monai/utils/module.py | 6 - tests/test_efficientnet.py | 40 ++++++- 7 files changed, 188 insertions(+), 62 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index a5ce86287a..54c2756535 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -358,6 +358,11 @@ Nets .. autoclass:: EfficientNetBN :members: +`EfficientNetBNFeatures` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: EfficientNetBNFeatures + :members: + `SegResNet` ~~~~~~~~~~~ .. autoclass:: SegResNet diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index 0afd202c92..a380f8e757 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -12,30 +12,30 @@ import torch from torch import nn -from monai.utils import PT_BEFORE_1_7, PT_BEFORE_1_9 +from monai.utils import optional_import -if PT_BEFORE_1_9: +if optional_import("torch.nn.functional", name="mish")[1]: def monai_mish(x, inplace: bool = False): - return x * torch.tanh(torch.nn.functional.softplus(x)) + return torch.nn.functional.mish(x, inplace=inplace) else: def monai_mish(x, inplace: bool = False): - return torch.nn.functional.mish(x, inplace=inplace) + return x * torch.tanh(torch.nn.functional.softplus(x)) -if PT_BEFORE_1_7: +if optional_import("torch.nn.functional", name="silu")[1]: def monai_swish(x, inplace: bool = False): - return SwishImplementation.apply(x) + return torch.nn.functional.silu(x, inplace=inplace) else: def monai_swish(x, inplace: bool = False): - return torch.nn.functional.silu(x, inplace=inplace) + return SwishImplementation.apply(x) class Swish(nn.Module): diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9cf6c5e07f..ad1ca2418b 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -31,7 +31,14 @@ densenet264, ) from .dynunet import DynUNet, DynUnet, Dynunet, dynunet -from .efficientnet import BlockArgs, EfficientNet, EfficientNetBN, drop_connect, get_efficientnet_image_size +from .efficientnet import ( + BlockArgs, + EfficientNet, + EfficientNetBN, + EfficientNetBNFeatures, + drop_connect, + get_efficientnet_image_size, +) from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index 69d1d2a655..e40f63d6c0 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -34,6 +34,7 @@ "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2), "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2), "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2), + "efficientnet-b8": (2.2, 3.6, 672, 0.5, 0.2), } @@ -208,8 +209,7 @@ def __init__( ) -> None: """ EfficientNet based on `Rethinking Model Scaling for Convolutional Neural Networks `_. - Adapted from `EfficientNet-PyTorch - `_. + Adapted from `EfficientNet-PyTorch `_. Args: blocks_args_str: block definitions. @@ -220,9 +220,10 @@ def __init__( depth_coefficient: depth multiplier coefficient (d in paper). dropout_rate: dropout rate for dropout layers. image_size: input image resolution. - norm: feature normalization type and arguments. Defaults to batch norm. + norm: feature normalization type and arguments. drop_connect_rate: dropconnect rate for drop connection (individual weights) layers. depth_divisor: depth divisor for channel rounding. + """ super().__init__() @@ -266,6 +267,8 @@ def __init__( num_blocks = 0 self._blocks = nn.Sequential() + self.extract_stacks = [] + # update baseline blocks to input/output filters and number of repeats based on width and depth multipliers. for idx, block_args in enumerate(self._blocks_args): block_args = block_args._replace( @@ -278,17 +281,23 @@ def __init__( # calculate the total number of blocks - needed for drop_connect estimation num_blocks += block_args.num_repeat + if block_args.stride > 1: + self.extract_stacks.append(idx) + + self.extract_stacks.append(len(self._blocks_args)) + # create and add MBConvBlocks to self._blocks idx = 0 # block index counter - for block_args in self._blocks_args: + for stack_idx, block_args in enumerate(self._blocks_args): blk_drop_connect_rate = self.drop_connect_rate # scale drop connect_rate if blk_drop_connect_rate: blk_drop_connect_rate *= float(idx) / num_blocks + sub_stack = nn.Sequential() # the first block needs to take care of stride and filter size increase. - self._blocks.add_module( + sub_stack.add_module( str(idx), MBConvBlock( spatial_dims=spatial_dims, @@ -319,7 +328,7 @@ def __init__( blk_drop_connect_rate *= float(idx) / num_blocks # add blocks - self._blocks.add_module( + sub_stack.add_module( str(idx), MBConvBlock( spatial_dims=spatial_dims, @@ -337,9 +346,14 @@ def __init__( ) idx += 1 # increment blocks index counter + self._blocks.add_module( + str(stack_idx), + sub_stack, + ) + # sanity check to see if len(self._blocks) equal expected num_blocks - if len(self._blocks) != num_blocks: - raise ValueError("number of blocks created != num_blocks") + if idx != num_blocks: + raise ValueError("total number of blocks created != num_blocks") # Head head_in_channels = block_args.output_filters @@ -369,8 +383,9 @@ def set_swish(self, memory_efficient: bool = True) -> None: """ self._swish = Act["memswish"]() if memory_efficient else Act["swish"](alpha=1.0) - for block in self._blocks: - block.set_swish(memory_efficient) + for sub_stack in self._blocks: + for block in sub_stack: + block.set_swish(memory_efficient) def forward(self, inputs: torch.Tensor): """ @@ -379,8 +394,7 @@ def forward(self, inputs: torch.Tensor): ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. Returns: - A torch Tensor of classification prediction in shape - ``(Batch, num_classes)``. + a torch Tensor of classification prediction in shape ``(Batch, num_classes)``. """ # Stem x = self._conv_stem(self._conv_stem_padding(inputs)) @@ -440,17 +454,17 @@ def __init__( """ Generic wrapper around EfficientNet, used to initialize EfficientNet-B0 to EfficientNet-B7 models model_name is mandatory argument as there is no EfficientNetBN itself, - it needs the N in [0, 1, 2, 3, 4, 5, 6, 7] to be a model + it needs the N in [0, 1, 2, 3, 4, 5, 6, 7, 8] to be a model Args: - model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b7]. + model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b8]. pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2 and batch norm is used. progress: whether to show download progress for pretrained weights download. spatial_dims: number of spatial dimensions. in_channels: number of input channels. num_classes: number of output classes. - norm: feature normalization type and arguments. Defaults to batch norm. + norm: feature normalization type and arguments. Examples:: @@ -471,7 +485,7 @@ def __init__( >>> model = EfficientNetBN("efficientnet-b7", spatial_dims=2) """ - # block args for EfficientNet-B0 to EfficientNet-B7 + # block args for EfficientNet-B0 to EfficientNet-B8 blocks_args_str = [ "r1_k3_s11_e1_i32_o16_se0.25", "r2_k3_s22_e6_i16_o24_se0.25", @@ -507,16 +521,90 @@ def __init__( norm=norm, ) - # attempt to load pretrained - is_default_model = (spatial_dims == 2) and (in_channels == 3) - loadable_from_file = pretrained and is_default_model + # only pretrained for when `spatial_dims` is 2 + if pretrained and (spatial_dims == 2): + _load_state_dict(self, model_name, progress, in_channels) + + +class EfficientNetBNFeatures(EfficientNet): + def __init__( + self, + model_name: str, + pretrained: bool = True, + progress: bool = True, + spatial_dims: int = 2, + in_channels: int = 3, + num_classes: int = 1000, + norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), + ) -> None: + """ + Initialize EfficientNet-B0 to EfficientNet-B7 models as a backbone, the backbone can + be used as an encoder for segmentation and objection models. + Compared with the class `EfficientNetBN`, the only different place is the forward function. + + This class refers to `PyTorch image models `_. + + """ + blocks_args_str = [ + "r1_k3_s11_e1_i32_o16_se0.25", + "r2_k3_s22_e6_i16_o24_se0.25", + "r2_k5_s22_e6_i24_o40_se0.25", + "r3_k3_s22_e6_i40_o80_se0.25", + "r3_k5_s11_e6_i80_o112_se0.25", + "r4_k5_s22_e6_i112_o192_se0.25", + "r1_k3_s11_e6_i192_o320_se0.25", + ] - if loadable_from_file: - # skip loading fc layers for transfer learning applications - load_fc = num_classes == 1000 + # check if model_name is valid model + if model_name not in efficientnet_params.keys(): + raise ValueError( + "invalid model_name {} found, must be one of {} ".format( + model_name, ", ".join(efficientnet_params.keys()) + ) + ) + + # get network parameters + weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name] + + # create model and initialize random weights + super(EfficientNetBNFeatures, self).__init__( + blocks_args_str=blocks_args_str, + spatial_dims=spatial_dims, + in_channels=in_channels, + num_classes=num_classes, + width_coefficient=weight_coeff, + depth_coefficient=depth_coeff, + dropout_rate=dropout_rate, + image_size=image_size, + drop_connect_rate=dropconnect_rate, + norm=norm, + ) + + # only pretrained for when `spatial_dims` is 2 + if pretrained and (spatial_dims == 2): + _load_state_dict(self, model_name, progress, in_channels) + + def forward(self, inputs: torch.Tensor): + """ + Args: + inputs: input should have spatially N dimensions + ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. + + Returns: + a list of torch Tensors. + """ + # Stem + x = self._conv_stem(self._conv_stem_padding(inputs)) + x = self._swish(self._bn0(x)) - # only pretrained for when `spatial_dims` is 2 - _load_state_dict(self, model_name, progress, load_fc) + features = [] + if 0 in self.extract_stacks: + features.append(x) + for i, block in enumerate(self._blocks): + x = block(x) + if i + 1 in self.extract_stacks: + features.append(x) + return features def get_efficientnet_image_size(model_name: str) -> int: @@ -588,7 +676,7 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor return output -def _load_state_dict(model: nn.Module, model_name: str, progress: bool, load_fc: bool) -> None: +def _load_state_dict(model: nn.Module, model_name: str, progress: bool, in_channels: int) -> None: url_map = { "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", @@ -599,27 +687,22 @@ def _load_state_dict(model: nn.Module, model_name: str, progress: bool, load_fc: "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", } - # load state dict from url - model_url = url_map[model_name] - state_dict = model_zoo.load_url(model_url, progress=progress) - - # load state dict into model parameters - if load_fc: # load everything - ret = model.load_state_dict(state_dict, strict=False) - if ret.missing_keys: - raise ValueError("Found missing keys when loading pretrained weights: {}".format(ret.missing_keys)) - else: # skip final FC layers, for transfer learning cases - state_dict.pop("_fc.weight") - state_dict.pop("_fc.bias") - ret = model.load_state_dict(state_dict, strict=False) - - # check if no other keys missing except FC layer parameters - if set(ret.missing_keys) != {"_fc.weight", "_fc.bias"}: - raise ValueError("Found missing keys when loading pretrained weights: {}".format(ret.missing_keys)) - - # check for any unexpected keys - if ret.unexpected_keys: - raise ValueError("Missing keys when loading pretrained weights: {}".format(ret.unexpected_keys)) + if model_name not in url_map: + print("pretrained weights of {} is not provided".format(model_name)) + else: + # load state dict from url + model_url = url_map[model_name] + pretrain_state_dict = model_zoo.load_url(model_url, progress=progress) + model_state_dict = model.state_dict() + + pattern = re.compile(r"(.+)\.\d+(\.\d+\..+)") + for key, value in model_state_dict.items(): + pretrain_key = re.sub(pattern, r"\1\2", key) + if pretrain_key in pretrain_state_dict: + if value.shape == pretrain_state_dict[pretrain_key].shape: + model_state_dict[key] = pretrain_state_dict[pretrain_key] + + model.load_state_dict(model_state_dict) def _get_same_padding_conv_nd( diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 6a63e766f3..dd300fce34 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -57,7 +57,6 @@ ) from .module import ( PT_BEFORE_1_7, - PT_BEFORE_1_9, InvalidPyTorchVersionError, OptionalImportError, damerau_levenshtein_distance, diff --git a/monai/utils/module.py b/monai/utils/module.py index b0ce1c01cd..33314fb0e3 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -34,7 +34,6 @@ "get_package_version", "get_torch_version_tuple", "PT_BEFORE_1_7", - "PT_BEFORE_1_9", "version_leq", ] @@ -407,8 +406,3 @@ def _try_cast(val): PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0") except (AttributeError, TypeError): PT_BEFORE_1_7 = True - -try: - PT_BEFORE_1_9 = torch.__version__ != "1.9.0" and version_leq(torch.__version__, "1.9.0") -except (AttributeError, TypeError): - PT_BEFORE_1_9 = True diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index 6567e3af9a..afb87334af 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -18,7 +18,13 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import BlockArgs, EfficientNetBN, drop_connect, get_efficientnet_image_size +from monai.networks.nets import ( + BlockArgs, + EfficientNetBN, + EfficientNetBNFeatures, + drop_connect, + get_efficientnet_image_size, +) from monai.utils import optional_import from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save @@ -226,6 +232,20 @@ def make_shape_cases( ) ) +CASE_EXTRACT_FEATURES = [ + ( + { + "model_name": "efficientnet-b8", + "pretrained": False, + "progress": False, + "spatial_dims": 2, + "in_channels": 2, + }, + [1, 2, 224, 224], + ([1, 32, 112, 112], [1, 56, 56, 56], [1, 88, 28, 28], [1, 248, 14, 14], [1, 704, 7, 7]), + ), +] + class TestEFFICIENTNET(unittest.TestCase): @parameterized.expand(CASES_1D + CASES_2D + CASES_3D + CASES_VARIATIONS) @@ -355,5 +375,23 @@ def test_script(self): test_script_save(net, test_data) +class TestExtractFeatures(unittest.TestCase): + @parameterized.expand(CASE_EXTRACT_FEATURES) + def test_shape(self, input_param, input_shape, expected_shapes): + device = "cuda" if torch.cuda.is_available() else "cpu" + + # initialize model + net = EfficientNetBNFeatures(**input_param).to(device) + + # run inference with random tensor + with eval_mode(net): + features = net(torch.randn(input_shape).to(device)) + + # check output shape + self.assertEqual(len(features), len(expected_shapes)) + for feature, expected_shape in zip(features, expected_shapes): + self.assertEqual(feature.shape, torch.Size(expected_shape)) + + if __name__ == "__main__": unittest.main() From 51c44067eb5a0286ba98d5724d02e9e59c0b8c90 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 23 Aug 2021 17:43:48 +0800 Subject: [PATCH 4/8] add more pretrained weights url Signed-off-by: Yiheng Wang --- monai/networks/nets/efficientnet.py | 28 +++++++++++++++++++++++----- tests/test_efficientnet.py | 4 +++- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index e40f63d6c0..ef54244a77 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -35,6 +35,7 @@ "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2), "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2), "efficientnet-b8": (2.2, 3.6, 672, 0.5, 0.2), + "efficientnet-l2": (4.3, 5.3, 800, 0.5, 0.2), } @@ -450,6 +451,7 @@ def __init__( in_channels: int = 3, num_classes: int = 1000, norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), + adv_prop: bool = False, ) -> None: """ Generic wrapper around EfficientNet, used to initialize EfficientNet-B0 to EfficientNet-B7 models @@ -457,7 +459,7 @@ def __init__( it needs the N in [0, 1, 2, 3, 4, 5, 6, 7, 8] to be a model Args: - model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b8]. + model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-l2]. pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2 and batch norm is used. progress: whether to show download progress for pretrained weights download. @@ -465,6 +467,8 @@ def __init__( in_channels: number of input channels. num_classes: number of output classes. norm: feature normalization type and arguments. + adv_prop: whether to use weights trained with adversarial examples. + This argument only works when `pretrained` is `True`. Examples:: @@ -485,7 +489,7 @@ def __init__( >>> model = EfficientNetBN("efficientnet-b7", spatial_dims=2) """ - # block args for EfficientNet-B0 to EfficientNet-B8 + # block args blocks_args_str = [ "r1_k3_s11_e1_i32_o16_se0.25", "r2_k3_s22_e6_i16_o24_se0.25", @@ -523,7 +527,7 @@ def __init__( # only pretrained for when `spatial_dims` is 2 if pretrained and (spatial_dims == 2): - _load_state_dict(self, model_name, progress, in_channels) + _load_state_dict(self, model_name, progress, in_channels, adv_prop) class EfficientNetBNFeatures(EfficientNet): @@ -536,6 +540,7 @@ def __init__( in_channels: int = 3, num_classes: int = 1000, norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), + adv_prop: bool = False, ) -> None: """ Initialize EfficientNet-B0 to EfficientNet-B7 models as a backbone, the backbone can @@ -582,7 +587,7 @@ def __init__( # only pretrained for when `spatial_dims` is 2 if pretrained and (spatial_dims == 2): - _load_state_dict(self, model_name, progress, in_channels) + _load_state_dict(self, model_name, progress, in_channels, adv_prop) def forward(self, inputs: torch.Tensor): """ @@ -676,7 +681,8 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor return output -def _load_state_dict(model: nn.Module, model_name: str, progress: bool, in_channels: int) -> None: +def _load_state_dict(model: nn.Module, model_name: str, progress: bool, in_channels: int, adv_prop: bool) -> None: + url_map = { "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", @@ -686,7 +692,19 @@ def _load_state_dict(model: nn.Module, model_name: str, progress: bool, in_chann "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", + "efficientnet-b0-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth", + "efficientnet-b1-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth", + "efficientnet-b2-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth", + "efficientnet-b3-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth", + "efficientnet-b4-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth", + "efficientnet-b5-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth", + "efficientnet-b6-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth", + "efficientnet-b7-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth", + "efficientnet-b8-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth", } + + if adv_prop: + model_name += "-ap" if model_name not in url_map: print("pretrained weights of {} is not provided".format(model_name)) else: diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index afb87334af..6befba108a 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -162,6 +162,7 @@ def make_shape_cases( "in_channels": 3, "num_classes": 1000, "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), + "adv_prop": False, }, os.path.join(os.path.dirname(__file__), "testing_data", "kitty_test.jpg"), 282, # ~ tiger cat @@ -236,10 +237,11 @@ def make_shape_cases( ( { "model_name": "efficientnet-b8", - "pretrained": False, + "pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, + "adv_prop": True, }, [1, 2, 224, 224], ([1, 32, 112, 112], [1, 56, 56, 56], [1, 88, 28, 28], [1, 248, 14, 14], [1, 704, 7, 7]), From e6cfa30002c5717afd4def04bd29c4090bb56b61 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 23 Aug 2021 18:30:40 +0800 Subject: [PATCH 5/8] fix flake8 errors Signed-off-by: Yiheng Wang --- monai/networks/nets/efficientnet.py | 44 ++++++++++++++--------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index ef54244a77..e4e46c5497 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -38,6 +38,27 @@ "efficientnet-l2": (4.3, 5.3, 800, 0.5, 0.2), } +url_map = { + "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", + "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", + "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", + "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", + "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", + "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", + "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", + "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", + # trained with adversarial examples, simplify the name to decrease string length + "b0-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth", + "b1-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth", + "b2-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth", + "b3-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth", + "b4-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth", + "b5-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth", + "b6-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth", + "b7-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth", + "b8-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth", +} + class MBConvBlock(nn.Module): def __init__( @@ -682,29 +703,8 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor def _load_state_dict(model: nn.Module, model_name: str, progress: bool, in_channels: int, adv_prop: bool) -> None: - - url_map = { - "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", - "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", - "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", - "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", - "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", - "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", - "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", - "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", - "efficientnet-b0-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth", - "efficientnet-b1-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth", - "efficientnet-b2-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth", - "efficientnet-b3-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth", - "efficientnet-b4-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth", - "efficientnet-b5-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth", - "efficientnet-b6-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth", - "efficientnet-b7-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth", - "efficientnet-b8-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth", - } - if adv_prop: - model_name += "-ap" + model_name = model_name.split("efficientnet-")[-1] + "-ap" if model_name not in url_map: print("pretrained weights of {} is not provided".format(model_name)) else: From b66f838b7360ae60b6f2b9bef10caa48b186d674 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 25 Aug 2021 11:54:42 +0800 Subject: [PATCH 6/8] use look up option Signed-off-by: Yiheng Wang --- monai/networks/nets/densenet.py | 45 +++++++++-------- monai/networks/nets/efficientnet.py | 14 +++--- monai/networks/nets/senet.py | 78 ++++++++++++++--------------- 3 files changed, 70 insertions(+), 67 deletions(-) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 3e30987bdc..d3f8400863 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -19,6 +19,7 @@ from monai.networks.layers.factories import Conv, Dropout, Pool from monai.networks.layers.utils import get_act_layer, get_norm_layer +from monai.utils.module import look_up_option __all__ = [ "DenseNet", @@ -249,7 +250,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def _load_state_dict(model, arch, progress): +def _load_state_dict(model: nn.Module, arch: str, progress: bool): """ This function is used to load pretrained models. Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16. @@ -260,30 +261,30 @@ def _load_state_dict(model, arch, progress): "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", } - if arch in model_urls: - model_url = model_urls[arch] - else: + model_url = look_up_option(arch, model_urls, None) + if model_url is None: raise ValueError( "only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights." ) - pattern = re.compile( - r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" - ) - - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) - state_dict[new_key] = state_dict[key] - del state_dict[key] - - model_dict = model.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - model.load_state_dict(model_dict) + else: + pattern = re.compile( + r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) + state_dict[new_key] = state_dict[key] + del state_dict[key] + + model_dict = model.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + model.load_state_dict(model_dict) class DenseNet121(DenseNet): diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index e4e46c5497..b43f0762a0 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -21,6 +21,7 @@ from monai.networks.layers.factories import Act, Conv, Pad, Pool from monai.networks.layers.utils import get_norm_layer +from monai.utils.module import look_up_option __all__ = ["EfficientNet", "EfficientNetBN", "get_efficientnet_image_size", "drop_connect"] @@ -480,7 +481,7 @@ def __init__( it needs the N in [0, 1, 2, 3, 4, 5, 6, 7, 8] to be a model Args: - model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-l2]. + model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2]. pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2 and batch norm is used. progress: whether to show download progress for pretrained weights download. @@ -702,14 +703,15 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor return output -def _load_state_dict(model: nn.Module, model_name: str, progress: bool, in_channels: int, adv_prop: bool) -> None: +def _load_state_dict(model: nn.Module, arch: str, progress: bool, in_channels: int, adv_prop: bool) -> None: if adv_prop: - model_name = model_name.split("efficientnet-")[-1] + "-ap" - if model_name not in url_map: - print("pretrained weights of {} is not provided".format(model_name)) + arch = arch.split("efficientnet-")[-1] + "-ap" + model_url = look_up_option(arch, url_map, None) + if model_url is None: + print("pretrained weights of {} is not provided".format(arch)) else: # load state dict from url - model_url = url_map[model_name] + model_url = url_map[arch] pretrain_state_dict = model_zoo.load_url(model_url, progress=progress) model_state_dict = model.state_dict() diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 7292b2a1d5..49246f21d4 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -20,6 +20,7 @@ from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.squeeze_and_excitation import SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck from monai.networks.layers.factories import Act, Conv, Dropout, Norm, Pool +from monai.utils.module import look_up_option __all__ = ["SENet", "SENet154", "SEResNet50", "SEResNet101", "SEResNet152", "SEResNeXt50", "SEResNext101"] @@ -249,7 +250,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def _load_state_dict(model, arch, progress): +def _load_state_dict(model: nn.Module, arch: str, progress: bool): """ This function is used to load pretrained models. """ @@ -261,48 +262,47 @@ def _load_state_dict(model, arch, progress): "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", } - if arch in model_urls: - model_url = model_urls[arch] - else: + model_url = look_up_option(arch, model_urls, None) + if model_url is None: raise ValueError( "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', " + "and se_resnext101_32x4d are supported to load pretrained weights." ) - - pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") - pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") - pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") - pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") - pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") - pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") - - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - new_key = None - if pattern_conv.match(key): - new_key = re.sub(pattern_conv, r"\1conv.\2", key) - elif pattern_bn.match(key): - new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) - elif pattern_se.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) - elif pattern_se2.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) - elif pattern_down_conv.match(key): - new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) - elif pattern_down_bn.match(key): - new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) - if new_key: - state_dict[new_key] = state_dict[key] - del state_dict[key] - - model_dict = model.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - model.load_state_dict(model_dict) + else: + pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") + pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") + pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") + pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") + pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") + pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") + + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + new_key = None + if pattern_conv.match(key): + new_key = re.sub(pattern_conv, r"\1conv.\2", key) + elif pattern_bn.match(key): + new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) + elif pattern_se.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) + elif pattern_se2.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) + elif pattern_down_conv.match(key): + new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) + elif pattern_down_bn.match(key): + new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) + if new_key: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + model_dict = model.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + model.load_state_dict(model_dict) class SENet154(SENet): From d424c2a928c1b06241ae9e0a7dd4033965503ddc Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 25 Aug 2021 22:27:51 +0800 Subject: [PATCH 7/8] fix deepsource issues Signed-off-by: Yiheng Wang --- monai/networks/blocks/activation.py | 4 +-- monai/networks/nets/densenet.py | 36 +++++++++---------- monai/networks/nets/efficientnet.py | 11 +++--- monai/networks/nets/senet.py | 56 ++++++++++++++--------------- 4 files changed, 53 insertions(+), 54 deletions(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index a380f8e757..d40f8414d0 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -22,7 +22,7 @@ def monai_mish(x, inplace: bool = False): else: - def monai_mish(x, inplace: bool = False): + def monai_mish(x, _inplace: bool = False): return x * torch.tanh(torch.nn.functional.softplus(x)) @@ -34,7 +34,7 @@ def monai_swish(x, inplace: bool = False): else: - def monai_swish(x, inplace: bool = False): + def monai_swish(x, _inplace: bool = False): return SwishImplementation.apply(x) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index d3f8400863..e9f3b6d33e 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -266,25 +266,25 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool): raise ValueError( "only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights." ) - else: - pattern = re.compile( - r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" - ) - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) - state_dict[new_key] = state_dict[key] - del state_dict[key] - - model_dict = model.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - model.load_state_dict(model_dict) + pattern = re.compile( + r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) + + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) + state_dict[new_key] = state_dict[key] + del state_dict[key] + + model_dict = model.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + model.load_state_dict(model_dict) class DenseNet121(DenseNet): diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index b43f0762a0..453916758a 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -549,7 +549,7 @@ def __init__( # only pretrained for when `spatial_dims` is 2 if pretrained and (spatial_dims == 2): - _load_state_dict(self, model_name, progress, in_channels, adv_prop) + _load_state_dict(self, model_name, progress, adv_prop) class EfficientNetBNFeatures(EfficientNet): @@ -609,7 +609,7 @@ def __init__( # only pretrained for when `spatial_dims` is 2 if pretrained and (spatial_dims == 2): - _load_state_dict(self, model_name, progress, in_channels, adv_prop) + _load_state_dict(self, model_name, progress, adv_prop) def forward(self, inputs: torch.Tensor): """ @@ -703,7 +703,7 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor return output -def _load_state_dict(model: nn.Module, arch: str, progress: bool, in_channels: int, adv_prop: bool) -> None: +def _load_state_dict(model: nn.Module, arch: str, progress: bool, adv_prop: bool) -> None: if adv_prop: arch = arch.split("efficientnet-")[-1] + "-ap" model_url = look_up_option(arch, url_map, None) @@ -718,9 +718,8 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool, in_channels: i pattern = re.compile(r"(.+)\.\d+(\.\d+\..+)") for key, value in model_state_dict.items(): pretrain_key = re.sub(pattern, r"\1\2", key) - if pretrain_key in pretrain_state_dict: - if value.shape == pretrain_state_dict[pretrain_key].shape: - model_state_dict[key] = pretrain_state_dict[pretrain_key] + if pretrain_key in pretrain_state_dict and value.shape == pretrain_state_dict[pretrain_key].shape: + model_state_dict[key] = pretrain_state_dict[pretrain_key] model.load_state_dict(model_state_dict) diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 49246f21d4..f4159d17a6 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -268,34 +268,34 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool): "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', " + "and se_resnext101_32x4d are supported to load pretrained weights." ) - else: - pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") - pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") - pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") - pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") - pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") - pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") - - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - new_key = None - if pattern_conv.match(key): - new_key = re.sub(pattern_conv, r"\1conv.\2", key) - elif pattern_bn.match(key): - new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) - elif pattern_se.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) - elif pattern_se2.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) - elif pattern_down_conv.match(key): - new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) - elif pattern_down_bn.match(key): - new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) - if new_key: - state_dict[new_key] = state_dict[key] - del state_dict[key] + + pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") + pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") + pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") + pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") + pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") + pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") + + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + new_key = None + if pattern_conv.match(key): + new_key = re.sub(pattern_conv, r"\1conv.\2", key) + elif pattern_bn.match(key): + new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) + elif pattern_se.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) + elif pattern_se2.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) + elif pattern_down_conv.match(key): + new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) + elif pattern_down_bn.match(key): + new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) + if new_key: + state_dict[new_key] = state_dict[key] + del state_dict[key] model_dict = model.state_dict() state_dict = { From 0ce265a35eca61bea63ddfedc6de3901cde2d8eb Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 25 Aug 2021 23:17:59 +0800 Subject: [PATCH 8/8] fix type issues Signed-off-by: Yiheng Wang --- monai/networks/blocks/activation.py | 4 ++-- monai/networks/nets/senet.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index d40f8414d0..a380f8e757 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -22,7 +22,7 @@ def monai_mish(x, inplace: bool = False): else: - def monai_mish(x, _inplace: bool = False): + def monai_mish(x, inplace: bool = False): return x * torch.tanh(torch.nn.functional.softplus(x)) @@ -34,7 +34,7 @@ def monai_swish(x, inplace: bool = False): else: - def monai_swish(x, _inplace: bool = False): + def monai_swish(x, inplace: bool = False): return SwishImplementation.apply(x) diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index f4159d17a6..9b7035c259 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -297,12 +297,12 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool): state_dict[new_key] = state_dict[key] del state_dict[key] - model_dict = model.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - model.load_state_dict(model_dict) + model_dict = model.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + model.load_state_dict(model_dict) class SENet154(SENet):