From bc6dea3e8658e5e72c0de6c24cf4450dbbe8c8f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?= Date: Tue, 23 Apr 2024 19:41:25 +0200 Subject: [PATCH 1/3] fixed unit tests with random function --- tests/test_regularization.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 4df60b9808..c84e876e36 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -20,7 +20,6 @@ class TestMixup(unittest.TestCase): - def setUp(self) -> None: set_determinism(seed=0) @@ -60,7 +59,6 @@ def test_mixupd(self): class TestCutMix(unittest.TestCase): - def setUp(self) -> None: set_determinism(seed=0) @@ -75,23 +73,32 @@ def test_cutmix(self): output = cutmix(sample) self.assertEqual(output.shape, sample.shape) self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) + # croppings are different on each application... most of the times! + checks = [torch.allclose(sample, cutmix(sample)) for _ in range(1000)] + # 1000/(32*32*32) + self.assertTrue(sum(checks) < 5) def test_cutmixd(self): + batch_size = 6 for dims in [2, 3]: - shape = (6, 3) + (32,) * dims + shape = (batch_size, 3) + (32,) * dims t = torch.rand(*shape, dtype=torch.float32) label = torch.randint(0, 1, shape) sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} - cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) - output = cutmix(sample) - # croppings are different on each application - self.assertTrue(not torch.allclose(output["a"], output["b"])) + cutmix = CutMixd(["a", "b"], batch_size, label_keys=("lbl1", "lbl2")) + # croppings are different on each application... most of the times! + checks = [] + for _ in range(1000): + output = cutmix(sample) + checks.append(torch.allclose(output["a"], output["b"])) + # 1000/(32*32*32) + self.assertTrue(sum(checks) < 5) + # but mixing of labels is not affected by it self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) class TestCutOut(unittest.TestCase): - def setUp(self) -> None: set_determinism(seed=0) From 9a50847f8cae8af429f6fc9b50337ad95bb1b838 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?= Date: Sun, 28 Apr 2024 12:34:16 +0200 Subject: [PATCH 2/3] allow to set rnadom seed of transforms MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Juan Pablo de la Cruz Gutiérrez --- monai/transforms/regularization/dictionary.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py index 373913da99..72d3f13625 100644 --- a/monai/transforms/regularization/dictionary.py +++ b/monai/transforms/regularization/dictionary.py @@ -13,6 +13,7 @@ from monai.config import KeysCollection from monai.utils.misc import ensure_tuple +from numpy.random import RandomState from ..transform import MapTransform from .array import CutMix, CutOut, MixUp @@ -28,12 +29,14 @@ class MixUpd(MapTransform): for consistency, i.e. images and labels must be applied the same augmenation. """ - def __init__( - self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False - ) -> None: + def __init__(self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False) -> None: super().__init__(keys, allow_missing_keys) self.mixup = MixUp(batch_size, alpha) + def set_random_state(self, seed: int | None = None, state: RandomState | None = None): + self.mixup.set_random_state(seed, state) + return self + def __call__(self, data): self.mixup.randomize() result = dict(data) @@ -63,6 +66,10 @@ def __init__( self.mixer = CutMix(batch_size, alpha) self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] + def set_random_state(self, seed: int | None = None, state: RandomState | None = None): + self.mixer.set_random_state(seed, state) + return self + def __call__(self, data): self.mixer.randomize() result = dict(data) @@ -84,6 +91,10 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo super().__init__(keys, allow_missing_keys) self.cutout = CutOut(batch_size) + def set_random_state(self, seed: int | None = None, state: RandomState | None = None): + self.cutout.set_random_state(seed, state) + return self + def __call__(self, data): result = dict(data) self.cutout.randomize() From efddfef43dc4365a24755882e7a94b2606810f43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?= Date: Sun, 28 Apr 2024 17:57:15 +0200 Subject: [PATCH 3/3] =?UTF-8?q?DCO=20Remediation=20Commit=20for=20Juan=20P?= =?UTF-8?q?ablo=20de=20la=20Cruz=20Guti=C3=A9rrez=20?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I, Juan Pablo de la Cruz Gutiérrez , hereby add my Signed-off-by to this commit: bc6dea3e8658e5e72c0de6c24cf4450dbbe8c8f4 Signed-off-by: Juan Pablo de la Cruz Gutiérrez --- monai/transforms/regularization/dictionary.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py index 72d3f13625..cd8ef14097 100644 --- a/monai/transforms/regularization/dictionary.py +++ b/monai/transforms/regularization/dictionary.py @@ -11,9 +11,10 @@ from __future__ import annotations +from numpy.random import RandomState + from monai.config import KeysCollection from monai.utils.misc import ensure_tuple -from numpy.random import RandomState from ..transform import MapTransform from .array import CutMix, CutOut, MixUp @@ -29,7 +30,9 @@ class MixUpd(MapTransform): for consistency, i.e. images and labels must be applied the same augmenation. """ - def __init__(self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False + ) -> None: super().__init__(keys, allow_missing_keys) self.mixup = MixUp(batch_size, alpha)