Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,12 @@ Utility
:members:
:special-members: __call__

`RandLambda`
""""""""""""
.. autoclass:: RandLambda
:members:
:special-members: __call__

`LabelToMask`
"""""""""""""
.. autoclass:: LabelToMask
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@
LabelToMask,
Lambda,
MapLabelValue,
RandLambda,
RemoveRepeatedChannel,
RepeatChannel,
SimulateDelay,
Expand Down
25 changes: 24 additions & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch

from monai.config import DtypeLike, NdarrayTensor
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
from monai.transforms.utils import (
convert_to_numpy,
convert_to_tensor,
Expand Down Expand Up @@ -58,6 +58,7 @@
"DataStats",
"SimulateDelay",
"Lambda",
"RandLambda",
"LabelToMask",
"FgBgToIndices",
"ClassesToIndices",
Expand Down Expand Up @@ -617,6 +618,28 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable
raise ValueError("Incompatible values: func=None and self.func=None.")


class RandLambda(Lambda, RandomizableTransform):
"""
Randomizable version :py:class:`monai.transforms.Lambda`, the input `func` may contain random logic,
or randomly execute the function based on `prob`.

Args:
func: Lambda/function to be applied.
prob: probability of executing the random function, default to 1.0, with 100% probability to execute.

For more details, please check :py:class:`monai.transforms.Lambda`.

"""

def __init__(self, func: Optional[Callable] = None, prob: float = 1.0) -> None:
Lambda.__init__(self=self, func=func)
RandomizableTransform.__init__(self=self, prob=prob)

def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None):
self.randomize(img)
return super().__call__(img=img, func=func) if self._do_transform else img


class LabelToMask(Transform):
"""
Convert labels to mask for other tasks. A typical usage is to convert segmentation labels
Expand Down
76 changes: 68 additions & 8 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
import torch

from monai.config import DtypeLike, KeysCollection, NdarrayTensor
from monai.data.utils import no_collation
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Randomizable
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.utility.array import (
AddChannel,
AsChannelFirst,
Expand Down Expand Up @@ -833,7 +834,7 @@ def __call__(self, data):
return d


class Lambdad(MapTransform):
class Lambdad(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Lambda`.

Expand All @@ -852,51 +853,110 @@ class Lambdad(MapTransform):
See also: :py:class:`monai.transforms.compose.MapTransform`
func: Lambda/function to be applied. It also can be a sequence of Callable,
each element corresponds to a key in ``keys``.
inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.

Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the
image's original size. If need these complicated information, please write a new InvertibleTransform directly.

"""

def __init__(
self,
keys: KeysCollection,
func: Union[Sequence[Callable], Callable],
inv_func: Union[Sequence[Callable], Callable] = no_collation,
overwrite: Union[Sequence[bool], bool] = True,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.func = ensure_tuple_rep(func, len(self.keys))
self.inv_func = ensure_tuple_rep(inv_func, len(self.keys))
self.overwrite = ensure_tuple_rep(overwrite, len(self.keys))
self._lambd = Lambda()

def _transform(self, data: Any, func: Callable):
return self._lambd(data, func=func)

def __call__(self, data):
d = dict(data)
for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite):
ret = self._lambd(d[key], func=func)
ret = self._transform(data=d[key], func=func)
if overwrite:
d[key] = ret
self.push_transform(d, key)
return d

def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable):
return self._lambd(data, func=func)

def inverse(self, data):
d = deepcopy(dict(data))
for key, inv_func, overwrite in self.key_iterator(d, self.inv_func, self.overwrite):
transform = self.get_most_recent_transform(d, key)
ret = self._inverse_transform(transform_info=transform, data=d[key], func=inv_func)
if overwrite:
d[key] = ret
self.pop_transform(d, key)
return d


