From 6b98c8129e6d14eb725c7b95943b6581ffcf01a7 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 9 Dec 2020 03:46:06 +0000 Subject: [PATCH 1/5] Working on DynUNet Torchscript compatibility Signed-off-by: Eric Kerfoot --- monai/networks/nets/dynunet.py | 314 +++++++++++++++++++++++++++++---- tests/test_dynunet.py | 34 +++- tests/utils.py | 11 +- 3 files changed, 308 insertions(+), 51 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index a70da683ba..a0975e4909 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -19,6 +19,249 @@ __all__ = ["DynUNet", "DynUnet", "Dynunet"] +# class DynUNet(nn.Module): +# """ +# This reimplementation of a dynamic UNet (DynUNet) is based on: +# `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. +# `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. + +# This model is more flexible compared with ``monai.networks.nets.UNet`` in three +# places: + +# - Residual connection is supported in conv blocks. +# - Anisotropic kernel sizes and strides can be used in each layers. +# - Deep supervision heads can be added. + +# The model supports 2D or 3D inputs and is consisted with four kinds of blocks: +# one input block, `n` downsample blocks, one bottleneck and `n+1` upsample blocks. Where, `n>0`. +# The first and last kernel and stride values of the input sequences are used for input block and +# bottleneck respectively, and the rest value(s) are used for downsample and upsample blocks. +# Therefore, pleasure ensure that the length of input sequences (``kernel_size`` and ``strides``) +# is no less than 3 in order to have at least one downsample upsample blocks. + +# Args: +# spatial_dims: number of spatial dimensions. +# in_channels: number of input channels. +# out_channels: number of output channels. +# kernel_size: convolution kernel size. +# strides: convolution strides for each blocks. +# upsample_kernel_size: convolution kernel size for transposed convolution layers. +# norm_name: [``"batch"``, ``"instance"``, ``"group"``] +# feature normalization type and arguments. +# deep_supervision: whether to add deep supervision head before output. Defaults to ``True``. +# If added, in training mode, the network will output not only the last feature maps +# (after being converted via output block), but also the previous feature maps that come +# from the intermediate up sample layers. +# deep_supr_num: number of feature maps that will output during deep supervision head. The +# value should be less than the number of up sample layers. Defaults to 1. +# res_block: whether to use residual connection based convolution blocks during the network. +# Defaults to ``True``. +# """ + +# def __init__( +# self, +# spatial_dims: int, +# in_channels: int, +# out_channels: int, +# kernel_size: Sequence[Union[Sequence[int], int]], +# strides: Sequence[Union[Sequence[int], int]], +# upsample_kernel_size: Sequence[Union[Sequence[int], int]], +# norm_name: str = "instance", +# deep_supervision: bool = True, +# deep_supr_num: int = 1, +# res_block: bool = False, +# ): +# super(DynUNet, self).__init__() +# self.spatial_dims = spatial_dims +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.kernel_size = kernel_size +# self.strides = strides +# self.upsample_kernel_size = upsample_kernel_size +# self.norm_name = norm_name +# self.deep_supervision = deep_supervision +# self.conv_block = UnetResBlock if res_block else UnetBasicBlock +# self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] +# self.input_block = self.get_input_block() +# self.downsamples = self.get_downsamples() +# self.bottleneck = self.get_bottleneck() +# self.upsamples = self.get_upsamples() +# self.output_block = self.get_output_block(0) +# self.deep_supervision_heads = self.get_deep_supervision_heads() +# self.deep_supr_num = deep_supr_num +# self.apply(self.initialize_weights) +# self.check_kernel_stride() +# self.check_deep_supr_num() + +# def check_kernel_stride(self): +# kernels, strides = self.kernel_size, self.strides +# error_msg = "length of kernel_size and strides should be the same, and no less than 3." +# assert len(kernels) == len(strides) and len(kernels) >= 3, error_msg + +# for idx in range(len(kernels)): +# kernel, stride = kernels[idx], strides[idx] +# if not isinstance(kernel, int): +# error_msg = "length of kernel_size in block {} should be the same as spatial_dims.".format(idx) +# assert len(kernel) == self.spatial_dims, error_msg +# if not isinstance(stride, int): +# error_msg = "length of stride in block {} should be the same as spatial_dims.".format(idx) +# assert len(stride) == self.spatial_dims, error_msg + +# def check_deep_supr_num(self): +# deep_supr_num, strides = self.deep_supr_num, self.strides +# num_up_layers = len(strides) - 1 +# error_msg = "deep_supr_num should be less than the number of up sample layers." +# assert 1 <= deep_supr_num < num_up_layers, error_msg + +# def forward(self, x): +# print("Input",self.input_block.conv1.conv) +# out = self.input_block(x) +# outputs = [out] + +# for downsample in self.downsamples: +# print("Down",downsample.conv1.conv) +# out = downsample(out) +# outputs.insert(0, out) + +# print("Bottleneck",self.bottleneck.conv1.conv) +# out = self.bottleneck(out) +# upsample_outs = [] + +# for upsample, skip in zip(self.upsamples, outputs): +# print("Upsample",upsample.transp_conv.conv,out.shape,skip.shape) +# out = upsample(out, skip) +# upsample_outs.append(out) + +# out = self.output_block(out) + +# if self.training and self.deep_supervision: +# start_output_idx = len(upsample_outs) - 1 - self.deep_supr_num +# upsample_outs = upsample_outs[start_output_idx:-1][::-1] +# preds = [self.deep_supervision_heads[i](out) for i, out in enumerate(upsample_outs)] +# return [out] + preds + +# return out + +# def get_input_block(self): +# return self.conv_block( +# self.spatial_dims, +# self.in_channels, +# self.filters[0], +# self.kernel_size[0], +# self.strides[0], +# self.norm_name, +# ) + +# def get_bottleneck(self): +# return self.conv_block( +# self.spatial_dims, +# self.filters[-2], +# self.filters[-1], +# self.kernel_size[-1], +# self.strides[-1], +# self.norm_name, +# ) + +# def get_output_block(self, idx: int): +# return UnetOutBlock( +# self.spatial_dims, +# self.filters[idx], +# self.out_channels, +# ) + +# def get_downsamples(self): +# inp, out = self.filters[:-2], self.filters[1:-1] +# strides, kernel_size = self.strides[1:-1], self.kernel_size[1:-1] +# return self.get_module_list(inp, out, kernel_size, strides, self.conv_block) + +# def get_upsamples(self): +# inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] +# strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] +# upsample_kernel_size = self.upsample_kernel_size[::-1] +# return self.get_module_list(inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size) + +# def get_module_list( +# self, +# in_channels: List[int], +# out_channels: List[int], +# kernel_size: Sequence[Union[Sequence[int], int]], +# strides: Sequence[Union[Sequence[int], int]], +# conv_block: nn.Module, +# upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None, +# ): +# layers = [] +# if upsample_kernel_size is not None: +# for in_c, out_c, kernel, stride, up_kernel in zip( +# in_channels, out_channels, kernel_size, strides, upsample_kernel_size +# ): +# params = { +# "spatial_dims": self.spatial_dims, +# "in_channels": in_c, +# "out_channels": out_c, +# "kernel_size": kernel, +# "stride": stride, +# "norm_name": self.norm_name, +# "upsample_kernel_size": up_kernel, +# } +# layer = conv_block(**params) +# layers.append(layer) +# else: +# for in_c, out_c, kernel, stride in zip(in_channels, out_channels, kernel_size, strides): +# params = { +# "spatial_dims": self.spatial_dims, +# "in_channels": in_c, +# "out_channels": out_c, +# "kernel_size": kernel, +# "stride": stride, +# "norm_name": self.norm_name, +# } +# layer = conv_block(**params) +# layers.append(layer) +# return nn.ModuleList(layers) + +# def get_deep_supervision_heads(self): +# return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)]) + +# @staticmethod +# def initialize_weights(module): +# name = module.__class__.__name__.lower() +# if "conv3d" in name or "conv2d" in name: +# nn.init.kaiming_normal_(module.weight, a=0.01) +# if module.bias is not None: +# nn.init.constant_(module.bias, 0) +# elif "norm" in name: +# nn.init.normal_(module.weight, 1.0, 0.02) +# nn.init.zeros_(module.bias) + + +# DynUnet = Dynunet = DynUNet + + +from typing import List + + +class DynUNetSkipLayer(nn.Module): + heads: List[torch.Tensor] + + def __init__(self, index, heads, downsample, upsample, superhead, nextlayer): + super().__init__() + self.downsample = downsample + self.upsample = upsample + self.nextlayer = nextlayer + self.superhead = superhead + self.heads = heads + self.index = index + + def forward(self, x): + downout = self.downsample(x) + nextout = self.nextlayer(downout) + upout = self.upsample(nextout, downout) + + self.heads[self.index] = self.superhead(upout) + + return upout + + class DynUNet(nn.Module): """ This reimplementation of a dynamic UNet (DynUNet) is based on: @@ -93,6 +336,31 @@ def __init__( self.check_kernel_stride() self.check_deep_supr_num() + self.heads: List[torch.tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) + + def create_skips(index, downsamples, upsamples, superheads, bottleneck): + assert len(downsamples) == len(upsamples), f"{len(downsamples)} != {len(upsamples)}" + assert (len(downsamples) - len(superheads)) in (1, 0), f"{len(downsamples)}-(0,1) != {len(superheads)}" + + if len(downsamples) == 0: + return bottleneck + elif index == 0: # don't associate a supervision head with self.input_block + current_head, rest_heads = nn.Identity(), superheads + else: + current_head, rest_heads = superheads[0], superheads[1:] + + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) + + return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) + + self.skip_layers = create_skips( + 0, + [self.input_block] + list(self.downsamples), + self.upsamples[::-1], + self.deep_supervision_heads, + self.bottleneck, + ) + def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." @@ -114,56 +382,26 @@ def check_deep_supr_num(self): assert 1 <= deep_supr_num < num_up_layers, error_msg def forward(self, x): - out = self.input_block(x) - outputs = [out] - - for downsample in self.downsamples: - out = downsample(out) - outputs.insert(0, out) - - out = self.bottleneck(out) - upsample_outs = [] - - for upsample, skip in zip(self.upsamples, outputs): - out = upsample(out, skip) - upsample_outs.append(out) - + out = self.skip_layers(x) out = self.output_block(out) - + if self.training and self.deep_supervision: - start_output_idx = len(upsample_outs) - 1 - self.deep_supr_num - upsample_outs = upsample_outs[start_output_idx:-1][::-1] - preds = [self.deep_supervision_heads[i](out) for i, out in enumerate(upsample_outs)] - return [out] + preds - - return out + return [out]+self.heads[1:self.deep_supr_num+1] + + return [out] def get_input_block(self): return self.conv_block( - self.spatial_dims, - self.in_channels, - self.filters[0], - self.kernel_size[0], - self.strides[0], - self.norm_name, + self.spatial_dims, self.in_channels, self.filters[0], self.kernel_size[0], self.strides[0], self.norm_name, ) def get_bottleneck(self): return self.conv_block( - self.spatial_dims, - self.filters[-2], - self.filters[-1], - self.kernel_size[-1], - self.strides[-1], - self.norm_name, + self.spatial_dims, self.filters[-2], self.filters[-1], self.kernel_size[-1], self.strides[-1], self.norm_name, ) def get_output_block(self, idx: int): - return UnetOutBlock( - self.spatial_dims, - self.filters[idx], - self.out_channels, - ) + return UnetOutBlock(self.spatial_dims, self.filters[idx], self.out_channels,) def get_downsamples(self): inp, out = self.filters[:-2], self.filters[1:-1] diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index ca5e056a16..c319791755 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,7 +17,7 @@ from monai.networks.nets import DynUNet -# from tests.utils import test_script_save +from tests.utils import test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -111,14 +111,30 @@ def test_shape(self, input_param, input_shape, expected_shape): net.eval() with torch.no_grad(): result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - -# def test_script(self): -# input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] -# net = DynUNet(**input_param) -# test_data = torch.randn(input_shape) -# test_script_save(net, test_data) + self.assertEqual(result[0].shape, expected_shape) + + def test_print(self): + input_param={ + 'spatial_dims': 2, + 'in_channels': 2, + 'out_channels': 2, + 'kernel_size': (3, 1, 1), 'strides': (1, 1, 1), + 'upsample_kernel_size': (1,), + 'norm_name': 'batch', + 'deep_supervision': False, + 'res_block': False + } + input_shape=(1,2,64,64) + input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] + test_data = torch.randn(input_shape) + net = DynUNet(**input_param) + assert net(test_data) is not None + + def test_script(self): + input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] + net = DynUNet(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) class TestDynUNetDeepSupervision(unittest.TestCase): diff --git a/tests/utils.py b/tests/utils.py index 3ab73a4fcd..cd3244f783 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,7 +29,7 @@ import torch.distributed as dist from monai.data import create_test_image_2d, create_test_image_3d -from monai.utils import optional_import, set_determinism +from monai.utils import optional_import, set_determinism, ensure_tuple nib, _ = optional_import("nibabel") @@ -457,11 +457,14 @@ def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): result1 = net(*inputs) result2 = reloaded_net(*inputs) set_determinism(seed=None) + # When using e.g., VAR, we will produce a tuple of outputs. # Hence, convert all to tuples and then compare all elements. - if not isinstance(result1, tuple): - result1 = (result1,) - result2 = (result2,) + result1=ensure_tuple(result1) + result2=ensure_tuple(result2) +# if not isinstance(result1, tuple): +# result1 = (result1,) +# result2 = (result2,) for i, (r1, r2) in enumerate(zip(result1, result2)): if None not in (r1, r2): # might be None From 01de79f34a56bbe0cd820666d4a0b0e9ae1fc93f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 9 Dec 2020 03:48:27 +0000 Subject: [PATCH 2/5] Working on DynUNet Torchscript compatibility Signed-off-by: Eric Kerfoot --- monai/networks/nets/dynunet.py | 221 --------------------------------- tests/utils.py | 8 +- 2 files changed, 2 insertions(+), 227 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index a0975e4909..f9cf8e77d5 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -19,227 +19,6 @@ __all__ = ["DynUNet", "DynUnet", "Dynunet"] -# class DynUNet(nn.Module): -# """ -# This reimplementation of a dynamic UNet (DynUNet) is based on: -# `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. -# `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - -# This model is more flexible compared with ``monai.networks.nets.UNet`` in three -# places: - -# - Residual connection is supported in conv blocks. -# - Anisotropic kernel sizes and strides can be used in each layers. -# - Deep supervision heads can be added. - -# The model supports 2D or 3D inputs and is consisted with four kinds of blocks: -# one input block, `n` downsample blocks, one bottleneck and `n+1` upsample blocks. Where, `n>0`. -# The first and last kernel and stride values of the input sequences are used for input block and -# bottleneck respectively, and the rest value(s) are used for downsample and upsample blocks. -# Therefore, pleasure ensure that the length of input sequences (``kernel_size`` and ``strides``) -# is no less than 3 in order to have at least one downsample upsample blocks. - -# Args: -# spatial_dims: number of spatial dimensions. -# in_channels: number of input channels. -# out_channels: number of output channels. -# kernel_size: convolution kernel size. -# strides: convolution strides for each blocks. -# upsample_kernel_size: convolution kernel size for transposed convolution layers. -# norm_name: [``"batch"``, ``"instance"``, ``"group"``] -# feature normalization type and arguments. -# deep_supervision: whether to add deep supervision head before output. Defaults to ``True``. -# If added, in training mode, the network will output not only the last feature maps -# (after being converted via output block), but also the previous feature maps that come -# from the intermediate up sample layers. -# deep_supr_num: number of feature maps that will output during deep supervision head. The -# value should be less than the number of up sample layers. Defaults to 1. -# res_block: whether to use residual connection based convolution blocks during the network. -# Defaults to ``True``. -# """ - -# def __init__( -# self, -# spatial_dims: int, -# in_channels: int, -# out_channels: int, -# kernel_size: Sequence[Union[Sequence[int], int]], -# strides: Sequence[Union[Sequence[int], int]], -# upsample_kernel_size: Sequence[Union[Sequence[int], int]], -# norm_name: str = "instance", -# deep_supervision: bool = True, -# deep_supr_num: int = 1, -# res_block: bool = False, -# ): -# super(DynUNet, self).__init__() -# self.spatial_dims = spatial_dims -# self.in_channels = in_channels -# self.out_channels = out_channels -# self.kernel_size = kernel_size -# self.strides = strides -# self.upsample_kernel_size = upsample_kernel_size -# self.norm_name = norm_name -# self.deep_supervision = deep_supervision -# self.conv_block = UnetResBlock if res_block else UnetBasicBlock -# self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] -# self.input_block = self.get_input_block() -# self.downsamples = self.get_downsamples() -# self.bottleneck = self.get_bottleneck() -# self.upsamples = self.get_upsamples() -# self.output_block = self.get_output_block(0) -# self.deep_supervision_heads = self.get_deep_supervision_heads() -# self.deep_supr_num = deep_supr_num -# self.apply(self.initialize_weights) -# self.check_kernel_stride() -# self.check_deep_supr_num() - -# def check_kernel_stride(self): -# kernels, strides = self.kernel_size, self.strides -# error_msg = "length of kernel_size and strides should be the same, and no less than 3." -# assert len(kernels) == len(strides) and len(kernels) >= 3, error_msg - -# for idx in range(len(kernels)): -# kernel, stride = kernels[idx], strides[idx] -# if not isinstance(kernel, int): -# error_msg = "length of kernel_size in block {} should be the same as spatial_dims.".format(idx) -# assert len(kernel) == self.spatial_dims, error_msg -# if not isinstance(stride, int): -# error_msg = "length of stride in block {} should be the same as spatial_dims.".format(idx) -# assert len(stride) == self.spatial_dims, error_msg - -# def check_deep_supr_num(self): -# deep_supr_num, strides = self.deep_supr_num, self.strides -# num_up_layers = len(strides) - 1 -# error_msg = "deep_supr_num should be less than the number of up sample layers." -# assert 1 <= deep_supr_num < num_up_layers, error_msg - -# def forward(self, x): -# print("Input",self.input_block.conv1.conv) -# out = self.input_block(x) -# outputs = [out] - -# for downsample in self.downsamples: -# print("Down",downsample.conv1.conv) -# out = downsample(out) -# outputs.insert(0, out) - -# print("Bottleneck",self.bottleneck.conv1.conv) -# out = self.bottleneck(out) -# upsample_outs = [] - -# for upsample, skip in zip(self.upsamples, outputs): -# print("Upsample",upsample.transp_conv.conv,out.shape,skip.shape) -# out = upsample(out, skip) -# upsample_outs.append(out) - -# out = self.output_block(out) - -# if self.training and self.deep_supervision: -# start_output_idx = len(upsample_outs) - 1 - self.deep_supr_num -# upsample_outs = upsample_outs[start_output_idx:-1][::-1] -# preds = [self.deep_supervision_heads[i](out) for i, out in enumerate(upsample_outs)] -# return [out] + preds - -# return out - -# def get_input_block(self): -# return self.conv_block( -# self.spatial_dims, -# self.in_channels, -# self.filters[0], -# self.kernel_size[0], -# self.strides[0], -# self.norm_name, -# ) - -# def get_bottleneck(self): -# return self.conv_block( -# self.spatial_dims, -# self.filters[-2], -# self.filters[-1], -# self.kernel_size[-1], -# self.strides[-1], -# self.norm_name, -# ) - -# def get_output_block(self, idx: int): -# return UnetOutBlock( -# self.spatial_dims, -# self.filters[idx], -# self.out_channels, -# ) - -# def get_downsamples(self): -# inp, out = self.filters[:-2], self.filters[1:-1] -# strides, kernel_size = self.strides[1:-1], self.kernel_size[1:-1] -# return self.get_module_list(inp, out, kernel_size, strides, self.conv_block) - -# def get_upsamples(self): -# inp, out = self.filters[1:][::-1], self.filters[:-1][::-1] -# strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1] -# upsample_kernel_size = self.upsample_kernel_size[::-1] -# return self.get_module_list(inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size) - -# def get_module_list( -# self, -# in_channels: List[int], -# out_channels: List[int], -# kernel_size: Sequence[Union[Sequence[int], int]], -# strides: Sequence[Union[Sequence[int], int]], -# conv_block: nn.Module, -# upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None, -# ): -# layers = [] -# if upsample_kernel_size is not None: -# for in_c, out_c, kernel, stride, up_kernel in zip( -# in_channels, out_channels, kernel_size, strides, upsample_kernel_size -# ): -# params = { -# "spatial_dims": self.spatial_dims, -# "in_channels": in_c, -# "out_channels": out_c, -# "kernel_size": kernel, -# "stride": stride, -# "norm_name": self.norm_name, -# "upsample_kernel_size": up_kernel, -# } -# layer = conv_block(**params) -# layers.append(layer) -# else: -# for in_c, out_c, kernel, stride in zip(in_channels, out_channels, kernel_size, strides): -# params = { -# "spatial_dims": self.spatial_dims, -# "in_channels": in_c, -# "out_channels": out_c, -# "kernel_size": kernel, -# "stride": stride, -# "norm_name": self.norm_name, -# } -# layer = conv_block(**params) -# layers.append(layer) -# return nn.ModuleList(layers) - -# def get_deep_supervision_heads(self): -# return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)]) - -# @staticmethod -# def initialize_weights(module): -# name = module.__class__.__name__.lower() -# if "conv3d" in name or "conv2d" in name: -# nn.init.kaiming_normal_(module.weight, a=0.01) -# if module.bias is not None: -# nn.init.constant_(module.bias, 0) -# elif "norm" in name: -# nn.init.normal_(module.weight, 1.0, 0.02) -# nn.init.zeros_(module.bias) - - -# DynUnet = Dynunet = DynUNet - - -from typing import List - - class DynUNetSkipLayer(nn.Module): heads: List[torch.Tensor] diff --git a/tests/utils.py b/tests/utils.py index cd3244f783..0f753e7660 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -458,14 +458,10 @@ def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): result2 = reloaded_net(*inputs) set_determinism(seed=None) - # When using e.g., VAR, we will produce a tuple of outputs. - # Hence, convert all to tuples and then compare all elements. + # convert results to tuples if needed to allow iterating over pairs of outputs result1=ensure_tuple(result1) result2=ensure_tuple(result2) -# if not isinstance(result1, tuple): -# result1 = (result1,) -# result2 = (result2,) - + for i, (r1, r2) in enumerate(zip(result1, result2)): if None not in (r1, r2): # might be None np.testing.assert_allclose( From a23889df627f5e7d36e1024f106518e99f8ec8f0 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 9 Dec 2020 03:48:48 +0000 Subject: [PATCH 3/5] Working on DynUNet Torchscript compatibility Signed-off-by: Eric Kerfoot --- tests/test_dynunet.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index c319791755..82a8e09113 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -113,23 +113,6 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result[0].shape, expected_shape) - def test_print(self): - input_param={ - 'spatial_dims': 2, - 'in_channels': 2, - 'out_channels': 2, - 'kernel_size': (3, 1, 1), 'strides': (1, 1, 1), - 'upsample_kernel_size': (1,), - 'norm_name': 'batch', - 'deep_supervision': False, - 'res_block': False - } - input_shape=(1,2,64,64) - input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] - test_data = torch.randn(input_shape) - net = DynUNet(**input_param) - assert net(test_data) is not None - def test_script(self): input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] net = DynUNet(**input_param) From 2c2a212564876b83869432b838fc81dae47bdadb Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 9 Dec 2020 04:06:57 +0000 Subject: [PATCH 4/5] Working on DynUNet Torchscript compatibility Signed-off-by: Eric Kerfoot --- monai/networks/nets/dynunet.py | 30 ++++++++++++++++++++++-------- tests/test_dynunet.py | 1 - tests/utils.py | 10 +++++----- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index f9cf8e77d5..d5ea19329f 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -115,7 +115,7 @@ def __init__( self.check_kernel_stride() self.check_deep_supr_num() - self.heads: List[torch.tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) + self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) def create_skips(index, downsamples, upsamples, superheads, bottleneck): assert len(downsamples) == len(upsamples), f"{len(downsamples)} != {len(upsamples)}" @@ -123,7 +123,7 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): if len(downsamples) == 0: return bottleneck - elif index == 0: # don't associate a supervision head with self.input_block + elif index == 0: # don't associate a supervision head with self.input_block current_head, rest_heads = nn.Identity(), superheads else: current_head, rest_heads = superheads[0], superheads[1:] @@ -163,24 +163,38 @@ def check_deep_supr_num(self): def forward(self, x): out = self.skip_layers(x) out = self.output_block(out) - + if self.training and self.deep_supervision: - return [out]+self.heads[1:self.deep_supr_num+1] - + return [out] + self.heads[1 : self.deep_supr_num + 1] + return [out] def get_input_block(self): return self.conv_block( - self.spatial_dims, self.in_channels, self.filters[0], self.kernel_size[0], self.strides[0], self.norm_name, + self.spatial_dims, + self.in_channels, + self.filters[0], + self.kernel_size[0], + self.strides[0], + self.norm_name, ) def get_bottleneck(self): return self.conv_block( - self.spatial_dims, self.filters[-2], self.filters[-1], self.kernel_size[-1], self.strides[-1], self.norm_name, + self.spatial_dims, + self.filters[-2], + self.filters[-1], + self.kernel_size[-1], + self.strides[-1], + self.norm_name, ) def get_output_block(self, idx: int): - return UnetOutBlock(self.spatial_dims, self.filters[idx], self.out_channels,) + return UnetOutBlock( + self.spatial_dims, + self.filters[idx], + self.out_channels, + ) def get_downsamples(self): inp, out = self.filters[:-2], self.filters[1:-1] diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 82a8e09113..6b89c8c4fd 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -16,7 +16,6 @@ from parameterized import parameterized from monai.networks.nets import DynUNet - from tests.utils import test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/tests/utils.py b/tests/utils.py index 0f753e7660..6c717264ac 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,7 +29,7 @@ import torch.distributed as dist from monai.data import create_test_image_2d, create_test_image_3d -from monai.utils import optional_import, set_determinism, ensure_tuple +from monai.utils import ensure_tuple, optional_import, set_determinism nib, _ = optional_import("nibabel") @@ -457,11 +457,11 @@ def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): result1 = net(*inputs) result2 = reloaded_net(*inputs) set_determinism(seed=None) - + # convert results to tuples if needed to allow iterating over pairs of outputs - result1=ensure_tuple(result1) - result2=ensure_tuple(result2) - + result1 = ensure_tuple(result1) + result2 = ensure_tuple(result2) + for i, (r1, r2) in enumerate(zip(result1, result2)): if None not in (r1, r2): # might be None np.testing.assert_allclose( From a1707faa2d466a278586b99c5edcc80ece5994eb Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 9 Dec 2020 19:43:32 +0000 Subject: [PATCH 5/5] Working on DynUNet Torchscript compatibility Signed-off-by: Eric Kerfoot --- monai/networks/nets/dynunet.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index d5ea19329f..0915785db6 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -20,23 +20,32 @@ class DynUNetSkipLayer(nn.Module): + """ + Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection. + The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet + structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on + looping over lists of layers and accumulating lists of output tensors which much be indexed. The `heads` list is + shared amongst all the instances of this class and is used to store the output from the supervision heads during + forward passes of the network. + """ + heads: List[torch.Tensor] - def __init__(self, index, heads, downsample, upsample, superhead, nextlayer): + def __init__(self, index, heads, downsample, upsample, super_head, next_layer): super().__init__() self.downsample = downsample self.upsample = upsample - self.nextlayer = nextlayer - self.superhead = superhead + self.next_layer = next_layer + self.super_head = super_head self.heads = heads self.index = index def forward(self, x): downout = self.downsample(x) - nextout = self.nextlayer(downout) + nextout = self.next_layer(downout) upout = self.upsample(nextout, downout) - self.heads[self.index] = self.superhead(upout) + self.heads[self.index] = self.super_head(upout) return upout @@ -115,19 +124,31 @@ def __init__( self.check_kernel_stride() self.check_deep_supr_num() + # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) def create_skips(index, downsamples, upsamples, superheads, bottleneck): + """ + Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is + done recursively from the top down since a recursive nn.Module subclass is being used to be compatible + with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads` + since the `input_block` is passed to this function as the first item in `downsamples`, however this + shouldn't be associated with a supervision head. + """ + assert len(downsamples) == len(upsamples), f"{len(downsamples)} != {len(upsamples)}" assert (len(downsamples) - len(superheads)) in (1, 0), f"{len(downsamples)}-(0,1) != {len(superheads)}" - if len(downsamples) == 0: + if len(downsamples) == 0: # bottom of the network, pass the bottleneck block return bottleneck elif index == 0: # don't associate a supervision head with self.input_block current_head, rest_heads = nn.Identity(), superheads + elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one + current_head, rest_heads = nn.Identity(), superheads[1:] else: current_head, rest_heads = superheads[0], superheads[1:] + # create the next layer down, this will stop at the bottleneck layer next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer)