diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 462c303e65..a6aa4d566d 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -70,3 +70,8 @@ Registration Losses ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss :members: + +`GlobalMutualInformationLoss` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: GlobalMutualInformationLoss + :members: \ No newline at end of file diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index ba6acafd47..591fb08f7b 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -22,5 +22,5 @@ generalized_wasserstein_dice, ) from .focal_loss import FocalLoss -from .image_dissimilarity import LocalNormalizedCrossCorrelationLoss +from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .tversky import TverskyLoss diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 1005587021..acba229121 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -17,7 +17,7 @@ from monai.utils import LossReduction -def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: +def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor: """ Calculate gradients on single dimension of a tensor using central finite difference. It moves the tensor along the dimension to calculate the approximate gradient @@ -26,7 +26,7 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: DeepReg (https://github.com/DeepRegNet/DeepReg) Args: - input: the shape should be BCH(WD). + x: the shape should be BCH(WD). dim: dimension to calculate gradient along. Returns: gradient_dx: the shape should be BCH(WD) @@ -36,17 +36,17 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: slice_2_e = slice(None, -2) slice_all = slice(None) slicing_s, slicing_e = [slice_all, slice_all], [slice_all, slice_all] - while len(slicing_s) < input.ndim: + while len(slicing_s) < x.ndim: slicing_s = slicing_s + [slice_1] slicing_e = slicing_e + [slice_1] slicing_s[dim] = slice_2_s slicing_e[dim] = slice_2_e - return (input[slicing_s] - input[slicing_e]) / 2.0 + return (x[slicing_s] - x[slicing_e]) / 2.0 class BendingEnergyLoss(_Loss): """ - Calculate the bending energy based on second-order differentiation of input using central finite difference. + Calculate the bending energy based on second-order differentiation of pred using central finite difference. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) @@ -67,35 +67,29 @@ def __init__( """ super(BendingEnergyLoss, self).__init__(reduction=LossReduction(reduction).value) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, pred: torch.Tensor) -> torch.Tensor: """ Args: - input: the shape should be BCH(WD) + pred: the shape should be BCH(WD) Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ - assert input.ndim in [3, 4, 5], f"expecting 3-d, 4-d or 5-d input, instead got input of shape {input.shape}" - if input.ndim == 3: - assert input.shape[-1] > 4, f"all spatial dimensions must > 4, got input of shape {input.shape}" - elif input.ndim == 4: - assert ( - input.shape[-1] > 4 and input.shape[-2] > 4 - ), f"all spatial dimensions must > 4, got input of shape {input.shape}" - elif input.ndim == 5: - assert ( - input.shape[-1] > 4 and input.shape[-2] > 4 and input.shape[-3] > 4 - ), f"all spatial dimensions must > 4, got input of shape {input.shape}" + if pred.ndim not in [3, 4, 5]: + raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + for i in range(pred.ndim - 2): + if pred.shape[-i - 1] <= 4: + raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}") # first order gradient - first_order_gradient = [spatial_gradient(input, dim) for dim in range(2, input.ndim)] + first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] energy = torch.tensor(0) for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 energy = spatial_gradient(g, dim_1) ** 2 + energy - for dim_2 in range(dim_1 + 1, input.ndim): + for dim_2 in range(dim_1 + 1, pred.ndim): energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy if self.reduction == LossReduction.MEAN.value: diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index d42303e154..b229a0c08f 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Tuple, Union import torch from torch.nn import functional as F @@ -63,7 +63,7 @@ def __init__( self, in_channels: int, ndim: int = 3, - kernel_size: int = 9, + kernel_size: int = 3, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-7, @@ -100,40 +100,44 @@ def __init__( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) self.kernel = kernel_dict[kernel_type](self.kernel_size) - self.kernel_vol = torch.sum(self.kernel) ** self.ndim + self.kernel_vol = self.get_kernel_vol() + self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def get_kernel_vol(self): + vol = self.kernel + for _ in range(self.ndim - 1): + vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0)) + return torch.sum(vol) + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: - input: the shape should be BNH[WD]. + pred: the shape should be BNH[WD]. target: the shape should be BNH[WD]. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ - assert ( - input.shape[1] == self.in_channels - ), f"expecting input with {self.in_channels} channels, got input of shape {input.shape}" - assert ( - input.ndim - 2 == self.ndim - ), f"expecting input with {self.ndim} spatial dimensions, got input of shape {input.shape}" - assert ( - target.shape == input.shape - ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" - - t2, p2, tp = target ** 2, input ** 2, target * input - + if pred.shape[1] != self.in_channels: + raise ValueError(f"expecting pred with {self.in_channels} channels, got pred of shape {pred.shape}") + if pred.ndim - 2 != self.ndim: + raise ValueError(f"expecting pred with {self.ndim} spatial dimensions, got pred of shape {pred.shape}") + if target.shape != pred.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") + + t2, p2, tp = target ** 2, pred ** 2, target * pred + kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred) # sum over kernel - t_sum = separable_filtering(target, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) - p_sum = separable_filtering(input, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) - t2_sum = separable_filtering(t2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) - p2_sum = separable_filtering(p2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) - tp_sum = separable_filtering(tp, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + t_sum = separable_filtering(target, kernels=[kernel] * self.ndim) + p_sum = separable_filtering(pred, kernels=[kernel] * self.ndim) + t2_sum = separable_filtering(t2, kernels=[kernel] * self.ndim) + p2_sum = separable_filtering(p2, kernels=[kernel] * self.ndim) + tp_sum = separable_filtering(tp, kernels=[kernel] * self.ndim) # average over kernel - t_avg = t_sum / self.kernel_vol - p_avg = p_sum / self.kernel_vol + t_avg = t_sum / kernel_vol + p_avg = p_sum / kernel_vol # normalized cross correlation between t and p # sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p] @@ -151,9 +155,90 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return torch.sum(ncc).neg() # sum over the batch and spatial ndims + return torch.sum(ncc).neg() # sum over the batch, channel and spatial ndims if self.reduction == LossReduction.NONE.value: return ncc.neg() if self.reduction == LossReduction.MEAN.value: - return torch.mean(ncc).neg() # average over the batch and spatial ndims + return torch.mean(ncc).neg() # average over the batch, channel and spatial ndims + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + +class GlobalMutualInformationLoss(_Loss): + """ + Differentiable global mutual information loss via Parzen windowing method. + + Reference: + https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1 + """ + + def __init__( + self, + num_bins: int = 23, + sigma_ratio: float = 0.5, + reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-7, + smooth_dr: float = 1e-7, + ) -> None: + """ + Args: + num_bins: number of bins for intensity + sigma_ratio: a hyper param for gaussian function + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid nan. + smooth_dr: a small constant added to the denominator to avoid nan. + """ + super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value) + if num_bins <= 0: + raise ValueError("num_bins must > 0, got {num_bins}") + bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) + sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio + self.preterm = 1 / (2 * sigma ** 2) + self.bin_centers = bin_centers[None, None, ...] + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + + def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + pred: the shape should be B[NDHW]. + """ + pred = torch.clamp(pred, 0, 1) + pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1) + weight = torch.exp( + -self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2 + ) # (batch, num_sample, num_bin) + weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) + probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) + return weight, probability + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: the shape should be B[NDHW]. + target: the shape should be same as the pred shape. + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + """ + if target.shape != pred.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") + wa, pa = self.parzen_windowing(pred) # (batch, num_sample, num_bin), (batch, 1, num_bin) + wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin) + pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) + + papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins) + mi = torch.sum( + pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2) + ) # (batch) + + if self.reduction == LossReduction.SUM.value: + return torch.sum(mi).neg() # sum over the batch and channel ndims + if self.reduction == LossReduction.NONE.value: + return mi.neg() + if self.reduction == LossReduction.MEAN.value: + return torch.mean(mi).neg() # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index 3ba22ebac0..f2b9a41cae 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -20,27 +20,27 @@ TEST_CASES = [ [ {}, - {"input": torch.ones((1, 3, 5, 5, 5))}, + {"pred": torch.ones((1, 3, 5, 5, 5))}, 0.0, ], [ {}, - {"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, + {"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0, ], [ {}, - {"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + {"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, 4.0, ], [ {}, - {"input": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, + {"pred": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, 4.0, ], [ {}, - {"input": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2}, + {"pred": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2}, 4.0, ], ] @@ -55,24 +55,24 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = BendingEnergyLoss() # not in 3-d, 4-d, 5-d - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3))) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 5, 5, 5))) # spatial_dim < 5 - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 4, 5, 5))) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 4, 5))) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 5, 4))) def test_ill_opts(self): - input = torch.rand(1, 3, 5, 5, 5) + pred = torch.rand(1, 3, 5, 5, 5) with self.assertRaisesRegex(ValueError, ""): - BendingEnergyLoss(reduction="unknown")(input) + BendingEnergyLoss(reduction="unknown")(pred) with self.assertRaisesRegex(ValueError, ""): - BendingEnergyLoss(reduction=None)(input) + BendingEnergyLoss(reduction=None)(pred) if __name__ == "__main__": diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py new file mode 100644 index 0000000000..252a70e85e --- /dev/null +++ b/tests/test_global_mutual_information_loss.py @@ -0,0 +1,100 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.image_dissimilarity import GlobalMutualInformationLoss + +TEST_CASES = [ + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + }, + -1.0986018, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3) + ** 2, + }, + -1.083999, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3) ** 2, + }, + -1.083999, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3) ** 2, + }, + -1.083999, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float)[None, :].div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :].div(3) ** 2, + }, + -1.083999, + ], + [ + {}, + { + "pred": torch.arange(0, 3, dtype=torch.float).div(3), + "target": torch.arange(0, 3, dtype=torch.float).div(3) ** 2, + }, + -1.1920927e-07, + ], +] + + +class TestGlobalMutualInformationLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = GlobalMutualInformationLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + + def test_ill_shape(self): + loss = GlobalMutualInformationLoss() + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) + + def test_ill_opts(self): + pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(num_bins=0)(pred, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(num_bins=-1)(pred, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(reduction="unknown")(pred, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(reduction=None)(pred, target) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index cb2f446dfc..cf8566a559 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -19,92 +19,60 @@ TEST_CASES = [ [ - {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + {"in_channels": 1, "ndim": 1, "kernel_type": "rectangular", "reduction": "sum"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "pred": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), }, - -1.0, - ], - [ - {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "rectangular"}, - { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - }, - -1.0, + -1.0 * 3, ], [ - {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "triangular"}, + {"in_channels": 1, "ndim": 1, "kernel_type": "rectangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "pred": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), }, -1.0, ], [ - {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "gaussian"}, + {"in_channels": 1, "ndim": 2, "kernel_type": "rectangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(torch.float), }, -1.0, ], [ - {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "rectangular"}, + {"in_channels": 1, "ndim": 3, "kernel_type": "rectangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 1, 3, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 1, 3, 3, 3).to(torch.float), }, -1.0, ], [ - {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "triangular"}, + {"in_channels": 3, "ndim": 3, "kernel_type": "rectangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float) ** 2, }, - -1.0, - ], - [ - {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian"}, - { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), - }, - -1.0, + -0.95801723, ], [ - {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian", "reduction": "sum"}, + {"in_channels": 3, "ndim": 3, "kernel_type": "triangular", "kernel_size": 5}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), + "pred": torch.arange(0, 5).reshape(1, 1, -1, 1, 1).expand(1, 3, 5, 5, 5).to(torch.float), + "target": torch.arange(0, 5).reshape(1, 1, -1, 1, 1).expand(1, 3, 5, 5, 5).to(torch.float) ** 2, }, - -6.0, + -0.918672, ], [ - {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + {"in_channels": 3, "ndim": 3, "kernel_type": "gaussian"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float) ** 2, }, - -0.06062524, - ], - [ - {"in_channels": 3, "ndim": 3, "kernel_size": 5, "kernel_type": "triangular"}, - { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, - }, - -0.923356, - ], - [ - {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "gaussian"}, - { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, - }, - -1.306177, + -0.95406944, ], ] @@ -113,33 +81,33 @@ class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) def test_ill_shape(self): loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) # in_channel unmatch - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 2, 3, 3, 3), dtype=torch.float), torch.ones((1, 2, 3, 3, 3), dtype=torch.float)) # ndim unmatch - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 3, 3), dtype=torch.float)) - # input, target shape unmatch - with self.assertRaisesRegex(AssertionError, ""): + # pred, target shape unmatch + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 4, 4, 4), dtype=torch.float)) def test_ill_opts(self): - input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type="unknown")(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type="unknown")(pred, target) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type=None)(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type=None)(pred, target) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(pred, target) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(pred, target) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(pred, target) if __name__ == "__main__": diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py new file mode 100644 index 0000000000..da6af6f66d --- /dev/null +++ b/tests/test_reg_loss_integration.py @@ -0,0 +1,111 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.nn as nn +import torch.optim as optim +from parameterized import parameterized + +from monai.losses import BendingEnergyLoss, GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss + +TEST_CASES = [ + [BendingEnergyLoss, {}, ["pred"]], + [ + LocalNormalizedCrossCorrelationLoss, + {"in_channels": 1, "kernel_size": 7, "kernel_type": "rectangular"}, + ["pred", "target"], + ], + [ + LocalNormalizedCrossCorrelationLoss, + {"in_channels": 1, "kernel_size": 5, "kernel_type": "triangular"}, + ["pred", "target"], + ], + [ + LocalNormalizedCrossCorrelationLoss, + {"in_channels": 1, "kernel_size": 3, "kernel_type": "gaussian"}, + ["pred", "target"], + ], + [GlobalMutualInformationLoss, {"num_bins": 10}, ["pred", "target"]], +] + + +class TestRegLossIntegration(unittest.TestCase): + def setUp(self): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.manual_seed(0) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") + + def tearDown(self): + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + + @parameterized.expand(TEST_CASES) + def test_convergence(self, loss_type, loss_args, forward_args): + """ + The goal of this test is to assess if the gradient of the loss function + is correct by testing if we can train a one layer neural network + to segment one image. + We verify that the loss is decreasing in almost all SGD steps. + """ + learning_rate = 0.001 + max_iter = 100 + + # define a simple 3d example + target = torch.rand((1, 1, 5, 5, 5), device=self.device) + image = 12 * target + 27 + image = image.to(device=self.device) + + # define a one layer model + class OnelayerNet(nn.Module): + def __init__(self): + super(OnelayerNet, self).__init__() + self.layer = nn.Sequential( + nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), + ) + + def forward(self, x): + return self.layer(x) + + # initialise the network + net = OnelayerNet().to(self.device) + + # initialize the loss + loss = loss_type(**loss_args).to(self.device) + + # initialize a SGD optimizer + optimizer = optim.Adam(net.parameters(), lr=learning_rate) + + # train the network + for iter in range(max_iter): + # set the gradient to zero + optimizer.zero_grad() + + # forward pass + output = net(image) + loss_input = {"pred": output, "target": target} + + loss_val = loss(**{k: loss_input[k] for k in forward_args}) + if iter == 0: + init_loss = loss_val + + # backward pass + loss_val.backward() + optimizer.step() + self.assertTrue(init_loss > loss_val, "loss did not decrease") + + +if __name__ == "__main__": + unittest.main()