From 2102c43928a072e5185380ef45ab414cd7013c07 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sun, 2 Apr 2023 15:29:35 +0800 Subject: [PATCH 1/4] enhance hovernet Signed-off-by: Yiheng Wang --- monai/networks/nets/hovernet.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index 323e107fd7..0a82b1d48b 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -443,6 +443,8 @@ class HoVerNet(nn.Module): adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format, this value should be `True`. + pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True. It is used to + extract the expected state dict. freeze_encoder: whether to freeze the encoder of the network. """ @@ -461,6 +463,7 @@ def __init__( dropout_prob: float = 0.0, pretrained_url: str | None = None, adapt_standard_resnet: bool = False, + pretrained_state_dict_key: str | None = None, freeze_encoder: bool = False, ) -> None: super().__init__() @@ -566,7 +569,7 @@ def __init__( if pretrained_url is not None: if adapt_standard_resnet: - weights = _remap_standard_resnet_model(pretrained_url) + weights = _remap_standard_resnet_model(pretrained_url, state_dict_key=pretrained_state_dict_key) else: weights = _remap_preact_resnet_model(pretrained_url) _load_pretrained_encoder(self, weights) @@ -609,6 +612,12 @@ def _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict): model_dict.update(state_dict) model.load_state_dict(model_dict) + if len(state_dict.keys()) == 0: + warnings.warn( + "no key will be updated. Please double confirm if the 'pretrained_url' is correct or `pretrained_state_dict_key` is reasonably set." + ) + else: + print(f"{len(state_dict)} out of {len(model_dict)} keys are updated with pretrained weights.") def _remap_preact_resnet_model(model_url: str): @@ -619,7 +628,9 @@ def _remap_preact_resnet_model(model_url: str): # download the pretrained weights into torch hub's default dir weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth") download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) - state_dict = torch.load(weights_dir, map_location=None)["desc"] + state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[ + "desc" + ] for key in list(state_dict.keys()): new_key = None if pattern_conv0.match(key): @@ -639,7 +650,7 @@ def _remap_preact_resnet_model(model_url: str): return state_dict -def _remap_standard_resnet_model(model_url: str): +def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = None): pattern_conv0 = re.compile(r"^conv1\.(.+)$") pattern_bn1 = re.compile(r"^bn1\.(.+)$") pattern_block = re.compile(r"^layer(\d+)\.(\d+)\.(.+)$") @@ -652,7 +663,9 @@ def _remap_standard_resnet_model(model_url: str): # download the pretrained weights into torch hub's default dir weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth") download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) - state_dict = torch.load(weights_dir, map_location=None) + state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu")) + if state_dict_key is not None: + state_dict = state_dict[state_dict_key] for key in list(state_dict.keys()): new_key = None From f53c66ce5a28bf7a4e866600ed80d62e8ebeddc8 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 3 Apr 2023 17:09:00 +0800 Subject: [PATCH 2/4] reduce line length Signed-off-by: Yiheng Wang --- monai/networks/nets/hovernet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index 0a82b1d48b..f5be9f28b4 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -614,7 +614,7 @@ def _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict): model.load_state_dict(model_dict) if len(state_dict.keys()) == 0: warnings.warn( - "no key will be updated. Please double confirm if the 'pretrained_url' is correct or `pretrained_state_dict_key` is reasonably set." + "no key will be updated. Please check if 'pretrained_url' or `pretrained_state_dict_key` is correct." ) else: print(f"{len(state_dict)} out of {len(model_dict)} keys are updated with pretrained weights.") From bfdce30c2a000846e4518a98af852fc5c0d4c11e Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 3 Apr 2023 22:53:31 +0800 Subject: [PATCH 3/4] fix flake8 Signed-off-by: Yiheng Wang --- monai/networks/nets/hovernet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index f5be9f28b4..3ec1cea37e 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -443,8 +443,8 @@ class HoVerNet(nn.Module): adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format, this value should be `True`. - pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True. It is used to - extract the expected state dict. + pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True. + It is used to extract the expected state dict. freeze_encoder: whether to freeze the encoder of the network. """ From 3ceeff193adce6d88a11364d0c4344d9b9cd7c85 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 3 Apr 2023 19:20:23 +0100 Subject: [PATCH 4/4] fixes mypy error Signed-off-by: Wenqi Li --- monai/data/image_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 03bffbb1e8..43086964ae 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -929,7 +929,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): for name in filenames: img = nib.load(name, **kwargs_) img = correct_nifti_header_if_necessary(img) - img_.append(img) + img_.append(img) # type: ignore return img_ if len(filenames) > 1 else img_[0] def get_data(self, img) -> tuple[np.ndarray, dict]: