diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index b808c24de0..89fe1912a5 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -63,12 +63,14 @@ def use_factory(fact_args): import warnings from typing import Any, Callable, Dict, Tuple, Type, Union +import torch import torch.nn as nn from monai.utils import look_up_option, optional_import InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") + __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] @@ -263,8 +265,21 @@ def instance_nvfuser_factory(dim): if dim != 3: warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.") return types[dim - 1] - if not has_nvfuser: - warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.") + # test InstanceNorm3dNVFuser installation with a basic example + has_nvfuser_flag = has_nvfuser + if not torch.cuda.is_available(): + return nn.InstanceNorm3d + try: + layer = InstanceNorm3dNVFuser(num_features=1, affine=True).to("cuda:0") + inp = torch.randn([1, 1, 1, 1, 1]).to("cuda:0") + out = layer(inp) + del inp, out, layer + except Exception: + has_nvfuser_flag = False + if not has_nvfuser_flag: + warnings.warn( + "`apex.normalization.InstanceNorm3dNVFuser` is not installed properly, use nn.InstanceNorm3d instead." + ) return nn.InstanceNorm3d return InstanceNorm3dNVFuser diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 14006b96e6..ff5d5efbef 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,11 +17,9 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet -from monai.utils import optional_import +from monai.utils.module import pytorch_after from tests.utils import skip_if_no_cuda, skip_if_windows, test_script_save -_, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") - device = "cuda" if torch.cuda.is_available() else "cpu" strides: Sequence[Union[Sequence[int], int]] @@ -123,7 +121,6 @@ def test_script(self): @skip_if_no_cuda @skip_if_windows -@unittest.skipUnless(has_nvfuser, "To use `instance_nvfuser`, `apex.normalization.InstanceNorm3dNVFuser` is needed.") class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase): @parameterized.expand([TEST_CASE_DYNUNET_3D[0]]) def test_consistency(self, input_param, input_shape, _): @@ -145,7 +142,11 @@ def test_consistency(self, input_param, input_shape, _): with eval_mode(net_fuser): result_fuser = net_fuser(input_tensor) - torch.testing.assert_close(result, result_fuser) + # torch.testing.assert_allclose() is deprecated since 1.12 and will be removed in 1.14 + if pytorch_after(1, 12): + torch.testing.assert_close(result, result_fuser) + else: + torch.testing.assert_allclose(result, result_fuser) class TestDynUNetDeepSupervision(unittest.TestCase):