From 0dec90a4c9050465c1fda151a445137ddeac0d37 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 24 Dec 2021 11:43:57 +0800 Subject: [PATCH] [DLMED] enhance set_determinism Signed-off-by: Nic Ma --- monai/utils/misc.py | 4 ++++ tests/test_set_determinism.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 4fe63744fd..eae0580696 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -251,6 +251,10 @@ def set_determinism( for func in additional_settings: func(seed) + if torch.backends.flags_frozen(): + warnings.warn("PyTorch global flag support of backends is disabled, enable it to set global `cudnn` flags.") + torch.backends.__allow_nonbracketed_mutation_flag = True + if seed is not None: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False diff --git a/tests/test_set_determinism.py b/tests/test_set_determinism.py index 31b3254f5b..7d6c54909d 100644 --- a/tests/test_set_determinism.py +++ b/tests/test_set_determinism.py @@ -40,6 +40,8 @@ def test_values(self): self.assertEqual(seed, get_seed()) a = np.random.randint(seed) b = torch.randint(seed, (1,)) + # tset when global flag support is disabled + torch.backends.disable_global_flags() set_determinism(seed=seed) c = np.random.randint(seed) d = torch.randint(seed, (1,))