From 7e4f30656042c9e877e623c61ac48c13b6481569 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 3 Feb 2021 23:47:57 +0000 Subject: [PATCH 1/4] 1525 add MultiScaleLoss Signed-off-by: kate-sann5100 --- monai/losses/multi_scale.py | 80 +++++++++++++++++++++++++++++++++++ tests/test_multi_scale.py | 84 +++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 monai/losses/multi_scale.py create mode 100644 tests/test_multi_scale.py diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py new file mode 100644 index 0000000000..77c6a90f2f --- /dev/null +++ b/monai/losses/multi_scale.py @@ -0,0 +1,80 @@ +from typing import Optional, List, Union + +import torch +from torch.nn.modules.loss import _Loss + +from monai.networks.layers import separable_filtering, gaussian_1d + + +def make_gaussian_kernel(sigma: int) -> torch.Tensor: + if sigma <= 0: + raise ValueError(f"expecting postive sigma, got sign={sigma}") + sigma = torch.tensor(sigma) + kernel = gaussian_1d(sigma=sigma, truncated=3, approx="sampled", normalize=False) + return kernel + + +def make_cauchy_kernel(sigma: int) -> torch.Tensor: + """ + Approximating cauchy kernel in 1d. + + :param sigma: int, defining standard deviation of kernel. + :return: shape = (dim, ) + """ + assert sigma > 0 + tail = int(sigma * 5) + k = torch.tensor([((x / sigma) ** 2 + 1) for x in range(-tail, tail + 1)]) + k = torch.reciprocal(k) + k = k / torch.sum(k) + return k + + +kernel_fn_dict = { + "gaussian": make_gaussian_kernel, + "cauchy": make_cauchy_kernel, +} + + +class MultiScaleLoss(_Loss): + def __init__( + self, + loss: _Loss, + scales: Optional[List] = None, + kernel: str = "gaussian", + ): + """ + Args: + scales: list of scalars or None, if None, do not apply any scaling. + kernel: gaussian or cauchy. + reduction: using SUM reduction over batch axis, + calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. + name: str, name of the loss. + """ + super(MultiScaleLoss, self).__init__() + if kernel not in kernel_fn_dict.keys(): + raise ValueError(f"got unsupported kernel type: {kernel}", + "only support gaussian and cauchy") + self.kernel_fn = kernel_fn_dict[kernel] + self.loss = loss + self.scales = scales + + def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + if self.scales is None: + return self.loss(y_pred, y_true) + losses = [] + for s in self.scales: + if s == 0: + # no smoothing + losses.append( + self.loss(y_pred, y_true) + ) + else: + losses.append( + self.loss( + separable_filtering(y_pred, [self.kernel_fn(s)] * (y_true.ndim - 2)), + separable_filtering(y_true, [self.kernel_fn(s)] * (y_true.ndim - 2)), + ) + ) + loss = torch.mean(torch.stack(losses, dim=0), dim=0) + return loss + diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py new file mode 100644 index 0000000000..df8fd9188b --- /dev/null +++ b/tests/test_multi_scale.py @@ -0,0 +1,84 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +import numpy as np +from parameterized import parameterized + +from monai.losses import DiceLoss +from monai.losses.multi_scale import MultiScaleLoss + +dice_loss = DiceLoss( + include_background=True, + sigmoid=True, + smooth_nr=1e-5, + smooth_dr=1e-5 +) + +TEST_CASES = [ + [ + { + "loss": dice_loss, + "scales": None, + "kernel": "gaussian" + }, + { + "y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), + "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]) + }, + 0.307576, + ], + [ + { + "loss": dice_loss, + "scales": [0, 1], + "kernel": "gaussian" + }, + { + "y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), + "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]) + }, + 0.463116, + ], + [ + { + "loss": dice_loss, + "scales": [0, 1, 2], + "kernel": "cauchy" + }, + { + "y_pred": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]]), + "y_true": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]]) + }, + 0.715228, + ] + +] + + +class TestMultiScale(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = MultiScaleLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_opts(self): + with self.assertRaisesRegex(ValueError, ""): + MultiScaleLoss(loss=dice_loss, kernel="none") + with self.assertRaisesRegex(ValueError, ""): + MultiScaleLoss(loss=dice_loss, scales=[-1])( + torch.ones((1, 1, 3)), torch.ones((1, 1, 3)) + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 4137d0449cb530a146b76c531fbdcd51888d1bfb Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 4 Feb 2021 00:04:35 +0000 Subject: [PATCH 2/4] 1525 add documentation Signed-off-by: kate-sann5100 --- docs/source/losses.rst | 8 +++++++ monai/losses/__init__.py | 1 + monai/losses/multi_scale.py | 36 +++++++++++++--------------- tests/test_multi_scale.py | 48 +++++++++---------------------------- 4 files changed, 37 insertions(+), 56 deletions(-) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index a6aa4d566d..5e19219fee 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -74,4 +74,12 @@ Registration Losses `GlobalMutualInformationLoss` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: GlobalMutualInformationLoss + :members: + +Loss Wrappers +-------------- + +`MultiScaleLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: MultiScaleLoss :members: \ No newline at end of file diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 591fb08f7b..b9146a6962 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -23,4 +23,5 @@ ) from .focal_loss import FocalLoss from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss +from .multi_scale import MultiScaleLoss from .tversky import TverskyLoss diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py index 77c6a90f2f..4201548c92 100644 --- a/monai/losses/multi_scale.py +++ b/monai/losses/multi_scale.py @@ -1,9 +1,9 @@ -from typing import Optional, List, Union +from typing import List, Optional import torch from torch.nn.modules.loss import _Loss -from monai.networks.layers import separable_filtering, gaussian_1d +from monai.networks.layers import gaussian_1d, separable_filtering def make_gaussian_kernel(sigma: int) -> torch.Tensor: @@ -15,13 +15,8 @@ def make_gaussian_kernel(sigma: int) -> torch.Tensor: def make_cauchy_kernel(sigma: int) -> torch.Tensor: - """ - Approximating cauchy kernel in 1d. - - :param sigma: int, defining standard deviation of kernel. - :return: shape = (dim, ) - """ - assert sigma > 0 + if sigma <= 0: + raise ValueError(f"expecting postive sigma, got sign={sigma}") tail = int(sigma * 5) k = torch.tensor([((x / sigma) ** 2 + 1) for x in range(-tail, tail + 1)]) k = torch.reciprocal(k) @@ -36,24 +31,30 @@ def make_cauchy_kernel(sigma: int) -> torch.Tensor: class MultiScaleLoss(_Loss): + """ + This is a wrapper class. + It smooths the input and target at different scales before passing them into the wrapped loss function. + The output is the average loss at all scales. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + def __init__( self, loss: _Loss, scales: Optional[List] = None, kernel: str = "gaussian", - ): + ) -> None: """ Args: + loss: loss function to be wrapped scales: list of scalars or None, if None, do not apply any scaling. kernel: gaussian or cauchy. - reduction: using SUM reduction over batch axis, - calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. - name: str, name of the loss. """ super(MultiScaleLoss, self).__init__() if kernel not in kernel_fn_dict.keys(): - raise ValueError(f"got unsupported kernel type: {kernel}", - "only support gaussian and cauchy") + raise ValueError(f"got unsupported kernel type: {kernel}", "only support gaussian and cauchy") self.kernel_fn = kernel_fn_dict[kernel] self.loss = loss self.scales = scales @@ -65,9 +66,7 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: for s in self.scales: if s == 0: # no smoothing - losses.append( - self.loss(y_pred, y_true) - ) + losses.append(self.loss(y_pred, y_true)) else: losses.append( self.loss( @@ -77,4 +76,3 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: ) loss = torch.mean(torch.stack(losses, dim=0), dim=0) return loss - diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index df8fd9188b..aa91a4380e 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -10,58 +10,34 @@ # limitations under the License. import unittest -import torch import numpy as np +import torch from parameterized import parameterized from monai.losses import DiceLoss from monai.losses.multi_scale import MultiScaleLoss -dice_loss = DiceLoss( - include_background=True, - sigmoid=True, - smooth_nr=1e-5, - smooth_dr=1e-5 -) +dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5) TEST_CASES = [ [ - { - "loss": dice_loss, - "scales": None, - "kernel": "gaussian" - }, - { - "y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]) - }, + {"loss": dice_loss, "scales": None, "kernel": "gaussian"}, + {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ - { - "loss": dice_loss, - "scales": [0, 1], - "kernel": "gaussian" - }, - { - "y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), - "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]) - }, + {"loss": dice_loss, "scales": [0, 1], "kernel": "gaussian"}, + {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.463116, ], [ - { - "loss": dice_loss, - "scales": [0, 1, 2], - "kernel": "cauchy" - }, + {"loss": dice_loss, "scales": [0, 1, 2], "kernel": "cauchy"}, { "y_pred": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]]), - "y_true": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]]) + "y_true": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]]), }, 0.715228, - ] - + ], ] @@ -75,10 +51,8 @@ def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): MultiScaleLoss(loss=dice_loss, kernel="none") with self.assertRaisesRegex(ValueError, ""): - MultiScaleLoss(loss=dice_loss, scales=[-1])( - torch.ones((1, 1, 3)), torch.ones((1, 1, 3)) - ) + MultiScaleLoss(loss=dice_loss, scales=[-1])(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From bc74c57bcef52524163ad9118604b5cf4194b45d Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 4 Feb 2021 00:16:04 +0000 Subject: [PATCH 3/4] 1525 reformat Signed-off-by: kate-sann5100 --- monai/losses/multi_scale.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py index 4201548c92..347faf238a 100644 --- a/monai/losses/multi_scale.py +++ b/monai/losses/multi_scale.py @@ -9,8 +9,7 @@ def make_gaussian_kernel(sigma: int) -> torch.Tensor: if sigma <= 0: raise ValueError(f"expecting postive sigma, got sign={sigma}") - sigma = torch.tensor(sigma) - kernel = gaussian_1d(sigma=sigma, truncated=3, approx="sampled", normalize=False) + kernel = gaussian_1d(sigma=torch.tensor(sigma), truncated=3, approx="sampled", normalize=False) return kernel @@ -61,7 +60,8 @@ def __init__( def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: if self.scales is None: - return self.loss(y_pred, y_true) + loss: torch.Tensor = self.loss(y_pred, y_true) + return loss losses = [] for s in self.scales: if s == 0: From 85431948acbc84ac672cfe3f3659832c98ce429f Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 4 Feb 2021 13:57:31 +0000 Subject: [PATCH 4/4] 1525 reformat Signed-off-by: kate-sann5100 --- monai/losses/multi_scale.py | 58 +++++++++++++++++++++++++------------ tests/test_multi_scale.py | 2 ++ 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py index 347faf238a..5a17bc2d07 100644 --- a/monai/losses/multi_scale.py +++ b/monai/losses/multi_scale.py @@ -1,21 +1,33 @@ -from typing import List, Optional +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union import torch from torch.nn.modules.loss import _Loss from monai.networks.layers import gaussian_1d, separable_filtering +from monai.utils import LossReduction def make_gaussian_kernel(sigma: int) -> torch.Tensor: if sigma <= 0: - raise ValueError(f"expecting postive sigma, got sign={sigma}") + raise ValueError(f"expecting positive sigma, got sigma={sigma}") kernel = gaussian_1d(sigma=torch.tensor(sigma), truncated=3, approx="sampled", normalize=False) return kernel def make_cauchy_kernel(sigma: int) -> torch.Tensor: if sigma <= 0: - raise ValueError(f"expecting postive sigma, got sign={sigma}") + raise ValueError(f"expecting positive sigma, got sigma={sigma}") tail = int(sigma * 5) k = torch.tensor([((x / sigma) ** 2 + 1) for x in range(-tail, tail + 1)]) k = torch.reciprocal(k) @@ -33,7 +45,6 @@ class MultiScaleLoss(_Loss): """ This is a wrapper class. It smooths the input and target at different scales before passing them into the wrapped loss function. - The output is the average loss at all scales. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) @@ -44,6 +55,7 @@ def __init__( loss: _Loss, scales: Optional[List] = None, kernel: str = "gaussian", + reduction: Union[LossReduction, str] = LossReduction.MEAN, ) -> None: """ Args: @@ -51,7 +63,7 @@ def __init__( scales: list of scalars or None, if None, do not apply any scaling. kernel: gaussian or cauchy. """ - super(MultiScaleLoss, self).__init__() + super(MultiScaleLoss, self).__init__(reduction=LossReduction(reduction).value) if kernel not in kernel_fn_dict.keys(): raise ValueError(f"got unsupported kernel type: {kernel}", "only support gaussian and cauchy") self.kernel_fn = kernel_fn_dict[kernel] @@ -61,18 +73,28 @@ def __init__( def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: if self.scales is None: loss: torch.Tensor = self.loss(y_pred, y_true) - return loss - losses = [] - for s in self.scales: - if s == 0: - # no smoothing - losses.append(self.loss(y_pred, y_true)) - else: - losses.append( - self.loss( - separable_filtering(y_pred, [self.kernel_fn(s)] * (y_true.ndim - 2)), - separable_filtering(y_true, [self.kernel_fn(s)] * (y_true.ndim - 2)), + else: + loss_list = [] + for s in self.scales: + if s == 0: + # no smoothing + loss_list.append(self.loss(y_pred, y_true)) + else: + loss_list.append( + self.loss( + separable_filtering(y_pred, [self.kernel_fn(s)] * (y_true.ndim - 2)), + separable_filtering(y_true, [self.kernel_fn(s)] * (y_true.ndim - 2)), + ) ) - ) - loss = torch.mean(torch.stack(losses, dim=0), dim=0) + loss = torch.stack(loss_list, dim=0) + + if self.reduction == LossReduction.MEAN.value: + loss = torch.mean(loss) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + loss = torch.sum(loss) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + pass # returns [N, n_classes] losses + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + return loss diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index aa91a4380e..722ae7cfce 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -52,6 +52,8 @@ def test_ill_opts(self): MultiScaleLoss(loss=dice_loss, kernel="none") with self.assertRaisesRegex(ValueError, ""): MultiScaleLoss(loss=dice_loss, scales=[-1])(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) + with self.assertRaisesRegex(ValueError, ""): + MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) if __name__ == "__main__":