From a25417cbf6b036cf8e111d8a1afbfe672677b822 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 26 Oct 2022 12:44:24 +0100 Subject: [PATCH 1/3] step 1 based on PR #4922, adding a lazy transform Signed-off-by: Wenqi Li --- docs/source/transforms.rst | 5 +++++ monai/data/meta_obj.py | 14 ++++++++++++++ monai/data/meta_tensor.py | 9 ++++++++- monai/transforms/__init__.py | 10 +++++++++- monai/transforms/spatial/array.py | 19 +++++++++++++------ monai/transforms/transform.py | 30 +++++++++++++++++++++++++++++- monai/utils/__init__.py | 1 + monai/utils/enums.py | 14 ++++++++++++++ 8 files changed, 93 insertions(+), 9 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 874f01a945..2384c9cce1 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -63,6 +63,11 @@ Generic Interfaces .. autoclass:: OneOf :members: +`LazyTransform` +^^^^^^^^^^^^^^^ +.. autoclass:: LazyTransform + :members: + Vanilla Transforms ------------------ diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 5061efc1ce..6aab05dc94 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -82,6 +82,7 @@ class MetaObj: def __init__(self): self._meta: dict = MetaObj.get_default_meta() self._applied_operations: list = MetaObj.get_default_applied_operations() + self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops self._is_batch: bool = False @staticmethod @@ -199,6 +200,19 @@ def push_applied_operation(self, t: Any) -> None: def pop_applied_operation(self) -> Any: return self._applied_operations.pop() + @property + def pending_operations(self) -> list[dict]: + """Get the pending operations. Defaults to ``[]``.""" + if hasattr(self, "_pending_operations"): + return self._pending_operations + return MetaObj.get_default_applied_operations() # the same default as applied_ops + + def push_pending_operation(self, t: Any) -> None: + self._pending_operations.append(t) + + def pop_pending_operation(self) -> Any: + return self._pending_operations.pop() + @property def is_batch(self) -> bool: """Return whether object is part of batch or not.""" diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 5a7d81ad8e..74b5b386a8 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -23,7 +23,7 @@ from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option -from monai.utils.enums import MetaKeys, PostFix, SpaceKeys +from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys from monai.utils.type_conversion import convert_data_type, convert_to_tensor __all__ = ["MetaTensor"] @@ -445,6 +445,13 @@ def pixdim(self): return [affine_to_spacing(a) for a in self.affine] return affine_to_spacing(self.affine) + def peek_pending_shape(self): + """Get the currently expected spatial shape as if all the pending operations are executed.""" + return self.pending_operations[-1][LazyAttr.SHAPE] if self.pending_operations else self.array.shape[1:] + + def peek_pending_affine(self): + return self.pending_operations[-1][LazyAttr.AFFINE] if self.pending_operations else self.affine + def new_empty(self, size, dtype=None, device=None, requires_grad=False): """ must be defined for deepcopy to work diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 389571d16f..fe96d60d70 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -449,7 +449,15 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import ( + LazyTransform, + MapTransform, + Randomizable, + RandomizableTransform, + ThreadUnsafe, + Transform, + apply_transform, +) from .utility.array import ( AddChannel, AddCoordinateChannels, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index dcddefce3a..67104ba7b7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -31,7 +31,7 @@ from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import Randomizable, RandomizableTransform, Transform +from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( convert_pad_mode, create_control_grid, @@ -48,6 +48,7 @@ GridSampleMode, GridSamplePadMode, InterpolateMode, + LazyAttr, NdimageMode, NumpyPadMode, SplineMode, @@ -751,7 +752,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return data -class Flip(InvertibleTransform): +class Flip(InvertibleTransform, LazyTransform): """ Reverses the order of elements along the given spatial axis. Preserves shape. See `torch.flip` documentation for additional details: @@ -771,14 +772,13 @@ class Flip(InvertibleTransform): def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis - def update_meta(self, img, shape, axes): + def update_meta(self, affine, shape, axes): # shape and axes include the channel dim - affine = img.affine mat = convert_to_dst_type(torch.eye(len(affine)), affine)[0] for axis in axes: sp = axis - 1 mat[sp, sp], mat[sp, -1] = mat[sp, sp] * -1, shape[axis] - 1 - img.affine = affine @ mat + return affine @ mat def forward_image(self, img, axes) -> torch.Tensor: return torch.flip(img, axes) @@ -790,9 +790,16 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) axes = map_spatial_axes(img.ndim, self.spatial_axis) + if self.lazy_evaluation and isinstance(img, MetaTensor): + spatial_chn_shape = [1, *convert_to_numpy(img.peek_pending_shape()).tolist()] + affine = img.peek_pending_affine() + lazy_affine = self.update_meta(affine, spatial_chn_shape, axes) + img.push_pending_operation({LazyAttr.SHAPE: img.peek_pending_shape(), LazyAttr.AFFINE: lazy_affine}) + self.push_transform(img) + return img out = self.forward_image(img, axes) if get_track_meta(): - self.update_meta(out, out.shape, axes) + out.affine = self.update_meta(out.affine, out.shape, axes) # type: ignore self.push_transform(out) return out diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 21d057f5d3..764fc4d25b 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -26,7 +26,15 @@ from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "LazyTransform", + "Randomizable", + "RandomizableTransform", + "Transform", + "MapTransform", +] ReturnType = TypeVar("ReturnType") @@ -131,6 +139,26 @@ class ThreadUnsafe: pass +class LazyTransform: + """ + An interface to denote whether a transform can be applied lazily. It is designed as part of lazy resampling of + multiple transforms. Classes inheriting this interface should be able to operate in two modes: + + - ``set_lazy_eval(False)`` (eagerly evaluating), the transform should output the finalized transform + results without any pending operations. Both primary data and metadata of the outputs should be up-to-date. + - ``set_lazy_eval(True)`` (lazily evaluating), the transform should only execute necessary/lightweight/lossless + metadata updates to track any pending operations. The goal is that, in a later stage, the pending operations + can be grouped together and evaluated more efficiently and accurately -- each transforms when evaluated + independently may cause some information losses. + + """ + + lazy_evaluation: bool = False + + def set_lazy_eval(self, value: bool): + self.lazy_evaluation = value + + class Randomizable(ThreadUnsafe): """ An interface for handling random state locally, currently based on a class diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index c5419cb9af..21d3621090 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -34,6 +34,7 @@ InterpolateMode, InverseKeys, JITMetadataKeys, + LazyAttr, LossReduction, MetaKeys, Method, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 79edbd7451..4fd9bea557 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -54,6 +54,7 @@ "AlgoEnsembleKeys", "HoVerNetMode", "HoVerNetBranch", + "LazyAttr", ] @@ -616,3 +617,16 @@ class HoVerNetBranch(StrEnum): HV = "horizontal_vertical" NP = "nucleus_prediction" NC = "type_prediction" + + +class LazyAttr(StrEnum): + """ + MetaTensor with pending operations requires some key attributes tracked especially when the primary array + is not up-to-date due to lazy evaluation. + This class specifies the set of key attributes to be tracked for each MetaTensor. + """ + + SHAPE = "lazy_shape" # spatial shape + AFFINE = "lazy_affine" + PADDING_MODE = "lazy_padding_mode" + INTERP_MODE = "lazy_interpolation_mode" From 52b097edaa39f1d59405149638f7fb513d7219af Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 29 Oct 2022 08:27:58 +0100 Subject: [PATCH 2/3] update Signed-off-by: Wenqi Li --- docs/source/transforms.rst | 25 +++++++-- monai/transforms/__init__.py | 3 + monai/transforms/spatial/array.py | 19 ++----- monai/transforms/transform.py | 93 ++++++++++++++++++++++++------- 4 files changed, 102 insertions(+), 38 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 2384c9cce1..7b728fde48 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -22,11 +22,31 @@ Generic Interfaces :members: :special-members: __call__ +`RandomizableTrait` +^^^^^^^^^^^^^^^^^^^ +.. autoclass:: RandomizableTrait + :members: + +`LazyTrait` +^^^^^^^^^^^ +.. autoclass:: LazyTrait + :members: + +`MultiSampleTrait` +^^^^^^^^^^^^^^^^^^ +.. autoclass:: MultiSampleTrait + :members: + `Randomizable` ^^^^^^^^^^^^^^ .. autoclass:: Randomizable :members: +`LazyTransform` +^^^^^^^^^^^^^^^ +.. autoclass:: LazyTransform + :members: + `RandomizableTransform` ^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: RandomizableTransform @@ -63,11 +83,6 @@ Generic Interfaces .. autoclass:: OneOf :members: -`LazyTransform` -^^^^^^^^^^^^^^^ -.. autoclass:: LazyTransform - :members: - Vanilla Transforms ------------------ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index fe96d60d70..9cabc167a7 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -450,9 +450,12 @@ ZoomDict, ) from .transform import ( + LazyTrait, LazyTransform, MapTransform, + MultiSampleTrait, Randomizable, + RandomizableTrait, RandomizableTransform, ThreadUnsafe, Transform, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 67104ba7b7..dcddefce3a 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -31,7 +31,7 @@ from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform +from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( convert_pad_mode, create_control_grid, @@ -48,7 +48,6 @@ GridSampleMode, GridSamplePadMode, InterpolateMode, - LazyAttr, NdimageMode, NumpyPadMode, SplineMode, @@ -752,7 +751,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return data -class Flip(InvertibleTransform, LazyTransform): +class Flip(InvertibleTransform): """ Reverses the order of elements along the given spatial axis. Preserves shape. See `torch.flip` documentation for additional details: @@ -772,13 +771,14 @@ class Flip(InvertibleTransform, LazyTransform): def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis - def update_meta(self, affine, shape, axes): + def update_meta(self, img, shape, axes): # shape and axes include the channel dim + affine = img.affine mat = convert_to_dst_type(torch.eye(len(affine)), affine)[0] for axis in axes: sp = axis - 1 mat[sp, sp], mat[sp, -1] = mat[sp, sp] * -1, shape[axis] - 1 - return affine @ mat + img.affine = affine @ mat def forward_image(self, img, axes) -> torch.Tensor: return torch.flip(img, axes) @@ -790,16 +790,9 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) axes = map_spatial_axes(img.ndim, self.spatial_axis) - if self.lazy_evaluation and isinstance(img, MetaTensor): - spatial_chn_shape = [1, *convert_to_numpy(img.peek_pending_shape()).tolist()] - affine = img.peek_pending_affine() - lazy_affine = self.update_meta(affine, spatial_chn_shape, axes) - img.push_pending_operation({LazyAttr.SHAPE: img.peek_pending_shape(), LazyAttr.AFFINE: lazy_affine}) - self.push_transform(img) - return img out = self.forward_image(img, axes) if get_track_meta(): - out.affine = self.update_meta(out.affine, out.shape, axes) # type: ignore + self.update_meta(out, out.shape, axes) self.push_transform(out) return out diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 764fc4d25b..b1a7d9b4db 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -29,8 +29,11 @@ __all__ = [ "ThreadUnsafe", "apply_transform", - "LazyTransform", + "LazyTrait", + "RandomizableTrait", + "MultiSampleTrait", "Randomizable", + "LazyTransform", "RandomizableTransform", "Transform", "MapTransform", @@ -126,37 +129,67 @@ def _log_stats(data, prefix: Optional[str] = "Data"): raise RuntimeError(f"applying transform {transform}") from e -class ThreadUnsafe: +class LazyTrait: + """ + An interface to indicate that the transform has the capability to execute using + MONAI's lazy resampling feature. In order to do this, the implementing class needs + to be able to describe its operation as an affine matrix or grid with accompanying metadata. + This interface can be extended from by people adapting transforms to the MONAI framework as + well as by implementors of MONAI transforms. """ - A class to denote that the transform will mutate its member variables, - when being applied. Transforms inheriting this class should be used - cautiously in a multi-thread context. - This type is typically used by :py:class:`monai.data.CacheDataset` and - its extensions, where the transform cache is built with multiple threads. + @property + def lazy_evaluation(self): + """ + Get whether lazy_evaluation is enabled for this transform instance. + Returns: + True if the transform is operating in a lazy fashion, False if not. + """ + raise NotImplementedError() + + @lazy_evaluation.setter + def lazy_evaluation(self, enabled: bool): + """ + Set whether lazy_evaluation is enabled for this transform instance. + Args: + enabled: True if the transform should operate in a lazy fashion, False if not. + """ + raise NotImplementedError() + + +class RandomizableTrait: + """ + An interface to indicate that the transform has the capability to perform + randomized transforms to the data that it is called upon. This interface + can be extended from by people adapting transforms to the MONAI framework as well as by + implementors of MONAI transforms. """ pass -class LazyTransform: +class MultiSampleTrait: """ - An interface to denote whether a transform can be applied lazily. It is designed as part of lazy resampling of - multiple transforms. Classes inheriting this interface should be able to operate in two modes: + An interface to indicate that the transform has the capability to return multiple samples + given an input, such as when performing random crops of a sample. This interface can be + extended from by people adapting transforms to the MONAI framework as well as by implementors + of MONAI transforms. + """ + + pass - - ``set_lazy_eval(False)`` (eagerly evaluating), the transform should output the finalized transform - results without any pending operations. Both primary data and metadata of the outputs should be up-to-date. - - ``set_lazy_eval(True)`` (lazily evaluating), the transform should only execute necessary/lightweight/lossless - metadata updates to track any pending operations. The goal is that, in a later stage, the pending operations - can be grouped together and evaluated more efficiently and accurately -- each transforms when evaluated - independently may cause some information losses. +class ThreadUnsafe: """ + A class to denote that the transform will mutate its member variables, + when being applied. Transforms inheriting this class should be used + cautiously in a multi-thread context. - lazy_evaluation: bool = False + This type is typically used by :py:class:`monai.data.CacheDataset` and + its extensions, where the transform cache is built with multiple threads. + """ - def set_lazy_eval(self, value: bool): - self.lazy_evaluation = value + pass class Randomizable(ThreadUnsafe): @@ -279,7 +312,27 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class RandomizableTransform(Randomizable, Transform): +class LazyTransform(Transform, LazyTrait): + """ + An implementation of functionality for lazy transforms that can be subclassed by array and + dictionary transforms to simplify implementation of new lazy transforms. + """ + + def __init__(self, lazy_evaluation: Optional[bool] = True): + self.lazy_evaluation = lazy_evaluation + + @property + def lazy_evaluation(self): + return self.lazy_evaluation + + @lazy_evaluation.setter + def lazy_evaluation(self, lazy_evaluation: bool): + if not isinstance(lazy_evaluation, bool): + raise TypeError("'lazy_evaluation must be a bool but is of " f"type {type(lazy_evaluation)}'") + self.lazy_evaluation = lazy_evaluation + + +class RandomizableTransform(Randomizable, Transform, RandomizableTrait): """ An interface for handling random state locally, currently based on a class variable `R`, which is an instance of `np.random.RandomState`. From a308b6bb737372e47ab20aa2f0d1fc0020550735 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 29 Oct 2022 08:56:38 +0100 Subject: [PATCH 3/3] unit tests Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 13 ++++++++++--- tests/test_meta_tensor.py | 9 +++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 74b5b386a8..493aef848b 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -24,7 +24,7 @@ from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys -from monai.utils.type_conversion import convert_data_type, convert_to_tensor +from monai.utils.type_conversion import convert_data_type, convert_to_numpy, convert_to_tensor __all__ = ["MetaTensor"] @@ -447,10 +447,17 @@ def pixdim(self): def peek_pending_shape(self): """Get the currently expected spatial shape as if all the pending operations are executed.""" - return self.pending_operations[-1][LazyAttr.SHAPE] if self.pending_operations else self.array.shape[1:] + res = None + if self.pending_operations: + res = self.pending_operations[-1].get(LazyAttr.SHAPE, None) + # default to spatial shape (assuming channel-first input) + return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res def peek_pending_affine(self): - return self.pending_operations[-1][LazyAttr.AFFINE] if self.pending_operations else self.affine + res = None + if self.pending_operations: + res = self.pending_operations[-1].get(LazyAttr.AFFINE, None) + return self.affine if res is None else res def new_empty(self, size, dtype=None, device=None, requires_grad=False): """ diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index b46905f3c1..20d25ef61c 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -495,6 +495,15 @@ def test_construct_with_pre_applied_transforms(self): m = MetaTensor(im, applied_operations=data["im"].applied_operations) self.assertEqual(len(m.applied_operations), len(tr.transforms)) + def test_pending_ops(self): + m, _ = self.get_im() + self.assertEqual(m.pending_operations, []) + self.assertEqual(m.peek_pending_shape(), (10, 8)) + self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + m.push_pending_operation({}) + self.assertEqual(m.peek_pending_shape(), (10, 8)) + self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + @parameterized.expand(TESTS) def test_multiprocessing(self, device=None, dtype=None): """multiprocessing sharing with 'device' and 'dtype'"""