From 28f51d16d9a9add15e916f9fd3adcedc0166d273 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 13 Apr 2022 11:34:35 +0100 Subject: [PATCH] fixes pytorch version tests Signed-off-by: Wenqi Li --- tests/test_meta_tensor.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 1721e7d2b9..7688950a4b 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -24,11 +24,9 @@ from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms from monai.data.meta_tensor import MetaTensor from monai.utils.enums import PostFix -from monai.utils.module import get_torch_version_tuple +from monai.utils.module import pytorch_after from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda -PT_VER_MAJ, PT_VER_MIN = get_torch_version_tuple() - DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] TESTS = [] for _device in TEST_DEVICES: @@ -230,7 +228,7 @@ def test_torchscript(self, device): traced_fn = torch.jit.load(fname) out = traced_fn(im) self.assertIsInstance(out, torch.Tensor) - if not isinstance(out, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 9: + if not isinstance(out, MetaTensor) and not pytorch_after(1, 9, 1): warnings.warn( "When calling `nn.Module(MetaTensor) on a module traced with " "`torch.jit.trace`, your version of pytorch returns a " @@ -246,7 +244,7 @@ def test_pickling(self): fname = os.path.join(tmp_dir, "im.pt") torch.save(m, fname) m2 = torch.load(fname) - if not isinstance(m2, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 7: + if not isinstance(m2, MetaTensor) and not pytorch_after(1, 8, 1): warnings.warn("Old version of pytorch. pickling converts `MetaTensor` to `torch.Tensor`.") m = m.as_tensor() self.check(m2, m, ids=False)