From 0613c2f8c9c32a096f3ccd825360f8f622c8010a Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Wed, 4 Oct 2023 14:21:49 +0000 Subject: [PATCH 01/31] Simplify resnet pretrained flag Fixes: #7047 Original behaviour did not support True pretrained flag. Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index e742db5ca5..0ff2091347 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -329,21 +329,39 @@ def _resnet( block: type[ResNetBlock | ResNetBottleneck], layers: list[int], block_inplanes: list[int], - pretrained: bool, + pretrained: bool | str, progress: bool, **kwargs: Any, ) -> ResNet: model: ResNet = ResNet(block, layers, block_inplanes, **kwargs) if pretrained: + if isinstance(pretrained, str): + if Path(pretrained).exists(): + logger.info(f"Loading weights from {weights_path}...") + checkpoint = torch.load(pretrained, map_location=device) + else: + ### Throw error + raise FileNotFoundError("The pretrained checkpoint file is not found") + else: + ### Throw error # Author of paper zipped the state_dict on googledrive, # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). # Would like to load dict from url but need somewhere to save the state dicts. raise NotImplementedError( - "Currently not implemented. You need to manually download weights provided by the paper's author" - " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet" - "Please ensure you pass the appropriate `shortcut_type` and `bias_downsample` args. as specified" - "here: https://github.com/Tencent/MedicalNet/tree/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b#update20190730" + "Provide the pretrained checkpoint string path" ) + + if "state_dict" in checkpoint: + model_state_dict = checkpoint["state_dict"] + model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} + else: + ### Throw error + raise KeyError( + "The checkpoint should contain the pretrained model state dict with the following key: 'state_dict'" + ) + + model.load_state_dict(model_state_dict, strict=True) + return model From 1005fe8bb4344b1c6635f13375df6d4aa0813d6a Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Fri, 6 Oct 2023 14:35:15 +0000 Subject: [PATCH 02/31] add tests + typos Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 16 +++++++++------- tests/test_resnet.py | 27 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 0ff2091347..9759a16e9b 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -14,6 +14,8 @@ from collections.abc import Callable from functools import partial from typing import Any +from pathlib import Path +import logging import torch import torch.nn as nn @@ -36,6 +38,9 @@ "resnet200", ] +logger = logging.getLogger(__name__) + +device = "cuda" if torch.cuda.is_available() else "cpu" def get_inplanes(): return [64, 128, 256, 512] @@ -337,19 +342,16 @@ def _resnet( if pretrained: if isinstance(pretrained, str): if Path(pretrained).exists(): - logger.info(f"Loading weights from {weights_path}...") + logger.info(f"Loading weights from {pretrained}...") checkpoint = torch.load(pretrained, map_location=device) else: ### Throw error raise FileNotFoundError("The pretrained checkpoint file is not found") else: ### Throw error - # Author of paper zipped the state_dict on googledrive, - # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). - # Would like to load dict from url but need somewhere to save the state dicts. - raise NotImplementedError( - "Provide the pretrained checkpoint string path" - ) + raise NotImplementedError( + "Provide the pretrained checkpoint string path" + ) if "state_dict" in checkpoint: model_state_dict = checkpoint["state_dict"] diff --git a/tests/test_resnet.py b/tests/test_resnet.py index cc24106373..6151b13c45 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -12,6 +12,8 @@ from __future__ import annotations import unittest +import copy +import os from typing import TYPE_CHECKING import torch @@ -30,6 +32,8 @@ else: torchvision, has_torchvision = optional_import("torchvision") +# from torchvision.models import ResNet50_Weights, resnet50 + device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ # 3D, batch 3, 2 input channel @@ -159,9 +163,11 @@ ] TEST_CASES = [] +PRETRAINED_TEST_CASES = [] for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) + PRETRAINED_TEST_CASES.append([model, *case]) for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]: TEST_CASES.append([ResNet, *case]) @@ -181,6 +187,27 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): else: self.assertTrue(result.shape in expected_shape) + @parameterized.expand(PRETRAINED_TEST_CASES) + def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape): + net = model(**input_param).to(device) + tmp_ckpt_filename = "monai_unittest_tmp_ckpt.pth" + # Save ckpt + torch.save({ + "state_dict": net.state_dict() + }, + tmp_ckpt_filename) + + cp_input_param = copy.copy(input_param) + cp_input_param["pretrained"] = tmp_ckpt_filename + pretrained_net = model(**cp_input_param) + assert str(net.state_dict()) == str(pretrained_net.state_dict()) + + with self.assertRaises(NotImplementedError): + cp_input_param["pretrained"] = True + bool_pretrained_net = model(**cp_input_param) + + os.remove(tmp_ckpt_filename) + @parameterized.expand(TEST_SCRIPT_CASES) def test_script(self, model, input_param, input_shape, expected_shape): net = model(**input_param) From 8b750955e4e9a44cbdd5dee3a376871482c7baa8 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Fri, 6 Oct 2023 16:26:01 +0000 Subject: [PATCH 03/31] add MedicalNet resnet 3D pretrained models support Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 9759a16e9b..4c4c86fb61 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -16,6 +16,8 @@ from typing import Any from pathlib import Path import logging +import re +from huggingface_hub import hf_hub_download import torch import torch.nn as nn @@ -25,6 +27,9 @@ from monai.utils import ensure_tuple_rep from monai.utils.module import look_up_option +MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" +MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" + __all__ = [ "ResNet", "ResNetBlock", @@ -348,10 +353,14 @@ def _resnet( ### Throw error raise FileNotFoundError("The pretrained checkpoint file is not found") else: - ### Throw error - raise NotImplementedError( - "Provide the pretrained checkpoint string path" - ) + resnet_depth = int(re.search(r"resnet(\d+)", arch).group(1)) + # Download the MedicalNet pretrained model + logger.info(f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}") + pretrained_path = hf_hub_download( + repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", + filename=f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth") + + checkpoint = torch.load(pretrained_path, map_location=device) if "state_dict" in checkpoint: model_state_dict = checkpoint["state_dict"] From 00ec022f8faee9622978ac5d66b07a5838e34913 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 16:27:07 +0000 Subject: [PATCH 04/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/resnet.py | 6 +++--- tests/test_resnet.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 4c4c86fb61..296133912a 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -357,9 +357,9 @@ def _resnet( # Download the MedicalNet pretrained model logger.info(f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}") pretrained_path = hf_hub_download( - repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", + repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", filename=f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth") - + checkpoint = torch.load(pretrained_path, map_location=device) if "state_dict" in checkpoint: @@ -370,7 +370,7 @@ def _resnet( raise KeyError( "The checkpoint should contain the pretrained model state dict with the following key: 'state_dict'" ) - + model.load_state_dict(model_state_dict, strict=True) return model diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 6151b13c45..f02da422c8 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -196,18 +196,18 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape "state_dict": net.state_dict() }, tmp_ckpt_filename) - + cp_input_param = copy.copy(input_param) cp_input_param["pretrained"] = tmp_ckpt_filename pretrained_net = model(**cp_input_param) assert str(net.state_dict()) == str(pretrained_net.state_dict()) - + with self.assertRaises(NotImplementedError): cp_input_param["pretrained"] = True - bool_pretrained_net = model(**cp_input_param) - + model(**cp_input_param) + os.remove(tmp_ckpt_filename) - + @parameterized.expand(TEST_SCRIPT_CASES) def test_script(self, model, input_param, input_shape, expected_shape): net = model(**input_param) From f5e09b1f90259b03c63a325083e18bdb84987453 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 7 Oct 2023 13:31:18 +0000 Subject: [PATCH 05/31] add optional import Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 10 ++++++---- tests/test_resnet.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 296133912a..abd62097a8 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -14,19 +14,21 @@ from collections.abc import Callable from functools import partial from typing import Any -from pathlib import Path import logging -import re -from huggingface_hub import hf_hub_download + import torch import torch.nn as nn from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer -from monai.utils import ensure_tuple_rep +from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option +Path, _ = optional_import("pathlib", name = "Path") +re, _ = optional_import("re") +hf_hub_download, _ = optional_import("huggingface_hub.hf_hub_download") + MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" diff --git a/tests/test_resnet.py b/tests/test_resnet.py index f02da422c8..e2785591b1 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -12,8 +12,6 @@ from __future__ import annotations import unittest -import copy -import os from typing import TYPE_CHECKING import torch @@ -25,6 +23,9 @@ from monai.utils import optional_import from tests.utils import test_script_save +copy, _ = optional_import("copy") +os, _ = optional_import("os") + if TYPE_CHECKING: import torchvision From c7a827be7c25694f6305690a9c324f410f630242 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 7 Oct 2023 13:50:07 +0000 Subject: [PATCH 06/31] simplify user pretrained weights loading Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index abd62097a8..b79bb709de 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -368,10 +368,7 @@ def _resnet( model_state_dict = checkpoint["state_dict"] model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} else: - ### Throw error - raise KeyError( - "The checkpoint should contain the pretrained model state dict with the following key: 'state_dict'" - ) + model_state_dict = checkpoint model.load_state_dict(model_state_dict, strict=True) From fa60fadacfe39032bae3e91dd9579f0ee5af904f Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 7 Oct 2023 22:49:35 +0000 Subject: [PATCH 07/31] Manage MedicalNet resnet model validation with pretrained flag Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 51 +++++++++++++++++++++-------------- monai/networks/utils.py | 47 ++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 20 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index b79bb709de..3265852773 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -24,10 +24,10 @@ from monai.networks.layers.utils import get_pool_layer from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option +from monai.networks.utils import get_pretrained_resnet_medicalnet Path, _ = optional_import("pathlib", name = "Path") re, _ = optional_import("re") -hf_hub_download, _ = optional_import("huggingface_hub.hf_hub_download") MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" @@ -350,31 +350,42 @@ def _resnet( if isinstance(pretrained, str): if Path(pretrained).exists(): logger.info(f"Loading weights from {pretrained}...") - checkpoint = torch.load(pretrained, map_location=device) + model_state_dict = torch.load(pretrained, map_location=device) else: ### Throw error raise FileNotFoundError("The pretrained checkpoint file is not found") - else: - resnet_depth = int(re.search(r"resnet(\d+)", arch).group(1)) - # Download the MedicalNet pretrained model - logger.info(f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}") - pretrained_path = hf_hub_download( - repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", - filename=f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth") - - checkpoint = torch.load(pretrained_path, map_location=device) - - if "state_dict" in checkpoint: - model_state_dict = checkpoint["state_dict"] - model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} - else: - model_state_dict = checkpoint - + else: + # Also check bias downsample and shortcut. + if kwargs.get("spatial_dims", 3) == 3: + if kwargs.get("n_input_channels", 3)==1 and kwargs.get("feed_forward", True)==False: + resnet_depth = int(re.search(r"resnet(\d+)", arch).group(1)) + # get shortcut_type and bias_downsample. + def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : + """ + Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth + """ + # After testing + # False: 10, 50, 101, 152, 200 + # Any: 18, 34 + bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 + shortcut_type = "A" if resnet_depth in [18, 34] else "B" + return bias_downsample, shortcut_type + + # Check model bias_downsample and shortcut_type + bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) + if shortcut_type == kwargs.get("shortcut_type", "B") and (bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True): + # Download the MedicalNet pretrained model + model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device, datasets23=True) + else: + raise NotImplementedError(f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bool(bias_downsample) if bias_downsample!=-1 else 'True or False'} when using pretrained MedicalNet resnet{resnet_depth}") + else: + raise NotImplementedError("Please set n_input_channels to 1 and feed_forward to False in order to use MedicalNet pretrained weights") + else: + raise NotImplementedError("MedicalNet pretrained weights are only avalaible for 3D models") + model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} model.load_state_dict(model_state_dict, strict=True) - return model - def resnet10(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: """ResNet-10 with optional pretrained support when `spatial_dims` is 3. diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 12533183b1..b3e0c696d4 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -36,6 +36,8 @@ onnx, _ = optional_import("onnx") onnxreference, _ = optional_import("onnx.reference") onnxruntime, _ = optional_import("onnxruntime") +hf_hub_download, _ = optional_import("huggingface_hub", name = "hf_hub_download") +EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name = "EntryNotFoundError") __all__ = [ "one_hot", @@ -1164,3 +1166,48 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.") logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") + +def get_pretrained_resnet_medicalnet(resnet_depth: int, device: torch.device = torch.device("cpu"), datasets23: bool = True): + """ + Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet + + Args: + resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 + device: device on which the returned state dict will be loaded. + datasets23: if True, get the weights trained on more datasets (23). Not all depths are available. If not, standard weights are returned. + + Returns: + Pretrained state dict + + Raises: + huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub + NotImplementedError: if `resnet_depth` is not supported + """ + + MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" + MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" + SUPPORTED_DEPTH = [10, 18, 34, 50, 101, 152, 200] + + logger.info(f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}") + + if resnet_depth in SUPPORTED_DEPTH: + filename = f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" if not datasets23 else f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}_23dataset.pth" + try: + pretrained_path = hf_hub_download( + repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", + filename=filename) + except Exception: + if datasets23: + logger.info(f"{filename} not available for resnet{resnet_depth}") + filename = f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" + logger.info(f"Trying with {filename}") + pretrained_path = hf_hub_download( + repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", + filename=filename) + else: + raise EntryNotFoundError(f"{filename} not found on {MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}") + checkpoint = torch.load(pretrained_path, map_location=device) + else: + raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]") + logger.info(f"{filename} downloaded") + return checkpoint.get("state_dict") From e9ba99dd6d68934584e0c74aecaaed9ec31048ac Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sat, 7 Oct 2023 23:13:59 +0000 Subject: [PATCH 08/31] update resnet tests Signed-off-by: vgrau98 --- tests/test_resnet.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index e2785591b1..bde2f0c438 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -25,6 +25,7 @@ copy, _ = optional_import("copy") os, _ = optional_import("os") +re, _ = optional_import("re") if TYPE_CHECKING: import torchvision @@ -193,20 +194,24 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape net = model(**input_param).to(device) tmp_ckpt_filename = "monai_unittest_tmp_ckpt.pth" # Save ckpt - torch.save({ - "state_dict": net.state_dict() - }, - tmp_ckpt_filename) + torch.save(net.state_dict(), tmp_ckpt_filename) cp_input_param = copy.copy(input_param) + # Custom pretrained weights cp_input_param["pretrained"] = tmp_ckpt_filename pretrained_net = model(**cp_input_param) assert str(net.state_dict()) == str(pretrained_net.state_dict()) - with self.assertRaises(NotImplementedError): - cp_input_param["pretrained"] = True + # True flag + cp_input_param["pretrained"] = True + resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1)) + model(**cp_input_param) + if input_param.get("spatial_dims", 3) == 3: model(**cp_input_param) - + else: + with self.assertRaises(NotImplementedError): + model(**cp_input_param) + os.remove(tmp_ckpt_filename) @parameterized.expand(TEST_SCRIPT_CASES) From 0ddfb045329dd3276069bb5730d581e48078797f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Oct 2023 23:15:09 +0000 Subject: [PATCH 09/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/resnet.py | 14 +++++++------- monai/networks/utils.py | 18 +++++++++--------- tests/test_resnet.py | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 3265852773..581fda4c86 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -27,7 +27,7 @@ from monai.networks.utils import get_pretrained_resnet_medicalnet Path, _ = optional_import("pathlib", name = "Path") -re, _ = optional_import("re") +re, _ = optional_import("re") MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" @@ -354,12 +354,12 @@ def _resnet( else: ### Throw error raise FileNotFoundError("The pretrained checkpoint file is not found") - else: - # Also check bias downsample and shortcut. + else: + # Also check bias downsample and shortcut. if kwargs.get("spatial_dims", 3) == 3: - if kwargs.get("n_input_channels", 3)==1 and kwargs.get("feed_forward", True)==False: + if kwargs.get("n_input_channels", 3)==1 and kwargs.get("feed_forward", True) is False: resnet_depth = int(re.search(r"resnet(\d+)", arch).group(1)) - # get shortcut_type and bias_downsample. + # get shortcut_type and bias_downsample. def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : """ Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth @@ -370,7 +370,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 shortcut_type = "A" if resnet_depth in [18, 34] else "B" return bias_downsample, shortcut_type - + # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) if shortcut_type == kwargs.get("shortcut_type", "B") and (bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True): @@ -378,7 +378,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device, datasets23=True) else: raise NotImplementedError(f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bool(bias_downsample) if bias_downsample!=-1 else 'True or False'} when using pretrained MedicalNet resnet{resnet_depth}") - else: + else: raise NotImplementedError("Please set n_input_channels to 1 and feed_forward to False in order to use MedicalNet pretrained weights") else: raise NotImplementedError("MedicalNet pretrained weights are only avalaible for 3D models") diff --git a/monai/networks/utils.py b/monai/networks/utils.py index b3e0c696d4..5c0acaf25b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1166,32 +1166,32 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.") logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") - + def get_pretrained_resnet_medicalnet(resnet_depth: int, device: torch.device = torch.device("cpu"), datasets23: bool = True): """ Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet - + Args: resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 device: device on which the returned state dict will be loaded. - datasets23: if True, get the weights trained on more datasets (23). Not all depths are available. If not, standard weights are returned. - + datasets23: if True, get the weights trained on more datasets (23). Not all depths are available. If not, standard weights are returned. + Returns: Pretrained state dict - + Raises: huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub NotImplementedError: if `resnet_depth` is not supported """ - + MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" SUPPORTED_DEPTH = [10, 18, 34, 50, 101, 152, 200] - + logger.info(f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}") - + if resnet_depth in SUPPORTED_DEPTH: - filename = f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" if not datasets23 else f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}_23dataset.pth" + filename = f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" if not datasets23 else f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}_23dataset.pth" try: pretrained_path = hf_hub_download( repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", diff --git a/tests/test_resnet.py b/tests/test_resnet.py index bde2f0c438..ffc9005477 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -204,14 +204,14 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape # True flag cp_input_param["pretrained"] = True - resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1)) + int(re.search(r"resnet(\d+)", model.__name__).group(1)) model(**cp_input_param) if input_param.get("spatial_dims", 3) == 3: model(**cp_input_param) else: with self.assertRaises(NotImplementedError): model(**cp_input_param) - + os.remove(tmp_ckpt_filename) @parameterized.expand(TEST_SCRIPT_CASES) From 06ed8b065970ad810fa7fcb0ef5101a05920e191 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 8 Oct 2023 17:57:48 +0000 Subject: [PATCH 10/31] update resnet unit tests Signed-off-by: vgrau98 --- tests/test_resnet.py | 46 ++++++++++++++++++++++++++++++++++++++------ tests/utils.py | 17 ++++++++++++++++ 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index ffc9005477..feda89ec7c 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -21,7 +21,9 @@ from monai.networks.nets import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 from monai.networks.nets.resnet import ResNetBlock from monai.utils import optional_import -from tests.utils import test_script_save +from monai.networks.utils import get_pretrained_resnet_medicalnet +from tests.utils import test_script_save, equal_state_dict + copy, _ = optional_import("copy") os, _ = optional_import("os") @@ -200,18 +202,50 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape # Custom pretrained weights cp_input_param["pretrained"] = tmp_ckpt_filename pretrained_net = model(**cp_input_param) - assert str(net.state_dict()) == str(pretrained_net.state_dict()) + assert (equal_state_dict(net.state_dict(), pretrained_net.state_dict())) # True flag cp_input_param["pretrained"] = True - int(re.search(r"resnet(\d+)", model.__name__).group(1)) - model(**cp_input_param) - if input_param.get("spatial_dims", 3) == 3: + resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1)) + + # Duplicate. see monai/networks/nets/resnet.py + def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : + """ + Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth + """ + # After testing + # False: 10, 50, 101, 152, 200 + # Any: 18, 34 + bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 + shortcut_type = "A" if resnet_depth in [18, 34] else "B" + return bias_downsample, shortcut_type + + bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) + + # With orig. test cases + if (input_param.get("spatial_dims", 3) == 3 and + input_param.get("n_input_channels", 3)==1 and + input_param.get("feed_forward", True)==False and + input_param.get("shortcut_type", "B") == shortcut_type and + (input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True) + ): model(**cp_input_param) else: with self.assertRaises(NotImplementedError): model(**cp_input_param) - + + # forcing MedicalNet pretrained download for 3D tests cases + cp_input_param["n_input_channels"] = 1 + cp_input_param["feed_forward"] = False + cp_input_param["shortcut_type"] = shortcut_type + cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample!=-1 else True + if cp_input_param.get("spatial_dims", 3)==3: + pretrained_net = model(**cp_input_param).to(device) + medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device = device) + medicalnet_state_dict = {key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()} + assert(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict)) + + # clean os.remove(tmp_ckpt_filename) @parameterized.expand(TEST_SCRIPT_CASES) diff --git a/tests/utils.py b/tests/utils.py index dc6e8fac44..87bdc7c8d0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -822,6 +822,23 @@ def command_line_tests(cmd, copy_env=True): errors = repr(e.stderr).replace("\\n", "\n").replace("\\t", "\t") raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e +def equal_state_dict(st_1, st_2): + """ + Compare 2 torch state dicts. + """ + r = True + for key_st_1, val_st_1 in st_1.items(): + if key_st_1 in st_2: + val_st_2 = st_2.get(key_st_1) + if not torch.equal(val_st_1, val_st_2): + r = False + break + else: + r = False + break + return r + + TEST_TORCH_TENSORS: tuple = (torch.as_tensor,) if torch.cuda.is_available(): From 955dcf1c6b6cfa2b93d118199f05dd6fa5c5d19d Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 8 Oct 2023 18:02:19 +0000 Subject: [PATCH 11/31] fix incorrect optional import Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 4 ++-- tests/test_resnet.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 581fda4c86..cef1f589d1 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -14,7 +14,9 @@ from collections.abc import Callable from functools import partial from typing import Any +from pathlib import Path import logging +import re import torch @@ -26,8 +28,6 @@ from monai.utils.module import look_up_option from monai.networks.utils import get_pretrained_resnet_medicalnet -Path, _ = optional_import("pathlib", name = "Path") -re, _ = optional_import("re") MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" diff --git a/tests/test_resnet.py b/tests/test_resnet.py index feda89ec7c..db65417d86 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -13,6 +13,9 @@ import unittest from typing import TYPE_CHECKING +import os +import re +import copy import torch from parameterized import parameterized @@ -24,11 +27,6 @@ from monai.networks.utils import get_pretrained_resnet_medicalnet from tests.utils import test_script_save, equal_state_dict - -copy, _ = optional_import("copy") -os, _ = optional_import("os") -re, _ = optional_import("re") - if TYPE_CHECKING: import torchvision From ff7f6d37b68ddb97bcc2400d9ebbc81d195616e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Oct 2023 18:03:27 +0000 Subject: [PATCH 12/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/resnet.py | 2 +- tests/test_resnet.py | 16 ++++++++-------- tests/utils.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index cef1f589d1..4f9f3463e6 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -24,7 +24,7 @@ from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer -from monai.utils import ensure_tuple_rep, optional_import +from monai.utils import ensure_tuple_rep from monai.utils.module import look_up_option from monai.networks.utils import get_pretrained_resnet_medicalnet diff --git a/tests/test_resnet.py b/tests/test_resnet.py index db65417d86..dfd11e5980 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -205,7 +205,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape # True flag cp_input_param["pretrained"] = True resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1)) - + # Duplicate. see monai/networks/nets/resnet.py def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : """ @@ -219,19 +219,19 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : return bias_downsample, shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) - + # With orig. test cases - if (input_param.get("spatial_dims", 3) == 3 and - input_param.get("n_input_channels", 3)==1 and - input_param.get("feed_forward", True)==False and + if (input_param.get("spatial_dims", 3) == 3 and + input_param.get("n_input_channels", 3)==1 and + input_param.get("feed_forward", True) is False and input_param.get("shortcut_type", "B") == shortcut_type and - (input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True) + (input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True) ): model(**cp_input_param) else: with self.assertRaises(NotImplementedError): model(**cp_input_param) - + # forcing MedicalNet pretrained download for 3D tests cases cp_input_param["n_input_channels"] = 1 cp_input_param["feed_forward"] = False @@ -242,7 +242,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device = device) medicalnet_state_dict = {key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()} assert(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict)) - + # clean os.remove(tmp_ckpt_filename) diff --git a/tests/utils.py b/tests/utils.py index 87bdc7c8d0..96abc80f1c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -836,8 +836,8 @@ def equal_state_dict(st_1, st_2): else: r = False break - return r - + return r + TEST_TORCH_TENSORS: tuple = (torch.as_tensor,) From bb6830f114f9e6167c77468a9fa763af4e877730 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Sun, 8 Oct 2023 21:41:57 +0000 Subject: [PATCH 13/31] Line shortening Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 11 ++++++++--- monai/networks/utils.py | 6 ++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 4f9f3463e6..53c8c5aba7 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -373,13 +373,18 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) - if shortcut_type == kwargs.get("shortcut_type", "B") and (bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True): + if shortcut_type == kwargs.get("shortcut_type", "B") and \ + (bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True): # Download the MedicalNet pretrained model model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device, datasets23=True) else: - raise NotImplementedError(f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bool(bias_downsample) if bias_downsample!=-1 else 'True or False'} when using pretrained MedicalNet resnet{resnet_depth}") + raise NotImplementedError(f"Please set shortcut_type to {shortcut_type} and bias_downsample to" \ + f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}" \ + f"when using pretrained MedicalNet resnet{resnet_depth}") else: - raise NotImplementedError("Please set n_input_channels to 1 and feed_forward to False in order to use MedicalNet pretrained weights") + raise NotImplementedError( + "Please set n_input_channels to 1" \ + "and feed_forward to False in order to use MedicalNet pretrained weights") else: raise NotImplementedError("MedicalNet pretrained weights are only avalaible for 3D models") model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 6a015cf241..901508f569 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1174,7 +1174,8 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: torch.device = t Args: resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 device: device on which the returned state dict will be loaded. - datasets23: if True, get the weights trained on more datasets (23). Not all depths are available. If not, standard weights are returned. + datasets23: if True, get the weights trained on more datasets (23). + Not all depths are available. If not, standard weights are returned. Returns: Pretrained state dict @@ -1191,7 +1192,8 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: torch.device = t logger.info(f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}") if resnet_depth in SUPPORTED_DEPTH: - filename = f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" if not datasets23 else f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}_23dataset.pth" + filename = f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" if not datasets23 \ + else f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}_23dataset.pth" try: pretrained_path = hf_hub_download( repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", From d21a02277d8cb9f2461df69769fd963cbad95b5d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Oct 2023 21:45:42 +0000 Subject: [PATCH 14/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 901508f569..ed7848d03c 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1174,7 +1174,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: torch.device = t Args: resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 device: device on which the returned state dict will be loaded. - datasets23: if True, get the weights trained on more datasets (23). + datasets23: if True, get the weights trained on more datasets (23). Not all depths are available. If not, standard weights are returned. Returns: From 373573323cbac814d06795ee4a6c39750c3b9bbd Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Mon, 9 Oct 2023 21:39:55 +0000 Subject: [PATCH 15/31] update resnet tests and deployment files Signed-off-by: vgrau98 --- docs/requirements.txt | 1 + requirements-dev.txt | 1 + setup.cfg | 1 + tests/test_resnet.py | 82 +++++++++++++++++++++++-------------------- 4 files changed, 46 insertions(+), 39 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index c4bd50ec2b..ac3891c420 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -39,3 +39,4 @@ opencv-python-headless onnx>=1.13.0 onnxruntime; python_version <= '3.10' zarr +huggingface_hub \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index d47e95a4f5..bba069cf64 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -56,3 +56,4 @@ filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 nvidia-ml-py +huggingface_hub \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 9d5a22963c..fe5fbf3b5b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -83,6 +83,7 @@ all = zarr lpips==0.1.4 nvidia-ml-py + huggingface_hub nibabel = nibabel ninja = diff --git a/tests/test_resnet.py b/tests/test_resnet.py index dfd11e5980..4dd3fd762d 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -14,6 +14,7 @@ import unittest from typing import TYPE_CHECKING import os +import sys import re import copy @@ -33,6 +34,8 @@ has_torchvision = True else: torchvision, has_torchvision = optional_import("torchvision") + +has_hf_modules = ("huggingface_hub" in sys.modules and "huggingface_hub.utils._errors" in sys.modules) # from torchvision.models import ResNet50_Weights, resnet50 @@ -202,46 +205,47 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape pretrained_net = model(**cp_input_param) assert (equal_state_dict(net.state_dict(), pretrained_net.state_dict())) - # True flag - cp_input_param["pretrained"] = True - resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1)) - - # Duplicate. see monai/networks/nets/resnet.py - def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : - """ - Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth - """ - # After testing - # False: 10, 50, 101, 152, 200 - # Any: 18, 34 - bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 - shortcut_type = "A" if resnet_depth in [18, 34] else "B" - return bias_downsample, shortcut_type - - bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) - - # With orig. test cases - if (input_param.get("spatial_dims", 3) == 3 and - input_param.get("n_input_channels", 3)==1 and - input_param.get("feed_forward", True) is False and - input_param.get("shortcut_type", "B") == shortcut_type and - (input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True) - ): - model(**cp_input_param) - else: - with self.assertRaises(NotImplementedError): + if has_hf_modules: + # True flag + cp_input_param["pretrained"] = True + resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1)) + + # Duplicate. see monai/networks/nets/resnet.py + def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : + """ + Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth + """ + # After testing + # False: 10, 50, 101, 152, 200 + # Any: 18, 34 + bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 + shortcut_type = "A" if resnet_depth in [18, 34] else "B" + return bias_downsample, shortcut_type + + bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) + + # With orig. test cases + if (input_param.get("spatial_dims", 3) == 3 and + input_param.get("n_input_channels", 3)==1 and + input_param.get("feed_forward", True) is False and + input_param.get("shortcut_type", "B") == shortcut_type and + (input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True) + ): model(**cp_input_param) - - # forcing MedicalNet pretrained download for 3D tests cases - cp_input_param["n_input_channels"] = 1 - cp_input_param["feed_forward"] = False - cp_input_param["shortcut_type"] = shortcut_type - cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample!=-1 else True - if cp_input_param.get("spatial_dims", 3)==3: - pretrained_net = model(**cp_input_param).to(device) - medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device = device) - medicalnet_state_dict = {key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()} - assert(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict)) + else: + with self.assertRaises(NotImplementedError): + model(**cp_input_param) + + # forcing MedicalNet pretrained download for 3D tests cases + cp_input_param["n_input_channels"] = 1 + cp_input_param["feed_forward"] = False + cp_input_param["shortcut_type"] = shortcut_type + cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample!=-1 else True + if cp_input_param.get("spatial_dims", 3)==3: + pretrained_net = model(**cp_input_param).to(device) + medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device = device) + medicalnet_state_dict = {key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()} + assert(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict)) # clean os.remove(tmp_ckpt_filename) From 8b6782af8f4e35d3b572bc6c27bc2416d7de4291 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 21:40:32 +0000 Subject: [PATCH 16/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/requirements.txt | 2 +- requirements-dev.txt | 2 +- tests/test_resnet.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index ac3891c420..a9bbc384f8 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -39,4 +39,4 @@ opencv-python-headless onnx>=1.13.0 onnxruntime; python_version <= '3.10' zarr -huggingface_hub \ No newline at end of file +huggingface_hub diff --git a/requirements-dev.txt b/requirements-dev.txt index bba069cf64..38715b8449 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -56,4 +56,4 @@ filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 nvidia-ml-py -huggingface_hub \ No newline at end of file +huggingface_hub diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 4dd3fd762d..9bd980accf 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -34,7 +34,7 @@ has_torchvision = True else: torchvision, has_torchvision = optional_import("torchvision") - + has_hf_modules = ("huggingface_hub" in sys.modules and "huggingface_hub.utils._errors" in sys.modules) # from torchvision.models import ResNet50_Weights, resnet50 From 02a360a128e82e04288a422a80a81d14d307f077 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 10 Oct 2023 10:02:38 +0000 Subject: [PATCH 17/31] [MONAI] code formatting Signed-off-by: monai-bot --- monai/networks/nets/resnet.py | 41 ++++++++++++++++++------------- monai/networks/utils.py | 30 +++++++++++++++-------- tests/test_resnet.py | 45 +++++++++++++++++++---------------- tests/utils.py | 2 +- 4 files changed, 70 insertions(+), 48 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 53c8c5aba7..801c9aeb16 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -11,23 +11,21 @@ from __future__ import annotations +import logging +import re from collections.abc import Callable from functools import partial -from typing import Any from pathlib import Path -import logging -import re - +from typing import Any import torch import torch.nn as nn from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer +from monai.networks.utils import get_pretrained_resnet_medicalnet from monai.utils import ensure_tuple_rep from monai.utils.module import look_up_option -from monai.networks.utils import get_pretrained_resnet_medicalnet - MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" @@ -49,6 +47,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu" + def get_inplanes(): return [64, 128, 256, 512] @@ -352,15 +351,16 @@ def _resnet( logger.info(f"Loading weights from {pretrained}...") model_state_dict = torch.load(pretrained, map_location=device) else: - ### Throw error + # Throw error raise FileNotFoundError("The pretrained checkpoint file is not found") else: # Also check bias downsample and shortcut. if kwargs.get("spatial_dims", 3) == 3: - if kwargs.get("n_input_channels", 3)==1 and kwargs.get("feed_forward", True) is False: + if kwargs.get("n_input_channels", 3) == 1 and kwargs.get("feed_forward", True) is False: resnet_depth = int(re.search(r"resnet(\d+)", arch).group(1)) # get shortcut_type and bias_downsample. - def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : + + def get_medicalnet_pretrained_resnet_args(resnet_depth: int): """ Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth """ @@ -373,24 +373,31 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) - if shortcut_type == kwargs.get("shortcut_type", "B") and \ - (bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True): + if shortcut_type == kwargs.get("shortcut_type", "B") and ( + bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True + ): # Download the MedicalNet pretrained model - model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device, datasets23=True) + model_state_dict = get_pretrained_resnet_medicalnet( + resnet_depth, device=device, datasets23=True + ) else: - raise NotImplementedError(f"Please set shortcut_type to {shortcut_type} and bias_downsample to" \ - f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}" \ - f"when using pretrained MedicalNet resnet{resnet_depth}") + raise NotImplementedError( + f"Please set shortcut_type to {shortcut_type} and bias_downsample to" + f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}" + f"when using pretrained MedicalNet resnet{resnet_depth}" + ) else: raise NotImplementedError( - "Please set n_input_channels to 1" \ - "and feed_forward to False in order to use MedicalNet pretrained weights") + "Please set n_input_channels to 1" + "and feed_forward to False in order to use MedicalNet pretrained weights" + ) else: raise NotImplementedError("MedicalNet pretrained weights are only avalaible for 3D models") model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} model.load_state_dict(model_state_dict, strict=True) return model + def resnet10(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: """ResNet-10 with optional pretrained support when `spatial_dims` is 3. diff --git a/monai/networks/utils.py b/monai/networks/utils.py index ed7848d03c..d8775a4b73 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -36,8 +36,8 @@ onnx, _ = optional_import("onnx") onnxreference, _ = optional_import("onnx.reference") onnxruntime, _ = optional_import("onnxruntime") -hf_hub_download, _ = optional_import("huggingface_hub", name = "hf_hub_download") -EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name = "EntryNotFoundError") +hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download") +EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name="EntryNotFoundError") __all__ = [ "one_hot", @@ -1167,7 +1167,10 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") -def get_pretrained_resnet_medicalnet(resnet_depth: int, device: torch.device = torch.device("cpu"), datasets23: bool = True): + +def get_pretrained_resnet_medicalnet( + resnet_depth: int, device: torch.device = torch.device("cpu"), datasets23: bool = True +): """ Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet @@ -1189,25 +1192,32 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: torch.device = t MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" SUPPORTED_DEPTH = [10, 18, 34, 50, 101, 152, 200] - logger.info(f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}") + logger.info( + f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}" + ) if resnet_depth in SUPPORTED_DEPTH: - filename = f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" if not datasets23 \ + filename = ( + f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" + if not datasets23 else f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}_23dataset.pth" + ) try: pretrained_path = hf_hub_download( - repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", - filename=filename) + repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", filename=filename + ) except Exception: if datasets23: logger.info(f"{filename} not available for resnet{resnet_depth}") filename = f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" logger.info(f"Trying with {filename}") pretrained_path = hf_hub_download( - repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", - filename=filename) + repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", filename=filename + ) else: - raise EntryNotFoundError(f"{filename} not found on {MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}") + raise EntryNotFoundError( + f"{filename} not found on {MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}" + ) checkpoint = torch.load(pretrained_path, map_location=device) else: raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]") diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 9bd980accf..1d2d33920b 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -11,12 +11,12 @@ from __future__ import annotations -import unittest -from typing import TYPE_CHECKING +import copy import os -import sys import re -import copy +import sys +import unittest +from typing import TYPE_CHECKING import torch from parameterized import parameterized @@ -24,9 +24,9 @@ from monai.networks import eval_mode from monai.networks.nets import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 from monai.networks.nets.resnet import ResNetBlock -from monai.utils import optional_import from monai.networks.utils import get_pretrained_resnet_medicalnet -from tests.utils import test_script_save, equal_state_dict +from monai.utils import optional_import +from tests.utils import equal_state_dict, test_script_save if TYPE_CHECKING: import torchvision @@ -35,7 +35,7 @@ else: torchvision, has_torchvision = optional_import("torchvision") -has_hf_modules = ("huggingface_hub" in sys.modules and "huggingface_hub.utils._errors" in sys.modules) +has_hf_modules = "huggingface_hub" in sys.modules and "huggingface_hub.utils._errors" in sys.modules # from torchvision.models import ResNet50_Weights, resnet50 @@ -203,7 +203,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape # Custom pretrained weights cp_input_param["pretrained"] = tmp_ckpt_filename pretrained_net = model(**cp_input_param) - assert (equal_state_dict(net.state_dict(), pretrained_net.state_dict())) + assert equal_state_dict(net.state_dict(), pretrained_net.state_dict()) if has_hf_modules: # True flag @@ -211,7 +211,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1)) # Duplicate. see monai/networks/nets/resnet.py - def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : + def get_medicalnet_pretrained_resnet_args(resnet_depth: int): """ Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth """ @@ -225,12 +225,15 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) # With orig. test cases - if (input_param.get("spatial_dims", 3) == 3 and - input_param.get("n_input_channels", 3)==1 and - input_param.get("feed_forward", True) is False and - input_param.get("shortcut_type", "B") == shortcut_type and - (input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True) - ): + if ( + input_param.get("spatial_dims", 3) == 3 + and input_param.get("n_input_channels", 3) == 1 + and input_param.get("feed_forward", True) is False + and input_param.get("shortcut_type", "B") == shortcut_type + and ( + input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True + ) + ): model(**cp_input_param) else: with self.assertRaises(NotImplementedError): @@ -240,12 +243,14 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int) : cp_input_param["n_input_channels"] = 1 cp_input_param["feed_forward"] = False cp_input_param["shortcut_type"] = shortcut_type - cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample!=-1 else True - if cp_input_param.get("spatial_dims", 3)==3: + cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True + if cp_input_param.get("spatial_dims", 3) == 3: pretrained_net = model(**cp_input_param).to(device) - medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device = device) - medicalnet_state_dict = {key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()} - assert(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict)) + medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device) + medicalnet_state_dict = { + key.replace("module.", ""): value for key, value in medicalnet_state_dict.items() + } + assert equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict) # clean os.remove(tmp_ckpt_filename) diff --git a/tests/utils.py b/tests/utils.py index c556205342..3213e66d68 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -823,6 +823,7 @@ def command_line_tests(cmd, copy_env=True): errors = repr(e.stderr).replace("\\n", "\n").replace("\\t", "\t") raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e + def equal_state_dict(st_1, st_2): """ Compare 2 torch state dicts. @@ -840,7 +841,6 @@ def equal_state_dict(st_1, st_2): return r - TEST_TORCH_TENSORS: tuple = (torch.as_tensor,) if torch.cuda.is_available(): gpu_tensor: Callable = partial(torch.as_tensor, device="cuda") From 1993f0466aa8ea598c426ec9a84f3dd4a61ddeb2 Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Tue, 10 Oct 2023 23:53:56 +0200 Subject: [PATCH 18/31] Update utils.py Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> --- monai/networks/utils.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index d8775a4b73..7ff1dc0554 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1188,35 +1188,35 @@ def get_pretrained_resnet_medicalnet( NotImplementedError: if `resnet_depth` is not supported """ - MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" - MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" - SUPPORTED_DEPTH = [10, 18, 34, 50, 101, 152, 200] + medicalnet_huggingface_repo_basename = "TencentMedicalNet/MedicalNet-Resnet" + medicalnet_huggingface_files_basename = "resnet_" + supported_depth = [10, 18, 34, 50, 101, 152, 200] logger.info( - f"Loading MedicalNet pretrained model from https://huggingface.co/{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}" + f"Loading MedicalNet pretrained model from https://huggingface.co/{medicalnet_huggingface_repo_basename}{resnet_depth}" ) - if resnet_depth in SUPPORTED_DEPTH: + if resnet_depth in supported_depth: filename = ( - f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" + f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" if not datasets23 - else f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}_23dataset.pth" + else f"{medicalnet_huggingface_files_basename}{resnet_depth}_23dataset.pth" ) try: pretrained_path = hf_hub_download( - repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", filename=filename + repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename ) except Exception: if datasets23: logger.info(f"{filename} not available for resnet{resnet_depth}") - filename = f"{MEDICALNET_HUGGINGFACE_FILES_BASENAME}{resnet_depth}.pth" + filename = f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" logger.info(f"Trying with {filename}") pretrained_path = hf_hub_download( - repo_id=f"{MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}", filename=filename + repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename ) else: raise EntryNotFoundError( - f"{filename} not found on {MEDICALNET_HUGGINGFACE_REPO_BASENAME}{resnet_depth}" + f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}" ) checkpoint = torch.load(pretrained_path, map_location=device) else: From e34477137fde88e11c9b2b7dde612545663838bf Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Tue, 10 Oct 2023 23:56:19 +0200 Subject: [PATCH 19/31] Update utils.py Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 7ff1dc0554..ddff087f73 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1217,7 +1217,7 @@ def get_pretrained_resnet_medicalnet( else: raise EntryNotFoundError( f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}" - ) + ) from None checkpoint = torch.load(pretrained_path, map_location=device) else: raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]") From be516c042125a2a9954ff239aab6ee4f4c5f0c45 Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Wed, 11 Oct 2023 00:09:21 +0200 Subject: [PATCH 20/31] Update resnet.py Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> --- monai/networks/nets/resnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 801c9aeb16..3eb14574cc 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -362,7 +362,8 @@ def _resnet( def get_medicalnet_pretrained_resnet_args(resnet_depth: int): """ - Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth + Return correct shortcut_type and bias_downsample + for pretrained MedicalNet weights according to rensnet depth """ # After testing # False: 10, 50, 101, 152, 200 From 3dd89de0a473ccec8bcaf84668356e3e38c5f859 Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Wed, 11 Oct 2023 00:16:06 +0200 Subject: [PATCH 21/31] Update utils.py Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> --- monai/networks/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index ddff087f73..bc90b3c8fe 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1169,14 +1169,14 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): def get_pretrained_resnet_medicalnet( - resnet_depth: int, device: torch.device = torch.device("cpu"), datasets23: bool = True + resnet_depth: int, device: str = "cpu", datasets23: bool = True ): """ Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet Args: resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 - device: device on which the returned state dict will be loaded. + device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example. datasets23: if True, get the weights trained on more datasets (23). Not all depths are available. If not, standard weights are returned. @@ -1218,7 +1218,7 @@ def get_pretrained_resnet_medicalnet( raise EntryNotFoundError( f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}" ) from None - checkpoint = torch.load(pretrained_path, map_location=device) + checkpoint = torch.load(pretrained_path, map_location=torch.device(device)) else: raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]") logger.info(f"{filename} downloaded") From b74d48a208c84f7eeeaeced14d91f7695bdcaaf3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Oct 2023 22:16:29 +0000 Subject: [PATCH 22/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index bc90b3c8fe..b9215403f1 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1176,7 +1176,7 @@ def get_pretrained_resnet_medicalnet( Args: resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 - device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example. + device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example. datasets23: if True, get the weights trained on more datasets (23). Not all depths are available. If not, standard weights are returned. From 68962162d69c6a8f44f5277631bf81e1324016d9 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Mon, 16 Oct 2023 09:28:04 +0000 Subject: [PATCH 23/31] fix lint error Signed-off-by: vgrau98 --- monai/networks/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index b9215403f1..0d73ed970a 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1168,9 +1168,7 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") -def get_pretrained_resnet_medicalnet( - resnet_depth: int, device: str = "cpu", datasets23: bool = True -): +def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True): """ Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet From b30ea731d766ee3fcbfcd5139d95ef37c8c9b081 Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Mon, 16 Oct 2023 10:02:37 +0000 Subject: [PATCH 24/31] minor refactos Signed-off-by: vgrau98 --- monai/networks/nets/__init__.py | 2 + monai/networks/nets/resnet.py | 86 +++++++++++++++++++++++++++------ monai/networks/utils.py | 58 +--------------------- tests/test_resnet.py | 26 +++++----- 4 files changed, 87 insertions(+), 85 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index a0c8628172..3c51336584 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -66,6 +66,8 @@ resnet101, resnet152, resnet200, + get_medicalnet_pretrained_resnet_args, + get_pretrained_resnet_medicalnet ) from .segresnet import SegResNet, SegResNetVAE from .segresnet_ds import SegResNetDS diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 3eb14574cc..7a243f8eef 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -23,9 +23,11 @@ from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer -from monai.networks.utils import get_pretrained_resnet_medicalnet from monai.utils import ensure_tuple_rep -from monai.utils.module import look_up_option +from monai.utils.module import look_up_option, optional_import + +hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download") +EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name="EntryNotFoundError") MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" @@ -360,18 +362,6 @@ def _resnet( resnet_depth = int(re.search(r"resnet(\d+)", arch).group(1)) # get shortcut_type and bias_downsample. - def get_medicalnet_pretrained_resnet_args(resnet_depth: int): - """ - Return correct shortcut_type and bias_downsample - for pretrained MedicalNet weights according to rensnet depth - """ - # After testing - # False: 10, 50, 101, 152, 200 - # Any: 18, 34 - bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 - shortcut_type = "A" if resnet_depth in [18, 34] else "B" - return bias_downsample, shortcut_type - # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) if shortcut_type == kwargs.get("shortcut_type", "B") and ( @@ -481,3 +471,71 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet("resnet200", ResNetBottleneck, [3, 24, 36, 3], get_inplanes(), pretrained, progress, **kwargs) + + +def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True): + """ + Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet + + Args: + resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 + device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example. + datasets23: if True, get the weights trained on more datasets (23). + Not all depths are available. If not, standard weights are returned. + + Returns: + Pretrained state dict + + Raises: + huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub + NotImplementedError: if `resnet_depth` is not supported + """ + + medicalnet_huggingface_repo_basename = "TencentMedicalNet/MedicalNet-Resnet" + medicalnet_huggingface_files_basename = "resnet_" + supported_depth = [10, 18, 34, 50, 101, 152, 200] + + logger.info( + f"Loading MedicalNet pretrained model from https://huggingface.co/{medicalnet_huggingface_repo_basename}{resnet_depth}" + ) + + if resnet_depth in supported_depth: + filename = ( + f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" + if not datasets23 + else f"{medicalnet_huggingface_files_basename}{resnet_depth}_23dataset.pth" + ) + try: + pretrained_path = hf_hub_download( + repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename + ) + except Exception: + if datasets23: + logger.info(f"{filename} not available for resnet{resnet_depth}") + filename = f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" + logger.info(f"Trying with {filename}") + pretrained_path = hf_hub_download( + repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename + ) + else: + raise EntryNotFoundError( + f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}" + ) from None + checkpoint = torch.load(pretrained_path, map_location=torch.device(device)) + else: + raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]") + logger.info(f"{filename} downloaded") + return checkpoint.get("state_dict") + + +def get_medicalnet_pretrained_resnet_args(resnet_depth: int): + """ + Return correct shortcut_type and bias_downsample + for pretrained MedicalNet weights according to resnet depth + """ + # After testing + # False: 10, 50, 101, 152, 200 + # Any: 18, 34 + bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 + shortcut_type = "A" if resnet_depth in [18, 34] else "B" + return bias_downsample, shortcut_type diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 0d73ed970a..6112d33100 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -36,8 +36,7 @@ onnx, _ = optional_import("onnx") onnxreference, _ = optional_import("onnx.reference") onnxruntime, _ = optional_import("onnxruntime") -hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download") -EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name="EntryNotFoundError") + __all__ = [ "one_hot", @@ -1166,58 +1165,3 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.") logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") - - -def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True): - """ - Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet - - Args: - resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 - device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example. - datasets23: if True, get the weights trained on more datasets (23). - Not all depths are available. If not, standard weights are returned. - - Returns: - Pretrained state dict - - Raises: - huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub - NotImplementedError: if `resnet_depth` is not supported - """ - - medicalnet_huggingface_repo_basename = "TencentMedicalNet/MedicalNet-Resnet" - medicalnet_huggingface_files_basename = "resnet_" - supported_depth = [10, 18, 34, 50, 101, 152, 200] - - logger.info( - f"Loading MedicalNet pretrained model from https://huggingface.co/{medicalnet_huggingface_repo_basename}{resnet_depth}" - ) - - if resnet_depth in supported_depth: - filename = ( - f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" - if not datasets23 - else f"{medicalnet_huggingface_files_basename}{resnet_depth}_23dataset.pth" - ) - try: - pretrained_path = hf_hub_download( - repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename - ) - except Exception: - if datasets23: - logger.info(f"{filename} not available for resnet{resnet_depth}") - filename = f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" - logger.info(f"Trying with {filename}") - pretrained_path = hf_hub_download( - repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename - ) - else: - raise EntryNotFoundError( - f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}" - ) from None - checkpoint = torch.load(pretrained_path, map_location=torch.device(device)) - else: - raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]") - logger.info(f"{filename} downloaded") - return checkpoint.get("state_dict") diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 1d2d33920b..2e75aed6ca 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -22,9 +22,19 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 +from monai.networks.nets import ( + ResNet, + get_medicalnet_pretrained_resnet_args, + get_pretrained_resnet_medicalnet, + resnet10, + resnet18, + resnet34, + resnet50, + resnet101, + resnet152, + resnet200, +) from monai.networks.nets.resnet import ResNetBlock -from monai.networks.utils import get_pretrained_resnet_medicalnet from monai.utils import optional_import from tests.utils import equal_state_dict, test_script_save @@ -210,18 +220,6 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape cp_input_param["pretrained"] = True resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1)) - # Duplicate. see monai/networks/nets/resnet.py - def get_medicalnet_pretrained_resnet_args(resnet_depth: int): - """ - Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth - """ - # After testing - # False: 10, 50, 101, 152, 200 - # Any: 18, 34 - bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 - shortcut_type = "A" if resnet_depth in [18, 34] else "B" - return bias_downsample, shortcut_type - bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) # With orig. test cases From 2945704b3b9a1cef03a52cab642d0289f7d78e02 Mon Sep 17 00:00:00 2001 From: vgrau98 <35843843+vgrau98@users.noreply.github.com> Date: Mon, 16 Oct 2023 12:17:07 +0200 Subject: [PATCH 25/31] fix lint error Signed-off-by: vgrau98 <35843843+vgrau98@users.noreply.github.com> --- monai/networks/nets/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 3c51336584..5428904980 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -59,6 +59,8 @@ ResNet, ResNetBlock, ResNetBottleneck, + get_medicalnet_pretrained_resnet_args, + get_pretrained_resnet_medicalnet resnet10, resnet18, resnet34, @@ -66,8 +68,6 @@ resnet101, resnet152, resnet200, - get_medicalnet_pretrained_resnet_args, - get_pretrained_resnet_medicalnet ) from .segresnet import SegResNet, SegResNetVAE from .segresnet_ds import SegResNetDS From 7a01bb575d5ea82c9abb039ba186872de8cf01ce Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Mon, 16 Oct 2023 10:29:35 +0000 Subject: [PATCH 26/31] fix typo Signed-off-by: vgrau98 --- monai/networks/nets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 5428904980..1fb0f08ccc 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -60,7 +60,7 @@ ResNetBlock, ResNetBottleneck, get_medicalnet_pretrained_resnet_args, - get_pretrained_resnet_medicalnet + get_pretrained_resnet_medicalnet, resnet10, resnet18, resnet34, From 8198203cda52041dad69c5fa294bd813427bf11e Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Tue, 17 Oct 2023 08:25:42 +0000 Subject: [PATCH 27/31] fix mypy error Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 7a243f8eef..ebcad092e3 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -359,8 +359,11 @@ def _resnet( # Also check bias downsample and shortcut. if kwargs.get("spatial_dims", 3) == 3: if kwargs.get("n_input_channels", 3) == 1 and kwargs.get("feed_forward", True) is False: - resnet_depth = int(re.search(r"resnet(\d+)", arch).group(1)) - # get shortcut_type and bias_downsample. + search_res = re.search(r"resnet(\d+)", arch) + if search_res: + resnet_depth = int(search_res.group(1)) + else: + raise ValueError("arch argument should be as 'resnet_\{resnet_depth\}") # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) From 2930f97b083073b64fc5e278f1caf130bd1d2ac6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Oct 2023 08:26:17 +0000 Subject: [PATCH 28/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index ebcad092e3..feecfebc5d 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -363,7 +363,7 @@ def _resnet( if search_res: resnet_depth = int(search_res.group(1)) else: - raise ValueError("arch argument should be as 'resnet_\{resnet_depth\}") + raise ValueError(r"arch argument should be as 'resnet_\{resnet_depth\}") # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) From 7a0a2c37d1dbd268661410a789dc787edc13f5de Mon Sep 17 00:00:00 2001 From: vgrau98 Date: Tue, 17 Oct 2023 08:44:20 +0000 Subject: [PATCH 29/31] fix lint error Signed-off-by: vgrau98 --- monai/networks/nets/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index feecfebc5d..029229f484 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -363,7 +363,7 @@ def _resnet( if search_res: resnet_depth = int(search_res.group(1)) else: - raise ValueError(r"arch argument should be as 'resnet_\{resnet_depth\}") + raise ValueError("arch argument should be as 'resnet_{resnet_depth}") # Check model bias_downsample and shortcut_type bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) From 741970d359998169fcfebd68e5262e1a00e86768 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 18 Oct 2023 18:37:15 +0100 Subject: [PATCH 30/31] update unit test Signed-off-by: Wenqi Li --- tests/test_resnet.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 2e75aed6ca..a6ea70b377 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -36,7 +36,7 @@ ) from monai.networks.nets.resnet import ResNetBlock from monai.utils import optional_import -from tests.utils import equal_state_dict, test_script_save +from tests.utils import equal_state_dict, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, test_script_save if TYPE_CHECKING: import torchvision @@ -192,6 +192,16 @@ class TestResNet(unittest.TestCase): + def setUp(self): + self.tmp_ckpt_filename = os.path.join("tests", "monai_unittest_tmp_ckpt.pth") + + def tearDown(self): + if os.path.exists(self.tmp_ckpt_filename): + try: + os.remove(self.tmp_ckpt_filename) + except BaseException: + pass + @parameterized.expand(TEST_CASES) def test_resnet_shape(self, model, input_param, input_shape, expected_shape): net = model(**input_param).to(device) @@ -203,17 +213,18 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): self.assertTrue(result.shape in expected_shape) @parameterized.expand(PRETRAINED_TEST_CASES) + @skip_if_quick + @skip_if_no_cuda def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape): net = model(**input_param).to(device) - tmp_ckpt_filename = "monai_unittest_tmp_ckpt.pth" # Save ckpt - torch.save(net.state_dict(), tmp_ckpt_filename) + torch.save(net.state_dict(), self.tmp_ckpt_filename) cp_input_param = copy.copy(input_param) # Custom pretrained weights - cp_input_param["pretrained"] = tmp_ckpt_filename + cp_input_param["pretrained"] = self.tmp_ckpt_filename pretrained_net = model(**cp_input_param) - assert equal_state_dict(net.state_dict(), pretrained_net.state_dict()) + self.assertTrue(equal_state_dict(net.state_dict(), pretrained_net.state_dict())) if has_hf_modules: # True flag @@ -243,15 +254,13 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape cp_input_param["shortcut_type"] = shortcut_type cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True if cp_input_param.get("spatial_dims", 3) == 3: - pretrained_net = model(**cp_input_param).to(device) - medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device) - medicalnet_state_dict = { - key.replace("module.", ""): value for key, value in medicalnet_state_dict.items() - } - assert equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict) - - # clean - os.remove(tmp_ckpt_filename) + with skip_if_downloading_fails(): + pretrained_net = model(**cp_input_param).to(device) + medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device) + medicalnet_state_dict = { + key.replace("module.", ""): value for key, value in medicalnet_state_dict.items() + } + self.assertTrue(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict)) @parameterized.expand(TEST_SCRIPT_CASES) def test_script(self, model, input_param, input_shape, expected_shape): From 89eda2c208994d8886f4ec040c5f59173753573f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 18 Oct 2023 20:44:04 +0100 Subject: [PATCH 31/31] local torch.cuda check Signed-off-by: Wenqi Li --- monai/networks/nets/resnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 029229f484..fca73f4de3 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -47,8 +47,6 @@ logger = logging.getLogger(__name__) -device = "cuda" if torch.cuda.is_available() else "cpu" - def get_inplanes(): return [64, 128, 256, 512] @@ -348,6 +346,7 @@ def _resnet( ) -> ResNet: model: ResNet = ResNet(block, layers, block_inplanes, **kwargs) if pretrained: + device = "cuda" if torch.cuda.is_available() else "cpu" if isinstance(pretrained, str): if Path(pretrained).exists(): logger.info(f"Loading weights from {pretrained}...")