diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 6379f49449..b808c24de0 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -60,11 +60,14 @@ def use_factory(fact_args): layer = use_factory( (fact.TEST, kwargs) ) """ +import warnings from typing import Any, Callable, Dict, Tuple, Type, Union import torch.nn as nn -from monai.utils import look_up_option +from monai.utils import look_up_option, optional_import + +InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] @@ -242,6 +245,30 @@ def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]: return nn.SyncBatchNorm +@Norm.factory_function("instance_nvfuser") +def instance_nvfuser_factory(dim): + """ + `InstanceNorm3dNVFuser` is a faster verison of InstanceNorm layer and implemented in `apex`. + It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS. + In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist, + `nn.InstanceNorm3d` will be returned instead. + This layer is based on a customized autograd function, which is not supported in TorchScript currently. + Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary. + + Please check the following link for more details about how to install `apex`: + https://github.com/NVIDIA/apex#installation + + """ + types = (nn.InstanceNorm1d, nn.InstanceNorm2d) + if dim != 3: + warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.") + return types[dim - 1] + if not has_nvfuser: + warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.") + return nn.InstanceNorm3d + return InstanceNorm3dNVFuser + + Act.add_factory_callable("elu", lambda: nn.modules.ELU) Act.add_factory_callable("relu", lambda: nn.modules.ReLU) Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU) diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index e858dcbb9b..053ab255b8 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -104,6 +104,8 @@ class DynUNet(nn.Module): If not specified, the way which nnUNet used will be employed. Defaults to ``None``. dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. + `INSTANCE_NVFUSER` is a faster version of the instance norm layer, it can be used when: + 1) `spatial_dims=3`, 2) CUDA device is available, 3) `apex` is installed and 4) non-Windows OS is used. act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. If ``True``, in training mode, the forward function will output not only the final feature map diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 36ac9d0309..14006b96e6 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,7 +17,10 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet -from tests.utils import test_script_save +from monai.utils import optional_import +from tests.utils import skip_if_no_cuda, skip_if_windows, test_script_save + +_, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") device = "cuda" if torch.cuda.is_available() else "cpu" @@ -118,6 +121,33 @@ def test_script(self): test_script_save(net, test_data) +@skip_if_no_cuda +@skip_if_windows +@unittest.skipUnless(has_nvfuser, "To use `instance_nvfuser`, `apex.normalization.InstanceNorm3dNVFuser` is needed.") +class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase): + @parameterized.expand([TEST_CASE_DYNUNET_3D[0]]) + def test_consistency(self, input_param, input_shape, _): + for eps in [1e-4, 1e-5]: + for momentum in [0.1, 0.01]: + for affine in [True, False]: + norm_param = {"eps": eps, "momentum": momentum, "affine": affine} + input_param["norm_name"] = ("instance", norm_param) + input_param_fuser = input_param.copy() + input_param_fuser["norm_name"] = ("instance_nvfuser", norm_param) + for memory_format in [torch.contiguous_format, torch.channels_last_3d]: + net = DynUNet(**input_param).to("cuda:0", memory_format=memory_format) + net_fuser = DynUNet(**input_param_fuser).to("cuda:0", memory_format=memory_format) + net_fuser.load_state_dict(net.state_dict()) + + input_tensor = torch.randn(input_shape).to("cuda:0", memory_format=memory_format) + with eval_mode(net): + result = net(input_tensor) + with eval_mode(net_fuser): + result_fuser = net_fuser(input_tensor) + + torch.testing.assert_close(result, result_fuser) + + class TestDynUNetDeepSupervision(unittest.TestCase): @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) def test_shape(self, input_param, input_shape, expected_shape):