diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 0b495c8623..a7436bda84 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -16,6 +16,9 @@ import torch +from monai.data.meta_obj import get_track_meta +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor + from ..transform import RandomizableTransform __all__ = ["MixUp", "CutMix", "CutOut", "Mixer"] @@ -53,9 +56,11 @@ def randomize(self, data=None) -> None: as needed. You need to call this method everytime you apply the transform to a new batch. """ + super().randomize(None) self._params = ( torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), self.R.permutation(self.batch_size), + [torch.from_numpy(self.R.randint(0, d, size=(1,))) for d in data.shape[2:]] if data is not None else [], ) @@ -69,7 +74,7 @@ class MixUp(Mixer): """ def apply(self, data: torch.Tensor): - weight, perm = self._params + weight, perm, _ = self._params nsamples, *dims = data.shape if len(weight) != nsamples: raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}") @@ -80,11 +85,18 @@ def apply(self, data: torch.Tensor): mixweight = weight[(Ellipsis,) + (None,) * len(dims)] return mixweight * data + (1 - mixweight) * data[perm, ...] - def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): - self.randomize() + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): + data_t = convert_to_tensor(data, track_meta=get_track_meta()) + if labels is not None: + labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) + if randomize: + self.randomize() if labels is None: - return self.apply(data) - return self.apply(data), self.apply(labels) + return convert_to_dst_type(self.apply(data_t), dst=data)[0] + return ( + convert_to_dst_type(self.apply(data_t), dst=data)[0], + convert_to_dst_type(self.apply(labels_t), dst=labels)[0], + ) class CutMix(Mixer): @@ -113,14 +125,13 @@ class CutMix(Mixer): """ def apply(self, data: torch.Tensor): - weights, perm = self._params + weights, perm, coords = self._params nsamples, _, *dims = data.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") mask = torch.ones_like(data) for s, weight in enumerate(weights): - coords = [torch.randint(0, d, size=(1,)) for d in dims] lengths = [d * sqrt(1 - weight) for d in dims] idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] mask[s][idx] = 0 @@ -128,7 +139,7 @@ def apply(self, data: torch.Tensor): return mask * data + (1 - mask) * data[perm, ...] def apply_on_labels(self, labels: torch.Tensor): - weights, perm = self._params + weights, perm, _ = self._params nsamples, *dims = labels.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") @@ -136,10 +147,16 @@ def apply_on_labels(self, labels: torch.Tensor): mixweight = weights[(Ellipsis,) + (None,) * len(dims)] return mixweight * labels + (1 - mixweight) * labels[perm, ...] - def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): - self.randomize() - augmented = self.apply(data) - return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True): + data_t = convert_to_tensor(data, track_meta=get_track_meta()) + if labels is not None: + labels_t = convert_to_tensor(labels, track_meta=get_track_meta()) + if randomize: + self.randomize(data) + augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0] + if labels is not None: + augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0] + return (augmented, augmented_label) if labels is not None else augmented class CutOut(Mixer): @@ -155,20 +172,21 @@ class CutOut(Mixer): """ def apply(self, data: torch.Tensor): - weights, _ = self._params + weights, _, coords = self._params nsamples, _, *dims = data.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") mask = torch.ones_like(data) for s, weight in enumerate(weights): - coords = [torch.randint(0, d, size=(1,)) for d in dims] lengths = [d * sqrt(1 - weight) for d in dims] idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] mask[s][idx] = 0 return mask * data - def __call__(self, data: torch.Tensor): - self.randomize() - return self.apply(data) + def __call__(self, data: torch.Tensor, randomize=True): + data_t = convert_to_tensor(data, track_meta=get_track_meta()) + if randomize: + self.randomize(data) + return convert_to_dst_type(self.apply(data_t), dst=data)[0] diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py index 373913da99..d8815e47b9 100644 --- a/monai/transforms/regularization/dictionary.py +++ b/monai/transforms/regularization/dictionary.py @@ -11,16 +11,23 @@ from __future__ import annotations +from collections.abc import Hashable + +import numpy as np + from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.data.meta_obj import get_track_meta +from monai.utils import convert_to_tensor from monai.utils.misc import ensure_tuple -from ..transform import MapTransform +from ..transform import MapTransform, RandomizableTransform from .array import CutMix, CutOut, MixUp __all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"] -class MixUpd(MapTransform): +class MixUpd(MapTransform, RandomizableTransform): """ Dictionary-based version :py:class:`monai.transforms.MixUp`. @@ -31,18 +38,24 @@ class MixUpd(MapTransform): def __init__( self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False ) -> None: - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.mixup = MixUp(batch_size, alpha) + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd: + super().set_random_state(seed, state) + self.mixup.set_random_state(seed, state) + return self + def __call__(self, data): - self.mixup.randomize() - result = dict(data) - for k in self.keys: - result[k] = self.mixup.apply(data[k]) - return result + d = dict(data) + # all the keys share the same random state + self.mixup.randomize(None) + for k in self.key_iterator(d): + d[k] = self.mixup(data[k], randomize=False) + return d -class CutMixd(MapTransform): +class CutMixd(MapTransform, RandomizableTransform): """ Dictionary-based version :py:class:`monai.transforms.CutMix`. @@ -63,17 +76,27 @@ def __init__( self.mixer = CutMix(batch_size, alpha) self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] - def __call__(self, data): - self.mixer.randomize() - result = dict(data) - for k in self.keys: - result[k] = self.mixer.apply(data[k]) - for k in self.label_keys: - result[k] = self.mixer.apply_on_labels(data[k]) - return result - + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutMixd: + super().set_random_state(seed, state) + self.mixer.set_random_state(seed, state) + return self -class CutOutd(MapTransform): + def __call__(self, data): + d = dict(data) + first_key: Hashable = self.first_key(d) + if first_key == (): + out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) + return out + self.mixer.randomize(d[first_key]) + for key, label_key in self.key_iterator(d, self.label_keys): + ret = self.mixer(data[key], data.get(label_key, None), randomize=False) + d[key] = ret[0] + if label_key in d: + d[label_key] = ret[1] + return d + + +class CutOutd(MapTransform, RandomizableTransform): """ Dictionary-based version :py:class:`monai.transforms.CutOut`. @@ -84,12 +107,21 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo super().__init__(keys, allow_missing_keys) self.cutout = CutOut(batch_size) + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutOutd: + super().set_random_state(seed, state) + self.cutout.set_random_state(seed, state) + return self + def __call__(self, data): - result = dict(data) - self.cutout.randomize() - for k in self.keys: - result[k] = self.cutout(data[k]) - return result + d = dict(data) + first_key: Hashable = self.first_key(d) + if first_key == (): + out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) + return out + self.cutout.randomize(d[first_key]) + for k in self.key_iterator(d): + d[k] = self.cutout(data[k], randomize=False) + return d MixUpD = MixUpDict = MixUpd diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 32df2f7b41..12d64637d5 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -13,29 +13,31 @@ import unittest +import numpy as np import torch -from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd -from monai.utils import set_determinism +from monai.transforms import CutMix, CutMixd, CutOut, CutOutd, MixUp, MixUpd +from tests.utils import assert_allclose -@unittest.skip("Mixup is non-deterministic. Skip it temporarily") class TestMixup(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - def test_mixup(self): for dims in [2, 3]: shape = (6, 3) + (32,) * dims sample = torch.rand(*shape, dtype=torch.float32) mixup = MixUp(6, 1.0) + mixup.set_random_state(seed=0) output = mixup(sample) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) self.assertEqual(output.shape, sample.shape) - self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10))) + mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)] + expected = mixweight * sample + (1 - mixweight) * sample[perm, ...] + assert_allclose(output, expected, type_test=False, atol=1e-7) with self.assertRaises(ValueError): MixUp(6, -0.5) @@ -53,27 +55,32 @@ def test_mixupd(self): t = torch.rand(*shape, dtype=torch.float32) sample = {"a": t, "b": t} mixup = MixUpd(["a", "b"], 6) + mixup.set_random_state(seed=0) output = mixup(sample) - self.assertTrue(torch.allclose(output["a"], output["b"])) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) + self.assertEqual(output["a"].shape, sample["a"].shape) + mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)] + expected = mixweight * sample["a"] + (1 - mixweight) * sample["a"][perm, ...] + assert_allclose(output["a"], expected, type_test=False, atol=1e-7) + assert_allclose(output["a"], output["b"], type_test=False, atol=1e-7) + # self.assertTrue(torch.allclose(output["a"], output["b"])) with self.assertRaises(ValueError): MixUpd(["k1", "k2"], 6, -0.5) -@unittest.skip("CutMix is non-deterministic. Skip it temporarily") class TestCutMix(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - def test_cutmix(self): for dims in [2, 3]: shape = (6, 3) + (32,) * dims sample = torch.rand(*shape, dtype=torch.float32) cutmix = CutMix(6, 1.0) + cutmix.set_random_state(seed=0) output = cutmix(sample) self.assertEqual(output.shape, sample.shape) self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) @@ -85,30 +92,50 @@ def test_cutmixd(self): label = torch.randint(0, 1, shape) sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) + cutmix.set_random_state(seed=123) output = cutmix(sample) - # croppings are different on each application - self.assertTrue(not torch.allclose(output["a"], output["b"])) # but mixing of labels is not affected by it self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) -@unittest.skip("CutOut is non-deterministic. Skip it temporarily") class TestCutOut(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) - - def tearDown(self) -> None: - set_determinism(None) - def test_cutout(self): for dims in [2, 3]: shape = (6, 3) + (32,) * dims sample = torch.rand(*shape, dtype=torch.float32) cutout = CutOut(6, 1.0) + cutout.set_random_state(seed=123) output = cutout(sample) + np.random.seed(123) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) + coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in sample.shape[2:]] + assert_allclose(weight, cutout._params[0]) + assert_allclose(perm, cutout._params[1]) + self.assertSequenceEqual(coords, cutout._params[2]) self.assertEqual(output.shape, sample.shape) - self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10))) + + def test_cutoutd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + sample = {"a": t, "b": t} + cutout = CutOutd(["a", "b"], 6, 1.0) + cutout.set_random_state(seed=123) + output = cutout(sample) + np.random.seed(123) + # simulate the randomize() of transform + np.random.random() + weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32) + perm = np.random.permutation(6) + coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in t.shape[2:]] + assert_allclose(weight, cutout.cutout._params[0]) + assert_allclose(perm, cutout.cutout._params[1]) + self.assertSequenceEqual(coords, cutout.cutout._params[2]) + self.assertEqual(output["a"].shape, sample["a"].shape) if __name__ == "__main__":