diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py index 373913da99..cd8ef14097 100644 --- a/monai/transforms/regularization/dictionary.py +++ b/monai/transforms/regularization/dictionary.py @@ -11,6 +11,8 @@ from __future__ import annotations +from numpy.random import RandomState + from monai.config import KeysCollection from monai.utils.misc import ensure_tuple @@ -34,6 +36,10 @@ def __init__( super().__init__(keys, allow_missing_keys) self.mixup = MixUp(batch_size, alpha) + def set_random_state(self, seed: int | None = None, state: RandomState | None = None): + self.mixup.set_random_state(seed, state) + return self + def __call__(self, data): self.mixup.randomize() result = dict(data) @@ -63,6 +69,10 @@ def __init__( self.mixer = CutMix(batch_size, alpha) self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] + def set_random_state(self, seed: int | None = None, state: RandomState | None = None): + self.mixer.set_random_state(seed, state) + return self + def __call__(self, data): self.mixer.randomize() result = dict(data) @@ -84,6 +94,10 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo super().__init__(keys, allow_missing_keys) self.cutout = CutOut(batch_size) + def set_random_state(self, seed: int | None = None, state: RandomState | None = None): + self.cutout.set_random_state(seed, state) + return self + def __call__(self, data): result = dict(data) self.cutout.randomize() diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 4df60b9808..c84e876e36 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -20,7 +20,6 @@ class TestMixup(unittest.TestCase): - def setUp(self) -> None: set_determinism(seed=0) @@ -60,7 +59,6 @@ def test_mixupd(self): class TestCutMix(unittest.TestCase): - def setUp(self) -> None: set_determinism(seed=0) @@ -75,23 +73,32 @@ def test_cutmix(self): output = cutmix(sample) self.assertEqual(output.shape, sample.shape) self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) + # croppings are different on each application... most of the times! + checks = [torch.allclose(sample, cutmix(sample)) for _ in range(1000)] + # 1000/(32*32*32) + self.assertTrue(sum(checks) < 5) def test_cutmixd(self): + batch_size = 6 for dims in [2, 3]: - shape = (6, 3) + (32,) * dims + shape = (batch_size, 3) + (32,) * dims t = torch.rand(*shape, dtype=torch.float32) label = torch.randint(0, 1, shape) sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} - cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) - output = cutmix(sample) - # croppings are different on each application - self.assertTrue(not torch.allclose(output["a"], output["b"])) + cutmix = CutMixd(["a", "b"], batch_size, label_keys=("lbl1", "lbl2")) + # croppings are different on each application... most of the times! + checks = [] + for _ in range(1000): + output = cutmix(sample) + checks.append(torch.allclose(output["a"], output["b"])) + # 1000/(32*32*32) + self.assertTrue(sum(checks) < 5) + # but mixing of labels is not affected by it self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) class TestCutOut(unittest.TestCase): - def setUp(self) -> None: set_determinism(seed=0)