diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 688f664089..f748eb8732 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -54,7 +54,7 @@ from monai.utils import set_determinism from monai.utils.enums import PostFix from tests.testing_data.integration_answers import test_integration_value -from tests.utils import DistTestCase, TimedCall, skip_if_quick +from tests.utils import DistTestCase, TimedCall, pytorch_after, skip_if_quick TASK = "integration_workflows" @@ -149,7 +149,7 @@ def _forward_completed(self, engine): val_handlers=val_handlers, amp=bool(amp), to_kwargs={"memory_format": torch.preserve_format}, - amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, + amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32} if pytorch_after(1, 10, 0) else {}, ) train_postprocessing = Compose( @@ -205,7 +205,7 @@ def _model_completed(self, engine): amp=bool(amp), optim_set_to_none=True, to_kwargs={"memory_format": torch.preserve_format}, - amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32}, + amp_kwargs={"dtype": torch.float16 if bool(amp) else torch.float32} if pytorch_after(1, 10, 0) else {}, ) trainer.run()