Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ Registration Losses
.. autoclass:: BendingEnergyLoss
:members:

`DiffusionLoss`
~~~~~~~~~~~~~~~
.. autoclass:: DiffusionLoss
:members:

`LocalNormalizedCrossCorrelationLoss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNormalizedCrossCorrelationLoss
Expand Down
2 changes: 1 addition & 1 deletion monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .adversarial_loss import PatchAdversarialLoss
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
from .deform import BendingEnergyLoss
from .deform import BendingEnergyLoss, DiffusionLoss
from .dice import (
Dice,
DiceCELoss,
Expand Down
82 changes: 82 additions & 0 deletions monai/losses/deform.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,85 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return energy


class DiffusionLoss(_Loss):
"""
Calculate the diffusion based on first-order differentiation of pred using central finite difference.
For the original paper, please refer to
VoxelMorph: A Learning Framework for Deformable Medical Image Registration,
Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca
IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.

Adapted from:
VoxelMorph (https://github.com/voxelmorph/voxelmorph)
"""

def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None:
"""
Args:
normalize:
Whether to divide out spatial sizes in order to make the computation roughly
invariant to image scale (i.e. vector field sampling resolution). Defaults to False.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.normalize = normalize

def forward(self, pred: torch.Tensor) -> torch.Tensor:
"""
Args:
pred:
Predicted dense displacement field (DDF) with shape BCH[WD],
where C is the number of spatial dimensions.
Note that diffusion loss can only be calculated
when the sizes of the DDF along all spatial dimensions are greater than 2.

Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2.
ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.

"""
if pred.ndim not in [3, 4, 5]:
raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}")
for i in range(pred.ndim - 2):
if pred.shape[-i - 1] <= 2:
raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}")
if pred.shape[1] != pred.ndim - 2:
raise ValueError(
f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, "
f"does not match number of spatial dimensions, {pred.ndim - 2}"
)

# first order gradient
first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]

# spatial dimensions in a shape suited for broadcasting below
if self.normalize:
spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,))

diffusion = torch.tensor(0)
for dim_1, g in enumerate(first_order_gradient):
dim_1 += 2
if self.normalize:
# We divide the partial derivative for each vector component at each voxel by the spatial size
# corresponding to that component relative to the spatial size of the vector component with respect
# to which the partial derivative is taken.
g *= pred.shape[dim_1] / spatial_dims
diffusion = diffusion + g**2

if self.reduction == LossReduction.MEAN.value:
diffusion = torch.mean(diffusion) # the batch and channel average
elif self.reduction == LossReduction.SUM.value:
diffusion = torch.sum(diffusion) # sum over the batch and channel dims
elif self.reduction != LossReduction.NONE.value:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return diffusion
116 changes: 116 additions & 0 deletions tests/test_diffusion_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses.deform import DiffusionLoss

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASES = [
# all first partials are zero, so the diffusion loss is also zero
[{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],
# all first partials are one, so the diffusion loss is also one
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0],
# before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67
[
{"normalize": False},
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
56.0 / 3.0,
],
# same as the previous case
[
{"normalize": False},
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
56.0 / 3.0,
],
# same as the previous case
[{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
# we have shown in the demo notebook that
# diffusion loss is scale-invariant when the all axes have the same resolution
[
{"normalize": True},
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
56.0 / 3.0,
],
[
{"normalize": True},
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
56.0 / 3.0,
],
[{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
# for the following case, consider the following 2D matrix:
# tensor([[[[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4],
# [3, 4, 5],
# [4, 5, 6]],
# [[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4],
# [3, 4, 5],
# [4, 5, 6]]]])
# the first partials wrt x are all ones, and so are the first partials wrt y
# the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2
[{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0],
# consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook,
# the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y
# the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689
[
{"normalize": True},
{"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)},
(1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0,
],
]


class TestDiffusionLoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = DiffusionLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)

def test_ill_shape(self):
loss = DiffusionLoss()
# not in 3-d, 4-d, 5-d
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
loss.forward(torch.ones((1, 3), device=device))
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 2, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 2, 5)))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 5, 2)))

# number of vector components unequal to number of spatial dims
with self.assertRaisesRegex(ValueError, "Number of vector components"):
loss.forward(torch.ones((1, 2, 5, 5, 5)))
with self.assertRaisesRegex(ValueError, "Number of vector components"):
loss.forward(torch.ones((1, 2, 5, 5, 5)))

def test_ill_opts(self):
pred = torch.rand(1, 3, 5, 5, 5).to(device=device)
with self.assertRaisesRegex(ValueError, ""):
DiffusionLoss(reduction="unknown")(pred)
with self.assertRaisesRegex(ValueError, ""):
DiffusionLoss(reduction=None)(pred)


if __name__ == "__main__":
unittest.main()