diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index 323e107fd7..3ec1cea37e 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -443,6 +443,8 @@ class HoVerNet(nn.Module): adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format, this value should be `True`. + pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True. + It is used to extract the expected state dict. freeze_encoder: whether to freeze the encoder of the network. """ @@ -461,6 +463,7 @@ def __init__( dropout_prob: float = 0.0, pretrained_url: str | None = None, adapt_standard_resnet: bool = False, + pretrained_state_dict_key: str | None = None, freeze_encoder: bool = False, ) -> None: super().__init__() @@ -566,7 +569,7 @@ def __init__( if pretrained_url is not None: if adapt_standard_resnet: - weights = _remap_standard_resnet_model(pretrained_url) + weights = _remap_standard_resnet_model(pretrained_url, state_dict_key=pretrained_state_dict_key) else: weights = _remap_preact_resnet_model(pretrained_url) _load_pretrained_encoder(self, weights) @@ -609,6 +612,12 @@ def _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict): model_dict.update(state_dict) model.load_state_dict(model_dict) + if len(state_dict.keys()) == 0: + warnings.warn( + "no key will be updated. Please check if 'pretrained_url' or `pretrained_state_dict_key` is correct." + ) + else: + print(f"{len(state_dict)} out of {len(model_dict)} keys are updated with pretrained weights.") def _remap_preact_resnet_model(model_url: str): @@ -619,7 +628,9 @@ def _remap_preact_resnet_model(model_url: str): # download the pretrained weights into torch hub's default dir weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth") download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) - state_dict = torch.load(weights_dir, map_location=None)["desc"] + state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[ + "desc" + ] for key in list(state_dict.keys()): new_key = None if pattern_conv0.match(key): @@ -639,7 +650,7 @@ def _remap_preact_resnet_model(model_url: str): return state_dict -def _remap_standard_resnet_model(model_url: str): +def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = None): pattern_conv0 = re.compile(r"^conv1\.(.+)$") pattern_bn1 = re.compile(r"^bn1\.(.+)$") pattern_block = re.compile(r"^layer(\d+)\.(\d+)\.(.+)$") @@ -652,7 +663,9 @@ def _remap_standard_resnet_model(model_url: str): # download the pretrained weights into torch hub's default dir weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth") download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) - state_dict = torch.load(weights_dir, map_location=None) + state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu")) + if state_dict_key is not None: + state_dict = state_dict[state_dict_key] for key in list(state_dict.keys()): new_key = None