From 7c9d4dc75e89dc8507ac82af6db6db504244be76 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 21 Jan 2021 18:35:28 +0800 Subject: [PATCH 1/3] [DLMED] fix TorchScript issue in AHNet Signed-off-by: Nic Ma --- monai/networks/nets/ahnet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 5146930fca..847993bd44 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -371,6 +371,7 @@ def __init__( self.pool_type = pool_type self.spatial_dims = spatial_dims self.psp_block_num = psp_block_num + self.psp = None if spatial_dims not in [2, 3]: raise AssertionError("spatial_dims can only be 2 or 3.") @@ -510,7 +511,7 @@ def forward(self, x): sum4 = self.up3(d3) + conv_x d4 = self.dense4(sum4) - if self.psp_block_num > 0: + if self.psp_block_num > 0 and self.psp is not None: psp = self.psp(d4) x = torch.cat((psp, d4), dim=1) else: From 8aa2acfa5c71719333d4be70d6148901c77c9ed3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 21 Jan 2021 21:15:34 +0800 Subject: [PATCH 2/3] [DLMED] add test cases Signed-off-by: Nic Ma --- tests/test_ahnet.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_ahnet.py b/tests/test_ahnet.py index 3dc8c05cf2..777e2637a7 100644 --- a/tests/test_ahnet.py +++ b/tests/test_ahnet.py @@ -191,9 +191,14 @@ def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape): @skip_if_quick def test_script(self): + # test 2D network net = AHNet(spatial_dims=2, out_channels=2) test_data = torch.randn(1, 1, 128, 64) test_script_save(net, test_data) + # test 3D network + net = AHNet(spatial_dims=3, out_channels=2, psp_block_num=0, upsample_mode="nearest") + test_data = torch.randn(1, 1, 32, 32, 64) + test_script_save(net, test_data) class TestAHNETWithPretrain(unittest.TestCase): From c6d30c37a7534d0de8268823792e43a5d96abe0c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Feb 2021 17:39:36 +0800 Subject: [PATCH 3/3] [DLMED] fix torchscript issue in PyTorch 1.5 Signed-off-by: Nic Ma --- monai/networks/nets/ahnet.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/ahnet.py b/monai/networks/nets/ahnet.py index 847993bd44..3321001af0 100644 --- a/monai/networks/nets/ahnet.py +++ b/monai/networks/nets/ahnet.py @@ -434,8 +434,7 @@ def __init__( self.dense4 = DenseBlock(spatial_dims, ndenselayer, num_init_features, densebn, densegrowth, 0.0) noutdense4 = num_init_features + densegrowth * ndenselayer - if psp_block_num > 0: - self.psp = PSP(spatial_dims, psp_block_num, noutdense4, upsample_mode) + self.psp = PSP(spatial_dims, psp_block_num, noutdense4, upsample_mode) self.final = Final(spatial_dims, psp_block_num + noutdense4, out_channels, upsample_mode) # Initialise parameters @@ -511,7 +510,7 @@ def forward(self, x): sum4 = self.up3(d3) + conv_x d4 = self.dense4(sum4) - if self.psp_block_num > 0 and self.psp is not None: + if self.psp_block_num > 0: psp = self.psp(d4) x = torch.cat((psp, d4), dim=1) else: