diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 62a1c82f2f..39889266d9 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -538,6 +538,11 @@ Nets .. autoclass:: TorchVisionFullyConvModel :members: +`MILModel` +~~~~~~~~~~ +.. autoclass:: MILModel + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index 213a86864f..c7fb7f3937 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -13,23 +13,22 @@ class MILModel(nn.Module): Multiple Instance Learning (MIL) model, with a backbone classification model. Args: - num_classes: number of output classes - mil_mode: MIL algorithm (either mean, max, att, att_trans, att_trans_pyramid) + num_classes: number of output classes. + mil_mode: MIL algorithm, available values: + "mean" - average features from all instances, equivalent to pure CNN (non MIL). + "max - retain only the instance with the max probability for loss calculation. + "att" - attention based MIL https://arxiv.org/abs/1802.04712. + "att_trans" - transformer MIL https://arxiv.org/abs/2111.01556. + "att_trans_pyramid" - transformer pyramid MIL https://arxiv.org/abs/2111.01556. Defaults to ``att``. - pretrained: init backbone with pretrained weights. - Defaults to ``True``. - backbone: Backbone classifier CNN. (either None, nn.Module that returns features, - or a string name of a torchvision model) + pretrained: init backbone with pretrained weights, defaults to ``True``. + backbone: Backbone classifier CNN (either None, nn.Module that returns features, + or a string name of a torchvision model). Defaults to ``None``, in which case ResNet50 is used. backbone_num_features: Number of output features of the backbone CNN Defaults to ``None`` (necessary only when using a custom backbone) - - mil_mode: - "mean" - average features from all instances, equivalent to pure CNN (non MIL) - "max - retain only the instance with the max probability for loss calculation - "att" - attention based MIL https://arxiv.org/abs/1802.04712 - "att_trans" - transformer MIL https://arxiv.org/abs/2111.01556 - "att_trans_pyramid" - transformer pyramid MIL https://arxiv.org/abs/2111.01556 + trans_blocks: number of the blocks in `TransformEncoder` layer. + trans_dropout: dropout rate in `TransformEncoder` layer. """ @@ -53,7 +52,6 @@ def __init__( raise ValueError("Unsupported mil_mode: " + str(mil_mode)) self.mil_mode = mil_mode.lower() - print("MILModel with mode", mil_mode, "num_classes", num_classes) self.attention = nn.Sequential() self.transformer = None # type: Optional[nn.Module]