Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [Unreleased]
* renamed model's `n_classes` to `num_classes`

## [0.6.0] - 2021-07-08
### Added
* 10 new transforms, a masked loss wrapper, and a `NetAdapter` for transfer learning
Expand Down
2 changes: 1 addition & 1 deletion monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.reduction == LossReduction.SUM.value:
return torch.sum(score) # sum over the batch and channel dims
if self.reduction == LossReduction.NONE.value:
return score # returns [N, n_classes] losses
return score # returns [N, num_classes] losses
if self.reduction == LossReduction.MEAN.value:
return torch.mean(score)
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
2 changes: 1 addition & 1 deletion monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def compute_meandice(
the predicted output. Defaults to True.

Returns:
Dice scores per batch and per class, (shape [batch_size, n_classes]).
Dice scores per batch and per class, (shape [batch_size, num_classes]).

Raises:
ValueError: when `y_pred` and `y` have different shapes.
Expand Down
4 changes: 2 additions & 2 deletions monai/metrics/rocauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def compute_roc_auc(
y_pred_ndim = y_pred.ndimension()
y_ndim = y.ndimension()
if y_pred_ndim not in (1, 2):
raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).")
raise ValueError("Predictions should be of shape (batch_size, num_classes) or (batch_size, ).")
if y_ndim not in (1, 2):
raise ValueError("Targets should be of shape (batch_size, n_classes) or (batch_size, ).")
raise ValueError("Targets should be of shape (batch_size, num_classes) or (batch_size, ).")
if y_pred_ndim == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
y_pred_ndim = 1
Expand Down
14 changes: 10 additions & 4 deletions monai/networks/nets/netadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from monai.networks.layers import Conv, get_pool_layer
from monai.utils import deprecated_arg


class NetAdapter(torch.nn.Module):
Expand All @@ -26,7 +27,7 @@ class NetAdapter(torch.nn.Module):
model: a PyTorch model, support both 2D and 3D models. typically, it can be a pretrained model in Torchvision,
like: ``resnet18``, ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, etc.
more details: https://pytorch.org/vision/stable/models.html.
n_classes: number of classes for the last classification layer. Default to 1.
num_classes: number of classes for the last classification layer. Default to 1.
dim: number of spatial dimensions, default to 2.
in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
use_conv: whether use convolutional layer to replace the last layer, default to False.
Expand All @@ -38,17 +39,22 @@ class NetAdapter(torch.nn.Module):

"""

@deprecated_arg("n_classes", since="0.6")
def __init__(
self,
model: torch.nn.Module,
n_classes: int = 1,
num_classes: int = 1,
dim: int = 2,
in_channels: Optional[int] = None,
use_conv: bool = False,
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
bias: bool = True,
n_classes: Optional[int] = None,
):
super().__init__()
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes == 1:
num_classes = n_classes
layers = list(model.children())
orig_fc = layers[-1]
in_channels_: int
Expand All @@ -74,7 +80,7 @@ def __init__(
# add 1x1 conv (it behaves like a FC layer)
self.fc = Conv[Conv.CONV, dim](
in_channels=in_channels_,
out_channels=n_classes,
out_channels=num_classes,
kernel_size=1,
bias=bias,
)
Expand All @@ -84,7 +90,7 @@ def __init__(
# replace the out_features of FC layer
self.fc = torch.nn.Linear(
in_features=in_channels_,
out_features=n_classes,
out_features=num_classes,
bias=bias,
)
self.use_conv = use_conv
Expand Down
17 changes: 12 additions & 5 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from functools import partial
from typing import Any, Callable, List, Type, Union
from typing import Any, Callable, List, Optional, Type, Union

import torch
import torch.nn as nn
Expand All @@ -20,6 +20,8 @@

__all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]

from monai.utils import deprecated_arg


def get_inplanes():
return [64, 128, 256, 512]
Expand Down Expand Up @@ -162,9 +164,10 @@ class ResNet(nn.Module):
no_max_pool: bool argument to determine if to use maxpool layer.
shortcut_type: which downsample block to use.
widen_factor: widen output for each layer.
n_classes: number of output (classifications)
num_classes: number of output (classifications)
"""

@deprecated_arg("n_classes", since="0.6")
def __init__(
self,
block: Type[Union[ResNetBlock, ResNetBottleneck]],
Expand All @@ -177,11 +180,15 @@ def __init__(
no_max_pool: bool = False,
shortcut_type: str = "B",
widen_factor: float = 1.0,
n_classes: int = 400,
num_classes: int = 400,
feed_forward: bool = True,
n_classes: Optional[int] = None,
) -> None:

super(ResNet, self).__init__()
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes == 400:
num_classes = n_classes

conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]
norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[Norm.BATCH, spatial_dims]
Expand Down Expand Up @@ -215,7 +222,7 @@ def __init__(
self.avgpool = avgp_type(block_avgpool[spatial_dims])

if feed_forward:
self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes)
self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, conv_type):
Expand Down Expand Up @@ -303,7 +310,7 @@ def _resnet(
progress: bool,
**kwargs: Any,
) -> ResNet:
model = ResNet(block, layers, block_inplanes, **kwargs)
model: ResNet = ResNet(block, layers, block_inplanes, **kwargs)
if pretrained:
# Author of paper zipped the state_dict on googledrive,
# so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
Expand Down
24 changes: 17 additions & 7 deletions monai/networks/nets/torchvision_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, Dict, Optional, Tuple, Union

from monai.networks.nets import NetAdapter
from monai.utils import deprecated, optional_import
from monai.utils import deprecated, deprecated_arg, optional_import

models, _ = optional_import("torchvision.models")

