From dbb5eef808a95a75218a7fa8e563d5f3dfe99021 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 28 Apr 2022 15:27:49 +0800 Subject: [PATCH 01/12] implement the base class Signed-off-by: Yiheng Wang --- .../layers/instance_norm_3dnvfuser.py | 191 ++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 monai/networks/layers/instance_norm_3dnvfuser.py diff --git a/monai/networks/layers/instance_norm_3dnvfuser.py b/monai/networks/layers/instance_norm_3dnvfuser.py new file mode 100644 index 0000000000..d96c985bb8 --- /dev/null +++ b/monai/networks/layers/instance_norm_3dnvfuser.py @@ -0,0 +1,191 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import Tensor +from torch.nn.modules.batchnorm import _NormBase + +from monai.utils import optional_import + +instance_norm_nvfuser_cuda, has_nvfuser = optional_import("instance_norm_nvfuser_cuda") + + +class InstanceNormNVFuserFunction(torch.autograd.Function): + @staticmethod + def forward( # type: ignore + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + running_mean: torch.Tensor, + running_var: torch.Tensor, + use_input_stats: bool, + momentum: float, + eps: float, + ): + + channels_last = input.is_contiguous(memory_format=torch.channels_last) or input.is_contiguous( + memory_format=torch.channels_last_3d + ) + # for channels_last format input, reorder it into NCHW[D] format + if channels_last: + order = [0] + [i for i in range(2, len(input.shape))] + [1] + _input = input.permute(order) + else: + _input = input + if not _input.is_contiguous(): + raise AssertionError("In NCHW[D] order, `input` must be contiguous.") + result = instance_norm_nvfuser_cuda.forward( + _input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, channels_last + ) + if len(result) == 3: + out, mean, invstd = result + else: + running_mean, running_var, out, mean, invstd = result + ctx.use_input_stats = use_input_stats + ctx.eps = eps + ctx.channels_last = channels_last + # saving for backward in "explicit channels-last format" + ctx.save_for_backward(_input, weight, running_mean, running_var, mean, invstd) + if channels_last: + order = [0, len(_input.shape) - 1] + [i for i in range(1, len(_input.shape) - 1)] + out = out.permute(order) + + if len(out.shape) == 4: + memory_format = torch.channels_last + elif len(out.shape) == 5: + memory_format = torch.channels_last_3d + else: + raise AssertionError("unhandled channels_last format variation in forward.") + if not out.is_contiguous(memory_format=memory_format): + raise AssertionError(f"In {memory_format} order, output of forward is not contiguous.") + if not input.is_contiguous(memory_format=memory_format): + raise AssertionError(f"In {memory_format} order, `input` is not contiguous.") + + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): # type: ignore + + if ctx.channels_last: + order = [0] + [i for i in range(2, len(grad_output.shape))] + [1] + grad_output = grad_output.permute(order) + # input was saved in "explicit channels-last format" + if not ctx.saved_tensors[0].is_contiguous(): + raise AssertionError("In NCHW order, `ctx.saved_tensors[0]` is not contiguous.") + grad_output = grad_output.contiguous() + saved = list(ctx.saved_tensors) + saved.insert(1, grad_output) + grad_input, grad_weight, grad_bias = instance_norm_nvfuser_cuda.backward( + *saved, ctx.use_input_stats, ctx.eps, ctx.channels_last + ) + if ctx.channels_last: + order = [0, len(grad_input.shape) - 1] + [i for i in range(1, len(grad_input.shape) - 1)] + grad_input = grad_input.permute(order) + if len(grad_input.shape) == 4: + memory_format = torch.channels_last + elif len(grad_input.shape) == 5: + memory_format = torch.channels_last_3d + else: + raise AssertionError("unhandled channels_last format variation in backward.") + if not grad_input.is_contiguous(memory_format=memory_format): + raise AssertionError(f"In {memory_format} order, output of backward is not contiguous.") + + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + + +class _InstanceNormNVFuser(_NormBase): + """ + Base of InstanceNorm3dNVFuser. This class only works on non-Windows OS and tensors in GPU mode. + This class refers to `APEX`. + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(_InstanceNormNVFuser, self).__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.dummy = torch.empty([], device="cuda") + + def _check_input_dim(self, input: torch.Tensor): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + # at version 1: removed running_mean and running_var when + # track_running_stats=False (default) + if version is None and not self.track_running_stats: + running_stats_keys = [] + for name in ("running_mean", "running_var"): + key = prefix + name + if key in state_dict: + running_stats_keys.append(key) + if len(running_stats_keys) > 0: + error_msgs.append( + "Unexpected running stats buffer(s) {names} for {klass} " + "with track_running_stats=False. If state_dict is a " + "checkpoint saved before 0.4.0, this may be expected " + "because {klass} does not track running stats by default " + "since 0.4.0. Please remove these keys from state_dict. If " + "the running stats are actually needed, instead set " + "track_running_stats=True in {klass} to enable them. See " + "the documentation of {klass} for details.".format( + names=" and ".join("{}".format(k) for k in running_stats_keys), klass=self.__class__.__name__ + ) + ) + for key in running_stats_keys: + state_dict.pop(key) + + super(_InstanceNormNVFuser, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, input: Tensor): + if not input.is_cuda: + raise AssertionError("NVFuser InstanceNorm is CUDA only.") + self._check_input_dim(input) + + out = InstanceNormNVFuserFunction.apply( + input=input, + weight=self.weight if self.weight is not None else self.dummy, + bias=self.bias if self.bias is not None else self.dummy, + running_mean=self.running_mean if self.running_mean is not None else self.dummy, + running_var=self.running_var if self.running_mean is not None else self.dummy, + use_input_stats=self.training or not self.track_running_stats, + momentum=self.momentum, + eps=self.eps, + ) + + return out + + +class InstanceNorm3dNVFuser(_InstanceNormNVFuser): + def _check_input_dim(self, input: torch.Tensor): + if input.dim() != 5: + raise ValueError("expected 5D input (got {}D input)".format(input.dim())) From 0163f39c91277a4b445e7ca04426e61ccdf0f7f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Apr 2022 07:29:33 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/layers/instance_norm_3dnvfuser.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/layers/instance_norm_3dnvfuser.py b/monai/networks/layers/instance_norm_3dnvfuser.py index d96c985bb8..105a510d1e 100644 --- a/monai/networks/layers/instance_norm_3dnvfuser.py +++ b/monai/networks/layers/instance_norm_3dnvfuser.py @@ -119,7 +119,7 @@ def __init__( dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - super(_InstanceNormNVFuser, self).__init__( + super().__init__( num_features, eps, momentum, affine, track_running_stats, **factory_kwargs ) self.dummy = torch.empty([], device="cuda") @@ -156,13 +156,13 @@ def _load_from_state_dict( "the running stats are actually needed, instead set " "track_running_stats=True in {klass} to enable them. See " "the documentation of {klass} for details.".format( - names=" and ".join("{}".format(k) for k in running_stats_keys), klass=self.__class__.__name__ + names=" and ".join(f"{k}" for k in running_stats_keys), klass=self.__class__.__name__ ) ) for key in running_stats_keys: state_dict.pop(key) - super(_InstanceNormNVFuser, self)._load_from_state_dict( + super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) @@ -188,4 +188,4 @@ def forward(self, input: Tensor): class InstanceNorm3dNVFuser(_InstanceNormNVFuser): def _check_input_dim(self, input: torch.Tensor): if input.dim() != 5: - raise ValueError("expected 5D input (got {}D input)".format(input.dim())) + raise ValueError(f"expected 5D input (got {input.dim()}D input)") From 09e85593af2c016338dc9fbdd6c75923cd36bf69 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 28 Apr 2022 20:05:59 +0800 Subject: [PATCH 03/12] add unittest Signed-off-by: Yiheng Wang --- monai/networks/layers/__init__.py | 2 +- monai/networks/layers/factories.py | 19 +++++- ...{instance_norm_3dnvfuser.py => nvfuser.py} | 28 +++++---- tests/test_instancenorm_nvfuser.py | 60 +++++++++++++++++++ 4 files changed, 97 insertions(+), 12 deletions(-) rename monai/networks/layers/{instance_norm_3dnvfuser.py => nvfuser.py} (90%) create mode 100644 tests/test_instancenorm_nvfuser.py diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 5115c00af3..4aa912d38d 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -26,4 +26,4 @@ separable_filtering, ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push -from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer +from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer \ No newline at end of file diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 6379f49449..9d14fa52db 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -64,7 +64,13 @@ def use_factory(fact_args): import torch.nn as nn -from monai.utils import look_up_option +import warnings + +from monai.utils import look_up_option, optional_import + +instance_norm_nvfuser_cuda, has_nvfuser = optional_import("instance_norm_nvfuser_cuda") +if has_nvfuser: + from monai.networks.layers.nvfuser import InstanceNorm3dNVFuser __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] @@ -241,6 +247,17 @@ def local_response_factory(_dim) -> Type[nn.LocalResponseNorm]: def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]: return nn.SyncBatchNorm +@Norm.factory_function("instance_nvfuser") +def instance_nvfuser_factory(_dim): + types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) + if _dim != 3: + warnings.warn("Only 3d instance norm nvfuser has been implemented, use common instance norm instead.") + return types[dim-1] + if not has_nvfuser: + warnings.warn("`instance_norm_nvfuser_cuda` is not installed, use common instance norm instead.") + return types[dim-1] + return InstanceNorm3dNVFuser + Act.add_factory_callable("elu", lambda: nn.modules.ELU) Act.add_factory_callable("relu", lambda: nn.modules.ReLU) diff --git a/monai/networks/layers/instance_norm_3dnvfuser.py b/monai/networks/layers/nvfuser.py similarity index 90% rename from monai/networks/layers/instance_norm_3dnvfuser.py rename to monai/networks/layers/nvfuser.py index d96c985bb8..21d1de9ac9 100644 --- a/monai/networks/layers/instance_norm_3dnvfuser.py +++ b/monai/networks/layers/nvfuser.py @@ -15,7 +15,9 @@ from monai.utils import optional_import -instance_norm_nvfuser_cuda, has_nvfuser = optional_import("instance_norm_nvfuser_cuda") +instance_norm_nvfuser_cuda, _ = optional_import("instance_norm_nvfuser_cuda") + +__all__ = ["InstanceNorm3dNVFuser"] class InstanceNormNVFuserFunction(torch.autograd.Function): @@ -104,7 +106,7 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore class _InstanceNormNVFuser(_NormBase): """ - Base of InstanceNorm3dNVFuser. This class only works on non-Windows OS and tensors in GPU mode. + Base of InstanceNorm3dNVFuser. This class only works on non-Windows OS and input tensors should be in GPU mode. This class refers to `APEX`. """ @@ -172,20 +174,26 @@ def forward(self, input: Tensor): self._check_input_dim(input) out = InstanceNormNVFuserFunction.apply( - input=input, - weight=self.weight if self.weight is not None else self.dummy, - bias=self.bias if self.bias is not None else self.dummy, - running_mean=self.running_mean if self.running_mean is not None else self.dummy, - running_var=self.running_var if self.running_mean is not None else self.dummy, - use_input_stats=self.training or not self.track_running_stats, - momentum=self.momentum, - eps=self.eps, + input, + self.weight if self.weight is not None else self.dummy, + self.bias if self.bias is not None else self.dummy, + self.running_mean if self.running_mean is not None else self.dummy, + self.running_var if self.running_mean is not None else self.dummy, + self.training or not self.track_running_stats, + self.momentum, + self.eps, ) return out class InstanceNorm3dNVFuser(_InstanceNormNVFuser): + """ + A faster version of 3d instance norm layer. + This class only works on non-Windows OS and input tensors should be in GPU mode. + This class refers to `APEX`. + """ + def _check_input_dim(self, input: torch.Tensor): if input.dim() != 5: raise ValueError("expected 5D input (got {}D input)".format(input.dim())) diff --git a/tests/test_instancenorm_nvfuser.py b/tests/test_instancenorm_nvfuser.py new file mode 100644 index 0000000000..19b63e5d37 --- /dev/null +++ b/tests/test_instancenorm_nvfuser.py @@ -0,0 +1,60 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers.nvfuser import InstanceNorm3dNVFuser +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows + +_, has_nvfuser = optional_import("instance_norm_nvfuser_cuda") + + +TEST_CASES = [] +input_shape = (1, 3, 64, 64, 64) +for eps in [1e-4, 1e-5]: + for momentum in [0.1, 0.01]: + for affine in [True, False]: + test_case = [ + { + "num_features": input_shape[1], + "eps": eps, + "momentum": momentum, + "affine": affine, + "device": "cuda", + }, + input_shape, + ] + TEST_CASES.append(test_case) + + +@skip_if_no_cuda +@skip_if_windows +@skip_if_quick +@SkipIfBeforePyTorchVersion((1, 10)) +@unittest.skipUnless(has_nvfuser, "`instance_norm_nvfuser_cuda` is necessary for `InstanceNorm3dNVFuser`.") +class TestInstanceNorm3dNVFuser(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_layer_consistency(self, input_param, input_shape): + input_tensor = torch.randn(input_shape).to("cuda") + in_layer = torch.nn.InstanceNorm3d(**input_param) + in_3dnvfuser_layer = InstanceNorm3dNVFuser(**input_param) + out_in = in_layer(input_tensor) + out_3dnvfuser = in_3dnvfuser_layer(input_tensor) + + torch.testing.assert_close(out_in, out_3dnvfuser) + + +if __name__ == "__main__": + unittest.main() From 9443d8667f1f9640a7b2022cd36664743a4c2f67 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Apr 2022 12:06:47 +0000 Subject: [PATCH 04/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/layers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 4aa912d38d..5115c00af3 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -26,4 +26,4 @@ separable_filtering, ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push -from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer \ No newline at end of file +from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer From caade5eb56c781768848fa403293fe11e2940242 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 28 Apr 2022 20:07:42 +0800 Subject: [PATCH 05/12] autofix Signed-off-by: Yiheng Wang --- monai/networks/layers/__init__.py | 2 +- monai/networks/layers/factories.py | 8 ++++---- monai/networks/layers/nvfuser.py | 13 ++----------- tests/test_instancenorm_nvfuser.py | 8 +------- 4 files changed, 8 insertions(+), 23 deletions(-) diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 4aa912d38d..5115c00af3 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -26,4 +26,4 @@ separable_filtering, ) from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push -from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer \ No newline at end of file +from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 9d14fa52db..f000972bb2 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -60,12 +60,11 @@ def use_factory(fact_args): layer = use_factory( (fact.TEST, kwargs) ) """ +import warnings from typing import Any, Callable, Dict, Tuple, Type, Union import torch.nn as nn -import warnings - from monai.utils import look_up_option, optional_import instance_norm_nvfuser_cuda, has_nvfuser = optional_import("instance_norm_nvfuser_cuda") @@ -247,15 +246,16 @@ def local_response_factory(_dim) -> Type[nn.LocalResponseNorm]: def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]: return nn.SyncBatchNorm + @Norm.factory_function("instance_nvfuser") def instance_nvfuser_factory(_dim): types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) if _dim != 3: warnings.warn("Only 3d instance norm nvfuser has been implemented, use common instance norm instead.") - return types[dim-1] + return types[dim - 1] if not has_nvfuser: warnings.warn("`instance_norm_nvfuser_cuda` is not installed, use common instance norm instead.") - return types[dim-1] + return types[dim - 1] return InstanceNorm3dNVFuser diff --git a/monai/networks/layers/nvfuser.py b/monai/networks/layers/nvfuser.py index 6598a74fb3..e9c51a9877 100644 --- a/monai/networks/layers/nvfuser.py +++ b/monai/networks/layers/nvfuser.py @@ -121,23 +121,14 @@ def __init__( dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - super().__init__( - num_features, eps, momentum, affine, track_running_stats, **factory_kwargs - ) + super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs) self.dummy = torch.empty([], device="cuda") def _check_input_dim(self, input: torch.Tensor): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): version = local_metadata.get("version", None) # at version 1: removed running_mean and running_var when diff --git a/tests/test_instancenorm_nvfuser.py b/tests/test_instancenorm_nvfuser.py index 19b63e5d37..e8bf4681f0 100644 --- a/tests/test_instancenorm_nvfuser.py +++ b/tests/test_instancenorm_nvfuser.py @@ -27,13 +27,7 @@ for momentum in [0.1, 0.01]: for affine in [True, False]: test_case = [ - { - "num_features": input_shape[1], - "eps": eps, - "momentum": momentum, - "affine": affine, - "device": "cuda", - }, + {"num_features": input_shape[1], "eps": eps, "momentum": momentum, "affine": affine, "device": "cuda"}, input_shape, ] TEST_CASES.append(test_case) From 8d69e51a6812c6be6e2f52780a8164ad508edb36 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 28 Apr 2022 21:29:27 +0800 Subject: [PATCH 06/12] switch to call apex directly Signed-off-by: Yiheng Wang --- monai/networks/layers/factories.py | 21 ++-- monai/networks/layers/nvfuser.py | 190 ----------------------------- monai/networks/nets/dynunet.py | 2 + tests/test_dynunet.py | 75 ++++++++---- tests/test_instancenorm_nvfuser.py | 54 -------- 5 files changed, 67 insertions(+), 275 deletions(-) delete mode 100644 monai/networks/layers/nvfuser.py delete mode 100644 tests/test_instancenorm_nvfuser.py diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index f000972bb2..e099983103 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -67,9 +67,7 @@ def use_factory(fact_args): from monai.utils import look_up_option, optional_import -instance_norm_nvfuser_cuda, has_nvfuser = optional_import("instance_norm_nvfuser_cuda") -if has_nvfuser: - from monai.networks.layers.nvfuser import InstanceNorm3dNVFuser +InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] @@ -248,14 +246,21 @@ def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]: @Norm.factory_function("instance_nvfuser") -def instance_nvfuser_factory(_dim): +def instance_nvfuser_factory(dim): + """ + `InstanceNorm3dNVFuser` is a faster verison of InstanceNorm layer and implemented in `apex`. + It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS. + In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist, + `nn.InstanceNorm3d` will be returned instead. + + """ types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) - if _dim != 3: - warnings.warn("Only 3d instance norm nvfuser has been implemented, use common instance norm instead.") + if dim != 3: + warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.") return types[dim - 1] if not has_nvfuser: - warnings.warn("`instance_norm_nvfuser_cuda` is not installed, use common instance norm instead.") - return types[dim - 1] + warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.") + return nn.InstanceNorm3d return InstanceNorm3dNVFuser diff --git a/monai/networks/layers/nvfuser.py b/monai/networks/layers/nvfuser.py deleted file mode 100644 index e9c51a9877..0000000000 --- a/monai/networks/layers/nvfuser.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import Tensor -from torch.nn.modules.batchnorm import _NormBase - -from monai.utils import optional_import - -instance_norm_nvfuser_cuda, _ = optional_import("instance_norm_nvfuser_cuda") - -__all__ = ["InstanceNorm3dNVFuser"] - - -class InstanceNormNVFuserFunction(torch.autograd.Function): - @staticmethod - def forward( # type: ignore - ctx, - input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - running_mean: torch.Tensor, - running_var: torch.Tensor, - use_input_stats: bool, - momentum: float, - eps: float, - ): - - channels_last = input.is_contiguous(memory_format=torch.channels_last) or input.is_contiguous( - memory_format=torch.channels_last_3d - ) - # for channels_last format input, reorder it into NCHW[D] format - if channels_last: - order = [0] + [i for i in range(2, len(input.shape))] + [1] - _input = input.permute(order) - else: - _input = input - if not _input.is_contiguous(): - raise AssertionError("In NCHW[D] order, `input` must be contiguous.") - result = instance_norm_nvfuser_cuda.forward( - _input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, channels_last - ) - if len(result) == 3: - out, mean, invstd = result - else: - running_mean, running_var, out, mean, invstd = result - ctx.use_input_stats = use_input_stats - ctx.eps = eps - ctx.channels_last = channels_last - # saving for backward in "explicit channels-last format" - ctx.save_for_backward(_input, weight, running_mean, running_var, mean, invstd) - if channels_last: - order = [0, len(_input.shape) - 1] + [i for i in range(1, len(_input.shape) - 1)] - out = out.permute(order) - - if len(out.shape) == 4: - memory_format = torch.channels_last - elif len(out.shape) == 5: - memory_format = torch.channels_last_3d - else: - raise AssertionError("unhandled channels_last format variation in forward.") - if not out.is_contiguous(memory_format=memory_format): - raise AssertionError(f"In {memory_format} order, output of forward is not contiguous.") - if not input.is_contiguous(memory_format=memory_format): - raise AssertionError(f"In {memory_format} order, `input` is not contiguous.") - - return out - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): # type: ignore - - if ctx.channels_last: - order = [0] + [i for i in range(2, len(grad_output.shape))] + [1] - grad_output = grad_output.permute(order) - # input was saved in "explicit channels-last format" - if not ctx.saved_tensors[0].is_contiguous(): - raise AssertionError("In NCHW order, `ctx.saved_tensors[0]` is not contiguous.") - grad_output = grad_output.contiguous() - saved = list(ctx.saved_tensors) - saved.insert(1, grad_output) - grad_input, grad_weight, grad_bias = instance_norm_nvfuser_cuda.backward( - *saved, ctx.use_input_stats, ctx.eps, ctx.channels_last - ) - if ctx.channels_last: - order = [0, len(grad_input.shape) - 1] + [i for i in range(1, len(grad_input.shape) - 1)] - grad_input = grad_input.permute(order) - if len(grad_input.shape) == 4: - memory_format = torch.channels_last - elif len(grad_input.shape) == 5: - memory_format = torch.channels_last_3d - else: - raise AssertionError("unhandled channels_last format variation in backward.") - if not grad_input.is_contiguous(memory_format=memory_format): - raise AssertionError(f"In {memory_format} order, output of backward is not contiguous.") - - return grad_input, grad_weight, grad_bias, None, None, None, None, None, None - - -class _InstanceNormNVFuser(_NormBase): - """ - Base of InstanceNorm3dNVFuser. This class only works on non-Windows OS and input tensors should be in GPU mode. - This class refers to `APEX`. - """ - - def __init__( - self, - num_features: int, - eps: float = 1e-5, - momentum: float = 0.1, - affine: bool = False, - track_running_stats: bool = False, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs) - self.dummy = torch.empty([], device="cuda") - - def _check_input_dim(self, input: torch.Tensor): - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - version = local_metadata.get("version", None) - # at version 1: removed running_mean and running_var when - # track_running_stats=False (default) - if version is None and not self.track_running_stats: - running_stats_keys = [] - for name in ("running_mean", "running_var"): - key = prefix + name - if key in state_dict: - running_stats_keys.append(key) - if len(running_stats_keys) > 0: - error_msgs.append( - "Unexpected running stats buffer(s) {names} for {klass} " - "with track_running_stats=False. If state_dict is a " - "checkpoint saved before 0.4.0, this may be expected " - "because {klass} does not track running stats by default " - "since 0.4.0. Please remove these keys from state_dict. If " - "the running stats are actually needed, instead set " - "track_running_stats=True in {klass} to enable them. See " - "the documentation of {klass} for details.".format( - names=" and ".join(f"{k}" for k in running_stats_keys), klass=self.__class__.__name__ - ) - ) - for key in running_stats_keys: - state_dict.pop(key) - - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - - def forward(self, input: Tensor): - if not input.is_cuda: - raise AssertionError("NVFuser InstanceNorm is CUDA only.") - self._check_input_dim(input) - - out = InstanceNormNVFuserFunction.apply( - input, - self.weight if self.weight is not None else self.dummy, - self.bias if self.bias is not None else self.dummy, - self.running_mean if self.running_mean is not None else self.dummy, - self.running_var if self.running_mean is not None else self.dummy, - self.training or not self.track_running_stats, - self.momentum, - self.eps, - ) - - return out - - -class InstanceNorm3dNVFuser(_InstanceNormNVFuser): - """ - A faster version of 3d instance norm layer. - This class only works on non-Windows OS and input tensors should be in GPU mode. - This class refers to `APEX`. - """ - - def _check_input_dim(self, input: torch.Tensor): - if input.dim() != 5: - raise ValueError(f"expected 5D input (got {input.dim()}D input)") diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index e858dcbb9b..053ab255b8 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -104,6 +104,8 @@ class DynUNet(nn.Module): If not specified, the way which nnUNet used will be employed. Defaults to ``None``. dropout: dropout ratio. Defaults to no dropout. norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``. + `INSTANCE_NVFUSER` is a faster version of the instance norm layer, it can be used when: + 1) `spatial_dims=3`, 2) CUDA device is available, 3) `apex` is installed and 4) non-Windows OS is used. act_name: activation layer type and arguments. Defaults to ``leakyrelu``. deep_supervision: whether to add deep supervision head before output. Defaults to ``False``. If ``True``, in training mode, the forward function will output not only the final feature map diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 36ac9d0309..6bc0bbf7c2 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -17,7 +17,10 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet -from tests.utils import test_script_save +from monai.utils import optional_import +from tests.utils import skip_if_no_cuda, skip_if_windows, test_script_save + +_, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") device = "cuda" if torch.cuda.is_available() else "cpu" @@ -103,28 +106,54 @@ TEST_CASE_DEEP_SUPERVISION.append(test_case) -class TestDynUNet(unittest.TestCase): - @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNet(**input_param).to(device) - with eval_mode(net): - result = net(torch.randn(input_shape).to(device)) - self.assertEqual(result.shape, expected_shape) - - def test_script(self): - input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] - net = DynUNet(**input_param) - test_data = torch.randn(input_shape) - test_script_save(net, test_data) - - -class TestDynUNetDeepSupervision(unittest.TestCase): - @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) - def test_shape(self, input_param, input_shape, expected_shape): - net = DynUNet(**input_param).to(device) - with torch.no_grad(): - results = net(torch.randn(input_shape).to(device)) - self.assertEqual(results.shape, expected_shape) +# class TestDynUNet(unittest.TestCase): +# @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) +# def test_shape(self, input_param, input_shape, expected_shape): +# net = DynUNet(**input_param).to(device) +# with eval_mode(net): +# result = net(torch.randn(input_shape).to(device)) +# self.assertEqual(result.shape, expected_shape) + +# def test_script(self): +# input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] +# net = DynUNet(**input_param) +# test_data = torch.randn(input_shape) +# test_script_save(net, test_data) + + +@skip_if_no_cuda +@skip_if_windows +@unittest.skipUnless(has_nvfuser, "To use `instance_nvfuser`, `apex.normalization.InstanceNorm3dNVFuser` is needed.") +class TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase): + @parameterized.expand([TEST_CASE_DYNUNET_3D[0]]) + def test_consistency(self, input_param, input_shape, _): + for eps in [1e-4, 1e-5]: + for momentum in [0.1, 0.01]: + for affine in [True, False]: + norm_param = {"eps": eps, "momentum": momentum, "affine": affine} + input_param["norm_name"] = ("instance", norm_param) + input_param_fuser = input_param.copy() + input_param_fuser["norm_name"] = ("instance_nvfuser", norm_param) + net = DynUNet(**input_param).to("cuda") + net_fuser = DynUNet(**input_param_fuser).to("cuda") + net_fuser.load_state_dict(net.state_dict()) + + input_tensor = torch.randn(input_shape).to("cuda") + with eval_mode(net): + result = net(input_tensor) + with eval_mode(net_fuser): + result_fuser = net_fuser(input_tensor) + + torch.testing.assert_close(result, result_fuser) + + +# class TestDynUNetDeepSupervision(unittest.TestCase): +# @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) +# def test_shape(self, input_param, input_shape, expected_shape): +# net = DynUNet(**input_param).to(device) +# with torch.no_grad(): +# results = net(torch.randn(input_shape).to(device)) +# self.assertEqual(results.shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_instancenorm_nvfuser.py b/tests/test_instancenorm_nvfuser.py deleted file mode 100644 index e8bf4681f0..0000000000 --- a/tests/test_instancenorm_nvfuser.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks.layers.nvfuser import InstanceNorm3dNVFuser -from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows - -_, has_nvfuser = optional_import("instance_norm_nvfuser_cuda") - - -TEST_CASES = [] -input_shape = (1, 3, 64, 64, 64) -for eps in [1e-4, 1e-5]: - for momentum in [0.1, 0.01]: - for affine in [True, False]: - test_case = [ - {"num_features": input_shape[1], "eps": eps, "momentum": momentum, "affine": affine, "device": "cuda"}, - input_shape, - ] - TEST_CASES.append(test_case) - - -@skip_if_no_cuda -@skip_if_windows -@skip_if_quick -@SkipIfBeforePyTorchVersion((1, 10)) -@unittest.skipUnless(has_nvfuser, "`instance_norm_nvfuser_cuda` is necessary for `InstanceNorm3dNVFuser`.") -class TestInstanceNorm3dNVFuser(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_layer_consistency(self, input_param, input_shape): - input_tensor = torch.randn(input_shape).to("cuda") - in_layer = torch.nn.InstanceNorm3d(**input_param) - in_3dnvfuser_layer = InstanceNorm3dNVFuser(**input_param) - out_in = in_layer(input_tensor) - out_3dnvfuser = in_3dnvfuser_layer(input_tensor) - - torch.testing.assert_close(out_in, out_3dnvfuser) - - -if __name__ == "__main__": - unittest.main() From 332eb6a4118eaf22d7db6d425d1e9d1e882e8a03 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 28 Apr 2022 21:34:47 +0800 Subject: [PATCH 07/12] uncomment unittest Signed-off-by: Yiheng Wang --- tests/test_dynunet.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 6bc0bbf7c2..154986ad15 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -106,19 +106,19 @@ TEST_CASE_DEEP_SUPERVISION.append(test_case) -# class TestDynUNet(unittest.TestCase): -# @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) -# def test_shape(self, input_param, input_shape, expected_shape): -# net = DynUNet(**input_param).to(device) -# with eval_mode(net): -# result = net(torch.randn(input_shape).to(device)) -# self.assertEqual(result.shape, expected_shape) +class TestDynUNet(unittest.TestCase): + @parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D) + def test_shape(self, input_param, input_shape, expected_shape): + net = DynUNet(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) -# def test_script(self): -# input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] -# net = DynUNet(**input_param) -# test_data = torch.randn(input_shape) -# test_script_save(net, test_data) + def test_script(self): + input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0] + net = DynUNet(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) @skip_if_no_cuda @@ -147,13 +147,13 @@ def test_consistency(self, input_param, input_shape, _): torch.testing.assert_close(result, result_fuser) -# class TestDynUNetDeepSupervision(unittest.TestCase): -# @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) -# def test_shape(self, input_param, input_shape, expected_shape): -# net = DynUNet(**input_param).to(device) -# with torch.no_grad(): -# results = net(torch.randn(input_shape).to(device)) -# self.assertEqual(results.shape, expected_shape) +class TestDynUNetDeepSupervision(unittest.TestCase): + @parameterized.expand(TEST_CASE_DEEP_SUPERVISION) + def test_shape(self, input_param, input_shape, expected_shape): + net = DynUNet(**input_param).to(device) + with torch.no_grad(): + results = net(torch.randn(input_shape).to(device)) + self.assertEqual(results.shape, expected_shape) if __name__ == "__main__": From b83b453c660b3dcd057fcd146012af0baebf20a1 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 29 Apr 2022 15:37:13 +0800 Subject: [PATCH 08/12] add apex install link in docstring Signed-off-by: Yiheng Wang --- monai/networks/layers/factories.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index e099983103..53334a31f1 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -253,6 +253,9 @@ def instance_nvfuser_factory(dim): In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist, `nn.InstanceNorm3d` will be returned instead. + Please check the following link for more details about how to install `apex`: + https://github.com/NVIDIA/apex#installation + """ types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) if dim != 3: From d69304907bde18448ad1604bf2430745b865e29e Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 29 Apr 2022 16:00:08 +0800 Subject: [PATCH 09/12] add channels_last_3d test case Signed-off-by: Yiheng Wang --- tests/test_dynunet.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py index 154986ad15..14006b96e6 100644 --- a/tests/test_dynunet.py +++ b/tests/test_dynunet.py @@ -134,17 +134,18 @@ def test_consistency(self, input_param, input_shape, _): input_param["norm_name"] = ("instance", norm_param) input_param_fuser = input_param.copy() input_param_fuser["norm_name"] = ("instance_nvfuser", norm_param) - net = DynUNet(**input_param).to("cuda") - net_fuser = DynUNet(**input_param_fuser).to("cuda") - net_fuser.load_state_dict(net.state_dict()) - - input_tensor = torch.randn(input_shape).to("cuda") - with eval_mode(net): - result = net(input_tensor) - with eval_mode(net_fuser): - result_fuser = net_fuser(input_tensor) - - torch.testing.assert_close(result, result_fuser) + for memory_format in [torch.contiguous_format, torch.channels_last_3d]: + net = DynUNet(**input_param).to("cuda:0", memory_format=memory_format) + net_fuser = DynUNet(**input_param_fuser).to("cuda:0", memory_format=memory_format) + net_fuser.load_state_dict(net.state_dict()) + + input_tensor = torch.randn(input_shape).to("cuda:0", memory_format=memory_format) + with eval_mode(net): + result = net(input_tensor) + with eval_mode(net_fuser): + result_fuser = net_fuser(input_tensor) + + torch.testing.assert_close(result, result_fuser) class TestDynUNetDeepSupervision(unittest.TestCase): From e33a75725bca895e758255b87a92f94f3f57d937 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 29 Apr 2022 22:51:47 +0800 Subject: [PATCH 10/12] rewrite types Signed-off-by: Yiheng Wang --- monai/networks/layers/factories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 53334a31f1..59bdd77f9d 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -263,7 +263,7 @@ def instance_nvfuser_factory(dim): return types[dim - 1] if not has_nvfuser: warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.") - return nn.InstanceNorm3d + return types[dim - 1] return InstanceNorm3dNVFuser From 3595cb14e6ea07a969bfba8fe8086f0d6a1cac0d Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 29 Apr 2022 23:07:42 +0800 Subject: [PATCH 11/12] change types Signed-off-by: Yiheng Wang --- monai/networks/layers/factories.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 59bdd77f9d..9abb99d73d 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -257,13 +257,13 @@ def instance_nvfuser_factory(dim): https://github.com/NVIDIA/apex#installation """ - types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) + types = (nn.InstanceNorm1d, nn.InstanceNorm2d) if dim != 3: warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.") return types[dim - 1] if not has_nvfuser: warnings.warn("`apex.normalization.InstanceNorm3dNVFuser` is not found, use nn.InstanceNorm3d instead.") - return types[dim - 1] + return nn.InstanceNorm3d return InstanceNorm3dNVFuser From b985cb8fb31b762d2a6fac637478fb0c49b149a3 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 5 May 2022 11:08:23 +0800 Subject: [PATCH 12/12] add docstrings Signed-off-by: Yiheng Wang --- monai/networks/layers/factories.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 9abb99d73d..b808c24de0 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -252,6 +252,8 @@ def instance_nvfuser_factory(dim): It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS. In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist, `nn.InstanceNorm3d` will be returned instead. + This layer is based on a customized autograd function, which is not supported in TorchScript currently. + Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary. Please check the following link for more details about how to install `apex`: https://github.com/NVIDIA/apex#installation