Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import warnings
from typing import Callable, Optional, Union
from typing import Callable, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -139,7 +139,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

# reducing only spatial dimensions (not batch nor channels)
reduce_axis = list(range(2, len(input.shape)))
reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis
Expand Down Expand Up @@ -268,23 +268,27 @@ def __init__(
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")

self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_act

w_type = Weight(w_type)
self.w_func: Callable = torch.ones_like
if w_type == Weight.SIMPLE:
self.w_func = torch.reciprocal
elif w_type == Weight.SQUARE:
self.w_func = lambda x: torch.reciprocal(x * x)
self.w_type = Weight(w_type)

self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch

def w_func(self, grnd):
if self.w_type == Weight.SIMPLE:
return torch.reciprocal(grnd)
elif self.w_type == Weight.SQUARE:
return torch.reciprocal(grnd * grnd)
else:
return torch.ones_like(grnd)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand Down Expand Up @@ -325,7 +329,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

# reducing only spatial dimensions (not batch nor channels)
reduce_axis = list(range(2, len(input.shape)))
reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
reduce_axis = [0] + reduce_axis
intersection = torch.sum(target * input, reduce_axis)
Expand Down
4 changes: 2 additions & 2 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
i = logits
t = target

if i.ndimension() != t.ndimension():
raise ValueError(f"logits and target ndim must match, got logits={i.ndimension()} target={t.ndimension()}.")
if i.ndim != t.ndim:
raise ValueError(f"logits and target ndim must match, got logits={i.ndim} target={t.ndim}.")

if t.shape[1] != 1 and t.shape[1] != i.shape[1]:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import warnings
from typing import Callable, Optional, Union
from typing import Callable, List, Optional, Union

import torch
from torch.nn.modules.loss import _Loss
Expand Down Expand Up @@ -139,7 +139,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
g1 = 1 - g0

# reducing only spatial dimensions (not batch nor channels)
reduce_axis = list(range(2, len(input.shape)))
reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis
Expand Down
76 changes: 44 additions & 32 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
# limitations under the License.

import math
from typing import Sequence, Union, cast
from typing import List, Sequence, Union

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function

from monai.networks.layers.convutils import gaussian_1d, same_padding
from monai.networks.layers.convutils import gaussian_1d
from monai.networks.layers.factories import Conv
from monai.utils import (
PT_BEFORE_1_7,
Expand Down Expand Up @@ -164,9 +164,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.reshape(shape)


def separable_filtering(
x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor], mode: str = "zeros"
def _separable_filtering_conv(
input_: torch.Tensor,
kernels: List[torch.Tensor],
pad_mode: str,
d: int,
spatial_dims: int,
paddings: List[int],
num_channels: int,
) -> torch.Tensor:

if d < 0:
return input_

s = [1] * len(input_.shape)
s[d + 2] = -1
_kernel = kernels[d].reshape(s)

# if filter kernel is unity, don't convolve
if _kernel.numel() == 1 and _kernel[0] == 1:
return _separable_filtering_conv(input_, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels)

_kernel = _kernel.repeat([num_channels, 1] + [1] * spatial_dims)
_padding = [0] * spatial_dims
_padding[d] = paddings[d]
conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]

# translate padding for input to torch.nn.functional.pad
_reversed_padding_repeated_twice: List[List[int]] = [[p, p] for p in reversed(_padding)]
_sum_reversed_padding_repeated_twice: List[int] = sum(_reversed_padding_repeated_twice, [])
padded_input = F.pad(input_, _sum_reversed_padding_repeated_twice, mode=pad_mode)

return conv_type(
input=_separable_filtering_conv(padded_input, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels),
weight=_kernel,
groups=num_channels,
)


def separable_filtering(x: torch.Tensor, kernels: List[torch.Tensor], mode: str = "zeros") -> torch.Tensor:
"""
Apply 1-D convolutions along each spatial dimension of `x`.

Expand All @@ -186,36 +222,12 @@ def separable_filtering(
raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")

spatial_dims = len(x.shape) - 2
_kernels = [
torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None)
for s in ensure_tuple_rep(kernels, spatial_dims)
]
_paddings = [cast(int, (same_padding(k.shape[0]))) for k in _kernels]
_kernels = [s.float() for s in kernels]
_paddings = [(k.shape[0] - 1) // 2 for k in _kernels]
n_chs = x.shape[1]
pad_mode = "constant" if mode == "zeros" else mode

def _conv(input_: torch.Tensor, d: int) -> torch.Tensor:
if d < 0:
return input_
s = [1] * len(input_.shape)
s[d + 2] = -1
_kernel = kernels[d].reshape(s)
# if filter kernel is unity, don't convolve
if _kernel.numel() == 1 and _kernel[0] == 1:
return _conv(input_, d - 1)
_kernel = _kernel.repeat([n_chs, 1] + [1] * spatial_dims)
_padding = [0] * spatial_dims
_padding[d] = _paddings[d]
conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]
# translate padding for input to torch.nn.functional.pad
_reversed_padding_repeated_twice = [p for p in reversed(_padding) for _ in range(2)]
pad_mode = "constant" if mode == "zeros" else mode
return conv_type(
input=_conv(F.pad(input_, _reversed_padding_repeated_twice, mode=pad_mode), d - 1),
weight=_kernel,
groups=n_chs,
)

return _conv(x, spatial_dims - 1)
return _separable_filtering_conv(x, kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs)


class SavitzkyGolayFilter(nn.Module):
Expand Down
9 changes: 4 additions & 5 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import torch
import torch.nn as nn

from monai.utils import ensure_tuple_size

__all__ = [
"one_hot",
"slice_channels",
Expand Down Expand Up @@ -50,13 +48,14 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f

# if `dim` is bigger, add singleton dim at the end
if labels.ndim < dim + 1:
shape = ensure_tuple_size(labels.shape, dim + 1, 1)
labels = labels.reshape(*shape)
shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape))
labels = torch.reshape(labels, shape)

sh = list(labels.shape)

if sh[dim] != 1:
raise AssertionError("labels should have a channel with length equals to one.")
raise AssertionError("labels should have a channel with length equal to one.")

sh[dim] = num_classes

o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_dice_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parameterized import parameterized

from monai.losses import DiceCELoss
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save

TEST_CASES = [
[ # shape: (2, 2, 3), (2, 1, 3)
Expand Down Expand Up @@ -64,6 +65,12 @@ def test_ill_shape(self):
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

@SkipIfBeforePyTorchVersion((1, 7, 0))
def test_script(self):
loss = DiceCELoss()
test_input = torch.ones(2, 1, 8, 8)
test_script_save(loss, test_input, test_input)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parameterized import parameterized

from monai.losses import DiceLoss
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save

TEST_CASES = [
[ # shape: (1, 1, 2, 2), (1, 1, 2, 2)
Expand Down Expand Up @@ -195,6 +196,12 @@ def test_input_warnings(self):
loss = DiceLoss(to_onehot_y=True)
loss.forward(chn_input, chn_target)

@SkipIfBeforePyTorchVersion((1, 7, 0))
def test_script(self):
loss = DiceLoss()
test_input = torch.ones(2, 1, 8, 8)
test_script_save(loss, test_input, test_input)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.nn.functional as F

from monai.losses import FocalLoss
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save


class TestFocalLoss(unittest.TestCase):
Expand Down Expand Up @@ -164,6 +165,12 @@ def test_ill_shape(self):
with self.assertRaisesRegex(NotImplementedError, ""):
FocalLoss()(chn_input, chn_target)

@SkipIfBeforePyTorchVersion((1, 7, 0))
def test_script(self):
loss = FocalLoss()
test_input = torch.ones(2, 2, 8, 8)
test_script_save(loss, test_input, test_input)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_generalized_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parameterized import parameterized

from monai.losses import GeneralizedDiceLoss
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save

TEST_CASES = [
[ # shape: (1, 1, 2, 2), (1, 1, 2, 2)
Expand Down Expand Up @@ -178,6 +179,12 @@ def test_input_warnings(self):
loss = GeneralizedDiceLoss(to_onehot_y=True)
loss.forward(chn_input, chn_target)

@SkipIfBeforePyTorchVersion((1, 7, 0))
def test_script(self):
loss = GeneralizedDiceLoss()
test_input = torch.ones(2, 1, 8, 8)
test_script_save(loss, test_input, test_input)


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions tests/test_generalized_wasserstein_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.optim as optim

from monai.losses import GeneralizedWassersteinDiceLoss
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save


class TestGeneralizedWassersteinDiceLoss(unittest.TestCase):
Expand Down Expand Up @@ -215,6 +216,18 @@ def forward(self, x):
# check that the predicted segmentation has improved
self.assertGreater(diff_start, diff_end)

@SkipIfBeforePyTorchVersion((1, 7, 0))
def test_script(self):
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])

# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0)
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()

loss = GeneralizedWassersteinDiceLoss(dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode="default")

test_script_save(loss, pred_very_good, target)


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions tests/test_local_normalized_cross_correlation_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,11 @@ def test_ill_opts(self):
LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(pred, target)


# def test_script(self):
# input_param, input_data, _ = TEST_CASES[0]
# loss = LocalNormalizedCrossCorrelationLoss(**input_param)
# test_script_save(loss, input_data["pred"], input_data["target"])


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.losses import DiceLoss
from monai.losses.multi_scale import MultiScaleLoss
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save

dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5)

Expand Down Expand Up @@ -55,6 +56,12 @@ def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(torch.ones((1, 1, 3)), torch.ones((1, 1, 3)))

@SkipIfBeforePyTorchVersion((1, 7, 0))
def test_script(self):
input_param, input_data, expected_val = TEST_CASES[0]
loss = MultiScaleLoss(**input_param)
test_script_save(loss, input_data["y_pred"], input_data["y_true"])


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions tests/test_tversky_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parameterized import parameterized

from monai.losses import TverskyLoss
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save

TEST_CASES = [
[ # shape: (1, 1, 2, 2), (1, 1, 2, 2)
Expand Down Expand Up @@ -183,6 +184,12 @@ def test_input_warnings(self):
loss = TverskyLoss(to_onehot_y=True)
loss.forward(chn_input, chn_target)

@SkipIfBeforePyTorchVersion((1, 7, 0))
def test_script(self):
loss = TverskyLoss()
test_input = torch.ones(2, 1, 8, 8)
test_script_save(loss, test_input, test_input)


if __name__ == "__main__":
unittest.main()