diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 568c7dfc77..e929e9d605 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -96,6 +96,11 @@ Registration Losses .. autoclass:: BendingEnergyLoss :members: +`DiffusionLoss` +~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionLoss + :members: + `LocalNormalizedCrossCorrelationLoss` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d734a9d44d..92898c81ca 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -14,7 +14,7 @@ from .adversarial_loss import PatchAdversarialLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss -from .deform import BendingEnergyLoss +from .deform import BendingEnergyLoss, DiffusionLoss from .dice import ( Dice, DiceCELoss, diff --git a/monai/losses/deform.py b/monai/losses/deform.py index dd03a8eb3d..129abeedd2 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -116,3 +116,85 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return energy + + +class DiffusionLoss(_Loss): + """ + Calculate the diffusion based on first-order differentiation of pred using central finite difference. + For the original paper, please refer to + VoxelMorph: A Learning Framework for Deformable Medical Image Registration, + Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca + IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + + Adapted from: + VoxelMorph (https://github.com/voxelmorph/voxelmorph) + """ + + def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None: + """ + Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize + + def forward(self, pred: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: + Predicted dense displacement field (DDF) with shape BCH[WD], + where C is the number of spatial dimensions. + Note that diffusion loss can only be calculated + when the sizes of the DDF along all spatial dimensions are greater than 2. + + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``pred`` is not 3-d, 4-d or 5-d. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2. + ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. + + """ + if pred.ndim not in [3, 4, 5]: + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + for i in range(pred.ndim - 2): + if pred.shape[-i - 1] <= 2: + raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " + f"does not match number of spatial dimensions, {pred.ndim - 2}" + ) + + # first order gradient + first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + + diffusion = torch.tensor(0) + for dim_1, g in enumerate(first_order_gradient): + dim_1 += 2 + if self.normalize: + # We divide the partial derivative for each vector component at each voxel by the spatial size + # corresponding to that component relative to the spatial size of the vector component with respect + # to which the partial derivative is taken. + g *= pred.shape[dim_1] / spatial_dims + diffusion = diffusion + g**2 + + if self.reduction == LossReduction.MEAN.value: + diffusion = torch.mean(diffusion) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + diffusion = torch.sum(diffusion) # sum over the batch and channel dims + elif self.reduction != LossReduction.NONE.value: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return diffusion diff --git a/tests/test_diffusion_loss.py b/tests/test_diffusion_loss.py new file mode 100644 index 0000000000..05dfab95fb --- /dev/null +++ b/tests/test_diffusion_loss.py @@ -0,0 +1,116 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.deform import DiffusionLoss + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + # all first partials are zero, so the diffusion loss is also zero + [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + # all first partials are one, so the diffusion loss is also one + [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0], + # before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67 + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # we have shown in the demo notebook that + # diffusion loss is scale-invariant when the all axes have the same resolution + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # for the following case, consider the following 2D matrix: + # tensor([[[[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]], + # [[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]]]]) + # the first partials wrt x are all ones, and so are the first partials wrt y + # the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2 + [{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0], + # consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook, + # the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y + # the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689 + [ + {"normalize": True}, + {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, + (1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0, + ], +] + + +class TestDiffusionLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = DiffusionLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = DiffusionLoss() + # not in 3-d, 4-d, 5-d + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 3), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 2, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 2, 5))) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 5, 2))) + + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + + def test_ill_opts(self): + pred = torch.rand(1, 3, 5, 5, 5).to(device=device) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction="unknown")(pred) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction=None)(pred) + + +if __name__ == "__main__": + unittest.main()