From 4595dbd70bb026787070a71c957f27b4e773ba37 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Mon, 24 Jul 2023 22:28:28 +0530 Subject: [PATCH 1/6] feat: add clDice loss Signed-off-by: Saurav Maheshkar --- monai/losses/cldice.py | 172 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100644 monai/losses/cldice.py diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py new file mode 100644 index 0000000000..f8d00c2ac6 --- /dev/null +++ b/monai/losses/cldice.py @@ -0,0 +1,172 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + + +def soft_erode(img: torch.Tensor) -> torch.Tensor: + """ + Perform soft erosion on the input image + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L6 + """ + if len(img.shape) == 4: + p1 = -F.max_pool2d(-img, (3, 1), (1, 1), (1, 0)) + p2 = -F.max_pool2d(-img, (1, 3), (1, 1), (0, 1)) + return torch.min(p1, p2) + elif len(img.shape) == 5: + p1 = -F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)) + p2 = -F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)) + p3 = -F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)) + return torch.min(torch.min(p1, p2), p3) + + +def soft_dilate(img: torch.Tensor) -> torch.Tensor: + """ + Perform soft dilation on the input image + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L18 + """ + if len(img.shape) == 4: + return F.max_pool2d(img, (3, 3), (1, 1), (1, 1)) + elif len(img.shape) == 5: + return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) + + +def soft_open(img: torch.Tensor) -> torch.Tensor: + """ + Wrapper function to perform soft opening on the input image + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L25 + """ + return soft_dilate(soft_erode(img)) + + +def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: + """ + Perform soft skeletonization on the input image + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L29 + + Args: + img: input image + iter_: number of iterations for skeletonization + + Returns: + skeletonized image + """ + img1 = soft_open(img) + skel = F.relu(img - img1) + for _ in range(iter_): + img = soft_erode(img) + img1 = soft_open(img) + delta = F.relu(img - img1) + skel = skel + F.relu(delta - skel * delta) + return skel + + +def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + """ + Function to compute soft dice loss + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22 + + Args: + y_true: the shape should be BCH(WD) + y_pred: the shape should be BCH(WD) + + Returns: + dice loss + """ + smooth = 1 + intersection = torch.sum((y_true * y_pred)[:, 1:, ...]) + coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth) + return 1.0 - coeff + + +class SoftclDiceLoss(_Loss): + """ + Compute the Soft clDice loss defined in: + + Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function + for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7 + """ + + def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None: + """ + Args: + iter_: Number of iterations for skeletonization + smooth: Smoothing parameter + """ + super().__init__() + self.iter = iter_ + self.smooth = smooth + + def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + skel_pred = soft_skel(y_pred, self.iter) + skel_true = soft_skel(y_true, self.iter) + tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_pred[:, 1:, ...]) + self.smooth + ) + tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_true[:, 1:, ...]) + self.smooth + ) + cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + return cl_dice + + +class SoftDiceclDiceLoss(_Loss): + """ + Compute the Soft clDice loss defined in: + + Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function + for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) + + Adapted from: + https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38 + """ + + def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None: + """ + Args: + iter_: Number of iterations for skeletonization + smooth: Smoothing parameter + alpha: Weighing factor for cldice + """ + super().__init__() + self.iter = iter_ + self.smooth = smooth + self.alpha = alpha + + def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + dice = soft_dice(y_true, y_pred) + skel_pred = soft_skel(y_pred, self.iter) + skel_true = soft_skel(y_true, self.iter) + tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_pred[:, 1:, ...]) + self.smooth + ) + tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( + torch.sum(skel_true[:, 1:, ...]) + self.smooth + ) + cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + return (1.0 - self.alpha) * dice + self.alpha * cl_dice From aa7200c32f7cbdbc611488dc260c1941c980576f Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Mon, 24 Jul 2023 22:30:02 +0530 Subject: [PATCH 2/6] feat: add clDice imports to losses/__init__.py Signed-off-by: Saurav Maheshkar --- monai/losses/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 9e09b0b123..db6b133ef0 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss from .dice import ( From aadce2a8540796b75f311b1df5439bb41d38600e Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Mon, 24 Jul 2023 22:54:29 +0530 Subject: [PATCH 3/6] fix: add type hints Signed-off-by: Saurav Maheshkar --- monai/losses/cldice.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py index f8d00c2ac6..f3dbb9be68 100644 --- a/monai/losses/cldice.py +++ b/monai/losses/cldice.py @@ -16,7 +16,7 @@ from torch.nn.modules.loss import _Loss -def soft_erode(img: torch.Tensor) -> torch.Tensor: +def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore """ Perform soft erosion on the input image @@ -26,15 +26,15 @@ def soft_erode(img: torch.Tensor) -> torch.Tensor: if len(img.shape) == 4: p1 = -F.max_pool2d(-img, (3, 1), (1, 1), (1, 0)) p2 = -F.max_pool2d(-img, (1, 3), (1, 1), (0, 1)) - return torch.min(p1, p2) + return torch.min(p1, p2) # type: ignore elif len(img.shape) == 5: p1 = -F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)) p2 = -F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)) p3 = -F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)) - return torch.min(torch.min(p1, p2), p3) + return torch.min(torch.min(p1, p2), p3) # type: ignore -def soft_dilate(img: torch.Tensor) -> torch.Tensor: +def soft_dilate(img: torch.Tensor) -> torch.Tensor: # type: ignore """ Perform soft dilation on the input image @@ -42,9 +42,9 @@ def soft_dilate(img: torch.Tensor) -> torch.Tensor: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L18 """ if len(img.shape) == 4: - return F.max_pool2d(img, (3, 3), (1, 1), (1, 1)) + return F.max_pool2d(img, (3, 3), (1, 1), (1, 1)) # type: ignore elif len(img.shape) == 5: - return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) + return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) # type: ignore def soft_open(img: torch.Tensor) -> torch.Tensor: @@ -98,7 +98,8 @@ def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: smooth = 1 intersection = torch.sum((y_true * y_pred)[:, 1:, ...]) coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth) - return 1.0 - coeff + soft_dice: torch.Tensor = 1.0 - coeff + return soft_dice class SoftclDiceLoss(_Loss): @@ -131,7 +132,7 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( torch.sum(skel_true[:, 1:, ...]) + self.smooth ) - cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) + cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) return cl_dice @@ -169,4 +170,5 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: torch.sum(skel_true[:, 1:, ...]) + self.smooth ) cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) - return (1.0 - self.alpha) * dice + self.alpha * cl_dice + total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice + return total_loss From 4dd80a45256a08e0a3b1cff7b5f7306d0a70be39 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 25 Jul 2023 12:32:08 +0530 Subject: [PATCH 4/6] feat: add tests for clDice Signed-off-by: Saurav Maheshkar --- monai/losses/cldice.py | 30 +++++++++++++------- tests/test_cldice_loss.py | 59 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 10 deletions(-) create mode 100644 tests/test_cldice_loss.py diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py index f3dbb9be68..ae769ce62d 100644 --- a/monai/losses/cldice.py +++ b/monai/losses/cldice.py @@ -20,17 +20,20 @@ def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore """ Perform soft erosion on the input image + Args: + img: the shape should be BCH(WD) + Adapted from: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L6 """ if len(img.shape) == 4: - p1 = -F.max_pool2d(-img, (3, 1), (1, 1), (1, 0)) - p2 = -F.max_pool2d(-img, (1, 3), (1, 1), (0, 1)) + p1 = -(F.max_pool2d(-img, (3, 1), (1, 1), (1, 0))) + p2 = -(F.max_pool2d(-img, (1, 3), (1, 1), (0, 1))) return torch.min(p1, p2) # type: ignore elif len(img.shape) == 5: - p1 = -F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)) - p2 = -F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)) - p3 = -F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)) + p1 = -(F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0))) + p2 = -(F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0))) + p3 = -(F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1))) return torch.min(torch.min(p1, p2), p3) # type: ignore @@ -38,6 +41,9 @@ def soft_dilate(img: torch.Tensor) -> torch.Tensor: # type: ignore """ Perform soft dilation on the input image + Args: + img: the shape should be BCH(WD) + Adapted from: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L18 """ @@ -51,10 +57,15 @@ def soft_open(img: torch.Tensor) -> torch.Tensor: """ Wrapper function to perform soft opening on the input image + Args: + img: the shape should be BCH(WD) + Adapted from: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L25 """ - return soft_dilate(soft_erode(img)) + eroded_image = soft_erode(img) + dilated_image = soft_dilate(eroded_image) + return dilated_image def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: @@ -65,7 +76,7 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L29 Args: - img: input image + img: the shape should be BCH(WD) iter_: number of iterations for skeletonization Returns: @@ -81,7 +92,7 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: return skel -def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: +def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: int = 1) -> torch.Tensor: """ Function to compute soft dice loss @@ -95,7 +106,6 @@ def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: Returns: dice loss """ - smooth = 1 intersection = torch.sum((y_true * y_pred)[:, 1:, ...]) coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth) soft_dice: torch.Tensor = 1.0 - coeff @@ -160,7 +170,7 @@ def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> N self.alpha = alpha def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - dice = soft_dice(y_true, y_pred) + dice = soft_dice(y_true, y_pred, self.smooth) skel_pred = soft_skel(y_pred, self.iter) skel_true = soft_skel(y_true, self.iter) tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( diff --git a/tests/test_cldice_loss.py b/tests/test_cldice_loss.py new file mode 100644 index 0000000000..e69a7670ae --- /dev/null +++ b/tests/test_cldice_loss.py @@ -0,0 +1,59 @@ +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import SoftclDiceLoss, SoftDiceclDiceLoss + +TEST_CASES = [ + [ # shape: (1, 4), (1, 4) + {"y_pred": torch.ones((100, 3, 256, 256)), "y_true": torch.ones((100, 3, 256, 256))}, + 0.0, + ], + [ # shape: (1, 5), (1, 5) + { + "y_pred": torch.ones((100, 3, 256, 256, 5)), + "y_true": torch.ones((100, 3, 256, 256, 5)), + }, + 0.0, + ], +] + + +class TestclDiceLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, y_pred_data, expected_val): + loss = SoftclDiceLoss() + loss_dice = SoftDiceclDiceLoss() + result = loss(**y_pred_data) + result_dice = loss_dice(**y_pred_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(result_dice.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_with_cuda(self): + loss = SoftclDiceLoss() + loss_dice = SoftDiceclDiceLoss() + i = torch.ones((100, 3, 256, 256)) + j = torch.ones((100, 3, 256, 256)) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + output_dice = loss_dice(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() From a2dab6066f344f08ae48a0c4725789012e3aa46a Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 25 Jul 2023 12:41:46 +0530 Subject: [PATCH 5/6] style: lint clDice tests Signed-off-by: Saurav Maheshkar --- tests/test_cldice_loss.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_cldice_loss.py b/tests/test_cldice_loss.py index e69a7670ae..109186b5d1 100644 --- a/tests/test_cldice_loss.py +++ b/tests/test_cldice_loss.py @@ -22,10 +22,7 @@ 0.0, ], [ # shape: (1, 5), (1, 5) - { - "y_pred": torch.ones((100, 3, 256, 256, 5)), - "y_true": torch.ones((100, 3, 256, 256, 5)), - }, + {"y_pred": torch.ones((100, 3, 256, 256, 5)), "y_true": torch.ones((100, 3, 256, 256, 5))}, 0.0, ], ] From 2303b308865b6f369d1f19e6eb75ce23d35f8737 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Tue, 25 Jul 2023 13:01:49 +0530 Subject: [PATCH 6/6] fix: type hint Signed-off-by: Saurav Maheshkar --- monai/losses/cldice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/cldice.py b/monai/losses/cldice.py index ae769ce62d..5c6a721e1d 100644 --- a/monai/losses/cldice.py +++ b/monai/losses/cldice.py @@ -92,7 +92,7 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: return skel -def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: int = 1) -> torch.Tensor: +def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor: """ Function to compute soft dice loss