From d639b838b8e2a275edf21f8d53cb1ca2e867ef16 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 15:54:19 +0000 Subject: [PATCH 01/17] revivse utilities Signed-off-by: Wenqi Li --- monai/data/meta_obj.py | 12 +++++++++--- monai/data/utils.py | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 67f4109c86..70ce3d49ca 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -113,7 +113,7 @@ def copy_items(data): return data.detach().clone() return deepcopy(data) - def copy_meta_from(self, input_objs, copy_attr=True) -> None: + def copy_meta_from(self, input_objs, copy_attr=True, keys=None): """ Copy metadata from a `MetaObj` or an iterable of `MetaObj` instances. @@ -121,13 +121,19 @@ def copy_meta_from(self, input_objs, copy_attr=True) -> None: input_objs: list of `MetaObj` to copy data from. copy_attr: whether to copy each attribute with `MetaObj.copy_item`. note that if the attribute is a nested list or dict, only a shallow copy will be done. + keys: the keys of attributes to copy from the ``input_objs``. + If None, all keys from the input_objs will be copied. """ first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self) + if not hasattr(first_meta, "__dict__"): + return self first_meta = first_meta.__dict__ + keys = first_meta.keys() if keys is None else keys if not copy_attr: - self.__dict__ = first_meta.copy() # shallow copy for performance + self.__dict__ = {a: first_meta[a] for a in keys if a in first_meta} # shallow copy for performance else: - self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in first_meta}) + self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in keys if a in first_meta}) + return self @staticmethod def get_default_meta() -> dict: diff --git a/monai/data/utils.py b/monai/data/utils.py index 96e3e15d95..ec4de6aa01 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -46,6 +46,7 @@ ensure_tuple_size, fall_back_tuple, first, + get_equivalent_dtype, issequenceiterable, look_up_option, optional_import, @@ -924,6 +925,7 @@ def to_affine_nd(r: np.ndarray | int, affine: NdarrayTensor, dtype=np.float64) - an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type) """ + dtype = get_equivalent_dtype(dtype, np.ndarray) affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] affine_np = affine_np.copy() if affine_np.ndim != 2: From c30dbc8655ccdf8d094f1640a218820ff753d99e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 16:22:28 +0000 Subject: [PATCH 02/17] adding new traceable keys Signed-off-by: Wenqi Li --- monai/transforms/inverse.py | 11 +++++++++++ monai/utils/enums.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 6d9060723a..88afb30898 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -72,6 +72,17 @@ def trace_key(key: Hashable = None): return f"{TraceKeys.KEY_SUFFIX}" return f"{key}{TraceKeys.KEY_SUFFIX}" + @staticmethod + def transform_keys(): + """The keys to store necessary info of an applied transform.""" + return ( + TraceKeys.CLASS_NAME, + TraceKeys.ID, + TraceKeys.TRACING, + TraceKeys.LAZY_EVALUATION, + TraceKeys.DO_TRANSFORM, + ) + def get_transform_info( self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None ) -> dict: diff --git a/monai/utils/enums.py b/monai/utils/enums.py index d1ac19f4b4..f1c75f71c3 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -310,6 +310,8 @@ class TraceKeys(StrEnum): DO_TRANSFORM: str = "do_transforms" KEY_SUFFIX: str = "_transforms" NONE: str = "none" + TRACING: str = "tracing" + LAZY_EVALUATION: str = "lazy_evaluation" class CommonKeys(StrEnum): From 23ccf839e748324cd548ba2e8ab3446dbe76ff1f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 16:50:09 +0000 Subject: [PATCH 03/17] update apply Signed-off-by: Wenqi Li --- monai/transforms/lazy/functional.py | 44 +++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 13aa753a55..455d3b088d 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -11,6 +11,7 @@ from __future__ import annotations +import numpy as np import torch from monai.data.meta_tensor import MetaTensor @@ -22,37 +23,70 @@ kwargs_from_pending, resample, ) +from monai.utils import LazyAttr __all__ = ["apply_transforms"] -def apply_transforms(data: torch.Tensor | MetaTensor, pending: list | None = None): +def apply_transforms( + data: torch.Tensor | MetaTensor, + pending: list | None = None, + mode: str | None = None, + padding_mode: str | None = None, + dtype=np.float64, +): """ This method applies pending transforms to `data` tensors. Args: data: A torch Tensor or a monai MetaTensor. pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor. + mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers). + Interpolation mode to calculate output values. Defaults to None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used + and the value represents the order of the spline interpolation. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to None. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `mode` is an integer, using numpy/cupy backends, this argument accepts + {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}. + See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html + dtype: data type for resampling computation. Defaults to ``float64``. + If ``None``, use the data type of input data`. """ if isinstance(data, MetaTensor) and pending is None: - pending = data.pending_operations + pending = data.pending_operations.copy() + data.clear_pending_operations() pending = [] if pending is None else pending if not pending: - return data + return data, [] cumulative_xform = affine_from_pending(pending[0]) cur_kwargs = kwargs_from_pending(pending[0]) + override_kwargs = {} + if mode is not None: + override_kwargs[LazyAttr.INTERP_MODE] = mode + if padding_mode is not None: + override_kwargs[LazyAttr.PADDING_MODE] = padding_mode + override_kwargs[LazyAttr.DTYPE] = data.dtype if dtype is None else dtype for p in pending[1:]: new_kwargs = kwargs_from_pending(p) if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs): # carry out an intermediate resample here due to incompatibility between arguments - data = resample(data, cumulative_xform, cur_kwargs) + _cur_kwargs = cur_kwargs.copy() + _cur_kwargs.update(override_kwargs) + sp_size = _cur_kwargs.pop(LazyAttr.SHAPE, None) + data = resample(data, cumulative_xform, sp_size, _cur_kwargs) next_matrix = affine_from_pending(p) cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) - data = resample(data, cumulative_xform, cur_kwargs) + cur_kwargs.update(override_kwargs) + sp_size = cur_kwargs.pop(LazyAttr.SHAPE, None) + data = resample(data, cumulative_xform, sp_size, cur_kwargs) if isinstance(data, MetaTensor): data.clear_pending_operations() data.affine = data.affine @ to_affine_nd(3, cumulative_xform) From 17f4e53e66a05c9bf93db8f1ac51d8964a0724d2 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 17:53:57 +0000 Subject: [PATCH 04/17] update utilities Signed-off-by: Wenqi Li --- monai/transforms/inverse.py | 156 ++++++++++++++++++---------- monai/transforms/lazy/functional.py | 12 ++- monai/transforms/lazy/utils.py | 19 +++- monai/utils/enums.py | 1 + tests/test_apply.py | 4 +- tests/test_meta_tensor.py | 1 + tests/test_resample.py | 4 +- tests/test_traceable_transform.py | 22 ++-- 8 files changed, 140 insertions(+), 79 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 88afb30898..f2f04fe85f 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -20,9 +20,11 @@ import torch from monai import transforms +from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.meta_tensor import MetaTensor -from monai.transforms.transform import Transform -from monai.utils.enums import TraceKeys +from monai.data.utils import to_affine_nd +from monai.transforms.transform import LazyTransform, Transform +from monai.utils import LazyAttr, MetaKeys, TraceKeys, convert_to_dst_type, convert_to_numpy, convert_to_tensor __all__ = ["TraceableTransform", "InvertibleTransform"] @@ -83,76 +85,122 @@ def transform_keys(): TraceKeys.DO_TRANSFORM, ) - def get_transform_info( - self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None - ) -> dict: + def get_transform_info(self) -> dict: """ Return a dictionary with the relevant information pertaining to an applied transform. - - Args: - data: input data. Can be dictionary or MetaTensor. We can use `shape` to - determine the original size of the object (unless that has been given - explicitly, see `orig_size`). - key: if data is a dictionary, data[key] will be modified. - extra_info: if desired, any extra information pertaining to the applied - transform can be stored in this dictionary. These are often needed for - computing the inverse transformation. - orig_size: sometimes during the inverse it is useful to know what the size - of the original image was, in which case it can be supplied here. - - Returns: - Dictionary of data pertaining to the applied transformation. """ - info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} - if orig_size is not None: - info[TraceKeys.ORIG_SIZE] = orig_size - elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"): - info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] - elif hasattr(data, "shape"): - info[TraceKeys.ORIG_SIZE] = data.shape[1:] - if extra_info is not None: - info[TraceKeys.EXTRA_INFO] = extra_info - # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) - if hasattr(self, "_do_transform"): # RandomizableTransform - info[TraceKeys.DO_TRANSFORM] = self._do_transform - return info - - def push_transform( - self, data, key: Hashable = None, extra_info: dict | None = None, orig_size: tuple | None = None - ) -> None: + vals = ( + self.__class__.__name__, + id(self), + self.tracing, + self.lazy_evaluation if isinstance(self, LazyTransform) else False, + self._do_transform if hasattr(self, "_do_transform") else True, + ) + return dict(zip(self.transform_keys(), vals)) + + def push_transform(self, data, *args, **kwargs): + """replace bool, whether to rewrite applied_operation (default False)""" + transform_info = self.get_transform_info() + lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False) + do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True) + kwargs = kwargs or {} + replace = kwargs.pop("replace", False) # whether to rewrite the most recently pushed transform info + if replace and get_track_meta() and isinstance(data, MetaTensor): + if not lazy_eval: + xform = self.pop_transform(data, check=False) if do_transform else {} + meta_obj = self.push_transform(data, orig_size=xform.get(TraceKeys.ORIG_SIZE), extra_info=xform) + return data.copy_meta_from(meta_obj) + if do_transform: + meta_obj = self.push_transform(data, pending_info=data.pending_operations.pop()) # type: ignore + return data.copy_meta_from(meta_obj) + return data + kwargs["lazy_evaluation"] = lazy_eval + kwargs["transform_info"] = transform_info + meta_obj = TraceableTransform.track_transform_tensor(data, *args, **kwargs) + return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data + + @classmethod + def track_transform_tensor( + cls, + data, + key: Hashable = None, + sp_size=None, + affine=None, + extra_info: dict | None = None, + orig_size: tuple | None = None, + transform_info=None, + pending_info=None, + lazy_evaluation=False, + ): """ Push to a stack of applied transforms. - Args: data: dictionary of data or `MetaTensor`. key: if data is a dictionary, data[key] will be modified. + sp_size: can be tensor or numpy, but will be converted to a list of ints. + affine: extra_info: if desired, any extra information pertaining to the applied transform can be stored in this dictionary. These are often needed for computing the inverse transformation. orig_size: sometimes during the inverse it is useful to know what the size of the original image was, in which case it can be supplied here. - + transform_info: info from self.get_transform_info(). + pending_info: info from self.get_transform_info() and previously pushed to pending_operations + lazy_evaluation: Returns: None, but data has been updated to store the applied transformation. """ - if not self.tracing: - return - info = self.get_transform_info(data, key, extra_info, orig_size) - - if isinstance(data, MetaTensor): - data.push_applied_operation(info) - elif isinstance(data, Mapping): - if key in data and isinstance(data[key], MetaTensor): - data[key].push_applied_operation(info) + data_t = data[key] if key is not None else data # compatible with the dict data representation + out_obj = MetaObj() + data_t = convert_to_tensor(data=data_t, track_meta=get_track_meta()) + out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys()) + + # not lazy evaluation, directly update the affine but don't push the stacks + if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor): + orig_affine = data_t.peek_pending_affine() + orig_affine = convert_to_dst_type(orig_affine, affine)[0] + affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=affine.dtype) + out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu")) + if not ( + isinstance(data_t, MetaTensor) + and get_track_meta() + and transform_info + and transform_info.get(TraceKeys.TRACING) + ): + if key is not None: + data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t + return data + return out_obj # return with data_t as tensor if get_track_meta() is False + + info = transform_info + # track the current spatial shape + info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape() if orig_size is None else orig_size + if extra_info is not None: + info[TraceKeys.EXTRA_INFO] = extra_info + if isinstance(pending_info, dict): + for k in TraceableTransform.transform_keys(): + pending_info.pop(k, None) + info.update(pending_info) + + # push the transform info to the applied_operation or pending_operation stack + if lazy_evaluation: + if sp_size is None: + if LazyAttr.SHAPE not in info: + warnings.warn("spatial size is None in push transform.") + else: + info[LazyAttr.SHAPE] = tuple(convert_to_numpy(sp_size, wrap_sequence=True).tolist()) + if affine is None: + if LazyAttr.AFFINE not in info: + warnings.warn("affine is None in push transform.") else: - # If this is the first, create list - if self.trace_key(key) not in data: - if not isinstance(data, dict): - data = dict(data) - data[self.trace_key(key)] = [] - data[self.trace_key(key)].append(info) + info[LazyAttr.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu")) + out_obj.push_pending_operation(info) else: - warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.") + out_obj.push_applied_operation(info) + if key is not None: + data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t + return data + return out_obj def check_transforms_match(self, transform: Mapping) -> None: """Check transforms are of same instance.""" diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 455d3b088d..2ae8be2201 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -11,6 +11,8 @@ from __future__ import annotations +from typing import Any + import numpy as np import torch @@ -34,6 +36,7 @@ def apply_transforms( mode: str | None = None, padding_mode: str | None = None, dtype=np.float64, + align_corners: bool | None = None, ): """ This method applies pending transforms to `data` tensors. @@ -55,6 +58,9 @@ def apply_transforms( See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html dtype: data type for resampling computation. Defaults to ``float64``. If ``None``, use the data type of input data`. + align_corners: Geometrically, we consider the pixels of the input as squares rather than points, when using + the PyTorch resampling backend. Defaults to ``None``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ if isinstance(data, MetaTensor) and pending is None: pending = data.pending_operations.copy() @@ -66,11 +72,13 @@ def apply_transforms( cumulative_xform = affine_from_pending(pending[0]) cur_kwargs = kwargs_from_pending(pending[0]) - override_kwargs = {} + override_kwargs: dict[str, Any] = {} if mode is not None: override_kwargs[LazyAttr.INTERP_MODE] = mode if padding_mode is not None: override_kwargs[LazyAttr.PADDING_MODE] = padding_mode + if align_corners is not None: + override_kwargs[LazyAttr.ALIGN_CORNERS] = align_corners override_kwargs[LazyAttr.DTYPE] = data.dtype if dtype is None else dtype for p in pending[1:]: @@ -89,7 +97,7 @@ def apply_transforms( data = resample(data, cumulative_xform, sp_size, cur_kwargs) if isinstance(data, MetaTensor): data.clear_pending_operations() - data.affine = data.affine @ to_affine_nd(3, cumulative_xform) + data.affine = data.affine @ to_affine_nd(len(data.affine) - 1, cumulative_xform) for p in pending: data.push_applied_operation(p) diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index e03314d655..1672695ed2 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -105,21 +105,30 @@ def is_compatible_apply_kwargs(kwargs_1, kwargs_2): return True -def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None): +def resample(data: torch.Tensor, matrix: NdarrayOrTensor, spatial_size, kwargs: dict | None = None): """ - This is a minimal implementation of resample that always uses Affine. + This is a minimal implementation of resample that always uses SpatialResample. + `kwargs` supports "lazy_dtype", "lazy_padding_mode", "lazy_interpolation_mode", "lazy_dtype", "lazy_align_corners". + + See Also: + :py:class:`monai.transforms.SpatialResample` """ if not Affine.is_affine_shaped(matrix): raise NotImplementedError("calling dense grid resample API not implemented") kwargs = {} if kwargs is None else kwargs init_kwargs = { - "spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:], "dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype), + "align_corners": kwargs.pop(LazyAttr.ALIGN_CORNERS, None), } + img = convert_to_tensor(data=data, track_meta=monai.data.get_track_meta()) + init_affine = monai.data.to_affine_nd(len(matrix) - 1, img.affine) call_kwargs = { + "spatial_size": img.peek_pending_shape() if spatial_size is None else spatial_size, + "dst_affine": init_affine @ monai.utils.convert_to_dst_type(matrix, init_affine)[0], "mode": kwargs.pop(LazyAttr.INTERP_MODE, None), "padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None), } - resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs) + resampler = monai.transforms.SpatialResample(**init_kwargs) + # resampler.lazy_evaluation = False with resampler.trace_transform(False): # don't track this transform in `data` - return resampler(img=data, **call_kwargs) + return resampler(img=img, **call_kwargs) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index f1c75f71c3..7a4aaaece7 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -625,3 +625,4 @@ class LazyAttr(StrEnum): PADDING_MODE = "lazy_padding_mode" INTERP_MODE = "lazy_interpolation_mode" DTYPE = "lazy_dtype" + ALIGN_CORNERS = "lazy_align_corners" diff --git a/tests/test_apply.py b/tests/test_apply.py index 8974360381..cf74721267 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -32,7 +32,7 @@ def single_2d_transform_cases(): (torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)), ( torch.as_tensor(get_arange_img((16, 16))), - [{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (1, 45, 45)}], + [{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (45, 45)}], (1, 45, 45), ), ] @@ -51,6 +51,8 @@ def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape else: for p in pending_transforms: tensor_.push_pending_operation(p) + if not isinstance(p, dict): + return result, transforms = apply_transforms(tensor_) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 936b3526c4..2d8fd3abe6 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -447,6 +447,7 @@ def test_astype(self): self.assertIsInstance(t.astype(pt_types), torch.Tensor) self.assertIsInstance(t.astype("torch.float", device="cpu"), torch.Tensor) + @unittest.skip("non metatensor tests") def test_transforms(self): key = "im" _, im = self.get_im() diff --git a/tests/test_resample.py b/tests/test_resample.py index 3ebdd23e02..8b2ffea194 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -28,13 +28,13 @@ def rotate_90_2d(): return t -RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[2, 5, 8], [1, 4, 7], [0, 3, 6]])] +RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[0, 3, 6], [0, 3, 6], [0, 3, 6]])] class TestResampleFunction(unittest.TestCase): @parameterized.expand(RESAMPLE_FUNCTION_CASES) def test_resample_function_impl(self, img, matrix, expected): - out = resample(convert_to_tensor(img), matrix) + out = resample(convert_to_tensor(img), matrix, img.shape[1:]) assert_allclose(out[0], expected, type_test=False) diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py index cf3da7139a..d7506ef6a1 100644 --- a/tests/test_traceable_transform.py +++ b/tests/test_traceable_transform.py @@ -13,16 +13,18 @@ import unittest +import torch + from monai.transforms.inverse import TraceableTransform class _TraceTest(TraceableTransform): def __call__(self, data): - self.push_transform(data) + self.push_transform(data, "image") return data def pop(self, data): - self.pop_transform(data) + self.pop_transform(data, "image") return data @@ -34,21 +36,11 @@ def test_default(self): data = {"image": "test"} data = a(data) # adds to the stack - self.assertTrue(isinstance(data[expected_key], list)) - self.assertEqual(data[expected_key][0]["class"], "_TraceTest") + self.assertEqual(data["image"], "test") + data = {"image": torch.tensor(1.0)} data = a(data) # adds to the stack - self.assertEqual(len(data[expected_key]), 2) - self.assertEqual(data[expected_key][-1]["class"], "_TraceTest") - - with self.assertRaises(IndexError): - a.pop({"test": "test"}) # no stack in the data - data = a.pop(data) - data = a.pop(data) - self.assertEqual(data[expected_key], []) - - with self.assertRaises(IndexError): # no more items - a.pop(data) + self.assertEqual(data["image"].applied_operations[0]["class"], "_TraceTest") if __name__ == "__main__": From 4727d60cc0d240099c6fd0129f6713da02cdbe0f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 19:09:10 +0000 Subject: [PATCH 05/17] update tests Signed-off-by: Wenqi Li --- monai/transforms/inverse.py | 30 +++++++++++++++++++----------- tests/test_box_transform.py | 16 +++++++--------- tests/test_random_order.py | 14 ++------------ 3 files changed, 28 insertions(+), 32 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index f2f04fe85f..4fd1fc7917 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -111,7 +111,9 @@ def push_transform(self, data, *args, **kwargs): meta_obj = self.push_transform(data, orig_size=xform.get(TraceKeys.ORIG_SIZE), extra_info=xform) return data.copy_meta_from(meta_obj) if do_transform: - meta_obj = self.push_transform(data, pending_info=data.pending_operations.pop()) # type: ignore + xform = data.pending_operations.pop() # type: ignore + xform.update(transform_info) + meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval) return data.copy_meta_from(meta_obj) return data kwargs["lazy_evaluation"] = lazy_eval @@ -129,7 +131,6 @@ def track_transform_tensor( extra_info: dict | None = None, orig_size: tuple | None = None, transform_info=None, - pending_info=None, lazy_evaluation=False, ): """ @@ -137,16 +138,17 @@ def track_transform_tensor( Args: data: dictionary of data or `MetaTensor`. key: if data is a dictionary, data[key] will be modified. - sp_size: can be tensor or numpy, but will be converted to a list of ints. - affine: + sp_size: the expected output spatial size when the transform is applied. + it can be tensor or numpy, but will be converted to a list of integers. + affine: the affine representation of the (spatial) transform in the image space. + When the transform is applied, meta_tensor.affine will be updated to ``meta_tensor.affine @ affine``. extra_info: if desired, any extra information pertaining to the applied transform can be stored in this dictionary. These are often needed for computing the inverse transformation. orig_size: sometimes during the inverse it is useful to know what the size of the original image was, in which case it can be supplied here. transform_info: info from self.get_transform_info(). - pending_info: info from self.get_transform_info() and previously pushed to pending_operations - lazy_evaluation: + lazy_evaluation: whether to push the transform to pending_operations or applied_operations. Returns: None, but data has been updated to store the applied transformation. """ @@ -175,12 +177,9 @@ def track_transform_tensor( info = transform_info # track the current spatial shape info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape() if orig_size is None else orig_size + # include extra_info if extra_info is not None: info[TraceKeys.EXTRA_INFO] = extra_info - if isinstance(pending_info, dict): - for k in TraceableTransform.transform_keys(): - pending_info.pop(k, None) - info.update(pending_info) # push the transform info to the applied_operation or pending_operation stack if lazy_evaluation: @@ -198,7 +197,16 @@ def track_transform_tensor( else: out_obj.push_applied_operation(info) if key is not None: - data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t + if isinstance(data_t, MetaTensor): + data[key] = data_t.copy_meta_from(out_obj) + else: + # If this is the first, create list + x_k = TraceableTransform.trace_key(key) + if x_k not in data: + if not isinstance(data, dict): + data = dict(data) + data[x_k] = [] + data[x_k].append(info) return data return out_obj diff --git a/tests/test_box_transform.py b/tests/test_box_transform.py index 94bd6ade52..ecd54d189c 100644 --- a/tests/test_box_transform.py +++ b/tests/test_box_transform.py @@ -150,7 +150,7 @@ def test_value_3d( transform_convert_mode = ConvertBoxModed(**keys) convert_result = transform_convert_mode(data) assert_allclose( - convert_result["boxes"], expected_convert_result, type_test=True, device_test=True, atol=1e-3 + convert_result["boxes"], expected_convert_result, type_test=False, device_test=False, atol=1e-3 ) invert_transform_convert_mode = Invertd( @@ -159,7 +159,7 @@ def test_value_3d( data_back = invert_transform_convert_mode(convert_result) if "boxes_transforms" in data_back: # if the transform is tracked in dict: self.assertEqual(data_back["boxes_transforms"], []) # it should be updated - assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, atol=1e-3) # test ZoomBoxd transform_zoom = ZoomBoxd( @@ -167,7 +167,7 @@ def test_value_3d( ) zoom_result = transform_zoom(data) self.assertEqual(len(zoom_result["image"].applied_operations), 1) - assert_allclose(zoom_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=1e-3) + assert_allclose(zoom_result["boxes"], expected_zoom_result, type_test=False, atol=1e-3) invert_transform_zoom = Invertd( keys=["image", "boxes"], transform=transform_zoom, orig_keys=["image", "boxes"] ) @@ -181,9 +181,7 @@ def test_value_3d( ) zoom_result = transform_zoom(data) self.assertEqual(len(zoom_result["image"].applied_operations), 1) - assert_allclose( - zoom_result["boxes"], expected_zoom_keepsize_result, type_test=True, device_test=True, atol=1e-3 - ) + assert_allclose(zoom_result["boxes"], expected_zoom_keepsize_result, type_test=False, atol=1e-3) # test RandZoomBoxd transform_zoom = RandZoomBoxd( @@ -216,7 +214,7 @@ def test_value_3d( affine_result = transform_affine(data) if "boxes_transforms" in affine_result: self.assertEqual(len(affine_result["boxes_transforms"]), 1) - assert_allclose(affine_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=0.01) + assert_allclose(affine_result["boxes"], expected_zoom_result, type_test=False, atol=0.01) invert_transform_affine = Invertd(keys=["boxes"], transform=transform_affine, orig_keys=["boxes"]) data_back = invert_transform_affine(affine_result) if "boxes_transforms" in data_back: @@ -233,7 +231,7 @@ def test_value_3d( flip_result = transform_flip(data) if "boxes_transforms" in flip_result: self.assertEqual(len(flip_result["boxes_transforms"]), 1) - assert_allclose(flip_result["boxes"], expected_flip_result, type_test=True, device_test=True, atol=1e-3) + assert_allclose(flip_result["boxes"], expected_flip_result, type_test=False, atol=1e-3) invert_transform_flip = Invertd( keys=["image", "boxes"], transform=transform_flip, orig_keys=["image", "boxes"] ) @@ -307,7 +305,7 @@ def test_value_3d( ) rotate_result = transform_rotate(data) self.assertEqual(len(rotate_result["image"].applied_operations), 1) - assert_allclose(rotate_result["boxes"], expected_rotate_result, type_test=True, device_test=True, atol=1e-3) + assert_allclose(rotate_result["boxes"], expected_rotate_result, type_test=False, atol=1e-3) invert_transform_rotate = Invertd( keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"] ) diff --git a/tests/test_random_order.py b/tests/test_random_order.py index a60202dd78..eb3284c2ae 100644 --- a/tests/test_random_order.py +++ b/tests/test_random_order.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.data import MetaTensor -from monai.transforms import RandomOrder, TraceableTransform +from monai.transforms import RandomOrder from monai.transforms.compose import Compose from monai.utils import set_determinism from monai.utils.enums import TraceKeys @@ -77,11 +77,7 @@ def test_inverse(self, transform, invertible, use_metatensor): if invertible: for k in KEYS: - t = ( - fwd_data1[TraceableTransform.trace_key(k)][-1] - if not use_metatensor - else fwd_data1[k].applied_operations[-1] - ) + t = fwd_data1[k].applied_operations[-1] # make sure the RandomOrder applied_order was stored self.assertEqual(t[TraceKeys.CLASS_NAME], RandomOrder.__name__) @@ -94,12 +90,6 @@ def test_inverse(self, transform, invertible, use_metatensor): for i, _fwd_inv_data in enumerate(fwd_inv_data): if invertible: for k in KEYS: - # check transform was removed - if not use_metatensor: - self.assertTrue( - len(_fwd_inv_data[TraceableTransform.trace_key(k)]) - < len(fwd_data[i][TraceableTransform.trace_key(k)]) - ) # check data is same as original (and different from forward) self.assertEqual(_fwd_inv_data[k], data[k]) self.assertNotEqual(_fwd_inv_data[k], fwd_data[i][k]) From 696e41134b14b2ba7b1cadfc98ae2bcf37ec4ab1 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 22:02:38 +0000 Subject: [PATCH 06/17] backward compatible Signed-off-by: Wenqi Li --- monai/transforms/inverse.py | 29 +++++++++++++++++------------ tests/test_box_transform.py | 16 +++++++++------- tests/test_meta_tensor.py | 1 - tests/test_random_order.py | 14 ++++++++++++-- tests/test_traceable_transform.py | 24 +++++++++++++++++------- 5 files changed, 55 insertions(+), 29 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 4fd1fc7917..18c22c82fa 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -149,34 +149,39 @@ def track_transform_tensor( of the original image was, in which case it can be supplied here. transform_info: info from self.get_transform_info(). lazy_evaluation: whether to push the transform to pending_operations or applied_operations. + Returns: None, but data has been updated to store the applied transformation. """ data_t = data[key] if key is not None else data # compatible with the dict data representation out_obj = MetaObj() - data_t = convert_to_tensor(data=data_t, track_meta=get_track_meta()) - out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys()) + # after deprecating metadict, we should always convert data_t to metatensor here + if isinstance(data_t, MetaTensor): + out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys()) + else: + warnings.warn("data_t is not a MetaTensor.") - # not lazy evaluation, directly update the affine but don't push the stacks if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor): + # not lazy evaluation, directly update the metatensor affine (don't push to the stack) orig_affine = data_t.peek_pending_affine() orig_affine = convert_to_dst_type(orig_affine, affine)[0] affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=affine.dtype) out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu")) - if not ( - isinstance(data_t, MetaTensor) - and get_track_meta() - and transform_info - and transform_info.get(TraceKeys.TRACING) - ): - if key is not None: + + if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)): + if isinstance(data, Mapping): data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t return data return out_obj # return with data_t as tensor if get_track_meta() is False info = transform_info # track the current spatial shape - info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape() if orig_size is None else orig_size + if orig_size is not None: + info[TraceKeys.ORIG_SIZE] = orig_size + elif isinstance(data_t, MetaTensor): + info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape() + elif hasattr(data_t, "shape"): + info[TraceKeys.ORIG_SIZE] = data_t.shape[1:] # include extra_info if extra_info is not None: info[TraceKeys.EXTRA_INFO] = extra_info @@ -196,7 +201,7 @@ def track_transform_tensor( out_obj.push_pending_operation(info) else: out_obj.push_applied_operation(info) - if key is not None: + if isinstance(data, Mapping): if isinstance(data_t, MetaTensor): data[key] = data_t.copy_meta_from(out_obj) else: diff --git a/tests/test_box_transform.py b/tests/test_box_transform.py index ecd54d189c..94bd6ade52 100644 --- a/tests/test_box_transform.py +++ b/tests/test_box_transform.py @@ -150,7 +150,7 @@ def test_value_3d( transform_convert_mode = ConvertBoxModed(**keys) convert_result = transform_convert_mode(data) assert_allclose( - convert_result["boxes"], expected_convert_result, type_test=False, device_test=False, atol=1e-3 + convert_result["boxes"], expected_convert_result, type_test=True, device_test=True, atol=1e-3 ) invert_transform_convert_mode = Invertd( @@ -159,7 +159,7 @@ def test_value_3d( data_back = invert_transform_convert_mode(convert_result) if "boxes_transforms" in data_back: # if the transform is tracked in dict: self.assertEqual(data_back["boxes_transforms"], []) # it should be updated - assert_allclose(data_back["boxes"], data["boxes"], type_test=False, atol=1e-3) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) # test ZoomBoxd transform_zoom = ZoomBoxd( @@ -167,7 +167,7 @@ def test_value_3d( ) zoom_result = transform_zoom(data) self.assertEqual(len(zoom_result["image"].applied_operations), 1) - assert_allclose(zoom_result["boxes"], expected_zoom_result, type_test=False, atol=1e-3) + assert_allclose(zoom_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=1e-3) invert_transform_zoom = Invertd( keys=["image", "boxes"], transform=transform_zoom, orig_keys=["image", "boxes"] ) @@ -181,7 +181,9 @@ def test_value_3d( ) zoom_result = transform_zoom(data) self.assertEqual(len(zoom_result["image"].applied_operations), 1) - assert_allclose(zoom_result["boxes"], expected_zoom_keepsize_result, type_test=False, atol=1e-3) + assert_allclose( + zoom_result["boxes"], expected_zoom_keepsize_result, type_test=True, device_test=True, atol=1e-3 + ) # test RandZoomBoxd transform_zoom = RandZoomBoxd( @@ -214,7 +216,7 @@ def test_value_3d( affine_result = transform_affine(data) if "boxes_transforms" in affine_result: self.assertEqual(len(affine_result["boxes_transforms"]), 1) - assert_allclose(affine_result["boxes"], expected_zoom_result, type_test=False, atol=0.01) + assert_allclose(affine_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=0.01) invert_transform_affine = Invertd(keys=["boxes"], transform=transform_affine, orig_keys=["boxes"]) data_back = invert_transform_affine(affine_result) if "boxes_transforms" in data_back: @@ -231,7 +233,7 @@ def test_value_3d( flip_result = transform_flip(data) if "boxes_transforms" in flip_result: self.assertEqual(len(flip_result["boxes_transforms"]), 1) - assert_allclose(flip_result["boxes"], expected_flip_result, type_test=False, atol=1e-3) + assert_allclose(flip_result["boxes"], expected_flip_result, type_test=True, device_test=True, atol=1e-3) invert_transform_flip = Invertd( keys=["image", "boxes"], transform=transform_flip, orig_keys=["image", "boxes"] ) @@ -305,7 +307,7 @@ def test_value_3d( ) rotate_result = transform_rotate(data) self.assertEqual(len(rotate_result["image"].applied_operations), 1) - assert_allclose(rotate_result["boxes"], expected_rotate_result, type_test=False, atol=1e-3) + assert_allclose(rotate_result["boxes"], expected_rotate_result, type_test=True, device_test=True, atol=1e-3) invert_transform_rotate = Invertd( keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"] ) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 2d8fd3abe6..936b3526c4 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -447,7 +447,6 @@ def test_astype(self): self.assertIsInstance(t.astype(pt_types), torch.Tensor) self.assertIsInstance(t.astype("torch.float", device="cpu"), torch.Tensor) - @unittest.skip("non metatensor tests") def test_transforms(self): key = "im" _, im = self.get_im() diff --git a/tests/test_random_order.py b/tests/test_random_order.py index eb3284c2ae..a60202dd78 100644 --- a/tests/test_random_order.py +++ b/tests/test_random_order.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.data import MetaTensor -from monai.transforms import RandomOrder +from monai.transforms import RandomOrder, TraceableTransform from monai.transforms.compose import Compose from monai.utils import set_determinism from monai.utils.enums import TraceKeys @@ -77,7 +77,11 @@ def test_inverse(self, transform, invertible, use_metatensor): if invertible: for k in KEYS: - t = fwd_data1[k].applied_operations[-1] + t = ( + fwd_data1[TraceableTransform.trace_key(k)][-1] + if not use_metatensor + else fwd_data1[k].applied_operations[-1] + ) # make sure the RandomOrder applied_order was stored self.assertEqual(t[TraceKeys.CLASS_NAME], RandomOrder.__name__) @@ -90,6 +94,12 @@ def test_inverse(self, transform, invertible, use_metatensor): for i, _fwd_inv_data in enumerate(fwd_inv_data): if invertible: for k in KEYS: + # check transform was removed + if not use_metatensor: + self.assertTrue( + len(_fwd_inv_data[TraceableTransform.trace_key(k)]) + < len(fwd_data[i][TraceableTransform.trace_key(k)]) + ) # check data is same as original (and different from forward) self.assertEqual(_fwd_inv_data[k], data[k]) self.assertNotEqual(_fwd_inv_data[k], fwd_data[i][k]) diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py index d7506ef6a1..b2e613f388 100644 --- a/tests/test_traceable_transform.py +++ b/tests/test_traceable_transform.py @@ -13,18 +13,16 @@ import unittest -import torch - from monai.transforms.inverse import TraceableTransform class _TraceTest(TraceableTransform): def __call__(self, data): - self.push_transform(data, "image") + self.push_transform(data) return data def pop(self, data): - self.pop_transform(data, "image") + self.pop_transform(data) return data @@ -32,15 +30,27 @@ class TestTraceable(unittest.TestCase): def test_default(self): expected_key = "_transforms" a = _TraceTest() + for x in a.transform_keys(): + self.assertTrue(x in a.get_transform_info()) self.assertEqual(a.trace_key(), expected_key) data = {"image": "test"} data = a(data) # adds to the stack - self.assertEqual(data["image"], "test") + self.assertTrue(isinstance(data[expected_key], list)) + self.assertEqual(data[expected_key][0]["class"], "_TraceTest") - data = {"image": torch.tensor(1.0)} data = a(data) # adds to the stack - self.assertEqual(data["image"].applied_operations[0]["class"], "_TraceTest") + self.assertEqual(len(data[expected_key]), 2) + self.assertEqual(data[expected_key][-1]["class"], "_TraceTest") + + with self.assertRaises(IndexError): + a.pop({"test": "test"}) # no stack in the data + data = a.pop(data) + data = a.pop(data) + self.assertEqual(data[expected_key], []) + + with self.assertRaises(IndexError): # no more items + a.pop(data) if __name__ == "__main__": From 47684d736e365750b72ffc3cd0b0915242e97146 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 22:18:33 +0000 Subject: [PATCH 07/17] fixes #5509 Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 7 ++++--- monai/transforms/inverse.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 22f9502708..ccf53a07e5 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -523,10 +523,10 @@ def ensure_torch_and_prune_meta( By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. """ - img = convert_to_tensor(im) # potentially ascontiguousarray + img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` - if not get_track_meta() or meta is None: + if not isinstance(img, MetaTensor): return img # remove any superfluous metadata. @@ -540,7 +540,8 @@ def ensure_torch_and_prune_meta( meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta) # return the `MetaTensor` - return MetaTensor(img, meta=meta) + img.meta = meta + return img def __repr__(self): """ diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 18c22c82fa..c741786e0b 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -170,6 +170,8 @@ def track_transform_tensor( if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)): if isinstance(data, Mapping): + if not isinstance(data, dict): + data = dict(data) data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t return data return out_obj # return with data_t as tensor if get_track_meta() is False @@ -202,15 +204,14 @@ def track_transform_tensor( else: out_obj.push_applied_operation(info) if isinstance(data, Mapping): + if not isinstance(data, dict): + data = dict(data) if isinstance(data_t, MetaTensor): data[key] = data_t.copy_meta_from(out_obj) else: - # If this is the first, create list x_k = TraceableTransform.trace_key(key) if x_k not in data: - if not isinstance(data, dict): - data = dict(data) - data[x_k] = [] + data[x_k] = [] # If this is the first, create list data[x_k].append(info) return data return out_obj From c508d5a438a8d73ff0ca57f3b50f28246f001e5e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 22:28:41 +0000 Subject: [PATCH 08/17] update types Signed-off-by: Wenqi Li --- monai/data/meta_obj.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 70ce3d49ca..111428906b 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -107,6 +107,8 @@ def flatten_meta_objs(*args: Iterable): @staticmethod def copy_items(data): """returns a copy of the data. list and dict are shallow copied for efficiency purposes.""" + if isinstance(data, (bool, int, float, str, type(None))): + return data if isinstance(data, (list, dict, np.ndarray)): return data.copy() if isinstance(data, torch.Tensor): From d562e0162217f6eca18922ed7bccfcb9f5ca99d6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 22:42:45 +0000 Subject: [PATCH 09/17] fixes docstrings Signed-off-by: Wenqi Li --- monai/transforms/inverse.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index c741786e0b..fba889737b 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -99,7 +99,16 @@ def get_transform_info(self) -> dict: return dict(zip(self.transform_keys(), vals)) def push_transform(self, data, *args, **kwargs): - """replace bool, whether to rewrite applied_operation (default False)""" + """ + Push to a stack of applied transforms of ``data``. + + Args: + data: dictionary of data or `MetaTensor`. + args: additional positional arguments to track_transform_meta. + kwargs: additional keyword arguments to track_transform_meta, + set ``replace=True`` (default False) to rewrite the last transform infor in + applied_operation/pending_operation based on ``self.get_transform_info()``. + """ transform_info = self.get_transform_info() lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False) do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True) @@ -118,11 +127,11 @@ def push_transform(self, data, *args, **kwargs): return data kwargs["lazy_evaluation"] = lazy_eval kwargs["transform_info"] = transform_info - meta_obj = TraceableTransform.track_transform_tensor(data, *args, **kwargs) + meta_obj = TraceableTransform.track_transform_meta(data, *args, **kwargs) return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data @classmethod - def track_transform_tensor( + def track_transform_meta( cls, data, key: Hashable = None, @@ -134,7 +143,8 @@ def track_transform_tensor( lazy_evaluation=False, ): """ - Push to a stack of applied transforms. + Update a stack of applied/pending transforms metadata of ``data``. + Args: data: dictionary of data or `MetaTensor`. key: if data is a dictionary, data[key] will be modified. @@ -151,7 +161,9 @@ def track_transform_tensor( lazy_evaluation: whether to push the transform to pending_operations or applied_operations. Returns: - None, but data has been updated to store the applied transformation. + + For backward compatibility, if ``data`` is a dictionary, it returns the dictionary with + updated ``data[key]``. Otherwise, this function returns a MetaObj with updated transform metadata. """ data_t = data[key] if key is not None else data # compatible with the dict data representation out_obj = MetaObj() From 5824f7824d8eb41a06c9d0833602a629d6690c59 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 31 Jan 2023 00:06:09 +0000 Subject: [PATCH 10/17] fixes merging issues Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 2 ++ monai/transforms/inverse.py | 7 ++++--- monai/transforms/lazy/functional.py | 3 --- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index ccf53a07e5..e094642f16 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -541,6 +541,8 @@ def ensure_torch_and_prune_meta( # return the `MetaTensor` img.meta = meta + if MetaKeys.AFFINE in meta: + img.affine = meta[MetaKeys.AFFINE] # this uses the affine property setter return img def __repr__(self): diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index fba889737b..49560eaf6c 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -126,7 +126,10 @@ def push_transform(self, data, *args, **kwargs): return data.copy_meta_from(meta_obj) return data kwargs["lazy_evaluation"] = lazy_eval - kwargs["transform_info"] = transform_info + if "transform_info" in kwargs and isinstance(kwargs["transform_info"], dict): + kwargs["transform_info"].update(transform_info) + else: + kwargs["transform_info"] = transform_info meta_obj = TraceableTransform.track_transform_meta(data, *args, **kwargs) return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data @@ -170,8 +173,6 @@ def track_transform_meta( # after deprecating metadict, we should always convert data_t to metatensor here if isinstance(data_t, MetaTensor): out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys()) - else: - warnings.warn("data_t is not a MetaTensor.") if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor): # not lazy evaluation, directly update the metatensor affine (don't push to the stack) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 2ae8be2201..773adf270f 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -17,7 +17,6 @@ import torch from monai.data.meta_tensor import MetaTensor -from monai.data.utils import to_affine_nd from monai.transforms.lazy.utils import ( affine_from_pending, combine_transforms, @@ -96,8 +95,6 @@ def apply_transforms( sp_size = cur_kwargs.pop(LazyAttr.SHAPE, None) data = resample(data, cumulative_xform, sp_size, cur_kwargs) if isinstance(data, MetaTensor): - data.clear_pending_operations() - data.affine = data.affine @ to_affine_nd(len(data.affine) - 1, cumulative_xform) for p in pending: data.push_applied_operation(p) From 49eaa5fd1e3b4391c75e27fac34fd76802cec14e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 31 Jan 2023 00:51:13 +0000 Subject: [PATCH 11/17] default affine Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index e094642f16..560aaf776c 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -543,6 +543,8 @@ def ensure_torch_and_prune_meta( img.meta = meta if MetaKeys.AFFINE in meta: img.affine = meta[MetaKeys.AFFINE] # this uses the affine property setter + else: + img.affine = MetaTensor.get_default_affine() return img def __repr__(self): From ab7c44c81361e15e42b89b926370d0c943cb1745 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 31 Jan 2023 10:27:12 +0000 Subject: [PATCH 12/17] update based on comments Signed-off-by: Wenqi Li --- monai/data/meta_obj.py | 2 +- monai/transforms/inverse.py | 4 ++-- tests/test_traceable_transform.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 111428906b..6c90f41a26 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -107,7 +107,7 @@ def flatten_meta_objs(*args: Iterable): @staticmethod def copy_items(data): """returns a copy of the data. list and dict are shallow copied for efficiency purposes.""" - if isinstance(data, (bool, int, float, str, type(None))): + if isinstance(data, (type(None), int, float, bool, complex, str, tuple, bytes, type, range, slice)): return data if isinstance(data, (list, dict, np.ndarray)): return data.copy() diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 49560eaf6c..80a27b98b5 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -75,7 +75,7 @@ def trace_key(key: Hashable = None): return f"{key}{TraceKeys.KEY_SUFFIX}" @staticmethod - def transform_keys(): + def transform_info_keys(): """The keys to store necessary info of an applied transform.""" return ( TraceKeys.CLASS_NAME, @@ -96,7 +96,7 @@ def get_transform_info(self) -> dict: self.lazy_evaluation if isinstance(self, LazyTransform) else False, self._do_transform if hasattr(self, "_do_transform") else True, ) - return dict(zip(self.transform_keys(), vals)) + return dict(zip(self.transform_info_keys(), vals)) def push_transform(self, data, *args, **kwargs): """ diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py index b2e613f388..42906c84d2 100644 --- a/tests/test_traceable_transform.py +++ b/tests/test_traceable_transform.py @@ -30,7 +30,7 @@ class TestTraceable(unittest.TestCase): def test_default(self): expected_key = "_transforms" a = _TraceTest() - for x in a.transform_keys(): + for x in a.transform_info_keys(): self.assertTrue(x in a.get_transform_info()) self.assertEqual(a.trace_key(), expected_key) From 9eec6b0734214e8713e6d3ffd27015d7d351f07e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 31 Jan 2023 10:30:11 +0000 Subject: [PATCH 13/17] update based on comments Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 560aaf776c..46463431c6 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -503,15 +503,15 @@ def clone(self): @staticmethod def ensure_torch_and_prune_meta( - im: NdarrayTensor, meta: dict, simple_keys: bool = False, pattern: str | None = None, sep: str = "." + im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "." ): """ - Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary, + Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary, convert that to `torch.Tensor`, too. Remove any superfluous metadata. Args: im: Input image (`np.ndarray` or `torch.Tensor`) - meta: Metadata dictionary. + meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor. simple_keys: whether to keep only a simple subset of metadata keys. pattern: combined with `sep`, a regular expression used to match and prune keys in the metadata (nested dictionary), default to None, no key deletion. @@ -521,7 +521,7 @@ def ensure_torch_and_prune_meta( Returns: By default, a `MetaTensor` is returned. - However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. + However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray From 0ad927961872f259789745a442733fd71c116f2d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 31 Jan 2023 11:54:54 +0000 Subject: [PATCH 14/17] fixes typing Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 46463431c6..d77fd782c2 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -529,6 +529,9 @@ def ensure_torch_and_prune_meta( if not isinstance(img, MetaTensor): return img + if meta is None: + meta = {} + # remove any superfluous metadata. if simple_keys: # ensure affine is of type `torch.Tensor` @@ -540,6 +543,8 @@ def ensure_torch_and_prune_meta( meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta) # return the `MetaTensor` + if meta is None: + meta = {} img.meta = meta if MetaKeys.AFFINE in meta: img.affine = meta[MetaKeys.AFFINE] # this uses the affine property setter From 1e039ad46b6c45aecd8cfb563e8983152ce7b530 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 31 Jan 2023 15:18:47 +0000 Subject: [PATCH 15/17] update based on comments Signed-off-by: Wenqi Li --- monai/data/meta_obj.py | 5 ++--- monai/utils/__init__.py | 1 + monai/utils/misc.py | 10 ++++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 6c90f41a26..86ce7e33fb 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -19,8 +19,7 @@ import numpy as np import torch -from monai.utils.enums import TraceKeys -from monai.utils.misc import first +from monai.utils import TraceKeys, first, is_immutable _TRACK_META = True @@ -107,7 +106,7 @@ def flatten_meta_objs(*args: Iterable): @staticmethod def copy_items(data): """returns a copy of the data. list and dict are shallow copied for efficiency purposes.""" - if isinstance(data, (type(None), int, float, bool, complex, str, tuple, bytes, type, range, slice)): + if is_immutable(data): return data if isinstance(data, (list, dict, np.ndarray)): return data.copy() diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 49daefcdda..92344d644c 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -73,6 +73,7 @@ is_module_ver_at_least, is_scalar, is_scalar_tensor, + is_immutable, issequenceiterable, list_to_dict, path_to_uri, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 554cc1b278..1674732637 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -36,6 +36,7 @@ "star_zip_with", "first", "issequenceiterable", + "is_immutable", "ensure_tuple", "ensure_tuple_size", "ensure_tuple_rep", @@ -116,6 +117,15 @@ def issequenceiterable(obj: Any) -> bool: return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) +def is_immutable(obj: Any) -> bool: + """ + Determine if the object is an immutable object. + + see also https://github.com/python/cpython/blob/740050af0493030b1f6ebf0b9ac39a356e2e74b6/Lib/copy.py#L109 + """ + return isinstance(obj, (type(None), int, float, bool, complex, str, tuple, bytes, type, range, slice)) + + def ensure_tuple(vals: Any, wrap_array: bool = False) -> tuple[Any, ...]: """ Returns a tuple of `vals`. From 710d1262c24c8650c7802b640f085f3b03c0ad43 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 1 Feb 2023 06:18:33 +0000 Subject: [PATCH 16/17] update based on comments Signed-off-by: Wenqi Li --- monai/transforms/lazy/functional.py | 2 +- monai/utils/misc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 773adf270f..44e46d4bdb 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -32,7 +32,7 @@ def apply_transforms( data: torch.Tensor | MetaTensor, pending: list | None = None, - mode: str | None = None, + mode: str | int | None = None, padding_mode: str | None = None, dtype=np.float64, align_corners: bool | None = None, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 1674732637..f2086f22d5 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -121,7 +121,7 @@ def is_immutable(obj: Any) -> bool: """ Determine if the object is an immutable object. - see also https://github.com/python/cpython/blob/740050af0493030b1f6ebf0b9ac39a356e2e74b6/Lib/copy.py#L109 + see also https://github.com/python/cpython/blob/3.11/Lib/copy.py#L109 """ return isinstance(obj, (type(None), int, float, bool, complex, str, tuple, bytes, type, range, slice)) From b0def3ef34c0e0158a5ed4d984c7b89ec3dc09bd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 1 Feb 2023 06:20:08 +0000 Subject: [PATCH 17/17] fixes memory layout Signed-off-by: Wenqi Li --- monai/data/image_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index c1cfcfd8ca..14583482ca 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1031,7 +1031,7 @@ def _get_array_data(self, img): img: a Nibabel image object loaded from an image file. """ - return np.asanyarray(img.dataobj) + return np.asanyarray(img.dataobj, order="C") class NumpyReader(ImageReader):