diff --git a/docs/source/networks.rst b/docs/source/networks.rst index a5ce86287a..54c2756535 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -358,6 +358,11 @@ Nets .. autoclass:: EfficientNetBN :members: +`EfficientNetBNFeatures` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: EfficientNetBNFeatures + :members: + `SegResNet` ~~~~~~~~~~~ .. autoclass:: SegResNet diff --git a/monai/networks/blocks/activation.py b/monai/networks/blocks/activation.py index ef2d19b550..a380f8e757 100644 --- a/monai/networks/blocks/activation.py +++ b/monai/networks/blocks/activation.py @@ -16,16 +16,28 @@ if optional_import("torch.nn.functional", name="mish")[1]: - def monai_mish(x): - return torch.nn.functional.mish(x, inplace=True) + def monai_mish(x, inplace: bool = False): + return torch.nn.functional.mish(x, inplace=inplace) else: - def monai_mish(x): + def monai_mish(x, inplace: bool = False): return x * torch.tanh(torch.nn.functional.softplus(x)) +if optional_import("torch.nn.functional", name="silu")[1]: + + def monai_swish(x, inplace: bool = False): + return torch.nn.functional.silu(x, inplace=inplace) + + +else: + + def monai_swish(x, inplace: bool = False): + return SwishImplementation.apply(x) + + class Swish(nn.Module): r"""Applies the element-wise function: @@ -92,6 +104,9 @@ class MemoryEfficientSwish(nn.Module): Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941. + From Pytorch 1.7.0+, the optimized version of `Swish` named `SiLU` is implemented, + this class will utilize `torch.nn.functional.silu` to do the calculation if meets the version. + Shape: - Input: :math:`(N, *)` where `*` means, any number of additional dimensions @@ -107,8 +122,13 @@ class MemoryEfficientSwish(nn.Module): >>> output = m(input) """ + def __init__(self, inplace: bool = False): + super(MemoryEfficientSwish, self).__init__() + # inplace only works when using torch.nn.functional.silu + self.inplace = inplace + def forward(self, input: torch.Tensor): - return SwishImplementation.apply(input) + return monai_swish(input, self.inplace) class Mish(nn.Module): @@ -119,6 +139,8 @@ class Mish(nn.Module): Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681. + From Pytorch 1.9.0+, the optimized version of `Mish` is implemented, + this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version. Shape: - Input: :math:`(N, *)` where `*` means, any number of additional @@ -135,5 +157,10 @@ class Mish(nn.Module): >>> output = m(input) """ + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + # inplace only works when using torch.nn.functional.mish + self.inplace = inplace + def forward(self, input: torch.Tensor): - return monai_mish(input) + return monai_mish(input, self.inplace) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9cf6c5e07f..ad1ca2418b 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -31,7 +31,14 @@ densenet264, ) from .dynunet import DynUNet, DynUnet, Dynunet, dynunet -from .efficientnet import BlockArgs, EfficientNet, EfficientNetBN, drop_connect, get_efficientnet_image_size +from .efficientnet import ( + BlockArgs, + EfficientNet, + EfficientNetBN, + EfficientNetBNFeatures, + drop_connect, + get_efficientnet_image_size, +) from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 3e30987bdc..e9f3b6d33e 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -19,6 +19,7 @@ from monai.networks.layers.factories import Conv, Dropout, Pool from monai.networks.layers.utils import get_act_layer, get_norm_layer +from monai.utils.module import look_up_option __all__ = [ "DenseNet", @@ -249,7 +250,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def _load_state_dict(model, arch, progress): +def _load_state_dict(model: nn.Module, arch: str, progress: bool): """ This function is used to load pretrained models. Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16. @@ -260,12 +261,12 @@ def _load_state_dict(model, arch, progress): "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", } - if arch in model_urls: - model_url = model_urls[arch] - else: + model_url = look_up_option(arch, model_urls, None) + if model_url is None: raise ValueError( "only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights." ) + pattern = re.compile( r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py index cb8e195b04..453916758a 100644 --- a/monai/networks/nets/efficientnet.py +++ b/monai/networks/nets/efficientnet.py @@ -21,6 +21,7 @@ from monai.networks.layers.factories import Act, Conv, Pad, Pool from monai.networks.layers.utils import get_norm_layer +from monai.utils.module import look_up_option __all__ = ["EfficientNet", "EfficientNetBN", "get_efficientnet_image_size", "drop_connect"] @@ -34,6 +35,29 @@ "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2), "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2), "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2), + "efficientnet-b8": (2.2, 3.6, 672, 0.5, 0.2), + "efficientnet-l2": (4.3, 5.3, 800, 0.5, 0.2), +} + +url_map = { + "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", + "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", + "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", + "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", + "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", + "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", + "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", + "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", + # trained with adversarial examples, simplify the name to decrease string length + "b0-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth", + "b1-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth", + "b2-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth", + "b3-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth", + "b4-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth", + "b5-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth", + "b6-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth", + "b7-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth", + "b8-ap": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth", } @@ -140,7 +164,7 @@ def __init__( # swish activation to use - using memory efficient swish by default # can be switched to normal swish using self.set_swish() function call - self._swish = Act["memswish"]() + self._swish = Act["memswish"](inplace=True) def forward(self, inputs: torch.Tensor): """MBConvBlock"s forward function. @@ -188,7 +212,7 @@ def set_swish(self, memory_efficient: bool = True) -> None: Args: memory_efficient (bool): Whether to use memory-efficient version of swish. """ - self._swish = Act["memswish"]() if memory_efficient else Act["swish"](alpha=1.0) + self._swish = Act["memswish"](inplace=True) if memory_efficient else Act["swish"](alpha=1.0) class EfficientNet(nn.Module): @@ -208,8 +232,7 @@ def __init__( ) -> None: """ EfficientNet based on `Rethinking Model Scaling for Convolutional Neural Networks `_. - Adapted from `EfficientNet-PyTorch - `_. + Adapted from `EfficientNet-PyTorch `_. Args: blocks_args_str: block definitions. @@ -220,9 +243,10 @@ def __init__( depth_coefficient: depth multiplier coefficient (d in paper). dropout_rate: dropout rate for dropout layers. image_size: input image resolution. - norm: feature normalization type and arguments. Defaults to batch norm. + norm: feature normalization type and arguments. drop_connect_rate: dropconnect rate for drop connection (individual weights) layers. depth_divisor: depth divisor for channel rounding. + """ super().__init__() @@ -266,6 +290,8 @@ def __init__( num_blocks = 0 self._blocks = nn.Sequential() + self.extract_stacks = [] + # update baseline blocks to input/output filters and number of repeats based on width and depth multipliers. for idx, block_args in enumerate(self._blocks_args): block_args = block_args._replace( @@ -278,17 +304,23 @@ def __init__( # calculate the total number of blocks - needed for drop_connect estimation num_blocks += block_args.num_repeat + if block_args.stride > 1: + self.extract_stacks.append(idx) + + self.extract_stacks.append(len(self._blocks_args)) + # create and add MBConvBlocks to self._blocks idx = 0 # block index counter - for block_args in self._blocks_args: + for stack_idx, block_args in enumerate(self._blocks_args): blk_drop_connect_rate = self.drop_connect_rate # scale drop connect_rate if blk_drop_connect_rate: blk_drop_connect_rate *= float(idx) / num_blocks + sub_stack = nn.Sequential() # the first block needs to take care of stride and filter size increase. - self._blocks.add_module( + sub_stack.add_module( str(idx), MBConvBlock( spatial_dims=spatial_dims, @@ -319,7 +351,7 @@ def __init__( blk_drop_connect_rate *= float(idx) / num_blocks # add blocks - self._blocks.add_module( + sub_stack.add_module( str(idx), MBConvBlock( spatial_dims=spatial_dims, @@ -337,9 +369,14 @@ def __init__( ) idx += 1 # increment blocks index counter + self._blocks.add_module( + str(stack_idx), + sub_stack, + ) + # sanity check to see if len(self._blocks) equal expected num_blocks - if len(self._blocks) != num_blocks: - raise ValueError("number of blocks created != num_blocks") + if idx != num_blocks: + raise ValueError("total number of blocks created != num_blocks") # Head head_in_channels = block_args.output_filters @@ -369,8 +406,9 @@ def set_swish(self, memory_efficient: bool = True) -> None: """ self._swish = Act["memswish"]() if memory_efficient else Act["swish"](alpha=1.0) - for block in self._blocks: - block.set_swish(memory_efficient) + for sub_stack in self._blocks: + for block in sub_stack: + block.set_swish(memory_efficient) def forward(self, inputs: torch.Tensor): """ @@ -379,8 +417,7 @@ def forward(self, inputs: torch.Tensor): ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. Returns: - A torch Tensor of classification prediction in shape - ``(Batch, num_classes)``. + a torch Tensor of classification prediction in shape ``(Batch, num_classes)``. """ # Stem x = self._conv_stem(self._conv_stem_padding(inputs)) @@ -436,21 +473,24 @@ def __init__( in_channels: int = 3, num_classes: int = 1000, norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), + adv_prop: bool = False, ) -> None: """ Generic wrapper around EfficientNet, used to initialize EfficientNet-B0 to EfficientNet-B7 models model_name is mandatory argument as there is no EfficientNetBN itself, - it needs the N in [0, 1, 2, 3, 4, 5, 6, 7] to be a model + it needs the N in [0, 1, 2, 3, 4, 5, 6, 7, 8] to be a model Args: - model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b7]. + model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2]. pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2 and batch norm is used. progress: whether to show download progress for pretrained weights download. spatial_dims: number of spatial dimensions. in_channels: number of input channels. num_classes: number of output classes. - norm: feature normalization type and arguments. Defaults to batch norm. + norm: feature normalization type and arguments. + adv_prop: whether to use weights trained with adversarial examples. + This argument only works when `pretrained` is `True`. Examples:: @@ -471,7 +511,7 @@ def __init__( >>> model = EfficientNetBN("efficientnet-b7", spatial_dims=2) """ - # block args for EfficientNet-B0 to EfficientNet-B7 + # block args blocks_args_str = [ "r1_k3_s11_e1_i32_o16_se0.25", "r2_k3_s22_e6_i16_o24_se0.25", @@ -507,16 +547,91 @@ def __init__( norm=norm, ) - # attempt to load pretrained - is_default_model = (spatial_dims == 2) and (in_channels == 3) - loadable_from_file = pretrained and is_default_model + # only pretrained for when `spatial_dims` is 2 + if pretrained and (spatial_dims == 2): + _load_state_dict(self, model_name, progress, adv_prop) + + +class EfficientNetBNFeatures(EfficientNet): + def __init__( + self, + model_name: str, + pretrained: bool = True, + progress: bool = True, + spatial_dims: int = 2, + in_channels: int = 3, + num_classes: int = 1000, + norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.01}), + adv_prop: bool = False, + ) -> None: + """ + Initialize EfficientNet-B0 to EfficientNet-B7 models as a backbone, the backbone can + be used as an encoder for segmentation and objection models. + Compared with the class `EfficientNetBN`, the only different place is the forward function. + + This class refers to `PyTorch image models `_. - if loadable_from_file: - # skip loading fc layers for transfer learning applications - load_fc = num_classes == 1000 + """ + blocks_args_str = [ + "r1_k3_s11_e1_i32_o16_se0.25", + "r2_k3_s22_e6_i16_o24_se0.25", + "r2_k5_s22_e6_i24_o40_se0.25", + "r3_k3_s22_e6_i40_o80_se0.25", + "r3_k5_s11_e6_i80_o112_se0.25", + "r4_k5_s22_e6_i112_o192_se0.25", + "r1_k3_s11_e6_i192_o320_se0.25", + ] + + # check if model_name is valid model + if model_name not in efficientnet_params.keys(): + raise ValueError( + "invalid model_name {} found, must be one of {} ".format( + model_name, ", ".join(efficientnet_params.keys()) + ) + ) + + # get network parameters + weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name] + + # create model and initialize random weights + super(EfficientNetBNFeatures, self).__init__( + blocks_args_str=blocks_args_str, + spatial_dims=spatial_dims, + in_channels=in_channels, + num_classes=num_classes, + width_coefficient=weight_coeff, + depth_coefficient=depth_coeff, + dropout_rate=dropout_rate, + image_size=image_size, + drop_connect_rate=dropconnect_rate, + norm=norm, + ) + + # only pretrained for when `spatial_dims` is 2 + if pretrained and (spatial_dims == 2): + _load_state_dict(self, model_name, progress, adv_prop) + + def forward(self, inputs: torch.Tensor): + """ + Args: + inputs: input should have spatially N dimensions + ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. + + Returns: + a list of torch Tensors. + """ + # Stem + x = self._conv_stem(self._conv_stem_padding(inputs)) + x = self._swish(self._bn0(x)) - # only pretrained for when `spatial_dims` is 2 - _load_state_dict(self, model_name, progress, load_fc) + features = [] + if 0 in self.extract_stacks: + features.append(x) + for i, block in enumerate(self._blocks): + x = block(x) + if i + 1 in self.extract_stacks: + features.append(x) + return features def get_efficientnet_image_size(model_name: str) -> int: @@ -588,38 +703,25 @@ def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor return output -def _load_state_dict(model: nn.Module, model_name: str, progress: bool, load_fc: bool) -> None: - url_map = { - "efficientnet-b0": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", - "efficientnet-b1": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", - "efficientnet-b2": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", - "efficientnet-b3": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", - "efficientnet-b4": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", - "efficientnet-b5": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", - "efficientnet-b6": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", - "efficientnet-b7": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", - } - # load state dict from url - model_url = url_map[model_name] - state_dict = model_zoo.load_url(model_url, progress=progress) - - # load state dict into model parameters - if load_fc: # load everything - ret = model.load_state_dict(state_dict, strict=False) - if ret.missing_keys: - raise ValueError("Found missing keys when loading pretrained weights: {}".format(ret.missing_keys)) - else: # skip final FC layers, for transfer learning cases - state_dict.pop("_fc.weight") - state_dict.pop("_fc.bias") - ret = model.load_state_dict(state_dict, strict=False) - - # check if no other keys missing except FC layer parameters - if set(ret.missing_keys) != {"_fc.weight", "_fc.bias"}: - raise ValueError("Found missing keys when loading pretrained weights: {}".format(ret.missing_keys)) - - # check for any unexpected keys - if ret.unexpected_keys: - raise ValueError("Missing keys when loading pretrained weights: {}".format(ret.unexpected_keys)) +def _load_state_dict(model: nn.Module, arch: str, progress: bool, adv_prop: bool) -> None: + if adv_prop: + arch = arch.split("efficientnet-")[-1] + "-ap" + model_url = look_up_option(arch, url_map, None) + if model_url is None: + print("pretrained weights of {} is not provided".format(arch)) + else: + # load state dict from url + model_url = url_map[arch] + pretrain_state_dict = model_zoo.load_url(model_url, progress=progress) + model_state_dict = model.state_dict() + + pattern = re.compile(r"(.+)\.\d+(\.\d+\..+)") + for key, value in model_state_dict.items(): + pretrain_key = re.sub(pattern, r"\1\2", key) + if pretrain_key in pretrain_state_dict and value.shape == pretrain_state_dict[pretrain_key].shape: + model_state_dict[key] = pretrain_state_dict[pretrain_key] + + model.load_state_dict(model_state_dict) def _get_same_padding_conv_nd( diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 7292b2a1d5..9b7035c259 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -20,6 +20,7 @@ from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.squeeze_and_excitation import SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck from monai.networks.layers.factories import Act, Conv, Dropout, Norm, Pool +from monai.utils.module import look_up_option __all__ = ["SENet", "SENet154", "SEResNet50", "SEResNet101", "SEResNet152", "SEResNeXt50", "SEResNext101"] @@ -249,7 +250,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def _load_state_dict(model, arch, progress): +def _load_state_dict(model: nn.Module, arch: str, progress: bool): """ This function is used to load pretrained models. """ @@ -261,9 +262,8 @@ def _load_state_dict(model, arch, progress): "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", } - if arch in model_urls: - model_url = model_urls[arch] - else: + model_url = look_up_option(arch, model_urls, None) + if model_url is None: raise ValueError( "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', " + "and se_resnext101_32x4d are supported to load pretrained weights." diff --git a/tests/test_efficientnet.py b/tests/test_efficientnet.py index 6567e3af9a..6befba108a 100644 --- a/tests/test_efficientnet.py +++ b/tests/test_efficientnet.py @@ -18,7 +18,13 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import BlockArgs, EfficientNetBN, drop_connect, get_efficientnet_image_size +from monai.networks.nets import ( + BlockArgs, + EfficientNetBN, + EfficientNetBNFeatures, + drop_connect, + get_efficientnet_image_size, +) from monai.utils import optional_import from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save @@ -156,6 +162,7 @@ def make_shape_cases( "in_channels": 3, "num_classes": 1000, "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), + "adv_prop": False, }, os.path.join(os.path.dirname(__file__), "testing_data", "kitty_test.jpg"), 282, # ~ tiger cat @@ -226,6 +233,21 @@ def make_shape_cases( ) ) +CASE_EXTRACT_FEATURES = [ + ( + { + "model_name": "efficientnet-b8", + "pretrained": True, + "progress": False, + "spatial_dims": 2, + "in_channels": 2, + "adv_prop": True, + }, + [1, 2, 224, 224], + ([1, 32, 112, 112], [1, 56, 56, 56], [1, 88, 28, 28], [1, 248, 14, 14], [1, 704, 7, 7]), + ), +] + class TestEFFICIENTNET(unittest.TestCase): @parameterized.expand(CASES_1D + CASES_2D + CASES_3D + CASES_VARIATIONS) @@ -355,5 +377,23 @@ def test_script(self): test_script_save(net, test_data) +class TestExtractFeatures(unittest.TestCase): + @parameterized.expand(CASE_EXTRACT_FEATURES) + def test_shape(self, input_param, input_shape, expected_shapes): + device = "cuda" if torch.cuda.is_available() else "cpu" + + # initialize model + net = EfficientNetBNFeatures(**input_param).to(device) + + # run inference with random tensor + with eval_mode(net): + features = net(torch.randn(input_shape).to(device)) + + # check output shape + self.assertEqual(len(features), len(expected_shapes)) + for feature, expected_shape in zip(features, expected_shapes): + self.assertEqual(feature.shape, torch.Size(expected_shape)) + + if __name__ == "__main__": unittest.main()