diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 0915785db6..5c241903fa 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -79,10 +79,6 @@ class DynUNet(nn.Module): upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: [``"batch"``, ``"instance"``, ``"group"``] feature normalization type and arguments. - deep_supervision: whether to add deep supervision head before output. Defaults to ``True``. - If added, in training mode, the network will output not only the last feature maps - (after being converted via output block), but also the previous feature maps that come - from the intermediate up sample layers. deep_supr_num: number of feature maps that will output during deep supervision head. The value should be less than the number of up sample layers. Defaults to 1. res_block: whether to use residual connection based convolution blocks during the network. @@ -98,7 +94,6 @@ def __init__( strides: Sequence[Union[Sequence[int], int]], upsample_kernel_size: Sequence[Union[Sequence[int], int]], norm_name: str = "instance", - deep_supervision: bool = True, deep_supr_num: int = 1, res_block: bool = False, ): @@ -110,7 +105,6 @@ def __init__( self.strides = strides self.upsample_kernel_size = upsample_kernel_size self.norm_name = norm_name - self.deep_supervision = deep_supervision self.conv_block = UnetResBlock if res_block else UnetBasicBlock self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))] self.input_block = self.get_input_block() @@ -143,8 +137,6 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck): return bottleneck elif index == 0: # don't associate a supervision head with self.input_block current_head, rest_heads = nn.Identity(), superheads - elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one - current_head, rest_heads = nn.Identity(), superheads[1:] else: current_head, rest_heads = superheads[0], superheads[1:] @@ -183,12 +175,14 @@ def check_deep_supr_num(self): def forward(self, x): out = self.skip_layers(x) - out = self.output_block(out) + return self.output_block(out) - if self.training and self.deep_supervision: - return [out] + self.heads[1 : self.deep_supr_num + 1] + def get_feature_maps(self): + """ + Return the feature maps. - return [out] + """ + return self.heads[1 : self.deep_supr_num + 1] def get_input_block(self): return self.conv_block( diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 565188df80..101f490bfe 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -43,7 +43,6 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "batch", - "deep_supervision": False, "res_block": res_block, }, (1, in_channels, in_size, in_size), @@ -66,7 +65,6 @@ "strides": ((1, 2, 1), 2, 2, 1), "upsample_kernel_size": (2, 2, 1), "norm_name": "instance", - "deep_supervision": False, "res_block": res_block, }, (1, in_channels, in_size, in_size, in_size), @@ -88,7 +86,6 @@ "strides": strides, "upsample_kernel_size": strides[1:], "norm_name": "group", - "deep_supervision": True, "deep_supr_num": deep_supr_num, "res_block": res_block, }, @@ -110,7 +107,7 @@ def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) with eval_mode(net): result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result[0].shape, expected_shape) + self.assertEqual(result.shape, expected_shape) def test_script(self): input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] @@ -124,7 +121,7 @@ class TestDynUNetDeepSupervision(unittest.TestCase): def test_shape(self, input_param, input_shape, expected_shape): net = DynUNet(**input_param).to(device) with torch.no_grad(): - results = net(torch.randn(input_shape).to(device)) + results = [net(torch.randn(input_shape).to(device))] + net.get_feature_maps() self.assertEqual(len(results), len(expected_shape)) for idx in range(len(results)): result, sub_expected_shape = results[idx], expected_shape[idx]