From a230402d6e13a02def3c3a2ba749a7d55eaaadd2 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 17 Jul 2024 19:09:03 +0800 Subject: [PATCH 1/4] fix #7923 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/resnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 6e61db07ca..989c4f62b4 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -510,7 +510,7 @@ def _resnet( # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) if shortcut_type == kwargs.get("shortcut_type", "B") and ( - bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True + bool(bias_downsample) == kwargs.get("bias_downsample", True) ): # Download the MedicalNet pretrained model model_state_dict = get_pretrained_resnet_medicalnet( @@ -518,8 +518,8 @@ def _resnet( ) else: raise NotImplementedError( - f"Please set shortcut_type to {shortcut_type} and bias_downsample to" - f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}" + f"Please set shortcut_type to {shortcut_type} and bias_downsample to " + f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'} " f"when using pretrained MedicalNet resnet{resnet_depth}" ) else: From 17ae7097e4131922be3c70d050b7eda49ece316b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 17 Jul 2024 19:12:59 +0800 Subject: [PATCH 2/4] fix error message Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/resnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 989c4f62b4..03b1a82f4c 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -518,8 +518,7 @@ def _resnet( ) else: raise NotImplementedError( - f"Please set shortcut_type to {shortcut_type} and bias_downsample to " - f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'} " + f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bool(bias_downsample)} " f"when using pretrained MedicalNet resnet{resnet_depth}" ) else: From 1d5e37e7509319767c7513a233e01e1c7a71d07d Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 17 Jul 2024 22:24:50 +0800 Subject: [PATCH 3/4] address comments Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/resnet.py | 6 +++--- tests/test_resnet.py | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 03b1a82f4c..dac54741c5 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -510,7 +510,7 @@ def _resnet( # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) if shortcut_type == kwargs.get("shortcut_type", "B") and ( - bool(bias_downsample) == kwargs.get("bias_downsample", True) + bias_downsample == kwargs.get("bias_downsample", True) ): # Download the MedicalNet pretrained model model_state_dict = get_pretrained_resnet_medicalnet( @@ -518,7 +518,7 @@ def _resnet( ) else: raise NotImplementedError( - f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bool(bias_downsample)} " + f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bias_downsample} " f"when using pretrained MedicalNet resnet{resnet_depth}" ) else: @@ -680,7 +680,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int): # After testing # False: 10, 50, 101, 152, 200 # Any: 18, 34 - bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 + bias_downsample = True if resnet_depth in [18, 34] else False # 18, 10, 34 shortcut_type = "A" if resnet_depth in [18, 34] else "B" return bias_downsample, shortcut_type diff --git a/tests/test_resnet.py b/tests/test_resnet.py index e873f1238a..a55d18f5de 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -266,7 +266,7 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): @parameterized.expand(PRETRAINED_TEST_CASES) @skip_if_quick @skip_if_no_cuda - def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape): + def test_resnet_pretrained(self, model, input_param, _input_shape, _expected_shape): net = model(**input_param).to(device) # Save ckpt torch.save(net.state_dict(), self.tmp_ckpt_filename) @@ -290,9 +290,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape and input_param.get("n_input_channels", 3) == 1 and input_param.get("feed_forward", True) is False and input_param.get("shortcut_type", "B") == shortcut_type - and ( - input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True - ) + and (input_param.get("bias_downsample", True) == bias_downsample) ): model(**cp_input_param) else: @@ -303,7 +301,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape cp_input_param["n_input_channels"] = 1 cp_input_param["feed_forward"] = False cp_input_param["shortcut_type"] = shortcut_type - cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True + cp_input_param["bias_downsample"] = bias_downsample if cp_input_param.get("spatial_dims", 3) == 3: with skip_if_downloading_fails(): pretrained_net = model(**cp_input_param).to(device) From 10fa744102db656e3987f4e8876114b322da3308 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 17 Jul 2024 23:37:39 +0800 Subject: [PATCH 4/4] Update monai/networks/nets/resnet.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index dac54741c5..d62722478e 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -680,7 +680,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int): # After testing # False: 10, 50, 101, 152, 200 # Any: 18, 34 - bias_downsample = True if resnet_depth in [18, 34] else False # 18, 10, 34 + bias_downsample = resnet_depth in (18, 34) shortcut_type = "A" if resnet_depth in [18, 34] else "B" return bias_downsample, shortcut_type