Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ class DynUNet(nn.Module):
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: [``"batch"``, ``"instance"``, ``"group"``]
feature normalization type and arguments.
deep_supervision: whether to add deep supervision head before output. Defaults to ``True``.
If added, in training mode, the network will output not only the last feature maps
(after being converted via output block), but also the previous feature maps that come
from the intermediate up sample layers.
deep_supr_num: number of feature maps that will output during deep supervision head. The
value should be less than the number of up sample layers. Defaults to 1.
res_block: whether to use residual connection based convolution blocks during the network.
Expand All @@ -98,7 +94,6 @@ def __init__(
strides: Sequence[Union[Sequence[int], int]],
upsample_kernel_size: Sequence[Union[Sequence[int], int]],
norm_name: str = "instance",
deep_supervision: bool = True,
deep_supr_num: int = 1,
res_block: bool = False,
):
Expand All @@ -110,7 +105,6 @@ def __init__(
self.strides = strides
self.upsample_kernel_size = upsample_kernel_size
self.norm_name = norm_name
self.deep_supervision = deep_supervision
self.conv_block = UnetResBlock if res_block else UnetBasicBlock
self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))]
self.input_block = self.get_input_block()
Expand Down Expand Up @@ -143,8 +137,6 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):
return bottleneck
elif index == 0: # don't associate a supervision head with self.input_block
current_head, rest_heads = nn.Identity(), superheads
elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one
current_head, rest_heads = nn.Identity(), superheads[1:]
else:
current_head, rest_heads = superheads[0], superheads[1:]

Expand Down Expand Up @@ -183,12 +175,14 @@ def check_deep_supr_num(self):

def forward(self, x):
out = self.skip_layers(x)
out = self.output_block(out)
return self.output_block(out)

if self.training and self.deep_supervision:
return [out] + self.heads[1 : self.deep_supr_num + 1]
def get_feature_maps(self):
"""
Return the feature maps.

return [out]
"""
return self.heads[1 : self.deep_supr_num + 1]

def get_input_block(self):
return self.conv_block(
Expand Down
7 changes: 2 additions & 5 deletions tests/test_dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
"strides": strides,
"upsample_kernel_size": strides[1:],
"norm_name": "batch",
"deep_supervision": False,
"res_block": res_block,
},
(1, in_channels, in_size, in_size),
Expand All @@ -66,7 +65,6 @@
"strides": ((1, 2, 1), 2, 2, 1),
"upsample_kernel_size": (2, 2, 1),
"norm_name": "instance",
"deep_supervision": False,
"res_block": res_block,
},
(1, in_channels, in_size, in_size, in_size),
Expand All @@ -88,7 +86,6 @@
"strides": strides,
"upsample_kernel_size": strides[1:],
"norm_name": "group",
"deep_supervision": True,
"deep_supr_num": deep_supr_num,
"res_block": res_block,
},
Expand All @@ -110,7 +107,7 @@ def test_shape(self, input_param, input_shape, expected_shape):
net = DynUNet(**input_param).to(device)
with eval_mode(net):
result = net(torch.randn(input_shape).to(device))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result.shape, expected_shape)

def test_script(self):
input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0]
Expand All @@ -124,7 +121,7 @@ class TestDynUNetDeepSupervision(unittest.TestCase):
def test_shape(self, input_param, input_shape, expected_shape):
net = DynUNet(**input_param).to(device)
with torch.no_grad():
results = net(torch.randn(input_shape).to(device))
results = [net(torch.randn(input_shape).to(device))] + net.get_feature_maps()
self.assertEqual(len(results), len(expected_shape))
for idx in range(len(results)):
result, sub_expected_shape = results[idx], expected_shape[idx]
Expand Down