diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 962e1f3769..01b1cb00bb 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -604,6 +604,12 @@ Utility :members: :special-members: __call__ +`RandLambda` +"""""""""""" +.. autoclass:: RandLambda + :members: + :special-members: __call__ + `LabelToMask` """"""""""""" .. autoclass:: LabelToMask diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 45eecd266c..487a995e5e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -323,6 +323,7 @@ LabelToMask, Lambda, MapLabelValue, + RandLambda, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 7f06f119c2..4e0141652f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -23,7 +23,7 @@ import torch from monai.config import DtypeLike, NdarrayTensor -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( convert_to_numpy, convert_to_tensor, @@ -58,6 +58,7 @@ "DataStats", "SimulateDelay", "Lambda", + "RandLambda", "LabelToMask", "FgBgToIndices", "ClassesToIndices", @@ -617,6 +618,28 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable raise ValueError("Incompatible values: func=None and self.func=None.") +class RandLambda(Lambda, RandomizableTransform): + """ + Randomizable version :py:class:`monai.transforms.Lambda`, the input `func` may contain random logic, + or randomly execute the function based on `prob`. + + Args: + func: Lambda/function to be applied. + prob: probability of executing the random function, default to 1.0, with 100% probability to execute. + + For more details, please check :py:class:`monai.transforms.Lambda`. + + """ + + def __init__(self, func: Optional[Callable] = None, prob: float = 1.0) -> None: + Lambda.__init__(self=self, func=func) + RandomizableTransform.__init__(self=self, prob=prob) + + def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None): + self.randomize(img) + return super().__call__(img=img, func=func) if self._do_transform else img + + class LabelToMask(Transform): """ Convert labels to mask for other tasks. A typical usage is to convert segmentation labels diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 6fa672e6c4..75be9685c4 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -24,8 +24,9 @@ import torch from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.data.utils import no_collation from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utility.array import ( AddChannel, AsChannelFirst, @@ -833,7 +834,7 @@ def __call__(self, data): return d -class Lambdad(MapTransform): +class Lambdad(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Lambda`. @@ -852,51 +853,110 @@ class Lambdad(MapTransform): See also: :py:class:`monai.transforms.compose.MapTransform` func: Lambda/function to be applied. It also can be a sequence of Callable, each element corresponds to a key in ``keys``. + inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`. + It also can be a sequence of Callable, each element corresponds to a key in ``keys``. overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + + Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the + image's original size. If need these complicated information, please write a new InvertibleTransform directly. + """ def __init__( self, keys: KeysCollection, func: Union[Sequence[Callable], Callable], + inv_func: Union[Sequence[Callable], Callable] = no_collation, overwrite: Union[Sequence[bool], bool] = True, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.func = ensure_tuple_rep(func, len(self.keys)) + self.inv_func = ensure_tuple_rep(inv_func, len(self.keys)) self.overwrite = ensure_tuple_rep(overwrite, len(self.keys)) self._lambd = Lambda() + def _transform(self, data: Any, func: Callable): + return self._lambd(data, func=func) + def __call__(self, data): d = dict(data) for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite): - ret = self._lambd(d[key], func=func) + ret = self._transform(data=d[key], func=func) + if overwrite: + d[key] = ret + self.push_transform(d, key) + return d + + def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable): + return self._lambd(data, func=func) + + def inverse(self, data): + d = deepcopy(dict(data)) + for key, inv_func, overwrite in self.key_iterator(d, self.inv_func, self.overwrite): + transform = self.get_most_recent_transform(d, key) + ret = self._inverse_transform(transform_info=transform, data=d[key], func=inv_func) if overwrite: d[key] = ret + self.pop_transform(d, key) return d -class RandLambdad(Lambdad, Randomizable): +class RandLambdad(Lambdad, RandomizableTransform): """ - Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. - It's a randomizable transform so `CacheDataset` will not execute it and cache the results. + Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` may contain random logic, + or randomly execute the function based on `prob`. so `CacheDataset` will not execute it and cache the results. Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` func: Lambda/function to be applied. It also can be a sequence of Callable, each element corresponds to a key in ``keys``. + inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`. + It also can be a sequence of Callable, each element corresponds to a key in ``keys``. overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. + prob: probability of executing the random function, default to 1.0, with 100% probability to execute. + note that all the data specified by `keys` will share the same random probability to execute or not. + allow_missing_keys: don't raise exception if key is missing. For more details, please check :py:class:`monai.transforms.Lambdad`. + Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the + image's original size. If need these complicated information, please write a new InvertibleTransform directly. + """ - def randomize(self, data: Any) -> None: - pass + def __init__( + self, + keys: KeysCollection, + func: Union[Sequence[Callable], Callable], + inv_func: Union[Sequence[Callable], Callable] = no_collation, + overwrite: Union[Sequence[bool], bool] = True, + prob: float = 1.0, + allow_missing_keys: bool = False, + ) -> None: + Lambdad.__init__( + self=self, + keys=keys, + func=func, + inv_func=inv_func, + overwrite=overwrite, + allow_missing_keys=allow_missing_keys, + ) + RandomizableTransform.__init__(self=self, prob=prob, do_transform=True) + + def _transform(self, data: Any, func: Callable): + return self._lambd(data, func=func) if self._do_transform else data + + def __call__(self, data): + self.randomize(data) + return super().__call__(data) + + def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable): + return self._lambd(data, func=func) if transform_info[InverseKeys.DO_TRANSFORM] else data class LabelToMaskd(MapTransform): diff --git a/tests/test_inverse.py b/tests/test_inverse.py index a1c171200f..f2470d47fd 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -35,6 +35,7 @@ DivisiblePadd, Flipd, InvertibleTransform, + Lambdad, LoadImaged, Orientationd, RandAffined, @@ -42,6 +43,7 @@ RandCropByLabelClassesd, RandCropByPosNegLabeld, RandFlipd, + RandLambdad, Randomizable, RandRotate90d, RandRotated, @@ -314,6 +316,16 @@ TESTS.append(("Resized longest 3d", "3D", 5e-2, Resized(KEYS, 201, "longest", "trilinear", True))) +TESTS.append(("Lambdad 2d", "2D", 5e-2, Lambdad(KEYS, func=lambda x: x + 5, inv_func=lambda x: x - 5, overwrite=True))) + +TESTS.append( + ( + "RandLambdad 3d", + "3D", + 5e-2, + RandLambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True, prob=0.5), + ) +) TESTS.append( ( diff --git a/tests/test_rand_lambda.py b/tests/test_rand_lambda.py new file mode 100644 index 0000000000..bf537883cf --- /dev/null +++ b/tests/test_rand_lambda.py @@ -0,0 +1,53 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.transforms.transform import Randomizable +from monai.transforms.utility.array import RandLambda + + +class RandTest(Randomizable): + """ + randomisable transform for testing. + """ + + def randomize(self, data=None): + self._a = self.R.random() + + def __call__(self, data): + self.randomize() + return data + self._a + + +class TestRandLambda(unittest.TestCase): + def test_rand_lambdad_identity(self): + img = np.zeros((10, 10)) + + test_func = RandTest() + test_func.set_random_state(seed=134) + expected = test_func(img) + test_func.set_random_state(seed=134) + ret = RandLambda(func=test_func)(img) + np.testing.assert_allclose(expected, ret) + ret = RandLambda(func=test_func, prob=0.0)(img) + np.testing.assert_allclose(img, ret) + + trans = RandLambda(func=test_func, prob=0.5) + trans.set_random_state(seed=123) + ret = trans(img) + np.testing.assert_allclose(img, ret) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index a450b67413..0a127839b8 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -42,6 +42,15 @@ def test_rand_lambdad_identity(self): ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data) np.testing.assert_allclose(expected["img"], ret["img"]) np.testing.assert_allclose(expected["prop"], ret["prop"]) + ret = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.0)(data) + np.testing.assert_allclose(data["img"], ret["img"]) + np.testing.assert_allclose(data["prop"], ret["prop"]) + + trans = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.5) + trans.set_random_state(seed=123) + ret = trans(data) + np.testing.assert_allclose(data["img"], ret["img"]) + np.testing.assert_allclose(data["prop"], ret["prop"]) if __name__ == "__main__":