From 202a4e70006f1ba7edb6a449cfc93eca823083db Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 1 Dec 2021 21:07:21 +0800 Subject: [PATCH 1/5] enhance dynunet Signed-off-by: Yiheng Wang --- monai/networks/nets/dynunet.py | 100 ++++++++++++++++++------------ tests/test_network_consistency.py | 8 ++- 2 files changed, 66 insertions(+), 42 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 4cd3046261..c09ee099dd 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -31,13 +31,13 @@ class DynUNetSkipLayer(nn.Module): forward passes of the network. """ - heads: List[torch.Tensor] + heads: Optional[List[torch.Tensor]] - def __init__(self, index, heads, downsample, upsample, super_head, next_layer): + def __init__(self, index, downsample, upsample, next_layer, heads=None, super_head=None): super().__init__() self.downsample = downsample - self.upsample = upsample self.next_layer = next_layer + self.upsample = upsample self.super_head = super_head self.heads = heads self.index = index @@ -46,8 +46,8 @@ def forward(self, x): downout = self.downsample(x) nextout = self.next_layer(downout) upout = self.upsample(nextout, downout) - - self.heads[self.index] = self.super_head(upout) + if self.super_head is not None and self.heads is not None and self.index > 0: + self.heads[self.index - 1] = self.super_head(upout) return upout @@ -58,30 +58,24 @@ class DynUNet(nn.Module): `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. `Optimized U-Net for Brain Tumor Segmentation `_. - This model is more flexible compared with ``monai.networks.nets.UNet`` in three places: - - Residual connection is supported in conv blocks. - Anisotropic kernel sizes and strides can be used in each layers. - Deep supervision heads can be added. - The model supports 2D or 3D inputs and is consisted with four kinds of blocks: one input block, `n` downsample blocks, one bottleneck and `n+1` upsample blocks. Where, `n>0`. The first and last kernel and stride values of the input sequences are used for input block and bottleneck respectively, and the rest value(s) are used for downsample and upsample blocks. Therefore, pleasure ensure that the length of input sequences (``kernel_size`` and ``strides``) is no less than 3 in order to have at least one downsample and upsample blocks. - To meet the requirements of the structure, the input size for each spatial dimension should be divisible by `2 * the product of all strides in the corresponding dimension`. The output size for each spatial dimension equals to the input size of the corresponding dimension divided by the stride in strides[0]. For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and the spatial size of the output is `(8, 8, 8)`. - Usage example with medical segmentation decathlon dataset is available at: https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline. - Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. @@ -111,7 +105,6 @@ class DynUNet(nn.Module): When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss one by one with the ground truth, then do a weighted average for all losses to achieve the final loss. (To be added: a corresponding tutorial link) - deep_supr_num: number of feature maps that will output during deep supervision head. The value should be larger than 0 and less than the number of up sample layers. Defaults to 1. @@ -160,16 +153,17 @@ def __init__( self.upsamples = self.get_upsamples() self.output_block = self.get_output_block(0) self.deep_supervision = deep_supervision - self.deep_supervision_heads = self.get_deep_supervision_heads() self.deep_supr_num = deep_supr_num + # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on + self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num + if self.deep_supervision: + self.deep_supervision_heads = self.get_deep_supervision_heads() + self.check_deep_supr_num() + self.apply(self.initialize_weights) self.check_kernel_stride() - self.check_deep_supr_num() - # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on - self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) - - def create_skips(index, downsamples, upsamples, superheads, bottleneck): + def create_skips(index, downsamples, upsamples, bottleneck, superheads=None): """ Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is done recursively from the top down since a recursive nn.Module subclass is being used to be compatible @@ -180,30 +174,55 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): if len(downsamples) != len(upsamples): raise ValueError(f"{len(downsamples)} != {len(upsamples)}") - if (len(downsamples) - len(superheads)) not in (1, 0): - raise ValueError(f"{len(downsamples)}-(0,1) != {len(superheads)}") if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck - if index == 0: # don't associate a supervision head with self.input_block - current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] - else: - current_head, rest_heads = superheads[0], superheads[1:] - # create the next layer down, this will stop at the bottleneck layer - next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) - - return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) - - self.skip_layers = create_skips( - 0, - [self.input_block] + list(self.downsamples), - self.upsamples[::-1], - self.deep_supervision_heads, - self.bottleneck, - ) + if superheads is None: + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck) + return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer) + else: + super_head_flag = False + if index == 0: # don't associate a supervision head with self.input_block + rest_heads = superheads + else: + if len(superheads) > 0: + super_head_flag = True + rest_heads = superheads[1:] + else: + rest_heads = nn.ModuleList() + + # create the next layer down, this will stop at the bottleneck layer + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, rest_heads) + if super_head_flag: + return DynUNetSkipLayer( + index, + downsample=downsamples[0], + upsample=upsamples[0], + next_layer=next_layer, + heads=self.heads, + super_head=superheads[0], + ) + else: + return DynUNetSkipLayer( + index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer + ) + + if not self.deep_supervision: + self.skip_layers = create_skips( + 0, + [self.input_block] + list(self.downsamples), + self.upsamples[::-1], + self.bottleneck, + ) + else: + self.skip_layers = create_skips( + 0, + [self.input_block] + list(self.downsamples), + self.upsamples[::-1], + self.bottleneck, + self.deep_supervision_heads, + ) def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides @@ -242,8 +261,7 @@ def forward(self, x): out = self.output_block(out) if self.training and self.deep_supervision: out_all = [out] - feature_maps = self.heads[1 : self.deep_supr_num + 1] - for feature_map in feature_maps: + for feature_map in self.heads: out_all.append(interpolate(feature_map, out.shape[2:])) return torch.stack(out_all, dim=1) return out @@ -334,7 +352,7 @@ def get_module_list( return nn.ModuleList(layers) def get_deep_supervision_heads(self): - return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)]) + return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)]) @staticmethod def initialize_weights(module): diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index ccccd9e7f0..7feb4ca713 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -61,7 +61,13 @@ def test_network_consistency(self, net_name, data_path, json_path): # Create model model = nets.__dict__[net_name](**model_params) - model.load_state_dict(loaded_data["model"]) + state_dict = loaded_data["model"] + model_dict = model.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + model.load_state_dict(model_dict) model.eval() in_data = loaded_data["in_data"] From c6e7e5272084ddd2bc4beaf53d38813aa908c5c5 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 1 Dec 2021 21:51:55 +0800 Subject: [PATCH 2/5] fix black issue Signed-off-by: Yiheng Wang --- monai/networks/nets/dynunet.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index c09ee099dd..93cd9a2be1 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -58,24 +58,30 @@ class DynUNet(nn.Module): `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. `Optimized U-Net for Brain Tumor Segmentation `_. + This model is more flexible compared with ``monai.networks.nets.UNet`` in three places: + - Residual connection is supported in conv blocks. - Anisotropic kernel sizes and strides can be used in each layers. - Deep supervision heads can be added. + The model supports 2D or 3D inputs and is consisted with four kinds of blocks: one input block, `n` downsample blocks, one bottleneck and `n+1` upsample blocks. Where, `n>0`. The first and last kernel and stride values of the input sequences are used for input block and bottleneck respectively, and the rest value(s) are used for downsample and upsample blocks. Therefore, pleasure ensure that the length of input sequences (``kernel_size`` and ``strides``) is no less than 3 in order to have at least one downsample and upsample blocks. + To meet the requirements of the structure, the input size for each spatial dimension should be divisible by `2 * the product of all strides in the corresponding dimension`. The output size for each spatial dimension equals to the input size of the corresponding dimension divided by the stride in strides[0]. For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and the spatial size of the output is `(8, 8, 8)`. + Usage example with medical segmentation decathlon dataset is available at: https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline. + Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. @@ -210,10 +216,7 @@ def create_skips(index, downsamples, upsamples, bottleneck, superheads=None): if not self.deep_supervision: self.skip_layers = create_skips( - 0, - [self.input_block] + list(self.downsamples), - self.upsamples[::-1], - self.bottleneck, + 0, [self.input_block] + list(self.downsamples), self.upsamples[::-1], self.bottleneck ) else: self.skip_layers = create_skips( From 4dbd660dabf8b76fbf98f539345edbf96e8b3964 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 6 Dec 2021 17:46:32 +0800 Subject: [PATCH 3/5] use strict=False Signed-off-by: Yiheng Wang --- monai/networks/nets/dynunet.py | 2 ++ tests/test_network_consistency.py | 10 ++-------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 93cd9a2be1..10e696cb01 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -79,6 +79,8 @@ class DynUNet(nn.Module): For example, if `strides=((1, 2, 4), 2, 1, 1)`, the minimal spatial size of the input is `(8, 16, 32)`, and the spatial size of the output is `(8, 8, 8)`. + For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`. + Usage example with medical segmentation decathlon dataset is available at: https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline. diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index 7feb4ca713..c3e87cd099 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -60,14 +60,8 @@ def test_network_consistency(self, net_name, data_path, json_path): json_file.close() # Create model - model = nets.__dict__[net_name](**model_params) - state_dict = loaded_data["model"] - model_dict = model.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - model.load_state_dict(model_dict) + model = getattr(nets, net_name)(**model_params) + model.load_state_dict(loaded_data["model"], strict=False) model.eval() in_data = loaded_data["in_data"] From 3bc888db3d3823152b39daaa4848b0863d0acabb Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 6 Dec 2021 18:35:54 +0800 Subject: [PATCH 4/5] fix black 21.12 error Signed-off-by: Yiheng Wang --- monai/networks/blocks/activation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index b136eb7f1f..9b58be04e8 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -19,7 +19,6 @@ def monai_mish(x, inplace: bool = False): return torch.nn.functional.mish(x, inplace=inplace) - else: def monai_mish(x, inplace: bool = False): @@ -31,7 +30,6 @@ def monai_mish(x, inplace: bool = False): def monai_swish(x, inplace: bool = False): return torch.nn.functional.silu(x, inplace=inplace) - else: def monai_swish(x, inplace: bool = False): From 2efe2a1407fc83483d7378197e82c3ea50fb4941 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 7 Dec 2021 17:08:11 +0800 Subject: [PATCH 5/5] enhance code and update docstring Signed-off-by: Yiheng Wang --- monai/networks/nets/dynunet.py | 61 +++++++++++++++---------------- tests/test_network_consistency.py | 2 +- 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 10e696cb01..08938bb3bd 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -102,17 +102,16 @@ class DynUNet(nn.Module): norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. - If ``True``, in training mode, the forward function will output not only the last feature - map, but also the previous feature maps that come from the intermediate up sample layers. + If ``True``, in training mode, the forward function will output not only the final feature map + (from `output_block`), but also the feature maps that come from the intermediate up sample layers. In order to unify the return type (the restriction of TorchScript), all intermediate - feature maps are interpolated into the same size as the last feature map and stacked together + feature maps are interpolated into the same size as the final feature map and stacked together (with a new dimension in the first axis)into one single tensor. - For instance, if there are three feature maps with shapes: (1, 2, 32, 24), (1, 2, 16, 12) and - (1, 2, 8, 6). The last two will be interpolated into (1, 2, 32, 24), and the stacked tensor - will has the shape (1, 3, 2, 8, 6). + For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and + (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps + will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24). When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss one by one with the ground truth, then do a weighted average for all losses to achieve the final loss. - (To be added: a corresponding tutorial link) deep_supr_num: number of feature maps that will output during deep supervision head. The value should be larger than 0 and less than the number of up sample layers. Defaults to 1. @@ -189,32 +188,30 @@ def create_skips(index, downsamples, upsamples, bottleneck, superheads=None): if superheads is None: next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck) return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer) + + super_head_flag = False + if index == 0: # don't associate a supervision head with self.input_block + rest_heads = superheads else: - super_head_flag = False - if index == 0: # don't associate a supervision head with self.input_block - rest_heads = superheads - else: - if len(superheads) > 0: - super_head_flag = True - rest_heads = superheads[1:] - else: - rest_heads = nn.ModuleList() - - # create the next layer down, this will stop at the bottleneck layer - next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, rest_heads) - if super_head_flag: - return DynUNetSkipLayer( - index, - downsample=downsamples[0], - upsample=upsamples[0], - next_layer=next_layer, - heads=self.heads, - super_head=superheads[0], - ) + if len(superheads) > 0: + super_head_flag = True + rest_heads = superheads[1:] else: - return DynUNetSkipLayer( - index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer - ) + rest_heads = nn.ModuleList() + + # create the next layer down, this will stop at the bottleneck layer + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, superheads=rest_heads) + if super_head_flag: + return DynUNetSkipLayer( + index, + downsample=downsamples[0], + upsample=upsamples[0], + next_layer=next_layer, + heads=self.heads, + super_head=superheads[0], + ) + + return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer) if not self.deep_supervision: self.skip_layers = create_skips( @@ -226,7 +223,7 @@ def create_skips(index, downsamples, upsamples, bottleneck, superheads=None): [self.input_block] + list(self.downsamples), self.upsamples[::-1], self.bottleneck, - self.deep_supervision_heads, + superheads=self.deep_supervision_heads, ) def check_kernel_stride(self): diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py index c3e87cd099..92cb7a0595 100644 --- a/tests/test_network_consistency.py +++ b/tests/test_network_consistency.py @@ -22,7 +22,7 @@ import monai.networks.nets as nets from monai.utils import set_determinism -extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA", None) +extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA") TESTS = [] if extra_test_data_dir is not None: