diff --git a/docs/source/losses.rst b/docs/source/losses.rst index a6aa4d566d..5e19219fee 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -74,4 +74,12 @@ Registration Losses `GlobalMutualInformationLoss` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: GlobalMutualInformationLoss + :members: + +Loss Wrappers +-------------- + +`MultiScaleLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: MultiScaleLoss :members: \ No newline at end of file diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 591fb08f7b..b9146a6962 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -23,4 +23,5 @@ ) from .focal_loss import FocalLoss from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss +from .multi_scale import MultiScaleLoss from .tversky import TverskyLoss diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py new file mode 100644 index 0000000000..5a17bc2d07 --- /dev/null +++ b/monai/losses/multi_scale.py @@ -0,0 +1,100 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +from torch.nn.modules.loss import _Loss + +from monai.networks.layers import gaussian_1d, separable_filtering +from monai.utils import LossReduction + + +def make_gaussian_kernel(sigma: int) -> torch.Tensor: + if sigma <= 0: + raise ValueError(f"expecting positive sigma, got sigma={sigma}") + kernel = gaussian_1d(sigma=torch.tensor(sigma), truncated=3, approx="sampled", normalize=False) + return kernel + + +def make_cauchy_kernel(sigma: int) -> torch.Tensor: + if sigma <= 0: + raise ValueError(f"expecting positive sigma, got sigma={sigma}") + tail = int(sigma * 5) + k = torch.tensor([((x / sigma) ** 2 + 1) for x in range(-tail, tail + 1)]) + k = torch.reciprocal(k) + k = k / torch.sum(k) + return k + + +kernel_fn_dict = { + "gaussian": make_gaussian_kernel, + "cauchy": make_cauchy_kernel, +} + + +class MultiScaleLoss(_Loss): + """ + This is a wrapper class. + It smooths the input and target at different scales before passing them into the wrapped loss function. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + loss: _Loss, + scales: Optional[List] = None, + kernel: str = "gaussian", + reduction: Union[LossReduction, str] = LossReduction.MEAN, + ) -> None: + """ + Args: + loss: loss function to be wrapped + scales: list of scalars or None, if None, do not apply any scaling. + kernel: gaussian or cauchy. + """ + super(MultiScaleLoss, self).__init__(reduction=LossReduction(reduction).value) + if kernel not in kernel_fn_dict.keys(): + raise ValueError(f"got unsupported kernel type: {kernel}", "only support gaussian and cauchy") + self.kernel_fn = kernel_fn_dict[kernel] + self.loss = loss + self.scales = scales + + def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + if self.scales is None: + loss: torch.Tensor = self.loss(y_pred, y_true) + else: + loss_list = [] + for s in self.scales: + if s == 0: + # no smoothing + loss_list.append(self.loss(y_pred, y_true)) + else: + loss_list.append( + self.loss( + separable_filtering(y_pred, [self.kernel_fn(s)] * (y_true.ndim - 2)), + separable_filtering(y_true, [self.kernel_fn(s)] * (y_true.ndim - 2)), + ) + ) + loss = torch.stack(loss_list, dim=0) + + if self.reduction == LossReduction.MEAN.value: + loss = torch.mean(loss) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + loss = torch.sum(loss) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + pass # returns [N, n_classes] losses + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return loss diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py new file mode 100644 index 0000000000..722ae7cfce --- /dev/null +++ b/tests/test_multi_scale.py @@ -0,0 +1,60 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import DiceLoss +from monai.losses.multi_scale import MultiScaleLoss + +dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5) + +TEST_CASES = [ + [ + {"loss": dice_loss, "scales": None, "kernel": "gaussian"}, + {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + 0.307576, + ], + [ + {"loss": dice_loss, "scales": [0, 1], "kernel": "gaussian"}, + {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + 0.463116, + ], + [ + {"loss": dice_loss, "scales": [0, 1, 2], "kernel": "cauchy"}, + { + "y_pred": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]]), + "y_true": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]]), + }, + 0.715228, + ], +] + + +class TestMultiScale(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = MultiScaleLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + MultiScaleLoss(loss=dice_loss, kernel="none") + with self.assertRaisesRegex(ValueError, ""): + MultiScaleLoss(loss=dice_loss, scales=[-1])(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) + with self.assertRaisesRegex(ValueError, ""): + MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) + + +if __name__ == "__main__": + unittest.main()