diff --git a/monai/losses/dice.py b/monai/losses/dice.py index c284660cc6..24bd038b68 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Callable, Optional, Union +from typing import Callable, List, Optional, Union import numpy as np import torch @@ -139,7 +139,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis = list(range(2, len(input.shape))) + reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis @@ -268,23 +268,27 @@ def __init__( raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") + self.include_background = include_background self.to_onehot_y = to_onehot_y self.sigmoid = sigmoid self.softmax = softmax self.other_act = other_act - w_type = Weight(w_type) - self.w_func: Callable = torch.ones_like - if w_type == Weight.SIMPLE: - self.w_func = torch.reciprocal - elif w_type == Weight.SQUARE: - self.w_func = lambda x: torch.reciprocal(x * x) + self.w_type = Weight(w_type) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + def w_func(self, grnd): + if self.w_type == Weight.SIMPLE: + return torch.reciprocal(grnd) + elif self.w_type == Weight.SQUARE: + return torch.reciprocal(grnd * grnd) + else: + return torch.ones_like(grnd) + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: @@ -325,7 +329,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis = list(range(2, len(input.shape))) + reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis intersection = torch.sum(target * input, reduce_axis) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index da7c63e571..920661f76f 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -80,8 +80,8 @@ def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: i = logits t = target - if i.ndimension() != t.ndimension(): - raise ValueError(f"logits and target ndim must match, got logits={i.ndimension()} target={t.ndimension()}.") + if i.ndim != t.ndim: + raise ValueError(f"logits and target ndim must match, got logits={i.ndim} target={t.ndim}.") if t.shape[1] != 1 and t.shape[1] != i.shape[1]: raise ValueError( diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index b1c45a74a2..1d75b9e8cc 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Callable, Optional, Union +from typing import Callable, List, Optional, Union import torch from torch.nn.modules.loss import _Loss @@ -139,7 +139,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: g1 = 1 - g0 # reducing only spatial dimensions (not batch nor channels) - reduce_axis = list(range(2, len(input.shape))) + reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index f560526db8..b2af4fcbcd 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -10,14 +10,14 @@ # limitations under the License. import math -from typing import Sequence, Union, cast +from typing import List, Sequence, Union import torch import torch.nn.functional as F from torch import nn from torch.autograd import Function -from monai.networks.layers.convutils import gaussian_1d, same_padding +from monai.networks.layers.convutils import gaussian_1d from monai.networks.layers.factories import Conv from monai.utils import ( PT_BEFORE_1_7, @@ -164,9 +164,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.reshape(shape) -def separable_filtering( - x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor], mode: str = "zeros" +def _separable_filtering_conv( + input_: torch.Tensor, + kernels: List[torch.Tensor], + pad_mode: str, + d: int, + spatial_dims: int, + paddings: List[int], + num_channels: int, ) -> torch.Tensor: + + if d < 0: + return input_ + + s = [1] * len(input_.shape) + s[d + 2] = -1 + _kernel = kernels[d].reshape(s) + + # if filter kernel is unity, don't convolve + if _kernel.numel() == 1 and _kernel[0] == 1: + return _separable_filtering_conv(input_, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels) + + _kernel = _kernel.repeat([num_channels, 1] + [1] * spatial_dims) + _padding = [0] * spatial_dims + _padding[d] = paddings[d] + conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] + + # translate padding for input to torch.nn.functional.pad + _reversed_padding_repeated_twice: List[List[int]] = [[p, p] for p in reversed(_padding)] + _sum_reversed_padding_repeated_twice: List[int] = sum(_reversed_padding_repeated_twice, []) + padded_input = F.pad(input_, _sum_reversed_padding_repeated_twice, mode=pad_mode) + + return conv_type( + input=_separable_filtering_conv(padded_input, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels), + weight=_kernel, + groups=num_channels, + ) + + +def separable_filtering(x: torch.Tensor, kernels: List[torch.Tensor], mode: str = "zeros") -> torch.Tensor: """ Apply 1-D convolutions along each spatial dimension of `x`. @@ -186,36 +222,12 @@ def separable_filtering( raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") spatial_dims = len(x.shape) - 2 - _kernels = [ - torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None) - for s in ensure_tuple_rep(kernels, spatial_dims) - ] - _paddings = [cast(int, (same_padding(k.shape[0]))) for k in _kernels] + _kernels = [s.float() for s in kernels] + _paddings = [(k.shape[0] - 1) // 2 for k in _kernels] n_chs = x.shape[1] + pad_mode = "constant" if mode == "zeros" else mode - def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: - if d < 0: - return input_ - s = [1] * len(input_.shape) - s[d + 2] = -1 - _kernel = kernels[d].reshape(s) - # if filter kernel is unity, don't convolve - if _kernel.numel() == 1 and _kernel[0] == 1: - return _conv(input_, d - 1) - _kernel = _kernel.repeat([n_chs, 1] + [1] * spatial_dims) - _padding = [0] * spatial_dims - _padding[d] = _paddings[d] - conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] - # translate padding for input to torch.nn.functional.pad - _reversed_padding_repeated_twice = [p for p in reversed(_padding) for _ in range(2)] - pad_mode = "constant" if mode == "zeros" else mode - return conv_type( - input=_conv(F.pad(input_, _reversed_padding_repeated_twice, mode=pad_mode), d - 1), - weight=_kernel, - groups=n_chs, - ) - - return _conv(x, spatial_dims - 1) + return _separable_filtering_conv(x, kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs) class SavitzkyGolayFilter(nn.Module): diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 847bfc97c2..48efe3934e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -19,8 +19,6 @@ import torch import torch.nn as nn -from monai.utils import ensure_tuple_size - __all__ = [ "one_hot", "slice_channels", @@ -50,13 +48,14 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f # if `dim` is bigger, add singleton dim at the end if labels.ndim < dim + 1: - shape = ensure_tuple_size(labels.shape, dim + 1, 1) - labels = labels.reshape(*shape) + shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) + labels = torch.reshape(labels, shape) sh = list(labels.shape) if sh[dim] != 1: - raise AssertionError("labels should have a channel with length equals to one.") + raise AssertionError("labels should have a channel with length equal to one.") + sh[dim] = num_classes o = torch.zeros(size=sh, dtype=dtype, device=labels.device) diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 443d9a9baf..8627c6d130 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.losses import DiceCELoss +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASES = [ [ # shape: (2, 2, 3), (2, 1, 3) @@ -64,6 +65,12 @@ def test_ill_shape(self): with self.assertRaisesRegex(ValueError, ""): loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + @SkipIfBeforePyTorchVersion((1, 7, 0)) + def test_script(self): + loss = DiceCELoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index aa4a7cbc34..ef0a51eb15 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.losses import DiceLoss +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -195,6 +196,12 @@ def test_input_warnings(self): loss = DiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + @SkipIfBeforePyTorchVersion((1, 7, 0)) + def test_script(self): + loss = DiceLoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index d06e2b4c36..2d1df602c7 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -16,6 +16,7 @@ import torch.nn.functional as F from monai.losses import FocalLoss +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save class TestFocalLoss(unittest.TestCase): @@ -164,6 +165,12 @@ def test_ill_shape(self): with self.assertRaisesRegex(NotImplementedError, ""): FocalLoss()(chn_input, chn_target) + @SkipIfBeforePyTorchVersion((1, 7, 0)) + def test_script(self): + loss = FocalLoss() + test_input = torch.ones(2, 2, 8, 8) + test_script_save(loss, test_input, test_input) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index e88253ccba..06446204fb 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.losses import GeneralizedDiceLoss +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -178,6 +179,12 @@ def test_input_warnings(self): loss = GeneralizedDiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + @SkipIfBeforePyTorchVersion((1, 7, 0)) + def test_script(self): + loss = GeneralizedDiceLoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py index 6865b53027..295a4a6d70 100644 --- a/tests/test_generalized_wasserstein_dice_loss.py +++ b/tests/test_generalized_wasserstein_dice_loss.py @@ -18,6 +18,7 @@ import torch.optim as optim from monai.losses import GeneralizedWassersteinDiceLoss +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save class TestGeneralizedWassersteinDiceLoss(unittest.TestCase): @@ -215,6 +216,18 @@ def forward(self, x): # check that the predicted segmentation has improved self.assertGreater(diff_start, diff_end) + @SkipIfBeforePyTorchVersion((1, 7, 0)) + def test_script(self): + target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) + + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) + pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() + + loss = GeneralizedWassersteinDiceLoss(dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode="default") + + test_script_save(loss, pred_very_good, target) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index cf8566a559..8e9482596f 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -110,5 +110,11 @@ def test_ill_opts(self): LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(pred, target) +# def test_script(self): +# input_param, input_data, _ = TEST_CASES[0] +# loss = LocalNormalizedCrossCorrelationLoss(**input_param) +# test_script_save(loss, input_data["pred"], input_data["target"]) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index 722ae7cfce..9ce1734e28 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -16,6 +16,7 @@ from monai.losses import DiceLoss from monai.losses.multi_scale import MultiScaleLoss +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5) @@ -55,6 +56,12 @@ def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) + @SkipIfBeforePyTorchVersion((1, 7, 0)) + def test_script(self): + input_param, input_data, expected_val = TEST_CASES[0] + loss = MultiScaleLoss(**input_param) + test_script_save(loss, input_data["y_pred"], input_data["y_true"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index a1befa062d..0bc2ca2e70 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.losses import TverskyLoss +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -183,6 +184,12 @@ def test_input_warnings(self): loss = TverskyLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + @SkipIfBeforePyTorchVersion((1, 7, 0)) + def test_script(self): + loss = TverskyLoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + if __name__ == "__main__": unittest.main()