Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,11 @@ Segmentation Losses
~~~~~~~~~~~~~
.. autoclass:: TverskyLoss
:members:

Registration Losses
-------------------

`BendingEnergyLoss`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: BendingEnergyLoss
:members:
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .deform import BendingEnergyLoss
from .dice import (
Dice,
DiceCELoss,
Expand Down
99 changes: 99 additions & 0 deletions monai/losses/deform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Union

import torch
from torch.nn.modules.loss import _Loss

from monai.utils import LossReduction


def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor:
"""
Calculate gradients on single dimension of a tensor using central finite difference.
It moves the tensor along the dimension to calculate the approximate gradient
dx[i] = (x[i+1] - x[i-1]) / 2.
Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)

Args:
input: the shape should be BCH(WD).
dim: dimension to calculate gradient along.
Returns:
gradient_dx: the shape should be BCH(WD)
"""
slice_1 = slice(1, -1)
slice_2_s = slice(2, None)
slice_2_e = slice(None, -2)
slice_all = slice(None)
slicing_s, slicing_e = [slice_all, slice_all], [slice_all, slice_all]
while len(slicing_s) < input.ndim:
slicing_s = slicing_s + [slice_1]
slicing_e = slicing_e + [slice_1]
slicing_s[dim] = slice_2_s
slicing_e[dim] = slice_2_e
return (input[slicing_s] - input[slicing_e]) / 2.0


class BendingEnergyLoss(_Loss):
"""
Calculate the bending energy based on second-order differentiation of input using central finite difference.

Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(
self,
reduction: Union[LossReduction, str] = LossReduction.MEAN,
) -> None:
"""
Args:
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
"""
super(BendingEnergyLoss, self).__init__(reduction=LossReduction(reduction).value)

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BCH(WD)

Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

"""
assert input.ndim in [3, 4, 5], f"expecting 3-d, 4-d or 5-d input, instead got input of shape {input.shape}"
if input.ndim == 3:
assert input.shape[-1] > 4, f"all spatial dimensions must > 4, got input of shape {input.shape}"
elif input.ndim == 4:
assert (
input.shape[-1] > 4 and input.shape[-2] > 4
), f"all spatial dimensions must > 4, got input of shape {input.shape}"
elif input.ndim == 5:
assert (
input.shape[-1] > 4 and input.shape[-2] > 4 and input.shape[-3] > 4
), f"all spatial dimensions must > 4, got input of shape {input.shape}"

# first order gradient
first_order_gradient = [spatial_gradient(input, dim) for dim in range(2, input.ndim)]

energy = torch.tensor(0)
for dim_1, g in enumerate(first_order_gradient):
dim_1 += 2
energy = spatial_gradient(g, dim_1) ** 2 + energy
for dim_2 in range(dim_1 + 1, input.ndim):
energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy

if self.reduction == LossReduction.MEAN.value:
energy = torch.mean(energy) # the batch and channel average
elif self.reduction == LossReduction.SUM.value:
energy = torch.sum(energy) # sum over the batch and channel dims
elif self.reduction == LossReduction.NONE.value:
pass # returns [N, n_classes] losses
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return energy
68 changes: 68 additions & 0 deletions tests/test_bending_energy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses.deform import BendingEnergyLoss

TEST_CASES = [
[
{},
{"input": torch.ones((1, 3, 5, 5, 5))},
0.0,
],
[
{},
{"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)},
0.0,
],
[
{},
{"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
4.0,
],
[
{},
{"input": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2},
4.0,
],
[
{},
{"input": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2},
4.0,
],
]


class TestBendingEnergy(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = BendingEnergyLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)

def test_ill_shape(self):
loss = BendingEnergyLoss()
# not in 3-d, 4-d, 5-d
with self.assertRaisesRegex(AssertionError, ""):
loss.forward(torch.ones((1, 3)))
with self.assertRaisesRegex(AssertionError, ""):
loss.forward(torch.ones((1, 3, 5, 5, 5, 5)))
# spatial_dim < 5
with self.assertRaisesRegex(AssertionError, ""):
loss.forward(torch.ones((1, 3, 4, 5, 5)))
with self.assertRaisesRegex(AssertionError, ""):
loss.forward(torch.ones((1, 3, 5, 4, 5)))
with self.assertRaisesRegex(AssertionError, ""):
loss.forward(torch.ones((1, 3, 5, 5, 4)))

def test_ill_opts(self):
input = torch.rand(1, 3, 5, 5, 5)
with self.assertRaisesRegex(ValueError, ""):
BendingEnergyLoss(reduction="unknown")(input)
with self.assertRaisesRegex(ValueError, ""):
BendingEnergyLoss(reduction=None)(input)


if __name__ == "__main__":
unittest.main()