From c27492d999372e6963f083b0160a5cc189a2a27a Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 26 Apr 2022 14:31:11 +0100 Subject: [PATCH 1/2] Test fix for AMP kwargs Signed-off-by: Eric Kerfoot --- tests/test_integration_workflows.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 688f664089..c8bd14e3e0 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -54,7 +54,7 @@ from monai.utils import set_determinism from monai.utils.enums import PostFix from tests.testing_data.integration_answers import test_integration_value -from tests.utils import DistTestCase, TimedCall, skip_if_quick +from tests.utils import DistTestCase, TimedCall, skip_if_quick, pytorch_after TASK = "integration_workflows" @@ -149,7 +149,7 @@ def _forward_completed(self, engine): val_handlers=val_handlers, amp=bool(amp), to_kwargs={"memory_format": torch.preserve_format}, - amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, + amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32} if pytorch_after(1, 10, 0) else {}, ) train_postprocessing = Compose( @@ -205,7 +205,7 @@ def _model_completed(self, engine): amp=bool(amp), optim_set_to_none=True, to_kwargs={"memory_format": torch.preserve_format}, - amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, + amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32} if pytorch_after(1, 10, 0) else {}, ) trainer.run() From 2d639ae7b991ffa2b713782cb56137bfdd3e5486 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 26 Apr 2022 13:54:05 +0000 Subject: [PATCH 2/2] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_integration_workflows.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index c8bd14e3e0..f748eb8732 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -54,7 +54,7 @@ from monai.utils import set_determinism from monai.utils.enums import PostFix from tests.testing_data.integration_answers import test_integration_value -from tests.utils import DistTestCase, TimedCall, skip_if_quick, pytorch_after +from tests.utils import DistTestCase, TimedCall, pytorch_after, skip_if_quick TASK = "integration_workflows" @@ -205,7 +205,7 @@ def _model_completed(self, engine): amp=bool(amp), optim_set_to_none=True, to_kwargs={"memory_format": torch.preserve_format}, - amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32} if pytorch_after(1, 10, 0) else {}, + amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32} if pytorch_after(1, 10, 0) else {}, ) trainer.run()