From 26224025458b6fac52f72efc85e25340cb85abe8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 14:32:00 +0100 Subject: [PATCH 01/27] collate , decollate, dataset, dataloader, out= Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 3 ++ monai/data/utils.py | 18 ++++++++-- tests/test_meta_tensor.py | 73 +++++++++++++++++++++++++++++++++++---- 3 files changed, 85 insertions(+), 9 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 30270d89e2..00943ba433 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -102,6 +102,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: if kwargs is None: kwargs = {} ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) + # if `out` has been used as argument, metadata is not copied, nothing to do. + if "out" in kwargs: + return ret # e.g., __repr__ returns a string if not isinstance(ret, torch.Tensor): return ret diff --git a/monai/data/utils.py b/monai/data/utils.py index 495daf15e2..5a3c1235a2 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -28,6 +28,7 @@ from torch.utils.data._utils.collate import default_collate from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike +from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -346,9 +347,15 @@ def list_data_collate(batch: Sequence): ret = {} for k in elem: key = k - ret[key] = default_collate([d[key] for d in data]) - return ret - return default_collate(data) + data_for_batch = [d[key] for d in data] + ret[key] = default_collate(data_for_batch) + if isinstance(ret[key], MetaTensor) and all(isinstance(d, MetaTensor) for d in data_for_batch): + ret[key].meta = list_data_collate([i.meta for i in data_for_batch]) + else: + ret = default_collate(data) + if isinstance(ret, MetaTensor) and all(isinstance(d, MetaTensor) for d in data): + ret.meta = list_data_collate([i.meta for i in data]) + return ret except RuntimeError as re: re_str = str(re) if "equal size" in re_str: @@ -466,6 +473,11 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if batch.ndim == 0: return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) + # if of type MetaTensor, decollate the metadata and affines + if isinstance(batch, MetaTensor): + metas = decollate_batch(batch.meta) + for i in range(len(out_list)): + out_list[i].meta = metas[i] if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index c18ef08b85..90bfb8d835 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -21,8 +21,10 @@ import torch from parameterized import parameterized +from monai.data import DataLoader, Dataset from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms from monai.data.meta_tensor import MetaTensor +from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix from monai.utils.module import pytorch_after from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda @@ -261,12 +263,71 @@ def test_amp(self): im_conv2 = conv(im) self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3) - # TODO - # collate - # decollate - # dataset - # dataloader - # matplotlib + def test_out(self): + """Test when `out` is given as an argument.""" + m1, _ = self.get_im() + m1_orig = deepcopy(m1) + m2, _ = self.get_im() + m3, _ = self.get_im() + torch.add(m2, m3, out=m1) + m1_add = m2 + m3 + + assert_allclose(m1, m1_add) + aff1, aff1_orig = m1.affine, m1_orig.affine + assert_allclose(aff1, aff1_orig) + meta1 = {k: v for k, v in m1.meta.items() if k != "affine"} + meta1_orig = {k: v for k, v in m1_orig.meta.items() if k != "affine"} + self.assertEqual(meta1, meta1_orig) + + @parameterized.expand(TESTS) + def test_collate(self, device, dtype): + numel = 3 + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)] + collated = list_data_collate(ims) + # tensor + self.assertIsInstance(collated, MetaTensor) + expected_shape = (numel,) + tuple(ims[0].shape) + self.assertTupleEqual(tuple(collated.shape), expected_shape) + for i, im in enumerate(ims): + self.check(im, ims[i], ids=True) + # affine + self.assertIsInstance(collated.affine, torch.Tensor) + expected_shape = (numel,) + tuple(ims[0].affine.shape) + self.assertTupleEqual(tuple(collated.affine.shape), expected_shape) + + @parameterized.expand(TESTS) + def test_dataset(self, device, dtype): + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(4)] + ds = Dataset(ims) + for i, im in enumerate(ds): + self.check(im, ims[i], ids=True) + + @parameterized.expand(DTYPES) + def test_dataloader(self, dtype): + batch_size = 5 + ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] + ds = Dataset(ims) + expected_im_shape = (batch_size,) + tuple(ims[0].shape) + expected_affine_shape = (batch_size,) + tuple(ims[0].affine.shape) + dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) + for batch in dl: + self.assertIsInstance(batch, MetaTensor) + self.assertTupleEqual(tuple(batch.shape), expected_im_shape) + self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + + @parameterized.expand(DTYPES) + def test_decollate(self, dtype): + batch_size = 3 + ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] + ds = Dataset(ims) + dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) + batch = next(iter(dl)) + decollated = decollate_batch(batch) + self.assertIsInstance(decollated, list) + self.assertEqual(len(decollated), batch_size) + for elem, im in zip(decollated, ims): + self.assertIsInstance(elem, MetaTensor) + self.check(elem, im, ids=False) if __name__ == "__main__": From 19e68c9ccfa2da781464f7247ae0b406047f0bb9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 15:37:39 +0100 Subject: [PATCH 02/27] mypy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 5a3c1235a2..18a467803b 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -473,11 +473,11 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if batch.ndim == 0: return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) - # if of type MetaTensor, decollate the metadata and affines + # if of type MetaTensor, decollate the metadata if isinstance(batch, MetaTensor): metas = decollate_batch(batch.meta) for i in range(len(out_list)): - out_list[i].meta = metas[i] + out_list[i].meta = metas[i] # type: ignore if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) From d017918de51ff8e3ab5048d534a14fb535ff797e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 16:27:18 +0100 Subject: [PATCH 03/27] skip decollation for pytorch 1.7 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 90bfb8d835..651c9e0d22 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -27,7 +27,7 @@ from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix from monai.utils.module import pytorch_after -from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda +from tests.utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, skip_if_no_cuda DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] TESTS = [] @@ -315,6 +315,7 @@ def test_dataloader(self, dtype): self.assertTupleEqual(tuple(batch.shape), expected_im_shape) self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + @SkipIfBeforePyTorchVersion((1, 8)) @parameterized.expand(DTYPES) def test_decollate(self, dtype): batch_size = 3 From b36cd10c1c623b926d8e9dd6aca6acf091a00db0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 16:39:53 +0100 Subject: [PATCH 04/27] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 651c9e0d22..e14695c3e1 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -315,8 +315,8 @@ def test_dataloader(self, dtype): self.assertTupleEqual(tuple(batch.shape), expected_im_shape) self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) - @SkipIfBeforePyTorchVersion((1, 8)) @parameterized.expand(DTYPES) + @SkipIfBeforePyTorchVersion((1, 8)) def test_decollate(self, dtype): batch_size = 3 ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] From a8f0373aa728ec36452610071fc8f668cabf7efb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 18:48:08 +0100 Subject: [PATCH 05/27] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index e14695c3e1..8fd31b58b6 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -303,6 +303,7 @@ def test_dataset(self, device, dtype): self.check(im, ims[i], ids=True) @parameterized.expand(DTYPES) + @SkipIfBeforePyTorchVersion((1, 8)) def test_dataloader(self, dtype): batch_size = 5 ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] From 12afd4a4d760f8b9376b6bdda3ff90c0f6c4a60a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:20:55 +0100 Subject: [PATCH 06/27] add batch index testing Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 13 +++++ monai/data/meta_tensor.py | 100 ++++++++++++++++++++++++++++++++---- monai/data/utils.py | 13 +++-- tests/test_meta_tensor.py | 104 +++++++++++++++++++++++++++++++++----- 4 files changed, 201 insertions(+), 29 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 0e213f130b..00e10ca816 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -111,6 +111,7 @@ class MetaObj: def __init__(self): self._meta: dict = self.get_default_meta() + self._is_batch: bool = False @staticmethod def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: @@ -176,6 +177,7 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None: id_in = id(input_objs[0]) if len(input_objs) > 0 else None deep_copy = id(self) != id_in self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) + self.is_batch = input_objs[0].is_batch def get_default_meta(self) -> dict: """Get the default meta. @@ -194,6 +196,7 @@ def __repr__(self) -> str: out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) else: out += "None" + out += f"\nIs batch?: {self.is_batch}" return out @@ -206,3 +209,13 @@ def meta(self) -> dict: def meta(self, d: dict) -> None: """Set the meta.""" self._meta = d + + @property + def is_batch(self) -> bool: + """Return whether object is part of batch or not.""" + return self._is_batch + + @is_batch.setter + def is_batch(self, val: bool) -> None: + """Set whether object is part of batch or not.""" + self._is_batch = val diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index f1d87bd5f0..3a3fdc47db 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -13,11 +13,12 @@ import warnings from copy import deepcopy -from typing import Callable +from typing import Callable, Sequence import torch from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms +from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix __all__ = ["MetaTensor"] @@ -59,6 +60,14 @@ class MetaTensor(MetaObj, torch.Tensor): `torch.jit.trace(net, im.as_tensor())`. - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. + - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute. + - With a batch of data, `batch[0]` will return the 0th image + with the 0th metadata. When the batch dimension is non-singleton, e.g., + `batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the + last example) of the metadata will be returned, and `is_batch` will return `True`. + - When creating a batch with this class, use `monai.data.DataLoader` as opposed + to `torch.utils.data.DataLoader`, as this will take care of collating the + metadata properly. """ @staticmethod @@ -101,24 +110,93 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call if isinstance(val, torch.Tensor): setattr(self, attribute, val.to(self.device)) + @staticmethod + def update_meta(rets: Sequence, func, args, kwargs): + """Update the metadata from the output of `__torch_function__`. + The output could be a single object, or a sequence of them. Hence, they get + converted to a sequence if necessary and then processed by iterating across them. + + For each element, if not of type `MetaTensor`, then nothing to do + """ + out = [] + metas = None + for idx, ret in enumerate(rets): + # if not `MetaTensor`, nothing to do. + if not isinstance(ret, MetaTensor): + pass + # if not tracking, convert to `torch.Tensor`. + elif not (get_track_meta() or get_track_transforms()): + ret = ret.as_tensor() + # else, handle the `MetaTensor` metadata. + else: + meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) + ret._copy_meta(meta_args) + + # If we have a batch of data, then we need to be careful if a slice of + # the data is returned. Depending on how the data are indexed, we return + # some or all of the metadata, and the return object may or may not be a + # batch of data (e.g., `batch[:,-1]` versus `batch[0]`.) + if ret.is_batch: + # only decollate metadata once + if metas is None: + metas = decollate_batch(ret.meta) + # if indexing e.g., `batch[0]` + if func == torch.Tensor.__getitem__: + idx = args[1] + if isinstance(idx, Sequence): + idx = idx[0] + # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the + # first element will be `slice(None, None, None)` and `Ellipsis`, + # respectively. Don't need to do anything with the metadata. + if idx not in (slice(None, None, None), Ellipsis): + meta = metas[idx] + # if using e.g., `batch[0:2]`, then `is_batch` should still be + # `True`. Also re-collate the remaining elements. + if isinstance(meta, list) and len(meta) > 1: + ret.meta = list_data_collate(meta) + # if using e.g., `batch[0]` or `batch[0, 1]`, then return single + # element from batch, and set `is_batch` to `False`. + else: + ret.meta = meta + ret.is_batch = False + # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. + # But we only want to split the batch if the `unbind` is along the 0th + # dimension. + elif func == torch.Tensor.unbind: + if len(args) > 1: + dim = args[1] + elif "dim" in kwargs: + dim = kwargs["dim"] + else: + dim = 0 + if dim == 0: + ret.meta = metas[idx] + ret.is_batch = False + + ret.affine = ret.affine.to(ret.device) + out.append(ret) + # if the input was a tuple, then return it as a tuple + return tuple(out) if isinstance(rets, tuple) else out + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: """Wraps all torch functions.""" if kwargs is None: kwargs = {} - ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) + ret = super().__torch_function__(func, types, args, kwargs) # if `out` has been used as argument, metadata is not copied, nothing to do. if "out" in kwargs: return ret - # e.g., __repr__ returns a string - if not isinstance(ret, torch.Tensor): - return ret - if not (get_track_meta() or get_track_transforms()): - return ret.as_tensor() - meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) - ret._copy_meta(meta_args) - ret.affine = ret.affine.to(ret.device) - return ret + # we might have 1 or multiple outputs. Might be MetaTensor, might be something + # else (e.g., `__repr__` returns a string). + # Convert to list (if necessary), process, and at end remove list if one was added. + if not isinstance(ret, Sequence): + ret = [ret] + unpack = True + else: + unpack = False + ret = MetaTensor.update_meta(ret, func, args, kwargs) + return ret[0] if unpack else ret def get_default_affine(self, dtype=torch.float64) -> torch.Tensor: return torch.eye(4, device=self.device, dtype=dtype) diff --git a/monai/data/utils.py b/monai/data/utils.py index 18a467803b..7f67088a06 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -28,7 +28,7 @@ from torch.utils.data._utils.collate import default_collate from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_obj import MetaObj from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -349,12 +349,14 @@ def list_data_collate(batch: Sequence): key = k data_for_batch = [d[key] for d in data] ret[key] = default_collate(data_for_batch) - if isinstance(ret[key], MetaTensor) and all(isinstance(d, MetaTensor) for d in data_for_batch): + if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch): ret[key].meta = list_data_collate([i.meta for i in data_for_batch]) + ret[key].is_batch = True else: ret = default_collate(data) - if isinstance(ret, MetaTensor) and all(isinstance(d, MetaTensor) for d in data): + if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data): ret.meta = list_data_collate([i.meta for i in data]) + ret.is_batch = True return ret except RuntimeError as re: re_str = str(re) @@ -473,11 +475,12 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if batch.ndim == 0: return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) - # if of type MetaTensor, decollate the metadata - if isinstance(batch, MetaTensor): + # if of type MetaObj, decollate the metadata + if isinstance(batch, MetaObj): metas = decollate_batch(batch.meta) for i in range(len(out_list)): out_list[i].meta = metas[i] # type: ignore + out_list[i].is_batch = False if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 8fd31b58b6..7968b8202e 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -61,6 +61,17 @@ def check_ids(self, a, b, should_match): comp = self.assertEqual if should_match else self.assertNotEqual comp(id(a), id(b)) + def check_meta(self, a: MetaTensor, b: MetaTensor) -> None: + self.assertEqual(a.is_batch, b.is_batch) + meta_a, meta_b = a.meta, b.meta + # need to split affine from rest of metadata + aff_a = meta_a.get("affine", None) + aff_b = meta_b.get("affine", None) + assert_allclose(aff_a, aff_b) + meta_a = {k: v for k, v in meta_a.items() if k != "affine"} + meta_b = {k: v for k, v in meta_b.items() if k != "affine"} + self.assertEqual(meta_a, meta_b) + def check( self, out: torch.Tensor, @@ -89,12 +100,7 @@ def check( # check meta and affine are equal and affine is on correct device if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta: - orig_meta_no_affine = deepcopy(orig.meta) - del orig_meta_no_affine["affine"] - out_meta_no_affine = deepcopy(out.meta) - del out_meta_no_affine["affine"] - self.assertEqual(orig_meta_no_affine, out_meta_no_affine) - assert_allclose(out.affine, orig.affine) + self.check_meta(orig, out) self.assertTrue(str(device) in str(out.affine.device)) if check_ids: self.check_ids(out.affine, orig.affine, ids) @@ -273,11 +279,7 @@ def test_out(self): m1_add = m2 + m3 assert_allclose(m1, m1_add) - aff1, aff1_orig = m1.affine, m1_orig.affine - assert_allclose(aff1, aff1_orig) - meta1 = {k: v for k, v in m1.meta.items() if k != "affine"} - meta1_orig = {k: v for k, v in m1_orig.meta.items() if k != "affine"} - self.assertEqual(meta1, meta1_orig) + self.check_meta(m1, m1_orig) @parameterized.expand(TESTS) def test_collate(self, device, dtype): @@ -308,14 +310,90 @@ def test_dataloader(self, dtype): batch_size = 5 ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] ds = Dataset(ims) - expected_im_shape = (batch_size,) + tuple(ims[0].shape) - expected_affine_shape = (batch_size,) + tuple(ims[0].affine.shape) + im_shape = tuple(ims[0].shape) + affine_shape = tuple(ims[0].affine.shape) + expected_im_shape = (batch_size,) + im_shape + expected_affine_shape = (batch_size,) + affine_shape dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) for batch in dl: self.assertIsInstance(batch, MetaTensor) self.assertTupleEqual(tuple(batch.shape), expected_im_shape) self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + def test_indexing(self): + """ + Check the metadata is returned in the expected format depending on whether + the input `MetaTensor` is a batch of data or not. + """ + ims = [self.get_im()[0] for _ in range(5)] + data = list_data_collate(ims) + + # check that when using non-batch data, metadata is copied wholly when indexing + # or iterating across data. + im = ims[0] + self.check_meta(im[0], im) + self.check_meta(next(iter(im)), im) + + # index + d = data[0] + self.check(d, ims[0], ids=False) + + # iter + d = next(iter(data)) + self.check(d, ims[0], ids=False) + + # complex indexing + + # `is_batch==True`, should have subset of image and metadata. + d = data[1:3] + self.check(d, list_data_collate(ims[1:3]), ids=False) + + # is_batch==True, should have subset of image and same metadata as `[1:3]`. + d = data[1:3, 0] + self.check(d, list_data_collate([i[0] for i in ims[1:3]]), ids=False) + + # `is_batch==False`, should have first metadata and subset of first image. + d = data[0, 0] + self.check(d, ims[0][0], ids=False) + + # `is_batch==True`, should have all metadata and subset of all images. + d = data[:, 0] + self.check(d, list_data_collate([i[0] for i in ims]), ids=False) + + # `is_batch==True`, should have all metadata and subset of all images. + d = data[..., -1] + self.check(d, list_data_collate([i[..., -1] for i in ims]), ids=False) + + # `is_batch==False`, tuple split along batch dim. Should have individual + # metadata. + d = data.unbind(0) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), len(ims)) + for _d, _im in zip(d, ims): + self.check(_d, _im, ids=False) + + # `is_batch==False`, tuple split along batch dim. Should have individual + # metadata. + d = data.unbind(dim=0) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), len(ims)) + for _d, _im in zip(d, ims): + self.check(_d, _im, ids=False) + + # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. + d = data.unbind(-1) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), ims[0].shape[-1]) + for _d in d: + self.check_meta(_d, data) + + # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. + d = data.unbind(dim=-1) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), ims[0].shape[-1]) + for _d in d: + self.check_meta(_d, data) + @parameterized.expand(DTYPES) @SkipIfBeforePyTorchVersion((1, 8)) def test_decollate(self, dtype): From fb9b10f0991ba1596f656ef674c2c788ab52d891 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:30:02 +0100 Subject: [PATCH 07/27] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 4 ++-- monai/data/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 3a3fdc47db..9bfcb0cfae 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -13,7 +13,7 @@ import warnings from copy import deepcopy -from typing import Callable, Sequence +from typing import Any, Callable, Sequence import torch @@ -179,7 +179,7 @@ def update_meta(rets: Sequence, func, args, kwargs): return tuple(out) if isinstance(rets, tuple) else out @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: + def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: """Wraps all torch functions.""" if kwargs is None: kwargs = {} diff --git a/monai/data/utils.py b/monai/data/utils.py index 7f67088a06..2bd7b49731 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -476,11 +476,11 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) # if of type MetaObj, decollate the metadata - if isinstance(batch, MetaObj): + if isinstance(batch, MetaObj) and all(isinstance(i, MetaObj) for i in out_list): metas = decollate_batch(batch.meta) for i in range(len(out_list)): out_list[i].meta = metas[i] # type: ignore - out_list[i].is_batch = False + out_list[i].is_batch = False # type: ignore if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) From 487578441a7067cf52a7ef6aeb91e100d8d02b0c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:31:21 +0100 Subject: [PATCH 08/27] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 00e10ca816..e38e009e96 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -177,7 +177,7 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None: id_in = id(input_objs[0]) if len(input_objs) > 0 else None deep_copy = id(self) != id_in self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) - self.is_batch = input_objs[0].is_batch + self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False def get_default_meta(self) -> dict: """Get the default meta. From b307d466f63d51488e5a5cdcc255b7a9ca32070b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:44:24 +0100 Subject: [PATCH 09/27] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 7968b8202e..05356fcc84 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -320,6 +320,7 @@ def test_dataloader(self, dtype): self.assertTupleEqual(tuple(batch.shape), expected_im_shape) self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + @SkipIfBeforePyTorchVersion((1, 9)) def test_indexing(self): """ Check the metadata is returned in the expected format depending on whether From e40553a5e54b481f09dae24656844f649d8c6832 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:45:16 +0100 Subject: [PATCH 10/27] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 9bfcb0cfae..e3fb7846ae 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -55,6 +55,7 @@ class MetaTensor(MetaObj, torch.Tensor): assert m2.affine == affine Notes: + - Requires pytorch 1.9 or newer for full compatibility. - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor())`. From f2c254877663e5774703d0eb1dbe822f47b4027d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:59:51 +0100 Subject: [PATCH 11/27] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index e3fb7846ae..9196f0186c 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -136,7 +136,7 @@ def update_meta(rets: Sequence, func, args, kwargs): # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a - # batch of data (e.g., `batch[:,-1]` versus `batch[0]`.) + # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if ret.is_batch: # only decollate metadata once if metas is None: From f9fd14a59e741b0038534953781674f3ac920cc1 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 15:24:03 +0100 Subject: [PATCH 12/27] load image meta tensor Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 13 +++ monai/data/meta_tensor.py | 104 ++++++++++++++++-- monai/data/utils.py | 21 +++- monai/transforms/io/array.py | 72 +++++++++---- monai/transforms/io/dictionary.py | 44 ++------ monai/transforms/utility/array.py | 3 + tests/test_load_image.py | 164 +++++++++++++++-------------- tests/test_load_imaged.py | 110 ++++++++++++++----- tests/test_meta_tensor.py | 168 +++++++++++++++++++++++++++--- tests/test_numpy_reader.py | 3 - 10 files changed, 510 insertions(+), 192 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 0e213f130b..e38e009e96 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -111,6 +111,7 @@ class MetaObj: def __init__(self): self._meta: dict = self.get_default_meta() + self._is_batch: bool = False @staticmethod def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: @@ -176,6 +177,7 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None: id_in = id(input_objs[0]) if len(input_objs) > 0 else None deep_copy = id(self) != id_in self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) + self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False def get_default_meta(self) -> dict: """Get the default meta. @@ -194,6 +196,7 @@ def __repr__(self) -> str: out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) else: out += "None" + out += f"\nIs batch?: {self.is_batch}" return out @@ -206,3 +209,13 @@ def meta(self) -> dict: def meta(self, d: dict) -> None: """Set the meta.""" self._meta = d + + @property + def is_batch(self) -> bool: + """Return whether object is part of batch or not.""" + return self._is_batch + + @is_batch.setter + def is_batch(self, val: bool) -> None: + """Set whether object is part of batch or not.""" + self._is_batch = val diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index ba80f93e74..9196f0186c 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -13,11 +13,12 @@ import warnings from copy import deepcopy -from typing import Callable +from typing import Any, Callable, Sequence import torch from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms +from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix __all__ = ["MetaTensor"] @@ -54,11 +55,20 @@ class MetaTensor(MetaObj, torch.Tensor): assert m2.affine == affine Notes: + - Requires pytorch 1.9 or newer for full compatibility. - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor())`. - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. + - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute. + - With a batch of data, `batch[0]` will return the 0th image + with the 0th metadata. When the batch dimension is non-singleton, e.g., + `batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the + last example) of the metadata will be returned, and `is_batch` will return `True`. + - When creating a batch with this class, use `monai.data.DataLoader` as opposed + to `torch.utils.data.DataLoader`, as this will take care of collating the + metadata properly. """ @staticmethod @@ -101,21 +111,93 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call if isinstance(val, torch.Tensor): setattr(self, attribute, val.to(self.device)) + @staticmethod + def update_meta(rets: Sequence, func, args, kwargs): + """Update the metadata from the output of `__torch_function__`. + The output could be a single object, or a sequence of them. Hence, they get + converted to a sequence if necessary and then processed by iterating across them. + + For each element, if not of type `MetaTensor`, then nothing to do + """ + out = [] + metas = None + for idx, ret in enumerate(rets): + # if not `MetaTensor`, nothing to do. + if not isinstance(ret, MetaTensor): + pass + # if not tracking, convert to `torch.Tensor`. + elif not (get_track_meta() or get_track_transforms()): + ret = ret.as_tensor() + # else, handle the `MetaTensor` metadata. + else: + meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) + ret._copy_meta(meta_args) + + # If we have a batch of data, then we need to be careful if a slice of + # the data is returned. Depending on how the data are indexed, we return + # some or all of the metadata, and the return object may or may not be a + # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). + if ret.is_batch: + # only decollate metadata once + if metas is None: + metas = decollate_batch(ret.meta) + # if indexing e.g., `batch[0]` + if func == torch.Tensor.__getitem__: + idx = args[1] + if isinstance(idx, Sequence): + idx = idx[0] + # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the + # first element will be `slice(None, None, None)` and `Ellipsis`, + # respectively. Don't need to do anything with the metadata. + if idx not in (slice(None, None, None), Ellipsis): + meta = metas[idx] + # if using e.g., `batch[0:2]`, then `is_batch` should still be + # `True`. Also re-collate the remaining elements. + if isinstance(meta, list) and len(meta) > 1: + ret.meta = list_data_collate(meta) + # if using e.g., `batch[0]` or `batch[0, 1]`, then return single + # element from batch, and set `is_batch` to `False`. + else: + ret.meta = meta + ret.is_batch = False + # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. + # But we only want to split the batch if the `unbind` is along the 0th + # dimension. + elif func == torch.Tensor.unbind: + if len(args) > 1: + dim = args[1] + elif "dim" in kwargs: + dim = kwargs["dim"] + else: + dim = 0 + if dim == 0: + ret.meta = metas[idx] + ret.is_batch = False + + ret.affine = ret.affine.to(ret.device) + out.append(ret) + # if the input was a tuple, then return it as a tuple + return tuple(out) if isinstance(rets, tuple) else out + @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: + def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: """Wraps all torch functions.""" if kwargs is None: kwargs = {} - ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) - # e.g., __repr__ returns a string - if not isinstance(ret, torch.Tensor): + ret = super().__torch_function__(func, types, args, kwargs) + # if `out` has been used as argument, metadata is not copied, nothing to do. + if "out" in kwargs: return ret - if not (get_track_meta() or get_track_transforms()): - return ret.as_tensor() - meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) - ret._copy_meta(meta_args) - ret.affine = ret.affine.to(ret.device) - return ret + # we might have 1 or multiple outputs. Might be MetaTensor, might be something + # else (e.g., `__repr__` returns a string). + # Convert to list (if necessary), process, and at end remove list if one was added. + if not isinstance(ret, Sequence): + ret = [ret] + unpack = True + else: + unpack = False + ret = MetaTensor.update_meta(ret, func, args, kwargs) + return ret[0] if unpack else ret def get_default_affine(self, dtype=torch.float64) -> torch.Tensor: return torch.eye(4, device=self.device, dtype=dtype) diff --git a/monai/data/utils.py b/monai/data/utils.py index 495daf15e2..2bd7b49731 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -28,6 +28,7 @@ from torch.utils.data._utils.collate import default_collate from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike +from monai.data.meta_obj import MetaObj from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -346,9 +347,17 @@ def list_data_collate(batch: Sequence): ret = {} for k in elem: key = k - ret[key] = default_collate([d[key] for d in data]) - return ret - return default_collate(data) + data_for_batch = [d[key] for d in data] + ret[key] = default_collate(data_for_batch) + if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch): + ret[key].meta = list_data_collate([i.meta for i in data_for_batch]) + ret[key].is_batch = True + else: + ret = default_collate(data) + if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data): + ret.meta = list_data_collate([i.meta for i in data]) + ret.is_batch = True + return ret except RuntimeError as re: re_str = str(re) if "equal size" in re_str: @@ -466,6 +475,12 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if batch.ndim == 0: return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) + # if of type MetaObj, decollate the metadata + if isinstance(batch, MetaObj) and all(isinstance(i, MetaObj) for i in out_list): + metas = decollate_batch(batch.meta) + for i in range(len(out_list)): + out_list[i].meta = metas[i] # type: ignore + out_list[i].is_batch = False # type: ignore if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 5bafd84eaf..03f87f1671 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -29,11 +29,20 @@ from monai.data import image_writer from monai.data.folder_layout import FolderLayout from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader +from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import MetaTensor from monai.transforms.transform import Transform from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import +from monai.utils import ( + InterpolateMode, + OptionalImportError, + deprecated_arg, + ensure_tuple, + look_up_option, + optional_import, +) nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -93,14 +102,11 @@ class LoadImage(Transform): """ + @deprecated_arg( + name="image_only", since="0.8", msg_suffix="If necessary, please extract meta data with `MetaTensor.meta`" + ) def __init__( - self, - reader=None, - image_only: bool = False, - dtype: DtypeLike = np.float32, - ensure_channel_first: bool = False, - *args, - **kwargs, + self, reader=None, dtype: DtypeLike = np.float32, ensure_channel_first: bool = False, *args, **kwargs ) -> None: """ Args: @@ -111,7 +117,6 @@ def __init__( ``"ITKReader"``, ``"NibabelReader"``, ``"NumpyReader"``. a reader instance will be constructed with the `*args` and `**kwargs` parameters. - if `reader` is a reader class/instance, it will be registered to this loader accordingly. - image_only: if True return only the image volume, otherwise return image data array and header dict. dtype: if not None convert the loaded image to this data type. ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert the image array shape to `channel first`. default to `False`. @@ -120,8 +125,8 @@ def __init__( Note: - - The transform returns an image data array if `image_only` is True, - or a tuple of two elements containing the data array, and the meta data in a dictionary format otherwise. + - The transform returns a MetaTensor, unless `set_track_meta(False)` has been used, in which case, a + `torch.Tensor` will be returned. - If `reader` is specified, the loader will attempt to use the specified readers and the default supported readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. In this case, it is therefore recommended setting the most appropriate reader as @@ -130,7 +135,6 @@ def __init__( """ self.auto_select = reader is None - self.image_only = image_only self.dtype = dtype self.ensure_channel_first = ensure_channel_first @@ -241,14 +245,46 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option raise ValueError("`meta_data` must be a dict.") # make sure all elements in metadata are little endian meta_data = switch_endianness(meta_data, "<") - if self.ensure_channel_first: - img_array = EnsureChannelFirst()(img_array, meta_data) - if self.image_only: - return img_array meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader - - return img_array, meta_data + img = self.join_im_and_meta(img_array, meta_data) + if self.ensure_channel_first: + img = EnsureChannelFirst()(img) + return img + + @staticmethod + def join_im_and_meta(im, meta: dict): + img = torch.as_tensor(im) + + # if not tracking metadata, return torch.Tensor + if not get_track_meta() or meta is None: + return img + + if "affine" in meta: + meta["affine"] = torch.as_tensor(meta["affine"]) + + # delete extra metadata + for i in range(8): + for k in ("dim", "pixdim"): + if f"{k}[{i}]" in meta: + del meta[f"{k}[{i}]"] + for k in ( + "original_affine", + "spatial_shape", + "spacing", + "srow_x", + "srow_y", + "srow_z", + "quatern_b", + "quatern_c", + "quatern_d", + "qoffset_x", + "qoffset_y", + "qoffset_z", + ): + if k in meta: + del meta[k] + return MetaTensor(img, meta=meta) class SaveImage(Transform): diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 30dedc7810..1aa6b934af 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -25,7 +25,7 @@ from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage from monai.transforms.transform import MapTransform -from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple, ensure_tuple_rep +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, deprecated_arg, ensure_tuple_rep from monai.utils.enums import PostFix __all__ = ["LoadImaged", "LoadImageD", "LoadImageDict", "SaveImaged", "SaveImageD", "SaveImageDict"] @@ -64,6 +64,10 @@ class LoadImaged(MapTransform): """ + @deprecated_arg(name="image_only", since="0.8") + @deprecated_arg(name="meta_keys", since="0.8") + @deprecated_arg(name="meta_key_postfix", since="0.8") + @deprecated_arg(name="overwriting", since="0.8") def __init__( self, keys: KeysCollection, @@ -90,17 +94,6 @@ def __init__( a reader instance will be constructed with the `*args` and `**kwargs` parameters. - if `reader` is a reader class/instance, it will be registered to this loader accordingly. dtype: if not None, convert the loaded image data to this data type. - meta_keys: explicitly indicate the key to store the corresponding meta data dictionary. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to store the metadata of the nifti image, - default is `meta_dict`. The meta data is a dictionary object. - For example, load nifti file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow overwriting existing meta data of same key. - default is False, which will raise exception if encountering existing key. - image_only: if True return dictionary containing just only the image volumes, otherwise return - dictionary containing image data array and header dict per input key. ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert the image array shape to `channel first`. default to `False`. allow_missing_keys: don't raise exception if key is missing. @@ -108,14 +101,7 @@ def __init__( kwargs: additional parameters for reader if providing a reader name. """ super().__init__(keys, allow_missing_keys) - self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, *args, **kwargs) - if not isinstance(meta_key_postfix, str): - raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.overwriting = overwriting + self._loader = LoadImage(reader, dtype, ensure_channel_first, *args, **kwargs) def register(self, reader: ImageReader): self._loader.register(reader) @@ -127,22 +113,8 @@ def __call__(self, data, reader: Optional[ImageReader] = None): """ d = dict(data) - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - data = self._loader(d[key], reader) - if self._loader.image_only: - if not isinstance(data, np.ndarray): - raise ValueError("loader must return a numpy array (because image_only=True was used).") - d[key] = data - else: - if not isinstance(data, (tuple, list)): - raise ValueError("loader must return a tuple or list (because image_only=False was used).") - d[key] = data[0] - if not isinstance(data[1], dict): - raise ValueError("metadata must be a dict.") - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key in d and not self.overwriting: - raise KeyError(f"Meta data with key {meta_key} already exists and overwriting=False.") - d[meta_key] = data[1] + for key in self.key_iterator(d): + d[key] = self._loader(d[key], reader) return d diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index bc0c09e949..f512c94dc4 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -24,6 +24,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_tensor import MetaTensor from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( extreme_points_to_image, @@ -210,6 +211,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> """ Apply the transform to `img`. """ + if isinstance(img, MetaTensor): + meta_dict = img.meta if not isinstance(meta_dict, Mapping): msg = "meta_dict not available, EnsureChannelFirst is not in use." if self.strict_check: diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 201fe2fd5b..03d26b4b77 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -10,6 +10,7 @@ # limitations under the License. import os +import shutil import tempfile import unittest from pathlib import Path @@ -17,11 +18,15 @@ import itk import nibabel as nib import numpy as np +import torch from parameterized import parameterized from PIL import Image from monai.data import ITKReader, NibabelReader +from monai.data.meta_obj import set_track_meta +from monai.data.meta_tensor import MetaTensor from monai.transforms import LoadImage +from tests.utils import assert_allclose class _MiniReader: @@ -40,75 +45,57 @@ def get_data(self, _obj): return np.zeros((1, 1, 1)), {"name": "my test"} -TEST_CASE_1 = [{"image_only": True}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_1 = [{}, ["test_image.nii.gz"], (128, 128, 128)] -TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_2 = [{}, ["test_image.nii.gz"], (128, 128, 128)] -TEST_CASE_3 = [ - {"image_only": True}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - (3, 128, 128, 128), -] +TEST_CASE_3 = [{}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128)] TEST_CASE_3_1 = [ # .mgz format - {"image_only": True, "reader": "nibabelreader"}, + {"reader": "nibabelreader"}, ["test_image.mgz", "test_image2.mgz", "test_image3.mgz"], (3, 128, 128, 128), ] -TEST_CASE_4 = [ - {"image_only": False}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - (3, 128, 128, 128), -] +TEST_CASE_4 = [{}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128)] TEST_CASE_4_1 = [ # additional parameter - {"image_only": False, "mmap": False}, + {"mmap": False}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128), ] -TEST_CASE_5 = [{"reader": NibabelReader(mmap=False), "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)] -TEST_CASE_6 = [{"reader": ITKReader(), "image_only": True}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_6 = [{"reader": ITKReader()}, ["test_image.nii.gz"], (128, 128, 128)] -TEST_CASE_7 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_7 = [{"reader": ITKReader()}, ["test_image.nii.gz"], (128, 128, 128)] TEST_CASE_8 = [ - {"reader": ITKReader(), "image_only": True}, + {"reader": ITKReader()}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128), ] TEST_CASE_8_1 = [ - {"reader": ITKReader(channel_dim=0), "image_only": True}, + {"reader": ITKReader(channel_dim=0)}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (384, 128, 128), ] TEST_CASE_9 = [ - {"reader": ITKReader(), "image_only": False}, + {"reader": ITKReader()}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], (3, 128, 128, 128), ] -TEST_CASE_10 = [ - {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, - "tests/testing_data/CT_DICOM", - (16, 16, 4), - (16, 16, 4), -] +TEST_CASE_10 = [{"reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)] -TEST_CASE_11 = [ - {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC}, - "tests/testing_data/CT_DICOM", - (16, 16, 4), - (16, 16, 4), -] +TEST_CASE_11 = [{"reader": "ITKReader", "pixel_type": itk.UC}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)] TEST_CASE_12 = [ - {"image_only": False, "reader": "ITKReader", "pixel_type": itk.UC, "reverse_indexing": True}, + {"reader": "ITKReader", "pixel_type": itk.UC, "reverse_indexing": True}, "tests/testing_data/CT_DICOM", (16, 16, 4), (4, 16, 16), @@ -135,6 +122,12 @@ def get_data(self, _obj): ] +TESTS_META = [] +for track_meta in (True, False): + TESTS_META.append([{}, (128, 128, 128), track_meta]) + TESTS_META.append([{"reader": "ITKReader", "fallback_only": False}, (128, 128, 128), track_meta]) + + class TestLoadImage(unittest.TestCase): @parameterized.expand( [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_3_1, TEST_CASE_4, TEST_CASE_4_1, TEST_CASE_5] @@ -146,13 +139,9 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) result = LoadImage(**input_param)(filenames) - - if isinstance(result, tuple): - result, header = result - self.assertTrue("affine" in header) - self.assertEqual(header["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz")) - np.testing.assert_allclose(header["affine"], np.eye(4)) - np.testing.assert_allclose(header["original_affine"], np.eye(4)) + ext = "".join(Path(name).suffixes) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext)) + assert_allclose(result.affine, torch.eye(4)) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) @@ -164,24 +153,18 @@ def test_itk_reader(self, input_param, filenames, expected_shape): itk_np_view = itk.image_view_from_array(test_image) itk.imwrite(itk_np_view, filenames[i]) result = LoadImage(**input_param)(filenames) - - if isinstance(result, tuple): - result, header = result - self.assertTrue("affine" in header) - self.assertEqual(header["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz")) - np_diag = np.diag([-1, -1, 1, 1]) - np.testing.assert_allclose(header["affine"], np_diag) - np.testing.assert_allclose(header["original_affine"], np_diag) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz")) + diag = torch.as_tensor(np.diag([-1, -1, 1, 1])) + np.testing.assert_allclose(result.affine, diag) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12]) def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, expected_np_shape): - result, header = LoadImage(**input_param)(filenames) - self.assertTrue("affine" in header) - self.assertEqual(header["filename_or_obj"], f"{Path(filenames)}") - np.testing.assert_allclose( - header["affine"], - np.array( + result = LoadImage(**input_param)(filenames) + self.assertEqual(result.meta["filename_or_obj"], f"{Path(filenames)}") + assert_allclose( + result.affine, + torch.tensor( [ [-0.488281, 0.0, 0.0, 125.0], [0.0, -0.488281, 0.0, 128.100006], @@ -190,7 +173,6 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, e ] ), ) - self.assertTupleEqual(tuple(header["spatial_shape"]), expected_shape) self.assertTupleEqual(result.shape, expected_np_shape) def test_itk_reader_multichannel(self): @@ -200,9 +182,7 @@ def test_itk_reader_multichannel(self): itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) for flag in (False, True): - result, header = LoadImage(reader=ITKReader(reverse_indexing=flag))(Path(filename)) - - self.assertTupleEqual(tuple(header["spatial_shape"]), (224, 256)) + result = LoadImage(reader=ITKReader(reverse_indexing=flag))(Path(filename)) test_image = test_image.transpose(1, 0, 2) np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0]) np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1]) @@ -215,12 +195,10 @@ def test_load_nifti_multichannel(self): itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) - itk_img, itk_header = LoadImage(reader=ITKReader())(Path(filename)) - self.assertTupleEqual(tuple(itk_header["spatial_shape"]), (16, 64, 31)) + itk_img = LoadImage(reader=ITKReader())(Path(filename)) self.assertTupleEqual(tuple(itk_img.shape), (16, 64, 31, 2)) - nib_image, nib_header = LoadImage(reader=NibabelReader(squeeze_non_spatial_dims=True))(Path(filename)) - self.assertTupleEqual(tuple(nib_header["spatial_shape"]), (16, 64, 31)) + nib_image = LoadImage(reader=NibabelReader(squeeze_non_spatial_dims=True))(Path(filename)) self.assertTupleEqual(tuple(nib_image.shape), (16, 64, 31, 2)) np.testing.assert_allclose(itk_img, nib_image, atol=1e-3, rtol=1e-3) @@ -231,8 +209,7 @@ def test_load_png(self): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) - result, header = LoadImage(image_only=False)(filename) - self.assertTupleEqual(tuple(header["spatial_shape"]), spatial_size[::-1]) + result = LoadImage()(filename) self.assertTupleEqual(result.shape, spatial_size[::-1]) np.testing.assert_allclose(result.T, test_image) @@ -244,10 +221,9 @@ def test_register(self): itk_np_view = itk.image_view_from_array(test_image) itk.imwrite(itk_np_view, filename) - loader = LoadImage(image_only=False) + loader = LoadImage() loader.register(ITKReader()) - result, header = loader(filename) - self.assertTupleEqual(tuple(header["spatial_shape"]), spatial_size[::-1]) + result = loader(filename) self.assertTupleEqual(result.shape, spatial_size[::-1]) def test_kwargs(self): @@ -258,35 +234,35 @@ def test_kwargs(self): itk_np_view = itk.image_view_from_array(test_image) itk.imwrite(itk_np_view, filename) - loader = LoadImage(image_only=False) + loader = LoadImage() reader = ITKReader(fallback_only=False) loader.register(reader) - result, header = loader(filename) + result = loader(filename) reader = ITKReader() img = reader.read(filename, fallback_only=False) - result_raw, header_raw = reader.get_data(img) - np.testing.assert_allclose(header["spatial_shape"], header_raw["spatial_shape"]) + result_raw = reader.get_data(img) + result_raw = LoadImage.join_im_and_meta(*result_raw) self.assertTupleEqual(result.shape, result_raw.shape) def test_my_reader(self): """test customised readers""" out = LoadImage(reader=_MiniReader, is_compatible=True)("test") - self.assertEqual(out[1]["name"], "my test") + self.assertEqual(out.meta["name"], "my test") out = LoadImage(reader=_MiniReader, is_compatible=False)("test") - self.assertEqual(out[1]["name"], "my test") + self.assertEqual(out.meta["name"], "my test") for item in (_MiniReader, _MiniReader(is_compatible=False)): out = LoadImage(reader=item)("test") - self.assertEqual(out[1]["name"], "my test") + self.assertEqual(out.meta["name"], "my test") out = LoadImage()("test", reader=_MiniReader(is_compatible=False)) - self.assertEqual(out[1]["name"], "my test") + self.assertEqual(out.meta["name"], "my test") def test_itk_meta(self): """test metadata from a directory""" - out, meta = LoadImage(reader="ITKReader", pixel_type=itk.UC, series_meta=True)("tests/testing_data/CT_DICOM") + out = LoadImage(reader="ITKReader", pixel_type=itk.UC, series_meta=True)("tests/testing_data/CT_DICOM") idx = "0008|103e" label = itk.GDCMImageIO.GetLabelFromTag(idx, "")[1] - val = meta[idx] + val = out.meta[idx] expected = "Series Description=Routine Brain " self.assertEqual(f"{label}={val}", expected) @@ -299,10 +275,38 @@ def test_channel_dim(self, input_param, filename, expected_shape): result = LoadImage(**input_param)(filename) self.assertTupleEqual( - result[0].shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape + result.shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape ) - self.assertTupleEqual(tuple(result[1]["spatial_shape"]), (128, 128, 128)) - self.assertEqual(result[1]["original_channel_dim"], input_param["channel_dim"]) + self.assertEqual(result.meta["original_channel_dim"], input_param["channel_dim"]) + + +class TestLoadImageMeta(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(__class__, cls).setUpClass() + cls.tmpdir = tempfile.mkdtemp() + test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4)) + nib.save(test_image, os.path.join(cls.tmpdir, "im.nii.gz")) + cls.test_data = os.path.join(cls.tmpdir, "im.nii.gz") + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir) + super(__class__, cls).tearDownClass() + + @parameterized.expand(TESTS_META) + def test_correct(self, input_param, expected_shape, track_meta): + set_track_meta(track_meta) + r = LoadImage(**input_param)(self.test_data) + self.assertTupleEqual(r.shape, expected_shape) + if track_meta: + self.assertIsInstance(r, MetaTensor) + self.assertTrue(hasattr(r, "affine")) + self.assertIsInstance(r.affine, torch.Tensor) + else: + self.assertIsInstance(r, torch.Tensor) + self.assertNotIsInstance(r, MetaTensor) + self.assertFalse(hasattr(r, "affine")) if __name__ == "__main__": diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index bc001cf2fd..af7886b63b 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -10,6 +10,7 @@ # limitations under the License. import os +import shutil import tempfile import unittest from pathlib import Path @@ -17,11 +18,16 @@ import itk import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.data import ITKReader -from monai.transforms import Compose, EnsureChannelFirstD, LoadImaged, SaveImageD +from monai.data.meta_obj import set_track_meta +from monai.data.meta_tensor import MetaTensor +from monai.transforms import Compose, EnsureChannelFirstD, FromMetaTensord, LoadImaged, SaveImageD +from monai.transforms.meta_utility.dictionary import ToMetaTensord from monai.utils.enums import PostFix +from tests.utils import assert_allclose KEYS = ["image", "label", "extra"] @@ -29,6 +35,11 @@ TEST_CASE_2 = [{"keys": KEYS, "reader": "ITKReader", "fallback_only": False}, (128, 128, 128)] +TESTS_META = [] +for track_meta in (True, False): + TESTS_META.append([{"keys": KEYS}, (128, 128, 128), track_meta]) + TESTS_META.append([{"keys": KEYS, "reader": "ITKReader", "fallback_only": False}, (128, 128, 128), track_meta]) + class TestLoadImaged(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @@ -55,7 +66,6 @@ def test_register(self): loader = LoadImaged(keys="img") loader.register(ITKReader()) result = loader({"img": Path(filename)}) - self.assertTupleEqual(tuple(result[PostFix.meta("img")]["spatial_shape"]), spatial_size[::-1]) self.assertTupleEqual(result["img"].shape, spatial_size[::-1]) def test_channel_dim(self): @@ -67,8 +77,8 @@ def test_channel_dim(self): loader = LoadImaged(keys="img") loader.register(ITKReader(channel_dim=2)) - result = EnsureChannelFirstD("img")(loader({"img": filename})) - self.assertTupleEqual(tuple(result[PostFix.meta("img")]["spatial_shape"]), (32, 64, 128)) + t = Compose([FromMetaTensord("img"), EnsureChannelFirstD("img")]) + result = t(loader({"img": filename})) self.assertTupleEqual(result["img"].shape, (3, 32, 64, 128)) def test_no_file(self): @@ -79,49 +89,57 @@ def test_no_file(self): class TestConsistency(unittest.TestCase): - def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext): + def _cmp(self, filename, ch_shape, reader_1, reader_2, outname, ext): data_dict = {"img": filename} keys = data_dict.keys() xforms = Compose([LoadImaged(keys, reader=reader_1, ensure_channel_first=True)]) img_dict = xforms(data_dict) # load dicom with itk self.assertTupleEqual(img_dict["img"].shape, ch_shape) - self.assertTupleEqual(tuple(img_dict[PostFix.meta("img")]["spatial_shape"]), shape) with tempfile.TemporaryDirectory() as tempdir: - save_xform = SaveImageD( - keys, meta_keys=PostFix.meta("img"), output_dir=tempdir, squeeze_end_dims=False, output_ext=ext + save_xform = Compose( + [ + FromMetaTensord(keys), + SaveImageD( + keys, meta_keys=PostFix.meta("img"), output_dir=tempdir, squeeze_end_dims=False, output_ext=ext + ), + ] ) save_xform(img_dict) # save to nifti - new_xforms = Compose([LoadImaged(keys, reader=reader_2), EnsureChannelFirstD(keys)]) + new_xforms = Compose( + [ + LoadImaged(keys, reader=reader_2), + FromMetaTensord(keys), + EnsureChannelFirstD(keys), + ToMetaTensord(keys), + ] + ) out = new_xforms({"img": os.path.join(tempdir, outname)}) # load nifti with itk self.assertTupleEqual(out["img"].shape, ch_shape) - self.assertTupleEqual(tuple(out[PostFix.meta("img")]["spatial_shape"]), shape) - if "affine" in img_dict[PostFix.meta("img")] and "affine" in out[PostFix.meta("img")]: - np.testing.assert_allclose( - img_dict[PostFix.meta("img")]["affine"], out[PostFix.meta("img")]["affine"], rtol=1e-3 - ) - np.testing.assert_allclose(out["img"], img_dict["img"], rtol=1e-3) + + def is_identity(x): + return (x == torch.eye(x.shape[0])).all() + + if not is_identity(img_dict["img"].affine) and not is_identity(out["img"].affine): + assert_allclose(img_dict["img"].affine, out["img"].affine, rtol=1e-3) + assert_allclose(out["img"], img_dict["img"], rtol=1e-3) def test_dicom(self): img_dir = "tests/testing_data/CT_DICOM" - self._cmp( - img_dir, (16, 16, 4), (1, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz" - ) + self._cmp(img_dir, (1, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz") output_name = "CT_DICOM/CT_DICOM_trans.nii.gz" - self._cmp(img_dir, (16, 16, 4), (1, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") - self._cmp(img_dir, (16, 16, 4), (1, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") + self._cmp(img_dir, (1, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") + self._cmp(img_dir, (1, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") def test_multi_dicom(self): """multichannel dicom reading, saving to nifti, then load with itk or nibabel""" img_dir = ["tests/testing_data/CT_DICOM", "tests/testing_data/CT_DICOM"] - self._cmp( - img_dir, (16, 16, 4), (2, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz" - ) + self._cmp(img_dir, (2, 16, 16, 4), "itkreader", "itkreader", "CT_DICOM/CT_DICOM_trans.nii.gz", ".nii.gz") output_name = "CT_DICOM/CT_DICOM_trans.nii.gz" - self._cmp(img_dir, (16, 16, 4), (2, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") - self._cmp(img_dir, (16, 16, 4), (2, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") + self._cmp(img_dir, (2, 16, 16, 4), "nibabelreader", "itkreader", output_name, ".nii.gz") + self._cmp(img_dir, (2, 16, 16, 4), "itkreader", "nibabelreader", output_name, ".nii.gz") def test_png(self): """png reading with itk, saving to nifti, then load with itk or nibabel or PIL""" @@ -132,9 +150,45 @@ def test_png(self): itk_np_view = itk.image_view_from_array(test_image, is_vector=True) itk.imwrite(itk_np_view, filename) output_name = "test_image/test_image_trans.png" - self._cmp(filename, (224, 256), (3, 224, 256), "itkreader", "itkreader", output_name, ".png") - self._cmp(filename, (224, 256), (3, 224, 256), "itkreader", "PILReader", output_name, ".png") - self._cmp(filename, (224, 256), (3, 224, 256), "itkreader", "nibabelreader", output_name, ".png") + self._cmp(filename, (3, 224, 256), "itkreader", "itkreader", output_name, ".png") + self._cmp(filename, (3, 224, 256), "itkreader", "PILReader", output_name, ".png") + self._cmp(filename, (3, 224, 256), "itkreader", "nibabelreader", output_name, ".png") + + +class TestLoadImagedMeta(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(__class__, cls).setUpClass() + cls.tmpdir = tempfile.mkdtemp() + test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4)) + cls.test_data = {} + for key in KEYS: + nib.save(test_image, os.path.join(cls.tmpdir, key + ".nii.gz")) + cls.test_data.update({key: os.path.join(cls.tmpdir, key + ".nii.gz")}) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir) + super(__class__, cls).tearDownClass() + + @parameterized.expand(TESTS_META) + def test_correct(self, input_param, expected_shape, track_meta): + set_track_meta(track_meta) + result = LoadImaged(**input_param)(self.test_data) + + # shouldn't have any extra meta data keys + self.assertEqual(len(result), len(KEYS)) + for key in KEYS: + r = result[key] + self.assertTupleEqual(r.shape, expected_shape) + if track_meta: + self.assertIsInstance(r, MetaTensor) + self.assertTrue(hasattr(r, "affine")) + self.assertIsInstance(r.affine, torch.Tensor) + else: + self.assertIsInstance(r, torch.Tensor) + self.assertNotIsInstance(r, MetaTensor) + self.assertFalse(hasattr(r, "affine")) if __name__ == "__main__": diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index c18ef08b85..05356fcc84 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -21,11 +21,13 @@ import torch from parameterized import parameterized +from monai.data import DataLoader, Dataset from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms from monai.data.meta_tensor import MetaTensor +from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix from monai.utils.module import pytorch_after -from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda +from tests.utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, skip_if_no_cuda DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] TESTS = [] @@ -59,6 +61,17 @@ def check_ids(self, a, b, should_match): comp = self.assertEqual if should_match else self.assertNotEqual comp(id(a), id(b)) + def check_meta(self, a: MetaTensor, b: MetaTensor) -> None: + self.assertEqual(a.is_batch, b.is_batch) + meta_a, meta_b = a.meta, b.meta + # need to split affine from rest of metadata + aff_a = meta_a.get("affine", None) + aff_b = meta_b.get("affine", None) + assert_allclose(aff_a, aff_b) + meta_a = {k: v for k, v in meta_a.items() if k != "affine"} + meta_b = {k: v for k, v in meta_b.items() if k != "affine"} + self.assertEqual(meta_a, meta_b) + def check( self, out: torch.Tensor, @@ -87,12 +100,7 @@ def check( # check meta and affine are equal and affine is on correct device if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta: - orig_meta_no_affine = deepcopy(orig.meta) - del orig_meta_no_affine["affine"] - out_meta_no_affine = deepcopy(out.meta) - del out_meta_no_affine["affine"] - self.assertEqual(orig_meta_no_affine, out_meta_no_affine) - assert_allclose(out.affine, orig.affine) + self.check_meta(orig, out) self.assertTrue(str(device) in str(out.affine.device)) if check_ids: self.check_ids(out.affine, orig.affine, ids) @@ -261,12 +269,146 @@ def test_amp(self): im_conv2 = conv(im) self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3) - # TODO - # collate - # decollate - # dataset - # dataloader - # matplotlib + def test_out(self): + """Test when `out` is given as an argument.""" + m1, _ = self.get_im() + m1_orig = deepcopy(m1) + m2, _ = self.get_im() + m3, _ = self.get_im() + torch.add(m2, m3, out=m1) + m1_add = m2 + m3 + + assert_allclose(m1, m1_add) + self.check_meta(m1, m1_orig) + + @parameterized.expand(TESTS) + def test_collate(self, device, dtype): + numel = 3 + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)] + collated = list_data_collate(ims) + # tensor + self.assertIsInstance(collated, MetaTensor) + expected_shape = (numel,) + tuple(ims[0].shape) + self.assertTupleEqual(tuple(collated.shape), expected_shape) + for i, im in enumerate(ims): + self.check(im, ims[i], ids=True) + # affine + self.assertIsInstance(collated.affine, torch.Tensor) + expected_shape = (numel,) + tuple(ims[0].affine.shape) + self.assertTupleEqual(tuple(collated.affine.shape), expected_shape) + + @parameterized.expand(TESTS) + def test_dataset(self, device, dtype): + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(4)] + ds = Dataset(ims) + for i, im in enumerate(ds): + self.check(im, ims[i], ids=True) + + @parameterized.expand(DTYPES) + @SkipIfBeforePyTorchVersion((1, 8)) + def test_dataloader(self, dtype): + batch_size = 5 + ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] + ds = Dataset(ims) + im_shape = tuple(ims[0].shape) + affine_shape = tuple(ims[0].affine.shape) + expected_im_shape = (batch_size,) + im_shape + expected_affine_shape = (batch_size,) + affine_shape + dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) + for batch in dl: + self.assertIsInstance(batch, MetaTensor) + self.assertTupleEqual(tuple(batch.shape), expected_im_shape) + self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + + @SkipIfBeforePyTorchVersion((1, 9)) + def test_indexing(self): + """ + Check the metadata is returned in the expected format depending on whether + the input `MetaTensor` is a batch of data or not. + """ + ims = [self.get_im()[0] for _ in range(5)] + data = list_data_collate(ims) + + # check that when using non-batch data, metadata is copied wholly when indexing + # or iterating across data. + im = ims[0] + self.check_meta(im[0], im) + self.check_meta(next(iter(im)), im) + + # index + d = data[0] + self.check(d, ims[0], ids=False) + + # iter + d = next(iter(data)) + self.check(d, ims[0], ids=False) + + # complex indexing + + # `is_batch==True`, should have subset of image and metadata. + d = data[1:3] + self.check(d, list_data_collate(ims[1:3]), ids=False) + + # is_batch==True, should have subset of image and same metadata as `[1:3]`. + d = data[1:3, 0] + self.check(d, list_data_collate([i[0] for i in ims[1:3]]), ids=False) + + # `is_batch==False`, should have first metadata and subset of first image. + d = data[0, 0] + self.check(d, ims[0][0], ids=False) + + # `is_batch==True`, should have all metadata and subset of all images. + d = data[:, 0] + self.check(d, list_data_collate([i[0] for i in ims]), ids=False) + + # `is_batch==True`, should have all metadata and subset of all images. + d = data[..., -1] + self.check(d, list_data_collate([i[..., -1] for i in ims]), ids=False) + + # `is_batch==False`, tuple split along batch dim. Should have individual + # metadata. + d = data.unbind(0) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), len(ims)) + for _d, _im in zip(d, ims): + self.check(_d, _im, ids=False) + + # `is_batch==False`, tuple split along batch dim. Should have individual + # metadata. + d = data.unbind(dim=0) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), len(ims)) + for _d, _im in zip(d, ims): + self.check(_d, _im, ids=False) + + # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. + d = data.unbind(-1) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), ims[0].shape[-1]) + for _d in d: + self.check_meta(_d, data) + + # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. + d = data.unbind(dim=-1) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), ims[0].shape[-1]) + for _d in d: + self.check_meta(_d, data) + + @parameterized.expand(DTYPES) + @SkipIfBeforePyTorchVersion((1, 8)) + def test_decollate(self, dtype): + batch_size = 3 + ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] + ds = Dataset(ims) + dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) + batch = next(iter(dl)) + decollated = decollate_batch(batch) + self.assertIsInstance(decollated, list) + self.assertEqual(len(decollated), batch_size) + for elem, im in zip(decollated, ims): + self.assertIsInstance(elem, MetaTensor) + self.check(elem, im, ids=False) if __name__ == "__main__": diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py index c2f3679e33..bb7686f67d 100644 --- a/tests/test_numpy_reader.py +++ b/tests/test_numpy_reader.py @@ -19,7 +19,6 @@ from monai.data import DataLoader, Dataset, NumpyReader from monai.transforms import LoadImaged -from monai.utils.enums import PostFix class TestNumpyReader(unittest.TestCase): @@ -110,8 +109,6 @@ def test_dataloader(self): num_workers=num_workers, ) for d in loader: - for s in d[PostFix.meta("image")]["spatial_shape"]: - torch.testing.assert_allclose(s, torch.as_tensor([3, 4, 5])) for c in d["image"]: torch.testing.assert_allclose(c, test_data) From b227fdd04657a74f4cf48f353f6a5932813a6c9b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 16:17:32 +0100 Subject: [PATCH 13/27] splitdims fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 7 ++++++- tests/test_splitdimd.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 564b2993e7..c7d249a246 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -425,7 +425,12 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d[split_meta_key] = deepcopy(orig_meta) dim = self.splitter.dim if dim > 0: # don't update affine if channel dim - shift = np.eye(len(d[split_meta_key]["affine"])) # type: ignore + affine = d[split_meta_key]["affine"] + ndim = len(affine) + if isinstance(affine, torch.Tensor): + shift = torch.eye(ndim, device=affine.device, dtype=affine.dtype) + else: + shift = np.eye(ndim) shift[dim - 1, -1] = i # type: ignore d[split_meta_key]["affine"] = d[split_meta_key]["affine"] @ shift # type: ignore diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py index 6b164a3cb8..f204ec277d 100644 --- a/tests/test_splitdimd.py +++ b/tests/test_splitdimd.py @@ -13,9 +13,10 @@ from copy import deepcopy import numpy as np +import torch from parameterized import parameterized -from monai.transforms import LoadImaged +from monai.transforms import Compose, FromMetaTensord, LoadImaged from monai.transforms.utility.dictionary import SplitDimd from tests.utils import TEST_NDARRAYS, assert_allclose, make_nifti_image, make_rand_affine @@ -33,7 +34,8 @@ def setUpClass(cls): affine = make_rand_affine() data = {"i": make_nifti_image(arr, affine)} - cls.data = LoadImaged("i")(data) + loader = Compose([LoadImaged("i"), FromMetaTensord("i")]) + cls.data = loader(data) @parameterized.expand(TESTS) def test_correct(self, keepdim, im_type, update_meta): @@ -54,8 +56,10 @@ def test_correct(self, keepdim, im_type, update_meta): split_idx = deepcopy(idx) split_idx[dim] = 0 # idx[1:] to remove channel and then add 1 for 4th element - real_world = data["i_meta_dict"]["affine"] @ (idx[1:] + [1]) - real_world2 = out[f"i_{split_im_idx}_meta_dict"]["affine"] @ (split_idx[1:] + [1]) + real_world = data["i_meta_dict"]["affine"] @ torch.tensor(idx[1:] + [1]).double() + real_world2 = ( + out[f"i_{split_im_idx}_meta_dict"]["affine"] @ torch.tensor(split_idx[1:] + [1]).double() + ) assert_allclose(real_world, real_world2) out = out["i_0"] From 603cdb55cf65a5a248626f7ebb175f8f6d812fb6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 16:52:43 +0100 Subject: [PATCH 14/27] flake8 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index c7d249a246..2143e729a9 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -425,12 +425,13 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d[split_meta_key] = deepcopy(orig_meta) dim = self.splitter.dim if dim > 0: # don't update affine if channel dim - affine = d[split_meta_key]["affine"] + affine = d[split_meta_key]["affine"] # type: ignore ndim = len(affine) - if isinstance(affine, torch.Tensor): - shift = torch.eye(ndim, device=affine.device, dtype=affine.dtype) - else: - shift = np.eye(ndim) + shift = ( + torch.eye(ndim, device=affine.device, dtype=affine.dtype) + if isinstance(affine, torch.Tensor) + else np.eye(ndim) + ) shift[dim - 1, -1] = i # type: ignore d[split_meta_key]["affine"] = d[split_meta_key]["affine"] @ shift # type: ignore From 91eff3a7ed327e98c54d0cf93f1ac2b96c0cfea0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 21 Apr 2022 11:43:11 +0100 Subject: [PATCH 15/27] fix test_nifti_rw Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/io/array.py | 8 ++++---- tests/test_nifti_rw.py | 12 +++++++----- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 03f87f1671..c99e59b965 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -263,15 +263,15 @@ def join_im_and_meta(im, meta: dict): if "affine" in meta: meta["affine"] = torch.as_tensor(meta["affine"]) - # delete extra metadata + # TODO: delete extra metadata for i in range(8): for k in ("dim", "pixdim"): if f"{k}[{i}]" in meta: del meta[f"{k}[{i}]"] for k in ( - "original_affine", - "spatial_shape", - "spacing", + # "original_affine", + # "spatial_shape", + # "spacing", "srow_x", "srow_y", "srow_z", diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index 2c0a8dc9a3..f9a987052b 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -85,11 +85,12 @@ def test_orientation(self, array, affine, reader_param, expected): # read test cases loader = LoadImage(**reader_param) load_result = loader(test_image) - if isinstance(load_result, tuple): - data_array, header = load_result - else: - data_array = load_result + data_array = load_result.numpy() + if reader_param.get("image_only", False): header = None + else: + header = load_result.meta + header["affine"] = header["affine"].numpy() if os.path.exists(test_image): os.remove(test_image) @@ -114,7 +115,8 @@ def test_orientation(self, array, affine, reader_param, expected): def test_consistency(self): np.set_printoptions(suppress=True, precision=3) test_image = make_nifti_image(np.arange(64).reshape(1, 8, 8), np.diag([1.5, 1.5, 1.5, 1])) - data, header = LoadImage(reader="NibabelReader", as_closest_canonical=False)(test_image) + data = LoadImage(reader="NibabelReader", as_closest_canonical=False)(test_image) + header = data.meta data, original_affine, new_affine = Spacing([0.8, 0.8, 0.8])(data[None], header["affine"], mode="nearest") data, _, new_affine = Orientation("ILP")(data, new_affine) if os.path.exists(test_image): From 0f548c60f86483f4ae3d217b601a38becf2f3f68 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 21 Apr 2022 12:01:00 +0100 Subject: [PATCH 16/27] test_smartcachedataset Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_smartcachedataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index e7d51be63a..d3d71145e9 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -17,10 +17,12 @@ import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.data import DataLoader, SmartCacheDataset from monai.transforms import Compose, Lambda, LoadImaged +from tests.utils import assert_allclose TEST_CASE_1 = [0.1, 0, Compose([LoadImaged(keys=["image", "label", "extra"])])] @@ -66,8 +68,8 @@ def test_shape(self, replace_rate, num_replace_workers, transform): for _ in range(3): dataset.update_cache() self.assertIsNotNone(dataset[15]) - if isinstance(dataset[15]["image"], np.ndarray): - np.testing.assert_allclose(dataset[15]["image"], dataset[15]["label"]) + if isinstance(dataset[15]["image"], (np.ndarray, torch.Tensor)): + assert_allclose(dataset[15]["image"], dataset[15]["label"]) else: self.assertIsInstance(dataset[15]["image"], str) dataset.shutdown() From 7e3b50c0b0cb8e868471beeb5a3768b430bcc74c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 21 Apr 2022 12:43:47 +0100 Subject: [PATCH 17/27] test fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_load_spacing_orientation.py | 68 +++++++++++++++++++++----- tests/test_wsireader.py | 3 +- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 2792822c3d..b98a8c8627 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -18,7 +18,7 @@ from nibabel.processing import resample_to_output from parameterized import parameterized -from monai.transforms import AddChanneld, LoadImaged, Orientationd, Spacingd +from monai.transforms import AddChanneld, Compose, FromMetaTensord, LoadImaged, Orientationd, Spacingd, ToNumpyd from monai.utils.enums import PostFix FILES = tuple( @@ -31,8 +31,15 @@ class TestLoadSpacingOrientation(unittest.TestCase): @parameterized.expand(FILES) def test_load_spacingd(self, filename): data = {"image": filename} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) t = time.time() res_dict = Spacingd(keys="image", pixdim=(1, 0.2, 1), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() @@ -49,8 +56,15 @@ def test_load_spacingd(self, filename): @parameterized.expand(FILES) def test_load_spacingd_rotate(self, filename): data = {"image": filename} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) affine = data_dict[PostFix.meta("image")]["affine"] data_dict[PostFix.meta("image")]["original_affine"] = data_dict[PostFix.meta("image")]["affine"] = ( np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine @@ -75,8 +89,15 @@ def test_load_spacingd_rotate(self, filename): def test_load_spacingd_non_diag(self): data = {"image": FILES[1]} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) affine = data_dict[PostFix.meta("image")]["affine"] data_dict[PostFix.meta("image")]["original_affine"] = data_dict[PostFix.meta("image")]["affine"] = ( np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine @@ -96,8 +117,15 @@ def test_load_spacingd_non_diag(self): def test_load_spacingd_rotate_non_diag(self): data = {"image": FILES[0]} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border")(data_dict) np.testing.assert_allclose( res_dict[PostFix.meta("image")]["affine"], @@ -106,8 +134,15 @@ def test_load_spacingd_rotate_non_diag(self): def test_load_spacingd_rotate_non_diag_ornt(self): data = {"image": FILES[0]} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=False, padding_mode="border")(data_dict) res_dict = Orientationd(keys="image", axcodes="LPI")(res_dict) np.testing.assert_allclose( @@ -117,8 +152,15 @@ def test_load_spacingd_rotate_non_diag_ornt(self): def test_load_spacingd_non_diag_ornt(self): data = {"image": FILES[1]} - data_dict = LoadImaged(keys="image")(data) - data_dict = AddChanneld(keys="image")(data_dict) + t = Compose( + [ + LoadImaged(keys="image"), + FromMetaTensord(keys="image"), + AddChanneld(keys="image"), + ToNumpyd(keys=["image", "image_meta_dict"]), + ] + ) + data_dict = t(data) affine = data_dict[PostFix.meta("image")]["affine"] data_dict[PostFix.meta("image")]["original_affine"] = data_dict[PostFix.meta("image")]["affine"] = ( np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) @ affine diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py index 6ee02143b8..2450cb5b11 100644 --- a/tests/test_wsireader.py +++ b/tests/test_wsireader.py @@ -20,7 +20,7 @@ from monai.data import DataLoader, Dataset from monai.data.image_reader import WSIReader -from monai.transforms import Compose, LoadImaged, ToTensord +from monai.transforms import Compose, FromMetaTensord, LoadImaged, ToTensord from monai.utils import first, optional_import from monai.utils.enums import PostFix from tests.utils import download_url_or_skip_test, testing_data_config @@ -203,6 +203,7 @@ def test_with_dataloader(self, file_path, level, expected_spatial_shape, expecte train_transform = Compose( [ LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), + FromMetaTensord(keys=["image"]), ToTensord(keys=["image"]), ] ) From 549c2eed81e1a33fa1b3e4fb1d877196e01482ee Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 21 Apr 2022 15:13:09 +0100 Subject: [PATCH 18/27] test fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/apps/deepgrow/dataset.py | 5 ++++- monai/data/image_dataset.py | 15 +++++++-------- tests/test_arraydataset.py | 17 +++++++++-------- tests/test_decollate.py | 12 ++---------- tests/test_ensure_channel_first.py | 29 +++++++++++++---------------- tests/test_ensure_channel_firstd.py | 4 +++- tests/test_image_dataset.py | 3 ++- tests/test_image_rw.py | 11 +++++++---- tests/test_invertd.py | 4 +++- tests/test_nifti_saver.py | 5 +++-- tests/test_resample_to_match.py | 10 ++++++++-- tests/test_resample_to_matchd.py | 2 ++ 12 files changed, 63 insertions(+), 54 deletions(-) diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index 721781196b..c79434038b 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -15,8 +15,9 @@ import numpy as np -from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd +from monai.transforms import AsChannelFirstd, Compose, FromMetaTensord, LoadImaged, Orientationd, Spacingd, ToNumpyd from monai.utils import GridSampleMode +from monai.utils.enums import PostFix def create_dataset( @@ -125,6 +126,8 @@ def _default_transforms(image_key, label_key, pixdim): return Compose( [ LoadImaged(keys=keys), + FromMetaTensord(keys=keys), + ToNumpyd(keys=keys + [PostFix.meta(k) for k in keys]), AsChannelFirstd(keys=keys), Orientationd(keys=keys, axcodes="RAS"), Spacingd(keys=keys, pixdim=pixdim, mode=mode), diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index 51f4e04959..c2a3e32d1e 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -86,7 +86,7 @@ def __init__( raise ValueError("transform_with_metadata=True requires image_only=False.") self.image_only = image_only self.transform_with_metadata = transform_with_metadata - self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs) + self.loader = LoadImage(reader, dtype, *args, **kwargs) self.set_random_state(seed=get_seed()) self._seed = 0 # transform synchronization seed @@ -101,14 +101,13 @@ def __getitem__(self, index: int): meta_data, seg_meta_data, seg, label = None, None, None, None # load data and optionally meta - if self.image_only: - img = self.loader(self.image_files[index]) + img = self.loader(self.image_files[index]) + if self.seg_files is not None: + seg = self.loader(self.seg_files[index]) + if not self.image_only: + meta_data = img.meta if self.seg_files is not None: - seg = self.loader(self.seg_files[index]) - else: - img, meta_data = self.loader(self.image_files[index]) - if self.seg_files is not None: - seg, seg_meta_data = self.loader(self.seg_files[index]) + seg_meta_data = seg.meta # apply the transforms if self.transform is not None: diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index ee1a92cf97..689fc0cb3d 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -23,15 +23,15 @@ from monai.transforms import AddChannel, Compose, LoadImage, RandAdjustContrast, RandGaussianNoise, Spacing TEST_CASE_1 = [ - Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]), - Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]), + Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]), + Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]), (0, 1), (1, 128, 128, 128), ] TEST_CASE_2 = [ - Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]), - Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]), + Compose([LoadImage(), AddChannel(), RandAdjustContrast(prob=1.0)]), + Compose([LoadImage(), AddChannel(), RandAdjustContrast(prob=1.0)]), (0, 1), (1, 128, 128, 128), ] @@ -39,20 +39,21 @@ class TestCompose(Compose): def __call__(self, input_): - img, metadata = self.transforms[0](input_) + img = self.transforms[0](input_) + metadata = img.meta img = self.transforms[1](img) img, _, _ = self.transforms[2](img, metadata["affine"]) return self.transforms[3](img), metadata TEST_CASE_3 = [ - TestCompose([LoadImage(image_only=False), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]), - TestCompose([LoadImage(image_only=False), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]), + TestCompose([LoadImage(), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]), + TestCompose([LoadImage(), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]), (0, 2), (1, 64, 64, 33), ] -TEST_CASE_4 = [Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]), (1, 128, 128, 128)] +TEST_CASE_4 = [Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]), (1, 128, 128, 128)] class TestArrayDataset(unittest.TestCase): diff --git a/tests/test_decollate.py b/tests/test_decollate.py index adeaa73337..60e3eb37e3 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -79,14 +79,6 @@ ] -class _ListCompose(Compose): - def __call__(self, input_): - img, metadata = self.transforms[0](input_) - for t in self.transforms[1:]: - img = t(img) - return img, metadata - - class TestDeCollate(unittest.TestCase): def setUp(self) -> None: set_determinism(seed=0) @@ -148,7 +140,7 @@ def test_decollation_tensor(self, *transforms): t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()]) # If nibabel present, read from disk if has_nib: - t_compose = Compose([LoadImage(image_only=True), t_compose]) + t_compose = Compose([LoadImage(), t_compose]) dataset = Dataset(self.data_list, t_compose) self.check_decollate(dataset=dataset) @@ -158,7 +150,7 @@ def test_decollation_list(self, *transforms): t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()]) # If nibabel present, read from disk if has_nib: - t_compose = _ListCompose([LoadImage(image_only=False), t_compose]) + t_compose = Compose([LoadImage(), t_compose]) dataset = Dataset(self.data_list, t_compose) self.check_decollate(dataset=dataset) diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index dd6168ec75..4f22a16f11 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -23,23 +23,19 @@ from monai.transforms import EnsureChannelFirst, LoadImage from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None] +TEST_CASE_1 = [{}, ["test_image.nii.gz"], None] -TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], -1] +TEST_CASE_2 = [{}, ["test_image.nii.gz"], -1] -TEST_CASE_3 = [{"image_only": False}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] +TEST_CASE_3 = [{}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] -TEST_CASE_4 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], None] +TEST_CASE_4 = [{"reader": ITKReader()}, ["test_image.nii.gz"], None] -TEST_CASE_5 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], -1] +TEST_CASE_5 = [{"reader": ITKReader()}, ["test_image.nii.gz"], -1] -TEST_CASE_6 = [ - {"reader": ITKReader(), "image_only": False}, - ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], - None, -] +TEST_CASE_6 = [{"reader": ITKReader()}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None] -TEST_CASE_7 = [{"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", None] +TEST_CASE_7 = [{"reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", None] class TestEnsureChannelFirst(unittest.TestCase): @@ -55,14 +51,15 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): filenames[i] = os.path.join(tempdir, name) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) for p in TEST_NDARRAYS: - result, header = LoadImage(**input_param)(filenames) + result = LoadImage(**input_param)(filenames) + header = result.meta result = EnsureChannelFirst()(p(result), header) self.assertEqual(result.shape[0], len(filenames)) @parameterized.expand([TEST_CASE_7]) def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim): - result, header = LoadImage(**input_param)(filenames) - result = EnsureChannelFirst()(result, header) + result = LoadImage(**input_param)(filenames) + result = EnsureChannelFirst()(result) self.assertEqual(result.shape[0], 1) def test_load_png(self): @@ -71,8 +68,8 @@ def test_load_png(self): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) - result, header = LoadImage(image_only=False)(filename) - result = EnsureChannelFirst()(result, header) + result = LoadImage()(filename) + result = EnsureChannelFirst()(result) self.assertEqual(result.shape[0], 3) def test_check(self): diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py index 7f1a57a207..cb1694d4e9 100644 --- a/tests/test_ensure_channel_firstd.py +++ b/tests/test_ensure_channel_firstd.py @@ -18,7 +18,7 @@ from parameterized import parameterized from PIL import Image -from monai.transforms import EnsureChannelFirstd, LoadImaged +from monai.transforms import EnsureChannelFirstd, FromMetaTensord, LoadImaged from monai.utils.enums import PostFix from tests.utils import TEST_NDARRAYS @@ -43,6 +43,7 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim): nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) for p in TEST_NDARRAYS: result = LoadImaged(**input_param)({"img": filenames}) + result = FromMetaTensord("img")(result) result["img"] = p(result["img"]) result = EnsureChannelFirstd(**input_param)(result) self.assertEqual(result["img"].shape[0], len(filenames)) @@ -54,6 +55,7 @@ def test_load_png(self): filename = os.path.join(tempdir, "test_image.png") Image.fromarray(test_image.astype("uint8")).save(filename) result = LoadImaged(keys="img")({"img": filename}) + result = FromMetaTensord(keys="img")(result) result = EnsureChannelFirstd(keys="img")(result) self.assertEqual(result["img"].shape[0], 3) diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index 41eda803dc..fae6cedff9 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -15,6 +15,7 @@ import nibabel as nib import numpy as np +import torch from monai.data import ImageDataset from monai.transforms import ( @@ -93,7 +94,7 @@ def test_dataset(self): # loading no meta, int dataset = ImageDataset(full_names, dtype=np.float16) for d, _ in zip(dataset, ref_data): - self.assertEqual(d.dtype, np.float16) + self.assertEqual(d.dtype, torch.float16) # loading with meta, no transform dataset = ImageDataset(full_names, image_only=False) diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py index 62b1147aa5..32f31a6af1 100644 --- a/tests/test_image_rw.py +++ b/tests/test_image_rw.py @@ -16,6 +16,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.data.image_reader import ITKReader, NibabelReader, PILReader @@ -52,14 +53,15 @@ def nifti_rw(self, test_data, reader, writer, dtype, resample=True): saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) self.assertTrue(os.path.exists(saved_path)) loader = LoadImage(reader=reader, squeeze_non_spatial_dims=True) - data, meta = loader(saved_path) + data = loader(saved_path) + meta = data.meta if meta["original_channel_dim"] == -1: _test_data = moveaxis(test_data, 0, -1) else: _test_data = test_data[0] if resample: _test_data = moveaxis(_test_data, 0, 1) - assert_allclose(data, _test_data) + assert_allclose(data, torch.as_tensor(_test_data)) @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, "ITKWriter"])) def test_2d(self, reader, writer): @@ -99,12 +101,13 @@ def png_rw(self, test_data, reader, writer, dtype, resample=True): saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) self.assertTrue(os.path.exists(saved_path)) loader = LoadImage(reader=reader) - data, meta = loader(saved_path) + data = loader(saved_path) + meta = data.meta if meta["original_channel_dim"] == -1: _test_data = moveaxis(test_data, 0, -1) else: _test_data = test_data[0] - assert_allclose(data, _test_data) + assert_allclose(data, torch.as_tensor(_test_data)) @parameterized.expand(itertools.product([PILReader, ITKReader], [PILWriter, ITKWriter])) def test_2d(self, reader, writer): diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 64c26c4012..183689113a 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -22,6 +22,7 @@ Compose, CopyItemsd, EnsureTyped, + FromMetaTensord, Invertd, LoadImaged, Orientationd, @@ -50,6 +51,7 @@ def test_invert(self): transform = Compose( [ LoadImaged(KEYS), + FromMetaTensord(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), @@ -156,7 +158,7 @@ def test_invert(self): reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) - reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"] + reverted_name = item["label_inverted"].meta["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index 6855a59041..45ca1e865b 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -80,7 +80,7 @@ def test_saved_3d_no_resize_content(self): saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) for i in range(8): filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - img, _ = LoadImage("nibabelreader")(filepath) + img = LoadImage("nibabelreader")(filepath) self.assertEqual(img.shape, (1, 2, 2, 8)) def test_squeeze_end_dims(self): @@ -102,7 +102,8 @@ def test_squeeze_end_dims(self): # 2d image w channel saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data) - im, meta = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) + im = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) + meta = im.meta self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4) self.assertTrue(meta["dim"][0] == im.ndim) diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py index e1f6a28998..f6aa58c191 100644 --- a/tests/test_resample_to_match.py +++ b/tests/test_resample_to_match.py @@ -21,7 +21,7 @@ from monai.data.image_reader import ITKReader, NibabelReader from monai.data.image_writer import ITKWriter -from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImaged +from monai.transforms import Compose, EnsureChannelFirstd, FromMetaTensord, LoadImaged, ResampleToMatch, SaveImaged from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config TEST_CASES = ["itkreader", "nibabelreader"] @@ -41,7 +41,13 @@ def setUp(self): @parameterized.expand(itertools.product([NibabelReader, ITKReader], ["monai.data.NibabelWriter", ITKWriter])) def test_correct(self, reader, writer): with tempfile.TemporaryDirectory() as temp_dir: - loader = Compose([LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2"))]) + loader = Compose( + [ + LoadImaged(("im1", "im2"), reader=reader), + FromMetaTensord(("im1", "im2")), + EnsureChannelFirstd(("im1", "im2")), + ] + ) data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) im_mod, meta = ResampleToMatch()(data["im2"], data["im2_meta_dict"], data["im1_meta_dict"]) diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py index d9dbeee133..14536891df 100644 --- a/tests/test_resample_to_matchd.py +++ b/tests/test_resample_to_matchd.py @@ -17,6 +17,7 @@ Compose, CopyItemsd, EnsureChannelFirstd, + FromMetaTensord, Invertd, Lambda, LoadImaged, @@ -47,6 +48,7 @@ def test_correct(self): transforms = Compose( [ LoadImaged(("im1", "im2")), + FromMetaTensord(("im1", "im2")), EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2", "im2_meta_dict"), names=("im3", "im3_meta_dict")), ResampleToMatchd("im3", "im1_meta_dict"), From 55f483ac8789492d38760cf740247928347eca5e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 21 Apr 2022 16:16:00 +0100 Subject: [PATCH 19/27] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/image_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index c2a3e32d1e..61722c5490 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -102,11 +102,11 @@ def __getitem__(self, index: int): # load data and optionally meta img = self.loader(self.image_files[index]) - if self.seg_files is not None: - seg = self.loader(self.seg_files[index]) if not self.image_only: meta_data = img.meta - if self.seg_files is not None: + if self.seg_files is not None: + seg = self.loader(self.seg_files[index]) + if not self.image_only: seg_meta_data = seg.meta # apply the transforms From ef0150480bd6a25c47e4c68e5935eb7e8399b88e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 21 Apr 2022 16:30:33 +0100 Subject: [PATCH 20/27] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_load_image.py | 2 +- tests/test_load_imaged.py | 2 +- tests/test_warp.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 03d26b4b77..1a96e8bbe8 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -123,7 +123,7 @@ def get_data(self, _obj): TESTS_META = [] -for track_meta in (True, False): +for track_meta in (False, True): TESTS_META.append([{}, (128, 128, 128), track_meta]) TESTS_META.append([{"reader": "ITKReader", "fallback_only": False}, (128, 128, 128), track_meta]) diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index af7886b63b..3b2fc4f58b 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -36,7 +36,7 @@ TEST_CASE_2 = [{"keys": KEYS, "reader": "ITKReader", "fallback_only": False}, (128, 128, 128)] TESTS_META = [] -for track_meta in (True, False): +for track_meta in (False, True): TESTS_META.append([{"keys": KEYS}, (128, 128, 128), track_meta]) TESTS_META.append([{"keys": KEYS, "reader": "ITKReader", "fallback_only": False}, (128, 128, 128), track_meta]) diff --git a/tests/test_warp.py b/tests/test_warp.py index c039b57211..56f1de23f2 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -153,6 +153,7 @@ def test_grad(self): def load_img_and_sample_ddf(): # load image img = LoadImaged(keys="img")({"img": FILE_PATH})["img"] + img = img.detach().numpy() # W, H, D -> D, H, W img = img.transpose((2, 1, 0)) From 752ed8b8100fd8b482ccaa51205a82c24849be4f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 21 Apr 2022 17:26:45 +0100 Subject: [PATCH 21/27] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_dataset_summary.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/test_dataset_summary.py b/tests/test_dataset_summary.py index 51840f77ea..d0531b28a0 100644 --- a/tests/test_dataset_summary.py +++ b/tests/test_dataset_summary.py @@ -19,6 +19,9 @@ from monai.data import Dataset, DatasetSummary, create_test_image_3d from monai.transforms import LoadImaged +from monai.transforms.compose import Compose +from monai.transforms.meta_utility.dictionary import FromMetaTensord +from monai.transforms.utility.dictionary import ToNumpyd from monai.utils import set_determinism from monai.utils.enums import PostFix @@ -50,12 +53,17 @@ def test_spacing_intensity(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - dataset = Dataset( - data=data_dicts, transform=LoadImaged(keys=["image", "label"], meta_keys=["test1", "test2"]) + t = Compose( + [ + LoadImaged(keys=["image", "label"]), + FromMetaTensord(keys=["image", "label"]), + ToNumpyd(keys=["image", "label", "image_meta_dict", "label_meta_dict"]), + ] ) + dataset = Dataset(data=data_dicts, transform=t) # test **kwargs of `DatasetSummary` for `DataLoader` - calculator = DatasetSummary(dataset, num_workers=4, meta_key="test1", collate_fn=test_collate) + calculator = DatasetSummary(dataset, num_workers=4, meta_key="image_meta_dict", collate_fn=test_collate) target_spacing = calculator.get_target_spacing() self.assertEqual(target_spacing, (1.0, 1.0, 1.0)) @@ -85,7 +93,8 @@ def test_anisotropic_spacing(self): {"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels) ] - dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"])) + t = Compose([LoadImaged(keys=["image", "label"]), FromMetaTensord(keys=["image", "label"])]) + dataset = Dataset(data=data_dicts, transform=t) calculator = DatasetSummary(dataset, num_workers=4, meta_key_postfix=PostFix.meta()) From 76c64591785636c22a1d088cf5de9e46fc60800e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 21 Apr 2022 18:15:53 +0100 Subject: [PATCH 22/27] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/testing_data/inference.json | 8 ++++++++ tests/testing_data/inference.yaml | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index cc9ddef866..67f5b0e46f 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -34,6 +34,14 @@ "_target_": "LoadImaged", "keys": "image" }, + { + "_target_": "FromMetaTensord", + "keys": "image" + }, + { + "_target_": "ToNumpyd", + "keys": "image" + }, { "_target_": "EnsureChannelFirstd", "keys": "image" diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index 4973d4473f..0d29085a05 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -27,6 +27,10 @@ preprocessing: transforms: - _target_: LoadImaged keys: image + - _target_: FromMetaTensord + keys: image + - _target_: ToNumpyd + keys: image - _target_: EnsureChannelFirstd keys: image - _target_: ScaleIntensityd From ad1d780336f40b0c4f79d6197d82187b69b68658 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Apr 2022 10:15:03 +0100 Subject: [PATCH 23/27] fix wsi test Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_wsireader_new.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py index 63d61dfeb3..69e147a9c3 100644 --- a/tests/test_wsireader_new.py +++ b/tests/test_wsireader_new.py @@ -20,7 +20,7 @@ from monai.data import DataLoader, Dataset from monai.data.wsi_reader import WSIReader -from monai.transforms import Compose, LoadImaged, ToTensord +from monai.transforms import Compose, FromMetaTensord, LoadImaged from monai.utils import first, optional_import from monai.utils.enums import PostFix from tests.utils import download_url_or_skip_test, testing_data_config @@ -196,7 +196,7 @@ def test_with_dataloader(self, file_path, level, expected_spatial_shape, expecte train_transform = Compose( [ LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level), - ToTensord(keys=["image"]), + FromMetaTensord(keys=["image"]), ] ) dataset = Dataset([{"image": file_path}], transform=train_transform) From 9e7b3d68022d4479de999f7d4bc55dbd521f2b13 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Apr 2022 11:28:49 +0100 Subject: [PATCH 24/27] changes after code review Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 32 +++++++++++++ monai/transforms/__init__.py | 3 ++ monai/transforms/io/array.py | 37 +-------------- monai/transforms/utility/dictionary.py | 10 ++-- monai/transforms/utils.py | 66 ++++++++++++++++++++++++++ tests/test_load_image.py | 2 +- 6 files changed, 108 insertions(+), 42 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 9196f0186c..4f011bc3ed 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -17,8 +17,10 @@ import torch +from monai.config.type_definitions import NdarrayTensor from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms from monai.data.utils import decollate_batch, list_data_collate +from monai.transforms.utils import remove_extra_metadata from monai.utils.enums import PostFix __all__ = ["MetaTensor"] @@ -232,3 +234,33 @@ def affine(self) -> torch.Tensor: def affine(self, d: torch.Tensor) -> None: """Set the affine.""" self.meta["affine"] = d + + @staticmethod + def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict): + """ + Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary, + convert that to `torch.Tensor`, too. Remove any superfluous metadata. + + Args: + im: Input image (`np.ndarray` or `torch.Tensor`) + meta: Metadata dictionary. + + Returns: + By default, a `MetaTensor` is returned. + However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. + """ + img = torch.as_tensor(im) + + # if not tracking metadata, return `torch.Tensor` + if not get_track_meta() or meta is None: + return img + + # ensure affine is of type `torch.Tensor` + if "affine" in meta: + meta["affine"] = torch.as_tensor(meta["affine"]) + + # remove any superfluous metadata. + remove_extra_metadata(meta) + + # return the `MetaTensor` + return MetaTensor(img, meta=meta) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index c2385499b3..75f95f4d5b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -569,6 +569,7 @@ generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, + get_extra_metadata_keys, get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, @@ -582,6 +583,8 @@ map_spatial_axes, print_transform_backends, rand_choice, + remove_extra_metadata, + remove_keys, rescale_array, rescale_array_int_max, rescale_instance_array, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index c99e59b965..3c4e8d59dd 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -29,7 +29,6 @@ from monai.data import image_writer from monai.data.folder_layout import FolderLayout from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader -from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.transforms.transform import Transform from monai.transforms.utility.array import EnsureChannelFirst @@ -247,45 +246,11 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option meta_data = switch_endianness(meta_data, "<") meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader - img = self.join_im_and_meta(img_array, meta_data) + img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data) if self.ensure_channel_first: img = EnsureChannelFirst()(img) return img - @staticmethod - def join_im_and_meta(im, meta: dict): - img = torch.as_tensor(im) - - # if not tracking metadata, return torch.Tensor - if not get_track_meta() or meta is None: - return img - - if "affine" in meta: - meta["affine"] = torch.as_tensor(meta["affine"]) - - # TODO: delete extra metadata - for i in range(8): - for k in ("dim", "pixdim"): - if f"{k}[{i}]" in meta: - del meta[f"{k}[{i}]"] - for k in ( - # "original_affine", - # "spatial_shape", - # "spacing", - "srow_x", - "srow_y", - "srow_z", - "quatern_b", - "quatern_c", - "quatern_d", - "qoffset_x", - "qoffset_y", - "qoffset_z", - ): - if k in meta: - del meta[k] - return MetaTensor(img, meta=meta) - class SaveImage(Transform): """ diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 2143e729a9..1b10b6ee85 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -427,11 +427,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if dim > 0: # don't update affine if channel dim affine = d[split_meta_key]["affine"] # type: ignore ndim = len(affine) - shift = ( - torch.eye(ndim, device=affine.device, dtype=affine.dtype) - if isinstance(affine, torch.Tensor) - else np.eye(ndim) - ) + shift: NdarrayOrTensor + if isinstance(affine, torch.Tensor): + shift = torch.eye(ndim, device=affine.device, dtype=affine.dtype) + else: + shift = np.eye(ndim) shift[dim - 1, -1] = i # type: ignore d[split_meta_key]["affine"] = d[split_meta_key]["affine"] @ shift # type: ignore diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 847614adfe..3be45b570d 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -105,6 +105,9 @@ "convert_pad_mode", "convert_to_contiguous", "get_unique_labels", + "remove_keys", + "remove_extra_metadata", + "get_extra_metadata_keys", ] @@ -1573,5 +1576,68 @@ def convert_to_contiguous(data, **kwargs): return data +def remove_keys(data: dict, keys: List[str]) -> None: + """ + Remove keys from a dictionary. Operates in-place so nothing is returned. + + Args: + data: dictionary to be modified. + keys: keys to be deleted from dictionary. + + Returns: + `None` + """ + for k in keys: + _ = data.pop(k, None) + + +def remove_extra_metadata(meta: dict) -> None: + """ + Remove extra metadata from the dictionary. Operates in-place so nothing is returned. + + Args: + meta: dictionary containing metadata to be modified. + + Returns: + `None` + """ + keys = get_extra_metadata_keys() + remove_keys(data=meta, keys=keys) + + +def get_extra_metadata_keys() -> List[str]: + """ + Get a list of unnecessary keys for metadata that can be removed. + + Returns: + List of keys to be removed. + """ + keys = [ + "srow_x", + "srow_y", + "srow_z", + "quatern_b", + "quatern_c", + "quatern_d", + "qoffset_x", + "qoffset_y", + "qoffset_z", + "dim", + "pixdim", + *[f"dim[{i}]" for i in range(8)], + *[f"pixdim[{i}]" for i in range(8)], + ] + + # TODO: it would be good to remove these, but they are currently being used in the + # codebase. + # keys += [ + # "original_affine", + # "spatial_shape", + # "spacing", + # ] + + return keys + + if __name__ == "__main__": print_transform_backends() diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 1a96e8bbe8..9509d26283 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -242,7 +242,7 @@ def test_kwargs(self): reader = ITKReader() img = reader.read(filename, fallback_only=False) result_raw = reader.get_data(img) - result_raw = LoadImage.join_im_and_meta(*result_raw) + result_raw = MetaTensor.ensure_torch_and_prune_meta(*result_raw) self.assertTupleEqual(result.shape, result_raw.shape) def test_my_reader(self): From ad58d7979b656fdfc04a9f57a09c622060633738 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 3 May 2022 10:31:49 +0100 Subject: [PATCH 25/27] fixes unit tests Signed-off-by: Wenqi Li --- monai/data/dataset_summary.py | 11 ++++++----- monai/data/utils.py | 2 ++ tests/test_nifti_saver.py | 2 -- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index b447585d3e..fe9618737b 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -18,9 +18,9 @@ from monai.config import KeysCollection from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset +from monai.data.utils import affine_to_spacing from monai.transforms import concatenate -from monai.utils import convert_data_type -from monai.utils.enums import PostFix +from monai.utils import PostFix, convert_data_type DEFAULT_POST_FIX = PostFix.meta() @@ -84,7 +84,7 @@ def collect_meta_data(self): raise ValueError(f"To collect meta data for the dataset, key `{self.meta_key}` must exist in `data`.") self.all_meta_data.append(data[self.meta_key]) - def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): + def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold: int = 3, percentile: float = 10.0): """ Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, @@ -93,7 +93,7 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. Args: - spacing_key: key of spacing in meta data (default: ``pixdim``). + spacing_key: key of the affine used to compute spacing in meta data (default: ``affine``). anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``). percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to replace that axis. @@ -103,7 +103,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") - all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0) + spacings = [affine_to_spacing(data[spacing_key][0], 3)[None] for data in self.all_meta_data] + all_spacings = concatenate(to_cat=spacings, axis=0) all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) target_spacing = np.median(all_spacings, axis=0) diff --git a/monai/data/utils.py b/monai/data/utils.py index 2bd7b49731..df40ca3af3 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -581,6 +581,8 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z Returns: an `r` dimensional vector of spacing. """ + if len(affine.shape) != 2 or affine.shape[0] != affine.shape[1]: + raise ValueError(f"affine must be a square matrix, got {affine.shape}.") _affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype) if isinstance(_affine, torch.Tensor): spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0)) diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index 45ca1e865b..bd1bf86207 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -103,9 +103,7 @@ def test_squeeze_end_dims(self): saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data) im = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) - meta = im.meta self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4) - self.assertTrue(meta["dim"][0] == im.ndim) if __name__ == "__main__": From 2fb16f91bb61bcf2373f986352571af9f9760656 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 3 May 2022 14:46:35 +0100 Subject: [PATCH 26/27] fixes: TypeError: __str__ returned non-string (type list) Signed-off-by: Wenqi Li --- monai/data/meta_obj.py | 4 +--- tests/test_meta_tensor.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index e38e009e96..3c173dc7c4 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -189,9 +189,7 @@ def get_default_meta(self) -> dict: def __repr__(self) -> str: """String representation of class.""" - out: str = super().__repr__() - - out += "\nMetaData\n" + out: str = "\nMetaData\n" if self.meta is not None: out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) else: diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 05356fcc84..1c1066ce8e 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -198,6 +198,7 @@ def test_conv(self, device): conv = torch.nn.Conv3d(im.shape[1], 5, 3) conv.to(device) out = conv(im) + self.assertTrue(str(out).startswith("\nMetaData")) self.check(out, im, shape=False, vals=False, ids=False) @parameterized.expand(TESTS) From c74f0b3aaf8a197278c6b6759b7d193dbe96ac4b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 3 May 2022 17:34:41 +0100 Subject: [PATCH 27/27] fixes unit test tests.test_lr_finder Signed-off-by: Wenqi Li --- monai/data/meta_obj.py | 2 +- monai/data/meta_tensor.py | 4 ++-- tests/test_meta_tensor.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 3c173dc7c4..7196ce31f1 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -153,7 +153,7 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call Returns: Returns `None`, but `self` should be updated to have the copied attribute. """ - attributes = [getattr(i, attribute) for i in input_objs] + attributes = [getattr(i, attribute) for i in input_objs if hasattr(i, attribute)] if len(attributes) > 0: val = attributes[0] if deep_copy: diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 4f011bc3ed..aae012fec0 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -188,8 +188,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: kwargs = {} ret = super().__torch_function__(func, types, args, kwargs) # if `out` has been used as argument, metadata is not copied, nothing to do. - if "out" in kwargs: - return ret + # if "out" in kwargs: + # return ret # we might have 1 or multiple outputs. Might be MetaTensor, might be something # else (e.g., `__repr__` returns a string). # Convert to list (if necessary), process, and at end remove list if one was added. diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 1c1066ce8e..17fbb3cb35 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -273,14 +273,13 @@ def test_amp(self): def test_out(self): """Test when `out` is given as an argument.""" m1, _ = self.get_im() - m1_orig = deepcopy(m1) m2, _ = self.get_im() m3, _ = self.get_im() torch.add(m2, m3, out=m1) m1_add = m2 + m3 assert_allclose(m1, m1_add) - self.check_meta(m1, m1_orig) + # self.check_meta(m1, m2) # meta is from first input tensor @parameterized.expand(TESTS) def test_collate(self, device, dtype):