diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index a07297be13..f7348b7b44 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -41,6 +41,7 @@ from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator from .highresnet import HighResBlock, HighResNet +from .milmodel import MILModel from .netadapter import NetAdapter from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py new file mode 100644 index 0000000000..213a86864f --- /dev/null +++ b/monai/networks/nets/milmodel.py @@ -0,0 +1,229 @@ +from typing import Dict, Optional, Union, cast + +import torch +import torch.nn as nn + +from monai.utils.module import optional_import + +models, _ = optional_import("torchvision.models") + + +class MILModel(nn.Module): + """ + Multiple Instance Learning (MIL) model, with a backbone classification model. + + Args: + num_classes: number of output classes + mil_mode: MIL algorithm (either mean, max, att, att_trans, att_trans_pyramid) + Defaults to ``att``. + pretrained: init backbone with pretrained weights. + Defaults to ``True``. + backbone: Backbone classifier CNN. (either None, nn.Module that returns features, + or a string name of a torchvision model) + Defaults to ``None``, in which case ResNet50 is used. + backbone_num_features: Number of output features of the backbone CNN + Defaults to ``None`` (necessary only when using a custom backbone) + + mil_mode: + "mean" - average features from all instances, equivalent to pure CNN (non MIL) + "max - retain only the instance with the max probability for loss calculation + "att" - attention based MIL https://arxiv.org/abs/1802.04712 + "att_trans" - transformer MIL https://arxiv.org/abs/2111.01556 + "att_trans_pyramid" - transformer pyramid MIL https://arxiv.org/abs/2111.01556 + + """ + + def __init__( + self, + num_classes: int, + mil_mode: str = "att", + pretrained: bool = True, + backbone: Optional[Union[str, nn.Module]] = None, + backbone_num_features: Optional[int] = None, + trans_blocks: int = 4, + trans_dropout: float = 0.0, + ) -> None: + + super().__init__() + + if num_classes <= 0: + raise ValueError("Number of classes must be positive: " + str(num_classes)) + + if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: + raise ValueError("Unsupported mil_mode: " + str(mil_mode)) + + self.mil_mode = mil_mode.lower() + print("MILModel with mode", mil_mode, "num_classes", num_classes) + self.attention = nn.Sequential() + self.transformer = None # type: Optional[nn.Module] + + if backbone is None: + + net = models.resnet50(pretrained=pretrained) + nfc = net.fc.in_features # save the number of final features + net.fc = torch.nn.Identity() # remove final linear layer + + self.extra_outputs = {} # type: Dict[str, torch.Tensor] + + if mil_mode == "att_trans_pyramid": + # register hooks to capture outputs of intermediate layers + def forward_hook(layer_name): + def hook(module, input, output): + self.extra_outputs[layer_name] = output + + return hook + + net.layer1.register_forward_hook(forward_hook("layer1")) + net.layer2.register_forward_hook(forward_hook("layer2")) + net.layer3.register_forward_hook(forward_hook("layer3")) + net.layer4.register_forward_hook(forward_hook("layer4")) + + elif isinstance(backbone, str): + + # assume torchvision model string is provided + torch_model = getattr(models, backbone, None) + if torch_model is None: + raise ValueError("Unknown torch vision model" + str(backbone)) + net = torch_model(pretrained=pretrained) + + if getattr(net, "fc", None) is not None: + nfc = net.fc.in_features # save the number of final features + net.fc = torch.nn.Identity() # remove final linear layer + else: + raise ValueError( + "Unable to detect FC layer for the torchvision model " + str(backbone), + ". Please initialize the backbone model manually.", + ) + + elif isinstance(backbone, nn.Module): + # use a custom backbone + net = backbone + nfc = backbone_num_features + + if backbone_num_features is None: + raise ValueError("Number of endencoder features must be provided for a custom backbone model") + + else: + raise ValueError("Unsupported backbone") + + if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]: + raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode)) + + if self.mil_mode in ["mean", "max"]: + pass + elif self.mil_mode == "att": + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + elif self.mil_mode == "att_trans": + transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout) + self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks) + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + elif self.mil_mode == "att_trans_pyramid": + + transformer_list = nn.ModuleList( + [ + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=trans_blocks + ), + nn.Sequential( + nn.Linear(768, 256), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ), + nn.Sequential( + nn.Linear(1280, 256), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ), + nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout), + num_layers=trans_blocks, + ), + ] + ) + self.transformer = transformer_list + nfc = nfc + 256 + self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) + + else: + raise ValueError("Unsupported mil_mode: " + str(mil_mode)) + + self.myfc = nn.Linear(nfc, num_classes) + self.net = net + + def calc_head(self, x: torch.Tensor) -> torch.Tensor: + + sh = x.shape + + if self.mil_mode == "mean": + x = self.myfc(x) + x = torch.mean(x, dim=1) + + elif self.mil_mode == "max": + x = self.myfc(x) + x, _ = torch.max(x, dim=1) + + elif self.mil_mode == "att": + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + elif self.mil_mode == "att_trans" and self.transformer is not None: + + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None: + + l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) + + transformer_list = cast(nn.ModuleList, self.transformer) + + x = transformer_list[0](l1) + x = transformer_list[1](torch.cat((x, l2), dim=2)) + x = transformer_list[2](torch.cat((x, l3), dim=2)) + x = transformer_list[3](torch.cat((x, l4), dim=2)) + + x = x.permute(1, 0, 2) + + a = self.attention(x) + a = torch.softmax(a, dim=1) + x = torch.sum(x * a, dim=1) + + x = self.myfc(x) + + else: + raise ValueError("Wrong model mode" + str(self.mil_mode)) + + return x + + def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor: + + sh = x.shape + x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4]) + + x = self.net(x) + x = x.reshape(sh[0], sh[1], -1) + + if not no_head: + x = self.calc_head(x) + + return x diff --git a/tests/min_tests.py b/tests/min_tests.py index d1f5384c2c..829f67c485 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -104,6 +104,7 @@ def run_testsuit(): "test_load_imaged", "test_load_spacing_orientation", "test_mednistdataset", + "test_milmodel", "test_mlp", "test_nifti_header_revise", "test_nifti_rw", diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py new file mode 100644 index 0000000000..9b21d4e2d1 --- /dev/null +++ b/tests/test_milmodel.py @@ -0,0 +1,91 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import MILModel +from monai.utils.module import optional_import +from tests.utils import test_script_save + +models, _ = optional_import("torchvision.models") + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +TEST_CASE_MILMODEL = [] +for num_classes in [1, 5]: + for mil_mode in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: + test_case = [ + {"num_classes": num_classes, "mil_mode": mil_mode, "pretrained": False}, + (1, 2, 3, 512, 512), + (1, num_classes), + ] + TEST_CASE_MILMODEL.append(test_case) + + +for trans_blocks in [1, 3]: + test_case = [ + {"num_classes": 5, "pretrained": False, "trans_blocks": trans_blocks, "trans_dropout": 0.5}, + (1, 2, 3, 512, 512), + (1, 5), + ] + TEST_CASE_MILMODEL.append(test_case) + +# torchvision backbone +TEST_CASE_MILMODEL.append( + [{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)] +) +TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)]) + +# custom backbone +backbone = models.densenet121(pretrained=False) +backbone_nfeatures = backbone.classifier.in_features +backbone.classifier = torch.nn.Identity() +TEST_CASE_MILMODEL.append( + [ + {"num_classes": 5, "backbone": backbone, "backbone_num_features": backbone_nfeatures, "pretrained": False}, + (2, 2, 3, 512, 512), + (2, 5), + ] +) + + +class TestMilModel(unittest.TestCase): + @parameterized.expand(TEST_CASE_MILMODEL) + def test_shape(self, input_param, input_shape, expected_shape): + net = MILModel(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape, dtype=torch.float).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_args(self): + with self.assertRaises(ValueError): + MILModel( + num_classes=5, + pretrained=False, + backbone="resnet50", + backbone_num_features=2048, + mil_mode="att_trans_pyramid", + ) + + def test_script(self): + input_param, input_shape, expected_shape = TEST_CASE_MILMODEL[0] + net = MILModel(**input_param) + test_data = torch.randn(input_shape, dtype=torch.float) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main()