From 4690a3c66b8e5042c1cf4dd256fab0c10d63bfe0 Mon Sep 17 00:00:00 2001 From: Imad Toubal Date: Sun, 17 Sep 2023 14:48:39 -0500 Subject: [PATCH 1/8] Add hausdorff loss Integrating an existing implementation publicly available on GitHub by Patryk Rygiel into the MONAI framework. Signed-off-by: Imad Toubal --- docs/source/losses.rst | 5 + monai/losses/__init__.py | 1 + monai/losses/hausdorff_loss.py | 214 +++++++++++++++++++++++++++++++++ tests/test_hausdorff_loss.py | 168 ++++++++++++++++++++++++++ 4 files changed, 388 insertions(+) create mode 100644 monai/losses/hausdorff_loss.py create mode 100644 tests/test_hausdorff_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 39f1d0e4d1..5d488afbb3 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -73,6 +73,11 @@ Segmentation Losses .. autoclass:: ContrastiveLoss :members: +`HausdorffDTLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: HausdorffDTLoss + :members: + Registration Losses ------------------- diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 75f4d181d0..d734a9d44d 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -33,6 +33,7 @@ from .ds_loss import DeepSupervisionLoss from .focal_loss import FocalLoss from .giou_loss import BoxGIoULoss, giou +from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .multi_scale import MultiScaleLoss from .perceptual import PerceptualLoss diff --git a/monai/losses/hausdorff_loss.py b/monai/losses/hausdorff_loss.py new file mode 100644 index 0000000000..bfdb809b00 --- /dev/null +++ b/monai/losses/hausdorff_loss.py @@ -0,0 +1,214 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Hausdorff loss implementation based on paper: +# https://arxiv.org/pdf/1904.10030.pdf + +# Repo: https://github.com/PatRyg99/HausdorffLoss + + +from __future__ import annotations + +import warnings +from typing import Callable + +import numpy as np +import torch +from torch.nn.modules.loss import _Loss + +from monai.metrics.utils import distance_transform_edt +from monai.networks import one_hot +from monai.utils import LossReduction + + +class HausdorffDTLoss(_Loss): + """ + Compute channel-wise binary Hausdorff loss based on distance transform. It can support both multi-classes and + multi-labels tasks. The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` + (BNHW[D]). + + Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, + must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` + can be 1 or N (one-hot format). + + The original paper: Karimi, D. et. al. (2019) Reducing the Hausdorff Distance in Medical Image Segmentation with + Convolutional Neural Networks, IEEE Transactions on medical imaging, 39(2), 499-513 + """ + + def __init__( + self, + alpha: float = 2.0, + include_background: bool = False, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Callable | None = None, + reduction: LossReduction | str = LossReduction.MEAN, + batch: bool = False, + ) -> None: + """ + Args: + include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. + to_onehot_y: whether to convert the ``target`` into the one-hot format, + using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: callable function to execute other activation layers, Defaults to ``None``. for example: + ``other_act = torch.tanh``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + batch: whether to sum the intersection and union areas over the batch dimension before the dividing. + Defaults to False, a loss value is computed independently from each item in the batch + before any `reduction`. + + Raises: + TypeError: When ``other_act`` is not an ``Optional[Callable]``. + ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. + Incompatible values. + + """ + super(HausdorffDTLoss, self).__init__(reduction=LossReduction(reduction).value) + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if int(sigmoid) + int(softmax) > 1: + raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") + + self.alpha = alpha + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act + self.batch = batch + + @torch.no_grad() + def distance_field(self, img: np.ndarray) -> np.ndarray: + """Generate distance transform. + + Args: + img (np.ndarray): input mask as NCHWD or NCHW. + + Returns: + np.ndarray: Distance field. + """ + field = np.zeros_like(img) + + for batch in range(len(img)): + fg_mask = img[batch] > 0.5 + + if fg_mask.any(): + bg_mask = ~fg_mask + + fg_dist = distance_transform_edt(fg_mask) + bg_dist = distance_transform_edt(bg_mask) + + field[batch] = fg_dist + bg_dist + + return field + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNHW[D], where N is the number of classes. + target: the shape should be BNHW[D] or B1HW[D], where N is the number of classes. + + Raises: + ValueError: If the input is not 2D (NCHW) or 3D (NCHWD). + AssertionError: When input and target (after one hot transform if set) + have different shapes. + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + + Example: + >>> import torch + >>> from monai.losses.hausdorff_loss import HausdorffDTLoss, one_hot + >>> B, C, H, W = 7, 5, 3, 2 + >>> input = torch.rand(B, C, H, W) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = HausdorffDTLoss(reduction='none') + >>> loss = self(input, target) + >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape + """ + if input.dim() != 4 and input.dim() != 5: + raise ValueError("Only 2D (NCHW) and 3D (NCHWD) supported") + + if self.sigmoid: + input = torch.sigmoid(input) + + n_pred_ch = input.shape[1] + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + if self.other_act is not None: + input = self.other_act(input) + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + # If skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + + if target.shape != input.shape: + raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + + device = input.device + all_f = [] + for i in range(input.shape[1]): + ch_input = input[:, [i]] + ch_target = target[:, [i]] + pred_dt = torch.from_numpy(self.distance_field(ch_input.detach().cpu().numpy())).float() + target_dt = torch.from_numpy(self.distance_field(ch_target.detach().cpu().numpy())).float() + + pred_error = (ch_input - ch_target) ** 2 + distance = pred_dt**self.alpha + target_dt**self.alpha + + f = pred_error * distance.to(device) + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() + if self.batch: + # reducing spatial dimensions and batch + reduce_axis = [0] + reduce_axis + all_f.append(f.mean(dim=reduce_axis, keepdim=True)) + f = torch.cat(all_f, dim=1) + if self.reduction == LossReduction.MEAN.value: + f = torch.mean(f) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + f = torch.sum(f) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least make sure a none reduction maintains a + # broadcastable shape + broadcast_shape = list(f.shape[0:2]) + [1] * (len(ch_input.shape) - 2) + f = f.view(broadcast_shape) + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return f + + +class LogHausdorffDTLoss(HausdorffDTLoss): + def forward(self, *args, **kwargs): + return torch.log(super().forward(*args, **kwargs) + 1) diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py new file mode 100644 index 0000000000..f77a6acdec --- /dev/null +++ b/tests/test_hausdorff_loss.py @@ -0,0 +1,168 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import HausdorffDTLoss + +TEST_CASES = [ + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + 0.509329, + ], + [ # shape: (1, 2, 2, 2), (1, 2, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]]), + }, + 0.326994, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.758470, + ], + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": False, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]]), + }, + 0.144659, + ], + [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "none"}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]] + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]]), + }, + [[[[0.407765]], [[0.407765]]], [[[0.5000]], [[0.5000]]]], + ], + [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + {"include_background": True, "to_onehot_y": True, "softmax": True}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]] + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]]), + }, + 0.357016, + ], + [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "sum"}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]] + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]]), + }, + 1.428062, + ], + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + 0.509329, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 3.450064, + ], + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": True, "to_onehot_y": True, "other_act": lambda x: torch.log_softmax(x, dim=1)}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]] + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]]), + }, + 4.366613, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 2.661359, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 2.661359, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "batch": False}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 2.661359, + ], +] + + +class TestHausdorffDTLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = HausdorffDTLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = HausdorffDTLoss() + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 1, 2, 3)), torch.ones((1, 4, 5, 6))) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + HausdorffDTLoss(sigmoid=True, softmax=True) + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + with self.assertRaisesRegex(ValueError, ""): + HausdorffDTLoss(reduction="unknown")(chn_input, chn_target) + with self.assertRaisesRegex(ValueError, ""): + HausdorffDTLoss(reduction=None)(chn_input, chn_target) + + def test_input_warnings(self): + chn_input = torch.ones((1, 1, 1, 3)) + chn_target = torch.ones((1, 1, 1, 3)) + with self.assertWarns(Warning): + loss = HausdorffDTLoss(include_background=False) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = HausdorffDTLoss(softmax=True) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = HausdorffDTLoss(to_onehot_y=True) + loss.forward(chn_input, chn_target) + + +if __name__ == "__main__": + unittest.main() From 8c866de1a15b625c8df444df50a71d9c5f08c3f4 Mon Sep 17 00:00:00 2001 From: Imad Toubal Date: Mon, 18 Sep 2023 15:01:20 -0500 Subject: [PATCH 2/8] Add tests on multiple devices. Signed-off-by: Imad Toubal --- tests/test_hausdorff_loss.py | 114 +++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py index f77a6acdec..51228a5c27 100644 --- a/tests/test_hausdorff_loss.py +++ b/tests/test_hausdorff_loss.py @@ -19,114 +19,122 @@ from monai.losses import HausdorffDTLoss -TEST_CASES = [ - [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) +TEST_CASES = [] +for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: + TEST_CASES.append([ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True}, - {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device) + }, 0.509329, - ], - [ # shape: (1, 2, 2, 2), (1, 2, 2, 2) + ]) + TEST_CASES.append([ # shape: (1, 2, 2, 2), (1, 2, 2, 2) {"include_background": True, "sigmoid": True}, { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]]), + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device), }, 0.326994, - ], - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + ]) + TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "sigmoid": True}, { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, 0.758470, - ], - [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + ]) + TEST_CASES.append([ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": False, "sigmoid": True}, { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]]), + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device), }, 0.144659, - ], - [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + ]) + TEST_CASES.append([ # shape: (2, 2, 3, 1), (2, 1, 3, 1) {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "none"}, { "input": torch.tensor( - [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]] + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device ), - "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]]), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), }, [[[[0.407765]], [[0.407765]]], [[[0.5000]], [[0.5000]]]], - ], - [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + ]) + TEST_CASES.append([ # shape: (2, 2, 3, 1), (2, 1, 3, 1) {"include_background": True, "to_onehot_y": True, "softmax": True}, { "input": torch.tensor( - [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]] + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device ), - "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]]), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), }, 0.357016, - ], - [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + ]) + TEST_CASES.append([ # shape: (2, 2, 3, 1), (2, 1, 3, 1) {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "sum"}, { "input": torch.tensor( - [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]] + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device ), - "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]]), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), }, 1.428062, - ], - [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + ]) + TEST_CASES.append([ # shape: (1, 1, 2, 2), (1, 1, 2, 2) {"include_background": True, "sigmoid": True}, - {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device)}, 0.509329, - ], - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + ]) + TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "other_act": torch.tanh}, { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, 3.450064, - ], - [ # shape: (2, 2, 3), (2, 1, 3) + ]) + TEST_CASES.append([ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "other_act": lambda x: torch.log_softmax(x, dim=1)}, { "input": torch.tensor( - [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]] - ), - "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]]), + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), }, 4.366613, - ], - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + ]) + TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "other_act": torch.tanh, "batch": True}, { - "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, 2.661359, - ], - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + ]) + TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "other_act": torch.tanh, "batch": True}, { - "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, 2.661359, - ], - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + ]) + TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "other_act": torch.tanh, "batch": False}, { - "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, 2.661359, - ], -] + ]) class TestHausdorffDTLoss(unittest.TestCase): From ec38df79728967f5e74e6b78e393808abc3cb3c9 Mon Sep 17 00:00:00 2001 From: Imad Toubal Date: Mon, 18 Sep 2023 15:03:55 -0500 Subject: [PATCH 3/8] Add a `skipUnless` to Hausdorff tests Add a `skipUnless` decorator to skip the monai.utils.module.OptionalImportError: `from scipy.ndimage.morphology import distance_transform_edt` Signed-off-by: Imad Toubal --- tests/test_hausdorff_loss.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py index 51228a5c27..6f3f98e3ec 100644 --- a/tests/test_hausdorff_loss.py +++ b/tests/test_hausdorff_loss.py @@ -12,12 +12,16 @@ from __future__ import annotations import unittest +from unittest.case import skipUnless import numpy as np import torch from parameterized import parameterized from monai.losses import HausdorffDTLoss +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") TEST_CASES = [] for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: @@ -137,6 +141,7 @@ ]) +@skipUnless(has_scipy, "Scipy required") class TestHausdorffDTLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): From 31637978cc817b8c085c7754fcf28f91ae3a515c Mon Sep 17 00:00:00 2001 From: Imad Toubal Date: Mon, 18 Sep 2023 15:30:03 -0500 Subject: [PATCH 4/8] Add 3D test cases for Hausdorff loss Signed-off-by: Imad Toubal --- tests/test_hausdorff_loss.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py index 6f3f98e3ec..36d8056f79 100644 --- a/tests/test_hausdorff_loss.py +++ b/tests/test_hausdorff_loss.py @@ -33,6 +33,22 @@ }, 0.509329, ]) + TEST_CASES.append([ # shape: (1, 1, 1, 2, 2), (1, 1, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]], device=device), + "target": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]], device=device) + }, + 0.509329, + ]) + TEST_CASES.append([ # shape: (1, 1, 2, 2, 2), (1, 1, 2, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[[1.0, -1.0], [1.0, -1.0]], [[-1.0, 1.0], [-1.0, 1.0]]]]], device=device), + "target": torch.tensor([[[[[1.0, 0.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]]]], device=device) + }, + 0.375718, + ]) TEST_CASES.append([ # shape: (1, 2, 2, 2), (1, 2, 2, 2) {"include_background": True, "sigmoid": True}, { From 211b04f27c82cb62252292273c0011929dec1f66 Mon Sep 17 00:00:00 2001 From: Imad Toubal Date: Mon, 18 Sep 2023 15:51:29 -0500 Subject: [PATCH 5/8] Autofix linting issues Signed-off-by: Imad Toubal --- tests/test_hausdorff_loss.py | 292 +++++++++++++++++++---------------- 1 file changed, 162 insertions(+), 130 deletions(-) diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py index 36d8056f79..aa4c77975e 100644 --- a/tests/test_hausdorff_loss.py +++ b/tests/test_hausdorff_loss.py @@ -25,136 +25,168 @@ TEST_CASES = [] for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: - TEST_CASES.append([ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device) - }, - 0.509329, - ]) - TEST_CASES.append([ # shape: (1, 1, 1, 2, 2), (1, 1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, - { - "input": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]], device=device), - "target": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]], device=device) - }, - 0.509329, - ]) - TEST_CASES.append([ # shape: (1, 1, 2, 2, 2), (1, 1, 2, 2, 2) - {"include_background": True, "sigmoid": True}, - { - "input": torch.tensor([[[[[1.0, -1.0], [1.0, -1.0]], [[-1.0, 1.0], [-1.0, 1.0]]]]], device=device), - "target": torch.tensor([[[[[1.0, 0.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]]]], device=device) - }, - 0.375718, - ]) - TEST_CASES.append([ # shape: (1, 2, 2, 2), (1, 2, 2, 2) - {"include_background": True, "sigmoid": True}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device), - }, - 0.326994, - ]) - TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "sigmoid": True}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), - }, - 0.758470, - ]) - TEST_CASES.append([ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": False, "sigmoid": True}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device), - }, - 0.144659, - ]) - TEST_CASES.append([ # shape: (2, 2, 3, 1), (2, 1, 3, 1) - {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "none"}, - { - "input": torch.tensor( - [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], - device=device - ), - "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), - }, - [[[[0.407765]], [[0.407765]]], [[[0.5000]], [[0.5000]]]], - ]) - TEST_CASES.append([ # shape: (2, 2, 3, 1), (2, 1, 3, 1) - {"include_background": True, "to_onehot_y": True, "softmax": True}, - { - "input": torch.tensor( - [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], - device=device - ), - "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), - }, - 0.357016, - ]) - TEST_CASES.append([ # shape: (2, 2, 3, 1), (2, 1, 3, 1) - {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "sum"}, - { - "input": torch.tensor( - [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], - device=device - ), - "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), - }, - 1.428062, - ]) - TEST_CASES.append([ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), - "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device)}, - 0.509329, - ]) - TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "other_act": torch.tanh}, - { - "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), - }, - 3.450064, - ]) - TEST_CASES.append([ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "other_act": lambda x: torch.log_softmax(x, dim=1)}, - { - "input": torch.tensor( - [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], - device=device), - "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), - }, - 4.366613, - ]) - TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "other_act": torch.tanh, "batch": True}, - { - "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), - }, - 2.661359, - ]) - TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "other_act": torch.tanh, "batch": True}, - { - "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), - }, - 2.661359, - ]) - TEST_CASES.append([ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "other_act": torch.tanh, "batch": False}, - { - "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), - "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), - }, - 2.661359, - ]) + TEST_CASES.append( + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device), + }, + 0.509329, + ] + ) + TEST_CASES.append( + [ # shape: (1, 1, 1, 2, 2), (1, 1, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]], device=device), + "target": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]], device=device), + }, + 0.509329, + ] + ) + TEST_CASES.append( + [ # shape: (1, 1, 2, 2, 2), (1, 1, 2, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[[1.0, -1.0], [1.0, -1.0]], [[-1.0, 1.0], [-1.0, 1.0]]]]], device=device), + "target": torch.tensor([[[[[1.0, 0.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]]]], device=device), + }, + 0.375718, + ] + ) + TEST_CASES.append( + [ # shape: (1, 2, 2, 2), (1, 2, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device), + }, + 0.326994, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 0.758470, + ] + ) + TEST_CASES.append( + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": False, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device), + }, + 0.144659, + ] + ) + TEST_CASES.append( + [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "none"}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device, + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), + }, + [[[[0.407765]], [[0.407765]]], [[[0.5000]], [[0.5000]]]], + ] + ) + TEST_CASES.append( + [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + {"include_background": True, "to_onehot_y": True, "softmax": True}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device, + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), + }, + 0.357016, + ] + ) + TEST_CASES.append( + [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "sum"}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device, + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), + }, + 1.428062, + ] + ) + TEST_CASES.append( + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device), + }, + 0.509329, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 3.450064, + ] + ) + TEST_CASES.append( + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": True, "to_onehot_y": True, "other_act": lambda x: torch.log_softmax(x, dim=1)}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device, + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), + }, + 4.366613, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 2.661359, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 2.661359, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "batch": False}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 2.661359, + ] + ) @skipUnless(has_scipy, "Scipy required") From 4fea7aaf96e5a12bd8572cdee0772bf544a86e6b Mon Sep 17 00:00:00 2001 From: Imad Toubal Date: Mon, 18 Sep 2023 17:03:06 -0500 Subject: [PATCH 6/8] Add log hausdorff loss tests and typing Signed-off-by: Imad Toubal --- monai/losses/hausdorff_loss.py | 2 +- tests/test_hausdorff_loss.py | 41 +++++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/monai/losses/hausdorff_loss.py b/monai/losses/hausdorff_loss.py index bfdb809b00..7384823997 100644 --- a/monai/losses/hausdorff_loss.py +++ b/monai/losses/hausdorff_loss.py @@ -210,5 +210,5 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: class LogHausdorffDTLoss(HausdorffDTLoss): - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> torch.Tensor: return torch.log(super().forward(*args, **kwargs) + 1) diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py index aa4c77975e..a4d1fb04cc 100644 --- a/tests/test_hausdorff_loss.py +++ b/tests/test_hausdorff_loss.py @@ -18,7 +18,7 @@ import torch from parameterized import parameterized -from monai.losses import HausdorffDTLoss +from monai.losses import HausdorffDTLoss, LogHausdorffDTLoss from monai.utils import optional_import _, has_scipy = optional_import("scipy") @@ -189,6 +189,9 @@ ) +TEST_CASES_LOG = [[*inputs, np.log(np.array(output) + 1)] for *inputs, output in TEST_CASES] + + @skipUnless(has_scipy, "Scipy required") class TestHausdorffDTLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -225,5 +228,41 @@ def test_input_warnings(self): loss.forward(chn_input, chn_target) +@skipUnless(has_scipy, "Scipy required") +class TesLogtHausdorffDTLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES_LOG) + def test_shape(self, input_param, input_data, expected_val): + result = LogHausdorffDTLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = LogHausdorffDTLoss() + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 1, 2, 3)), torch.ones((1, 4, 5, 6))) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + LogHausdorffDTLoss(sigmoid=True, softmax=True) + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + with self.assertRaisesRegex(ValueError, ""): + LogHausdorffDTLoss(reduction="unknown")(chn_input, chn_target) + with self.assertRaisesRegex(ValueError, ""): + LogHausdorffDTLoss(reduction=None)(chn_input, chn_target) + + def test_input_warnings(self): + chn_input = torch.ones((1, 1, 1, 3)) + chn_target = torch.ones((1, 1, 1, 3)) + with self.assertWarns(Warning): + loss = LogHausdorffDTLoss(include_background=False) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = LogHausdorffDTLoss(softmax=True) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = LogHausdorffDTLoss(to_onehot_y=True) + loss.forward(chn_input, chn_target) + + if __name__ == "__main__": unittest.main() From 8b44bfe1181bf6b6c5fc5f9d85faac60c2c1222f Mon Sep 17 00:00:00 2001 From: Imad Toubal Date: Mon, 18 Sep 2023 17:52:22 -0500 Subject: [PATCH 7/8] Add docstring and correct type annotations Signed-off-by: Imad Toubal --- monai/losses/hausdorff_loss.py | 35 ++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/monai/losses/hausdorff_loss.py b/monai/losses/hausdorff_loss.py index 7384823997..b040568765 100644 --- a/monai/losses/hausdorff_loss.py +++ b/monai/losses/hausdorff_loss.py @@ -187,12 +187,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: pred_error = (ch_input - ch_target) ** 2 distance = pred_dt**self.alpha + target_dt**self.alpha - f = pred_error * distance.to(device) + running_f = pred_error * distance.to(device) reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - all_f.append(f.mean(dim=reduce_axis, keepdim=True)) + all_f.append(running_f.mean(dim=reduce_axis, keepdim=True)) f = torch.cat(all_f, dim=1) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average @@ -210,5 +210,32 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: class LogHausdorffDTLoss(HausdorffDTLoss): - def forward(self, *args, **kwargs) -> torch.Tensor: - return torch.log(super().forward(*args, **kwargs) + 1) + """ + Compute the logarithm of the Hausdorff Distance Transform Loss. + + This class computes the logarithm of the Hausdorff Distance Transform Loss, which is based on the distance transform. + The logarithm is computed to potentially stabilize and scale the loss values, especially when the original loss + values are very small. + + The formula for the loss is given by: + log_loss = log(HausdorffDTLoss + 1) + + Inherits from the HausdorffDTLoss class to utilize its distance transform computation. + """ + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute the logarithm of the Hausdorff Distance Transform Loss. + + Args: + input (torch.Tensor): The shape should be BNHW[D], where N is the number of classes. + target (torch.Tensor): The shape should be BNHW[D] or B1HW[D], where N is the number of classes. + + Returns: + torch.Tensor: The computed Log Hausdorff Distance Transform Loss for the given input and target. + + Raises: + Any exceptions raised by the parent class HausdorffDTLoss. + """ + log_loss: torch.Tensor = torch.log(super().forward(input, target) + 1) + return log_loss From 3c8abf6fa609c1e82c9c298b44fbca7bc593a8ca Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Tue, 19 Sep 2023 07:50:02 +0100 Subject: [PATCH 8/8] Update monai/losses/hausdorff_loss.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Wenqi Li <831580+wyli@users.noreply.github.com> --- monai/losses/hausdorff_loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/hausdorff_loss.py b/monai/losses/hausdorff_loss.py index b040568765..eeba96933c 100644 --- a/monai/losses/hausdorff_loss.py +++ b/monai/losses/hausdorff_loss.py @@ -134,7 +134,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Example: >>> import torch - >>> from monai.losses.hausdorff_loss import HausdorffDTLoss, one_hot + >>> from monai.losses.hausdorff_loss import HausdorffDTLoss + >>> from monai.networks.utils import one_hot >>> B, C, H, W = 7, 5, 3, 2 >>> input = torch.rand(B, C, H, W) >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()