From 34fb74674a09247bf7de837cd035a0fd2de3e18d Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 5 Jan 2021 22:33:10 +0000 Subject: [PATCH 01/11] 1405 add bending energy loss Signed-off-by: kate-sann5100 --- docs/source/losses.rst | 11 ++++ monai/losses/__init__.py | 1 + monai/losses/deform.py | 113 +++++++++++++++++++++++++++++++++++ tests/test_bending_energy.py | 54 +++++++++++++++++ 4 files changed, 179 insertions(+) create mode 100644 monai/losses/deform.py create mode 100644 tests/test_bending_energy.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 3f87f172d5..b550887717 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -52,3 +52,14 @@ Segmentation Losses ~~~~~~~~~~~~~ .. autoclass:: TverskyLoss :members: + +Registration Losses +------------------- + +.. automodule:: monai.losses +.. currentmodule:: monai.losses + +`BendingEnergyLoss` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: BendingEnergyLoss + :members: diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 7c3ca0cfe1..045e62ec28 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .deform import BendingEnergyLoss from .dice import ( Dice, DiceLoss, diff --git a/monai/losses/deform.py b/monai/losses/deform.py new file mode 100644 index 0000000000..e6521355f5 --- /dev/null +++ b/monai/losses/deform.py @@ -0,0 +1,113 @@ +from typing import Union + +import torch +from torch.nn.modules.loss import _Loss + +from monai.utils import LossReduction + + +def gradient_dx(fx: torch.Tensor) -> torch.Tensor: + """ + Calculate gradients on x-axis of a 3D tensor using central finite difference. + It moves the tensor along axis 1 to calculate the approximate gradient, the x axis, + dx[i] = (x[i+1] - x[i-1]) / 2. + Args: + fx: the shape should be BDHW. + + Returns: + gradient_dx: the shape should be BDHW + """ + return (fx[..., 1:-1, 1:-1, 2:] - fx[..., 1:-1, 1:-1, :-2]) / 2 + + +def gradient_dy(fy: torch.Tensor) -> torch.Tensor: + """ + Calculate gradients on y-axis of a 3D tensor using central finite difference. + It moves the tensor along axis 1 to calculate the approximate gradient, the y axis, + dy[i] = (y[i+1] - y[i-1]) / 2. + Args: + fy: the shape should be BDHW. + + Returns: + gradient_dy: the shape should be BDHW + """ + return (fy[..., 1:-1, 2:, 1:-1] - fy[..., 1:-1, :-2, 1:-1]) / 2 + + +def gradient_dz(fz: torch.Tensor) -> torch.Tensor: + """ + Calculate gradients on z-axis of a 3D tensor using central finite difference. + It moves the tensor along axis 1 to calculate the approximate gradient, the z axis, + dz[i] = (z[i+1] - z[i-1]) / 2. + Args: + fz: the shape should be BDHW. + + Returns: + gradient_dy: the shape should be BDHW + """ + return (fz[..., 2:, 1:-1, 1:-1] - fz[..., :-2, 1:-1, 1:-1]) / 2 + + +class BendingEnergyLoss(_Loss): + def __init__( + self, + reduction: Union[LossReduction, str] = LossReduction.MEAN, + ) -> None: + """ + Args: + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super(BendingEnergyLoss, self).__init__(reduction=LossReduction(reduction).value) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Calculate the bending energy based on second-order differentiation of input using central finite difference. + + Args: + input: the shape should be B3DHW + + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + + """ + assert len(input.shape) == 5 and input.shape[1] == 3, ( + f"expecting 5-d ddf input with 3 channels, " f"instead got input of shape {input.shape}" + ) + assert ( + input.shape[-1] > 4 and input.shape[-2] > 4 and input.shape[-3] > 4 + ), f"all depth, height and width must > 4, got input of shape {input.shape}" + + # first order gradient + # (batch, 3, d-2, h-2, w-2) + dfdx = gradient_dx(input) + dfdy = gradient_dy(input) + dfdz = gradient_dz(input) + + # second order gradient + # (batch, 3, d-4, h-4, w-4) + dfdxx = gradient_dx(dfdx) + dfdyy = gradient_dy(dfdy) + dfdzz = gradient_dz(dfdz) + dfdxy = gradient_dy(dfdx) + dfdyz = gradient_dz(dfdy) + dfdxz = gradient_dz(dfdx) + + # (dx + dy + dz) ** 2 = dxx + dyy + dzz + 2*(dxy + dyz + dzx) + energy = dfdxx ** 2 + dfdyy ** 2 + dfdzz ** 2 + energy += 2 * dfdxy ** 2 + 2 * dfdxz ** 2 + 2 * dfdyz ** 2 + + if self.reduction == LossReduction.MEAN.value: + energy = torch.mean(energy) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + energy = torch.sum(energy) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + pass # returns [N, n_classes] losses + else: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return energy diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py new file mode 100644 index 0000000000..8aefdb4cf7 --- /dev/null +++ b/tests/test_bending_energy.py @@ -0,0 +1,54 @@ +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.deform import BendingEnergyLoss + +TEST_CASES = [ + [ + {}, + {"input": torch.ones((1, 3, 5, 5, 5))}, + 0.0, + ], + [ + {}, + {"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, + 0.0, + ], + [ + {}, + {"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 4.0, + ], +] + + +class TestBendingEnergy(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = BendingEnergyLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = BendingEnergyLoss() + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 5, 5, 5))) + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 4, 5, 5))) + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 5, 4, 5))) + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 5, 5, 4))) + + def test_ill_opts(self): + input = torch.rand(1, 3, 5, 5, 5) + with self.assertRaisesRegex(ValueError, ""): + BendingEnergyLoss(reduction="unknown")(input) + with self.assertRaisesRegex(ValueError, ""): + BendingEnergyLoss(reduction=None)(input) + + +if __name__ == "__main__": + unittest.main() From 5fb43b6a29d87abdcf57c74d38ae83e1ac6c5b60 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 5 Jan 2021 23:02:38 +0000 Subject: [PATCH 02/11] 1405 update documentation Signed-off-by: kate-sann5100 --- docs/source/losses.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index b550887717..e1a492d611 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -56,9 +56,6 @@ Segmentation Losses Registration Losses ------------------- -.. automodule:: monai.losses -.. currentmodule:: monai.losses - `BendingEnergyLoss` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: BendingEnergyLoss From 5386871f3521b6a3b493d047428875acb5b82613 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 5 Jan 2021 23:55:52 +0000 Subject: [PATCH 03/11] 1405 fix integer divison Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index e6521355f5..e060457f49 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -45,7 +45,7 @@ def gradient_dz(fz: torch.Tensor) -> torch.Tensor: Returns: gradient_dy: the shape should be BDHW """ - return (fz[..., 2:, 1:-1, 1:-1] - fz[..., :-2, 1:-1, 1:-1]) / 2 + return (fz[..., 2:, 1:-1, 1:-1] - fz[..., :-2, 1:-1, 1:-1]) / 2.0 class BendingEnergyLoss(_Loss): From c5325172e2e1754e7e9b88180697798e6f88a601 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 6 Jan 2021 00:40:36 +0000 Subject: [PATCH 04/11] 1405 fix integer division Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index e060457f49..109799c2cb 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -17,7 +17,7 @@ def gradient_dx(fx: torch.Tensor) -> torch.Tensor: Returns: gradient_dx: the shape should be BDHW """ - return (fx[..., 1:-1, 1:-1, 2:] - fx[..., 1:-1, 1:-1, :-2]) / 2 + return (fx[..., 1:-1, 1:-1, 2:] - fx[..., 1:-1, 1:-1, :-2]) / 2.0 def gradient_dy(fy: torch.Tensor) -> torch.Tensor: @@ -31,7 +31,7 @@ def gradient_dy(fy: torch.Tensor) -> torch.Tensor: Returns: gradient_dy: the shape should be BDHW """ - return (fy[..., 1:-1, 2:, 1:-1] - fy[..., 1:-1, :-2, 1:-1]) / 2 + return (fy[..., 1:-1, 2:, 1:-1] - fy[..., 1:-1, :-2, 1:-1]) / 2.0 def gradient_dz(fz: torch.Tensor) -> torch.Tensor: From 8e7d071ece49b66abb18f13be38bc2f88c27d04a Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 6 Jan 2021 12:40:21 +0000 Subject: [PATCH 05/11] 1405 unify gradient function for all dimensions Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 52 ++++++++++++++---------------------------- 1 file changed, 17 insertions(+), 35 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 109799c2cb..3f8d622d61 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -6,46 +6,28 @@ from monai.utils import LossReduction -def gradient_dx(fx: torch.Tensor) -> torch.Tensor: +def gradient(input: torch.Tensor, dim: int) -> torch.Tensor: """ - Calculate gradients on x-axis of a 3D tensor using central finite difference. - It moves the tensor along axis 1 to calculate the approximate gradient, the x axis, + Calculate gradients on single dimension of a tensor using central finite difference. + It moves the tensor along the dimension to calculate the approximate gradient dx[i] = (x[i+1] - x[i-1]) / 2. Args: - fx: the shape should be BDHW. - - Returns: - gradient_dx: the shape should be BDHW - """ - return (fx[..., 1:-1, 1:-1, 2:] - fx[..., 1:-1, 1:-1, :-2]) / 2.0 - - -def gradient_dy(fy: torch.Tensor) -> torch.Tensor: - """ - Calculate gradients on y-axis of a 3D tensor using central finite difference. - It moves the tensor along axis 1 to calculate the approximate gradient, the y axis, - dy[i] = (y[i+1] - y[i-1]) / 2. - Args: - fy: the shape should be BDHW. - - Returns: - gradient_dy: the shape should be BDHW - """ - return (fy[..., 1:-1, 2:, 1:-1] - fy[..., 1:-1, :-2, 1:-1]) / 2.0 - - -def gradient_dz(fz: torch.Tensor) -> torch.Tensor: - """ - Calculate gradients on z-axis of a 3D tensor using central finite difference. - It moves the tensor along axis 1 to calculate the approximate gradient, the z axis, - dz[i] = (z[i+1] - z[i-1]) / 2. - Args: - fz: the shape should be BDHW. - + input: the shape should be BCH(WD). + dim: dimension to calculate gradient along. Returns: - gradient_dy: the shape should be BDHW + gradient_dx: the shape should be BCH(WD) """ - return (fz[..., 2:, 1:-1, 1:-1] - fz[..., :-2, 1:-1, 1:-1]) / 2.0 + slice_1 = slice(1, -1) + slice_2_s = slice(2, None) + slice_2_e = slice(None, -2) + slice_all = slice(None) + slicing = [slice_all, slice_all] + while len(slicing) < input.ndim: + slicing = slicing + [slice_1] + slicing_s, slicing_e = slicing, slicing + slicing_s[dim] = slice_2_s + slicing_e[dim] = slice_2_e + return (input[slicing_s] - input[slicing_e]) / 2.0 class BendingEnergyLoss(_Loss): From c36a199ff8295bbbd788947dbe87a197783c489e Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 6 Jan 2021 13:46:09 +0000 Subject: [PATCH 06/11] 1405 extend bending energy to 1-d and 2-d input Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 52 +++++++++++++++++------------------- tests/test_bending_energy.py | 16 ++++++++++- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 3f8d622d61..1388da4b43 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -21,10 +21,10 @@ def gradient(input: torch.Tensor, dim: int) -> torch.Tensor: slice_2_s = slice(2, None) slice_2_e = slice(None, -2) slice_all = slice(None) - slicing = [slice_all, slice_all] - while len(slicing) < input.ndim: - slicing = slicing + [slice_1] - slicing_s, slicing_e = slicing, slicing + slicing_s, slicing_e = [slice_all, slice_all], [slice_all, slice_all] + while len(slicing_s) < input.ndim: + slicing_s = slicing_s + [slice_1] + slicing_e = slicing_e + [slice_1] slicing_s[dim] = slice_2_s slicing_e[dim] = slice_2_e return (input[slicing_s] - input[slicing_e]) / 2.0 @@ -51,37 +51,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: Calculate the bending energy based on second-order differentiation of input using central finite difference. Args: - input: the shape should be B3DHW + input: the shape should be BCH(WD) Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ - assert len(input.shape) == 5 and input.shape[1] == 3, ( - f"expecting 5-d ddf input with 3 channels, " f"instead got input of shape {input.shape}" - ) - assert ( - input.shape[-1] > 4 and input.shape[-2] > 4 and input.shape[-3] > 4 - ), f"all depth, height and width must > 4, got input of shape {input.shape}" + assert input.ndim in [3, 4, 5], f"expecting 3-d, 4-d or 5-d input, instead got input of shape {input.shape}" + if input.ndim == 3: + assert input.shape[-1] > 4, f"all spatial dimensions must > 4, got input of shape {input.shape}" + elif input.ndim == 4: + assert ( + input.shape[-1] > 4 and input.shape[-2] > 4 + ), f"all spatial dimensions must > 4, got input of shape {input.shape}" + elif input.ndim == 5: + assert ( + input.shape[-1] > 4 and input.shape[-2] > 4 and input.shape[-3] > 4 + ), f"all spatial dimensions must > 4, got input of shape {input.shape}" # first order gradient - # (batch, 3, d-2, h-2, w-2) - dfdx = gradient_dx(input) - dfdy = gradient_dy(input) - dfdz = gradient_dz(input) + first_order_gradient = [gradient(input, dim) for dim in range(2, input.ndim)] - # second order gradient - # (batch, 3, d-4, h-4, w-4) - dfdxx = gradient_dx(dfdx) - dfdyy = gradient_dy(dfdy) - dfdzz = gradient_dz(dfdz) - dfdxy = gradient_dy(dfdx) - dfdyz = gradient_dz(dfdy) - dfdxz = gradient_dz(dfdx) - - # (dx + dy + dz) ** 2 = dxx + dyy + dzz + 2*(dxy + dyz + dzx) - energy = dfdxx ** 2 + dfdyy ** 2 + dfdzz ** 2 - energy += 2 * dfdxy ** 2 + 2 * dfdxz ** 2 + 2 * dfdyz ** 2 + energy = 0 + for dim_1, g in enumerate(first_order_gradient): + dim_1 += 2 + energy += gradient(g, dim_1) ** 2 + for dim_2 in range(dim_1 + 1, input.ndim): + energy += 2 * gradient(g, dim_2) ** 2 if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average @@ -92,4 +88,4 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') - return energy + return energy \ No newline at end of file diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index 8aefdb4cf7..8488953b70 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -22,6 +22,16 @@ {"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, 4.0, ], + [ + {}, + {"input": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, + 4.0, + ], + [ + {}, + {"input": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2}, + 4.0, + ], ] @@ -33,8 +43,12 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = BendingEnergyLoss() + # not in 3-d, 4-d, 5-d + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3))) with self.assertRaisesRegex(AssertionError, ""): - loss.forward(torch.ones((1, 5, 5, 5))) + loss.forward(torch.ones((1, 3, 5, 5, 5, 5))) + # spatial_dim < 5 with self.assertRaisesRegex(AssertionError, ""): loss.forward(torch.ones((1, 3, 4, 5, 5))) with self.assertRaisesRegex(AssertionError, ""): From a1155aab13c42399d303aa59028485806af79084 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 6 Jan 2021 14:10:17 +0000 Subject: [PATCH 07/11] 1405 auto style fix Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 1388da4b43..6b535b8893 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -66,7 +66,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ), f"all spatial dimensions must > 4, got input of shape {input.shape}" elif input.ndim == 5: assert ( - input.shape[-1] > 4 and input.shape[-2] > 4 and input.shape[-3] > 4 + input.shape[-1] > 4 and input.shape[-2] > 4 and input.shape[-3] > 4 ), f"all spatial dimensions must > 4, got input of shape {input.shape}" # first order gradient @@ -88,4 +88,4 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') - return energy \ No newline at end of file + return energy From 5342fdd667b165d6f246c94733f860933075ba4c Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 6 Jan 2021 16:18:14 +0000 Subject: [PATCH 08/11] 1405 fix typing overflow Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 6b535b8893..30ff51b4e8 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -6,7 +6,7 @@ from monai.utils import LossReduction -def gradient(input: torch.Tensor, dim: int) -> torch.Tensor: +def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: """ Calculate gradients on single dimension of a tensor using central finite difference. It moves the tensor along the dimension to calculate the approximate gradient @@ -70,14 +70,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ), f"all spatial dimensions must > 4, got input of shape {input.shape}" # first order gradient - first_order_gradient = [gradient(input, dim) for dim in range(2, input.ndim)] + first_order_gradient = [spatial_gradient(input, dim) for dim in range(2, input.ndim)] - energy = 0 + energy = torch.tensor(0) for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 - energy += gradient(g, dim_1) ** 2 + energy += spatial_gradient(g, dim_1) ** 2 for dim_2 in range(dim_1 + 1, input.ndim): - energy += 2 * gradient(g, dim_2) ** 2 + energy += 2 * spatial_gradient(g, dim_2) ** 2 if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average From da80d5dc7d0a3865c7d4bac5cbef872c721e5cbb Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 6 Jan 2021 16:21:08 +0000 Subject: [PATCH 09/11] 1405 add DeepReg acknowledgement Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 30ff51b4e8..b2aeebbe65 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -11,6 +11,7 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: Calculate gradients on single dimension of a tensor using central finite difference. It moves the tensor along the dimension to calculate the approximate gradient dx[i] = (x[i+1] - x[i-1]) / 2. + Adapted from DeepReg (https://github.com/DeepRegNet/DeepReg) Args: input: the shape should be BCH(WD). dim: dimension to calculate gradient along. @@ -31,11 +32,13 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: class BendingEnergyLoss(_Loss): + def __init__( self, reduction: Union[LossReduction, str] = LossReduction.MEAN, ) -> None: """ + Adapted from DeepReg (https://github.com/DeepRegNet/DeepReg) Args: reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. From 999bd1065a439ae6d26592dec16401db8c6da2eb Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 6 Jan 2021 16:36:34 +0000 Subject: [PATCH 10/11] 1405 auto style fix and docstring update Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index b2aeebbe65..11eb193a00 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -11,7 +11,9 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: Calculate gradients on single dimension of a tensor using central finite difference. It moves the tensor along the dimension to calculate the approximate gradient dx[i] = (x[i+1] - x[i-1]) / 2. - Adapted from DeepReg (https://github.com/DeepRegNet/DeepReg) + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + Args: input: the shape should be BCH(WD). dim: dimension to calculate gradient along. @@ -32,13 +34,18 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: class BendingEnergyLoss(_Loss): + """ + Calculate the bending energy based on second-order differentiation of input using central finite difference. + + Adapted from: + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ def __init__( self, reduction: Union[LossReduction, str] = LossReduction.MEAN, ) -> None: """ - Adapted from DeepReg (https://github.com/DeepRegNet/DeepReg) Args: reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -51,8 +58,6 @@ def __init__( def forward(self, input: torch.Tensor) -> torch.Tensor: """ - Calculate the bending energy based on second-order differentiation of input using central finite difference. - Args: input: the shape should be BCH(WD) From 4c7d14421b036da393a6984e9dc8f2fd9760634c Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 6 Jan 2021 17:00:41 +0000 Subject: [PATCH 11/11] 1405 fix typing overflow Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 11eb193a00..a4be4c0178 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -83,9 +83,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: energy = torch.tensor(0) for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 - energy += spatial_gradient(g, dim_1) ** 2 + energy = spatial_gradient(g, dim_1) ** 2 + energy for dim_2 in range(dim_1 + 1, input.ndim): - energy += 2 * spatial_gradient(g, dim_2) ** 2 + energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average