-
Notifications
You must be signed in to change notification settings - Fork 1.4k
MIL Component - MILModel #3236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
MIL Component - MILModel #3236
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
002bc27
3251 Add dependency check in WSIReader (#3312)
Nic-Ma bd113cd
[DLMED] MILmodel PR
myron 0a220b8
small updates
myron d6859a8
fix jit issues
myron a4d9ff4
jit fix attempt
myron 098025a
Merge branch 'dev' into milmodel
bhashemian 51c76d4
removing Enum
myron c255897
Merge branch 'dev' into milmodel
bhashemian 3ccbbba
Merge branch 'dev' into milmodel
bhashemian e08e94e
Merge branch 'dev' into milmodel
bhashemian 0498981
Merge branch 'dev' into milmodel
bhashemian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,229 @@ | ||
| from typing import Dict, Optional, Union, cast | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from monai.utils.module import optional_import | ||
|
|
||
| models, _ = optional_import("torchvision.models") | ||
|
|
||
|
|
||
| class MILModel(nn.Module): | ||
| """ | ||
| Multiple Instance Learning (MIL) model, with a backbone classification model. | ||
|
|
||
| Args: | ||
| num_classes: number of output classes | ||
| mil_mode: MIL algorithm (either mean, max, att, att_trans, att_trans_pyramid) | ||
| Defaults to ``att``. | ||
| pretrained: init backbone with pretrained weights. | ||
| Defaults to ``True``. | ||
| backbone: Backbone classifier CNN. (either None, nn.Module that returns features, | ||
| or a string name of a torchvision model) | ||
| Defaults to ``None``, in which case ResNet50 is used. | ||
| backbone_num_features: Number of output features of the backbone CNN | ||
| Defaults to ``None`` (necessary only when using a custom backbone) | ||
|
|
||
| mil_mode: | ||
| "mean" - average features from all instances, equivalent to pure CNN (non MIL) | ||
| "max - retain only the instance with the max probability for loss calculation | ||
| "att" - attention based MIL https://arxiv.org/abs/1802.04712 | ||
| "att_trans" - transformer MIL https://arxiv.org/abs/2111.01556 | ||
| "att_trans_pyramid" - transformer pyramid MIL https://arxiv.org/abs/2111.01556 | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_classes: int, | ||
| mil_mode: str = "att", | ||
| pretrained: bool = True, | ||
| backbone: Optional[Union[str, nn.Module]] = None, | ||
| backbone_num_features: Optional[int] = None, | ||
| trans_blocks: int = 4, | ||
| trans_dropout: float = 0.0, | ||
| ) -> None: | ||
|
|
||
| super().__init__() | ||
|
|
||
| if num_classes <= 0: | ||
| raise ValueError("Number of classes must be positive: " + str(num_classes)) | ||
|
|
||
| if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: | ||
| raise ValueError("Unsupported mil_mode: " + str(mil_mode)) | ||
|
|
||
| self.mil_mode = mil_mode.lower() | ||
| print("MILModel with mode", mil_mode, "num_classes", num_classes) | ||
| self.attention = nn.Sequential() | ||
| self.transformer = None # type: Optional[nn.Module] | ||
|
|
||
| if backbone is None: | ||
|
|
||
| net = models.resnet50(pretrained=pretrained) | ||
| nfc = net.fc.in_features # save the number of final features | ||
| net.fc = torch.nn.Identity() # remove final linear layer | ||
|
|
||
| self.extra_outputs = {} # type: Dict[str, torch.Tensor] | ||
bhashemian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if mil_mode == "att_trans_pyramid": | ||
| # register hooks to capture outputs of intermediate layers | ||
| def forward_hook(layer_name): | ||
| def hook(module, input, output): | ||
| self.extra_outputs[layer_name] = output | ||
|
|
||
| return hook | ||
|
|
||
| net.layer1.register_forward_hook(forward_hook("layer1")) | ||
myron marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| net.layer2.register_forward_hook(forward_hook("layer2")) | ||
| net.layer3.register_forward_hook(forward_hook("layer3")) | ||
| net.layer4.register_forward_hook(forward_hook("layer4")) | ||
|
|
||
| elif isinstance(backbone, str): | ||
|
|
||
| # assume torchvision model string is provided | ||
| torch_model = getattr(models, backbone, None) | ||
| if torch_model is None: | ||
| raise ValueError("Unknown torch vision model" + str(backbone)) | ||
| net = torch_model(pretrained=pretrained) | ||
|
|
||
| if getattr(net, "fc", None) is not None: | ||
| nfc = net.fc.in_features # save the number of final features | ||
| net.fc = torch.nn.Identity() # remove final linear layer | ||
| else: | ||
| raise ValueError( | ||
| "Unable to detect FC layer for the torchvision model " + str(backbone), | ||
| ". Please initialize the backbone model manually.", | ||
| ) | ||
|
|
||
| elif isinstance(backbone, nn.Module): | ||
| # use a custom backbone | ||
| net = backbone | ||
bhashemian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| nfc = backbone_num_features | ||
|
|
||
| if backbone_num_features is None: | ||
| raise ValueError("Number of endencoder features must be provided for a custom backbone model") | ||
|
|
||
| else: | ||
| raise ValueError("Unsupported backbone") | ||
|
|
||
| if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]: | ||
| raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode)) | ||
|
|
||
| if self.mil_mode in ["mean", "max"]: | ||
| pass | ||
bhashemian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| elif self.mil_mode == "att": | ||
| self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) | ||
|
|
||
| elif self.mil_mode == "att_trans": | ||
| transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout) | ||
| self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks) | ||
| self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) | ||
|
|
||
| elif self.mil_mode == "att_trans_pyramid": | ||
|
|
||
| transformer_list = nn.ModuleList( | ||
| [ | ||
| nn.TransformerEncoder( | ||
| nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=trans_blocks | ||
| ), | ||
| nn.Sequential( | ||
| nn.Linear(768, 256), | ||
| nn.TransformerEncoder( | ||
| nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), | ||
| num_layers=trans_blocks, | ||
| ), | ||
| ), | ||
| nn.Sequential( | ||
| nn.Linear(1280, 256), | ||
| nn.TransformerEncoder( | ||
| nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), | ||
| num_layers=trans_blocks, | ||
| ), | ||
| ), | ||
| nn.TransformerEncoder( | ||
| nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout), | ||
| num_layers=trans_blocks, | ||
| ), | ||
| ] | ||
| ) | ||
| self.transformer = transformer_list | ||
| nfc = nfc + 256 | ||
| self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) | ||
|
|
||
| else: | ||
| raise ValueError("Unsupported mil_mode: " + str(mil_mode)) | ||
|
|
||
| self.myfc = nn.Linear(nfc, num_classes) | ||
myron marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.net = net | ||
|
|
||
| def calc_head(self, x: torch.Tensor) -> torch.Tensor: | ||
|
|
||
| sh = x.shape | ||
|
|
||
| if self.mil_mode == "mean": | ||
| x = self.myfc(x) | ||
| x = torch.mean(x, dim=1) | ||
|
|
||
| elif self.mil_mode == "max": | ||
| x = self.myfc(x) | ||
| x, _ = torch.max(x, dim=1) | ||
|
|
||
| elif self.mil_mode == "att": | ||
|
|
||
| a = self.attention(x) | ||
| a = torch.softmax(a, dim=1) | ||
| x = torch.sum(x * a, dim=1) | ||
|
|
||
| x = self.myfc(x) | ||
|
|
||
| elif self.mil_mode == "att_trans" and self.transformer is not None: | ||
|
|
||
| x = x.permute(1, 0, 2) | ||
| x = self.transformer(x) | ||
| x = x.permute(1, 0, 2) | ||
|
|
||
| a = self.attention(x) | ||
| a = torch.softmax(a, dim=1) | ||
| x = torch.sum(x * a, dim=1) | ||
|
|
||
| x = self.myfc(x) | ||
|
|
||
| elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None: | ||
|
|
||
| l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) | ||
| l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) | ||
| l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) | ||
| l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) | ||
|
|
||
| transformer_list = cast(nn.ModuleList, self.transformer) | ||
|
|
||
| x = transformer_list[0](l1) | ||
| x = transformer_list[1](torch.cat((x, l2), dim=2)) | ||
| x = transformer_list[2](torch.cat((x, l3), dim=2)) | ||
| x = transformer_list[3](torch.cat((x, l4), dim=2)) | ||
|
|
||
| x = x.permute(1, 0, 2) | ||
|
|
||
| a = self.attention(x) | ||
| a = torch.softmax(a, dim=1) | ||
| x = torch.sum(x * a, dim=1) | ||
|
|
||
| x = self.myfc(x) | ||
|
|
||
| else: | ||
| raise ValueError("Wrong model mode" + str(self.mil_mode)) | ||
|
|
||
| return x | ||
|
|
||
| def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor: | ||
|
|
||
| sh = x.shape | ||
| x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4]) | ||
|
|
||
| x = self.net(x) | ||
| x = x.reshape(sh[0], sh[1], -1) | ||
|
|
||
| if not no_head: | ||
| x = self.calc_head(x) | ||
|
|
||
| return x | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| # Copyright 2020 - 2021 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from parameterized import parameterized | ||
|
|
||
| from monai.networks import eval_mode | ||
| from monai.networks.nets import MILModel | ||
| from monai.utils.module import optional_import | ||
| from tests.utils import test_script_save | ||
|
|
||
| models, _ = optional_import("torchvision.models") | ||
|
|
||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
|
||
|
|
||
| TEST_CASE_MILMODEL = [] | ||
| for num_classes in [1, 5]: | ||
| for mil_mode in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]: | ||
| test_case = [ | ||
| {"num_classes": num_classes, "mil_mode": mil_mode, "pretrained": False}, | ||
| (1, 2, 3, 512, 512), | ||
| (1, num_classes), | ||
| ] | ||
| TEST_CASE_MILMODEL.append(test_case) | ||
|
|
||
|
|
||
| for trans_blocks in [1, 3]: | ||
| test_case = [ | ||
| {"num_classes": 5, "pretrained": False, "trans_blocks": trans_blocks, "trans_dropout": 0.5}, | ||
| (1, 2, 3, 512, 512), | ||
| (1, 5), | ||
| ] | ||
| TEST_CASE_MILMODEL.append(test_case) | ||
|
|
||
| # torchvision backbone | ||
| TEST_CASE_MILMODEL.append( | ||
| [{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)] | ||
| ) | ||
| TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)]) | ||
|
|
||
| # custom backbone | ||
| backbone = models.densenet121(pretrained=False) | ||
| backbone_nfeatures = backbone.classifier.in_features | ||
| backbone.classifier = torch.nn.Identity() | ||
| TEST_CASE_MILMODEL.append( | ||
| [ | ||
| {"num_classes": 5, "backbone": backbone, "backbone_num_features": backbone_nfeatures, "pretrained": False}, | ||
| (2, 2, 3, 512, 512), | ||
| (2, 5), | ||
| ] | ||
| ) | ||
|
|
||
|
|
||
| class TestMilModel(unittest.TestCase): | ||
| @parameterized.expand(TEST_CASE_MILMODEL) | ||
| def test_shape(self, input_param, input_shape, expected_shape): | ||
| net = MILModel(**input_param).to(device) | ||
| with eval_mode(net): | ||
| result = net(torch.randn(input_shape, dtype=torch.float).to(device)) | ||
| self.assertEqual(result.shape, expected_shape) | ||
|
|
||
| def test_ill_args(self): | ||
| with self.assertRaises(ValueError): | ||
| MILModel( | ||
| num_classes=5, | ||
| pretrained=False, | ||
| backbone="resnet50", | ||
| backbone_num_features=2048, | ||
| mil_mode="att_trans_pyramid", | ||
| ) | ||
|
|
||
| def test_script(self): | ||
| input_param, input_shape, expected_shape = TEST_CASE_MILMODEL[0] | ||
| net = MILModel(**input_param) | ||
| test_data = torch.randn(input_shape, dtype=torch.float) | ||
| test_script_save(net, test_data) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.