From ac157da8c6a4b2b85b66116b6a2af7dae6c1cf73 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 29 May 2024 18:57:26 +0800 Subject: [PATCH 1/8] fix #7697 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/regularization/array.py | 53 ++++++++----- monai/transforms/regularization/dictionary.py | 75 +++++++++++++------ tests/test_regularization.py | 62 ++++++++------- 3 files changed, 123 insertions(+), 67 deletions(-) diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 0b495c8623..7600f1c017 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -16,7 +16,9 @@ import torch +from monai.data.meta_obj import get_track_meta from ..transform import RandomizableTransform +from monai.utils.type_conversion import convert_to_tensor, convert_to_dst_type __all__ = ["MixUp", "CutMix", "CutOut", "Mixer"] @@ -53,9 +55,11 @@ def randomize(self, data=None) -> None: as needed. You need to call this method everytime you apply the transform to a new batch. """ + super().randomize(None) self._params = ( torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), self.R.permutation(self.batch_size), + [torch.from_numpy(self.R.randint(0, d, size=(1,))) for d in data.shape[2:]] if data is not None else None, ) @@ -69,7 +73,7 @@ class MixUp(Mixer): """ def apply(self, data: torch.Tensor): - weight, perm = self._params + weight, perm, _ = self._params nsamples, *dims = data.shape if len(weight) != nsamples: raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}") @@ -80,11 +84,17 @@ def apply(self, data: torch.Tensor): mixweight = weight[(Ellipsis,) + (None,) * len(dims)] return mixweight * data + (1 - mixweight) * data[perm, ...] - def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): - self.randomize() + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): + data = convert_to_tensor(data, track_meta=get_track_meta()) + data_t = convert_to_tensor(data, track_meta=False) + if labels is not None: + labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) + labels_t = convert_to_tensor(labels, track_meta=False) + if randomize: + self.randomize() if labels is None: - return self.apply(data) - return self.apply(data), self.apply(labels) + return convert_to_dst_type(self.apply(data_t), dst=data)[0] + return convert_to_dst_type(self.apply(data_t), dst=data)[0], convert_to_dst_type(self.apply(labels_t), dst=labels)[0] class CutMix(Mixer): @@ -113,14 +123,13 @@ class CutMix(Mixer): """ def apply(self, data: torch.Tensor): - weights, perm = self._params + weights, perm, coords = self._params nsamples, _, *dims = data.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") mask = torch.ones_like(data) for s, weight in enumerate(weights): - coords = [torch.randint(0, d, size=(1,)) for d in dims] lengths = [d * sqrt(1 - weight) for d in dims] idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] mask[s][idx] = 0 @@ -128,7 +137,7 @@ def apply(self, data: torch.Tensor): return mask * data + (1 - mask) * data[perm, ...] def apply_on_labels(self, labels: torch.Tensor): - weights, perm = self._params + weights, perm, _ = self._params nsamples, *dims = labels.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") @@ -136,10 +145,18 @@ def apply_on_labels(self, labels: torch.Tensor): mixweight = weights[(Ellipsis,) + (None,) * len(dims)] return mixweight * labels + (1 - mixweight) * labels[perm, ...] - def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): - self.randomize() - augmented = self.apply(data) - return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): + data = convert_to_tensor(data, track_meta=get_track_meta()) + data_t = convert_to_tensor(data, track_meta=False) + if labels is not None: + labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) + labels_t = convert_to_tensor(labels, track_meta=False) + if randomize: + self.randomize(data) + augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0] + if labels is not None: + augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0] + return (augmented, augmented_label) if labels is not None else augmented class CutOut(Mixer): @@ -155,20 +172,22 @@ class CutOut(Mixer): """ def apply(self, data: torch.Tensor): - weights, _ = self._params + weights, _, coords = self._params nsamples, _, *dims = data.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") mask = torch.ones_like(data) for s, weight in enumerate(weights): - coords = [torch.randint(0, d, size=(1,)) for d in dims] lengths = [d * sqrt(1 - weight) for d in dims] idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] mask[s][idx] = 0 return mask * data - def __call__(self, data: torch.Tensor): - self.randomize() - return self.apply(data) + def __call__(self, data: torch.Tensor, randomize=True): + data = convert_to_tensor(data, track_meta=get_track_meta()) + data_t = convert_to_tensor(data, track_meta=False) + if randomize: + self.randomize(data) + return convert_to_dst_type(self.apply(data_t), dst=data)[0] diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py index 373913da99..a086392dde 100644 --- a/monai/transforms/regularization/dictionary.py +++ b/monai/transforms/regularization/dictionary.py @@ -10,17 +10,22 @@ # limitations under the License. from __future__ import annotations +import numpy as np +from collections.abc import Hashable from monai.config import KeysCollection from monai.utils.misc import ensure_tuple +from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_obj import get_track_meta +from monai.utils import convert_to_tensor -from ..transform import MapTransform +from ..transform import MapTransform, RandomizableTransform from .array import CutMix, CutOut, MixUp __all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"] -class MixUpd(MapTransform): +class MixUpd(MapTransform, RandomizableTransform): """ Dictionary-based version :py:class:`monai.transforms.MixUp`. @@ -31,18 +36,26 @@ class MixUpd(MapTransform): def __init__( self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False ) -> None: - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.mixup = MixUp(batch_size, alpha) + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> MixUpd: + super().set_random_state(seed, state) + self.mixup.set_random_state(seed, state) + return self + def __call__(self, data): - self.mixup.randomize() - result = dict(data) - for k in self.keys: - result[k] = self.mixup.apply(data[k]) - return result + d = dict(data) + # all the keys share the same random state + self.mixup.randomize(None) + for k in self.key_iterator(self.keys): + d[k] = self.mixup(data[k], randomize=False) + return d -class CutMixd(MapTransform): +class CutMixd(MapTransform, RandomizableTransform): """ Dictionary-based version :py:class:`monai.transforms.CutMix`. @@ -63,17 +76,29 @@ def __init__( self.mixer = CutMix(batch_size, alpha) self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] - def __call__(self, data): - self.mixer.randomize() - result = dict(data) - for k in self.keys: - result[k] = self.mixer.apply(data[k]) - for k in self.label_keys: - result[k] = self.mixer.apply_on_labels(data[k]) - return result - + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> MixUpd: + super().set_random_state(seed, state) + self.mixer.set_random_state(seed, state) + return self -class CutOutd(MapTransform): + def __call__(self, data): + d = dict(data) + first_key: Hashable = self.first_key(d) + if first_key == (): + out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) + return out + self.mixer.randomize(d[first_key]) + for key, label_key in self.key_iterator(self.keys, self.label_keys): + ret = self.mixer(data[key], data.get(label_key, None), randomize=False) + d[key] = ret[0] + if label_key in d: + d[label_key] = ret[1] + return d + + +class CutOutd(MapTransform, RandomizableTransform): """ Dictionary-based version :py:class:`monai.transforms.CutOut`. @@ -85,11 +110,15 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo self.cutout = CutOut(batch_size) def __call__(self, data): - result = dict(data) - self.cutout.randomize() + d = dict(data) + first_key: Hashable = self.first_key(d) + if first_key == (): + out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) + return out + self.cutout.randomize(d[first_key]) for k in self.keys: - result[k] = self.cutout(data[k]) - return result + d[k] = self.cutout(data[k]) + return d MixUpD = MixUpDict = MixUpd diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 32df2f7b41..73fe407bcb 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -14,28 +14,30 @@ import unittest import torch +import numpy as np from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd -from monai.utils import set_determinism +from tests.utils import assert_allclose -@unittest.skip("Mixup is non-deterministic. Skip it temporarily") class TestMixup(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - def test_mixup(self): for dims in [2, 3]: shape = (6, 3) + (32,) * dims sample = torch.rand(*shape, dtype=torch.float32) mixup = MixUp(6, 1.0) + mixup.set_random_state(seed=0) output = mixup(sample) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) self.assertEqual(output.shape, sample.shape) - self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10))) + mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)] + expected = mixweight * sample + (1 - mixweight) * sample[perm, ...] + assert_allclose(output, expected, type_test=False, atol=1e-7) with self.assertRaises(ValueError): MixUp(6, -0.5) @@ -53,27 +55,32 @@ def test_mixupd(self): t = torch.rand(*shape, dtype=torch.float32) sample = {"a": t, "b": t} mixup = MixUpd(["a", "b"], 6) + mixup.set_random_state(seed=0) output = mixup(sample) - self.assertTrue(torch.allclose(output["a"], output["b"])) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) + self.assertEqual(output["a"].shape, sample["a"].shape) + mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)] + expected = mixweight * sample["a"] + (1 - mixweight) * sample["a"][perm, ...] + assert_allclose(output["a"], expected, type_test=False, atol=1e-7) + assert_allclose(output["a"], output["b"], type_test=False, atol=1e-7) + # self.assertTrue(torch.allclose(output["a"], output["b"])) with self.assertRaises(ValueError): MixUpd(["k1", "k2"], 6, -0.5) -@unittest.skip("CutMix is non-deterministic. Skip it temporarily") class TestCutMix(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - def test_cutmix(self): for dims in [2, 3]: shape = (6, 3) + (32,) * dims sample = torch.rand(*shape, dtype=torch.float32) cutmix = CutMix(6, 1.0) + cutmix.set_random_state(seed=0) output = cutmix(sample) self.assertEqual(output.shape, sample.shape) self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) @@ -85,30 +92,31 @@ def test_cutmixd(self): label = torch.randint(0, 1, shape) sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) + cutmix.set_random_state(seed=123) output = cutmix(sample) - # croppings are different on each application - self.assertTrue(not torch.allclose(output["a"], output["b"])) # but mixing of labels is not affected by it self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) -@unittest.skip("CutOut is non-deterministic. Skip it temporarily") class TestCutOut(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - def test_cutout(self): for dims in [2, 3]: shape = (6, 3) + (32,) * dims sample = torch.rand(*shape, dtype=torch.float32) cutout = CutOut(6, 1.0) + cutout.set_random_state(seed=123) output = cutout(sample) + np.random.seed(123) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) + coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in sample.shape[2:]] + assert_allclose(weight, cutout._params[0]) + assert_allclose(perm, cutout._params[1]) + self.assertSequenceEqual(coords, cutout._params[2]) self.assertEqual(output.shape, sample.shape) - self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10))) if __name__ == "__main__": From da123975f0e546c89dbc6d3f0d967d686ca1d2a1 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 29 May 2024 18:57:37 +0800 Subject: [PATCH 2/8] fix format Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/regularization/array.py | 8 ++++++-- monai/transforms/regularization/dictionary.py | 14 ++++++-------- tests/test_regularization.py | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 7600f1c017..a98777920a 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -17,8 +17,9 @@ import torch from monai.data.meta_obj import get_track_meta +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor + from ..transform import RandomizableTransform -from monai.utils.type_conversion import convert_to_tensor, convert_to_dst_type __all__ = ["MixUp", "CutMix", "CutOut", "Mixer"] @@ -94,7 +95,10 @@ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, rando self.randomize() if labels is None: return convert_to_dst_type(self.apply(data_t), dst=data)[0] - return convert_to_dst_type(self.apply(data_t), dst=data)[0], convert_to_dst_type(self.apply(labels_t), dst=labels)[0] + return ( + convert_to_dst_type(self.apply(data_t), dst=data)[0], + convert_to_dst_type(self.apply(labels_t), dst=labels)[0], + ) class CutMix(Mixer): diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py index a086392dde..af5f6275d4 100644 --- a/monai/transforms/regularization/dictionary.py +++ b/monai/transforms/regularization/dictionary.py @@ -10,14 +10,16 @@ # limitations under the License. from __future__ import annotations -import numpy as np + from collections.abc import Hashable +import numpy as np + from monai.config import KeysCollection -from monai.utils.misc import ensure_tuple from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.utils import convert_to_tensor +from monai.utils.misc import ensure_tuple from ..transform import MapTransform, RandomizableTransform from .array import CutMix, CutOut, MixUp @@ -39,9 +41,7 @@ def __init__( MapTransform.__init__(self, keys, allow_missing_keys) self.mixup = MixUp(batch_size, alpha) - def set_random_state( - self, seed: int | None = None, state: np.random.RandomState | None = None - ) -> MixUpd: + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd: super().set_random_state(seed, state) self.mixup.set_random_state(seed, state) return self @@ -76,9 +76,7 @@ def __init__( self.mixer = CutMix(batch_size, alpha) self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] - def set_random_state( - self, seed: int | None = None, state: np.random.RandomState | None = None - ) -> MixUpd: + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd: super().set_random_state(seed, state) self.mixer.set_random_state(seed, state) return self diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 73fe407bcb..29c3e9f578 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -13,8 +13,8 @@ import unittest -import torch import numpy as np +import torch from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd from tests.utils import assert_allclose From 6bde51492e7200e69c47b93f75713fa8b1086834 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 29 May 2024 21:53:17 +0800 Subject: [PATCH 3/8] Update monai/transforms/regularization/array.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/regularization/array.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index a98777920a..35fb83fdf4 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -190,8 +190,7 @@ def apply(self, data: torch.Tensor): return mask * data def __call__(self, data: torch.Tensor, randomize=True): - data = convert_to_tensor(data, track_meta=get_track_meta()) - data_t = convert_to_tensor(data, track_meta=False) + data_t = convert_to_tensor(data, track_meta=get_track_meta()) if randomize: self.randomize(data) return convert_to_dst_type(self.apply(data_t), dst=data)[0] From fd14a413b6354517c939afdc4e6f46b198f19240 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 29 May 2024 21:53:34 +0800 Subject: [PATCH 4/8] Update monai/transforms/regularization/array.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/regularization/array.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 35fb83fdf4..ebd3fe8111 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -150,11 +150,9 @@ def apply_on_labels(self, labels: torch.Tensor): return mixweight * labels + (1 - mixweight) * labels[perm, ...] def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): - data = convert_to_tensor(data, track_meta=get_track_meta()) - data_t = convert_to_tensor(data, track_meta=False) + data_t = convert_to_tensor(data, track_meta=get_track_meta()) if labels is not None: labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) - labels_t = convert_to_tensor(labels, track_meta=False) if randomize: self.randomize(data) augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0] From 1e0bd88659a2be0b7ed5e68f6318450f8523bfcd Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 29 May 2024 21:53:43 +0800 Subject: [PATCH 5/8] Update monai/transforms/regularization/array.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/regularization/array.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index ebd3fe8111..12d6d2931a 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -86,11 +86,9 @@ def apply(self, data: torch.Tensor): return mixweight * data + (1 - mixweight) * data[perm, ...] def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): - data = convert_to_tensor(data, track_meta=get_track_meta()) - data_t = convert_to_tensor(data, track_meta=False) + data_t = convert_to_tensor(data, track_meta=get_track_meta()) if labels is not None: labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) - labels_t = convert_to_tensor(labels, track_meta=False) if randomize: self.randomize() if labels is None: From 37f09dc1351545516043cf60b5595fc09cff6b4e Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 29 May 2024 22:12:46 +0800 Subject: [PATCH 6/8] ad more tests Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/regularization/dictionary.py | 15 ++++++++----- tests/test_regularization.py | 21 ++++++++++++++++++- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py index af5f6275d4..d8815e47b9 100644 --- a/monai/transforms/regularization/dictionary.py +++ b/monai/transforms/regularization/dictionary.py @@ -50,7 +50,7 @@ def __call__(self, data): d = dict(data) # all the keys share the same random state self.mixup.randomize(None) - for k in self.key_iterator(self.keys): + for k in self.key_iterator(d): d[k] = self.mixup(data[k], randomize=False) return d @@ -76,7 +76,7 @@ def __init__( self.mixer = CutMix(batch_size, alpha) self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] - def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd: + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutMixd: super().set_random_state(seed, state) self.mixer.set_random_state(seed, state) return self @@ -88,7 +88,7 @@ def __call__(self, data): out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out self.mixer.randomize(d[first_key]) - for key, label_key in self.key_iterator(self.keys, self.label_keys): + for key, label_key in self.key_iterator(d, self.label_keys): ret = self.mixer(data[key], data.get(label_key, None), randomize=False) d[key] = ret[0] if label_key in d: @@ -107,6 +107,11 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo super().__init__(keys, allow_missing_keys) self.cutout = CutOut(batch_size) + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutOutd: + super().set_random_state(seed, state) + self.cutout.set_random_state(seed, state) + return self + def __call__(self, data): d = dict(data) first_key: Hashable = self.first_key(d) @@ -114,8 +119,8 @@ def __call__(self, data): out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out self.cutout.randomize(d[first_key]) - for k in self.keys: - d[k] = self.cutout(data[k]) + for k in self.key_iterator(d): + d[k] = self.cutout(data[k], randomize=False) return d diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 29c3e9f578..7552b30baf 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -16,7 +16,7 @@ import numpy as np import torch -from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd +from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd, CutOutd from tests.utils import assert_allclose @@ -118,6 +118,25 @@ def test_cutout(self): self.assertSequenceEqual(coords, cutout._params[2]) self.assertEqual(output.shape, sample.shape) + def test_cutoutd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + sample = {"a": t, "b": t} + cutout = CutOutd(["a", "b"], 6, 1.0) + cutout.set_random_state(seed=123) + output = cutout(sample) + np.random.seed(123) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) + coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in t.shape[2:]] + assert_allclose(weight, cutout.cutout._params[0]) + assert_allclose(perm, cutout.cutout._params[1]) + self.assertSequenceEqual(coords, cutout.cutout._params[2]) + self.assertEqual(output["a"].shape, sample["a"].shape) + if __name__ == "__main__": unittest.main() From 6f65042a3d7443a4d185ab21aae1fc8d47c5e5a8 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 29 May 2024 22:19:13 +0800 Subject: [PATCH 7/8] fix mypy Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/regularization/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 12d6d2931a..a7436bda84 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -60,7 +60,7 @@ def randomize(self, data=None) -> None: self._params = ( torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), self.R.permutation(self.batch_size), - [torch.from_numpy(self.R.randint(0, d, size=(1,))) for d in data.shape[2:]] if data is not None else None, + [torch.from_numpy(self.R.randint(0, d, size=(1,))) for d in data.shape[2:]] if data is not None else [], ) From 7090065d4cc9f98fcb15e25ff84143b5999c700b Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 30 May 2024 10:52:57 +0800 Subject: [PATCH 8/8] Fix flake8 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/test_regularization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 7552b30baf..12d64637d5 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -16,7 +16,7 @@ import numpy as np import torch -from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd, CutOutd +from monai.transforms import CutMix, CutMixd, CutOut, CutOutd, MixUp, MixUpd from tests.utils import assert_allclose