From 002bc27e20c1fc38e625b20f86b3001d750a3e55 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 12 Nov 2021 10:18:28 +0800 Subject: [PATCH 1/6] 3251 Add dependency check in WSIReader (#3312) * [DLMED] add dep check Signed-off-by: Nic Ma * [DLMED] fix typo Signed-off-by: Nic Ma * [DLMED] update according to comments Signed-off-by: Nic Ma Signed-off-by: myron --- monai/data/image_reader.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index c7d77e0781..4830b56aa8 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -37,6 +37,10 @@ Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") PILImage, has_pil = optional_import("PIL.Image") +OpenSlide, _ = optional_import("openslide", name="OpenSlide") +CuImage, _ = optional_import("cucim", name="CuImage") +TiffFile, _ = optional_import("tifffile", name="TiffFile") + __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"] @@ -690,16 +694,20 @@ class WSIReader(ImageReader): def __init__(self, backend: str = "OpenSlide", level: int = 0): super().__init__() self.backend = backend.lower() - if self.backend == "openslide": - self.wsi_reader, *_ = optional_import("openslide", name="OpenSlide") - elif self.backend == "cucim": - self.wsi_reader, *_ = optional_import("cucim", name="CuImage") - elif self.backend == "tifffile": - self.wsi_reader, *_ = optional_import("tifffile", name="TiffFile") - else: - raise ValueError('`backend` should be "cuCIM", "OpenSlide", or "TiffFile') + func = require_pkg(self.backend)(self._set_reader) + self.wsi_reader = func(self.backend) self.level = level + @staticmethod + def _set_reader(backend: str): + if backend == "openslide": + return OpenSlide + if backend == "cucim": + return CuImage + if backend == "tifffile": + return TiffFile + raise ValueError("`backend` should be 'cuCIM', 'OpenSlide' or 'TiffFile'.") + def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. From bd113cd6c5a9e44eea8bec090dee86ec45721883 Mon Sep 17 00:00:00 2001 From: myron Date: Mon, 1 Nov 2021 16:51:11 -0700 Subject: [PATCH 2/6] [DLMED] MILmodel PR Signed-off-by: myron --- monai/networks/nets/__init__.py | 1 + monai/networks/nets/milmodel.py | 226 ++++++++++++++++++++++++++++++++ tests/min_tests.py | 1 + tests/test_milmodel.py | 92 +++++++++++++ 4 files changed, 320 insertions(+) create mode 100644 monai/networks/nets/milmodel.py create mode 100644 tests/test_milmodel.py diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index a07297be13..514a7ed501 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -41,6 +41,7 @@ from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet +from .milmodel import MilMode, MILModel from .netadapter import NetAdapter from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py new file mode 100644 index 0000000000..1e3d680ebc --- /dev/null +++ b/monai/networks/nets/milmodel.py @@ -0,0 +1,226 @@ +from enum import Enum +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn + +from monai.utils.module import optional_import + +models, _ = optional_import("torchvision.models") + + +class MilMode(Enum): + MEAN = "mean" + MAX = "max" + ATT = "att" + ATT_TRANS = "att_trans" + ATT_TRANS_PYRAMID = "att_trans_pyramid" + + +class MILModel(nn.Module): + """ + A wrapper around backbone classification model suitable for MIL + + Args: + num_classes: number of output classes + mil_mode: MIL variant (supported max, mean, att, att_trans, att_trans_pyramid + pretrained: init backbone with pretrained weights. Defaults to True. + backbone: Backbone classifier CNN. Defaults to None, it which case ResNet50 will be used. + backbone_nfeatures: Number of output featues of the backbone CNN (necessary only when using custom backbone) + + mil_mode: + MilMode.MEAN - average features from all instances, equivalent to pure CNN (non MIL) + MilMode.MAX - retain only the instance with the max probability for loss calculation + MilMode.ATT - attention based MIL https://arxiv.org/abs/1802.04712 + MilMode.ATT_TRANS - transformer MIL https://arxiv.org/abs/2111.01556 + MilMode.ATT_TRANS_PYRAMID - transformer pyramid MIL https://arxiv.org/abs/2111.01556 + + """ + + def __init__( + self, + num_classes: int, + mil_mode: MilMode = MilMode.ATT, + pretrained: bool = True, + backbone: Optional[Union[str, nn.Module]] = None, + backbone_nfeatures: Optional[int] = None, + trans_blocks: int = 4, + trans_dropout: float = 0.0, + ) -> None: + + super().__init__() + + if num_classes <= 0: + raise ValueError("Number of classes must be positive: " + str(num_classes)) + + self.mil_mode = mil_mode + print("MILModel with mode", mil_mode, "num_classes", num_classes) + self.attention = nn.Sequential() + self.transformer = None + + if backbone is None: + + net = models.resnet50(pretrained=pretrained) + nfc = net.fc.in_features # save the number of final features + net.fc = torch.nn.Identity() # remove final linear layer + + self.extra_outputs = {} # type: Dict[str, torch.Tensor] + + if mil_mode == MilMode.ATT_TRANS_PYRAMID: + # register hooks to capture outputs of intermediate layers + def forward_hook(layer_name): + def hook(module, input, output): + self.extra_outputs[layer_name] = output + + return hook + + net.layer1.register_forward_hook(forward_hook("layer1")) + net.layer2.register_forward_hook(forward_hook("layer2")) + net.layer3.register_forward_hook(forward_hook("layer3")) + net.layer4.register_forward_hook(forward_hook("layer4")) + + elif isinstance(backbone, str): + + # assume torchvision model string is provided + trch_model = getattr(models, backbone, None) + if trch_model is None: + raise ValueError("Unknown torch vision model" + str(backbone)) + net = trch_model(pretrained=pretrained) + + if getattr(net, "fc", None) is not None: + nfc = net.fc.in_features # save the number of final features + net.fc = torch.nn.Identity() # remove final linear layer + else: + raise ValueError( + "Unable to detect FC layer for torch vision model " + str(backbone), + ". Please initialize the backbone model manually.", + ) + + else: + # use a custom backbone (untested) + net = backbone + nfc = backbone_nfeatures + + if backbone_nfeatures is None: + raise ValueError("Number of endencoder features must be provided for a custom backbone model") + + if backbone is not None and mil_mode not in [MilMode.MEAN, MilMode.MAX, MilMode.ATT, MilMode.ATT_TRANS]: + raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode)) + + if self.mil_mode in [MilMode.MEAN, MilMode.MAX]: + pass + elif self.mil_mode == MilMode.ATT: + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + elif self.mil_mode == MilMode.ATT_TRANS: + transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout) + self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks) + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + elif self.mil_mode == MilMode.ATT_TRANS_PYRAMID: + + transformer_list = nn.ModuleList( + [ + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=trans_blocks + ), + nn.Sequential( + nn.Linear(768, 256), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ), + nn.Sequential( + nn.Linear(1280, 256), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ] + ) + self.transformer = transformer_list # type: ignore + nfc = nfc + 256 + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + else: + raise ValueError("Unsupported mil_mode: " + str(mil_mode)) + + self.myfc = nn.Linear(nfc, num_classes) + self.net = net + + def calc_head(self, x: torch.Tensor) -> torch.Tensor: + + sh = x.shape + + if self.mil_mode == MilMode.MEAN: + x = self.myfc(x) + x = torch.mean(x, dim=1) + + elif self.mil_mode == MilMode.MAX: + x = self.myfc(x) + x, _ = torch.max(x, dim=1) + + elif self.mil_mode == MilMode.ATT: + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + elif self.mil_mode == MilMode.ATT_TRANS and self.transformer is not None: + + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + elif self.mil_mode == MilMode.ATT_TRANS_PYRAMID and self.transformer is not None: + + l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + + transformer_list: List = self.transformer # type: ignore + x = transformer_list[0](l1) + x = transformer_list[1](torch.cat((x, l2), dim=2)) + x = transformer_list[2](torch.cat((x, l3), dim=2)) + x = transformer_list[3](torch.cat((x, l4), dim=2)) + + x = x.permute(1, 0, 2) + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + else: + raise ValueError("Wrong model mode" + str(self.mil_mode)) + + return x + + def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor: + + sh = x.shape + x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4]) + + x = self.net(x) + x = x.reshape(sh[0], sh[1], -1) + + if not no_head: + x = self.calc_head(x) + + return x diff --git a/tests/min_tests.py b/tests/min_tests.py index d1f5384c2c..829f67c485 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -104,6 +104,7 @@ def run_testsuit(): "test_load_imaged", "test_load_spacing_orientation", "test_mednistdataset", + "test_milmodel", "test_mlp", "test_nifti_header_revise", "test_nifti_rw", diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py new file mode 100644 index 0000000000..92be812176 --- /dev/null +++ b/tests/test_milmodel.py @@ -0,0 +1,92 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MilMode, MILModel +from monai.utils.module import optional_import +from tests.utils import test_script_save + +models, _ = optional_import("torchvision.models") + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +TEST_CASE_MILMODEL = [] +for num_classes in [1, 5]: + for mil_mode in [MilMode.MEAN, MilMode.MAX, MilMode.ATT, MilMode.ATT_TRANS, MilMode.ATT_TRANS_PYRAMID]: + test_case = [ + {"num_classes": num_classes, "mil_mode": mil_mode, "pretrained": False}, + (1, 2, 3, 512, 512), + (1, num_classes), + ] + TEST_CASE_MILMODEL.append(test_case) + + +for trans_blocks in [1, 3]: + test_case = [ + {"num_classes": 5, "pretrained": False, "trans_blocks": trans_blocks, "trans_dropout": 0.5}, + (1, 2, 3, 512, 512), + (1, 5), + ] + TEST_CASE_MILMODEL.append(test_case) + +# torchvision backbone +TEST_CASE_MILMODEL.append( + [{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)] +) +TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)]) + +# custom backbone +backbone = models.densenet121(pretrained=False) +backbone_nfeatures = backbone.classifier.in_features +backbone.classifier = torch.nn.Identity() +TEST_CASE_MILMODEL.append( + [ + {"num_classes": 5, "backbone": backbone, "backbone_nfeatures": backbone_nfeatures, "pretrained": False}, + (2, 2, 3, 512, 512), + (2, 5), + ] +) + + +class TestResNet(unittest.TestCase): + @parameterized.expand(TEST_CASE_MILMODEL) + def test_shape(self, input_param, input_shape, expected_shape): + net = MILModel(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape, dtype=torch.float).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_args(self): + with self.assertRaises(ValueError): + MILModel( + num_classes=5, + pretrained=False, + backbone="resnet50", + backbone_nfeatures=2048, + mil_mode=MilMode.ATT_TRANS_PYRAMID, + ) + + def test_script(self): + input_param, input_shape, expected_shape = TEST_CASE_MILMODEL[0] + net = MILModel(**input_param) + test_data = torch.randn(input_shape, dtype=torch.float) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main() From 0a220b8f522eeccd464555e58e5b7b457160ffb0 Mon Sep 17 00:00:00 2001 From: myron Date: Fri, 12 Nov 2021 01:07:27 -0800 Subject: [PATCH 3/6] small updates Signed-off-by: myron --- monai/networks/nets/milmodel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index 1e3d680ebc..fc3b258111 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -56,7 +56,7 @@ def __init__( self.mil_mode = mil_mode print("MILModel with mode", mil_mode, "num_classes", num_classes) self.attention = nn.Sequential() - self.transformer = None + self.transformer = nn.Module() if backbone is None: @@ -144,7 +144,7 @@ def hook(module, input, output): ), ] ) - self.transformer = transformer_list # type: ignore + self.transformer = transformer_list nfc = nfc + 256 self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) From d6859a8fdadede0227c5919e0e96fa4feb503e40 Mon Sep 17 00:00:00 2001 From: myron Date: Fri, 12 Nov 2021 10:47:46 -0800 Subject: [PATCH 4/6] fix jit issues Signed-off-by: myron --- monai/networks/nets/milmodel.py | 2 +- tests/test_milmodel.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index fc3b258111..175264b36f 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -56,7 +56,7 @@ def __init__( self.mil_mode = mil_mode print("MILModel with mode", mil_mode, "num_classes", num_classes) self.attention = nn.Sequential() - self.transformer = nn.Module() + self.transformer = None # type: Optional[nn.Module] if backbone is None: diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py index 92be812176..af055ac60e 100644 --- a/tests/test_milmodel.py +++ b/tests/test_milmodel.py @@ -63,7 +63,7 @@ ) -class TestResNet(unittest.TestCase): +class TestMilModel(unittest.TestCase): @parameterized.expand(TEST_CASE_MILMODEL) def test_shape(self, input_param, input_shape, expected_shape): net = MILModel(**input_param).to(device) From a4d9ff4913a7f1e01955604feaa732d908e63f6a Mon Sep 17 00:00:00 2001 From: myron Date: Fri, 12 Nov 2021 12:30:31 -0800 Subject: [PATCH 5/6] jit fix attempt Signed-off-by: myron --- monai/networks/nets/milmodel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index 175264b36f..a28f9002a9 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, cast import torch import torch.nn as nn @@ -193,7 +193,8 @@ def calc_head(self, x: torch.Tensor) -> torch.Tensor: l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) - transformer_list: List = self.transformer # type: ignore + transformer_list = cast(nn.ModuleList, self.transformer) + x = transformer_list[0](l1) x = transformer_list[1](torch.cat((x, l2), dim=2)) x = transformer_list[2](torch.cat((x, l3), dim=2)) From 51c76d4e6492dedcb6990bad4311eab3179e6e72 Mon Sep 17 00:00:00 2001 From: myron Date: Sat, 13 Nov 2021 16:17:41 -0800 Subject: [PATCH 6/6] removing Enum Signed-off-by: myron --- monai/networks/nets/__init__.py | 2 +- monai/networks/nets/milmodel.py | 84 +++++++++++++++++---------------- tests/test_milmodel.py | 11 ++--- 3 files changed, 49 insertions(+), 48 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 514a7ed501..f7348b7b44 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -41,7 +41,7 @@ from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet -from .milmodel import MilMode, MILModel +from .milmodel import MILModel from .netadapter import NetAdapter from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index a28f9002a9..213a86864f 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -1,4 +1,3 @@ -from enum import Enum from typing import Dict, Optional, Union, cast import torch @@ -9,41 +8,38 @@ models, _ = optional_import("torchvision.models") -class MilMode(Enum): - MEAN = "mean" - MAX = "max" - ATT = "att" - ATT_TRANS = "att_trans" - ATT_TRANS_PYRAMID = "att_trans_pyramid" - - class MILModel(nn.Module): """ - A wrapper around backbone classification model suitable for MIL + Multiple Instance Learning (MIL) model, with a backbone classification model. Args: num_classes: number of output classes - mil_mode: MIL variant (supported max, mean, att, att_trans, att_trans_pyramid - pretrained: init backbone with pretrained weights. Defaults to True. - backbone: Backbone classifier CNN. Defaults to None, it which case ResNet50 will be used. - backbone_nfeatures: Number of output featues of the backbone CNN (necessary only when using custom backbone) + mil_mode: MIL algorithm (either mean, max, att, att_trans, att_trans_pyramid) + Defaults to ``att``. + pretrained: init backbone with pretrained weights. + Defaults to ``True``. + backbone: Backbone classifier CNN. (either None, nn.Module that returns features, + or a string name of a torchvision model) + Defaults to ``None``, in which case ResNet50 is used. + backbone_num_features: Number of output features of the backbone CNN + Defaults to ``None`` (necessary only when using a custom backbone) mil_mode: - MilMode.MEAN - average features from all instances, equivalent to pure CNN (non MIL) - MilMode.MAX - retain only the instance with the max probability for loss calculation - MilMode.ATT - attention based MIL https://arxiv.org/abs/1802.04712 - MilMode.ATT_TRANS - transformer MIL https://arxiv.org/abs/2111.01556 - MilMode.ATT_TRANS_PYRAMID - transformer pyramid MIL https://arxiv.org/abs/2111.01556 + "mean" - average features from all instances, equivalent to pure CNN (non MIL) + "max - retain only the instance with the max probability for loss calculation + "att" - attention based MIL https://arxiv.org/abs/1802.04712 + "att_trans" - transformer MIL https://arxiv.org/abs/2111.01556 + "att_trans_pyramid" - transformer pyramid MIL https://arxiv.org/abs/2111.01556 """ def __init__( self, num_classes: int, - mil_mode: MilMode = MilMode.ATT, + mil_mode: str = "att", pretrained: bool = True, backbone: Optional[Union[str, nn.Module]] = None, - backbone_nfeatures: Optional[int] = None, + backbone_num_features: Optional[int] = None, trans_blocks: int = 4, trans_dropout: float = 0.0, ) -> None: @@ -53,7 +49,10 @@ def __init__( if num_classes <= 0: raise ValueError("Number of classes must be positive: " + str(num_classes)) - self.mil_mode = mil_mode + if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: + raise ValueError("Unsupported mil_mode: " + str(mil_mode)) + + self.mil_mode = mil_mode.lower() print("MILModel with mode", mil_mode, "num_classes", num_classes) self.attention = nn.Sequential() self.transformer = None # type: Optional[nn.Module] @@ -66,7 +65,7 @@ def __init__( self.extra_outputs = {} # type: Dict[str, torch.Tensor] - if mil_mode == MilMode.ATT_TRANS_PYRAMID: + if mil_mode == "att_trans_pyramid": # register hooks to capture outputs of intermediate layers def forward_hook(layer_name): def hook(module, input, output): @@ -82,42 +81,45 @@ def hook(module, input, output): elif isinstance(backbone, str): # assume torchvision model string is provided - trch_model = getattr(models, backbone, None) - if trch_model is None: + torch_model = getattr(models, backbone, None) + if torch_model is None: raise ValueError("Unknown torch vision model" + str(backbone)) - net = trch_model(pretrained=pretrained) + net = torch_model(pretrained=pretrained) if getattr(net, "fc", None) is not None: nfc = net.fc.in_features # save the number of final features net.fc = torch.nn.Identity() # remove final linear layer else: raise ValueError( - "Unable to detect FC layer for torch vision model " + str(backbone), + "Unable to detect FC layer for the torchvision model " + str(backbone), ". Please initialize the backbone model manually.", ) - else: - # use a custom backbone (untested) + elif isinstance(backbone, nn.Module): + # use a custom backbone net = backbone - nfc = backbone_nfeatures + nfc = backbone_num_features - if backbone_nfeatures is None: + if backbone_num_features is None: raise ValueError("Number of endencoder features must be provided for a custom backbone model") - if backbone is not None and mil_mode not in [MilMode.MEAN, MilMode.MAX, MilMode.ATT, MilMode.ATT_TRANS]: + else: + raise ValueError("Unsupported backbone") + + if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]: raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode)) - if self.mil_mode in [MilMode.MEAN, MilMode.MAX]: + if self.mil_mode in ["mean", "max"]: pass - elif self.mil_mode == MilMode.ATT: + elif self.mil_mode == "att": self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) - elif self.mil_mode == MilMode.ATT_TRANS: + elif self.mil_mode == "att_trans": transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout) self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks) self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) - elif self.mil_mode == MilMode.ATT_TRANS_PYRAMID: + elif self.mil_mode == "att_trans_pyramid": transformer_list = nn.ModuleList( [ @@ -158,15 +160,15 @@ def calc_head(self, x: torch.Tensor) -> torch.Tensor: sh = x.shape - if self.mil_mode == MilMode.MEAN: + if self.mil_mode == "mean": x = self.myfc(x) x = torch.mean(x, dim=1) - elif self.mil_mode == MilMode.MAX: + elif self.mil_mode == "max": x = self.myfc(x) x, _ = torch.max(x, dim=1) - elif self.mil_mode == MilMode.ATT: + elif self.mil_mode == "att": a = self.attention(x) a = torch.softmax(a, dim=1) @@ -174,7 +176,7 @@ def calc_head(self, x: torch.Tensor) -> torch.Tensor: x = self.myfc(x) - elif self.mil_mode == MilMode.ATT_TRANS and self.transformer is not None: + elif self.mil_mode == "att_trans" and self.transformer is not None: x = x.permute(1, 0, 2) x = self.transformer(x) @@ -186,7 +188,7 @@ def calc_head(self, x: torch.Tensor) -> torch.Tensor: x = self.myfc(x) - elif self.mil_mode == MilMode.ATT_TRANS_PYRAMID and self.transformer is not None: + elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None: l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py index af055ac60e..9b21d4e2d1 100644 --- a/tests/test_milmodel.py +++ b/tests/test_milmodel.py @@ -15,19 +15,18 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import MilMode, MILModel +from monai.networks.nets import MILModel from monai.utils.module import optional_import from tests.utils import test_script_save models, _ = optional_import("torchvision.models") - device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_MILMODEL = [] for num_classes in [1, 5]: - for mil_mode in [MilMode.MEAN, MilMode.MAX, MilMode.ATT, MilMode.ATT_TRANS, MilMode.ATT_TRANS_PYRAMID]: + for mil_mode in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: test_case = [ {"num_classes": num_classes, "mil_mode": mil_mode, "pretrained": False}, (1, 2, 3, 512, 512), @@ -56,7 +55,7 @@ backbone.classifier = torch.nn.Identity() TEST_CASE_MILMODEL.append( [ - {"num_classes": 5, "backbone": backbone, "backbone_nfeatures": backbone_nfeatures, "pretrained": False}, + {"num_classes": 5, "backbone": backbone, "backbone_num_features": backbone_nfeatures, "pretrained": False}, (2, 2, 3, 512, 512), (2, 5), ] @@ -77,8 +76,8 @@ def test_ill_args(self): num_classes=5, pretrained=False, backbone="resnet50", - backbone_nfeatures=2048, - mil_mode=MilMode.ATT_TRANS_PYRAMID, + backbone_num_features=2048, + mil_mode="att_trans_pyramid", ) def test_script(self):