From 8a0d218a25ce402ccab0752d4a27bf8567895105 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 25 Aug 2021 14:56:42 +0200 Subject: [PATCH 1/7] rename n_classes Signed-off-by: Jirka --- CHANGELOG.md | 2 ++ monai/losses/tversky.py | 2 +- monai/metrics/meandice.py | 2 +- monai/metrics/rocauc.py | 4 ++-- monai/networks/nets/netadapter.py | 8 +++---- monai/networks/nets/resnet.py | 6 ++--- monai/networks/nets/torchvision_fc.py | 12 +++++----- monai/transforms/post/array.py | 16 ++++++------- monai/transforms/post/dictionary.py | 12 +++++----- monai/transforms/utility/array.py | 6 ++--- tests/test_as_discrete.py | 8 +++---- tests/test_as_discreted.py | 6 ++--- tests/test_compute_roc_auc.py | 4 ++-- tests/test_handler_decollate_batch.py | 2 +- tests/test_handler_post_processing.py | 2 +- tests/test_handler_rocauc.py | 2 +- tests/test_handler_rocauc_dist.py | 2 +- tests/test_integration_classification_2d.py | 2 +- tests/test_net_adapter.py | 10 ++++---- tests/test_resnet.py | 6 ++--- tests/test_torchvision_fc_model.py | 26 ++++++++++----------- tests/test_torchvision_fully_conv_model.py | 14 +++++------ 22 files changed, 78 insertions(+), 76 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55a0ca11e9..bdbd23e7dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] +* renamed model's `n_classes` to `num_classes` + ## [0.6.0] - 2021-07-08 ### Added * 10 new transforms, a masked loss wrapper, and a `NetAdapter` for transfer learning diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 1d75b9e8cc..1cc0e1d8d7 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -155,7 +155,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.reduction == LossReduction.SUM.value: return torch.sum(score) # sum over the batch and channel dims if self.reduction == LossReduction.NONE.value: - return score # returns [N, n_classes] losses + return score # returns [N, num_classes] losses if self.reduction == LossReduction.MEAN.value: return torch.mean(score) raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 1bfd85a83e..226c106f7e 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -114,7 +114,7 @@ def compute_meandice( the predicted output. Defaults to True. Returns: - Dice scores per batch and per class, (shape [batch_size, n_classes]). + Dice scores per batch and per class, (shape [batch_size, num_classes]). Raises: ValueError: when `y_pred` and `y` have different shapes. diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index 3bd6c0d69c..c2679cc2ea 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -131,9 +131,9 @@ def compute_roc_auc( y_pred_ndim = y_pred.ndimension() y_ndim = y.ndimension() if y_pred_ndim not in (1, 2): - raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).") + raise ValueError("Predictions should be of shape (batch_size, num_classes) or (batch_size, ).") if y_ndim not in (1, 2): - raise ValueError("Targets should be of shape (batch_size, n_classes) or (batch_size, ).") + raise ValueError("Targets should be of shape (batch_size, num_classes) or (batch_size, ).") if y_pred_ndim == 2 and y_pred.shape[1] == 1: y_pred = y_pred.squeeze(dim=-1) y_pred_ndim = 1 diff --git a/monai/networks/nets/netadapter.py b/monai/networks/nets/netadapter.py index bc88454f87..602136ec3d 100644 --- a/monai/networks/nets/netadapter.py +++ b/monai/networks/nets/netadapter.py @@ -26,7 +26,7 @@ class NetAdapter(torch.nn.Module): model: a PyTorch model, support both 2D and 3D models. typically, it can be a pretrained model in Torchvision, like: ``resnet18``, ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, etc. more details: https://pytorch.org/vision/stable/models.html. - n_classes: number of classes for the last classification layer. Default to 1. + num_classes: number of classes for the last classification layer. Default to 1. dim: number of spatial dimensions, default to 2. in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer. use_conv: whether use convolutional layer to replace the last layer, default to False. @@ -41,7 +41,7 @@ class NetAdapter(torch.nn.Module): def __init__( self, model: torch.nn.Module, - n_classes: int = 1, + num_classes: int = 1, dim: int = 2, in_channels: Optional[int] = None, use_conv: bool = False, @@ -74,7 +74,7 @@ def __init__( # add 1x1 conv (it behaves like a FC layer) self.fc = Conv[Conv.CONV, dim]( in_channels=in_channels_, - out_channels=n_classes, + out_channels=num_classes, kernel_size=1, bias=bias, ) @@ -84,7 +84,7 @@ def __init__( # replace the out_features of FC layer self.fc = torch.nn.Linear( in_features=in_channels_, - out_features=n_classes, + out_features=num_classes, bias=bias, ) self.use_conv = use_conv diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index f34de563ce..0dd3546760 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -162,7 +162,7 @@ class ResNet(nn.Module): no_max_pool: bool argument to determine if to use maxpool layer. shortcut_type: which downsample block to use. widen_factor: widen output for each layer. - n_classes: number of output (classifications) + num_classes: number of output (classifications) """ def __init__( @@ -177,7 +177,7 @@ def __init__( no_max_pool: bool = False, shortcut_type: str = "B", widen_factor: float = 1.0, - n_classes: int = 400, + num_classes: int = 400, feed_forward: bool = True, ) -> None: @@ -215,7 +215,7 @@ def __init__( self.avgpool = avgp_type(block_avgpool[spatial_dims]) if feed_forward: - self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes) + self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes) for m in self.modules(): if isinstance(m, conv_type): diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 2c4c7c8c32..0d420252c6 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -29,7 +29,7 @@ class TorchVisionFCModel(NetAdapter): ``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``. model details: https://pytorch.org/vision/stable/models.html. - n_classes: number of classes for the last classification layer. Default to 1. + num_classes: number of classes for the last classification layer. Default to 1. dim: number of spatial dimensions, default to 2. in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer. use_conv: whether use convolutional layer to replace the last layer, default to False. @@ -44,7 +44,7 @@ class TorchVisionFCModel(NetAdapter): def __init__( self, model_name: str = "resnet18", - n_classes: int = 1, + num_classes: int = 1, dim: int = 2, in_channels: Optional[int] = None, use_conv: bool = False, @@ -59,7 +59,7 @@ def __init__( super().__init__( model=model, - n_classes=n_classes, + num_classes=num_classes, dim=dim, in_channels=in_channels, use_conv=use_conv, @@ -77,7 +77,7 @@ class TorchVisionFullyConvModel(TorchVisionFCModel): model_name: name of any torchvision with adaptive avg pooling and fully connected layer at the end. ``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``. - n_classes: number of classes for the last classification layer. Default to 1. + num_classes: number of classes for the last classification layer. Default to 1. pool_size: the kernel size for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to (7, 7). pool_stride: the stride for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to 1. pretrained: whether to use the imagenet pretrained weights. Default to False. @@ -90,14 +90,14 @@ class TorchVisionFullyConvModel(TorchVisionFCModel): def __init__( self, model_name: str = "resnet18", - n_classes: int = 1, + num_classes: int = 1, pool_size: Union[int, Tuple[int, int]] = (7, 7), pool_stride: Union[int, Tuple[int, int]] = 1, pretrained: bool = False, ): super().__init__( model_name=model_name, - n_classes=n_classes, + num_classes=num_classes, use_conv=True, pool=("avg", {"kernel_size": pool_size, "stride": pool_stride}), pretrained=pretrained, diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 7b3e7b4fd2..2c4f11a989 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -120,7 +120,7 @@ class AsDiscrete(Transform): Defaults to ``False``. to_onehot: whether to convert input data into the one-hot format. Defaults to ``False``. - n_classes: the number of classes to convert to One-Hot format. + num_classes: the number of classes to convert to One-Hot format. Defaults to ``None``. threshold_values: whether threshold the float value to int number 0 or 1. Defaults to ``False``. @@ -135,14 +135,14 @@ def __init__( self, argmax: bool = False, to_onehot: bool = False, - n_classes: Optional[int] = None, + num_classes: Optional[int] = None, threshold_values: bool = False, logit_thresh: float = 0.5, rounding: Optional[str] = None, ) -> None: self.argmax = argmax self.to_onehot = to_onehot - self.n_classes = n_classes + self.num_classes = num_classes self.threshold_values = threshold_values self.logit_thresh = logit_thresh self.rounding = rounding @@ -152,7 +152,7 @@ def __call__( img: torch.Tensor, argmax: Optional[bool] = None, to_onehot: Optional[bool] = None, - n_classes: Optional[int] = None, + num_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, rounding: Optional[str] = None, @@ -165,8 +165,8 @@ def __call__( Defaults to ``self.argmax``. to_onehot: whether to convert input data into the one-hot format. Defaults to ``self.to_onehot``. - n_classes: the number of classes to convert to One-Hot format. - Defaults to ``self.n_classes``. + num_classes: the number of classes to convert to One-Hot format. + Defaults to ``self.num_classes``. threshold_values: whether threshold the float value to int number 0 or 1. Defaults to ``self.threshold_values``. logit_thresh: the threshold value for thresholding operation.. @@ -179,9 +179,9 @@ def __call__( img = torch.argmax(img, dim=0, keepdim=True) if to_onehot or self.to_onehot: - _nclasses = self.n_classes if n_classes is None else n_classes + _nclasses = self.num_classes if num_classes is None else num_classes if not isinstance(_nclasses, int): - raise AssertionError("One of self.n_classes or n_classes must be an integer") + raise AssertionError("One of self.num_classes or num_classes must be an integer") img = one_hot(img, num_classes=_nclasses, dim=0) if threshold_values or self.threshold_values: diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index d4e039339b..9aea6c0cab 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -131,7 +131,7 @@ def __init__( keys: KeysCollection, argmax: Union[Sequence[bool], bool] = False, to_onehot: Union[Sequence[bool], bool] = False, - n_classes: Optional[Union[Sequence[int], int]] = None, + num_classes: Optional[Union[Sequence[int], int]] = None, threshold_values: Union[Sequence[bool], bool] = False, logit_thresh: Union[Sequence[float], float] = 0.5, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, @@ -145,7 +145,7 @@ def __init__( it also can be a sequence of bool, each element corresponds to a key in ``keys``. to_onehot: whether to convert input data into the one-hot format. Defaults to False. it also can be a sequence of bool, each element corresponds to a key in ``keys``. - n_classes: the number of classes to convert to One-Hot format. it also can be a + num_classes: the number of classes to convert to One-Hot format. it also can be a sequence of int, each element corresponds to a key in ``keys``. threshold_values: whether threshold the float value to int number 0 or 1, default is False. it also can be a sequence of bool, each element corresponds to a key in ``keys``. @@ -160,7 +160,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) - self.n_classes = ensure_tuple_rep(n_classes, len(self.keys)) + self.num_classes = ensure_tuple_rep(num_classes, len(self.keys)) self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys)) self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) self.rounding = ensure_tuple_rep(rounding, len(self.keys)) @@ -168,14 +168,14 @@ def __init__( def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) - for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh, rounding in self.key_iterator( - d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh, self.rounding + for key, argmax, to_onehot, num_classes, threshold_values, logit_thresh, rounding in self.key_iterator( + d, self.argmax, self.to_onehot, self.num_classes, self.threshold_values, self.logit_thresh, self.rounding ): d[key] = self.converter( d[key], argmax, to_onehot, - n_classes, + num_classes, threshold_values, logit_thresh, rounding, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index dd045817fb..2eb6c447c6 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -282,13 +282,13 @@ def __init__(self, channel_dim: int = 0) -> None: self.channel_dim = channel_dim def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: - n_classes = img.shape[self.channel_dim] - if n_classes <= 1: + num_classes = img.shape[self.channel_dim] + if num_classes <= 1: raise RuntimeError("input image does not contain multiple channels.") outputs = [] slices = [slice(None)] * len(img.shape) - for i in range(n_classes): + for i in range(num_classes): slices[self.channel_dim] = slice(i, i + 1) outputs.append(img[tuple(slices)]) diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index b87fafd8f3..bb9457a357 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -17,28 +17,28 @@ from monai.transforms import AsDiscrete TEST_CASE_1 = [ - {"argmax": True, "to_onehot": False, "n_classes": None, "threshold_values": False, "logit_thresh": 0.5}, + {"argmax": True, "to_onehot": False, "num_classes": None, "threshold_values": False, "logit_thresh": 0.5}, torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), torch.tensor([[[1.0, 1.0]]]), (1, 1, 2), ] TEST_CASE_2 = [ - {"argmax": True, "to_onehot": True, "n_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, + {"argmax": True, "to_onehot": True, "num_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), (2, 1, 2), ] TEST_CASE_3 = [ - {"argmax": False, "to_onehot": False, "n_classes": None, "threshold_values": True, "logit_thresh": 0.6}, + {"argmax": False, "to_onehot": False, "num_classes": None, "threshold_values": True, "logit_thresh": 0.6}, torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), (1, 2, 2), ] TEST_CASE_4 = [ - {"argmax": False, "to_onehot": True, "n_classes": 3}, + {"argmax": False, "to_onehot": True, "num_classes": 3}, torch.tensor(1), torch.tensor([0.0, 1.0, 0.0]), (3,), diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index ac594f0daa..90e98b297b 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -21,7 +21,7 @@ "keys": ["pred", "label"], "argmax": [True, False], "to_onehot": True, - "n_classes": 2, + "num_classes": 2, "threshold_values": False, "logit_thresh": 0.5, }, @@ -35,7 +35,7 @@ "keys": ["pred", "label"], "argmax": False, "to_onehot": False, - "n_classes": None, + "num_classes": None, "threshold_values": [True, False], "logit_thresh": 0.6, }, @@ -49,7 +49,7 @@ "keys": ["pred"], "argmax": True, "to_onehot": True, - "n_classes": 2, + "num_classes": 2, "threshold_values": False, "logit_thresh": 0.5, }, diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 79d62b6436..1cec357b93 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -87,7 +87,7 @@ class TestComputeROCAUC(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, n_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0) y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0) result = compute_roc_auc(y_pred=y_pred, y=y, average=average) @@ -96,7 +96,7 @@ def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, n_classes=2)]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, num_classes=2)]) y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)] y = [y_trans(i) for i in decollate_batch(y)] metric = ROCAUCMetric(average=average) diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py index bc74cf5328..8f0ffb2b5c 100644 --- a/tests/test_handler_decollate_batch.py +++ b/tests/test_handler_decollate_batch.py @@ -32,7 +32,7 @@ def test_compute(self): [ Activationsd(keys="pred", sigmoid=True), CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), + AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), ] ) ), diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index 552cde9eb1..e9d57128cb 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -26,7 +26,7 @@ "transform": Compose( [ CopyItemsd(keys="filename", times=1, names="filename_bak"), - AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), + AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, num_classes=2), ] ), "event": "iteration_completed", diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 46594eb629..5b80bc43eb 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -22,7 +22,7 @@ class TestHandlerROCAUC(unittest.TestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, n_classes=2) + to_onehot = AsDiscrete(to_onehot=True, num_classes=2) y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])] y = [torch.Tensor([0]), torch.Tensor([1])] diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index e728c80be6..8316d4c4b6 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -26,7 +26,7 @@ class DistributedROCAUC(DistTestCase): def test_compute(self): auc_metric = ROCAUC() act = Activations(softmax=True) - to_onehot = AsDiscrete(to_onehot=True, n_classes=2) + to_onehot = AsDiscrete(to_onehot=True, num_classes=2) device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" if dist.get_rank() == 0: diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index db435ee4e4..03b5571973 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -80,7 +80,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", [LoadImage(image_only=True), AddChannel(), Transpose(indices=[0, 2, 1]), ScaleIntensity(), ToTensor()] ) y_pred_trans = Compose([ToTensor(), Activations(softmax=True)]) - y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=True, n_classes=len(np.unique(train_y)))]) + y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=True, num_classes=len(np.unique(train_y)))]) auc_metric = ROCAUCMetric() # create train, val data loaders diff --git a/tests/test_net_adapter.py b/tests/test_net_adapter.py index 1ec3e26203..b2d55129a7 100644 --- a/tests/test_net_adapter.py +++ b/tests/test_net_adapter.py @@ -20,31 +20,31 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_0 = [ - {"n_classes": 1, "use_conv": True, "dim": 2}, + {"num_classes": 1, "use_conv": True, "dim": 2}, (2, 3, 224, 224), (2, 1, 8, 1), ] TEST_CASE_1 = [ - {"n_classes": 1, "use_conv": True, "dim": 3, "pool": None}, + {"num_classes": 1, "use_conv": True, "dim": 3, "pool": None}, (2, 3, 32, 32, 32), (2, 1, 1, 1, 1), ] TEST_CASE_2 = [ - {"n_classes": 5, "use_conv": True, "dim": 3, "pool": None}, + {"num_classes": 5, "use_conv": True, "dim": 3, "pool": None}, (2, 3, 32, 32, 32), (2, 5, 1, 1, 1), ] TEST_CASE_3 = [ - {"n_classes": 5, "use_conv": True, "pool": ("avg", {"kernel_size": 4, "stride": 1}), "dim": 3}, + {"num_classes": 5, "use_conv": True, "pool": ("avg", {"kernel_size": 4, "stride": 1}), "dim": 3}, (2, 3, 128, 128, 128), (2, 5, 5, 1, 1), ] TEST_CASE_4 = [ - {"n_classes": 5, "use_conv": False, "pool": ("adaptiveavg", {"output_size": (1, 1, 1)}), "dim": 3}, + {"num_classes": 5, "use_conv": False, "pool": ("adaptiveavg", {"output_size": (1, 1, 1)}), "dim": 3}, (2, 3, 32, 32, 32), (2, 5), ] diff --git a/tests/test_resnet.py b/tests/test_resnet.py index a20be298b9..c4ba5c2e16 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -31,19 +31,19 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ # 3D, batch 3, 2 input channel - {"pretrained": False, "spatial_dims": 3, "n_input_channels": 2, "n_classes": 3}, + {"pretrained": False, "spatial_dims": 3, "n_input_channels": 2, "num_classes": 3}, (3, 2, 32, 64, 48), (3, 3), ] TEST_CASE_2 = [ # 2D, batch 2, 1 input channel - {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "n_classes": 3}, + {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3}, (2, 1, 32, 64), (2, 3), ] TEST_CASE_3 = [ # 1D, batch 1, 2 input channels - {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "n_classes": 3}, + {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3}, (1, 2, 32), (1, 3), ] diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py index ae39968266..d6d3ea69c9 100644 --- a/tests/test_torchvision_fc_model.py +++ b/tests/test_torchvision_fc_model.py @@ -24,19 +24,19 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_0 = [ - {"model_name": "resnet18", "n_classes": 1, "use_conv": True, "pretrained": False}, + {"model_name": "resnet18", "num_classes": 1, "use_conv": True, "pretrained": False}, (2, 3, 224, 224), (2, 1, 1, 1), ] TEST_CASE_1 = [ - {"model_name": "resnet18", "n_classes": 1, "use_conv": True, "pretrained": False}, + {"model_name": "resnet18", "num_classes": 1, "use_conv": True, "pretrained": False}, (2, 3, 256, 256), (2, 1, 2, 2), ] TEST_CASE_2 = [ - {"model_name": "resnet101", "n_classes": 5, "use_conv": True, "pretrained": False}, + {"model_name": "resnet101", "num_classes": 5, "use_conv": True, "pretrained": False}, (2, 3, 256, 256), (2, 5, 2, 2), ] @@ -44,7 +44,7 @@ TEST_CASE_3 = [ { "model_name": "resnet101", - "n_classes": 5, + "num_classes": 5, "use_conv": True, "pool": ("avg", {"kernel_size": 6, "stride": 1}), "pretrained": False, @@ -54,60 +54,60 @@ ] TEST_CASE_4 = [ - {"model_name": "resnet18", "n_classes": 1, "use_conv": False, "pool": None, "pretrained": False}, + {"model_name": "resnet18", "num_classes": 1, "use_conv": False, "pool": None, "pretrained": False}, (2, 3, 224, 224), (2, 1), ] TEST_CASE_5 = [ - {"model_name": "resnet18", "n_classes": 1, "use_conv": False, "pool": None, "pretrained": False}, + {"model_name": "resnet18", "num_classes": 1, "use_conv": False, "pool": None, "pretrained": False}, (2, 3, 256, 256), (2, 1), ] TEST_CASE_6 = [ - {"model_name": "resnet101", "n_classes": 5, "use_conv": False, "pool": None, "pretrained": False}, + {"model_name": "resnet101", "num_classes": 5, "use_conv": False, "pool": None, "pretrained": False}, (2, 3, 256, 256), (2, 5), ] TEST_CASE_PRETRAINED_0 = [ - {"model_name": "resnet18", "n_classes": 1, "use_conv": True, "pretrained": True}, + {"model_name": "resnet18", "num_classes": 1, "use_conv": True, "pretrained": True}, (2, 3, 224, 224), (2, 1, 1, 1), -0.010419349186122417, ] TEST_CASE_PRETRAINED_1 = [ - {"model_name": "resnet18", "n_classes": 1, "use_conv": True, "pretrained": True}, + {"model_name": "resnet18", "num_classes": 1, "use_conv": True, "pretrained": True}, (2, 3, 256, 256), (2, 1, 2, 2), -0.010419349186122417, ] TEST_CASE_PRETRAINED_2 = [ - {"model_name": "resnet18", "n_classes": 5, "use_conv": True, "pretrained": True}, + {"model_name": "resnet18", "num_classes": 5, "use_conv": True, "pretrained": True}, (2, 3, 256, 256), (2, 5, 2, 2), -0.010419349186122417, ] TEST_CASE_PRETRAINED_3 = [ - {"model_name": "resnet18", "n_classes": 1, "use_conv": False, "pool": None, "pretrained": True}, + {"model_name": "resnet18", "num_classes": 1, "use_conv": False, "pool": None, "pretrained": True}, (2, 3, 224, 224), (2, 1), -0.010419349186122417, ] TEST_CASE_PRETRAINED_4 = [ - {"model_name": "resnet18", "n_classes": 1, "use_conv": False, "pool": None, "pretrained": True}, + {"model_name": "resnet18", "num_classes": 1, "use_conv": False, "pool": None, "pretrained": True}, (2, 3, 256, 256), (2, 1), -0.010419349186122417, ] TEST_CASE_PRETRAINED_5 = [ - {"model_name": "resnet18", "n_classes": 5, "use_conv": False, "pool": None, "pretrained": True}, + {"model_name": "resnet18", "num_classes": 5, "use_conv": False, "pool": None, "pretrained": True}, (2, 3, 256, 256), (2, 5), -0.010419349186122417, diff --git a/tests/test_torchvision_fully_conv_model.py b/tests/test_torchvision_fully_conv_model.py index 2c65f0d32c..af2c1458d3 100644 --- a/tests/test_torchvision_fully_conv_model.py +++ b/tests/test_torchvision_fully_conv_model.py @@ -24,45 +24,45 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_0 = [ - {"model_name": "resnet18", "n_classes": 1, "pretrained": False}, + {"model_name": "resnet18", "num_classes": 1, "pretrained": False}, (2, 3, 224, 224), (2, 1, 1, 1), ] TEST_CASE_1 = [ - {"model_name": "resnet18", "n_classes": 1, "pretrained": False}, + {"model_name": "resnet18", "num_classes": 1, "pretrained": False}, (2, 3, 256, 256), (2, 1, 2, 2), ] TEST_CASE_2 = [ - {"model_name": "resnet101", "n_classes": 5, "pretrained": False}, + {"model_name": "resnet101", "num_classes": 5, "pretrained": False}, (2, 3, 256, 256), (2, 5, 2, 2), ] TEST_CASE_3 = [ - {"model_name": "resnet101", "n_classes": 5, "pool_size": 6, "pretrained": False}, + {"model_name": "resnet101", "num_classes": 5, "pool_size": 6, "pretrained": False}, (2, 3, 224, 224), (2, 5, 2, 2), ] TEST_CASE_PRETRAINED_0 = [ - {"model_name": "resnet18", "n_classes": 1, "pretrained": True}, + {"model_name": "resnet18", "num_classes": 1, "pretrained": True}, (2, 3, 224, 224), (2, 1, 1, 1), -0.010419349186122417, ] TEST_CASE_PRETRAINED_1 = [ - {"model_name": "resnet18", "n_classes": 1, "pretrained": True}, + {"model_name": "resnet18", "num_classes": 1, "pretrained": True}, (2, 3, 256, 256), (2, 1, 2, 2), -0.010419349186122417, ] TEST_CASE_PRETRAINED_2 = [ - {"model_name": "resnet18", "n_classes": 5, "pretrained": True}, + {"model_name": "resnet18", "num_classes": 5, "pretrained": True}, (2, 3, 256, 256), (2, 5, 2, 2), -0.010419349186122417, From cf3f483af29b8e448d5905b44c73ba675f8f082f Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 26 Aug 2021 11:54:57 +0200 Subject: [PATCH 2/7] back compatibility Signed-off-by: Jirka --- monai/networks/nets/netadapter.py | 6 ++++++ monai/networks/nets/resnet.py | 9 ++++++++- monai/networks/nets/torchvision_fc.py | 12 +++++++++++- monai/transforms/post/array.py | 12 +++++++++++- monai/transforms/post/dictionary.py | 7 ++++++- 5 files changed, 42 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/netadapter.py b/monai/networks/nets/netadapter.py index 602136ec3d..0a57bf0780 100644 --- a/monai/networks/nets/netadapter.py +++ b/monai/networks/nets/netadapter.py @@ -14,6 +14,7 @@ import torch from monai.networks.layers import Conv, get_pool_layer +from monai.utils import deprecated_arg class NetAdapter(torch.nn.Module): @@ -38,6 +39,7 @@ class NetAdapter(torch.nn.Module): """ + @deprecated_arg("n_classes") def __init__( self, model: torch.nn.Module, @@ -47,8 +49,12 @@ def __init__( use_conv: bool = False, pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}), bias: bool = True, + n_classes: Optional[int] = None, ): super().__init__() + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes == 1: + num_classes = n_classes layers = list(model.children()) orig_fc = layers[-1] in_channels_: int diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 0dd3546760..cbfe8300db 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -10,7 +10,7 @@ # limitations under the License. from functools import partial -from typing import Any, Callable, List, Type, Union +from typing import Any, Callable, List, Type, Union, Optional import torch import torch.nn as nn @@ -20,6 +20,8 @@ __all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] +from monai.utils import deprecated_arg + def get_inplanes(): return [64, 128, 256, 512] @@ -165,6 +167,7 @@ class ResNet(nn.Module): num_classes: number of output (classifications) """ + @deprecated_arg("n_classes") def __init__( self, block: Type[Union[ResNetBlock, ResNetBottleneck]], @@ -179,9 +182,13 @@ def __init__( widen_factor: float = 1.0, num_classes: int = 400, feed_forward: bool = True, + n_classes: Optional[int] = None, ) -> None: super(ResNet, self).__init__() + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes == 400: + num_classes = n_classes conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims] diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 0d420252c6..c681d3b066 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -12,7 +12,7 @@ from typing import Any, Dict, Optional, Tuple, Union from monai.networks.nets import NetAdapter -from monai.utils import deprecated, optional_import +from monai.utils import deprecated, optional_import, deprecated_arg models, _ = optional_import("torchvision.models") @@ -41,6 +41,7 @@ class TorchVisionFCModel(NetAdapter): pretrained: whether to use the imagenet pretrained weights. Default to False. """ + @deprecated_arg("n_classes") def __init__( self, model_name: str = "resnet18", @@ -51,7 +52,11 @@ def __init__( pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}), bias: bool = True, pretrained: bool = False, + n_classes: Optional[int] = None, ): + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes == 1: + num_classes = n_classes model = getattr(models, model_name)(pretrained=pretrained) # check if the model is compatible, should have a FC layer at the end if not str(list(model.children())[-1]).startswith("Linear"): @@ -87,6 +92,7 @@ class TorchVisionFullyConvModel(TorchVisionFCModel): """ + @deprecated_arg("n_classes") def __init__( self, model_name: str = "resnet18", @@ -94,7 +100,11 @@ def __init__( pool_size: Union[int, Tuple[int, int]] = (7, 7), pool_stride: Union[int, Tuple[int, int]] = 1, pretrained: bool = False, + n_classes: Optional[int] = None, ): + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes == 1: + num_classes = n_classes super().__init__( model_name=model_name, num_classes=num_classes, diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 2c4f11a989..4f11d3bca3 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -25,7 +25,7 @@ from monai.networks.layers import GaussianFilter from monai.transforms.transform import Transform from monai.transforms.utils import fill_holes, get_largest_connected_component_mask -from monai.utils import ensure_tuple, look_up_option +from monai.utils import ensure_tuple, look_up_option, deprecated_arg __all__ = [ "Activations", @@ -131,6 +131,7 @@ class AsDiscrete(Transform): """ + @deprecated_arg("n_classes") def __init__( self, argmax: bool = False, @@ -139,7 +140,11 @@ def __init__( threshold_values: bool = False, logit_thresh: float = 0.5, rounding: Optional[str] = None, + n_classes: Optional[int] = None, ) -> None: + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes is None: + num_classes = n_classes self.argmax = argmax self.to_onehot = to_onehot self.num_classes = num_classes @@ -147,6 +152,7 @@ def __init__( self.logit_thresh = logit_thresh self.rounding = rounding + @deprecated_arg("n_classes") def __call__( self, img: torch.Tensor, @@ -156,6 +162,7 @@ def __call__( threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, rounding: Optional[str] = None, + n_classes: Optional[int] = None, ) -> torch.Tensor: """ Args: @@ -175,6 +182,9 @@ def __call__( available options: ["torchrounding"]. """ + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes is None: + num_classes = n_classes if argmax or self.argmax: img = torch.argmax(img, dim=0, keepdim=True) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 9aea6c0cab..573cab81b9 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -39,7 +39,7 @@ from monai.transforms.transform import MapTransform from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils import ensure_tuple, ensure_tuple_rep +from monai.utils import ensure_tuple, ensure_tuple_rep, deprecated_arg from monai.utils.enums import InverseKeys __all__ = [ @@ -126,6 +126,7 @@ class AsDiscreted(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`. """ + @deprecated_arg("n_classes") def __init__( self, keys: KeysCollection, @@ -136,6 +137,7 @@ def __init__( logit_thresh: Union[Sequence[float], float] = 0.5, rounding: Union[Sequence[Optional[str]], Optional[str]] = None, allow_missing_keys: bool = False, + n_classes: Optional[int] = None, ) -> None: """ Args: @@ -157,6 +159,9 @@ def __init__( allow_missing_keys: don't raise exception if key is missing. """ + # in case the new num_classes is default but you still call deprecated n_classes + if n_classes is not None and num_classes is None: + num_classes = n_classes super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys)) From 0f806d37477baa660afccd5f2d003543540321f5 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 26 Aug 2021 14:26:57 +0200 Subject: [PATCH 3/7] fix isort Signed-off-by: Jirka --- monai/networks/nets/resnet.py | 2 +- monai/networks/nets/torchvision_fc.py | 2 +- monai/transforms/post/array.py | 2 +- monai/transforms/post/dictionary.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index cbfe8300db..298276f602 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -10,7 +10,7 @@ # limitations under the License. from functools import partial -from typing import Any, Callable, List, Type, Union, Optional +from typing import Any, Callable, List, Optional, Type, Union import torch import torch.nn as nn diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index c681d3b066..cc1f422c17 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -12,7 +12,7 @@ from typing import Any, Dict, Optional, Tuple, Union from monai.networks.nets import NetAdapter -from monai.utils import deprecated, optional_import, deprecated_arg +from monai.utils import deprecated, deprecated_arg, optional_import models, _ = optional_import("torchvision.models") diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 4f11d3bca3..bc22f13c30 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -25,7 +25,7 @@ from monai.networks.layers import GaussianFilter from monai.transforms.transform import Transform from monai.transforms.utils import fill_holes, get_largest_connected_component_mask -from monai.utils import ensure_tuple, look_up_option, deprecated_arg +from monai.utils import deprecated_arg, ensure_tuple, look_up_option __all__ = [ "Activations", diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 573cab81b9..2dde3fa9c4 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -39,7 +39,7 @@ from monai.transforms.transform import MapTransform from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils import ensure_tuple, ensure_tuple_rep, deprecated_arg +from monai.utils import deprecated_arg, ensure_tuple, ensure_tuple_rep from monai.utils.enums import InverseKeys __all__ = [ From f3432d69eeb16182d85cd1f0d937784b18ba49c9 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 3 Sep 2021 11:47:52 -0400 Subject: [PATCH 4/7] Split On Grid (#2879) * Implement SplitOnGrid Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Implement dictionary-based SplitOnGrid Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update inits Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Change imports Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Update input logic in SplitOnGrid) Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Add unittests for SplitOnGrid and SplitOnGridDict Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Sort import Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove imports Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Address comments Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Remove optional Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> * Address thread safety issues Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/apps.rst | 8 ++ monai/apps/pathology/transforms/__init__.py | 2 + .../pathology/transforms/spatial/__init__.py | 13 ++ .../pathology/transforms/spatial/array.py | 77 ++++++++++ .../transforms/spatial/dictionary.py | 56 ++++++++ tests/test_split_on_grid.py | 131 ++++++++++++++++++ tests/test_split_on_grid_dict.py | 131 ++++++++++++++++++ 7 files changed, 418 insertions(+) create mode 100644 monai/apps/pathology/transforms/spatial/__init__.py create mode 100644 monai/apps/pathology/transforms/spatial/array.py create mode 100644 monai/apps/pathology/transforms/spatial/dictionary.py create mode 100644 tests/test_split_on_grid.py create mode 100644 tests/test_split_on_grid_dict.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 1a2efeff48..11d60767ec 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -110,3 +110,11 @@ Clara MMARs :members: .. autoclass:: NormalizeHEStainsd :members: + +.. automodule:: monai.apps.pathology.transforms.spatial.array +.. autoclass:: SplitOnGrid + :members: + +.. automodule:: monai.apps.pathology.transforms.spatial.dictionary +.. autoclass:: SplitOnGridd + :members: diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 0df016244b..1be96b8e34 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .spatial.array import SplitOnGrid +from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict from .stain.array import ExtractHEStains, NormalizeHEStains from .stain.dictionary import ( ExtractHEStainsd, diff --git a/monai/apps/pathology/transforms/spatial/__init__.py b/monai/apps/pathology/transforms/spatial/__init__.py new file mode 100644 index 0000000000..07ba222ab0 --- /dev/null +++ b/monai/apps/pathology/transforms/spatial/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .array import SplitOnGrid +from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py new file mode 100644 index 0000000000..53e0c63715 --- /dev/null +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -0,0 +1,77 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch + +from monai.transforms.transform import Transform + +__all__ = ["SplitOnGrid"] + + +class SplitOnGrid(Transform): + """ + Split the image into patches based on the provided grid shape. + This transform works only with torch.Tensor inputs. + + Args: + grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches. + If it's an integer, the value will be repeated for each dimension. Default is 2x2 + patch_size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is (0, 0), where the patch size will be infered from the grid shape. + + Note: the shape of the input image is infered based on the first image used. + """ + + def __init__( + self, + grid_size: Union[int, Tuple[int, int]] = (2, 2), + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + ): + # Grid size + if isinstance(grid_size, int): + self.grid_size = (grid_size, grid_size) + else: + self.grid_size = grid_size + # Patch size + self.patch_size = None + if isinstance(patch_size, int): + self.patch_size = (patch_size, patch_size) + else: + self.patch_size = patch_size + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + if self.grid_size == (1, 1) and self.patch_size is None: + return torch.stack([image]) + patch_size, steps = self.get_params(image.shape[1:]) + patches = ( + image.unfold(1, patch_size[0], steps[0]) + .unfold(2, patch_size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) + .contiguous() + ) + return patches + + def get_params(self, image_size): + if self.patch_size is None: + patch_size = tuple(image_size[i] // self.grid_size[i] for i in range(2)) + else: + patch_size = self.patch_size + + steps = tuple( + (image_size[i] - patch_size[i]) // (self.grid_size[i] - 1) if self.grid_size[i] > 1 else image_size[i] + for i in range(2) + ) + + return patch_size, steps diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py new file mode 100644 index 0000000000..10b01a39de --- /dev/null +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -0,0 +1,56 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Hashable, Mapping, Optional, Tuple, Union + +import torch + +from monai.config import KeysCollection +from monai.transforms.transform import MapTransform + +from .array import SplitOnGrid + +__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"] + + +class SplitOnGridd(MapTransform): + """ + Split the image into patches based on the provided grid shape. + This transform works only with torch.Tensor inputs. + + Args: + grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches. + If it's an integer, the value will be repeated for each dimension. Default is 2x2 + patch_size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is (0, 0), where the patch size will be infered from the grid shape. + + Note: the shape of the input image is infered based on the first image used. + """ + + def __init__( + self, + keys: KeysCollection, + grid_size: Union[int, Tuple[int, int]] = (2, 2), + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) + self.splitter = SplitOnGrid(grid_size=grid_size, patch_size=patch_size) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.splitter(d[key]) + return d + + +SplitOnGridDict = SplitOnGridD = SplitOnGridd diff --git a/tests/test_split_on_grid.py b/tests/test_split_on_grid.py new file mode 100644 index 0000000000..a187835e7b --- /dev/null +++ b/tests/test_split_on_grid.py @@ -0,0 +1,131 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps.pathology.transforms import SplitOnGrid + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [ + {"grid_size": (2, 2)}, + A, + torch.stack([A11, A12, A21, A22]), +] + +TEST_CASE_1 = [ + {"grid_size": (2, 1)}, + A, + torch.stack([A1, A2]), +] + +TEST_CASE_2 = [ + {"grid_size": (1, 2)}, + A1, + torch.stack([A11, A12]), +] + +TEST_CASE_3 = [ + {"grid_size": (1, 2)}, + A2, + torch.stack([A21, A22]), +] + +TEST_CASE_4 = [ + {"grid_size": (1, 1), "patch_size": (2, 2)}, + A, + torch.stack([A11]), +] + +TEST_CASE_5 = [ + {"grid_size": 1, "patch_size": 4}, + A, + torch.stack([A]), +] + +TEST_CASE_6 = [ + {"grid_size": 2, "patch_size": 2}, + A, + torch.stack([A11, A12, A21, A22]), +] + +TEST_CASE_7 = [ + {"grid_size": 1}, + A, + torch.stack([A]), +] + +TEST_CASE_MC_0 = [ + {"grid_size": (2, 2)}, + [A, A], + [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], +] + + +TEST_CASE_MC_1 = [ + {"grid_size": (2, 1)}, + [A] * 5, + [torch.stack([A1, A2])] * 5, +] + + +TEST_CASE_MC_2 = [ + {"grid_size": (1, 2)}, + [A1, A2], + [torch.stack([A11, A12]), torch.stack([A21, A22])], +] + + +class TestSplitOnGrid(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + ] + ) + def test_split_pathce_single_call(self, input_parameters, img, expected): + splitter = SplitOnGrid(**input_parameters) + output = splitter(img) + np.testing.assert_equal(output.numpy(), expected.numpy()) + + @parameterized.expand( + [ + TEST_CASE_MC_0, + TEST_CASE_MC_1, + TEST_CASE_MC_2, + ] + ) + def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + splitter = SplitOnGrid(**input_parameters) + for img, expected in zip(img_list, expected_list): + output = splitter(img) + np.testing.assert_equal(output.numpy(), expected.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_split_on_grid_dict.py b/tests/test_split_on_grid_dict.py new file mode 100644 index 0000000000..96ec095423 --- /dev/null +++ b/tests/test_split_on_grid_dict.py @@ -0,0 +1,131 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps.pathology.transforms import SplitOnGridDict + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [ + {"keys": "image", "grid_size": (2, 2)}, + {"image": A}, + torch.stack([A11, A12, A21, A22]), +] + +TEST_CASE_1 = [ + {"keys": "image", "grid_size": (2, 1)}, + {"image": A}, + torch.stack([A1, A2]), +] + +TEST_CASE_2 = [ + {"keys": "image", "grid_size": (1, 2)}, + {"image": A1}, + torch.stack([A11, A12]), +] + +TEST_CASE_3 = [ + {"keys": "image", "grid_size": (1, 2)}, + {"image": A2}, + torch.stack([A21, A22]), +] + +TEST_CASE_4 = [ + {"keys": "image", "grid_size": (1, 1), "patch_size": (2, 2)}, + {"image": A}, + torch.stack([A11]), +] + +TEST_CASE_5 = [ + {"keys": "image", "grid_size": 1, "patch_size": 4}, + {"image": A}, + torch.stack([A]), +] + +TEST_CASE_6 = [ + {"keys": "image", "grid_size": 2, "patch_size": 2}, + {"image": A}, + torch.stack([A11, A12, A21, A22]), +] + +TEST_CASE_7 = [ + {"keys": "image", "grid_size": 1}, + {"image": A}, + torch.stack([A]), +] + +TEST_CASE_MC_0 = [ + {"keys": "image", "grid_size": (2, 2)}, + [{"image": A}, {"image": A}], + [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], +] + + +TEST_CASE_MC_1 = [ + {"keys": "image", "grid_size": (2, 1)}, + [{"image": A}] * 5, + [torch.stack([A1, A2])] * 5, +] + + +TEST_CASE_MC_2 = [ + {"keys": "image", "grid_size": (1, 2)}, + [{"image": A1}, {"image": A2}], + [torch.stack([A11, A12]), torch.stack([A21, A22])], +] + + +class TestSplitOnGridDict(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + ] + ) + def test_split_pathce_single_call(self, input_parameters, img_dict, expected): + splitter = SplitOnGridDict(**input_parameters) + output = splitter(img_dict)[input_parameters["keys"]] + np.testing.assert_equal(output.numpy(), expected.numpy()) + + @parameterized.expand( + [ + TEST_CASE_MC_0, + TEST_CASE_MC_1, + TEST_CASE_MC_2, + ] + ) + def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + splitter = SplitOnGridDict(**input_parameters) + for img_dict, expected in zip(img_list, expected_list): + output = splitter(img_dict)[input_parameters["keys"]] + np.testing.assert_equal(output.numpy(), expected.numpy()) + + +if __name__ == "__main__": + unittest.main() From d257e6968f54a3effd71452fbcd7c5a8d8337fdb Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 6 Sep 2021 23:59:38 +0800 Subject: [PATCH 5/7] 2885 2899 Add support to support subclass of Tensor or numpy array (#2900) * [DLMED] fix type issue Signed-off-by: Nic Ma * [DLMED] fix test Signed-off-by: Nic Ma * [DLMED] simplify the change Signed-off-by: Nic Ma * [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/utils/type_conversion.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 14300eeca0..b0ce187e38 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -171,7 +171,14 @@ def convert_data_type( Returns: modified data, orig_type, orig_device """ - orig_type = type(data) + orig_type: Any + if isinstance(data, torch.Tensor): + orig_type = torch.Tensor + elif isinstance(data, np.ndarray): + orig_type = np.ndarray + else: + orig_type = type(data) + orig_device = data.device if isinstance(data, torch.Tensor) else None output_type = output_type or orig_type From 76f490d3a7589d18b10a88c0bbfd0e57018250b9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 6 Sep 2021 17:55:09 +0100 Subject: [PATCH 6/7] fixes resnet type Signed-off-by: Wenqi Li --- monai/networks/nets/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 298276f602..4dcb86c7a3 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -310,7 +310,7 @@ def _resnet( progress: bool, **kwargs: Any, ) -> ResNet: - model = ResNet(block, layers, block_inplanes, **kwargs) + model: ResNet = ResNet(block, layers, block_inplanes, **kwargs) if pretrained: # Author of paper zipped the state_dict on googledrive, # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). From 9daba91d10b5ba7c8c4a0c7db049220ac0eb3f01 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 6 Sep 2021 18:28:22 +0100 Subject: [PATCH 7/7] since 0.6 Signed-off-by: Wenqi Li --- monai/networks/nets/netadapter.py | 2 +- monai/networks/nets/resnet.py | 2 +- monai/networks/nets/torchvision_fc.py | 4 ++-- monai/transforms/post/array.py | 4 ++-- monai/transforms/post/dictionary.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/netadapter.py b/monai/networks/nets/netadapter.py index 0a57bf0780..80288f7945 100644 --- a/monai/networks/nets/netadapter.py +++ b/monai/networks/nets/netadapter.py @@ -39,7 +39,7 @@ class NetAdapter(torch.nn.Module): """ - @deprecated_arg("n_classes") + @deprecated_arg("n_classes", since="0.6") def __init__( self, model: torch.nn.Module, diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 4dcb86c7a3..a5e6b7ab81 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -167,7 +167,7 @@ class ResNet(nn.Module): num_classes: number of output (classifications) """ - @deprecated_arg("n_classes") + @deprecated_arg("n_classes", since="0.6") def __init__( self, block: Type[Union[ResNetBlock, ResNetBottleneck]], diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index cc1f422c17..1619f877e7 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -41,7 +41,7 @@ class TorchVisionFCModel(NetAdapter): pretrained: whether to use the imagenet pretrained weights. Default to False. """ - @deprecated_arg("n_classes") + @deprecated_arg("n_classes", since="0.6") def __init__( self, model_name: str = "resnet18", @@ -92,7 +92,7 @@ class TorchVisionFullyConvModel(TorchVisionFCModel): """ - @deprecated_arg("n_classes") + @deprecated_arg("n_classes", since="0.6") def __init__( self, model_name: str = "resnet18", diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index bc22f13c30..631947025c 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -131,7 +131,7 @@ class AsDiscrete(Transform): """ - @deprecated_arg("n_classes") + @deprecated_arg("n_classes", since="0.6") def __init__( self, argmax: bool = False, @@ -152,7 +152,7 @@ def __init__( self.logit_thresh = logit_thresh self.rounding = rounding - @deprecated_arg("n_classes") + @deprecated_arg("n_classes", since="0.6") def __call__( self, img: torch.Tensor, diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 2dde3fa9c4..2fc3993e3e 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -126,7 +126,7 @@ class AsDiscreted(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`. """ - @deprecated_arg("n_classes") + @deprecated_arg("n_classes", since="0.6") def __init__( self, keys: KeysCollection,