From cacda36635c918791d192d6822abd7acccf7464f Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Mon, 2 Nov 2020 11:51:09 +0000 Subject: [PATCH 1/2] skip lltm test if code not compiled Signed-off-by: Rich <33289025+rijobro@users.noreply.github.com> --- tests/test_lltm.py | 72 ++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/tests/test_lltm.py b/tests/test_lltm.py index 5c666a8794..5b8439cf11 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -15,40 +15,44 @@ from parameterized import parameterized from monai.networks.layers import LLTM - -TEST_CASE_1 = [ - {"input_features": 32, "state_size": 2}, - torch.tensor([[-0.1622, 0.1663], [0.5465, 0.0459], [-0.1436, 0.6171], [0.3632, -0.0111]]), - torch.tensor([[-1.3773, 0.3348], [0.8353, 1.3064], [-0.2179, 4.1739], [1.3045, -0.1444]]), -] - - -class TestLLTM(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_value(self, input_param, expected_h, expected_c): - torch.manual_seed(0) - x = torch.randn(4, 32) - h = torch.randn(4, 2) - c = torch.randn(4, 2) - new_h, new_c = LLTM(**input_param)(x, (h, c)) - (new_h.sum() + new_c.sum()).backward() - - torch.testing.assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04) - torch.testing.assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04) - - @parameterized.expand([TEST_CASE_1]) - def test_value_cuda(self, input_param, expected_h, expected_c): - device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") - torch.manual_seed(0) - x = torch.randn(4, 32).to(device) - h = torch.randn(4, 2).to(device) - c = torch.randn(4, 2).to(device) - lltm = LLTM(**input_param).to(device) - new_h, new_c = lltm(x, (h, c)) - (new_h.sum() + new_c.sum()).backward() - - torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=0.0001, atol=1e-04) - torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=0.0001, atol=1e-04) +from monai.utils import optional_import + +_, is_compiled = optional_import("monai._C") + +if is_compiled: + + TEST_CASE_1 = [ + {"input_features": 32, "state_size": 2}, + torch.tensor([[-0.1622, 0.1663], [0.5465, 0.0459], [-0.1436, 0.6171], [0.3632, -0.0111]]), + torch.tensor([[-1.3773, 0.3348], [0.8353, 1.3064], [-0.2179, 4.1739], [1.3045, -0.1444]]), + ] + + class TestLLTM(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_value(self, input_param, expected_h, expected_c): + torch.manual_seed(0) + x = torch.randn(4, 32) + h = torch.randn(4, 2) + c = torch.randn(4, 2) + new_h, new_c = LLTM(**input_param)(x, (h, c)) + (new_h.sum() + new_c.sum()).backward() + + torch.testing.assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04) + + @parameterized.expand([TEST_CASE_1]) + def test_value_cuda(self, input_param, expected_h, expected_c): + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") + torch.manual_seed(0) + x = torch.randn(4, 32).to(device) + h = torch.randn(4, 2).to(device) + c = torch.randn(4, 2).to(device) + lltm = LLTM(**input_param).to(device) + new_h, new_c = lltm(x, (h, c)) + (new_h.sum() + new_c.sum()).backward() + + torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=0.0001, atol=1e-04) if __name__ == "__main__": From b29f07c03a58336c856016762771be5b9b739c10 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Mon, 2 Nov 2020 12:30:56 +0000 Subject: [PATCH 2/2] implement as decorator Signed-off-by: Rich <33289025+rijobro@users.noreply.github.com> --- tests/test_lltm.py | 75 +++++++++++++++++++++++----------------------- tests/utils.py | 12 ++++++++ 2 files changed, 49 insertions(+), 38 deletions(-) diff --git a/tests/test_lltm.py b/tests/test_lltm.py index 5b8439cf11..41a9ea55fd 100644 --- a/tests/test_lltm.py +++ b/tests/test_lltm.py @@ -15,44 +15,43 @@ from parameterized import parameterized from monai.networks.layers import LLTM -from monai.utils import optional_import - -_, is_compiled = optional_import("monai._C") - -if is_compiled: - - TEST_CASE_1 = [ - {"input_features": 32, "state_size": 2}, - torch.tensor([[-0.1622, 0.1663], [0.5465, 0.0459], [-0.1436, 0.6171], [0.3632, -0.0111]]), - torch.tensor([[-1.3773, 0.3348], [0.8353, 1.3064], [-0.2179, 4.1739], [1.3045, -0.1444]]), - ] - - class TestLLTM(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_value(self, input_param, expected_h, expected_c): - torch.manual_seed(0) - x = torch.randn(4, 32) - h = torch.randn(4, 2) - c = torch.randn(4, 2) - new_h, new_c = LLTM(**input_param)(x, (h, c)) - (new_h.sum() + new_c.sum()).backward() - - torch.testing.assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04) - torch.testing.assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04) - - @parameterized.expand([TEST_CASE_1]) - def test_value_cuda(self, input_param, expected_h, expected_c): - device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") - torch.manual_seed(0) - x = torch.randn(4, 32).to(device) - h = torch.randn(4, 2).to(device) - c = torch.randn(4, 2).to(device) - lltm = LLTM(**input_param).to(device) - new_h, new_c = lltm(x, (h, c)) - (new_h.sum() + new_c.sum()).backward() - - torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=0.0001, atol=1e-04) - torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=0.0001, atol=1e-04) +from tests.utils import SkipIfNoModule + +TEST_CASE_1 = [ + {"input_features": 32, "state_size": 2}, + torch.tensor([[-0.1622, 0.1663], [0.5465, 0.0459], [-0.1436, 0.6171], [0.3632, -0.0111]]), + torch.tensor([[-1.3773, 0.3348], [0.8353, 1.3064], [-0.2179, 4.1739], [1.3045, -0.1444]]), +] + + +class TestLLTM(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + @SkipIfNoModule("monai._C") + def test_value(self, input_param, expected_h, expected_c): + torch.manual_seed(0) + x = torch.randn(4, 32) + h = torch.randn(4, 2) + c = torch.randn(4, 2) + new_h, new_c = LLTM(**input_param)(x, (h, c)) + (new_h.sum() + new_c.sum()).backward() + + torch.testing.assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04) + + @parameterized.expand([TEST_CASE_1]) + @SkipIfNoModule("monai._C") + def test_value_cuda(self, input_param, expected_h, expected_c): + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") + torch.manual_seed(0) + x = torch.randn(4, 32).to(device) + h = torch.randn(4, 2).to(device) + c = torch.randn(4, 2).to(device) + lltm = LLTM(**input_param).to(device) + new_h, new_c = lltm(x, (h, c)) + (new_h.sum() + new_c.sum()).backward() + + torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=0.0001, atol=1e-04) + torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=0.0001, atol=1e-04) if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index dbbd9ba4c3..0e1fb9aee4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,6 +32,18 @@ def skip_if_quick(obj): return unittest.skipIf(is_quick, "Skipping slow tests")(obj) +class SkipIfNoModule(object): + """Decorator to be used if test should be skipped + when optional module is not present.""" + + def __init__(self, module_name): + self.module_name = module_name + self.module_missing = not optional_import(self.module_name)[1] + + def __call__(self, obj): + return unittest.skipIf(self.module_missing, f"optional module not present: {self.module_name}")(obj) + + def make_nifti_image(array, affine=None): """ Create a temporary nifti image on the disk and return the image name.