Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions monai/networks/nets/ahnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ def forward(self, x):
x = self.relu4(x)
new_features = self.conv4(x)

self.dropout_prob = 0 # Dropout will make trouble!
self.dropout_prob = 0.0 # Dropout will make trouble!
# since we use the train mode for inference
if self.dropout_prob > 0:
if self.dropout_prob > 0.0:
new_features = F.dropout(new_features, p=self.dropout_prob, training=self.training)
return torch.cat([inx, new_features], 1)

Expand Down Expand Up @@ -300,7 +300,7 @@ def forward(self, x):
x16 = self.up16(self.proj16(self.pool16(x)))
x8 = self.up8(self.proj8(self.pool8(x)))
else:
interpolate_size = tuple(x.size()[2:])
interpolate_size = x.shape[2:]
x64 = F.interpolate(
self.proj64(self.pool64(x)),
size=interpolate_size,
Expand Down
8 changes: 7 additions & 1 deletion tests/test_ahnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from monai.networks.blocks import FCN, MCFCN
from monai.networks.nets import AHNet
from tests.utils import skip_if_quick
from tests.utils import skip_if_quick, test_script_save

TEST_CASE_FCN_1 = [
{"out_channels": 3, "upsample_mode": "transpose"},
Expand Down Expand Up @@ -139,6 +139,12 @@ def test_ahnet_shape(self, input_param, input_data, expected_shape):
result = net.forward(input_data)
self.assertEqual(result.shape, expected_shape)

def test_script(self):
net = AHNet(spatial_dims=3, out_channels=2)
test_data = torch.randn(1, 1, 128, 128, 64)
out_orig, out_reloaded = test_script_save(net, test_data)
assert torch.allclose(out_orig, out_reloaded)


class TestAHNETWithPretrain(unittest.TestCase):
@parameterized.expand(
Expand Down
7 changes: 7 additions & 0 deletions tests/test_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from parameterized import parameterized

from monai.networks.nets import Discriminator
from tests.utils import test_script_save

TEST_CASE_0 = [
{"in_shape": (1, 64, 64), "channels": (2, 4, 8), "strides": (2, 2, 2), "num_res_units": 0},
Expand Down Expand Up @@ -46,6 +47,12 @@ def test_shape(self, input_param, input_data, expected_shape):
result = net.forward(input_data)
self.assertEqual(result.shape, expected_shape)

def test_script(self):
net = Discriminator(in_shape=(1, 64, 64), channels=(2, 4), strides=(2, 2), num_res_units=0)
test_data = torch.rand(16, 1, 64, 64)
out_orig, out_reloaded = test_script_save(net, test_data)
assert torch.allclose(out_orig, out_reloaded)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from parameterized import parameterized

from monai.networks.nets import Generator
from tests.utils import test_script_save

TEST_CASE_0 = [
{"latent_shape": (64,), "start_shape": (8, 8, 8), "channels": (8, 4, 1), "strides": (2, 2, 2), "num_res_units": 0},
Expand Down Expand Up @@ -46,6 +47,12 @@ def test_shape(self, input_param, input_data, expected_shape):
result = net.forward(input_data)
self.assertEqual(result.shape, expected_shape)

def test_script(self):
net = Generator(latent_shape=(64,), start_shape=(8, 8, 8), channels=(8, 1), strides=(2, 2), num_res_units=2)
test_data = torch.rand(16, 64)
out_orig, out_reloaded = test_script_save(net, test_data)
assert torch.allclose(out_orig, out_reloaded)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.networks.layers import Act, Norm
from monai.networks.nets import UNet
from tests.utils import test_script_save

TEST_CASE_0 = [ # single channel 2D, batch 16, no residual
{
Expand Down Expand Up @@ -123,6 +124,12 @@ def test_shape(self, input_param, input_data, expected_shape):
result = net.forward(input_data)
self.assertEqual(result.shape, expected_shape)

def test_script(self):
net = UNet(dimensions=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0)
test_data = torch.randn(16, 1, 32, 32)
out_orig, out_reloaded = test_script_save(net, test_data)
assert torch.allclose(out_orig, out_reloaded)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_vnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from parameterized import parameterized

from monai.networks.nets import VNet
from tests.utils import test_script_save

TEST_CASE_VNET_2D_1 = [
{"spatial_dims": 2, "in_channels": 4, "out_channels": 1, "act": "elu", "dropout_dim": 1},
Expand Down Expand Up @@ -65,3 +66,9 @@ def test_vnet_shape(self, input_param, input_data, expected_shape):
with torch.no_grad():
result = net.forward(input_data)
self.assertEqual(result.shape, expected_shape)

def test_script(self):
net = VNet(spatial_dims=3, in_channels=1, out_channels=3, dropout_dim=3)
test_data = torch.randn(1, 1, 32, 32, 32)
out_orig, out_reloaded = test_script_save(net, test_data)
assert torch.allclose(out_orig, out_reloaded)
14 changes: 14 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import tempfile
import unittest
from io import BytesIO
from subprocess import PIPE, Popen

import numpy as np
Expand Down Expand Up @@ -97,6 +98,19 @@ def expect_failure_if_no_gpu(test):
return test


def test_script_save(net, inputs):
scripted = torch.jit.script(net)
buffer = scripted.save_to_buffer()
reloaded_net = torch.jit.load(BytesIO(buffer))
net.eval()
reloaded_net.eval()
with torch.no_grad():
result1 = net(inputs)
result2 = reloaded_net(inputs)

return result1, result2


def query_memory(n=2):
"""
Find best n idle devices and return a string of device ids.
Expand Down