From 024b73516bdc5c1399a01a5c336cf30b56061b55 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 18 Mar 2021 21:25:34 +0800 Subject: [PATCH 1/2] add pretrain options Signed-off-by: Yiheng Wang --- monai/networks/nets/densenet.py | 112 ++++++++++++++--------- monai/networks/nets/senet.py | 154 ++++++++++++++++++-------------- 2 files changed, 155 insertions(+), 111 deletions(-) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index a59ab99e68..6db5233890 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -115,6 +115,11 @@ class DenseNet(nn.Module): bn_size: multiplicative factor for number of bottle neck layers. (i.e. bn_size * k features in the bottleneck layer) dropout_prob: dropout rate after each dense layer. + pretrained: whether to load ImageNet pretrained weights when `spatial_dims == 2`. + In order to load weights correctly, Please ensure that the `block_config` + is consistent with the corresponding arch. + pretrained_arch: the arch name for pretrained weights. + progress: If True, displays a progress bar of the download to stderr. """ def __init__( @@ -127,6 +132,9 @@ def __init__( block_config: Sequence[int] = (6, 12, 24, 16), bn_size: int = 4, dropout_prob: float = 0.0, + pretrained: bool = False, + pretrained_arch: str = "densenet121", + progress: bool = True, ) -> None: super(DenseNet, self).__init__() @@ -190,43 +198,48 @@ def __init__( elif isinstance(m, nn.Linear): nn.init.constant_(torch.as_tensor(m.bias), 0) + if pretrained: + self._load_state_dict(pretrained_arch, progress) + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = self.class_layers(x) return x + def _load_state_dict(self, arch, progress): + """ + This function is used to load pretrained models. + Adapted from `PyTorch Hub 2D version + `_ + """ + model_urls = { + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + } + if arch in model_urls.keys(): + model_url = model_urls[arch] + else: + error_msg = "only densenet121, densenet169 and densenet201 are supported to load pretrained weights." + raise AssertionError(error_msg) + pattern = re.compile( + r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) -model_urls = { - "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", - "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", - "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", -} - - -def _load_state_dict(model, model_url, progress): - """ - This function is used to load pretrained models. - Adapted from `PyTorch Hub 2D version - `_ - """ - pattern = re.compile( - r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" - ) - - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) - state_dict[new_key] = state_dict[key] - del state_dict[key] + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) + state_dict[new_key] = state_dict[key] + del state_dict[key] - model_dict = model.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - model.load_state_dict(model_dict) + model_dict = self.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + self.load_state_dict(model_dict) def densenet121(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet: @@ -235,10 +248,15 @@ def densenet121(pretrained: bool = False, progress: bool = True, **kwargs) -> De from `PyTorch Hub 2D version `_ """ - model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) - if pretrained: - arch = "densenet121" - _load_state_dict(model, model_urls[arch], progress) + model = DenseNet( + init_features=64, + growth_rate=32, + block_config=(6, 12, 24, 16), + pretrained=pretrained, + pretrained_arch="densenet121", + progress=progress, + **kwargs, + ) return model @@ -248,10 +266,15 @@ def densenet169(pretrained: bool = False, progress: bool = True, **kwargs) -> De from `PyTorch Hub 2D version `_ """ - model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) - if pretrained: - arch = "densenet169" - _load_state_dict(model, model_urls[arch], progress) + model = DenseNet( + init_features=64, + growth_rate=32, + block_config=(6, 12, 32, 32), + pretrained=pretrained, + pretrained_arch="densenet169", + progress=progress, + **kwargs, + ) return model @@ -261,10 +284,15 @@ def densenet201(pretrained: bool = False, progress: bool = True, **kwargs) -> De from `PyTorch Hub 2D version `_ """ - model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) - if pretrained: - arch = "densenet201" - _load_state_dict(model, model_urls[arch], progress) + model = DenseNet( + init_features=64, + growth_rate=32, + block_config=(6, 12, 48, 32), + pretrained=pretrained, + pretrained_arch="densenet201", + progress=progress, + **kwargs, + ) return model diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index ef67f853d6..d6a657ae03 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -66,7 +66,11 @@ class SENet(nn.Module): - For SE-ResNeXt models: False num_classes: number of outputs in `last_linear` layer. for all models: 1000 - + pretrained: whether to load ImageNet pretrained weights when `spatial_dims == 2`. + In order to load weights correctly, Please ensure that the `block_config` + is consistent with the corresponding arch. + pretrained_arch: the arch name for pretrained weights. + progress: If True, displays a progress bar of the download to stderr. """ def __init__( @@ -83,6 +87,9 @@ def __init__( downsample_kernel_size: int = 3, input_3x3: bool = True, num_classes: int = 1000, + pretrained: bool = False, + pretrained_arch: str = "se_resnet50", + progress: bool = True, ) -> None: super(SENet, self).__init__() @@ -176,6 +183,65 @@ def __init__( elif isinstance(m, nn.Linear): nn.init.constant_(torch.as_tensor(m.bias), 0) + if pretrained: + self._load_state_dict(pretrained_arch, progress) + + def _load_state_dict(self, arch, progress): + """ + This function is used to load pretrained models. + """ + model_urls = { + "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", + "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", + "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", + "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", + "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", + "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", + } + if arch in model_urls.keys(): + model_url = model_urls[arch] + else: + error_msg = ( + "only senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d " + + "and se_resnext101_32x4d are supported to load pretrained weights." + ) + raise AssertionError(error_msg) + + pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") + pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") + pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") + pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") + pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") + pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") + + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + new_key = None + if pattern_conv.match(key): + new_key = re.sub(pattern_conv, r"\1conv.\2", key) + elif pattern_bn.match(key): + new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) + elif pattern_se.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) + elif pattern_se2.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) + elif pattern_down_conv.match(key): + new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) + elif pattern_down_bn.match(key): + new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) + if new_key: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + model_dict = self.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + self.load_state_dict(model_dict) + def _make_layer( self, block: Type[Union[SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck]], @@ -248,56 +314,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -model_urls = { - "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", - "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", - "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", - "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", - "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", - "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", -} - - -def _load_state_dict(model, model_url, progress): - """ - This function is used to load pretrained models. - """ - pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") - pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") - pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") - pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") - pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") - pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") - - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - new_key = None - if pattern_conv.match(key): - new_key = re.sub(pattern_conv, r"\1conv.\2", key) - elif pattern_bn.match(key): - new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) - elif pattern_se.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) - elif pattern_se2.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) - elif pattern_down_conv.match(key): - new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) - elif pattern_down_bn.match(key): - new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) - if new_key: - state_dict[new_key] = state_dict[key] - del state_dict[key] - - model_dict = model.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - model.load_state_dict(model_dict) - - def senet154( spatial_dims: int, in_channels: int, @@ -320,10 +336,10 @@ def senet154( dropout_prob=0.2, dropout_dim=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="senet154", + progress=progress, ) - if pretrained: - arch = "senet154" - _load_state_dict(model, model_urls[arch], progress) return model @@ -347,10 +363,10 @@ def se_resnet50( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnet50", + progress=progress, ) - if pretrained: - arch = "se_resnet50" - _load_state_dict(model, model_urls[arch], progress) return model @@ -375,10 +391,10 @@ def se_resnet101( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnet101", + progress=progress, ) - if pretrained: - arch = "se_resnet101" - _load_state_dict(model, model_urls[arch], progress) return model @@ -403,10 +419,10 @@ def se_resnet152( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnet152", + progress=progress, ) - if pretrained: - arch = "se_resnet152" - _load_state_dict(model, model_urls[arch], progress) return model @@ -430,10 +446,10 @@ def se_resnext50_32x4d( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnext50_32x4d", + progress=progress, ) - if pretrained: - arch = "se_resnext50_32x4d" - _load_state_dict(model, model_urls[arch], progress) return model @@ -457,8 +473,8 @@ def se_resnext101_32x4d( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnext101_32x4d", + progress=progress, ) - if pretrained: - arch = "se_resnext101_32x4d" - _load_state_dict(model, model_urls[arch], progress) return model From 1fe5d1def266eeef3e448fce2cfc9ac94c9812f1 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 19 Mar 2021 13:56:34 +0800 Subject: [PATCH 2/2] rewrite error message add test cases Signed-off-by: Yiheng Wang --- monai/networks/nets/densenet.py | 5 +++-- monai/networks/nets/senet.py | 7 +++---- tests/test_densenet.py | 18 +++++++++++++++++- tests/test_senet.py | 26 +++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 6db5233890..4b4f2cc6a4 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -220,8 +220,9 @@ def _load_state_dict(self, arch, progress): if arch in model_urls.keys(): model_url = model_urls[arch] else: - error_msg = "only densenet121, densenet169 and densenet201 are supported to load pretrained weights." - raise AssertionError(error_msg) + raise ValueError( + "only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights." + ) pattern = re.compile( r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index d6a657ae03..333a3b1159 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -201,11 +201,10 @@ def _load_state_dict(self, arch, progress): if arch in model_urls.keys(): model_url = model_urls[arch] else: - error_msg = ( - "only senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d " - + "and se_resnext101_32x4d are supported to load pretrained weights." + raise ValueError( + "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', \ + and se_resnext101_32x4d are supported to load pretrained weights." ) - raise AssertionError(error_msg) pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 41b5fbf7d6..5ead5f5818 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -17,7 +17,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import densenet121, densenet169, densenet201, densenet264 +from monai.networks.nets import DenseNet, densenet121, densenet169, densenet201, densenet264 from monai.utils import optional_import from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save @@ -78,6 +78,17 @@ (1, 3, 32, 32), ] +TEST_PRETRAINED_2D_CASE_4 = [ + { + "pretrained": True, + "pretrained_arch": "densenet264", + "progress": False, + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 1, + }, +] + class TestPretrainedDENSENET(unittest.TestCase): @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2]) @@ -100,6 +111,11 @@ def test_pretrain_consistency(self, model, input_param, input_shape): expected_result = torchvision_net.features.forward(example) self.assertTrue(torch.all(result == expected_result)) + @parameterized.expand([TEST_PRETRAINED_2D_CASE_4]) + def test_ill_pretrain(self, input_param): + with self.assertRaisesRegex(ValueError, ""): + net = DenseNet(**input_param) + class TestDENSENET(unittest.TestCase): @parameterized.expand(TEST_CASES) diff --git a/tests/test_senet.py b/tests/test_senet.py index c1327ceb7d..a2d96e1f18 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -17,7 +17,9 @@ from parameterized import parameterized from monai.networks import eval_mode +from monai.networks.blocks.squeeze_and_excitation import SEBottleneck from monai.networks.nets import ( + SENet, se_resnet50, se_resnet101, se_resnet152, @@ -46,7 +48,20 @@ TEST_CASE_5 = [se_resnext50_32x4d, NET_ARGS] TEST_CASE_6 = [se_resnext101_32x4d, NET_ARGS] -TEST_CASE_PRETRAINED = [se_resnet50, {"spatial_dims": 2, "in_channels": 3, "num_classes": 2, "pretrained": True}] +TEST_CASE_PRETRAINED_1 = [se_resnet50, {"spatial_dims": 2, "in_channels": 3, "num_classes": 2, "pretrained": True}] +TEST_CASE_PRETRAINED_2 = [ + { + "spatial_dims": 2, + "in_channels": 3, + "block": SEBottleneck, + "layers": [3, 8, 36, 3], + "groups": 64, + "reduction": 16, + "num_classes": 2, + "pretrained": True, + "pretrained_arch": "resnet50", + } +] class TestSENET(unittest.TestCase): @@ -67,7 +82,7 @@ def test_script(self, net, net_args): class TestPretrainedSENET(unittest.TestCase): - @parameterized.expand([TEST_CASE_PRETRAINED]) + @parameterized.expand([TEST_CASE_PRETRAINED_1]) def test_senet_shape(self, model, input_param): net = test_pretrained_networks(model, input_param, device) input_data = torch.randn(3, 3, 64, 64).to(device) @@ -77,7 +92,7 @@ def test_senet_shape(self, model, input_param): result = net(input_data) self.assertEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_PRETRAINED]) + @parameterized.expand([TEST_CASE_PRETRAINED_1]) @skipUnless(has_cadene_pretrain, "Requires `pretrainedmodels` package.") def test_pretrain_consistency(self, model, input_param): input_data = torch.randn(1, 3, 64, 64).to(device) @@ -92,6 +107,11 @@ def test_pretrain_consistency(self, model, input_param): # a conv layer with kernel size equals to 1. It may bring a little difference. self.assertTrue(torch.allclose(result, expected_result, rtol=1e-5, atol=1e-5)) + @parameterized.expand([TEST_CASE_PRETRAINED_2]) + def test_ill_pretrain(self, input_param): + with self.assertRaisesRegex(ValueError, ""): + net = SENet(**input_param) + if __name__ == "__main__": unittest.main()