Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 71 additions & 19 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,37 @@
__all__ = ["DynUNet", "DynUnet", "Dynunet"]


class DynUNetSkipLayer(nn.Module):
"""
Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection.
The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet
structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on
looping over lists of layers and accumulating lists of output tensors which much be indexed. The `heads` list is
shared amongst all the instances of this class and is used to store the output from the supervision heads during
forward passes of the network.
"""

heads: List[torch.Tensor]

def __init__(self, index, heads, downsample, upsample, super_head, next_layer):
super().__init__()
self.downsample = downsample
self.upsample = upsample
self.next_layer = next_layer
self.super_head = super_head
self.heads = heads
self.index = index

def forward(self, x):
downout = self.downsample(x)
nextout = self.next_layer(downout)
upout = self.upsample(nextout, downout)

self.heads[self.index] = self.super_head(upout)

return upout


class DynUNet(nn.Module):
"""
This reimplementation of a dynamic UNet (DynUNet) is based on:
Expand Down Expand Up @@ -93,6 +124,43 @@ def __init__(
self.check_kernel_stride()
self.check_deep_supr_num()

# initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1)

def create_skips(index, downsamples, upsamples, superheads, bottleneck):
"""
Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is
done recursively from the top down since a recursive nn.Module subclass is being used to be compatible
with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads`
since the `input_block` is passed to this function as the first item in `downsamples`, however this
shouldn't be associated with a supervision head.
"""

assert len(downsamples) == len(upsamples), f"{len(downsamples)} != {len(upsamples)}"
assert (len(downsamples) - len(superheads)) in (1, 0), f"{len(downsamples)}-(0,1) != {len(superheads)}"

if len(downsamples) == 0: # bottom of the network, pass the bottleneck block
return bottleneck
elif index == 0: # don't associate a supervision head with self.input_block
current_head, rest_heads = nn.Identity(), superheads
elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one
current_head, rest_heads = nn.Identity(), superheads[1:]
else:
current_head, rest_heads = superheads[0], superheads[1:]

# create the next layer down, this will stop at the bottleneck layer
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck)

return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer)

self.skip_layers = create_skips(
0,
[self.input_block] + list(self.downsamples),
self.upsamples[::-1],
self.deep_supervision_heads,
self.bottleneck,
)

def check_kernel_stride(self):
kernels, strides = self.kernel_size, self.strides
error_msg = "length of kernel_size and strides should be the same, and no less than 3."
Expand All @@ -114,29 +182,13 @@ def check_deep_supr_num(self):
assert 1 <= deep_supr_num < num_up_layers, error_msg

def forward(self, x):
out = self.input_block(x)
outputs = [out]

for downsample in self.downsamples:
out = downsample(out)
outputs.insert(0, out)

out = self.bottleneck(out)
upsample_outs = []

for upsample, skip in zip(self.upsamples, outputs):
out = upsample(out, skip)
upsample_outs.append(out)

out = self.skip_layers(x)
out = self.output_block(out)

if self.training and self.deep_supervision:
start_output_idx = len(upsample_outs) - 1 - self.deep_supr_num
upsample_outs = upsample_outs[start_output_idx:-1][::-1]
preds = [self.deep_supervision_heads[i](out) for i, out in enumerate(upsample_outs)]
return [out] + preds
return [out] + self.heads[1 : self.deep_supr_num + 1]

return out
return [out]

def get_input_block(self):
return self.conv_block(
Expand Down
16 changes: 7 additions & 9 deletions tests/test_dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from parameterized import parameterized

from monai.networks.nets import DynUNet

# from tests.utils import test_script_save
from tests.utils import test_script_save

device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -111,14 +110,13 @@ def test_shape(self, input_param, input_shape, expected_shape):
net.eval()
with torch.no_grad():
result = net(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

self.assertEqual(result[0].shape, expected_shape)

# def test_script(self):
# input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0]
# net = DynUNet(**input_param)
# test_data = torch.randn(input_shape)
# test_script_save(net, test_data)
def test_script(self):
input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0]
net = DynUNet(**input_param)
test_data = torch.randn(input_shape)
test_script_save(net, test_data)


class TestDynUNetDeepSupervision(unittest.TestCase):
Expand Down
11 changes: 5 additions & 6 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch.distributed as dist

from monai.data import create_test_image_2d, create_test_image_3d
from monai.utils import optional_import, set_determinism
from monai.utils import ensure_tuple, optional_import, set_determinism

nib, _ = optional_import("nibabel")

Expand Down Expand Up @@ -457,11 +457,10 @@ def test_script_save(net, *inputs, eval_nets=True, device=None, rtol=1e-4):
result1 = net(*inputs)
result2 = reloaded_net(*inputs)
set_determinism(seed=None)
# When using e.g., VAR, we will produce a tuple of outputs.
# Hence, convert all to tuples and then compare all elements.
if not isinstance(result1, tuple):
result1 = (result1,)
result2 = (result2,)

# convert results to tuples if needed to allow iterating over pairs of outputs
result1 = ensure_tuple(result1)
result2 = ensure_tuple(result2)

for i, (r1, r2) in enumerate(zip(result1, result2)):
if None not in (r1, r2): # might be None
Expand Down