Expand All @@ -29,7 +29,7 @@ class TorchVisionFCModel(NetAdapter):
``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``,
``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``.
model details: https://pytorch.org/vision/stable/models.html.
n_classes: number of classes for the last classification layer. Default to 1.
num_classes: number of classes for the last classification layer. Default to 1.
dim: number of spatial dimensions, default to 2.
in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
use_conv: whether use convolutional layer to replace the last layer, default to False.
Expand All @@ -41,25 +41,30 @@ class TorchVisionFCModel(NetAdapter):
pretrained: whether to use the imagenet pretrained weights. Default to False.
"""

@deprecated_arg("n_classes", since="0.6")
def __init__(
self,
model_name: str = "resnet18",
n_classes: int = 1,
num_classes: int = 1,
dim: int = 2,
in_channels: Optional[int] = None,
use_conv: bool = False,
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
bias: bool = True,
pretrained: bool = False,
n_classes: Optional[int] = None,
):
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes == 1:
num_classes = n_classes
model = getattr(models, model_name)(pretrained=pretrained)
# check if the model is compatible, should have a FC layer at the end
if not str(list(model.children())[-1]).startswith("Linear"):
raise ValueError(f"Model ['{model_name}'] does not have a Linear layer at the end.")

super().__init__(
model=model,
n_classes=n_classes,
num_classes=num_classes,
dim=dim,
in_channels=in_channels,
use_conv=use_conv,
Expand All @@ -77,7 +82,7 @@ class TorchVisionFullyConvModel(TorchVisionFCModel):
model_name: name of any torchvision with adaptive avg pooling and fully connected layer at the end.
``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``,
``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``.
n_classes: number of classes for the last classification layer. Default to 1.
num_classes: number of classes for the last classification layer. Default to 1.
pool_size: the kernel size for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to (7, 7).
pool_stride: the stride for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to 1.
pretrained: whether to use the imagenet pretrained weights. Default to False.
Expand All @@ -87,17 +92,22 @@ class TorchVisionFullyConvModel(TorchVisionFCModel):

"""

@deprecated_arg("n_classes", since="0.6")
def __init__(
self,
model_name: str = "resnet18",
n_classes: int = 1,
num_classes: int = 1,
pool_size: Union[int, Tuple[int, int]] = (7, 7),
pool_stride: Union[int, Tuple[int, int]] = 1,
pretrained: bool = False,
n_classes: Optional[int] = None,
):
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes == 1:
num_classes = n_classes
super().__init__(
model_name=model_name,
n_classes=n_classes,
num_classes=num_classes,
use_conv=True,
pool=("avg", {"kernel_size": pool_size, "stride": pool_stride}),
pretrained=pretrained,
Expand Down
28 changes: 19 additions & 9 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.networks.layers import GaussianFilter
from monai.transforms.transform import Transform
from monai.transforms.utils import fill_holes, get_largest_connected_component_mask
from monai.utils import ensure_tuple, look_up_option
from monai.utils import deprecated_arg, ensure_tuple, look_up_option

__all__ = [
"Activations",
Expand Down Expand Up @@ -120,7 +120,7 @@ class AsDiscrete(Transform):
Defaults to ``False``.
to_onehot: whether to convert input data into the one-hot format.
Defaults to ``False``.
n_classes: the number of classes to convert to One-Hot format.
num_classes: the number of classes to convert to One-Hot format.
Defaults to ``None``.
threshold_values: whether threshold the float value to int number 0 or 1.
Defaults to ``False``.
Expand All @@ -131,31 +131,38 @@ class AsDiscrete(Transform):

"""

@deprecated_arg("n_classes", since="0.6")
def __init__(
self,
argmax: bool = False,
to_onehot: bool = False,
n_classes: Optional[int] = None,
num_classes: Optional[int] = None,
threshold_values: bool = False,
logit_thresh: float = 0.5,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
) -> None:
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes is None:
num_classes = n_classes
self.argmax = argmax
self.to_onehot = to_onehot
self.n_classes = n_classes
self.num_classes = num_classes
self.threshold_values = threshold_values
self.logit_thresh = logit_thresh
self.rounding = rounding

@deprecated_arg("n_classes", since="0.6")
def __call__(
self,
img: torch.Tensor,
argmax: Optional[bool] = None,
to_onehot: Optional[bool] = None,
n_classes: Optional[int] = None,
num_classes: Optional[int] = None,
threshold_values: Optional[bool] = None,
logit_thresh: Optional[float] = None,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
Expand All @@ -165,8 +172,8 @@ def __call__(
Defaults to ``self.argmax``.
to_onehot: whether to convert input data into the one-hot format.
Defaults to ``self.to_onehot``.
n_classes: the number of classes to convert to One-Hot format.
Defaults to ``self.n_classes``.
num_classes: the number of classes to convert to One-Hot format.
Defaults to ``self.num_classes``.
threshold_values: whether threshold the float value to int number 0 or 1.
Defaults to ``self.threshold_values``.
logit_thresh: the threshold value for thresholding operation..
Expand All @@ -175,13 +182,16 @@ def __call__(
available options: ["torchrounding"].

"""
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes is None:
num_classes = n_classes
if argmax or self.argmax:
img = torch.argmax(img, dim=0, keepdim=True)

if to_onehot or self.to_onehot:
_nclasses = self.n_classes if n_classes is None else n_classes
_nclasses = self.num_classes if num_classes is None else num_classes
if not isinstance(_nclasses, int):
raise AssertionError("One of self.n_classes or n_classes must be an integer")
raise AssertionError("One of self.num_classes or num_classes must be an integer")
img = one_hot(img, num_classes=_nclasses, dim=0)

if threshold_values or self.threshold_values:
Expand Down
Loading