diff --git a/monai/networks/blocks/squeeze_and_excitation.py b/monai/networks/blocks/squeeze_and_excitation.py index 1d6b823497..e1533d454d 100644 --- a/monai/networks/blocks/squeeze_and_excitation.py +++ b/monai/networks/blocks/squeeze_and_excitation.py @@ -32,6 +32,7 @@ def __init__( r: int = 2, acti_type_1: Union[Tuple[str, Dict], str] = ("relu", {"inplace": True}), acti_type_2: Union[Tuple[str, Dict], str] = "sigmoid", + add_residual: bool = False, ) -> None: """ Args: @@ -51,6 +52,8 @@ def __init__( """ super(ChannelSELayer, self).__init__() + self.add_residual = add_residual + pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims] self.avg_pool = pool_type(1) # spatial size (1, 1, ...) @@ -74,8 +77,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ b, c = x.shape[:2] y: torch.Tensor = self.avg_pool(x).view(b, c) - y = self.fc(y).view([b, c] + [1] * (x.ndimension() - 2)) - return x * y + y = self.fc(y).view([b, c] + [1] * (x.ndim - 2)) + result = x * y + + # Residual connection is moved here instead of providing an override of forward in ResidualSELayer since + # Torchscript has an issue with using super(). + if self.add_residual: + result += x + + return result class ResidualSELayer(ChannelSELayer): @@ -85,7 +95,6 @@ class ResidualSELayer(ChannelSELayer): --+-- SE --o-- | | +--------+ - """ def __init__( @@ -105,21 +114,17 @@ def __init__( acti_type_2: defaults to "relu". See also: - :py:class:`monai.networks.blocks.ChannelSELayer` - """ super().__init__( - spatial_dims=spatial_dims, in_channels=in_channels, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2 + spatial_dims=spatial_dims, + in_channels=in_channels, + r=r, + acti_type_1=acti_type_1, + acti_type_2=acti_type_2, + add_residual=True, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]). - """ - return x + super().forward(x) - class SEBlock(nn.Module): """ @@ -196,28 +201,31 @@ def __init__( spatial_dims=spatial_dims, in_channels=n_chns_3, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2 ) - self.project = project - if self.project is None and in_channels != n_chns_3: + if project is None and in_channels != n_chns_3: self.project = Conv[Conv.CONV, spatial_dims](in_channels, n_chns_3, kernel_size=1) + elif project is None: + self.project = nn.Identity() + else: + self.project = project - self.act = None if acti_type_final is not None: act_final, act_final_args = split_args(acti_type_final) self.act = Act[act_final](**act_final_args) + else: + self.act = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]). """ - residual = x if self.project is None else self.project(x) + residual = self.project(x) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.se_layer(x) x += residual - if self.act is not None: - x = self.act(x) + x = self.act(x) return x @@ -358,7 +366,7 @@ def __init__( conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False} width = math.floor(planes * (base_width / 64)) * groups - super(SEResNeXtBottleneck, self).__init__( + super().__init__( spatial_dims=spatial_dims, in_channels=inplanes, n_chns_1=width, diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index 162a11b0ad..c2239450f2 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -194,11 +194,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i return decode - def forward( - self, x: torch.Tensor - ) -> Union[ - torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - ]: # big tuple return necessary for VAE, which inherits + def forward(self, x: torch.Tensor) -> Any: x = self.encode(x) x = self.intermediate(x) x = self.decode(x) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 0d1aa8c447..1bfd2ce68a 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -20,7 +20,7 @@ from monai.networks.layers.factories import Conv, Dropout, Norm, Pool -class _DenseLayer(nn.Sequential): +class _DenseLayer(nn.Module): def __init__( self, spatial_dims: int, in_channels: int, growth_rate: int, bn_size: int, dropout_prob: float ) -> None: @@ -38,21 +38,23 @@ def __init__( out_channels = bn_size * growth_rate conv_type: Callable = Conv[Conv.CONV, spatial_dims] norm_type: Callable = Norm[Norm.BATCH, spatial_dims] - dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims] + dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims] - self.add_module("norm1", norm_type(in_channels)) - self.add_module("relu1", nn.ReLU(inplace=True)) - self.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False)) + self.layers = nn.Sequential() - self.add_module("norm2", norm_type(out_channels)) - self.add_module("relu2", nn.ReLU(inplace=True)) - self.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False)) + self.layers.add_module("norm1", norm_type(in_channels)) + self.layers.add_module("relu1", nn.ReLU(inplace=True)) + self.layers.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False)) + + self.layers.add_module("norm2", norm_type(out_channels)) + self.layers.add_module("relu2", nn.ReLU(inplace=True)) + self.layers.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False)) if dropout_prob > 0: - self.add_module("dropout", dropout_type(dropout_prob)) + self.layers.add_module("dropout", dropout_type(dropout_prob)) def forward(self, x: torch.Tensor) -> torch.Tensor: - new_features = super(_DenseLayer, self).forward(x) + new_features = self.layers(x) return torch.cat([x, new_features], 1) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index b530f0c6cb..a70da683ba 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -97,6 +97,7 @@ def check_kernel_stride(self): kernels, strides = self.kernel_size, self.strides error_msg = "length of kernel_size and strides should be the same, and no less than 3." assert len(kernels) == len(strides) and len(kernels) >= 3, error_msg + for idx in range(len(kernels)): kernel, stride = kernels[idx], strides[idx] if not isinstance(kernel, int): @@ -115,20 +116,26 @@ def check_deep_supr_num(self): def forward(self, x): out = self.input_block(x) outputs = [out] + for downsample in self.downsamples: out = downsample(out) - outputs.append(out) + outputs.insert(0, out) + out = self.bottleneck(out) upsample_outs = [] - for upsample, skip in zip(self.upsamples, reversed(outputs)): + + for upsample, skip in zip(self.upsamples, outputs): out = upsample(out, skip) upsample_outs.append(out) + out = self.output_block(out) + if self.training and self.deep_supervision: start_output_idx = len(upsample_outs) - 1 - self.deep_supr_num upsample_outs = upsample_outs[start_output_idx:-1][::-1] preds = [self.deep_supervision_heads[i](out) for i, out in enumerate(upsample_outs)] return [out] + preds + return out def get_input_block(self): diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index fba0d17097..c2adfd237a 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -86,6 +86,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.as_tensor(self.layers(x)) +class ChannelPad(nn.Module): + def __init__(self, pad): + super().__init__() + self.pad = tuple(pad) + + def forward(self, x): + return F.pad(x, self.pad) + + class HighResBlock(nn.Module): def __init__( self, @@ -124,21 +133,26 @@ def __init__( norm_type = Normalisation(norm_type) acti_type = Activation(acti_type) - self.project, self.pad = None, None + self.project = None + self.pad = None + if in_channels != out_channels: channel_matching = ChannelMatching(channel_matching) + if channel_matching == ChannelMatching.PROJECT: self.project = conv_type(in_channels, out_channels, kernel_size=1) + if channel_matching == ChannelMatching.PAD: if in_channels > out_channels: raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.') pad_1 = (out_channels - in_channels) // 2 pad_2 = out_channels - in_channels - pad_1 pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0] - self.pad = lambda input: F.pad(input, pad) + self.pad = ChannelPad(pad) layers = nn.ModuleList() _in_chns, _out_chns = in_channels, out_channels + for kernel_size in kernels: layers.append(SUPPORTED_NORM[norm_type](spatial_dims)(_in_chns)) layers.append(SUPPORTED_ACTI[acti_type](inplace=True)) @@ -148,14 +162,18 @@ def __init__( ) ) _in_chns = _out_chns + self.layers = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: x_conv: torch.Tensor = self.layers(x) + if self.project is not None: - return x_conv + torch.as_tensor(self.project(x)) + return x_conv + torch.as_tensor(self.project(x)) # as_tensor used to get around mypy typing bug + if self.pad is not None: return x_conv + torch.as_tensor(self.pad(x)) + return x_conv + x diff --git a/monai/networks/nets/segresnet.py b/monai/networks/nets/segresnet.py index dd0d146c98..21275e96f6 100644 --- a/monai/networks/nets/segresnet.py +++ b/monai/networks/nets/segresnet.py @@ -82,7 +82,7 @@ def __init__( self.relu = Act[Act.RELU](inplace=True) self.conv_final = self._make_final_conv(out_channels) - if dropout_prob: + if dropout_prob is not None: self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob) def _make_down_layers(self): @@ -151,18 +151,20 @@ def _make_final_conv(self, out_channels: int): def forward(self, x): x = self.convInit(x) - if self.dropout_prob: + if self.dropout_prob is not None: x = self.dropout(x) down_x = [] - for i in range(len(self.blocks_down)): - x = self.down_layers[i](x) + + for down in self.down_layers: + x = down(x) down_x.append(x) + down_x.reverse() - for i in range(len(self.blocks_up)): - x = self.up_samples[i](x) + down_x[i + 1] - x = self.up_layers[i](x) + for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): + x = up(x) + down_x[i + 1] + x = upl(x) if self.use_conv_final: x = self.conv_final(x) @@ -239,6 +241,11 @@ def __init__( ) self.input_image_size = input_image_size + self.smallest_filters = 16 + + zoom = 2 ** (len(self.blocks_down) - 1) + self.fc_insize = [s // (2 * zoom) for s in self.input_image_size] + self.vae_estimate_std = vae_estimate_std self.vae_default_std = vae_default_std self.vae_nz = vae_nz @@ -246,10 +253,8 @@ def __init__( self.vae_conv_final = self._make_final_conv(in_channels) def _prepare_vae_modules(self): - self.smallest_filters = 16 zoom = 2 ** (len(self.blocks_down) - 1) v_filters = self.init_filters * zoom - self.fc_insize = list(np.array(self.input_image_size) // (2 * zoom)) total_elements = int(self.smallest_filters * np.prod(self.fc_insize)) self.vae_down = nn.Sequential( @@ -281,23 +286,31 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): x_vae = self.vae_down(vae_input) x_vae = x_vae.view(-1, self.vae_fc1.in_features) z_mean = self.vae_fc1(x_vae) + + z_mean_rand = torch.randn_like(z_mean) + z_mean_rand.requires_grad_(False) + if self.vae_estimate_std: z_sigma = self.vae_fc2(x_vae) z_sigma = F.softplus(z_sigma) vae_reg_loss = 0.5 * torch.mean(z_mean ** 2 + z_sigma ** 2 - torch.log(1e-8 + z_sigma ** 2) - 1) + + x_vae = z_mean + z_sigma * z_mean_rand else: z_sigma = self.vae_default_std vae_reg_loss = torch.mean(z_mean ** 2) - x_vae = z_mean + z_sigma * torch.randn( - z_mean.shape, dtype=z_mean.dtype, device=z_mean.device, requires_grad=False - ) + + x_vae = z_mean + z_sigma * z_mean_rand + x_vae = self.vae_fc3(x_vae) x_vae = self.relu(x_vae) x_vae = x_vae.view([-1, self.smallest_filters] + self.fc_insize) x_vae = self.vae_fc_up_sample(x_vae) - for i in range(len(self.blocks_up)): - x_vae = self.up_samples[i](x_vae) - x_vae = self.up_layers[i](x_vae) + + for up, upl in zip(self.up_samples, self.up_layers): + x_vae = up(x_vae) + x_vae = upl(x_vae) + x_vae = self.vae_conv_final(x_vae) vae_mse_loss = F.mse_loss(net_input, x_vae) vae_loss = vae_reg_loss + vae_mse_loss @@ -306,20 +319,21 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): def forward(self, x): net_input = x x = self.convInit(x) - if self.dropout_prob: + if self.dropout_prob is not None: x = self.dropout(x) down_x = [] - for i in range(len(self.blocks_down)): - x = self.down_layers[i](x) + for down in self.down_layers: + x = down(x) down_x.append(x) + down_x.reverse() vae_input = x - for i in range(len(self.blocks_up)): - x = self.up_samples[i](x) + down_x[i + 1] - x = self.up_layers[i](x) + for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)): + x = up(x) + down_x[i + 1] + x = upl(x) if self.use_conv_final: x = self.conv_final(x) @@ -327,4 +341,5 @@ def forward(self, x): if self.training: vae_loss = self._get_vae_loss(net_input, vae_input) return x, vae_loss - return x + + return x, None diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index d9326176eb..55901430d2 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -87,7 +87,7 @@ def _create_block( c = channels[0] s = strides[0] - subblock: Union[nn.Sequential, ResidualUnit, Convolution] + subblock: nn.Module if len(channels) > 2: subblock = _create_block(c, c, channels[1:], strides[1:], False) # continue recursion down @@ -104,9 +104,7 @@ def _create_block( self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True) - def _get_down_layer( - self, in_channels: int, out_channels: int, strides: int, is_top: bool - ) -> Union[ResidualUnit, Convolution]: + def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: """ Args: in_channels: number of input channels. @@ -138,7 +136,7 @@ def _get_down_layer( dropout=self.dropout, ) - def _get_bottom_layer(self, in_channels: int, out_channels: int) -> Union[ResidualUnit, Convolution]: + def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: """ Args: in_channels: number of input channels. @@ -146,9 +144,7 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int) -> Union[Residu """ return self._get_down_layer(in_channels, out_channels, 1, False) - def _get_up_layer( - self, in_channels: int, out_channels: int, strides: int, is_top: bool - ) -> Union[Convolution, nn.Sequential]: + def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: """ Args: in_channels: number of input channels. diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 7d68cca9a9..a46e8e66d7 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -34,7 +34,7 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f assert labels.dim() > 0, "labels should have dim of 1 or more." # if `dim` is bigger, add singleton dim at the end - if labels.ndimension() < dim + 1: + if labels.ndim < dim + 1: shape = ensure_tuple_size(labels.shape, dim + 1, 1) labels = labels.reshape(*shape) diff --git a/tests/test_autoencoder.py b/tests/test_autoencoder.py index ead9e57689..962f41e303 100644 --- a/tests/test_autoencoder.py +++ b/tests/test_autoencoder.py @@ -5,6 +5,7 @@ from monai.networks.layers import Act from monai.networks.nets import AutoEncoder +from tests.utils import test_script_save device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -70,6 +71,12 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = AutoEncoder(dimensions=2, in_channels=1, out_channels=1, channels=(4, 8), strides=(2, 2)) + test_data = torch.randn(2, 1, 32, 32) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_densenet.py b/tests/test_densenet.py index f7c22632b0..726b4f13e3 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -1,83 +1,93 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks.nets import densenet121, densenet169, densenet201, densenet264 -from tests.utils import skip_if_quick, test_pretrained_networks - -device = "cuda" if torch.cuda.is_available() else "cpu" - -TEST_CASE_1 = [ # 4-channel 3D, batch 2 - {"pretrained": False, "spatial_dims": 3, "in_channels": 2, "out_channels": 3}, - (2, 2, 32, 64, 48), - (2, 3), -] - -TEST_CASE_2 = [ # 4-channel 2D, batch 2 - {"pretrained": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, - (2, 2, 32, 64), - (2, 3), -] - -TEST_CASE_3 = [ # 4-channel 1D, batch 1 - {"pretrained": False, "spatial_dims": 1, "in_channels": 2, "out_channels": 3}, - (1, 2, 32), - (1, 3), -] - -TEST_CASES = [] -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]: - for model in [densenet121, densenet169, densenet201, densenet264]: - TEST_CASES.append([model, *case]) - - -TEST_PRETRAINED_2D_CASE_1 = [ # 4-channel 2D, batch 2 - densenet121, - {"pretrained": True, "progress": True, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, - (2, 2, 32, 64), - (2, 3), -] - -TEST_PRETRAINED_2D_CASE_2 = [ # 4-channel 2D, batch 2 - densenet121, - {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, - (2, 2, 32, 64), - (2, 3), -] - - -class TestPretrainedDENSENET(unittest.TestCase): - @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2]) - @skip_if_quick - def test_121_3d_shape_pretrain(self, model, input_param, input_shape, expected_shape): - net = test_pretrained_networks(model, input_param, device) - net.eval() - with torch.no_grad(): - result = net.forward(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - -class TestDENSENET(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_densenet_shape(self, model, input_param, input_shape, expected_shape): - net = model(**input_param).to(device) - net.eval() - with torch.no_grad(): - result = net.forward(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - -if __name__ == "__main__": - unittest.main() +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets import densenet121, densenet169, densenet201, densenet264 +from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASE_1 = [ # 4-channel 3D, batch 2 + {"pretrained": False, "spatial_dims": 3, "in_channels": 2, "out_channels": 3}, + (2, 2, 32, 64, 48), + (2, 3), +] + +TEST_CASE_2 = [ # 4-channel 2D, batch 2 + {"pretrained": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, + (2, 2, 32, 64), + (2, 3), +] + +TEST_CASE_3 = [ # 4-channel 1D, batch 1 + {"pretrained": False, "spatial_dims": 1, "in_channels": 2, "out_channels": 3}, + (1, 2, 32), + (1, 3), +] + +TEST_CASES = [] +for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]: + for model in [densenet121, densenet169, densenet201, densenet264]: + TEST_CASES.append([model, *case]) + + +TEST_SCRIPT_CASES = [[model, *TEST_CASE_1] for model in [densenet121, densenet169, densenet201, densenet264]] + + +TEST_PRETRAINED_2D_CASE_1 = [ # 4-channel 2D, batch 2 + densenet121, + {"pretrained": True, "progress": True, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, + (2, 2, 32, 64), + (2, 3), +] + +TEST_PRETRAINED_2D_CASE_2 = [ # 4-channel 2D, batch 2 + densenet121, + {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, + (2, 2, 32, 64), + (2, 3), +] + + +class TestPretrainedDENSENET(unittest.TestCase): + @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2]) + @skip_if_quick + def test_121_3d_shape_pretrain(self, model, input_param, input_shape, expected_shape): + net = test_pretrained_networks(model, input_param, device) + net.eval() + with torch.no_grad(): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + +class TestDENSENET(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_densenet_shape(self, model, input_param, input_shape, expected_shape): + net = model(**input_param).to(device) + net.eval() + with torch.no_grad(): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + @parameterized.expand(TEST_SCRIPT_CASES) + def test_script(self, model, input_param, input_shape, expected_shape): + net = model(**input_param) + test_data = torch.randn(input_shape) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 64dba1f6a4..1fe0cc188e 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,6 +17,8 @@ from monai.networks.nets import DynUNet +# from tests.utils import test_script_save + device = "cuda" if torch.cuda.is_available() else "cpu" strides: Sequence[Union[Sequence[int], int]] @@ -112,6 +114,14 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) +# def test_script(self): +# input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] +# net = DynUNet(**input_param) +# test_data = torch.randn(input_shape) +# out_orig, out_reloaded = test_script_save(net, test_data) +# assert torch.allclose(out_orig, out_reloaded) + + class TestDynUNetDeepSupervision(unittest.TestCase): @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) def test_shape(self, input_param, input_shape, expected_shape): diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index 26d67abe6a..cb4fe40cbf 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_padding +from tests.utils import test_script_save TEST_CASE_RES_BASIC_BLOCK = [] for spatial_dims in range(2, 4): @@ -80,6 +81,15 @@ def test_ill_arg(self): with self.assertRaises(AssertionError): UnetResBlock(3, 4, 2, kernel_size=1, stride=4, norm_name="batch") + def test_script(self): + input_param, input_shape, _ = TEST_CASE_RES_BASIC_BLOCK[0] + + for net_type in (UnetResBlock, UnetBasicBlock): + net = net_type(**input_param) + test_data = torch.randn(input_shape) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + class TestUpBlock(unittest.TestCase): @parameterized.expand(TEST_UP_BLOCK) @@ -90,6 +100,15 @@ def test_shape(self, input_param, input_shape, expected_shape, skip_shape): result = net(torch.randn(input_shape), torch.randn(skip_shape)) self.assertEqual(result.shape, expected_shape) + def test_script(self): + input_param, input_shape, _, skip_shape = TEST_UP_BLOCK[0] + + net = UnetUpBlock(**input_param) + test_data = torch.randn(input_shape) + skip_data = torch.randn(skip_shape) + out_orig, out_reloaded = test_script_save(net, test_data, skip_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py index e594198171..79da208d5f 100644 --- a/tests/test_highresnet.py +++ b/tests/test_highresnet.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.networks.nets import HighResNet +from tests.utils import test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -52,6 +53,13 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + def test_script(self): + input_param, input_shape, expected_shape = TEST_CASE_1 + net = HighResNet(**input_param) + test_data = torch.randn(input_shape) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_se_block.py b/tests/test_se_block.py index cc0b3979e2..c2318b5027 100644 --- a/tests/test_se_block.py +++ b/tests/test_se_block.py @@ -16,6 +16,7 @@ from monai.networks.blocks import SEBlock from monai.networks.layers.factories import Act, Norm +from tests.utils import test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -67,6 +68,13 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + def test_script(self): + input_param, input_shape, _ = TEST_CASES[0] + net = SEBlock(**input_param) + test_data = torch.randn(input_shape) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + def test_ill_arg(self): with self.assertRaises(ValueError): SEBlock(spatial_dims=1, in_channels=4, n_chns_1=2, n_chns_2=3, n_chns_3=4, r=100) diff --git a/tests/test_se_blocks.py b/tests/test_se_blocks.py index c8ad05e8aa..8a62fffb43 100644 --- a/tests/test_se_blocks.py +++ b/tests/test_se_blocks.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.networks.blocks import ChannelSELayer, ResidualSELayer +from tests.utils import test_script_save TEST_CASES = [ # single channel 3D, batch 16 [{"spatial_dims": 2, "in_channels": 4, "r": 3}, (7, 4, 64, 48), (7, 4, 64, 48)], # 4-channel 2D, batch 7 @@ -45,6 +46,13 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + def test_script(self): + input_param, input_shape, _ = TEST_CASES[0] + net = ChannelSELayer(**input_param) + test_data = torch.randn(input_shape) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + def test_ill_arg(self): with self.assertRaises(ValueError): ChannelSELayer(spatial_dims=1, in_channels=4, r=100) @@ -59,6 +67,13 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + def test_script(self): + input_param, input_shape, _ = TEST_CASES[0] + net = ResidualSELayer(**input_param) + test_data = torch.randn(input_shape) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_segresnet.py b/tests/test_segresnet.py index 713febbfba..c9225a7811 100644 --- a/tests/test_segresnet.py +++ b/tests/test_segresnet.py @@ -16,6 +16,7 @@ from monai.networks.nets import SegResNet, SegResNetVAE from monai.utils import UpsampleMode +from tests.utils import test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -77,7 +78,7 @@ TEST_CASE_SEGRESNET_VAE.append(test_case) -class TestResBlock(unittest.TestCase): +class TestResNet(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET + TEST_CASE_SEGRESNET_2) def test_shape(self, input_param, input_shape, expected_shape): net = SegResNet(**input_param).to(device) @@ -90,8 +91,15 @@ def test_ill_arg(self): with self.assertRaises(AssertionError): SegResNet(spatial_dims=4) + def test_script(self): + input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET[0] + net = SegResNet(**input_param) + test_data = torch.randn(input_shape) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) -class TestResBlockVAE(unittest.TestCase): + +class TestResNetVAE(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET_VAE) def test_vae_shape(self, input_param, input_shape, expected_shape): net = SegResNetVAE(**input_param).to(device) @@ -99,6 +107,13 @@ def test_vae_shape(self, input_param, input_shape, expected_shape): result, _ = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + def test_script(self): + input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET_VAE[0] + net = SegResNetVAE(**input_param) + test_data = torch.randn(input_shape) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig[0], out_reloaded[0]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_senet.py b/tests/test_senet.py index 9bcd8e8cb8..f8688fc24a 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -22,16 +22,16 @@ se_resnext101_32x4d, senet154, ) -from tests.utils import test_pretrained_networks +from tests.utils import test_pretrained_networks, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_1 = [senet154(3, 2, 2).to(device)] -TEST_CASE_2 = [se_resnet50(3, 2, 2).to(device)] -TEST_CASE_3 = [se_resnet101(3, 2, 2).to(device)] -TEST_CASE_4 = [se_resnet152(3, 2, 2).to(device)] -TEST_CASE_5 = [se_resnext50_32x4d(3, 2, 2).to(device)] -TEST_CASE_6 = [se_resnext101_32x4d(3, 2, 2).to(device)] +TEST_CASE_1 = [senet154(3, 2, 2)] +TEST_CASE_2 = [se_resnet50(3, 2, 2)] +TEST_CASE_3 = [se_resnet101(3, 2, 2)] +TEST_CASE_4 = [se_resnet152(3, 2, 2)] +TEST_CASE_5 = [se_resnext50_32x4d(3, 2, 2)] +TEST_CASE_6 = [se_resnext101_32x4d(3, 2, 2)] TEST_CASE_PRETRAINED = [se_resnet50, {"spatial_dims": 2, "in_channels": 3, "num_classes": 2, "pretrained": True}] @@ -41,11 +41,17 @@ class TestSENET(unittest.TestCase): def test_senet_shape(self, net): input_data = torch.randn(2, 2, 64, 64, 64).to(device) expected_shape = (2, 2) - net.eval() + net = net.to(device).eval() with torch.no_grad(): - result = net.forward(input_data) + result = net(input_data) self.assertEqual(result.shape, expected_shape) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_script(self, net): + input_data = torch.randn(2, 2, 64, 64, 64) + out_orig, out_reloaded = test_script_save(net.cpu(), input_data) + assert torch.allclose(out_orig, out_reloaded) + class TestPretrainedSENET(unittest.TestCase): @parameterized.expand( @@ -57,9 +63,9 @@ def test_senet_shape(self, model, input_param): net = test_pretrained_networks(model, input_param, device) input_data = torch.randn(3, 3, 64, 64).to(device) expected_shape = (3, 2) - net.eval() + net = net.to(device).eval() with torch.no_grad(): - result = net.forward(input_data) + result = net(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py index f14fcb2941..f2756008aa 100644 --- a/tests/test_varautoencoder.py +++ b/tests/test_varautoencoder.py @@ -5,6 +5,7 @@ from monai.networks.layers import Act from monai.networks.nets import VarAutoEncoder +from tests.utils import test_script_save device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -74,6 +75,14 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device))[0] self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = VarAutoEncoder( + dimensions=2, in_shape=(1, 32, 32), out_channels=1, latent_size=2, channels=(4, 8), strides=(2, 2) + ) + test_data = torch.randn(2, 1, 32, 32) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig[0], out_reloaded[0]) + if __name__ == "__main__": unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 7d9dd6689e..9ba7023bff 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -123,15 +123,25 @@ def setUp(self): self.segn = torch.tensor(self.segn) -def test_script_save(net, inputs): +def test_script_save(net, *inputs, eval_nets=True): + """ + Test the ability to save `net` as a Torchscript object, reload it, and apply inference. The value `inputs` is + forward-passed through the original and loaded copy of the network and their results returned. Both `net` and its + reloaded copy are set to evaluation mode if `eval_nets` is True. The forward pass for both is done without + gradient accumulation. + """ + scripted = torch.jit.script(net) buffer = scripted.save_to_buffer() reloaded_net = torch.jit.load(BytesIO(buffer)) - net.eval() - reloaded_net.eval() + + if eval_nets: + net.eval() + reloaded_net.eval() + with torch.no_grad(): - result1 = net(inputs) - result2 = reloaded_net(inputs) + result1 = net(*inputs) + result2 = reloaded_net(*inputs) return result1, result2