diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index a70da683ba..0915785db6 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -19,6 +19,37 @@ __all__ = ["DynUNet", "DynUnet", "Dynunet"] +class DynUNetSkipLayer(nn.Module): + """ + Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection. + The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet + structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on + looping over lists of layers and accumulating lists of output tensors which much be indexed. The `heads` list is + shared amongst all the instances of this class and is used to store the output from the supervision heads during + forward passes of the network. + """ + + heads: List[torch.Tensor] + + def __init__(self, index, heads, downsample, upsample, super_head, next_layer): + super().__init__() + self.downsample = downsample + self.upsample = upsample + self.next_layer = next_layer + self.super_head = super_head + self.heads = heads + self.index = index + + def forward(self, x): + downout = self.downsample(x) + nextout = self.next_layer(downout) + upout = self.upsample(nextout, downout) + + self.heads[self.index] = self.super_head(upout) + + return upout + + class DynUNet(nn.Module): """ This reimplementation of a dynamic UNet (DynUNet) is based on: @@ -93,6 +124,43 @@ def __init__( self.check_kernel_stride() self.check_deep_supr_num() + # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on + self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1) + + def create_skips(index, downsamples, upsamples, superheads, bottleneck): + """ + Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is + done recursively from the top down since a recursive nn.Module subclass is being used to be compatible + with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads` + since the `input_block` is passed to this function as the first item in `downsamples`, however this + shouldn't be associated with a supervision head. + """ + + assert len(downsamples) == len(upsamples), f"{len(downsamples)} != {len(upsamples)}" + assert (len(downsamples) - len(superheads)) in (1, 0), f"{len(downsamples)}-(0,1) != {len(superheads)}" + + if len(downsamples) == 0: # bottom of the network, pass the bottleneck block + return bottleneck + elif index == 0: # don't associate a supervision head with self.input_block + current_head, rest_heads = nn.Identity(), superheads + elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one + current_head, rest_heads = nn.Identity(), superheads[1:] + else: + current_head, rest_heads = superheads[0], superheads[1:] + + # create the next layer down, this will stop at the bottleneck layer + next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck) + + return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer) + + self.skip_layers = create_skips( + 0, + [self.input_block] + list(self.downsamples), + self.upsamples[::-1], + self.deep_supervision_heads, + self.bottleneck, + ) + def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." @@ -114,29 +182,13 @@ def check_deep_supr_num(self): assert 1 <= deep_supr_num < num_up_layers, error_msg def forward(self, x): - out = self.input_block(x) - outputs = [out] - - for downsample in self.downsamples: - out = downsample(out) - outputs.insert(0, out) - - out = self.bottleneck(out) - upsample_outs = [] - - for upsample, skip in zip(self.upsamples, outputs): - out = upsample(out, skip) - upsample_outs.append(out) - + out = self.skip_layers(x) out = self.output_block(out) if self.training and self.deep_supervision: - start_output_idx = len(upsample_outs) - 1 - self.deep_supr_num - upsample_outs = upsample_outs[start_output_idx:-1][::-1] - preds = [self.deep_supervision_heads[i](out) for i, out in enumerate(upsample_outs)] - return [out] + preds + return [out] + self.heads[1 : self.deep_supr_num + 1] - return out + return [out] def get_input_block(self): return self.conv_block( diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index ca5e056a16..6b89c8c4fd 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -16,8 +16,7 @@ from parameterized import parameterized from monai.networks.nets import DynUNet - -# from tests.utils import test_script_save +from tests.utils import test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -111,14 +110,13 @@ def test_shape(self, input_param, input_shape, expected_shape): net.eval() with torch.no_grad(): result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - + self.assertEqual(result[0].shape, expected_shape) -# def test_script(self): -# input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] -# net = DynUNet(**input_param) -# test_data = torch.randn(input_shape) -# test_script_save(net, test_data) + def test_script(self): + input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] + net = DynUNet(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) class TestDynUNetDeepSupervision(unittest.TestCase): diff --git a/tests/utils.py b/tests/utils.py index 3ab73a4fcd..6c717264ac 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,7 +29,7 @@ import torch.distributed as dist from monai.data import create_test_image_2d, create_test_image_3d -from monai.utils import optional_import, set_determinism +from monai.utils import ensure_tuple, optional_import, set_determinism nib, _ = optional_import("nibabel") @@ -457,11 +457,10 @@ def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4): result1 = net(*inputs) result2 = reloaded_net(*inputs) set_determinism(seed=None) - # When using e.g., VAR, we will produce a tuple of outputs. - # Hence, convert all to tuples and then compare all elements. - if not isinstance(result1, tuple): - result1 = (result1,) - result2 = (result2,) + + # convert results to tuples if needed to allow iterating over pairs of outputs + result1 = ensure_tuple(result1) + result2 = ensure_tuple(result2) for i, (r1, r2) in enumerate(zip(result1, result2)): if None not in (r1, r2): # might be None