From 8680c14c251a3b552e21eb431c4536a2593ab121 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 29 Jul 2021 13:50:41 +0800 Subject: [PATCH 1/9] [DLMED] add RandCompose Signed-off-by: Nic Ma --- monai/transforms/compose.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index b380f7d42a..d98fcef44d 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -27,7 +27,7 @@ Transform, apply_transform, ) -from monai.utils import MAX_SEED, ensure_tuple, get_seed +from monai.utils import MAX_SEED, ensure_tuple, ensure_tuple_rep, get_seed __all__ = ["Compose"] @@ -159,8 +159,12 @@ def __call__(self, input_): input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) return input_ + def _get_applied_transforms(self): + return self.flatten().transforms + def inverse(self, data): - invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] + invertible_transforms = [t for t in self._get_applied_transforms() if isinstance(t, InvertibleTransform)] + if not invertible_transforms: warnings.warn("inverse has been called but no invertible transforms have been supplied") @@ -168,3 +172,31 @@ def inverse(self, data): for t in reversed(invertible_transforms): data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) return data + + +class RandCompose(Compose): + def __init__( + self, + prob: Union[Sequence[float], float], + transforms: Optional[Union[Sequence[Callable], Callable]] = None, + map_items: bool = True, + unpack_items: bool = False, + ) -> None: + super().__init__(transforms=transforms, map_items=map_items, unpack_items=unpack_items) + self.prob = ensure_tuple_rep(prob, len(self.transforms)) + self.applied: List[Callable] = [] + + def flatten(self): + return self + + def __call__(self, input_): + rands = self.R.rand(len(self)) + self.applied = [] + for _transform, r, p in enumerate(self.transforms, rands, self.prob): + if r < min(max(p, 0.0), 1.0): + input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) + self.applied.append(_transform) + return input_ + + def _get_applied_transforms(self): + return self.applied From b60908d829dcd192da1acaf012720e580f4ef092 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 29 Jul 2021 20:13:49 +0800 Subject: [PATCH 2/9] [DLMED] add unit tests Signed-off-by: Nic Ma --- docs/source/transforms.rst | 6 ++ monai/transforms/__init__.py | 2 +- monai/transforms/compose.py | 17 ++-- tests/test_rand_compose.py | 169 +++++++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 8 deletions(-) create mode 100644 tests/test_rand_compose.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 962e1f3769..3e691002d3 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -38,6 +38,12 @@ Generic Interfaces :members: :special-members: __call__ +`RandCompose` +^^^^^^^^^^^^^ +.. autoclass:: RandCompose + :members: + :special-members: __call__ + `InvertibleTransform` ^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: InvertibleTransform diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 45eecd266c..9a45bfa0b9 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .compose import Compose +from .compose import Compose, RandCompose from .croppad.array import ( BorderPad, BoundingRect, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index d98fcef44d..cbfaa93818 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -29,7 +29,7 @@ ) from monai.utils import MAX_SEED, ensure_tuple, ensure_tuple_rep, get_seed -__all__ = ["Compose"] +__all__ = ["Compose", "RandCompose"] class Compose(Randomizable, InvertibleTransform): @@ -143,7 +143,7 @@ def flatten(self): """ new_transforms = [] for t in self.transforms: - if isinstance(t, Compose): + if type(t) == Compose: new_transforms += t.flatten().transforms else: new_transforms.append(t) @@ -164,7 +164,7 @@ def _get_applied_transforms(self): def inverse(self, data): invertible_transforms = [t for t in self._get_applied_transforms() if isinstance(t, InvertibleTransform)] - + if not invertible_transforms: warnings.warn("inverse has been called but no invertible transforms have been supplied") @@ -185,14 +185,17 @@ def __init__( super().__init__(transforms=transforms, map_items=map_items, unpack_items=unpack_items) self.prob = ensure_tuple_rep(prob, len(self.transforms)) self.applied: List[Callable] = [] - + def flatten(self): - return self - + raise NotImplementedError("flatten method not yet implemented for `RandCompose` class.") + + def __len__(self): + return len(self.transforms) + def __call__(self, input_): rands = self.R.rand(len(self)) self.applied = [] - for _transform, r, p in enumerate(self.transforms, rands, self.prob): + for _transform, r, p in zip(self.transforms, rands, self.prob): if r < min(max(p, 0.0), 1.0): input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) self.applied.append(_transform) diff --git a/tests/test_rand_compose.py b/tests/test_rand_compose.py new file mode 100644 index 0000000000..cd2bc4d493 --- /dev/null +++ b/tests/test_rand_compose.py @@ -0,0 +1,169 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +from monai.data import DataLoader, Dataset +from monai.transforms import AddChannel, Compose, RandCompose +from monai.transforms.transform import Randomizable +from monai.utils import set_determinism + + +class _RandXform(Randomizable): + def randomize(self): + self.val = self.R.random_sample() + + def __call__(self, img): + self.randomize() + return img + self.val + + +class TestRandCompose(unittest.TestCase): + def test_non_dict_compose(self): + def a(i): + return i + "a" + + def b(i): + return i + "b" + + c = RandCompose(prob=[1.0, 0.0, 0.0, 1.0], transforms=[a, b, a, b]) + self.assertEqual(c(""), "ab") + + def test_dict_compose(self): + def a(d): + d = dict(d) + d["a"] += 1 + return d + + def b(d): + d = dict(d) + d["b"] += 1 + return d + + c = RandCompose(prob=[1.0, 0.0, 1.0, 0.0, 1.0], transforms=[a, b, a, b, a]) + self.assertDictEqual(c({"a": 0, "b": 0}), {"a": 3, "b": 0}) + + def test_list_dict_compose(self): + def a(d): # transform to handle dict data + d = dict(d) + d["a"] += 1 + return d + + def b(d): # transform to generate a batch list of data + d = dict(d) + d["b"] += 1 + d = [d] * 5 + return d + + def c(d): # transform to handle dict data + d = dict(d) + d["c"] += 1 + return d + + transforms = RandCompose(prob=[1.0, 0.0, 1.0, 0.0, 1.0], transforms=[a, a, b, c, c]) + value = transforms({"a": 0, "b": 0, "c": 0}) + for item in value: + self.assertDictEqual(item, {"a": 1, "b": 1, "c": 1}) + + def test_non_dict_compose_with_unpack(self): + def a(i, i2): + return i + "a", i2 + "a2" + + def b(i, i2): + return i + "b", i2 + "b2" + + c = RandCompose(prob=[1.0, 0.0, 0.0, 1.0], transforms=[a, b, a, b], map_items=False, unpack_items=True) + self.assertEqual(c(("", "")), ("ab", "a2b2")) + + def test_list_non_dict_compose_with_unpack(self): + def a(i, i2): + return i + "a", i2 + "a2" + + def b(i, i2): + return i + "b", i2 + "b2" + + c = RandCompose(prob=[1.0, 0.0, 0.0, 1.0], transforms=[a, b, a, b], unpack_items=True) + self.assertEqual(c([("", ""), ("t", "t")]), [("ab", "a2b2"), ("tab", "ta2b2")]) + + def test_list_dict_compose_no_map(self): + def a(d): # transform to handle dict data + d = dict(d) + d["a"] += 1 + return d + + def b(d): # transform to generate a batch list of data + d = dict(d) + d["b"] += 1 + d = [d] * 5 + return d + + def c(d): # transform to handle dict data + d = [dict(di) for di in d] + for di in d: + di["c"] += 1 + return d + + transforms = RandCompose(prob=[1.0, 0.0, 1.0, 0.0, 1.0], transforms=[a, a, b, c, c], map_items=False) + value = transforms({"a": 0, "b": 0, "c": 0}) + for item in value: + self.assertDictEqual(item, {"a": 1, "b": 1, "c": 1}) + + def test_random_compose(self): + class _Acc(Randomizable): + self.rand = 0.0 + + def randomize(self, data=None): + self.rand = self.R.rand() + + def __call__(self, data): + self.randomize() + return self.rand + data + + c = RandCompose(prob=0.5, transforms=[_Acc(), _Acc()]) + self.assertNotAlmostEqual(c(0), c(0)) + c.set_random_state(123) + self.assertAlmostEqual(c(1), 1.61381597) + c.set_random_state(456) + c.randomize() + self.assertAlmostEqual(c(1), 1.17330701) + + def test_data_loader(self): + xform_1 = RandCompose(prob=0.5, transforms=[_RandXform(), _RandXform(), _RandXform()]) + train_ds = Dataset([1], transform=xform_1) + + set_determinism(seed=123) + train_loader = DataLoader(train_ds, num_workers=0) + out_1 = next(iter(train_loader)) + self.assertAlmostEqual(out_1.item(), 1.58704446) + + if sys.platform != "win32": # skip multi-worker tests on win32 + train_loader = DataLoader(train_ds, num_workers=1) + out_1 = next(iter(train_loader)) + self.assertAlmostEqual(out_1.item(), 1.15912328) + + train_loader = DataLoader(train_ds, num_workers=2) + out_1 = next(iter(train_loader)) + self.assertAlmostEqual(out_1.item(), 1.65850210) + set_determinism(None) + + def test_flatten_and_len(self): + x = AddChannel() + t1 = Compose([x, x, x, x, Compose([RandCompose(prob=[0.1, 0.2], transforms=[x, x]), x, x])]) + + t2 = t1.flatten() + # test length + self.assertEqual(len(t1), 7) + self.assertEqual(len(t2.transforms[4]), 2) + + +if __name__ == "__main__": + unittest.main() From bf313996701cd477ed3143c5ae61d074c252b3c1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 30 Jul 2021 11:58:08 +0800 Subject: [PATCH 3/9] [DLMED] change to enhance RandLambda Signed-off-by: Nic Ma --- docs/source/transforms.rst | 6 +++ monai/transforms/__init__.py | 1 + monai/transforms/utility/array.py | 22 ++++++++++- monai/transforms/utility/dictionary.py | 22 +++++++++-- tests/test_rand_lambda.py | 53 ++++++++++++++++++++++++++ tests/test_rand_lambdad.py | 9 +++++ 6 files changed, 108 insertions(+), 5 deletions(-) create mode 100644 tests/test_rand_lambda.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 3e691002d3..7a3573870d 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -610,6 +610,12 @@ Utility :members: :special-members: __call__ +`RandLambda` +"""""""""""" +.. autoclass:: RandLambda + :members: + :special-members: __call__ + `LabelToMask` """"""""""""" .. autoclass:: LabelToMask diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9a45bfa0b9..842cbe1a75 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -323,6 +323,7 @@ LabelToMask, Lambda, MapLabelValue, + RandLambda, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 7f06f119c2..17d837830f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -23,7 +23,7 @@ import torch from monai.config import DtypeLike, NdarrayTensor -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( convert_to_numpy, convert_to_tensor, @@ -58,6 +58,7 @@ "DataStats", "SimulateDelay", "Lambda", + "RandLambda", "LabelToMask", "FgBgToIndices", "ClassesToIndices", @@ -616,6 +617,25 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable return self.func(img) raise ValueError("Incompatible values: func=None and self.func=None.") +class RandLambda(Lambda, RandomizableTransform): + """ + Randomizable version :py:class:`monai.transforms.Lambda`, the input `func` contains random logic. + + Args: + func: Lambda/function to be applied. + prob: probability of executing the random function, default to 1.0, with 100% probability to execute. + + For more details, please check :py:class:`monai.transforms.Lambda`. + + """ + def __init__(self, func: Optional[Callable] = None, prob: float = 1.0) -> None: + Lambda.__init__(self=self, func=func) + RandomizableTransform.__init__(self=self, prob=prob) + + def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None): + self.randomize(img) + return super().__call__(img=img, func=func) if self._do_transform else img + class LabelToMask(Transform): """ diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 6fa672e6c4..7d296a5cef 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -25,7 +25,7 @@ from monai.config import DtypeLike, KeysCollection, NdarrayTensor from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utility.array import ( AddChannel, AsChannelFirst, @@ -878,7 +878,7 @@ def __call__(self, data): return d -class RandLambdad(Lambdad, Randomizable): +class RandLambdad(Lambdad, RandomizableTransform): """ Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. It's a randomizable transform so `CacheDataset` will not execute it and cache the results. @@ -890,13 +890,27 @@ class RandLambdad(Lambdad, Randomizable): each element corresponds to a key in ``keys``. overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. + prob: probability of executing the random function, default to 1.0, with 100% probability to execute. + note that all the data specified by `keys` will share the same random probability to execute or not. + allow_missing_keys: don't raise exception if key is missing. For more details, please check :py:class:`monai.transforms.Lambdad`. """ + def __init__( + self, + keys: KeysCollection, + func: Union[Sequence[Callable], Callable], + overwrite: Union[Sequence[bool], bool] = True, + prob: float = 1.0, + allow_missing_keys: bool = False, + ) -> None: + Lambdad.__init__(self=self, keys=keys, func=func, overwrite=overwrite, allow_missing_keys=allow_missing_keys) + RandomizableTransform.__init__(self=self, prob=prob, do_transform=True) - def randomize(self, data: Any) -> None: - pass + def __call__(self, data): + self.randomize(data) + return super().__call__(data) if self._do_transform else data class LabelToMaskd(MapTransform): diff --git a/tests/test_rand_lambda.py b/tests/test_rand_lambda.py new file mode 100644 index 0000000000..bf537883cf --- /dev/null +++ b/tests/test_rand_lambda.py @@ -0,0 +1,53 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.transforms.transform import Randomizable +from monai.transforms.utility.array import RandLambda + + +class RandTest(Randomizable): + """ + randomisable transform for testing. + """ + + def randomize(self, data=None): + self._a = self.R.random() + + def __call__(self, data): + self.randomize() + return data + self._a + + +class TestRandLambda(unittest.TestCase): + def test_rand_lambdad_identity(self): + img = np.zeros((10, 10)) + + test_func = RandTest() + test_func.set_random_state(seed=134) + expected = test_func(img) + test_func.set_random_state(seed=134) + ret = RandLambda(func=test_func)(img) + np.testing.assert_allclose(expected, ret) + ret = RandLambda(func=test_func, prob=0.0)(img) + np.testing.assert_allclose(img, ret) + + trans = RandLambda(func=test_func, prob=0.5) + trans.set_random_state(seed=123) + ret = trans(img) + np.testing.assert_allclose(img, ret) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py index a450b67413..0a127839b8 100644 --- a/tests/test_rand_lambdad.py +++ b/tests/test_rand_lambdad.py @@ -42,6 +42,15 @@ def test_rand_lambdad_identity(self): ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data) np.testing.assert_allclose(expected["img"], ret["img"]) np.testing.assert_allclose(expected["prop"], ret["prop"]) + ret = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.0)(data) + np.testing.assert_allclose(data["img"], ret["img"]) + np.testing.assert_allclose(data["prop"], ret["prop"]) + + trans = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.5) + trans.set_random_state(seed=123) + ret = trans(data) + np.testing.assert_allclose(data["img"], ret["img"]) + np.testing.assert_allclose(data["prop"], ret["prop"]) if __name__ == "__main__": From 7358326455153f00dd4debc7035a833c179b1fc1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 30 Jul 2021 12:00:49 +0800 Subject: [PATCH 4/9] [DLMED] remove RandCompose Signed-off-by: Nic Ma --- docs/source/transforms.rst | 6 -- monai/transforms/__init__.py | 2 +- monai/transforms/compose.py | 42 +-------- tests/test_rand_compose.py | 169 ----------------------------------- 4 files changed, 5 insertions(+), 214 deletions(-) delete mode 100644 tests/test_rand_compose.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 7a3573870d..01b1cb00bb 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -38,12 +38,6 @@ Generic Interfaces :members: :special-members: __call__ -`RandCompose` -^^^^^^^^^^^^^ -.. autoclass:: RandCompose - :members: - :special-members: __call__ - `InvertibleTransform` ^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: InvertibleTransform diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 842cbe1a75..487a995e5e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .compose import Compose, RandCompose +from .compose import Compose from .croppad.array import ( BorderPad, BoundingRect, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index cbfaa93818..65dd3e92a4 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -27,9 +27,9 @@ Transform, apply_transform, ) -from monai.utils import MAX_SEED, ensure_tuple, ensure_tuple_rep, get_seed +from monai.utils import MAX_SEED, ensure_tuple, get_seed -__all__ = ["Compose", "RandCompose"] +__all__ = ["Compose"] class Compose(Randomizable, InvertibleTransform): @@ -143,7 +143,7 @@ def flatten(self): """ new_transforms = [] for t in self.transforms: - if type(t) == Compose: + if isinstance(t, Compose): new_transforms += t.flatten().transforms else: new_transforms.append(t) @@ -159,11 +159,8 @@ def __call__(self, input_): input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) return input_ - def _get_applied_transforms(self): - return self.flatten().transforms - def inverse(self, data): - invertible_transforms = [t for t in self._get_applied_transforms() if isinstance(t, InvertibleTransform)] + invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] if not invertible_transforms: warnings.warn("inverse has been called but no invertible transforms have been supplied") @@ -172,34 +169,3 @@ def inverse(self, data): for t in reversed(invertible_transforms): data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) return data - - -class RandCompose(Compose): - def __init__( - self, - prob: Union[Sequence[float], float], - transforms: Optional[Union[Sequence[Callable], Callable]] = None, - map_items: bool = True, - unpack_items: bool = False, - ) -> None: - super().__init__(transforms=transforms, map_items=map_items, unpack_items=unpack_items) - self.prob = ensure_tuple_rep(prob, len(self.transforms)) - self.applied: List[Callable] = [] - - def flatten(self): - raise NotImplementedError("flatten method not yet implemented for `RandCompose` class.") - - def __len__(self): - return len(self.transforms) - - def __call__(self, input_): - rands = self.R.rand(len(self)) - self.applied = [] - for _transform, r, p in zip(self.transforms, rands, self.prob): - if r < min(max(p, 0.0), 1.0): - input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) - self.applied.append(_transform) - return input_ - - def _get_applied_transforms(self): - return self.applied diff --git a/tests/test_rand_compose.py b/tests/test_rand_compose.py deleted file mode 100644 index cd2bc4d493..0000000000 --- a/tests/test_rand_compose.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import unittest - -from monai.data import DataLoader, Dataset -from monai.transforms import AddChannel, Compose, RandCompose -from monai.transforms.transform import Randomizable -from monai.utils import set_determinism - - -class _RandXform(Randomizable): - def randomize(self): - self.val = self.R.random_sample() - - def __call__(self, img): - self.randomize() - return img + self.val - - -class TestRandCompose(unittest.TestCase): - def test_non_dict_compose(self): - def a(i): - return i + "a" - - def b(i): - return i + "b" - - c = RandCompose(prob=[1.0, 0.0, 0.0, 1.0], transforms=[a, b, a, b]) - self.assertEqual(c(""), "ab") - - def test_dict_compose(self): - def a(d): - d = dict(d) - d["a"] += 1 - return d - - def b(d): - d = dict(d) - d["b"] += 1 - return d - - c = RandCompose(prob=[1.0, 0.0, 1.0, 0.0, 1.0], transforms=[a, b, a, b, a]) - self.assertDictEqual(c({"a": 0, "b": 0}), {"a": 3, "b": 0}) - - def test_list_dict_compose(self): - def a(d): # transform to handle dict data - d = dict(d) - d["a"] += 1 - return d - - def b(d): # transform to generate a batch list of data - d = dict(d) - d["b"] += 1 - d = [d] * 5 - return d - - def c(d): # transform to handle dict data - d = dict(d) - d["c"] += 1 - return d - - transforms = RandCompose(prob=[1.0, 0.0, 1.0, 0.0, 1.0], transforms=[a, a, b, c, c]) - value = transforms({"a": 0, "b": 0, "c": 0}) - for item in value: - self.assertDictEqual(item, {"a": 1, "b": 1, "c": 1}) - - def test_non_dict_compose_with_unpack(self): - def a(i, i2): - return i + "a", i2 + "a2" - - def b(i, i2): - return i + "b", i2 + "b2" - - c = RandCompose(prob=[1.0, 0.0, 0.0, 1.0], transforms=[a, b, a, b], map_items=False, unpack_items=True) - self.assertEqual(c(("", "")), ("ab", "a2b2")) - - def test_list_non_dict_compose_with_unpack(self): - def a(i, i2): - return i + "a", i2 + "a2" - - def b(i, i2): - return i + "b", i2 + "b2" - - c = RandCompose(prob=[1.0, 0.0, 0.0, 1.0], transforms=[a, b, a, b], unpack_items=True) - self.assertEqual(c([("", ""), ("t", "t")]), [("ab", "a2b2"), ("tab", "ta2b2")]) - - def test_list_dict_compose_no_map(self): - def a(d): # transform to handle dict data - d = dict(d) - d["a"] += 1 - return d - - def b(d): # transform to generate a batch list of data - d = dict(d) - d["b"] += 1 - d = [d] * 5 - return d - - def c(d): # transform to handle dict data - d = [dict(di) for di in d] - for di in d: - di["c"] += 1 - return d - - transforms = RandCompose(prob=[1.0, 0.0, 1.0, 0.0, 1.0], transforms=[a, a, b, c, c], map_items=False) - value = transforms({"a": 0, "b": 0, "c": 0}) - for item in value: - self.assertDictEqual(item, {"a": 1, "b": 1, "c": 1}) - - def test_random_compose(self): - class _Acc(Randomizable): - self.rand = 0.0 - - def randomize(self, data=None): - self.rand = self.R.rand() - - def __call__(self, data): - self.randomize() - return self.rand + data - - c = RandCompose(prob=0.5, transforms=[_Acc(), _Acc()]) - self.assertNotAlmostEqual(c(0), c(0)) - c.set_random_state(123) - self.assertAlmostEqual(c(1), 1.61381597) - c.set_random_state(456) - c.randomize() - self.assertAlmostEqual(c(1), 1.17330701) - - def test_data_loader(self): - xform_1 = RandCompose(prob=0.5, transforms=[_RandXform(), _RandXform(), _RandXform()]) - train_ds = Dataset([1], transform=xform_1) - - set_determinism(seed=123) - train_loader = DataLoader(train_ds, num_workers=0) - out_1 = next(iter(train_loader)) - self.assertAlmostEqual(out_1.item(), 1.58704446) - - if sys.platform != "win32": # skip multi-worker tests on win32 - train_loader = DataLoader(train_ds, num_workers=1) - out_1 = next(iter(train_loader)) - self.assertAlmostEqual(out_1.item(), 1.15912328) - - train_loader = DataLoader(train_ds, num_workers=2) - out_1 = next(iter(train_loader)) - self.assertAlmostEqual(out_1.item(), 1.65850210) - set_determinism(None) - - def test_flatten_and_len(self): - x = AddChannel() - t1 = Compose([x, x, x, x, Compose([RandCompose(prob=[0.1, 0.2], transforms=[x, x]), x, x])]) - - t2 = t1.flatten() - # test length - self.assertEqual(len(t1), 7) - self.assertEqual(len(t2.transforms[4]), 2) - - -if __name__ == "__main__": - unittest.main() From 69090c4f9238100433fff455e2e6a1315e6d6326 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 30 Jul 2021 12:04:05 +0800 Subject: [PATCH 5/9] [DLMED] fix format Signed-off-by: Nic Ma --- monai/transforms/compose.py | 1 - monai/transforms/utility/array.py | 2 ++ monai/transforms/utility/dictionary.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 65dd3e92a4..b380f7d42a 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -161,7 +161,6 @@ def __call__(self, input_): def inverse(self, data): invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] - if not invertible_transforms: warnings.warn("inverse has been called but no invertible transforms have been supplied") diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 17d837830f..7b9382ca94 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -617,6 +617,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable return self.func(img) raise ValueError("Incompatible values: func=None and self.func=None.") + class RandLambda(Lambda, RandomizableTransform): """ Randomizable version :py:class:`monai.transforms.Lambda`, the input `func` contains random logic. @@ -628,6 +629,7 @@ class RandLambda(Lambda, RandomizableTransform): For more details, please check :py:class:`monai.transforms.Lambda`. """ + def __init__(self, func: Optional[Callable] = None, prob: float = 1.0) -> None: Lambda.__init__(self=self, func=func) RandomizableTransform.__init__(self=self, prob=prob) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7d296a5cef..797a0c9820 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -897,6 +897,7 @@ class RandLambdad(Lambdad, RandomizableTransform): For more details, please check :py:class:`monai.transforms.Lambdad`. """ + def __init__( self, keys: KeysCollection, From a6b5d185e82c20410b56ca5bcb2c51d9d14eee00 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 30 Jul 2021 12:09:00 +0800 Subject: [PATCH 6/9] [DLMED] enhance doc Signed-off-by: Nic Ma --- monai/transforms/utility/array.py | 3 ++- monai/transforms/utility/dictionary.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 7b9382ca94..4e0141652f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -620,7 +620,8 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable class RandLambda(Lambda, RandomizableTransform): """ - Randomizable version :py:class:`monai.transforms.Lambda`, the input `func` contains random logic. + Randomizable version :py:class:`monai.transforms.Lambda`, the input `func` may contain random logic, + or randomly execute the function based on `prob`. Args: func: Lambda/function to be applied. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 797a0c9820..0b0d47f9fc 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -880,8 +880,8 @@ def __call__(self, data): class RandLambdad(Lambdad, RandomizableTransform): """ - Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic. - It's a randomizable transform so `CacheDataset` will not execute it and cache the results. + Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` may contain random logic, + or randomly execute the function based on `prob`. so `CacheDataset` will not execute it and cache the results. Args: keys: keys of the corresponding items to be transformed. From 1df334f8ae050d6227413e5b4765b8a3386efb83 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 30 Jul 2021 18:13:12 +0800 Subject: [PATCH 7/9] [DLMED] add inverse operation Signed-off-by: Nic Ma --- monai/transforms/utility/dictionary.py | 42 ++++++++++++++++++++++++-- tests/test_inverse.py | 12 ++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 0b0d47f9fc..a555a731aa 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -833,7 +833,7 @@ def __call__(self, data): return d -class Lambdad(MapTransform): +class Lambdad(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Lambda`. @@ -852,20 +852,28 @@ class Lambdad(MapTransform): See also: :py:class:`monai.transforms.compose.MapTransform` func: Lambda/function to be applied. It also can be a sequence of Callable, each element corresponds to a key in ``keys``. + inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`. + It also can be a sequence of Callable, each element corresponds to a key in ``keys``. overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + + Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the + image's original size. If need these complicated information, please write a new InvertibleTransform directly. + """ def __init__( self, keys: KeysCollection, func: Union[Sequence[Callable], Callable], + inv_func: Union[Sequence[Callable], Callable] = lambda x: x, overwrite: Union[Sequence[bool], bool] = True, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.func = ensure_tuple_rep(func, len(self.keys)) + self.inv_func = ensure_tuple_rep(inv_func, len(self.keys)) self.overwrite = ensure_tuple_rep(overwrite, len(self.keys)) self._lambd = Lambda() @@ -875,6 +883,20 @@ def __call__(self, data): ret = self._lambd(d[key], func=func) if overwrite: d[key] = ret + self.push_transform(d, key) + return d + + def _inverse_transform(self, transform: Dict, data: Any, func: Callable): + return self._lambd(data, func=func) + + def inverse(self, data): + d = deepcopy(dict(data)) + for key, inv_func, overwrite in self.key_iterator(d, self.inv_func, self.overwrite): + transform = self.get_most_recent_transform(d, key) + ret = self._inverse_transform(transform=transform, data=d[key], func=inv_func) + if overwrite: + d[key] = ret + self.pop_transform(d, key) return d @@ -888,6 +910,8 @@ class RandLambdad(Lambdad, RandomizableTransform): See also: :py:class:`monai.transforms.compose.MapTransform` func: Lambda/function to be applied. It also can be a sequence of Callable, each element corresponds to a key in ``keys``. + inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`. + It also can be a sequence of Callable, each element corresponds to a key in ``keys``. overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. prob: probability of executing the random function, default to 1.0, with 100% probability to execute. @@ -896,23 +920,37 @@ class RandLambdad(Lambdad, RandomizableTransform): For more details, please check :py:class:`monai.transforms.Lambdad`. + Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the + image's original size. If need these complicated information, please write a new InvertibleTransform directly. + """ def __init__( self, keys: KeysCollection, func: Union[Sequence[Callable], Callable], + inv_func: Union[Sequence[Callable], Callable] = lambda x: x, overwrite: Union[Sequence[bool], bool] = True, prob: float = 1.0, allow_missing_keys: bool = False, ) -> None: - Lambdad.__init__(self=self, keys=keys, func=func, overwrite=overwrite, allow_missing_keys=allow_missing_keys) + Lambdad.__init__( + self=self, + keys=keys, + func=func, + inv_func=inv_func, + overwrite=overwrite, + allow_missing_keys=allow_missing_keys, + ) RandomizableTransform.__init__(self=self, prob=prob, do_transform=True) def __call__(self, data): self.randomize(data) return super().__call__(data) if self._do_transform else data + def _inverse_transform(self, transform: Dict, data: Any, func: Callable): + return self._lambd(data, func=func) if transform[InverseKeys.DO_TRANSFORM] else data + class LabelToMaskd(MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index a1c171200f..0af74ae085 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -35,6 +35,7 @@ DivisiblePadd, Flipd, InvertibleTransform, + Lambdad, LoadImaged, Orientationd, RandAffined, @@ -42,6 +43,7 @@ RandCropByLabelClassesd, RandCropByPosNegLabeld, RandFlipd, + RandLambdad, Randomizable, RandRotate90d, RandRotated, @@ -314,6 +316,16 @@ TESTS.append(("Resized longest 3d", "3D", 5e-2, Resized(KEYS, 201, "longest", "trilinear", True))) +TESTS.append(("Lambdad 2d", "2D", 5e-2, Lambdad(KEYS, func=lambda x: x + 5, inv_func=lambda x: x - 5, overwrite=True))) + +TESTS.append( + ( + "RandLambdad 3d", + "3D", + 5e-2, + Lambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True), + ) +) TESTS.append( ( From 0a14072d702ea4206cd3760bf81b49092c58bfb4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 30 Jul 2021 18:24:44 +0800 Subject: [PATCH 8/9] [DLMED] add more tests Signed-off-by: Nic Ma --- monai/transforms/utility/dictionary.py | 18 ++++++++++++------ tests/test_inverse.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index a555a731aa..2302755077 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -877,23 +877,26 @@ def __init__( self.overwrite = ensure_tuple_rep(overwrite, len(self.keys)) self._lambd = Lambda() + def _transform(self, data: Any, func: Callable): + return self._lambd(data, func=func) + def __call__(self, data): d = dict(data) for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite): - ret = self._lambd(d[key], func=func) + ret = self._transform(data=d[key], func=func) if overwrite: d[key] = ret self.push_transform(d, key) return d - def _inverse_transform(self, transform: Dict, data: Any, func: Callable): + def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable): return self._lambd(data, func=func) def inverse(self, data): d = deepcopy(dict(data)) for key, inv_func, overwrite in self.key_iterator(d, self.inv_func, self.overwrite): transform = self.get_most_recent_transform(d, key) - ret = self._inverse_transform(transform=transform, data=d[key], func=inv_func) + ret = self._inverse_transform(transform_info=transform, data=d[key], func=inv_func) if overwrite: d[key] = ret self.pop_transform(d, key) @@ -944,12 +947,15 @@ def __init__( ) RandomizableTransform.__init__(self=self, prob=prob, do_transform=True) + def _transform(self, data: Any, func: Callable): + return self._lambd(data, func=func) if self._do_transform else data + def __call__(self, data): self.randomize(data) - return super().__call__(data) if self._do_transform else data + return super().__call__(data) - def _inverse_transform(self, transform: Dict, data: Any, func: Callable): - return self._lambd(data, func=func) if transform[InverseKeys.DO_TRANSFORM] else data + def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable): + return self._lambd(data, func=func) if transform_info[InverseKeys.DO_TRANSFORM] else data class LabelToMaskd(MapTransform): diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0af74ae085..f2470d47fd 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -323,7 +323,7 @@ "RandLambdad 3d", "3D", 5e-2, - Lambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True), + RandLambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True, prob=0.5), ) ) From 24e43f2ff8ae50ab9b024cacff26865800074d36 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 30 Jul 2021 22:58:27 +0800 Subject: [PATCH 9/9] [DLMED] fix subprogress issue Signed-off-by: Nic Ma --- monai/transforms/utility/dictionary.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 2302755077..75be9685c4 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -24,6 +24,7 @@ import torch from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.data.utils import no_collation from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utility.array import ( @@ -867,7 +868,7 @@ def __init__( self, keys: KeysCollection, func: Union[Sequence[Callable], Callable], - inv_func: Union[Sequence[Callable], Callable] = lambda x: x, + inv_func: Union[Sequence[Callable], Callable] = no_collation, overwrite: Union[Sequence[bool], bool] = True, allow_missing_keys: bool = False, ) -> None: @@ -932,7 +933,7 @@ def __init__( self, keys: KeysCollection, func: Union[Sequence[Callable], Callable], - inv_func: Union[Sequence[Callable], Callable] = lambda x: x, + inv_func: Union[Sequence[Callable], Callable] = no_collation, overwrite: Union[Sequence[bool], bool] = True, prob: float = 1.0, allow_missing_keys: bool = False,