diff --git a/docs/requirements.txt b/docs/requirements.txt index c4bd50ec2b..a9bbc384f8 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -39,3 +39,4 @@ opencv-python-headless onnx>=1.13.0 onnxruntime; python_version <= '3.10' zarr +huggingface_hub diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index a0c8628172..1fb0f08ccc 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -59,6 +59,8 @@ ResNet, ResNetBlock, ResNetBottleneck, + get_medicalnet_pretrained_resnet_args, + get_pretrained_resnet_medicalnet, resnet10, resnet18, resnet34, diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index e742db5ca5..fca73f4de3 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -11,8 +11,11 @@ from __future__ import annotations +import logging +import re from collections.abc import Callable from functools import partial +from pathlib import Path from typing import Any import torch @@ -21,7 +24,13 @@ from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer from monai.utils import ensure_tuple_rep -from monai.utils.module import look_up_option +from monai.utils.module import look_up_option, optional_import + +hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download") +EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name="EntryNotFoundError") + +MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" +MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" __all__ = [ "ResNet", @@ -36,6 +45,8 @@ "resnet200", ] +logger = logging.getLogger(__name__) + def get_inplanes(): return [64, 128, 256, 512] @@ -329,21 +340,54 @@ def _resnet( block: type[ResNetBlock | ResNetBottleneck], layers: list[int], block_inplanes: list[int], - pretrained: bool, + pretrained: bool | str, progress: bool, **kwargs: Any, ) -> ResNet: model: ResNet = ResNet(block, layers, block_inplanes, **kwargs) if pretrained: - # Author of paper zipped the state_dict on googledrive, - # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). - # Would like to load dict from url but need somewhere to save the state dicts. - raise NotImplementedError( - "Currently not implemented. You need to manually download weights provided by the paper's author" - " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet" - "Please ensure you pass the appropriate `shortcut_type` and `bias_downsample` args. as specified" - "here: https://github.com/Tencent/MedicalNet/tree/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b#update20190730" - ) + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(pretrained, str): + if Path(pretrained).exists(): + logger.info(f"Loading weights from {pretrained}...") + model_state_dict = torch.load(pretrained, map_location=device) + else: + # Throw error + raise FileNotFoundError("The pretrained checkpoint file is not found") + else: + # Also check bias downsample and shortcut. + if kwargs.get("spatial_dims", 3) == 3: + if kwargs.get("n_input_channels", 3) == 1 and kwargs.get("feed_forward", True) is False: + search_res = re.search(r"resnet(\d+)", arch) + if search_res: + resnet_depth = int(search_res.group(1)) + else: + raise ValueError("arch argument should be as 'resnet_{resnet_depth}") + + # Check model bias_downsample and shortcut_type + bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) + if shortcut_type == kwargs.get("shortcut_type", "B") and ( + bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True + ): + # Download the MedicalNet pretrained model + model_state_dict = get_pretrained_resnet_medicalnet( + resnet_depth, device=device, datasets23=True + ) + else: + raise NotImplementedError( + f"Please set shortcut_type to {shortcut_type} and bias_downsample to" + f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}" + f"when using pretrained MedicalNet resnet{resnet_depth}" + ) + else: + raise NotImplementedError( + "Please set n_input_channels to 1" + "and feed_forward to False in order to use MedicalNet pretrained weights" + ) + else: + raise NotImplementedError("MedicalNet pretrained weights are only avalaible for 3D models") + model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} + model.load_state_dict(model_state_dict, strict=True) return model @@ -429,3 +473,71 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet("resnet200", ResNetBottleneck, [3, 24, 36, 3], get_inplanes(), pretrained, progress, **kwargs) + + +def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True): + """ + Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet + + Args: + resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 + device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example. + datasets23: if True, get the weights trained on more datasets (23). + Not all depths are available. If not, standard weights are returned. + + Returns: + Pretrained state dict + + Raises: + huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub + NotImplementedError: if `resnet_depth` is not supported + """ + + medicalnet_huggingface_repo_basename = "TencentMedicalNet/MedicalNet-Resnet" + medicalnet_huggingface_files_basename = "resnet_" + supported_depth = [10, 18, 34, 50, 101, 152, 200] + + logger.info( + f"Loading MedicalNet pretrained model from https://huggingface.co/{medicalnet_huggingface_repo_basename}{resnet_depth}" + ) + + if resnet_depth in supported_depth: + filename = ( + f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" + if not datasets23 + else f"{medicalnet_huggingface_files_basename}{resnet_depth}_23dataset.pth" + ) + try: + pretrained_path = hf_hub_download( + repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename + ) + except Exception: + if datasets23: + logger.info(f"{filename} not available for resnet{resnet_depth}") + filename = f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" + logger.info(f"Trying with {filename}") + pretrained_path = hf_hub_download( + repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename + ) + else: + raise EntryNotFoundError( + f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}" + ) from None + checkpoint = torch.load(pretrained_path, map_location=torch.device(device)) + else: + raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]") + logger.info(f"{filename} downloaded") + return checkpoint.get("state_dict") + + +def get_medicalnet_pretrained_resnet_args(resnet_depth: int): + """ + Return correct shortcut_type and bias_downsample + for pretrained MedicalNet weights according to resnet depth + """ + # After testing + # False: 10, 50, 101, 152, 200 + # Any: 18, 34 + bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 + shortcut_type = "A" if resnet_depth in [18, 34] else "B" + return bias_downsample, shortcut_type diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 42e537648a..6112d33100 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -37,6 +37,7 @@ onnxreference, _ = optional_import("onnx.reference") onnxruntime, _ = optional_import("onnxruntime") + __all__ = [ "one_hot", "predict_segmentation", diff --git a/requirements-dev.txt b/requirements-dev.txt index d47e95a4f5..38715b8449 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -56,3 +56,4 @@ filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 nvidia-ml-py +huggingface_hub diff --git a/setup.cfg b/setup.cfg index 9d5a22963c..fe5fbf3b5b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -83,6 +83,7 @@ all = zarr lpips==0.1.4 nvidia-ml-py + huggingface_hub nibabel = nibabel ninja = diff --git a/tests/test_resnet.py b/tests/test_resnet.py index cc24106373..a6ea70b377 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -11,6 +11,10 @@ from __future__ import annotations +import copy +import os +import re +import sys import unittest from typing import TYPE_CHECKING @@ -18,10 +22,21 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 +from monai.networks.nets import ( + ResNet, + get_medicalnet_pretrained_resnet_args, + get_pretrained_resnet_medicalnet, + resnet10, + resnet18, + resnet34, + resnet50, + resnet101, + resnet152, + resnet200, +) from monai.networks.nets.resnet import ResNetBlock from monai.utils import optional_import -from tests.utils import test_script_save +from tests.utils import equal_state_dict, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, test_script_save if TYPE_CHECKING: import torchvision @@ -30,6 +45,10 @@ else: torchvision, has_torchvision = optional_import("torchvision") +has_hf_modules = "huggingface_hub" in sys.modules and "huggingface_hub.utils._errors" in sys.modules + +# from torchvision.models import ResNet50_Weights, resnet50 + device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ # 3D, batch 3, 2 input channel @@ -159,9 +178,11 @@ ] TEST_CASES = [] +PRETRAINED_TEST_CASES = [] for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) + PRETRAINED_TEST_CASES.append([model, *case]) for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]: TEST_CASES.append([ResNet, *case]) @@ -171,6 +192,16 @@ class TestResNet(unittest.TestCase): + def setUp(self): + self.tmp_ckpt_filename = os.path.join("tests", "monai_unittest_tmp_ckpt.pth") + + def tearDown(self): + if os.path.exists(self.tmp_ckpt_filename): + try: + os.remove(self.tmp_ckpt_filename) + except BaseException: + pass + @parameterized.expand(TEST_CASES) def test_resnet_shape(self, model, input_param, input_shape, expected_shape): net = model(**input_param).to(device) @@ -181,6 +212,56 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): else: self.assertTrue(result.shape in expected_shape) + @parameterized.expand(PRETRAINED_TEST_CASES) + @skip_if_quick + @skip_if_no_cuda + def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape): + net = model(**input_param).to(device) + # Save ckpt + torch.save(net.state_dict(), self.tmp_ckpt_filename) + + cp_input_param = copy.copy(input_param) + # Custom pretrained weights + cp_input_param["pretrained"] = self.tmp_ckpt_filename + pretrained_net = model(**cp_input_param) + self.assertTrue(equal_state_dict(net.state_dict(), pretrained_net.state_dict())) + + if has_hf_modules: + # True flag + cp_input_param["pretrained"] = True + resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1)) + + bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) + + # With orig. test cases + if ( + input_param.get("spatial_dims", 3) == 3 + and input_param.get("n_input_channels", 3) == 1 + and input_param.get("feed_forward", True) is False + and input_param.get("shortcut_type", "B") == shortcut_type + and ( + input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True + ) + ): + model(**cp_input_param) + else: + with self.assertRaises(NotImplementedError): + model(**cp_input_param) + + # forcing MedicalNet pretrained download for 3D tests cases + cp_input_param["n_input_channels"] = 1 + cp_input_param["feed_forward"] = False + cp_input_param["shortcut_type"] = shortcut_type + cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True + if cp_input_param.get("spatial_dims", 3) == 3: + with skip_if_downloading_fails(): + pretrained_net = model(**cp_input_param).to(device) + medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device) + medicalnet_state_dict = { + key.replace("module.", ""): value for key, value in medicalnet_state_dict.items() + } + self.assertTrue(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict)) + @parameterized.expand(TEST_SCRIPT_CASES) def test_script(self, model, input_param, input_shape, expected_shape): net = model(**input_param) diff --git a/tests/utils.py b/tests/utils.py index a8efbe081e..b391111fd5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -825,6 +825,23 @@ def command_line_tests(cmd, copy_env=True): raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e +def equal_state_dict(st_1, st_2): + """ + Compare 2 torch state dicts. + """ + r = True + for key_st_1, val_st_1 in st_1.items(): + if key_st_1 in st_2: + val_st_2 = st_2.get(key_st_1) + if not torch.equal(val_st_1, val_st_2): + r = False + break + else: + r = False + break + return r + + TEST_TORCH_TENSORS: tuple = (torch.as_tensor,) if torch.cuda.is_available(): gpu_tensor: Callable = partial(torch.as_tensor, device="cuda")