diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 79e8967406..1341b6d17c 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -228,9 +228,9 @@ def forward(self, x): x = self.relu4(x) new_features = self.conv4(x) - self.dropout_prob = 0 # Dropout will make trouble! + self.dropout_prob = 0.0 # Dropout will make trouble! # since we use the train mode for inference - if self.dropout_prob > 0: + if self.dropout_prob > 0.0: new_features = F.dropout(new_features, p=self.dropout_prob, training=self.training) return torch.cat([inx, new_features], 1) @@ -300,7 +300,7 @@ def forward(self, x): x16 = self.up16(self.proj16(self.pool16(x))) x8 = self.up8(self.proj8(self.pool8(x))) else: - interpolate_size = tuple(x.size()[2:]) + interpolate_size = x.shape[2:] x64 = F.interpolate( self.proj64(self.pool64(x)), size=interpolate_size, diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 537b3238b3..acff225e8b 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -16,7 +16,7 @@ from monai.networks.blocks import FCN, MCFCN from monai.networks.nets import AHNet -from tests.utils import skip_if_quick +from tests.utils import skip_if_quick, test_script_save TEST_CASE_FCN_1 = [ {"out_channels": 3, "upsample_mode": "transpose"}, @@ -139,6 +139,12 @@ def test_ahnet_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = AHNet(spatial_dims=3, out_channels=2) + test_data = torch.randn(1, 1, 128, 128, 64) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + class TestAHNETWithPretrain(unittest.TestCase): @parameterized.expand( diff --git a/tests/test_discriminator.py b/tests/test_discriminator.py index 35792fa132..345cc72ec4 100644 --- a/tests/test_discriminator.py +++ b/tests/test_discriminator.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.networks.nets import Discriminator +from tests.utils import test_script_save TEST_CASE_0 = [ {"in_shape": (1, 64, 64), "channels": (2, 4, 8), "strides": (2, 2, 2), "num_res_units": 0}, @@ -46,6 +47,12 @@ def test_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = Discriminator(in_shape=(1, 64, 64), channels=(2, 4), strides=(2, 2), num_res_units=0) + test_data = torch.rand(16, 1, 64, 64) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_generator.py b/tests/test_generator.py index 7af35a6711..044c662754 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.networks.nets import Generator +from tests.utils import test_script_save TEST_CASE_0 = [ {"latent_shape": (64,), "start_shape": (8, 8, 8), "channels": (8, 4, 1), "strides": (2, 2, 2), "num_res_units": 0}, @@ -46,6 +47,12 @@ def test_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = Generator(latent_shape=(64,), start_shape=(8, 8, 8), channels=(8, 1), strides=(2, 2), num_res_units=2) + test_data = torch.rand(16, 64) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_unet.py b/tests/test_unet.py index e779561f58..0b48d477d8 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -16,6 +16,7 @@ from monai.networks.layers import Act, Norm from monai.networks.nets import UNet +from tests.utils import test_script_save TEST_CASE_0 = [ # single channel 2D, batch 16, no residual { @@ -123,6 +124,12 @@ def test_shape(self, input_param, input_data, expected_shape): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = UNet(dimensions=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0) + test_data = torch.randn(16, 1, 32, 32) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vnet.py b/tests/test_vnet.py index 6067518a7a..a15910784d 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.networks.nets import VNet +from tests.utils import test_script_save TEST_CASE_VNET_2D_1 = [ {"spatial_dims": 2, "in_channels": 4, "out_channels": 1, "act": "elu", "dropout_dim": 1}, @@ -65,3 +66,9 @@ def test_vnet_shape(self, input_param, input_data, expected_shape): with torch.no_grad(): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape) + + def test_script(self): + net = VNet(spatial_dims=3, in_channels=1, out_channels=3, dropout_dim=3) + test_data = torch.randn(1, 1, 32, 32, 32) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) diff --git a/tests/utils.py b/tests/utils.py index 5a98694539..dbbd9ba4c3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from io import BytesIO from subprocess import PIPE, Popen import numpy as np @@ -97,6 +98,19 @@ def expect_failure_if_no_gpu(test): return test +def test_script_save(net, inputs): + scripted = torch.jit.script(net) + buffer = scripted.save_to_buffer() + reloaded_net = torch.jit.load(BytesIO(buffer)) + net.eval() + reloaded_net.eval() + with torch.no_grad(): + result1 = net(inputs) + result2 = reloaded_net(inputs) + + return result1, result2 + + def query_memory(n=2): """ Find best n idle devices and return a string of device ids.