From ad50094050451c2b1b2414bf2e3b66dfc15d34ec Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 3 Mar 2021 23:18:04 +0800 Subject: [PATCH 1/2] Update load pretrain for densenet Signed-off-by: Yiheng Wang --- monai/networks/nets/densenet.py | 4 ++-- tests/test_densenet.py | 42 ++++++++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index ad1d1d6e5f..a59ab99e68 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -210,14 +210,14 @@ def _load_state_dict(model, model_url, progress): `_ """ pattern = re.compile( - r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): res = pattern.match(key) if res: - new_key = res.group(1) + res.group(2) + new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) state_dict[new_key] = state_dict[key] del state_dict[key] diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 876689314a..b9f5132c43 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -10,14 +10,26 @@ # limitations under the License. import unittest +from typing import TYPE_CHECKING +from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets import densenet121, densenet169, densenet201, densenet264 +from monai.utils import optional_import from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save + +if TYPE_CHECKING: + import torchvision + + has_torchvision = True +else: + torchvision, has_torchvision = optional_import("torchvision") + + device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ # 4-channel 3D, batch 2 @@ -50,27 +62,45 @@ TEST_PRETRAINED_2D_CASE_1 = [ # 4-channel 2D, batch 2 densenet121, {"pretrained": True, "progress": True, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, - (2, 2, 32, 64), - (2, 3), + (1, 2, 32, 64), + (1, 3), ] TEST_PRETRAINED_2D_CASE_2 = [ # 4-channel 2D, batch 2 densenet121, - {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, - (2, 2, 32, 64), - (2, 3), + {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 1}, + (1, 2, 32, 64), + (1, 1), +] + +TEST_PRETRAINED_2D_CASE_3 = [ + densenet121, + {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 3, "out_channels": 1}, + (1, 3, 32, 32), ] class TestPretrainedDENSENET(unittest.TestCase): @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2]) @skip_if_quick - def test_121_3d_shape_pretrain(self, model, input_param, input_shape, expected_shape): + def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_shape): net = test_pretrained_networks(model, input_param, device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + @parameterized.expand([TEST_PRETRAINED_2D_CASE_3]) + @skipUnless(has_torchvision, "Requires `torchvision` package.") + def test_pretrain_consistency(self, model, input_param, input_shape): + example = torch.randn(input_shape).to(device) + net = test_pretrained_networks(model, input_param, device) + with eval_mode(net): + result = net.features.forward(example) + torchvision_net = torchvision.models.densenet121(pretrained=True).to(device) + with eval_mode(torchvision_net): + expected_result = torchvision_net.features.forward(example) + self.assertTrue(torch.all(result == expected_result)) + class TestDENSENET(unittest.TestCase): @parameterized.expand(TEST_CASES) From 32a0f13725f4fcb0e290252b7e744a92e45b4355 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 3 Mar 2021 23:27:49 +0800 Subject: [PATCH 2/2] Fix isort issue Signed-off-by: Yiheng Wang --- tests/test_densenet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_densenet.py b/tests/test_densenet.py index b9f5132c43..41b5fbf7d6 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -21,7 +21,6 @@ from monai.utils import optional_import from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save - if TYPE_CHECKING: import torchvision