Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ def use_factory(fact_args):
layer = use_factory( (fact.TEST, kwargs) )
"""

import warnings
from typing import Any, Callable, Dict, Tuple, Type, Union

import torch.nn as nn

from monai.utils import look_up_option
from monai.utils import look_up_option, optional_import

InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")

__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]

Expand Down Expand Up @@ -242,6 +245,30 @@ def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]:
return nn.SyncBatchNorm


@Norm.factory_function("instance_nvfuser")
def instance_nvfuser_factory(dim):
"""
`InstanceNorm3dNVFuser` is a faster verison of InstanceNorm layer and implemented in `apex`.
It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS.
In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist,
`nn.InstanceNorm3d` will be returned instead.
This layer is based on a customized autograd function, which is not supported in TorchScript currently.
Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary.

Please check the following link for more details about how to install `apex`:
https://github.com/NVIDIA/apex#installation

"""
types = (nn.InstanceNorm1d, nn.InstanceNorm2d)
if dim != 3:
warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.")
return types[dim - 1]
if not has_nvfuser:
warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.")
return nn.InstanceNorm3d
return InstanceNorm3dNVFuser


Act.add_factory_callable("elu", lambda: nn.modules.ELU)
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
Expand Down
2 changes: 2 additions & 0 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class DynUNet(nn.Module):
If not specified, the way which nnUNet used will be employed. Defaults to ``None``.
dropout: dropout ratio. Defaults to no dropout.
norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.
`INSTANCE_NVFUSER` is a faster version of the instance norm layer, it can be used when:
1) `spatial_dims=3`, 2) CUDA device is available, 3) `apex` is installed and 4) non-Windows OS is used.
act_name: activation layer type and arguments. Defaults to ``leakyrelu``.
deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
If ``True``, in training mode, the forward function will output not only the final feature map
Expand Down
32 changes: 31 additions & 1 deletion tests/test_dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from monai.networks import eval_mode
from monai.networks.nets import DynUNet
from tests.utils import test_script_save
from monai.utils import optional_import
from tests.utils import skip_if_no_cuda, skip_if_windows, test_script_save

_, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")

device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -118,6 +121,33 @@ def test_script(self):
test_script_save(net, test_data)


@skip_if_no_cuda
@skip_if_windows
@unittest.skipUnless(has_nvfuser, "To use `instance_nvfuser`, `apex.normalization.InstanceNorm3dNVFuser` is needed.")
class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase):
@parameterized.expand([TEST_CASE_DYNUNET_3D[0]])
def test_consistency(self, input_param, input_shape, _):
for eps in [1e-4, 1e-5]:
for momentum in [0.1, 0.01]:
for affine in [True, False]:
norm_param = {"eps": eps, "momentum": momentum, "affine": affine}
input_param["norm_name"] = ("instance", norm_param)
input_param_fuser = input_param.copy()
input_param_fuser["norm_name"] = ("instance_nvfuser", norm_param)
for memory_format in [torch.contiguous_format, torch.channels_last_3d]:
net = DynUNet(**input_param).to("cuda:0", memory_format=memory_format)
net_fuser = DynUNet(**input_param_fuser).to("cuda:0", memory_format=memory_format)
net_fuser.load_state_dict(net.state_dict())

input_tensor = torch.randn(input_shape).to("cuda:0", memory_format=memory_format)
with eval_mode(net):
result = net(input_tensor)
with eval_mode(net_fuser):
result_fuser = net_fuser(input_tensor)

torch.testing.assert_close(result, result_fuser)


class TestDynUNetDeepSupervision(unittest.TestCase):
@parameterized.expand(TEST_CASE_DEEP_SUPERVISION)
def test_shape(self, input_param, input_shape, expected_shape):
Expand Down