Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions monai/transforms/regularization/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from __future__ import annotations

from numpy.random import RandomState

from monai.config import KeysCollection
from monai.utils.misc import ensure_tuple

Expand All @@ -34,6 +36,10 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)

def set_random_state(self, seed: int | None = None, state: RandomState | None = None):
self.mixup.set_random_state(seed, state)
return self

def __call__(self, data):
self.mixup.randomize()
result = dict(data)
Expand Down Expand Up @@ -63,6 +69,10 @@ def __init__(
self.mixer = CutMix(batch_size, alpha)
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []

def set_random_state(self, seed: int | None = None, state: RandomState | None = None):
self.mixer.set_random_state(seed, state)
return self

def __call__(self, data):
self.mixer.randomize()
result = dict(data)
Expand All @@ -84,6 +94,10 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo
super().__init__(keys, allow_missing_keys)
self.cutout = CutOut(batch_size)

def set_random_state(self, seed: int | None = None, state: RandomState | None = None):
self.cutout.set_random_state(seed, state)
return self

def __call__(self, data):
result = dict(data)
self.cutout.randomize()
Expand Down
23 changes: 15 additions & 8 deletions tests/test_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


class TestMixup(unittest.TestCase):

def setUp(self) -> None:
set_determinism(seed=0)

Expand Down Expand Up @@ -60,7 +59,6 @@ def test_mixupd(self):


class TestCutMix(unittest.TestCase):

def setUp(self) -> None:
set_determinism(seed=0)

Expand All @@ -75,23 +73,32 @@ def test_cutmix(self):
output = cutmix(sample)
self.assertEqual(output.shape, sample.shape)
self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10)))
# croppings are different on each application... most of the times!
checks = [torch.allclose(sample, cutmix(sample)) for _ in range(1000)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this check might help mitigate the issue, it's not addressing the core problem. After set_determinism(seed=0), we should guarantee consistent transform results every time.

cc @ericspod

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're setting a seed before the test then we should get the same results every time. Any behaviour dependent on a random variable should be dependent by this seed so that we can enforce determinism in cases like this. These transforms should inherit form RandomizableTransform but also set the random state of the internal delegate transform appropriately like here. This ensures the seed value is propagated to the delegate.

Sorry this wasn't discussed in the initial PR for these transforms, it's something we missed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @ericspod for the detailed explanation of how to correctly implement it

Also, I still think that the unit test was missing the case that no transform happens when the mixing only considers one of the samples. Now this is exactly tested and it is checked that it happens as often as expected, so I would leave it as it is now.

Best

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does test that the mixing is a no-op an expected number of times, but I feel that's a problem with the tests not being deterministic. The issue is that for a given seed we should see the exact behaviour, so be able to choose a seed that gives a no-op for a fixed input and test that this continues to happen (changing the seed as needed). This test may pass if the no-op occurs correctly a known number of times, but could also pass if there's a bug causing no-ops at the wrong time but within your threshold. We aren't directly testing that a no-op occurs. Tests should always be deterministic, if we fix the Mixup/CutMix classes to be so we can do that here as well.

# 1000/(32*32*32)
self.assertTrue(sum(checks) < 5)

def test_cutmixd(self):
batch_size = 6
for dims in [2, 3]:
shape = (6, 3) + (32,) * dims
shape = (batch_size, 3) + (32,) * dims
t = torch.rand(*shape, dtype=torch.float32)
label = torch.randint(0, 1, shape)
sample = {"a": t, "b": t, "lbl1": label, "lbl2": label}
cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2"))
output = cutmix(sample)
# croppings are different on each application
self.assertTrue(not torch.allclose(output["a"], output["b"]))
cutmix = CutMixd(["a", "b"], batch_size, label_keys=("lbl1", "lbl2"))
# croppings are different on each application... most of the times!
checks = []
for _ in range(1000):
output = cutmix(sample)
checks.append(torch.allclose(output["a"], output["b"]))
# 1000/(32*32*32)
self.assertTrue(sum(checks) < 5)

# but mixing of labels is not affected by it
self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"]))


class TestCutOut(unittest.TestCase):

def setUp(self) -> None:
set_determinism(seed=0)

Expand Down