diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 874f01a945..7b728fde48 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -22,11 +22,31 @@ Generic Interfaces :members: :special-members: __call__ +`RandomizableTrait` +^^^^^^^^^^^^^^^^^^^ +.. autoclass:: RandomizableTrait + :members: + +`LazyTrait` +^^^^^^^^^^^ +.. autoclass:: LazyTrait + :members: + +`MultiSampleTrait` +^^^^^^^^^^^^^^^^^^ +.. autoclass:: MultiSampleTrait + :members: + `Randomizable` ^^^^^^^^^^^^^^ .. autoclass:: Randomizable :members: +`LazyTransform` +^^^^^^^^^^^^^^^ +.. autoclass:: LazyTransform + :members: + `RandomizableTransform` ^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: RandomizableTransform diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 389571d16f..9cabc167a7 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -449,7 +449,18 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import ( + LazyTrait, + LazyTransform, + MapTransform, + MultiSampleTrait, + Randomizable, + RandomizableTrait, + RandomizableTransform, + ThreadUnsafe, + Transform, + apply_transform, +) from .utility.array import ( AddChannel, AddCoordinateChannels, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 21d057f5d3..b1a7d9b4db 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -26,7 +26,18 @@ from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "LazyTrait", + "RandomizableTrait", + "MultiSampleTrait", + "Randomizable", + "LazyTransform", + "RandomizableTransform", + "Transform", + "MapTransform", +] ReturnType = TypeVar("ReturnType") @@ -118,6 +129,56 @@ def _log_stats(data, prefix: Optional[str] = "Data"): raise RuntimeError(f"applying transform {transform}") from e +class LazyTrait: + """ + An interface to indicate that the transform has the capability to execute using + MONAI's lazy resampling feature. In order to do this, the implementing class needs + to be able to describe its operation as an affine matrix or grid with accompanying metadata. + This interface can be extended from by people adapting transforms to the MONAI framework as + well as by implementors of MONAI transforms. + """ + + @property + def lazy_evaluation(self): + """ + Get whether lazy_evaluation is enabled for this transform instance. + Returns: + True if the transform is operating in a lazy fashion, False if not. + """ + raise NotImplementedError() + + @lazy_evaluation.setter + def lazy_evaluation(self, enabled: bool): + """ + Set whether lazy_evaluation is enabled for this transform instance. + Args: + enabled: True if the transform should operate in a lazy fashion, False if not. + """ + raise NotImplementedError() + + +class RandomizableTrait: + """ + An interface to indicate that the transform has the capability to perform + randomized transforms to the data that it is called upon. This interface + can be extended from by people adapting transforms to the MONAI framework as well as by + implementors of MONAI transforms. + """ + + pass + + +class MultiSampleTrait: + """ + An interface to indicate that the transform has the capability to return multiple samples + given an input, such as when performing random crops of a sample. This interface can be + extended from by people adapting transforms to the MONAI framework as well as by implementors + of MONAI transforms. + """ + + pass + + class ThreadUnsafe: """ A class to denote that the transform will mutate its member variables, @@ -251,7 +312,27 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class RandomizableTransform(Randomizable, Transform): +class LazyTransform(Transform, LazyTrait): + """ + An implementation of functionality for lazy transforms that can be subclassed by array and + dictionary transforms to simplify implementation of new lazy transforms. + """ + + def __init__(self, lazy_evaluation: Optional[bool] = True): + self.lazy_evaluation = lazy_evaluation + + @property + def lazy_evaluation(self): + return self.lazy_evaluation + + @lazy_evaluation.setter + def lazy_evaluation(self, lazy_evaluation: bool): + if not isinstance(lazy_evaluation, bool): + raise TypeError("'lazy_evaluation must be a bool but is of " f"type {type(lazy_evaluation)}'") + self.lazy_evaluation = lazy_evaluation + + +class RandomizableTransform(Randomizable, Transform, RandomizableTrait): """ An interface for handling random state locally, currently based on a class variable `R`, which is an instance of `np.random.RandomState`. diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index 7f4ddce1d0..7ab6ef260d 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -90,7 +90,7 @@ def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_gra x.requires_grad = True self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs) - grad: torch.Tensor = x.grad.detach() + grad: torch.Tensor = x.grad.detach() # type: ignore return grad def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: diff --git a/tests/test_randomizable_transform_type.py b/tests/test_randomizable_transform_type.py new file mode 100644 index 0000000000..9f77d2cd5a --- /dev/null +++ b/tests/test_randomizable_transform_type.py @@ -0,0 +1,33 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.transform import RandomizableTrait, RandomizableTransform + + +class InheritsInterface(RandomizableTrait): + pass + + +class InheritsImplementation(RandomizableTransform): + def __call__(self, data): + return data + + +class TestRandomizableTransformType(unittest.TestCase): + def test_is_randomizable_transform_type(self): + inst = InheritsInterface() + self.assertIsInstance(inst, RandomizableTrait) + + def test_set_random_state_randomizable_transform(self): + inst = InheritsImplementation() + inst.set_random_state(0)