From fcf8736916c55f335717b187ced0d66dda85313d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 28 Jan 2021 10:41:26 +0800 Subject: [PATCH] [DLMED] add overwrite option Signed-off-by: Nic Ma --- monai/transforms/utility/dictionary.py | 17 +++++++++++++---- tests/test_lambdad.py | 15 +++++++-------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1427f24356..ef89dbe32d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -599,18 +599,27 @@ class Lambdad(MapTransform): See also: :py:class:`monai.transforms.compose.MapTransform` func: Lambda/function to be applied. It also can be a sequence of Callable, each element corresponds to a key in ``keys``. + overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output. + default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. """ - def __init__(self, keys: KeysCollection, func: Union[Sequence[Callable], Callable]) -> None: + def __init__( + self, + keys: KeysCollection, + func: Union[Sequence[Callable], Callable], + overwrite: Union[Sequence[bool], bool] = True, + ) -> None: super().__init__(keys) self.func = ensure_tuple_rep(func, len(self.keys)) - self.lambd = Lambda() + self.overwrite = ensure_tuple_rep(overwrite, len(self.keys)) + self._lambd = Lambda() def __call__(self, data): d = dict(data) for idx, key in enumerate(self.keys): - d[key] = self.lambd(d[key], func=self.func[idx]) - + ret = self._lambd(d[key], func=self.func[idx]) + if self.overwrite[idx]: + d[key] = ret return d diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py index 8f7e6b1133..ca28af778b 100644 --- a/tests/test_lambdad.py +++ b/tests/test_lambdad.py @@ -20,16 +20,15 @@ class TestLambdad(NumpyImageTestCase2D): def test_lambdad_identity(self): img = self.imt - data = {} - data["img"] = img + data = {"img": img, "prop": 1.0} - def identity_func(x): - return x + def noise_func(x): + return x + 1.0 - lambd = Lambdad(keys=data.keys(), func=identity_func) - expected = data - expected["img"] = identity_func(data["img"]) - self.assertTrue(np.allclose(expected["img"], lambd(data)["img"])) + expected = {"img": noise_func(data["img"]), "prop": 1.0} + ret = Lambdad(keys=["img", "prop"], func=noise_func, overwrite=[True, False])(data) + self.assertTrue(np.allclose(expected["img"], ret["img"])) + self.assertTrue(np.allclose(expected["prop"], ret["prop"])) def test_lambdad_slicing(self): img = self.imt