From 6058b4e20cb4575931ef68e2484f87518d396130 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 23 Feb 2021 00:35:24 +0000 Subject: [PATCH 1/6] Fixing issues preventing loss functions from being compatible with Torchscript Signed-off-by: Eric Kerfoot --- monai/losses/dice.py | 28 +++++++++---- monai/networks/utils.py | 42 +++++++++++++------ tests/test_dice_ce_loss.py | 6 +++ tests/test_dice_loss.py | 7 ++++ tests/test_generalized_dice_loss.py | 8 ++++ .../test_generalized_wasserstein_dice_loss.py | 14 +++++++ 6 files changed, 84 insertions(+), 21 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index c284660cc6..4c6bdacfa4 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, List import numpy as np import torch @@ -139,7 +139,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis = list(range(2, len(input.shape))) + reduce_axis: List[int] = torch.arange(2,len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis @@ -268,22 +268,32 @@ def __init__( raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") + self.include_background = include_background self.to_onehot_y = to_onehot_y self.sigmoid = sigmoid self.softmax = softmax self.other_act = other_act - w_type = Weight(w_type) - self.w_func: Callable = torch.ones_like - if w_type == Weight.SIMPLE: - self.w_func = torch.reciprocal - elif w_type == Weight.SQUARE: - self.w_func = lambda x: torch.reciprocal(x * x) + self.w_type = Weight(w_type) + +# self.w_func: Callable = torch.ones_like +# if w_type == Weight.SIMPLE: +# self.w_func = torch.reciprocal +# elif w_type == Weight.SQUARE: +# self.w_func = lambda x: torch.reciprocal(x * x) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + + def w_func(self, grnd): + if self.w_type == Weight.SIMPLE: + return torch.reciprocal(grnd) + elif self.w_type == Weight.SQUARE: + return torch.reciprocal(grnd * grnd) + else: + return torch.ones_like(grnd) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -325,7 +335,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis = list(range(2, len(input.shape))) + reduce_axis: List[int] = torch.arange(2,len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis intersection = torch.sum(target * input, reduce_axis) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 847bfc97c2..e1795df2a4 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -45,24 +45,42 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0. Note that this will include the background label, thus a binary mask should be treated as having 2 classes. """ + if labels.dim() <= 0: - raise AssertionError("labels should have dim of 1 or more.") + raise ValueError("`labels` should have dim of 1 or more") + elif labels.shape[1] != 1: + raise ValueError("`labels` should have a single channel") + + oh = torch.nn.functional.one_hot(labels.long(), num_classes) + + return oh.transpose(1, -1)[..., 0] # swap class axis with channel axis (which should be 1) and drop channel axis + +# oh=oh[:,0] + +# if oh.ndim() == 3: +# return oh.permute(0,2,1) +# if oh.ndim() == 4: +# return oh.permute(0,3,1,2) +# elif oh.ndim() == 5: +# return oh.permute(0,4,1,2,3) +# else: +# raise ValueError(f"Unknown tensor format with {oh.ndim()} dimensions") - # if `dim` is bigger, add singleton dim at the end - if labels.ndim < dim + 1: - shape = ensure_tuple_size(labels.shape, dim + 1, 1) - labels = labels.reshape(*shape) +# # if `dim` is bigger, add singleton dim at the end +# if labels.ndim < dim + 1: +# shape = ensure_tuple_size(labels.shape, dim + 1, 1) +# labels = labels.reshape(*shape) - sh = list(labels.shape) +# sh = list(labels.shape) - if sh[dim] != 1: - raise AssertionError("labels should have a channel with length equals to one.") - sh[dim] = num_classes +# if sh[dim] != 1: +# raise AssertionError("labels should have a channel with length equals to one.") +# sh[dim] = num_classes - o = torch.zeros(size=sh, dtype=dtype, device=labels.device) - labels = o.scatter_(dim=dim, index=labels.long(), value=1) +# o = torch.zeros(size=sh, dtype=dtype, device=labels.device) +# labels = o.scatter_(dim=dim, index=labels.long(), value=1) - return labels +# return labels def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor: diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 443d9a9baf..dee23b1b60 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.losses import DiceCELoss +from tests.utils import test_script_save TEST_CASES = [ [ # shape: (2, 2, 3), (2, 1, 3) @@ -63,6 +64,11 @@ def test_ill_shape(self): loss = DiceCELoss() with self.assertRaisesRegex(ValueError, ""): loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_script(self): + loss = DiceCELoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index aa4a7cbc34..e2e86b57d0 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -16,6 +16,8 @@ from parameterized import parameterized from monai.losses import DiceLoss +from tests.utils import test_script_save + TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -194,6 +196,11 @@ def test_input_warnings(self): with self.assertWarns(Warning): loss = DiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + + def test_script(self): + loss = DiceLoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index e88253ccba..fcff51c569 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.losses import GeneralizedDiceLoss +from tests.utils import test_script_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -177,6 +178,13 @@ def test_input_warnings(self): with self.assertWarns(Warning): loss = GeneralizedDiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + + def test_script(self): + loss = GeneralizedDiceLoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + + if __name__ == "__main__": diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py index 6865b53027..e8beb41b28 100644 --- a/tests/test_generalized_wasserstein_dice_loss.py +++ b/tests/test_generalized_wasserstein_dice_loss.py @@ -18,6 +18,7 @@ import torch.optim as optim from monai.losses import GeneralizedWassersteinDiceLoss +from tests.utils import test_script_save class TestGeneralizedWassersteinDiceLoss(unittest.TestCase): @@ -214,6 +215,19 @@ def forward(self, x): # check that the predicted segmentation has improved self.assertGreater(diff_start, diff_end) + + def test_script(self): + target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) + + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) + pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() + + loss = GeneralizedWassersteinDiceLoss( + dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode="default" + ) + + test_script_save(loss, pred_very_good, target) if __name__ == "__main__": From 3dba85c268d5a9ce3408a051ad11041b247fec3d Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Thu, 25 Feb 2021 00:08:13 +0000 Subject: [PATCH 2/6] Updates Signed-off-by: Eric Kerfoot --- monai/losses/dice.py | 4 ++-- monai/losses/focal_loss.py | 4 ++-- monai/losses/tversky.py | 2 +- tests/test_focal_loss.py | 6 ++++++ tests/test_multi_scale.py | 6 ++++++ tests/test_tversky_loss.py | 6 ++++++ 6 files changed, 23 insertions(+), 5 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 4c6bdacfa4..9bda505672 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -139,7 +139,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis: List[int] = torch.arange(2,len(input.shape)).tolist() + reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis @@ -335,7 +335,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis: List[int] = torch.arange(2,len(input.shape)).tolist() + reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis intersection = torch.sum(target * input, reduce_axis) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index da7c63e571..920661f76f 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -80,8 +80,8 @@ def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: i = logits t = target - if i.ndimension() != t.ndimension(): - raise ValueError(f"logits and target ndim must match, got logits={i.ndimension()} target={t.ndimension()}.") + if i.ndim != t.ndim: + raise ValueError(f"logits and target ndim must match, got logits={i.ndim} target={t.ndim}.") if t.shape[1] != 1 and t.shape[1] != i.shape[1]: raise ValueError( diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index b1c45a74a2..c629b040be 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -139,7 +139,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: g1 = 1 - g0 # reducing only spatial dimensions (not batch nor channels) - reduce_axis = list(range(2, len(input.shape))) + reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index d06e2b4c36..c10bbed815 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -16,6 +16,7 @@ import torch.nn.functional as F from monai.losses import FocalLoss +from tests.utils import test_script_save class TestFocalLoss(unittest.TestCase): @@ -163,6 +164,11 @@ def test_ill_shape(self): chn_target = torch.ones((1, 1, 30)) with self.assertRaisesRegex(NotImplementedError, ""): FocalLoss()(chn_input, chn_target) + + def test_script(self): + loss = FocalLoss() + test_input = torch.ones(2, 2, 8, 8) + test_script_save(loss, test_input, test_input) if __name__ == "__main__": diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index 722ae7cfce..121b2be89b 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -16,6 +16,7 @@ from monai.losses import DiceLoss from monai.losses.multi_scale import MultiScaleLoss +from tests.utils import test_script_save dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5) @@ -54,6 +55,11 @@ def test_ill_opts(self): MultiScaleLoss(loss=dice_loss, scales=[-1])(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) with self.assertRaisesRegex(ValueError, ""): MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) + + def test_script(self): + input_param, input_data, expected_val = TEST_CASES[0] + loss = MultiScaleLoss(**input_param) + test_script_save(loss, input_data["y_pred"], input_data["y_true"]) if __name__ == "__main__": diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index a1befa062d..0847bc3a89 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.losses import TverskyLoss +from tests.utils import test_script_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -182,6 +183,11 @@ def test_input_warnings(self): with self.assertWarns(Warning): loss = TverskyLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + + def test_script(self): + loss = TverskyLoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) if __name__ == "__main__": From b65e8e64decadba9e29bcc44b021cae8a84e139a Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Thu, 25 Feb 2021 00:46:50 +0000 Subject: [PATCH 3/6] Updates Signed-off-by: Eric Kerfoot --- monai/networks/utils.py | 43 +++++++++++++---------------------------- 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index e1795df2a4..18382ea739 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -45,42 +45,25 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0. Note that this will include the background label, thus a binary mask should be treated as having 2 classes. """ - if labels.dim() <= 0: - raise ValueError("`labels` should have dim of 1 or more") - elif labels.shape[1] != 1: - raise ValueError("`labels` should have a single channel") - - oh = torch.nn.functional.one_hot(labels.long(), num_classes) - - return oh.transpose(1, -1)[..., 0] # swap class axis with channel axis (which should be 1) and drop channel axis - -# oh=oh[:,0] - -# if oh.ndim() == 3: -# return oh.permute(0,2,1) -# if oh.ndim() == 4: -# return oh.permute(0,3,1,2) -# elif oh.ndim() == 5: -# return oh.permute(0,4,1,2,3) -# else: -# raise ValueError(f"Unknown tensor format with {oh.ndim()} dimensions") + raise AssertionError("labels should have dim of 1 or more.") -# # if `dim` is bigger, add singleton dim at the end -# if labels.ndim < dim + 1: -# shape = ensure_tuple_size(labels.shape, dim + 1, 1) -# labels = labels.reshape(*shape) + # if `dim` is bigger, add singleton dim at the end + if labels.ndim < dim + 1: + shape=list(labels.shape)+[1]*(dim+1-len(labels.shape)) + labels = torch.reshape(labels, shape) -# sh = list(labels.shape) + sh = list(labels.shape) -# if sh[dim] != 1: -# raise AssertionError("labels should have a channel with length equals to one.") -# sh[dim] = num_classes + if sh[dim] != 1: + raise AssertionError("labels should have a channel with length equal to one.") + + sh[dim] = num_classes -# o = torch.zeros(size=sh, dtype=dtype, device=labels.device) -# labels = o.scatter_(dim=dim, index=labels.long(), value=1) + o = torch.zeros(size=sh, dtype=dtype, device=labels.device) + labels = o.scatter_(dim=dim, index=labels.long(), value=1) -# return labels + return labels def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor: From 6042180d7fcb35f8f858a99ac4563e8ab85c4424 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 25 Feb 2021 13:29:22 +0000 Subject: [PATCH 4/6] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/losses/dice.py | 16 ++++++++-------- monai/networks/utils.py | 4 ++-- tests/test_dice_ce_loss.py | 2 +- tests/test_dice_loss.py | 3 +-- tests/test_focal_loss.py | 2 +- tests/test_generalized_dice_loss.py | 4 +--- tests/test_generalized_wasserstein_dice_loss.py | 8 +++----- tests/test_multi_scale.py | 2 +- tests/test_tversky_loss.py | 2 +- 9 files changed, 19 insertions(+), 24 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 9bda505672..1c1808270a 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Callable, Optional, Union, List +from typing import Callable, List, Optional, Union import numpy as np import torch @@ -276,17 +276,17 @@ def __init__( self.other_act = other_act self.w_type = Weight(w_type) - -# self.w_func: Callable = torch.ones_like -# if w_type == Weight.SIMPLE: -# self.w_func = torch.reciprocal -# elif w_type == Weight.SQUARE: -# self.w_func = lambda x: torch.reciprocal(x * x) + + # self.w_func: Callable = torch.ones_like + # if w_type == Weight.SIMPLE: + # self.w_func = torch.reciprocal + # elif w_type == Weight.SQUARE: + # self.w_func = lambda x: torch.reciprocal(x * x) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch - + def w_func(self, grnd): if self.w_type == Weight.SIMPLE: return torch.reciprocal(grnd) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 18382ea739..a418e7cf1d 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -50,14 +50,14 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f # if `dim` is bigger, add singleton dim at the end if labels.ndim < dim + 1: - shape=list(labels.shape)+[1]*(dim+1-len(labels.shape)) + shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) labels = torch.reshape(labels, shape) sh = list(labels.shape) if sh[dim] != 1: raise AssertionError("labels should have a channel with length equal to one.") - + sh[dim] = num_classes o = torch.zeros(size=sh, dtype=dtype, device=labels.device) diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index dee23b1b60..24a4c0998b 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -64,7 +64,7 @@ def test_ill_shape(self): loss = DiceCELoss() with self.assertRaisesRegex(ValueError, ""): loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) - + def test_script(self): loss = DiceCELoss() test_input = torch.ones(2, 1, 8, 8) diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index e2e86b57d0..4efdfd4454 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -18,7 +18,6 @@ from monai.losses import DiceLoss from tests.utils import test_script_save - TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, @@ -196,7 +195,7 @@ def test_input_warnings(self): with self.assertWarns(Warning): loss = DiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) - + def test_script(self): loss = DiceLoss() test_input = torch.ones(2, 1, 8, 8) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index c10bbed815..a1a5b5edbe 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -164,7 +164,7 @@ def test_ill_shape(self): chn_target = torch.ones((1, 1, 30)) with self.assertRaisesRegex(NotImplementedError, ""): FocalLoss()(chn_input, chn_target) - + def test_script(self): loss = FocalLoss() test_input = torch.ones(2, 2, 8, 8) diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index fcff51c569..f6b778f264 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -178,13 +178,11 @@ def test_input_warnings(self): with self.assertWarns(Warning): loss = GeneralizedDiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) - + def test_script(self): loss = GeneralizedDiceLoss() test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input) - - if __name__ == "__main__": diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py index e8beb41b28..6a960bf495 100644 --- a/tests/test_generalized_wasserstein_dice_loss.py +++ b/tests/test_generalized_wasserstein_dice_loss.py @@ -215,7 +215,7 @@ def forward(self, x): # check that the predicted segmentation has improved self.assertGreater(diff_start, diff_end) - + def test_script(self): target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) @@ -223,10 +223,8 @@ def test_script(self): target = target.unsqueeze(0) pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - loss = GeneralizedWassersteinDiceLoss( - dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode="default" - ) - + loss = GeneralizedWassersteinDiceLoss(dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode="default") + test_script_save(loss, pred_very_good, target) diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index 121b2be89b..55bcdf8091 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -55,7 +55,7 @@ def test_ill_opts(self): MultiScaleLoss(loss=dice_loss, scales=[-1])(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) with self.assertRaisesRegex(ValueError, ""): MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) - + def test_script(self): input_param, input_data, expected_val = TEST_CASES[0] loss = MultiScaleLoss(**input_param) diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 0847bc3a89..2a791f9b3d 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -183,7 +183,7 @@ def test_input_warnings(self): with self.assertWarns(Warning): loss = TverskyLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) - + def test_script(self): loss = TverskyLoss() test_input = torch.ones(2, 1, 8, 8) From 0b56f2f7403c74f4ac6ca6c277750dfdd4145ee8 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Thu, 25 Feb 2021 14:22:21 +0000 Subject: [PATCH 5/6] Updates Signed-off-by: Eric Kerfoot --- monai/losses/dice.py | 6 -- monai/losses/tversky.py | 2 +- monai/networks/layers/simplelayers.py | 76 +++++++++++-------- monai/networks/utils.py | 2 - ...local_normalized_cross_correlation_loss.py | 6 ++ 5 files changed, 51 insertions(+), 41 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 1c1808270a..24bd038b68 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -277,12 +277,6 @@ def __init__( self.w_type = Weight(w_type) - # self.w_func: Callable = torch.ones_like - # if w_type == Weight.SIMPLE: - # self.w_func = torch.reciprocal - # elif w_type == Weight.SQUARE: - # self.w_func = lambda x: torch.reciprocal(x * x) - self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index c629b040be..1d75b9e8cc 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Callable, Optional, Union +from typing import Callable, List, Optional, Union import torch from torch.nn.modules.loss import _Loss diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index f560526db8..b2af4fcbcd 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -10,14 +10,14 @@ # limitations under the License. import math -from typing import Sequence, Union, cast +from typing import List, Sequence, Union import torch import torch.nn.functional as F from torch import nn from torch.autograd import Function -from monai.networks.layers.convutils import gaussian_1d, same_padding +from monai.networks.layers.convutils import gaussian_1d from monai.networks.layers.factories import Conv from monai.utils import ( PT_BEFORE_1_7, @@ -164,9 +164,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.reshape(shape) -def separable_filtering( - x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor], mode: str = "zeros" +def _separable_filtering_conv( + input_: torch.Tensor, + kernels: List[torch.Tensor], + pad_mode: str, + d: int, + spatial_dims: int, + paddings: List[int], + num_channels: int, ) -> torch.Tensor: + + if d < 0: + return input_ + + s = [1] * len(input_.shape) + s[d + 2] = -1 + _kernel = kernels[d].reshape(s) + + # if filter kernel is unity, don't convolve + if _kernel.numel() == 1 and _kernel[0] == 1: + return _separable_filtering_conv(input_, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels) + + _kernel = _kernel.repeat([num_channels, 1] + [1] * spatial_dims) + _padding = [0] * spatial_dims + _padding[d] = paddings[d] + conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] + + # translate padding for input to torch.nn.functional.pad + _reversed_padding_repeated_twice: List[List[int]] = [[p, p] for p in reversed(_padding)] + _sum_reversed_padding_repeated_twice: List[int] = sum(_reversed_padding_repeated_twice, []) + padded_input = F.pad(input_, _sum_reversed_padding_repeated_twice, mode=pad_mode) + + return conv_type( + input=_separable_filtering_conv(padded_input, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels), + weight=_kernel, + groups=num_channels, + ) + + +def separable_filtering(x: torch.Tensor, kernels: List[torch.Tensor], mode: str = "zeros") -> torch.Tensor: """ Apply 1-D convolutions along each spatial dimension of `x`. @@ -186,36 +222,12 @@ def separable_filtering( raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") spatial_dims = len(x.shape) - 2 - _kernels = [ - torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None) - for s in ensure_tuple_rep(kernels, spatial_dims) - ] - _paddings = [cast(int, (same_padding(k.shape[0]))) for k in _kernels] + _kernels = [s.float() for s in kernels] + _paddings = [(k.shape[0] - 1) // 2 for k in _kernels] n_chs = x.shape[1] + pad_mode = "constant" if mode == "zeros" else mode - def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: - if d < 0: - return input_ - s = [1] * len(input_.shape) - s[d + 2] = -1 - _kernel = kernels[d].reshape(s) - # if filter kernel is unity, don't convolve - if _kernel.numel() == 1 and _kernel[0] == 1: - return _conv(input_, d - 1) - _kernel = _kernel.repeat([n_chs, 1] + [1] * spatial_dims) - _padding = [0] * spatial_dims - _padding[d] = _paddings[d] - conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] - # translate padding for input to torch.nn.functional.pad - _reversed_padding_repeated_twice = [p for p in reversed(_padding) for _ in range(2)] - pad_mode = "constant" if mode == "zeros" else mode - return conv_type( - input=_conv(F.pad(input_, _reversed_padding_repeated_twice, mode=pad_mode), d - 1), - weight=_kernel, - groups=n_chs, - ) - - return _conv(x, spatial_dims - 1) + return _separable_filtering_conv(x, kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs) class SavitzkyGolayFilter(nn.Module): diff --git a/monai/networks/utils.py b/monai/networks/utils.py index a418e7cf1d..48efe3934e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -19,8 +19,6 @@ import torch import torch.nn as nn -from monai.utils import ensure_tuple_size - __all__ = [ "one_hot", "slice_channels", diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index cf8566a559..8e9482596f 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -110,5 +110,11 @@ def test_ill_opts(self): LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(pred, target) +# def test_script(self): +# input_param, input_data, _ = TEST_CASES[0] +# loss = LocalNormalizedCrossCorrelationLoss(**input_param) +# test_script_save(loss, input_data["pred"], input_data["target"]) + + if __name__ == "__main__": unittest.main() From ce57a9f030294b376a93ec6d433532ecf41930ad Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Thu, 25 Feb 2021 15:19:44 +0000 Subject: [PATCH 6/6] Adding conditional skip to Torchscript tests Signed-off-by: Eric Kerfoot --- tests/test_dice_ce_loss.py | 3 ++- tests/test_dice_loss.py | 3 ++- tests/test_focal_loss.py | 3 ++- tests/test_generalized_dice_loss.py | 3 ++- tests/test_generalized_wasserstein_dice_loss.py | 3 ++- tests/test_multi_scale.py | 3 ++- tests/test_tversky_loss.py | 3 ++- 7 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 24a4c0998b..8627c6d130 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.losses import DiceCELoss -from tests.utils import test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASES = [ [ # shape: (2, 2, 3), (2, 1, 3) @@ -65,6 +65,7 @@ def test_ill_shape(self): with self.assertRaisesRegex(ValueError, ""): loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): loss = DiceCELoss() test_input = torch.ones(2, 1, 8, 8) diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 4efdfd4454..ef0a51eb15 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.losses import DiceLoss -from tests.utils import test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -196,6 +196,7 @@ def test_input_warnings(self): loss = DiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): loss = DiceLoss() test_input = torch.ones(2, 1, 8, 8) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index a1a5b5edbe..2d1df602c7 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from monai.losses import FocalLoss -from tests.utils import test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save class TestFocalLoss(unittest.TestCase): @@ -165,6 +165,7 @@ def test_ill_shape(self): with self.assertRaisesRegex(NotImplementedError, ""): FocalLoss()(chn_input, chn_target) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): loss = FocalLoss() test_input = torch.ones(2, 2, 8, 8) diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index f6b778f264..06446204fb 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.losses import GeneralizedDiceLoss -from tests.utils import test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -179,6 +179,7 @@ def test_input_warnings(self): loss = GeneralizedDiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): loss = GeneralizedDiceLoss() test_input = torch.ones(2, 1, 8, 8) diff --git a/tests/test_generalized_wasserstein_dice_loss.py b/tests/test_generalized_wasserstein_dice_loss.py index 6a960bf495..295a4a6d70 100644 --- a/tests/test_generalized_wasserstein_dice_loss.py +++ b/tests/test_generalized_wasserstein_dice_loss.py @@ -18,7 +18,7 @@ import torch.optim as optim from monai.losses import GeneralizedWassersteinDiceLoss -from tests.utils import test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save class TestGeneralizedWassersteinDiceLoss(unittest.TestCase): @@ -216,6 +216,7 @@ def forward(self, x): # check that the predicted segmentation has improved self.assertGreater(diff_start, diff_end) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index 55bcdf8091..9ce1734e28 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -16,7 +16,7 @@ from monai.losses import DiceLoss from monai.losses.multi_scale import MultiScaleLoss -from tests.utils import test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5) @@ -56,6 +56,7 @@ def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): input_param, input_data, expected_val = TEST_CASES[0] loss = MultiScaleLoss(**input_param) diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 2a791f9b3d..0bc2ca2e70 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.losses import TverskyLoss -from tests.utils import test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) @@ -184,6 +184,7 @@ def test_input_warnings(self): loss = TverskyLoss(to_onehot_y=True) loss.forward(chn_input, chn_target) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): loss = TverskyLoss() test_input = torch.ones(2, 1, 8, 8)