From c2fadf23b0a733a28d4c13cf093884deb709de09 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 20 Oct 2020 00:04:22 +0800 Subject: [PATCH 1/4] [DLMED] add TorchScript tests Signed-off-by: Nic Ma --- monai/networks/nets/ahnet.py | 4 ++-- tests/test_ahnet.py | 8 +++++++- tests/utils.py | 9 ++++++++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 413c515a0a..e166f1e3ee 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -227,9 +227,9 @@ def forward(self, x): x = self.relu4(x) new_features = self.conv4(x) - self.dropout_prob = 0 # Dropout will make trouble! + self.dropout_prob = 0.0 # Dropout will make trouble! # since we use the train mode for inference - if self.dropout_prob > 0: + if self.dropout_prob > 0.0: new_features = F.dropout(new_features, p=self.dropout_prob, training=self.training) return torch.cat([inx, new_features], 1) diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 509cfbc59c..4290ad427e 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -16,7 +16,7 @@ from monai.networks.blocks import FCN, MCFCN from monai.networks.nets import AHNet -from tests.utils import skip_if_quick +from tests.utils import skip_if_quick, test_script_save TEST_CASE_FCN_1 = [ {"out_channels": 3, "upsample_mode": "transpose"}, @@ -139,6 +139,12 @@ def test_ahnet_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + def test_ahnet_script(self): + net = AHNet(in_channels=2) + test_data = torch.randn(2, 2, 128, 128, 64) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + class TestAHNETWithPretrain(unittest.TestCase): @parameterized.expand( diff --git a/tests/utils.py b/tests/utils.py index 5a98694539..418b1c2775 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,7 @@ import tempfile import unittest from subprocess import PIPE, Popen - +from io import BytesIO import numpy as np import torch @@ -96,6 +96,13 @@ def expect_failure_if_no_gpu(test): else: return test +def test_script_save(net, *inputs): + scripted = torch.jit.script(net) + buffer = scripted.save_to_buffer() + reloaded_net = torch.jit.load(BytesIO(buffer)) + + return net(*inputs), reloaded_net(*inputs) + def query_memory(n=2): """ From 8ec40febfeaf9c3c76fda3e868b1452f4fc1158a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 20 Oct 2020 07:21:11 +0800 Subject: [PATCH 2/4] [DLMED] fix dynamic input data size Signed-off-by: Nic Ma --- monai/networks/nets/ahnet.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index e166f1e3ee..a294cf41c2 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -299,28 +299,27 @@ def forward(self, x): x16 = self.up16(self.proj16(self.pool16(x))) x8 = self.up8(self.proj8(self.pool8(x))) else: - interpolate_size = tuple(x.size()[2:]) x64 = F.interpolate( self.proj64(self.pool64(x)), - size=interpolate_size, + scale_factor=(64, 64, 1)[-self.spatial_dims:], mode=self.upsample_mode, align_corners=True, ) x32 = F.interpolate( self.proj32(self.pool32(x)), - size=interpolate_size, + scale_factor=(32, 32, 1)[-self.spatial_dims:], mode=self.upsample_mode, align_corners=True, ) x16 = F.interpolate( self.proj16(self.pool16(x)), - size=interpolate_size, + scale_factor=(16, 16, 1)[-self.spatial_dims:], mode=self.upsample_mode, align_corners=True, ) x8 = F.interpolate( self.proj8(self.pool8(x)), - size=interpolate_size, + scale_factor=(8, 8, 1)[-self.spatial_dims:], mode=self.upsample_mode, align_corners=True, ) From f9e86c68018c89bba5c3b31c0c19348c12d6bae1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 20 Oct 2020 23:09:13 +0800 Subject: [PATCH 3/4] [DLMED] Add test for networks Signed-off-by: Nic Ma --- monai/networks/nets/ahnet.py | 11 +++++++---- monai/networks/nets/dynunet.py | 2 ++ tests/test_ahnet.py | 8 +------- tests/test_discriminator.py | 8 +++++++- tests/test_generator.py | 8 +++++++- tests/test_unet.py | 8 +++++++- tests/test_vnet.py | 8 +++++++- tests/utils.py | 9 +++++++-- 8 files changed, 45 insertions(+), 17 deletions(-) diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index a294cf41c2..77787e9cc7 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -299,27 +299,28 @@ def forward(self, x): x16 = self.up16(self.proj16(self.pool16(x))) x8 = self.up8(self.proj8(self.pool8(x))) else: + interpolate_size = tuple(x.size()[2:]) x64 = F.interpolate( self.proj64(self.pool64(x)), - scale_factor=(64, 64, 1)[-self.spatial_dims:], + size=interpolate_size, mode=self.upsample_mode, align_corners=True, ) x32 = F.interpolate( self.proj32(self.pool32(x)), - scale_factor=(32, 32, 1)[-self.spatial_dims:], + size=interpolate_size, mode=self.upsample_mode, align_corners=True, ) x16 = F.interpolate( self.proj16(self.pool16(x)), - scale_factor=(16, 16, 1)[-self.spatial_dims:], + size=interpolate_size, mode=self.upsample_mode, align_corners=True, ) x8 = F.interpolate( self.proj8(self.pool8(x)), - scale_factor=(8, 8, 1)[-self.spatial_dims:], + size=interpolate_size, mode=self.upsample_mode, align_corners=True, ) @@ -350,6 +351,8 @@ class AHNet(nn.Module): model2d = monai.networks.blocks.FCN() model.copy_from(model2d) + Note that as this network contains dynamic parameters, it can't support ``torch.jit.script`` export. + Args: layers: number of residual blocks for 4 layers of the network (layer1...layer4). Defaults to ``(3, 4, 6, 3)``. spatial_dims: spatial dimension of the input data. Defaults to 3. diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 9f9a9a3900..294132c2ba 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -37,6 +37,8 @@ class DynUNet(nn.Module): Therefore, pleasure ensure that the length of input sequences (``kernel_size`` and ``strides``) is no less than 3 in order to have at least one downsample upsample blocks. + Note that as this network contains dynamic parameters, it can't support ``torch.jit.script`` export. + Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 4290ad427e..509cfbc59c 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -16,7 +16,7 @@ from monai.networks.blocks import FCN, MCFCN from monai.networks.nets import AHNet -from tests.utils import skip_if_quick, test_script_save +from tests.utils import skip_if_quick TEST_CASE_FCN_1 = [ {"out_channels": 3, "upsample_mode": "transpose"}, @@ -139,12 +139,6 @@ def test_ahnet_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) - def test_ahnet_script(self): - net = AHNet(in_channels=2) - test_data = torch.randn(2, 2, 128, 128, 64) - out_orig, out_reloaded = test_script_save(net, test_data) - assert torch.allclose(out_orig, out_reloaded) - class TestAHNETWithPretrain(unittest.TestCase): @parameterized.expand( diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index 35792fa132..b35a4541e4 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -13,7 +13,7 @@ import torch from parameterized import parameterized - +from tests.utils import test_script_save from monai.networks.nets import Discriminator TEST_CASE_0 = [ @@ -46,6 +46,12 @@ def test_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = Discriminator(in_shape=(1, 64, 64), channels=(2, 4), strides=(2, 2), num_res_units=0) + test_data = torch.rand(16, 1, 64, 64) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_generator.py b/tests/test_generator.py index 7af35a6711..95cc406789 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -13,7 +13,7 @@ import torch from parameterized import parameterized - +from tests.utils import test_script_save from monai.networks.nets import Generator TEST_CASE_0 = [ @@ -46,6 +46,12 @@ def test_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = Generator(latent_shape=(64,), start_shape=(8, 8, 8), channels=(8, 1), strides=(2, 2), num_res_units=2) + test_data = torch.rand(16, 64) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_unet.py b/tests/test_unet.py index e779561f58..a740527d3c 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -13,7 +13,7 @@ import torch from parameterized import parameterized - +from tests.utils import test_script_save from monai.networks.layers import Act, Norm from monai.networks.nets import UNet @@ -123,6 +123,12 @@ def test_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = UNet(dimensions=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0) + test_data = torch.randn(16, 1, 32, 32) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vnet.py b/tests/test_vnet.py index 6067518a7a..23fc084744 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -13,7 +13,7 @@ import torch from parameterized import parameterized - +from tests.utils import test_script_save from monai.networks.nets import VNet TEST_CASE_VNET_2D_1 = [ @@ -65,3 +65,9 @@ def test_vnet_shape(self, input_param, input_data, expected_shape): with torch.no_grad(): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + + def test_script(self): + net = VNet(spatial_dims=3, in_channels=1, out_channels=3, dropout_dim=3) + test_data = torch.randn(1, 1, 32, 32, 32) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) diff --git a/tests/utils.py b/tests/utils.py index 418b1c2775..c24a310741 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -96,12 +96,17 @@ def expect_failure_if_no_gpu(test): else: return test -def test_script_save(net, *inputs): +def test_script_save(net, inputs): scripted = torch.jit.script(net) buffer = scripted.save_to_buffer() reloaded_net = torch.jit.load(BytesIO(buffer)) + net.eval() + reloaded_net.eval() + with torch.no_grad(): + result1 = net.forward(inputs) + result2 = reloaded_net.forward(inputs) - return net(*inputs), reloaded_net(*inputs) + return result1, result2 def query_memory(n=2): From fbeb952e383fa8db8c73e14a1fa0e8591a0748df Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 21 Oct 2020 07:16:12 +0800 Subject: [PATCH 4/4] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/networks/nets/ahnet.py | 4 +--- monai/networks/nets/dynunet.py | 2 -- tests/test_ahnet.py | 8 +++++++- tests/test_discriminator.py | 3 ++- tests/test_generator.py | 3 ++- tests/test_unet.py | 3 ++- tests/test_vnet.py | 3 ++- tests/utils.py | 8 +++++--- 8 files changed, 21 insertions(+), 13 deletions(-) diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 77787e9cc7..aeadee20f3 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -299,7 +299,7 @@ def forward(self, x): x16 = self.up16(self.proj16(self.pool16(x))) x8 = self.up8(self.proj8(self.pool8(x))) else: - interpolate_size = tuple(x.size()[2:]) + interpolate_size = x.shape[2:] x64 = F.interpolate( self.proj64(self.pool64(x)), size=interpolate_size, @@ -351,8 +351,6 @@ class AHNet(nn.Module): model2d = monai.networks.blocks.FCN() model.copy_from(model2d) - Note that as this network contains dynamic parameters, it can't support ``torch.jit.script`` export. - Args: layers: number of residual blocks for 4 layers of the network (layer1...layer4). Defaults to ``(3, 4, 6, 3)``. spatial_dims: spatial dimension of the input data. Defaults to 3. diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index 294132c2ba..9f9a9a3900 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -37,8 +37,6 @@ class DynUNet(nn.Module): Therefore, pleasure ensure that the length of input sequences (``kernel_size`` and ``strides``) is no less than 3 in order to have at least one downsample upsample blocks. - Note that as this network contains dynamic parameters, it can't support ``torch.jit.script`` export. - Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 509cfbc59c..f0136a6c44 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -16,7 +16,7 @@ from monai.networks.blocks import FCN, MCFCN from monai.networks.nets import AHNet -from tests.utils import skip_if_quick +from tests.utils import skip_if_quick, test_script_save TEST_CASE_FCN_1 = [ {"out_channels": 3, "upsample_mode": "transpose"}, @@ -139,6 +139,12 @@ def test_ahnet_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = AHNet(spatial_dims=3, out_channels=2) + test_data = torch.randn(1, 1, 128, 128, 64) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + class TestAHNETWithPretrain(unittest.TestCase): @parameterized.expand( diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index b35a4541e4..345cc72ec4 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -13,8 +13,9 @@ import torch from parameterized import parameterized -from tests.utils import test_script_save + from monai.networks.nets import Discriminator +from tests.utils import test_script_save TEST_CASE_0 = [ {"in_shape": (1, 64, 64), "channels": (2, 4, 8), "strides": (2, 2, 2), "num_res_units": 0}, diff --git a/tests/test_generator.py b/tests/test_generator.py index 95cc406789..044c662754 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -13,8 +13,9 @@ import torch from parameterized import parameterized -from tests.utils import test_script_save + from monai.networks.nets import Generator +from tests.utils import test_script_save TEST_CASE_0 = [ {"latent_shape": (64,), "start_shape": (8, 8, 8), "channels": (8, 4, 1), "strides": (2, 2, 2), "num_res_units": 0}, diff --git a/tests/test_unet.py b/tests/test_unet.py index a740527d3c..0b48d477d8 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -13,9 +13,10 @@ import torch from parameterized import parameterized -from tests.utils import test_script_save + from monai.networks.layers import Act, Norm from monai.networks.nets import UNet +from tests.utils import test_script_save TEST_CASE_0 = [ # single channel 2D, batch 16, no residual { diff --git a/tests/test_vnet.py b/tests/test_vnet.py index 23fc084744..a15910784d 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -13,8 +13,9 @@ import torch from parameterized import parameterized -from tests.utils import test_script_save + from monai.networks.nets import VNet +from tests.utils import test_script_save TEST_CASE_VNET_2D_1 = [ {"spatial_dims": 2, "in_channels": 4, "out_channels": 1, "act": "elu", "dropout_dim": 1}, diff --git a/tests/utils.py b/tests/utils.py index c24a310741..dbbd9ba4c3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,8 +12,9 @@ import os import tempfile import unittest -from subprocess import PIPE, Popen from io import BytesIO +from subprocess import PIPE, Popen + import numpy as np import torch @@ -96,6 +97,7 @@ def expect_failure_if_no_gpu(test): else: return test + def test_script_save(net, inputs): scripted = torch.jit.script(net) buffer = scripted.save_to_buffer() @@ -103,8 +105,8 @@ def test_script_save(net, inputs): net.eval() reloaded_net.eval() with torch.no_grad(): - result1 = net.forward(inputs) - result2 = reloaded_net.forward(inputs) + result1 = net(inputs) + result2 = reloaded_net(inputs) return result1, result2