class RandLambdad(Lambdad, Randomizable):
class RandLambdad(Lambdad, RandomizableTransform):
"""
Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic.
It's a randomizable transform so `CacheDataset` will not execute it and cache the results.
Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` may contain random logic,
or randomly execute the function based on `prob`. so `CacheDataset` will not execute it and cache the results.

Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
func: Lambda/function to be applied. It also can be a sequence of Callable,
each element corresponds to a key in ``keys``.
inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
note that all the data specified by `keys` will share the same random probability to execute or not.
allow_missing_keys: don't raise exception if key is missing.

For more details, please check :py:class:`monai.transforms.Lambdad`.

Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the
image's original size. If need these complicated information, please write a new InvertibleTransform directly.

"""

def randomize(self, data: Any) -> None:
pass
def __init__(
self,
keys: KeysCollection,
func: Union[Sequence[Callable], Callable],
inv_func: Union[Sequence[Callable], Callable] = no_collation,
overwrite: Union[Sequence[bool], bool] = True,
prob: float = 1.0,
allow_missing_keys: bool = False,
) -> None:
Lambdad.__init__(
self=self,
keys=keys,
func=func,
inv_func=inv_func,
overwrite=overwrite,
allow_missing_keys=allow_missing_keys,
)
RandomizableTransform.__init__(self=self, prob=prob, do_transform=True)

def _transform(self, data: Any, func: Callable):
return self._lambd(data, func=func) if self._do_transform else data

def __call__(self, data):
self.randomize(data)
return super().__call__(data)

def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable):
return self._lambd(data, func=func) if transform_info[InverseKeys.DO_TRANSFORM] else data


class LabelToMaskd(MapTransform):
Expand Down
12 changes: 12 additions & 0 deletions tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@
DivisiblePadd,
Flipd,
InvertibleTransform,
Lambdad,
LoadImaged,
Orientationd,
RandAffined,
RandAxisFlipd,
RandCropByLabelClassesd,
RandCropByPosNegLabeld,
RandFlipd,
RandLambdad,
Randomizable,
RandRotate90d,
RandRotated,
Expand Down Expand Up @@ -314,6 +316,16 @@

TESTS.append(("Resized longest 3d", "3D", 5e-2, Resized(KEYS, 201, "longest", "trilinear", True)))

TESTS.append(("Lambdad 2d", "2D", 5e-2, Lambdad(KEYS, func=lambda x: x + 5, inv_func=lambda x: x - 5, overwrite=True)))

TESTS.append(
(
"RandLambdad 3d",
"3D",
5e-2,
RandLambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True, prob=0.5),
)
)

TESTS.append(
(
Expand Down
53 changes: 53 additions & 0 deletions tests/test_rand_lambda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

from monai.transforms.transform import Randomizable
from monai.transforms.utility.array import RandLambda


class RandTest(Randomizable):
"""
randomisable transform for testing.
"""

def randomize(self, data=None):
self._a = self.R.random()

def __call__(self, data):
self.randomize()
return data + self._a


class TestRandLambda(unittest.TestCase):
def test_rand_lambdad_identity(self):
img = np.zeros((10, 10))

test_func = RandTest()
test_func.set_random_state(seed=134)
expected = test_func(img)
test_func.set_random_state(seed=134)
ret = RandLambda(func=test_func)(img)
np.testing.assert_allclose(expected, ret)
ret = RandLambda(func=test_func, prob=0.0)(img)
np.testing.assert_allclose(img, ret)

trans = RandLambda(func=test_func, prob=0.5)
trans.set_random_state(seed=123)
ret = trans(img)
np.testing.assert_allclose(img, ret)


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions tests/test_rand_lambdad.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def test_rand_lambdad_identity(self):
ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data)
np.testing.assert_allclose(expected["img"], ret["img"])
np.testing.assert_allclose(expected["prop"], ret["prop"])
ret = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.0)(data)
np.testing.assert_allclose(data["img"], ret["img"])
np.testing.assert_allclose(data["prop"], ret["prop"])

trans = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.5)
trans.set_random_state(seed=123)
ret = trans(data)
np.testing.assert_allclose(data["img"], ret["img"])
np.testing.assert_allclose(data["prop"], ret["prop"])


if __name__ == "__main__":
Expand Down