Skip to content
19 changes: 17 additions & 2 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ def use_factory(fact_args):
import warnings
from typing import Any, Callable, Dict, Tuple, Type, Union

import torch
import torch.nn as nn

from monai.utils import look_up_option, optional_import

InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")


__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]


Expand Down Expand Up @@ -263,8 +265,21 @@ def instance_nvfuser_factory(dim):
if dim != 3:
warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.")
return types[dim - 1]
if not has_nvfuser:
warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.")
# test InstanceNorm3dNVFuser installation with a basic example
has_nvfuser_flag = has_nvfuser
if not torch.cuda.is_available():
return nn.InstanceNorm3d
try:
layer = InstanceNorm3dNVFuser(num_features=1, affine=True).to("cuda:0")
inp = torch.randn([1, 1, 1, 1, 1]).to("cuda:0")
out = layer(inp)
del inp, out, layer
except Exception:
has_nvfuser_flag = False
if not has_nvfuser_flag:
warnings.warn(
"`apex.normalization.InstanceNorm3dNVFuser` is not installed properly, use nn.InstanceNorm3d instead."
)
return nn.InstanceNorm3d
return InstanceNorm3dNVFuser

Expand Down
11 changes: 6 additions & 5 deletions tests/test_dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

from monai.networks import eval_mode
from monai.networks.nets import DynUNet
from monai.utils import optional_import
from monai.utils.module import pytorch_after
from tests.utils import skip_if_no_cuda, skip_if_windows, test_script_save

_, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")

device = "cuda" if torch.cuda.is_available() else "cpu"

strides: Sequence[Union[Sequence[int], int]]
Expand Down Expand Up @@ -123,7 +121,6 @@ def test_script(self):

@skip_if_no_cuda
@skip_if_windows
@unittest.skipUnless(has_nvfuser, "To use `instance_nvfuser`, `apex.normalization.InstanceNorm3dNVFuser` is needed.")
class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase):
@parameterized.expand([TEST_CASE_DYNUNET_3D[0]])
def test_consistency(self, input_param, input_shape, _):
Expand All @@ -145,7 +142,11 @@ def test_consistency(self, input_param, input_shape, _):
with eval_mode(net_fuser):
result_fuser = net_fuser(input_tensor)

torch.testing.assert_close(result, result_fuser)
# torch.testing.assert_allclose() is deprecated since 1.12 and will be removed in 1.14
if pytorch_after(1, 12):
torch.testing.assert_close(result, result_fuser)
else:
torch.testing.assert_allclose(result, result_fuser)


class TestDynUNetDeepSupervision(unittest.TestCase):
Expand Down