From 99215abcd44308c91b7c9d595384b4bf20974cdb Mon Sep 17 00:00:00 2001 From: Konstantin Sukharev Date: Sun, 24 Mar 2024 18:33:50 +0500 Subject: [PATCH 1/7] Add ResNet backbones for FlexibleUNet Signed-off-by: Konstantin Sukharev --- monai/networks/nets/__init__.py | 2 + monai/networks/nets/flexible_unet.py | 8 +- monai/networks/nets/resnet.py | 340 +++++++++++++++++++++++++++ tests/test_flexible_unet.py | 143 +++-------- tests/test_resnet.py | 45 ++++ 5 files changed, 430 insertions(+), 108 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9247aaee85..de5d1adc7e 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -59,6 +59,8 @@ ResNet, ResNetBlock, ResNetBottleneck, + ResNetEncoder, + ResNetFeatures, get_medicalnet_pretrained_resnet_args, get_pretrained_resnet_medicalnet, resnet10, diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py index ac2124b5f9..747ee76163 100644 --- a/monai/networks/nets/flexible_unet.py +++ b/monai/networks/nets/flexible_unet.py @@ -24,6 +24,7 @@ from monai.networks.layers.utils import get_act_layer from monai.networks.nets import EfficientNetEncoder from monai.networks.nets.basic_unet import UpCat +from monai.networks.nets.resnet import ResNetEncoder from monai.utils import InterpolateMode, optional_import __all__ = ["FlexibleUNet", "FlexUNet", "FLEXUNET_BACKBONE", "FlexUNetEncoderRegister"] @@ -78,6 +79,7 @@ def register_class(self, name: type[Any] | str): FLEXUNET_BACKBONE = FlexUNetEncoderRegister() FLEXUNET_BACKBONE.register_class(EfficientNetEncoder) +FLEXUNET_BACKBONE.register_class(ResNetEncoder) class UNetDecoder(nn.Module): @@ -238,7 +240,7 @@ def __init__( ) -> None: """ A flexible implement of UNet, in which the backbone/encoder can be replaced with - any efficient network. Currently the input must have a 2 or 3 spatial dimension + any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension and the spatial size of each dimension must be a multiple of 32 if is_pad parameter is False. Please notice each output of backbone must be 2x downsample in spatial dimension @@ -248,8 +250,8 @@ def __init__( Args: in_channels: number of input channels. out_channels: number of output channels. - backbone: name of backbones to initialize, only support efficientnet right now, - can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2]. + backbone: name of backbones to initialize, only support efficientnet and resnet right now, + can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200]. pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2 and batch norm is used, default to False. decoder_channels: number of output channels for all feature maps in decoder. diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 34a4b7057e..6aeb869b4c 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -21,6 +21,7 @@ import torch import torch.nn as nn +from monai.networks.blocks.encoder import BaseEncoder from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer from monai.utils import ensure_tuple_rep @@ -335,6 +336,332 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class ResNetFeatures(ResNet): + + def __init__( + self, + model_name: str, + pretrained: bool = True, + block: type[ResNetBlock | ResNetBottleneck] | str = ResNetBlock, + layers: list[int] = (1, 1, 1, 1), + block_inplanes: list[int] = (64, 128, 256, 512), + spatial_dims: int = 3, + in_channels: int = 1, + conv1_t_size: tuple[int] | int = 7, + conv1_t_stride: tuple[int] | int = 2, + no_max_pool: bool = False, + shortcut_type: str = "B", + widen_factor: float = 1.0, + num_classes: int = 400, + feed_forward: bool = False, + bias_downsample: bool = False, + ) -> None: + """Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for + segmentation and objection models. + + Compared with the class `ResNet`, the only different place is the forward function. + """ + if model_name not in ResNetEncoder.backbone_names: + model_name_string = ", ".join(ResNetEncoder.backbone_names) + raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ") + + super().__init__( + block=block, + layers=layers, + block_inplanes=block_inplanes, + spatial_dims=spatial_dims, + n_input_channels=in_channels, + conv1_t_size=conv1_t_size, + conv1_t_stride=conv1_t_stride, + no_max_pool=no_max_pool, + shortcut_type=shortcut_type, + widen_factor=widen_factor, + num_classes=num_classes, + feed_forward=feed_forward, + bias_downsample=bias_downsample, + ) + if pretrained: + if spatial_dims == 3 and in_channels == 1: + _load_state_dict(self, model_name) + else: + raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.") + + def forward(self, inputs: torch.Tensor) -> list[torch.Tensor]: + """ + Args: + inputs: input should have spatially N dimensions + ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. + + Returns: + a list of torch Tensors. + """ + x = self.conv1(inputs) + x = self.bn1(x) + x = self.relu(x) + + features = [] + features.append(x) + + if not self.no_max_pool: + x = self.maxpool(x) + + x = self.layer1(x) + features.append(x) + + x = self.layer2(x) + features.append(x) + + x = self.layer3(x) + features.append(x) + + x = self.layer4(x) + features.append(x) + + return features + + +class ResNetEncoder(ResNetFeatures, BaseEncoder): + """Wrap the original resnet to an encoder for flexible-unet.""" + + backbone_names = [ + "resnet10", + "resnet10_23datasets", + "resnet18", + "resnet18_23datasets", + "resnet34", + "resnet34_23datasets", + "resnet50", + "resnet50_23datasets", + "resnet101", + "resnet152", + "resnet200", + ] + + @classmethod + def get_encoder_parameters(cls) -> list[dict]: + """Get the initialization parameter for resnet backbones.""" + parameter_list = [ + { + "model_name": "resnet10", + "pretrained": True, + "block": ResNetBlock, + "layers": [1, 1, 1, 1], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "B", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": False, + }, + { + "model_name": "resnet10_23datasets", + "pretrained": True, + "block": ResNetBlock, + "layers": [1, 1, 1, 1], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "B", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": False, + }, + { + "model_name": "resnet18", + "pretrained": True, + "block": ResNetBlock, + "layers": [2, 2, 2, 2], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "A", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": True, + }, + { + "model_name": "resnet18_23datasets", + "pretrained": True, + "block": ResNetBlock, + "layers": [2, 2, 2, 2], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "A", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": True, + }, + { + "model_name": "resnet34", + "pretrained": True, + "block": ResNetBlock, + "layers": [3, 4, 6, 3], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "A", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": True, + }, + { + "model_name": "resnet34_23datasets", + "pretrained": True, + "block": ResNetBlock, + "layers": [3, 4, 6, 3], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "A", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": True, + }, + { + "model_name": "resnet50", + "pretrained": True, + "block": ResNetBottleneck, + "layers": [3, 4, 6, 3], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "B", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": False, + }, + { + "model_name": "resnet50_23datasets", + "pretrained": True, + "block": ResNetBottleneck, + "layers": [3, 4, 6, 3], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "B", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": False, + }, + { + "model_name": "resnet101", + "pretrained": True, + "block": ResNetBottleneck, + "layers": [3, 4, 23, 3], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "B", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": False, + }, + { + "model_name": "resnet152", + "pretrained": True, + "block": ResNetBottleneck, + "layers": [3, 8, 36, 3], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "B", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": False, + }, + { + "model_name": "resnet200", + "pretrained": True, + "block": ResNetBottleneck, + "layers": [3, 24, 36, 3], + "block_inplanes": get_inplanes(), + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "B", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": False, + }, + ] + return parameter_list + + @classmethod + def num_channels_per_output(cls) -> list[tuple[int, ...]]: + """Get number of resnet backbone output feature maps channel.""" + return [ + (64, 64, 128, 256, 512), + (64, 64, 128, 256, 512), + (64, 64, 128, 256, 512), + (64, 64, 128, 256, 512), + (64, 64, 128, 256, 512), + (64, 64, 128, 256, 512), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + ] + + @classmethod + def num_outputs(cls) -> list[int]: + """Get number of resnet backbone output feature maps. + + Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`. + """ + return [5] * 11 + + @classmethod + def get_encoder_names(cls) -> list[str]: + """Get names of resnet backbones.""" + return cls.backbone_names + + def _resnet( arch: str, block: type[ResNetBlock | ResNetBottleneck], @@ -541,3 +868,16 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int): bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 shortcut_type = "A" if resnet_depth in [18, 34] else "B" return bias_downsample, shortcut_type + + +def _load_state_dict(model: nn.Module, model_name: str) -> None: + search_res = re.search(r"resnet(\d+)", model_name) + if search_res: + resnet_depth = int(search_res.group(1)) + datasets23 = model_name.endswith("_23datasets") + else: + raise ValueError("model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.") + + model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device="cpu", datasets23=datasets23) + model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} + model.load_state_dict(model_state_dict) diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 404855c9a8..82e7c4e0ab 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -23,9 +23,9 @@ EfficientNetBNFeatures, FlexibleUNet, FlexUNetEncoderRegister, - ResNet, ResNetBlock, - ResNetBottleneck, + ResNetEncoder, + ResNetFeatures, ) from monai.utils import optional_import from tests.utils import skip_if_downloading_fails, skip_if_quick @@ -59,101 +59,6 @@ def get_encoder_names(cls): return ["encoder_wrong_channels", "encoder_no_param1", "encoder_no_param2", "encoder_no_param3"] -class ResNetEncoder(ResNet, BaseEncoder): - backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] - output_feature_channels = [(64, 128, 256, 512)] * 3 + [(256, 512, 1024, 2048)] * 4 - parameter_layers = [ - [1, 1, 1, 1], - [2, 2, 2, 2], - [3, 4, 6, 3], - [3, 4, 6, 3], - [3, 4, 23, 3], - [3, 8, 36, 3], - [3, 24, 36, 3], - ] - - def __init__(self, in_channels, pretrained, **kargs): - super().__init__(**kargs, n_input_channels=in_channels) - if pretrained: - # Author of paper zipped the state_dict on googledrive, - # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). - # Would like to load dict from url but need somewhere to save the state dicts. - raise NotImplementedError( - "Currently not implemented. You need to manually download weights provided by the paper's author" - " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet" - ) - - @staticmethod - def get_inplanes(): - return [64, 128, 256, 512] - - @classmethod - def get_encoder_parameters(cls) -> list[dict]: - """ - Get parameter list to initialize encoder networks. - Each parameter dict must have `spatial_dims`, `in_channels` - and `pretrained` parameters. - """ - parameter_list = [] - res_type: type[ResNetBlock] | type[ResNetBottleneck] - for backbone in range(len(cls.backbone_names)): - if backbone < 3: - res_type = ResNetBlock - else: - res_type = ResNetBottleneck - parameter_list.append( - { - "block": res_type, - "layers": cls.parameter_layers[backbone], - "block_inplanes": ResNetEncoder.get_inplanes(), - "spatial_dims": 2, - "in_channels": 3, - "pretrained": False, - } - ) - return parameter_list - - @classmethod - def num_channels_per_output(cls): - """ - Get number of output features' channel. - """ - return cls.output_feature_channels - - @classmethod - def num_outputs(cls): - """ - Get number of output feature. - """ - return [4] * 7 - - @classmethod - def get_encoder_names(cls): - """ - Get the name string of backbones which will be used to initialize flexible unet. - """ - return cls.backbone_names - - def forward(self, x: torch.Tensor): - feature_list = [] - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - if not self.no_max_pool: - x = self.maxpool(x) - x = self.layer1(x) - feature_list.append(x) - x = self.layer2(x) - feature_list.append(x) - x = self.layer3(x) - feature_list.append(x) - x = self.layer4(x) - feature_list.append(x) - - return feature_list - - -FLEXUNET_BACKBONE.register_class(ResNetEncoder) FLEXUNET_BACKBONE.register_class(DummyEncoder) @@ -204,9 +109,7 @@ def make_shape_cases( def make_error_case(): - error_dummy_backbones = DummyEncoder.get_encoder_names() - error_resnet_backbones = ResNetEncoder.get_encoder_names() - error_backbones = error_dummy_backbones + error_resnet_backbones + error_backbones = DummyEncoder.get_encoder_names() error_param_list = [] for backbone in error_backbones: error_param_list.append( @@ -232,7 +135,7 @@ def make_error_case(): norm="instance", ) CASES_3D = make_shape_cases( - models=[SEL_MODELS[0]], + models=[SEL_MODELS[0], SEL_MODELS[2]], spatial_dims=[3], batches=[1], pretrained=[False], @@ -345,6 +248,7 @@ def make_error_case(): "spatial_dims": 2, "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), }, + EfficientNetBNFeatures, { "in_channels": 3, "num_classes": 10, @@ -354,7 +258,36 @@ def make_error_case(): "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), }, ["_conv_stem.weight"], - ) + ), + ( + { + "in_channels": 1, + "out_channels": 10, + "backbone": SEL_MODELS[2], + "pretrained": True, + "spatial_dims": 3, + "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), + }, + ResNetFeatures, + { + "model_name": SEL_MODELS[2], + "pretrained": True, + "block": ResNetBlock, + "layers": [1, 1, 1, 1], + "block_inplanes": [64, 128, 256, 512], + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "B", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": False, + }, + ["conv1.weight"], + ), ] CASE_ERRORS = make_error_case() @@ -381,14 +314,14 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) @parameterized.expand(CASES_PRETRAIN) - def test_pretrain(self, input_param, efficient_input_param, weight_list): + def test_pretrain(self, flexunet_input_param, feature_extractor_class, feature_extractor_input_param, weight_list): device = "cuda" if torch.cuda.is_available() else "cpu" with skip_if_downloading_fails(): - net = FlexibleUNet(**input_param).to(device) + net = FlexibleUNet(**flexunet_input_param).to(device) with skip_if_downloading_fails(): - eff_net = EfficientNetBNFeatures(**efficient_input_param).to(device) + eff_net = feature_extractor_class(**feature_extractor_input_param).to(device) for weight_name in weight_list: if weight_name in net.encoder.state_dict() and weight_name in eff_net.state_dict(): diff --git a/tests/test_resnet.py b/tests/test_resnet.py index ad1aad8fc6..e069e95894 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -24,6 +24,7 @@ from monai.networks import eval_mode from monai.networks.nets import ( ResNet, + ResNetFeatures, get_medicalnet_pretrained_resnet_args, get_pretrained_resnet_medicalnet, resnet10, @@ -191,6 +192,31 @@ ] +CASE_EXTRACT_FEATURES = [ + ( + { + "model_name": "resnet10", + "pretrained": True, + "block": ResNetBlock, + "layers": [1, 1, 1, 1], + "block_inplanes": [64, 128, 256, 512], + "spatial_dims": 3, + "in_channels": 1, + "conv1_t_size": 7, + "conv1_t_stride": 2, + "no_max_pool": False, + "shortcut_type": "B", + "widen_factor": 1.0, + "num_classes": 400, + "feed_forward": False, + "bias_downsample": False, + }, + [1, 1, 64, 64, 64], + ([1, 64, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 256, 4, 4, 4], [1, 512, 2, 2, 2]), + ) +] + + class TestResNet(unittest.TestCase): def setUp(self): @@ -270,5 +296,24 @@ def test_script(self, model, input_param, input_shape, expected_shape): test_script_save(net, test_data) +class TestExtractFeatures(unittest.TestCase): + + @parameterized.expand(CASE_EXTRACT_FEATURES) + def test_shape(self, input_param, input_shape, expected_shapes): + device = "cuda" if torch.cuda.is_available() else "cpu" + + with skip_if_downloading_fails(): + net = ResNetFeatures(**input_param).to(device) + + # run inference with random tensor + with eval_mode(net): + features = net(torch.randn(input_shape).to(device)) + + # check output shape + self.assertEqual(len(features), len(expected_shapes)) + for feature, expected_shape in zip(features, expected_shapes): + self.assertEqual(feature.shape, torch.Size(expected_shape)) + + if __name__ == "__main__": unittest.main() From 3531433087cfc1f31c981c7d734053e03107dc02 Mon Sep 17 00:00:00 2001 From: Konstantin Sukharev Date: Sun, 24 Mar 2024 18:35:18 +0500 Subject: [PATCH 2/7] Add ResNetFeatures to docs Signed-off-by: Konstantin Sukharev --- docs/source/networks.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index b59c8af5fc..249375dfc1 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -491,6 +491,11 @@ Nets .. autoclass:: ResNet :members: +`ResNetFeatures` +~~~~~~~~~~~~~~~~ +.. autoclass:: ResNetFeatures + :members: + `SENet` ~~~~~~~ .. autoclass:: SENet From 59eab7a44f1a28a7ded03c4257fdbb86dab7cffd Mon Sep 17 00:00:00 2001 From: Konstantin Sukharev Date: Sun, 24 Mar 2024 18:55:44 +0500 Subject: [PATCH 3/7] Fix types in type hints for ResNetFeatures Signed-off-by: Konstantin Sukharev --- monai/networks/nets/resnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 6aeb869b4c..1e5bcb96a3 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -343,8 +343,8 @@ def __init__( model_name: str, pretrained: bool = True, block: type[ResNetBlock | ResNetBottleneck] | str = ResNetBlock, - layers: list[int] = (1, 1, 1, 1), - block_inplanes: list[int] = (64, 128, 256, 512), + layers: tuple[int] = (1, 1, 1, 1), + block_inplanes: tuple[int] = (64, 128, 256, 512), spatial_dims: int = 3, in_channels: int = 1, conv1_t_size: tuple[int] | int = 7, @@ -386,7 +386,7 @@ def __init__( else: raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.") - def forward(self, inputs: torch.Tensor) -> list[torch.Tensor]: + def forward(self, inputs: torch.Tensor): """ Args: inputs: input should have spatially N dimensions From 0181beef55d9332817c8b4e7705777f5d47eb636 Mon Sep 17 00:00:00 2001 From: Konstantin Sukharev Date: Sun, 24 Mar 2024 21:00:22 +0500 Subject: [PATCH 4/7] Remove useless arguments from ResNetFeatures __init__ Signed-off-by: Konstantin Sukharev --- monai/networks/nets/resnet.py | 273 +++++----------------------------- tests/test_flexible_unet.py | 19 +-- tests/test_resnet.py | 18 +-- 3 files changed, 39 insertions(+), 271 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 1e5bcb96a3..51dfd39b5f 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -46,6 +46,19 @@ "resnet200", ] + +resnet_params = { + # model_name: (block, layers, shortcut_type, bias_downsample, datasets23) + "resnet10": ("basic", [1, 1, 1, 1], "B", False, True), + "resnet18": ("basic", [2, 2, 2, 2], "A", True, True), + "resnet34": ("basic", [3, 4, 6, 3], "A", True, True), + "resnet50": ("bottleneck", [3, 4, 6, 3], "B", False, True), + "resnet101": ("bottleneck", [3, 4, 23, 3], "B", False, False), + "resnet152": ("bottleneck", [3, 8, 36, 3], "B", False, False), + "resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False), +} + + logger = logging.getLogger(__name__) @@ -338,51 +351,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ResNetFeatures(ResNet): - def __init__( - self, - model_name: str, - pretrained: bool = True, - block: type[ResNetBlock | ResNetBottleneck] | str = ResNetBlock, - layers: tuple[int] = (1, 1, 1, 1), - block_inplanes: tuple[int] = (64, 128, 256, 512), - spatial_dims: int = 3, - in_channels: int = 1, - conv1_t_size: tuple[int] | int = 7, - conv1_t_stride: tuple[int] | int = 2, - no_max_pool: bool = False, - shortcut_type: str = "B", - widen_factor: float = 1.0, - num_classes: int = 400, - feed_forward: bool = False, - bias_downsample: bool = False, - ) -> None: + def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None: """Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for segmentation and objection models. Compared with the class `ResNet`, the only different place is the forward function. + + Args: + model_name: name of model to initialize, can be from [resnet10, ..., resnet200]. + pretrained: whether to initialize pretrained Med3D weights, + only available for spatial_dims=3 and in_channels=1. + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of input channels for first convolutional layer. """ - if model_name not in ResNetEncoder.backbone_names: - model_name_string = ", ".join(ResNetEncoder.backbone_names) + if model_name not in resnet_params: + model_name_string = ", ".join(resnet_params.keys()) raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ") + block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name] + super().__init__( block=block, layers=layers, - block_inplanes=block_inplanes, + block_inplanes=get_inplanes(), spatial_dims=spatial_dims, n_input_channels=in_channels, - conv1_t_size=conv1_t_size, - conv1_t_stride=conv1_t_stride, - no_max_pool=no_max_pool, + conv1_t_stride=2, shortcut_type=shortcut_type, - widen_factor=widen_factor, - num_classes=num_classes, - feed_forward=feed_forward, + feed_forward=False, bias_downsample=bias_downsample, ) if pretrained: if spatial_dims == 3 and in_channels == 1: - _load_state_dict(self, model_name) + _load_state_dict(self, model_name, datasets23=datasets23) else: raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.") @@ -423,212 +424,16 @@ def forward(self, inputs: torch.Tensor): class ResNetEncoder(ResNetFeatures, BaseEncoder): """Wrap the original resnet to an encoder for flexible-unet.""" - backbone_names = [ - "resnet10", - "resnet10_23datasets", - "resnet18", - "resnet18_23datasets", - "resnet34", - "resnet34_23datasets", - "resnet50", - "resnet50_23datasets", - "resnet101", - "resnet152", - "resnet200", - ] + backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] @classmethod def get_encoder_parameters(cls) -> list[dict]: """Get the initialization parameter for resnet backbones.""" - parameter_list = [ - { - "model_name": "resnet10", - "pretrained": True, - "block": ResNetBlock, - "layers": [1, 1, 1, 1], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "B", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": False, - }, - { - "model_name": "resnet10_23datasets", - "pretrained": True, - "block": ResNetBlock, - "layers": [1, 1, 1, 1], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "B", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": False, - }, - { - "model_name": "resnet18", - "pretrained": True, - "block": ResNetBlock, - "layers": [2, 2, 2, 2], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "A", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": True, - }, - { - "model_name": "resnet18_23datasets", - "pretrained": True, - "block": ResNetBlock, - "layers": [2, 2, 2, 2], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "A", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": True, - }, - { - "model_name": "resnet34", - "pretrained": True, - "block": ResNetBlock, - "layers": [3, 4, 6, 3], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "A", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": True, - }, - { - "model_name": "resnet34_23datasets", - "pretrained": True, - "block": ResNetBlock, - "layers": [3, 4, 6, 3], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "A", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": True, - }, - { - "model_name": "resnet50", - "pretrained": True, - "block": ResNetBottleneck, - "layers": [3, 4, 6, 3], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "B", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": False, - }, - { - "model_name": "resnet50_23datasets", - "pretrained": True, - "block": ResNetBottleneck, - "layers": [3, 4, 6, 3], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "B", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": False, - }, - { - "model_name": "resnet101", - "pretrained": True, - "block": ResNetBottleneck, - "layers": [3, 4, 23, 3], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "B", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": False, - }, - { - "model_name": "resnet152", - "pretrained": True, - "block": ResNetBottleneck, - "layers": [3, 8, 36, 3], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "B", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": False, - }, - { - "model_name": "resnet200", - "pretrained": True, - "block": ResNetBottleneck, - "layers": [3, 24, 36, 3], - "block_inplanes": get_inplanes(), - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "B", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": False, - }, - ] + parameter_list = [] + for backbone_name in cls.backbone_names: + parameter_list.append( + {"model_name": backbone_name, "pretrained": True, "spatial_dims": 3, "in_channels": 1} + ) return parameter_list @classmethod @@ -638,10 +443,6 @@ def num_channels_per_output(cls) -> list[tuple[int, ...]]: (64, 64, 128, 256, 512), (64, 64, 128, 256, 512), (64, 64, 128, 256, 512), - (64, 64, 128, 256, 512), - (64, 64, 128, 256, 512), - (64, 64, 128, 256, 512), - (64, 256, 512, 1024, 2048), (64, 256, 512, 1024, 2048), (64, 256, 512, 1024, 2048), (64, 256, 512, 1024, 2048), @@ -654,7 +455,7 @@ def num_outputs(cls) -> list[int]: Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`. """ - return [5] * 11 + return [5] * 7 @classmethod def get_encoder_names(cls) -> list[str]: @@ -870,7 +671,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int): return bias_downsample, shortcut_type -def _load_state_dict(model: nn.Module, model_name: str) -> None: +def _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None: search_res = re.search(r"resnet(\d+)", model_name) if search_res: resnet_depth = int(search_res.group(1)) diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 82e7c4e0ab..fc470fda02 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -23,7 +23,6 @@ EfficientNetBNFeatures, FlexibleUNet, FlexUNetEncoderRegister, - ResNetBlock, ResNetEncoder, ResNetFeatures, ) @@ -269,23 +268,7 @@ def make_error_case(): "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), }, ResNetFeatures, - { - "model_name": SEL_MODELS[2], - "pretrained": True, - "block": ResNetBlock, - "layers": [1, 1, 1, 1], - "block_inplanes": [64, 128, 256, 512], - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "B", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": False, - }, + {"model_name": SEL_MODELS[2], "pretrained": True, "spatial_dims": 3, "in_channels": 1}, ["conv1.weight"], ), ] diff --git a/tests/test_resnet.py b/tests/test_resnet.py index e069e95894..5f7d75a0d8 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -194,23 +194,7 @@ CASE_EXTRACT_FEATURES = [ ( - { - "model_name": "resnet10", - "pretrained": True, - "block": ResNetBlock, - "layers": [1, 1, 1, 1], - "block_inplanes": [64, 128, 256, 512], - "spatial_dims": 3, - "in_channels": 1, - "conv1_t_size": 7, - "conv1_t_stride": 2, - "no_max_pool": False, - "shortcut_type": "B", - "widen_factor": 1.0, - "num_classes": 400, - "feed_forward": False, - "bias_downsample": False, - }, + {"model_name": "resnet10", "pretrained": True, "spatial_dims": 3, "in_channels": 1}, [1, 1, 64, 64, 64], ([1, 64, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 256, 4, 4, 4], [1, 512, 2, 2, 2]), ) From c66b875c0f35717d44620f13659080ce82971d51 Mon Sep 17 00:00:00 2001 From: Konstantin Sukharev Date: Sun, 24 Mar 2024 22:30:17 +0500 Subject: [PATCH 5/7] Fix tests if hf_hub_download is not installed Signed-off-by: Konstantin Sukharev --- tests/test_flexible_unet.py | 3 ++- tests/test_resnet.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index fc470fda02..8fb2de5b33 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -27,7 +27,7 @@ ResNetFeatures, ) from monai.utils import optional_import -from tests.utils import skip_if_downloading_fails, skip_if_quick +from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick torchvision, has_torchvision = optional_import("torchvision") PIL, has_pil = optional_import("PIL") @@ -279,6 +279,7 @@ def make_error_case(): CASE_REGISTER_ENCODER = ["EfficientNetEncoder", "monai.networks.nets.EfficientNetEncoder"] +@SkipIfNoModule("hf_hub_download") @skip_if_quick class TestFLEXIBLEUNET(unittest.TestCase): diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 5f7d75a0d8..449edba4bf 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -37,7 +37,14 @@ ) from monai.networks.nets.resnet import ResNetBlock from monai.utils import optional_import -from tests.utils import equal_state_dict, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, test_script_save +from tests.utils import ( + SkipIfNoModule, + equal_state_dict, + skip_if_downloading_fails, + skip_if_no_cuda, + skip_if_quick, + test_script_save, +) if TYPE_CHECKING: import torchvision @@ -280,6 +287,7 @@ def test_script(self, model, input_param, input_shape, expected_shape): test_script_save(net, test_data) +@SkipIfNoModule("hf_hub_download") class TestExtractFeatures(unittest.TestCase): @parameterized.expand(CASE_EXTRACT_FEATURES) From c445889ddf5746f8ef611fdfa580d46a07d3e808 Mon Sep 17 00:00:00 2001 From: Konstantin Sukharev Date: Sun, 14 Apr 2024 00:04:50 +0500 Subject: [PATCH 6/7] Rename eff_net variable to feature_extractor_net Signed-off-by: Konstantin Sukharev --- tests/test_flexible_unet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 8fb2de5b33..42baa28b71 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -305,12 +305,12 @@ def test_pretrain(self, flexunet_input_param, feature_extractor_class, feature_e net = FlexibleUNet(**flexunet_input_param).to(device) with skip_if_downloading_fails(): - eff_net = feature_extractor_class(**feature_extractor_input_param).to(device) + feature_extractor_net = feature_extractor_class(**feature_extractor_input_param).to(device) for weight_name in weight_list: - if weight_name in net.encoder.state_dict() and weight_name in eff_net.state_dict(): + if weight_name in net.encoder.state_dict() and weight_name in feature_extractor_net.state_dict(): net_weight = net.encoder.state_dict()[weight_name] - download_weight = eff_net.state_dict()[weight_name] + download_weight = feature_extractor_net.state_dict()[weight_name] weight_diff = torch.abs(net_weight - download_weight) diff_sum = torch.sum(weight_diff) # check if a weight in weight_list equals to the downloaded weight. From 4272f3fc1086a8d413bbd3163e2199cced22d850 Mon Sep 17 00:00:00 2001 From: Konstantin Sukharev Date: Sun, 14 Apr 2024 00:05:36 +0500 Subject: [PATCH 7/7] Update docstrings Signed-off-by: Konstantin Sukharev --- monai/networks/nets/flexible_unet.py | 5 +++-- monai/networks/nets/resnet.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py index 747ee76163..c27b0fc17b 100644 --- a/monai/networks/nets/flexible_unet.py +++ b/monai/networks/nets/flexible_unet.py @@ -252,8 +252,9 @@ def __init__( out_channels: number of output channels. backbone: name of backbones to initialize, only support efficientnet and resnet right now, can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200]. - pretrained: whether to initialize pretrained ImageNet weights, only available - for spatial_dims=2 and batch norm is used, default to False. + pretrained: whether to initialize pretrained weights. ImageNet weights are available for efficient networks + if spatial_dims=2 and batch norm is used. MedicalNet weights are available for residual networks + if spatial_dims=3 and in_channels=1. Default to False. decoder_channels: number of output channels for all feature maps in decoder. `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default to (256, 128, 64, 32, 16). diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 51dfd39b5f..99975271da 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -359,7 +359,7 @@ def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = Args: model_name: name of model to initialize, can be from [resnet10, ..., resnet200]. - pretrained: whether to initialize pretrained Med3D weights, + pretrained: whether to initialize pretrained MedicalNet weights, only available for spatial_dims=3 and in_channels=1. spatial_dims: number of spatial dimensions of the input image. in_channels: number of input channels for first convolutional layer. @@ -605,7 +605,7 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True): """ - Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet + Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet Args: resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 @@ -661,7 +661,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat def get_medicalnet_pretrained_resnet_args(resnet_depth: int): """ Return correct shortcut_type and bias_downsample - for pretrained MedicalNet weights according to resnet depth + for pretrained MedicalNet weights according to resnet depth. """ # After testing # False: 10, 50, 101, 152, 200