diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 39f1d0e4d1..5d488afbb3 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -73,6 +73,11 @@ Segmentation Losses .. autoclass:: ContrastiveLoss :members: +`HausdorffDTLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: HausdorffDTLoss + :members: + Registration Losses ------------------- diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 75f4d181d0..d734a9d44d 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -33,6 +33,7 @@ from .ds_loss import DeepSupervisionLoss from .focal_loss import FocalLoss from .giou_loss import BoxGIoULoss, giou +from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .multi_scale import MultiScaleLoss from .perceptual import PerceptualLoss diff --git a/monai/losses/hausdorff_loss.py b/monai/losses/hausdorff_loss.py new file mode 100644 index 0000000000..eeba96933c --- /dev/null +++ b/monai/losses/hausdorff_loss.py @@ -0,0 +1,242 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Hausdorff loss implementation based on paper: +# https://arxiv.org/pdf/1904.10030.pdf + +# Repo: https://github.com/PatRyg99/HausdorffLoss + + +from __future__ import annotations + +import warnings +from typing import Callable + +import numpy as np +import torch +from torch.nn.modules.loss import _Loss + +from monai.metrics.utils import distance_transform_edt +from monai.networks import one_hot +from monai.utils import LossReduction + + +class HausdorffDTLoss(_Loss): + """ + Compute channel-wise binary Hausdorff loss based on distance transform. It can support both multi-classes and + multi-labels tasks. The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` + (BNHW[D]). + + Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, + must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` + can be 1 or N (one-hot format). + + The original paper: Karimi, D. et. al. (2019) Reducing the Hausdorff Distance in Medical Image Segmentation with + Convolutional Neural Networks, IEEE Transactions on medical imaging, 39(2), 499-513 + """ + + def __init__( + self, + alpha: float = 2.0, + include_background: bool = False, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Callable | None = None, + reduction: LossReduction | str = LossReduction.MEAN, + batch: bool = False, + ) -> None: + """ + Args: + include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. + to_onehot_y: whether to convert the ``target`` into the one-hot format, + using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: callable function to execute other activation layers, Defaults to ``None``. for example: + ``other_act = torch.tanh``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + batch: whether to sum the intersection and union areas over the batch dimension before the dividing. + Defaults to False, a loss value is computed independently from each item in the batch + before any `reduction`. + + Raises: + TypeError: When ``other_act`` is not an ``Optional[Callable]``. + ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. + Incompatible values. + + """ + super(HausdorffDTLoss, self).__init__(reduction=LossReduction(reduction).value) + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if int(sigmoid) + int(softmax) > 1: + raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") + + self.alpha = alpha + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act + self.batch = batch + + @torch.no_grad() + def distance_field(self, img: np.ndarray) -> np.ndarray: + """Generate distance transform. + + Args: + img (np.ndarray): input mask as NCHWD or NCHW. + + Returns: + np.ndarray: Distance field. + """ + field = np.zeros_like(img) + + for batch in range(len(img)): + fg_mask = img[batch] > 0.5 + + if fg_mask.any(): + bg_mask = ~fg_mask + + fg_dist = distance_transform_edt(fg_mask) + bg_dist = distance_transform_edt(bg_mask) + + field[batch] = fg_dist + bg_dist + + return field + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNHW[D], where N is the number of classes. + target: the shape should be BNHW[D] or B1HW[D], where N is the number of classes. + + Raises: + ValueError: If the input is not 2D (NCHW) or 3D (NCHWD). + AssertionError: When input and target (after one hot transform if set) + have different shapes. + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + + Example: + >>> import torch + >>> from monai.losses.hausdorff_loss import HausdorffDTLoss + >>> from monai.networks.utils import one_hot + >>> B, C, H, W = 7, 5, 3, 2 + >>> input = torch.rand(B, C, H, W) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = HausdorffDTLoss(reduction='none') + >>> loss = self(input, target) + >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape + """ + if input.dim() != 4 and input.dim() != 5: + raise ValueError("Only 2D (NCHW) and 3D (NCHWD) supported") + + if self.sigmoid: + input = torch.sigmoid(input) + + n_pred_ch = input.shape[1] + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + if self.other_act is not None: + input = self.other_act(input) + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + # If skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + + if target.shape != input.shape: + raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + + device = input.device + all_f = [] + for i in range(input.shape[1]): + ch_input = input[:, [i]] + ch_target = target[:, [i]] + pred_dt = torch.from_numpy(self.distance_field(ch_input.detach().cpu().numpy())).float() + target_dt = torch.from_numpy(self.distance_field(ch_target.detach().cpu().numpy())).float() + + pred_error = (ch_input - ch_target) ** 2 + distance = pred_dt**self.alpha + target_dt**self.alpha + + running_f = pred_error * distance.to(device) + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() + if self.batch: + # reducing spatial dimensions and batch + reduce_axis = [0] + reduce_axis + all_f.append(running_f.mean(dim=reduce_axis, keepdim=True)) + f = torch.cat(all_f, dim=1) + if self.reduction == LossReduction.MEAN.value: + f = torch.mean(f) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + f = torch.sum(f) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least make sure a none reduction maintains a + # broadcastable shape + broadcast_shape = list(f.shape[0:2]) + [1] * (len(ch_input.shape) - 2) + f = f.view(broadcast_shape) + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return f + + +class LogHausdorffDTLoss(HausdorffDTLoss): + """ + Compute the logarithm of the Hausdorff Distance Transform Loss. + + This class computes the logarithm of the Hausdorff Distance Transform Loss, which is based on the distance transform. + The logarithm is computed to potentially stabilize and scale the loss values, especially when the original loss + values are very small. + + The formula for the loss is given by: + log_loss = log(HausdorffDTLoss + 1) + + Inherits from the HausdorffDTLoss class to utilize its distance transform computation. + """ + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute the logarithm of the Hausdorff Distance Transform Loss. + + Args: + input (torch.Tensor): The shape should be BNHW[D], where N is the number of classes. + target (torch.Tensor): The shape should be BNHW[D] or B1HW[D], where N is the number of classes. + + Returns: + torch.Tensor: The computed Log Hausdorff Distance Transform Loss for the given input and target. + + Raises: + Any exceptions raised by the parent class HausdorffDTLoss. + """ + log_loss: torch.Tensor = torch.log(super().forward(input, target) + 1) + return log_loss diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py new file mode 100644 index 0000000000..a4d1fb04cc --- /dev/null +++ b/tests/test_hausdorff_loss.py @@ -0,0 +1,268 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.case import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import HausdorffDTLoss, LogHausdorffDTLoss +from monai.utils import optional_import + +_, has_scipy = optional_import("scipy") + +TEST_CASES = [] +for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: + TEST_CASES.append( + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device), + }, + 0.509329, + ] + ) + TEST_CASES.append( + [ # shape: (1, 1, 1, 2, 2), (1, 1, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]], device=device), + "target": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]], device=device), + }, + 0.509329, + ] + ) + TEST_CASES.append( + [ # shape: (1, 1, 2, 2, 2), (1, 1, 2, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[[1.0, -1.0], [1.0, -1.0]], [[-1.0, 1.0], [-1.0, 1.0]]]]], device=device), + "target": torch.tensor([[[[[1.0, 0.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]]]], device=device), + }, + 0.375718, + ] + ) + TEST_CASES.append( + [ # shape: (1, 2, 2, 2), (1, 2, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device), + }, + 0.326994, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 0.758470, + ] + ) + TEST_CASES.append( + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": False, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device), + }, + 0.144659, + ] + ) + TEST_CASES.append( + [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "none"}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device, + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), + }, + [[[[0.407765]], [[0.407765]]], [[[0.5000]], [[0.5000]]]], + ] + ) + TEST_CASES.append( + [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + {"include_background": True, "to_onehot_y": True, "softmax": True}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device, + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), + }, + 0.357016, + ] + ) + TEST_CASES.append( + [ # shape: (2, 2, 3, 1), (2, 1, 3, 1) + {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "sum"}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device, + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), + }, + 1.428062, + ] + ) + TEST_CASES.append( + [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) + {"include_background": True, "sigmoid": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device), + }, + 0.509329, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 3.450064, + ] + ) + TEST_CASES.append( + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": True, "to_onehot_y": True, "other_act": lambda x: torch.log_softmax(x, dim=1)}, + { + "input": torch.tensor( + [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]], + device=device, + ), + "target": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device), + }, + 4.366613, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 2.661359, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 2.661359, + ] + ) + TEST_CASES.append( + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "batch": False}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), + }, + 2.661359, + ] + ) + + +TEST_CASES_LOG = [[*inputs, np.log(np.array(output) + 1)] for *inputs, output in TEST_CASES] + + +@skipUnless(has_scipy, "Scipy required") +class TestHausdorffDTLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = HausdorffDTLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = HausdorffDTLoss() + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 1, 2, 3)), torch.ones((1, 4, 5, 6))) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + HausdorffDTLoss(sigmoid=True, softmax=True) + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + with self.assertRaisesRegex(ValueError, ""): + HausdorffDTLoss(reduction="unknown")(chn_input, chn_target) + with self.assertRaisesRegex(ValueError, ""): + HausdorffDTLoss(reduction=None)(chn_input, chn_target) + + def test_input_warnings(self): + chn_input = torch.ones((1, 1, 1, 3)) + chn_target = torch.ones((1, 1, 1, 3)) + with self.assertWarns(Warning): + loss = HausdorffDTLoss(include_background=False) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = HausdorffDTLoss(softmax=True) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = HausdorffDTLoss(to_onehot_y=True) + loss.forward(chn_input, chn_target) + + +@skipUnless(has_scipy, "Scipy required") +class TesLogtHausdorffDTLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES_LOG) + def test_shape(self, input_param, input_data, expected_val): + result = LogHausdorffDTLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = LogHausdorffDTLoss() + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 1, 2, 3)), torch.ones((1, 4, 5, 6))) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + LogHausdorffDTLoss(sigmoid=True, softmax=True) + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + with self.assertRaisesRegex(ValueError, ""): + LogHausdorffDTLoss(reduction="unknown")(chn_input, chn_target) + with self.assertRaisesRegex(ValueError, ""): + LogHausdorffDTLoss(reduction=None)(chn_input, chn_target) + + def test_input_warnings(self): + chn_input = torch.ones((1, 1, 1, 3)) + chn_target = torch.ones((1, 1, 1, 3)) + with self.assertWarns(Warning): + loss = LogHausdorffDTLoss(include_background=False) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = LogHausdorffDTLoss(softmax=True) + loss.forward(chn_input, chn_target) + with self.assertWarns(Warning): + loss = LogHausdorffDTLoss(to_onehot_y=True) + loss.forward(chn_input, chn_target) + + +if __name__ == "__main__": + unittest.main()