diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 4c98fb9936..3e30987bdc 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -17,7 +17,8 @@ import torch.nn as nn from torch.hub import load_state_dict_from_url -from monai.networks.layers.factories import Conv, Dropout, Norm, Pool +from monai.networks.layers.factories import Conv, Dropout, Pool +from monai.networks.layers.utils import get_act_layer, get_norm_layer __all__ = [ "DenseNet", @@ -40,7 +41,14 @@ class _DenseLayer(nn.Module): def __init__( - self, spatial_dims: int, in_channels: int, growth_rate: int, bn_size: int, dropout_prob: float + self, + spatial_dims: int, + in_channels: int, + growth_rate: int, + bn_size: int, + dropout_prob: float, + act: Union[str, tuple] = ("relu", {"inplace": True}), + norm: Union[str, tuple] = "batch", ) -> None: """ Args: @@ -50,22 +58,23 @@ def __init__( bn_size: multiplicative factor for number of bottle neck layers. (i.e. bn_size * k features in the bottleneck layer) dropout_prob: dropout rate after each dense layer. + act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. """ super(_DenseLayer, self).__init__() out_channels = bn_size * growth_rate conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims] self.layers = nn.Sequential() - self.layers.add_module("norm1", norm_type(in_channels)) - self.layers.add_module("relu1", nn.ReLU(inplace=True)) + self.layers.add_module("norm1", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)) + self.layers.add_module("relu1", get_act_layer(name=act)) self.layers.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False)) - self.layers.add_module("norm2", norm_type(out_channels)) - self.layers.add_module("relu2", nn.ReLU(inplace=True)) + self.layers.add_module("norm2", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels)) + self.layers.add_module("relu2", get_act_layer(name=act)) self.layers.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False)) if dropout_prob > 0: @@ -78,7 +87,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class _DenseBlock(nn.Sequential): def __init__( - self, spatial_dims: int, layers: int, in_channels: int, bn_size: int, growth_rate: int, dropout_prob: float + self, + spatial_dims: int, + layers: int, + in_channels: int, + bn_size: int, + growth_rate: int, + dropout_prob: float, + act: Union[str, tuple] = ("relu", {"inplace": True}), + norm: Union[str, tuple] = "batch", ) -> None: """ Args: @@ -89,30 +106,40 @@ def __init__( (i.e. bn_size * k features in the bottleneck layer) growth_rate: how many filters to add each layer (k in paper). dropout_prob: dropout rate after each dense layer. + act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. """ super(_DenseBlock, self).__init__() for i in range(layers): - layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob) + layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob, act=act, norm=norm) in_channels += growth_rate self.add_module("denselayer%d" % (i + 1), layer) class _Transition(nn.Sequential): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None: + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: Union[str, tuple] = ("relu", {"inplace": True}), + norm: Union[str, tuple] = "batch", + ) -> None: """ Args: spatial_dims: number of spatial dimensions of the input image. in_channels: number of the input channel. out_channels: number of the output classes. + act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. """ super(_Transition, self).__init__() conv_type: Callable = Conv[Conv.CONV, spatial_dims] - norm_type: Callable = Norm[Norm.BATCH, spatial_dims] pool_type: Callable = Pool[Pool.AVG, spatial_dims] - self.add_module("norm", norm_type(in_channels)) - self.add_module("relu", nn.ReLU(inplace=True)) + self.add_module("norm", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)) + self.add_module("relu", get_act_layer(name=act)) self.add_module("conv", conv_type(in_channels, out_channels, kernel_size=1, bias=False)) self.add_module("pool", pool_type(kernel_size=2, stride=2)) @@ -131,6 +158,8 @@ class DenseNet(nn.Module): block_config: how many layers in each pooling block. bn_size: multiplicative factor for number of bottle neck layers. (i.e. bn_size * k features in the bottleneck layer) + act: activation type and arguments. Defaults to relu. + norm: feature normalization type and arguments. Defaults to batch norm. dropout_prob: dropout rate after each dense layer. """ @@ -143,13 +172,14 @@ def __init__( growth_rate: int = 32, block_config: Sequence[int] = (6, 12, 24, 16), bn_size: int = 4, + act: Union[str, tuple] = ("relu", {"inplace": True}), + norm: Union[str, tuple] = "batch", dropout_prob: float = 0.0, ) -> None: super(DenseNet, self).__init__() conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] - norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims] avg_pool_type: Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]] = Pool[ Pool.ADAPTIVEAVG, spatial_dims @@ -159,8 +189,8 @@ def __init__( OrderedDict( [ ("conv0", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)), - ("norm0", norm_type(init_features)), - ("relu0", nn.ReLU(inplace=True)), + ("norm0", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=init_features)), + ("relu0", get_act_layer(name=act)), ("pool0", pool_type(kernel_size=3, stride=2, padding=1)), ] ) @@ -175,14 +205,20 @@ def __init__( bn_size=bn_size, growth_rate=growth_rate, dropout_prob=dropout_prob, + act=act, + norm=norm, ) self.features.add_module(f"denseblock{i + 1}", block) in_channels += num_layers * growth_rate if i == len(block_config) - 1: - self.features.add_module("norm5", norm_type(in_channels)) + self.features.add_module( + "norm5", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) + ) else: _out_channels = in_channels // 2 - trans = _Transition(spatial_dims, in_channels=in_channels, out_channels=_out_channels) + trans = _Transition( + spatial_dims, in_channels=in_channels, out_channels=_out_channels, act=act, norm=norm + ) self.features.add_module(f"transition{i + 1}", trans) in_channels = _out_channels @@ -190,7 +226,7 @@ def __init__( self.class_layers = nn.Sequential( OrderedDict( [ - ("relu", nn.ReLU(inplace=True)), + ("relu", get_act_layer(name=act)), ("pool", avg_pool_type(1)), ("flatten", nn.Flatten(1)), ("out", nn.Linear(in_channels, out_channels)), @@ -201,7 +237,7 @@ def __init__( for m in self.modules(): if isinstance(m, conv_type): nn.init.kaiming_normal_(torch.as_tensor(m.weight)) - elif isinstance(m, norm_type): + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): nn.init.constant_(torch.as_tensor(m.weight), 1) nn.init.constant_(torch.as_tensor(m.bias), 0) elif isinstance(m, nn.Linear): diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index fcb50c29f3..cb8e195b04 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -19,7 +19,8 @@ from torch import nn from torch.utils import model_zoo -from monai.networks.layers.factories import Act, Conv, Norm, Pad, Pool +from monai.networks.layers.factories import Act, Conv, Pad, Pool +from monai.networks.layers.utils import get_norm_layer __all__ = ["EfficientNet", "EfficientNetBN", "get_efficientnet_image_size", "drop_connect"] @@ -48,8 +49,7 @@ def __init__( expand_ratio: int, se_ratio: Optional[float], id_skip: Optional[bool] = True, - batch_norm_momentum: float = 0.99, - batch_norm_epsilon: float = 1e-3, + norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), drop_connect_rate: Optional[float] = 0.2, ) -> None: """ @@ -65,8 +65,7 @@ def __init__( expand_ratio: expansion ratio for inverted bottleneck. se_ratio: squeeze-excitation ratio for se layers. id_skip: whether to use skip connection. - batch_norm_momentum: momentum for batch norm. - batch_norm_epsilon: epsilon for batch norm. + norm: feature normalization type and arguments. Defaults to batch norm. drop_connect_rate: dropconnect rate for drop connection (individual weights) layers. References: @@ -79,7 +78,6 @@ def __init__( # select the type of N-Dimensional layers to use # these are based on spatial dims and selected from MONAI factories conv_type = Conv["conv", spatial_dims] - batchnorm_type = Norm["batch", spatial_dims] adaptivepool_type = Pool["adaptiveavg", spatial_dims] self.in_channels = in_channels @@ -95,9 +93,6 @@ def __init__( else: self.has_se = False - bn_mom = 1.0 - batch_norm_momentum # pytorch"s difference from tensorflow - bn_eps = batch_norm_epsilon - # Expansion phase (Inverted Bottleneck) inp = in_channels # number of input channels oup = in_channels * expand_ratio # number of output channels @@ -105,7 +100,7 @@ def __init__( self._expand_conv = conv_type(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) self._expand_conv_padding = _make_same_padder(self._expand_conv, image_size) - self._bn0 = batchnorm_type(num_features=oup, momentum=bn_mom, eps=bn_eps) + self._bn0 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=oup) else: # need to have the following to fix JIT error: # "Module 'MBConvBlock' has no attribute '_expand_conv'" @@ -125,7 +120,7 @@ def __init__( bias=False, ) self._depthwise_conv_padding = _make_same_padder(self._depthwise_conv, image_size) - self._bn1 = batchnorm_type(num_features=oup, momentum=bn_mom, eps=bn_eps) + self._bn1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=oup) image_size = _calculate_output_image_size(image_size, self.stride) # Squeeze and Excitation layer, if desired @@ -141,7 +136,7 @@ def __init__( final_oup = out_channels self._project_conv = conv_type(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) self._project_conv_padding = _make_same_padder(self._project_conv, image_size) - self._bn2 = batchnorm_type(num_features=final_oup, momentum=bn_mom, eps=bn_eps) + self._bn2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=final_oup) # swish activation to use - using memory efficient swish by default # can be switched to normal swish using self.set_swish() function call @@ -207,8 +202,7 @@ def __init__( depth_coefficient: float = 1.0, dropout_rate: float = 0.2, image_size: int = 224, - batch_norm_momentum: float = 0.99, - batch_norm_epsilon: float = 1e-3, + norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), drop_connect_rate: float = 0.2, depth_divisor: int = 8, ) -> None: @@ -226,8 +220,7 @@ def __init__( depth_coefficient: depth multiplier coefficient (d in paper). dropout_rate: dropout rate for dropout layers. image_size: input image resolution. - batch_norm_momentum: momentum for batch norm. - batch_norm_epsilon: epsilon for batch norm. + norm: feature normalization type and arguments. Defaults to batch norm. drop_connect_rate: dropconnect rate for drop connection (individual weights) layers. depth_divisor: depth divisor for channel rounding. """ @@ -239,7 +232,6 @@ def __init__( # select the type of N-Dimensional layers to use # these are based on spatial dims and selected from MONAI factories conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv["conv", spatial_dims] - batchnorm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm["batch", spatial_dims] adaptivepool_type: Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]] = Pool[ "adaptiveavg", spatial_dims ] @@ -262,16 +254,12 @@ def __init__( # expand input image dimensions to list current_image_size = [image_size] * spatial_dims - # parameters for batch norm - bn_mom = 1 - batch_norm_momentum # 1 - bn_m to convert tensorflow's arg to pytorch bn compatible - bn_eps = batch_norm_epsilon - # Stem stride = 2 out_channels = _round_filters(32, width_coefficient, depth_divisor) # number of output channels self._conv_stem = conv_type(self.in_channels, out_channels, kernel_size=3, stride=stride, bias=False) self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size) - self._bn0 = batchnorm_type(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + self._bn0 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels) current_image_size = _calculate_output_image_size(current_image_size, stride) # build MBConv blocks @@ -312,8 +300,7 @@ def __init__( expand_ratio=block_args.expand_ratio, se_ratio=block_args.se_ratio, id_skip=block_args.id_skip, - batch_norm_momentum=batch_norm_momentum, - batch_norm_epsilon=batch_norm_epsilon, + norm=norm, drop_connect_rate=blk_drop_connect_rate, ), ) @@ -344,8 +331,7 @@ def __init__( expand_ratio=block_args.expand_ratio, se_ratio=block_args.se_ratio, id_skip=block_args.id_skip, - batch_norm_momentum=batch_norm_momentum, - batch_norm_epsilon=batch_norm_epsilon, + norm=norm, drop_connect_rate=blk_drop_connect_rate, ), ) @@ -360,7 +346,7 @@ def __init__( out_channels = _round_filters(1280, width_coefficient, depth_divisor) self._conv_head = conv_type(head_in_channels, out_channels, kernel_size=1, bias=False) self._conv_head_padding = _make_same_padder(self._conv_head, current_image_size) - self._bn1 = batchnorm_type(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + self._bn1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels) # final linear layer self._avg_pooling = adaptivepool_type(1) @@ -449,6 +435,7 @@ def __init__( spatial_dims: int = 2, in_channels: int = 3, num_classes: int = 1000, + norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), ) -> None: """ Generic wrapper around EfficientNet, used to initialize EfficientNet-B0 to EfficientNet-B7 models @@ -457,11 +444,13 @@ def __init__( Args: model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b7]. - pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2. + pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2 and batch + norm is used. progress: whether to show download progress for pretrained weights download. spatial_dims: number of spatial dimensions. in_channels: number of input channels. num_classes: number of output classes. + norm: feature normalization type and arguments. Defaults to batch norm. Examples:: @@ -515,6 +504,7 @@ def __init__( dropout_rate=dropout_rate, image_size=image_size, drop_connect_rate=dropconnect_rate, + norm=norm, ) # attempt to load pretrained @@ -527,12 +517,6 @@ def __init__( # only pretrained for when `spatial_dims` is 2 _load_state_dict(self, model_name, progress, load_fc) - else: - print( - "Skipping loading pretrained weights for non-default {}, pretrained={}, is_default_model={}".format( - model_name, pretrained, is_default_model - ) - ) def get_efficientnet_image_size(model_name: str) -> int: diff --git a/tests/test_densenet.py b/tests/test_densenet.py index fe0a3a5222..ba4b7afcb4 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -32,13 +32,13 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ # 4-channel 3D, batch 2 - {"pretrained": False, "spatial_dims": 3, "in_channels": 2, "out_channels": 3}, + {"pretrained": False, "spatial_dims": 3, "in_channels": 2, "out_channels": 3, "norm": ("instance", {"eps": 1e-5})}, (2, 2, 32, 64, 48), (2, 3), ] TEST_CASE_2 = [ # 4-channel 2D, batch 2 - {"pretrained": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, + {"pretrained": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 3, "act": "PRELU"}, (2, 2, 32, 64), (2, 3), ] diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index f11fc8d433..6567e3af9a 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -75,7 +75,15 @@ def get_block_args(): ] -def make_shape_cases(models, spatial_dims, batches, pretrained, in_channels=3, num_classes=1000): +def make_shape_cases( + models, + spatial_dims, + batches, + pretrained, + in_channels=3, + num_classes=1000, + norm=("batch", {"eps": 1e-3, "momentum": 0.01}), +): ret_tests = [] for spatial_dim in spatial_dims: # selected spatial_dims for batch in batches: # check single batch as well as multiple batch input @@ -88,6 +96,7 @@ def make_shape_cases(models, spatial_dims, batches, pretrained, in_channels=3, n "spatial_dims": spatial_dim, "in_channels": in_channels, "num_classes": num_classes, + "norm": norm, } ret_tests.append( [ @@ -115,10 +124,22 @@ def make_shape_cases(models, spatial_dims, batches, pretrained, in_channels=3, n # 2D and 3D models are expensive so use selected models CASES_2D = make_shape_cases( - models=SEL_MODELS, spatial_dims=[2], batches=[1, 4], pretrained=[False], in_channels=3, num_classes=1000 + models=SEL_MODELS, + spatial_dims=[2], + batches=[1, 4], + pretrained=[False], + in_channels=3, + num_classes=1000, + norm="instance", ) CASES_3D = make_shape_cases( - models=[SEL_MODELS[0]], spatial_dims=[3], batches=[1], pretrained=[False], in_channels=3, num_classes=1000 + models=[SEL_MODELS[0]], + spatial_dims=[3], + batches=[1], + pretrained=[False], + in_channels=3, + num_classes=1000, + norm="batch", ) # pretrained=True cases @@ -134,6 +155,7 @@ def make_shape_cases(models, spatial_dims, batches, pretrained, in_channels=3, n "spatial_dims": 2, "in_channels": 3, "num_classes": 1000, + "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), }, os.path.join(os.path.dirname(__file__), "testing_data", "kitty_test.jpg"), 282, # ~ tiger cat @@ -209,7 +231,6 @@ class TestEFFICIENTNET(unittest.TestCase): @parameterized.expand(CASES_1D + CASES_2D + CASES_3D + CASES_VARIATIONS) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - print(input_param) # initialize model net = EfficientNetBN(**input_param).to(device) @@ -224,7 +245,6 @@ def test_shape(self, input_param, input_shape, expected_shape): @parameterized.expand(CASES_1D + CASES_2D) def test_non_default_shapes(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" - print(input_param) # initialize model net = EfficientNetBN(**input_param).to(device) @@ -234,7 +254,6 @@ def test_non_default_shapes(self, input_param, input_shape, expected_shape): non_default_sizes = [128, 256, 512] for candidate_size in non_default_sizes: input_shape = input_shape[0:2] + (candidate_size,) * num_dims - print(input_shape) # run inference with random tensor with eval_mode(net): result = net(torch.randn(input_shape).to(device))