Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@
from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from monai.data.meta_tensor import MetaTensor
from monai.utils.enums import PostFix
from monai.utils.module import get_torch_version_tuple
from monai.utils.module import pytorch_after
from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda

PT_VER_MAJ, PT_VER_MIN = get_torch_version_tuple()

DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]]
TESTS = []
for _device in TEST_DEVICES:
Expand Down Expand Up @@ -230,7 +228,7 @@ def test_torchscript(self, device):
traced_fn = torch.jit.load(fname)
out = traced_fn(im)
self.assertIsInstance(out, torch.Tensor)
if not isinstance(out, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 9:
if not isinstance(out, MetaTensor) and not pytorch_after(1, 9, 1):
warnings.warn(
"When calling `nn.Module(MetaTensor) on a module traced with "
"`torch.jit.trace`, your version of pytorch returns a "
Expand All @@ -246,7 +244,7 @@ def test_pickling(self):
fname = os.path.join(tmp_dir, "im.pt")
torch.save(m, fname)
m2 = torch.load(fname)
if not isinstance(m2, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 7:
if not isinstance(m2, MetaTensor) and not pytorch_after(1, 8, 1):
warnings.warn("Old version of pytorch. pickling converts `MetaTensor` to `torch.Tensor`.")
m = m.as_tensor()
self.check(m2, m, ids=False)
Expand Down