From 75fd543e9a308c1f83a3d8a76dee6ac4b7ffaabf Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 4 Mar 2021 19:25:48 +0800 Subject: [PATCH] Fix senet pretrained weights issue Signed-off-by: Yiheng Wang --- monai/networks/nets/senet.py | 4 ++-- tests/test_senet.py | 32 +++++++++++++++++++++++++++----- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 655ff203c7..ef67f853d6 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -275,7 +275,7 @@ def _load_state_dict(model, model_url, progress): if pattern_conv.match(key): new_key = re.sub(pattern_conv, r"\1conv.\2", key) elif pattern_bn.match(key): - new_key = re.sub(pattern_bn, r"\1conv\2norm.\3", key) + new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) elif pattern_se.match(key): state_dict[key] = state_dict[key].squeeze() new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) @@ -285,7 +285,7 @@ def _load_state_dict(model, model_url, progress): elif pattern_down_conv.match(key): new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) elif pattern_down_bn.match(key): - new_key = re.sub(pattern_down_bn, r"\1project.norm.\2", key) + new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) if new_key: state_dict[new_key] = state_dict[key] del state_dict[key] diff --git a/tests/test_senet.py b/tests/test_senet.py index 883d75d62d..c1327ceb7d 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -10,6 +10,8 @@ # limitations under the License. import unittest +from typing import TYPE_CHECKING +from unittest import skipUnless import torch from parameterized import parameterized @@ -23,8 +25,17 @@ se_resnext101_32x4d, senet154, ) +from monai.utils import optional_import from tests.utils import test_pretrained_networks, test_script_save +if TYPE_CHECKING: + import pretrainedmodels + + has_cadene_pretrain = True +else: + pretrainedmodels, has_cadene_pretrain = optional_import("pretrainedmodels") + + device = "cuda" if torch.cuda.is_available() else "cpu" NET_ARGS = {"spatial_dims": 3, "in_channels": 2, "num_classes": 2} @@ -56,11 +67,7 @@ def test_script(self, net, net_args): class TestPretrainedSENET(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_PRETRAINED, - ] - ) + @parameterized.expand([TEST_CASE_PRETRAINED]) def test_senet_shape(self, model, input_param): net = test_pretrained_networks(model, input_param, device) input_data = torch.randn(3, 3, 64, 64).to(device) @@ -70,6 +77,21 @@ def test_senet_shape(self, model, input_param): result = net(input_data) self.assertEqual(result.shape, expected_shape) + @parameterized.expand([TEST_CASE_PRETRAINED]) + @skipUnless(has_cadene_pretrain, "Requires `pretrainedmodels` package.") + def test_pretrain_consistency(self, model, input_param): + input_data = torch.randn(1, 3, 64, 64).to(device) + net = test_pretrained_networks(model, input_param, device) + with eval_mode(net): + result = net.features(input_data) + cadene_net = pretrainedmodels.se_resnet50().to(device) + with eval_mode(cadene_net): + expected_result = cadene_net.features(input_data) + # The difference between Cadene's senet and our version is that + # we use nn.Linear as the FC layer, but Cadene's version uses + # a conv layer with kernel size equals to 1. It may bring a little difference. + self.assertTrue(torch.allclose(result, expected_result, rtol=1e-5, atol=1e-5)) + if __name__ == "__main__": unittest.main()