From ce61c38c377059f27deef07b26433669ab45c850 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 00:00:21 +0000 Subject: [PATCH 01/23] 1412 add local normalized cross correlation Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 157 ++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 monai/losses/image_dissimilarity.py diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py new file mode 100644 index 0000000000..557a757e79 --- /dev/null +++ b/monai/losses/image_dissimilarity.py @@ -0,0 +1,157 @@ +import torch +from torch.nn.modules.loss import _Loss + +from torch.nn import functional as F + +from monai.utils import LossReduction, Union + + +conv_dict = { + 1: F.conv1d, + 2: F.conv2d, + 3: F.conv3d +} + +EPS = 1e-7 + + +class LocalNormalizedCrossCorrelation(_Loss): + """ + Local squared zero-normalized cross-correlation. + The loss is based on a moving kernel/window over the y_true/y_pred, + within the window the square of zncc is calculated. + The kernel can be a rectangular / triangular / gaussian window. + The final loss is the averaged loss over all windows. + + Adapted from: + https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + in_channels: int, + dim: int = 3, + kernel_size: int = 9, + kernel_type: str = "rectangular", + reduction: Union[LossReduction, str] = LossReduction.MEAN, + ) -> None: + """ + Args: + in_channels: number of input channels + dim: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3. + kernel_size: kernel size or kernel sigma for kernel_type=``"gaussian"`` + kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super(LocalNormalizedCrossCorrelation, self).__init__(reduction=LossReduction(reduction).value) + self.in_channels = in_channels + self.dim = dim + if self.dim not in [1, 2, 3]: + raise ValueError(f'Unsupported dim: {self.dim}-d, only 1-d, 2-d, and 3-d inputs are supported') + self.kernel_size = kernel_size + if kernel_type == "rectangular": + self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() + elif kernel_type == "triangular": + self.kernel, self.kernel_vol, self.padding = self.make_triangular_kernel() + elif kernel_type == "gaussian": + self.kernel, self.kernel_vol, self.padding = self.make_gaussian_kernel() + else: + raise ValueError( + f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' + ) + + def make_rectangular_kernel(self): + shape = [1, self.in_channels] + [self.kernel_size] * self.dim + return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.dim, int((self.kernel_size - 1) / 2) + + def make_triangular_kernel(self): + fsize = torch.tensor((self.kernel_size + 1) / 2, dtype=torch.int) + f1 = torch.ones( + [1, 1] + [fsize] * self.dim, + dtype=torch.float + ) / fsize # (1, 1, D, H, W) + f2 = torch.ones( + [1, self.in_channels] + [fsize] * self.dim, + dtype=torch.float + ) / fsize # (1, in_channels, D, H, W) + # (1, 1, D, H, W) -> (1, in_channels, D, H, W) + fn = conv_dict[self.dim] + kernel = fn(f1, f2, padding=int((fsize - 1) / 2)) + + return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2) + + def make_gaussian_kernel(self): + mean = (self.kernel_size - 1) / 2.0 + sigma = self.kernel_size / 3 + + grid_dim = torch.arange(0, self.kernel_size) + grid_dim_ch = torch.arange(0, self.in_channel) + + if self.dim == 1: + grid = torch.meshgrid(grid_dim_ch, grid_dim) + elif self.dim == 2: + grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim) + elif self.dim == 3: + grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim, grid_dim) + else: + raise ValueError + + grid = torch.stack(grid, dim=-1).to(dtype=torch.float) + kernel = torch.exp( + -torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2) + ).unsqueeze(0) # (1, in_channel, kernel_size, kernel_size, kernel_size) + return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2) + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD]. + target: the shape should be BNH[WD]. + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + """ + assert ( + target.shape == input.shape + ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" + + t2, p2, tp = target ** 2, input ** 2, target * input + + # sum over kernel + fn = conv_dict[self.dim] + t_sum = fn(target, weight=self.kernel, padding=self.padding) + p_sum = fn(input, weight=self.kernel, padding=self.paddin) + t2_sum = fn(t2, weight=self.kernel, padding=self.paddin) + p2_sum = fn(p2, weight=self.kernel, padding=self.paddin) + tp_sum = fn(tp, weight=self.kernel, padding=self.paddin) + + # average over kernel + t_avg = t_sum / self.kernel_vol + p_avg = p_sum / self.kernel_vol + + # normalized cross correlation between t and p + # sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p] + # denoted by num / denom + # assume we sum over N values + # num = sum[t * p - mean[t] * p - t * mean[p] + mean[t] * mean[p]] + # = sum[t*p] - sum[t] * sum[p] / N * 2 + sum[t] * sum[p] / N + # = sum[t*p] - sum[t] * sum[p] / N + # = sum[t*p] - sum[t] * mean[p] = cross + # the following is actually squared ncc + cross = tp_sum - p_avg * t_sum + t_var = t2_sum - t_avg * t_sum # std[t] ** 2 + p_var = p2_sum - p_avg * p_sum # std[p] ** 2 + ncc = (cross * cross + EPS) / (t_var * p_var + EPS) # shape = (batch, 1, D, H, W) + + if self.reduction == LossReduction.SUM.value: + return torch.sum(ncc) # sum over the batch and channel dims + if self.reduction == LossReduction.NONE.value: + return ncc + if self.reduction == LossReduction.MEAN.value: + return torch.mean(ncc) # average over the batch and channel dims + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') From 5cf91d0463f3b52d36a193dfea290263d9ad0819 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 01:02:13 +0000 Subject: [PATCH 02/23] 1412 add unit test and documentation Signed-off-by: kate-sann5100 --- docs/source/losses.rst | 5 + monai/losses/__init__.py | 1 + monai/losses/image_dissimilarity.py | 94 +++++++------- ...local_normalized_cross_correlation_loss.py | 121 ++++++++++++++++++ 4 files changed, 172 insertions(+), 49 deletions(-) create mode 100644 tests/test_local_normalized_cross_correlation_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index d2c8e02ca4..fc86c9cd05 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -65,3 +65,8 @@ Registration Losses ~~~~~~~~~~~~~~~~~~~ .. autoclass:: BendingEnergyLoss :members: + +`LocalNormalizedCrossCorrelationLoss` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNormalizedCrossCorrelationLoss + :members: diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d4e21f900c..42ccd0f65a 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -22,4 +22,5 @@ generalized_wasserstein_dice, ) from .focal_loss import FocalLoss +from .image_dissimilarity import LocalNormalizedCrossCorrelationLoss from .tversky import TverskyLoss diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 557a757e79..26ee7ea718 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -1,21 +1,15 @@ import torch -from torch.nn.modules.loss import _Loss - from torch.nn import functional as F +from torch.nn.modules.loss import _Loss from monai.utils import LossReduction, Union - -conv_dict = { - 1: F.conv1d, - 2: F.conv2d, - 3: F.conv3d -} +conv_dict = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} EPS = 1e-7 -class LocalNormalizedCrossCorrelation(_Loss): +class LocalNormalizedCrossCorrelationLoss(_Loss): """ Local squared zero-normalized cross-correlation. The loss is based on a moving kernel/window over the y_true/y_pred, @@ -29,17 +23,17 @@ class LocalNormalizedCrossCorrelation(_Loss): """ def __init__( - self, - in_channels: int, - dim: int = 3, - kernel_size: int = 9, - kernel_type: str = "rectangular", - reduction: Union[LossReduction, str] = LossReduction.MEAN, + self, + in_channels: int, + ndim: int = 3, + kernel_size: int = 9, + kernel_type: str = "rectangular", + reduction: Union[LossReduction, str] = LossReduction.MEAN, ) -> None: """ Args: in_channels: number of input channels - dim: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3. + ndim: number of spatial ndimensions, {``1``, ``2``, ``3``}. Defaults to 3. kernel_size: kernel size or kernel sigma for kernel_type=``"gaussian"`` kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -51,9 +45,9 @@ def __init__( """ super(LocalNormalizedCrossCorrelation, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels - self.dim = dim - if self.dim not in [1, 2, 3]: - raise ValueError(f'Unsupported dim: {self.dim}-d, only 1-d, 2-d, and 3-d inputs are supported') + self.ndim = ndim + if self.ndim not in [1, 2, 3]: + raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") self.kernel_size = kernel_size if kernel_type == "rectangular": self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() @@ -67,21 +61,17 @@ def __init__( ) def make_rectangular_kernel(self): - shape = [1, self.in_channels] + [self.kernel_size] * self.dim - return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.dim, int((self.kernel_size - 1) / 2) + shape = [1, self.in_channels] + [self.kernel_size] * self.ndim + return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.ndim, int((self.kernel_size - 1) / 2) def make_triangular_kernel(self): fsize = torch.tensor((self.kernel_size + 1) / 2, dtype=torch.int) - f1 = torch.ones( - [1, 1] + [fsize] * self.dim, - dtype=torch.float - ) / fsize # (1, 1, D, H, W) - f2 = torch.ones( - [1, self.in_channels] + [fsize] * self.dim, - dtype=torch.float - ) / fsize # (1, in_channels, D, H, W) + f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize # (1, 1, D, H, W) + f2 = ( + torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize + ) # (1, in_channels, D, H, W) # (1, 1, D, H, W) -> (1, in_channels, D, H, W) - fn = conv_dict[self.dim] + fn = conv_dict[self.ndim] kernel = fn(f1, f2, padding=int((fsize - 1) / 2)) return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2) @@ -90,22 +80,22 @@ def make_gaussian_kernel(self): mean = (self.kernel_size - 1) / 2.0 sigma = self.kernel_size / 3 - grid_dim = torch.arange(0, self.kernel_size) - grid_dim_ch = torch.arange(0, self.in_channel) + grid_ndim = torch.arange(0, self.kernel_size) + grid_ndim_ch = torch.arange(0, self.in_channels) - if self.dim == 1: - grid = torch.meshgrid(grid_dim_ch, grid_dim) - elif self.dim == 2: - grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim) - elif self.dim == 3: - grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim, grid_dim) + if self.ndim == 1: + grid = torch.meshgrid(grid_ndim_ch, grid_ndim) + elif self.ndim == 2: + grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim) + elif self.ndim == 3: + grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim, grid_ndim) else: raise ValueError grid = torch.stack(grid, dim=-1).to(dtype=torch.float) - kernel = torch.exp( - -torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2) - ).unsqueeze(0) # (1, in_channel, kernel_size, kernel_size, kernel_size) + kernel = torch.exp(-torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2)).unsqueeze( + 0 + ) # (1, in_channel, kernel_size, kernel_size, kernel_size) return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: @@ -116,6 +106,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ + assert ( + input.shape[1] == self.in_channels + ), f"expecting input with {self.in_channels} channels, got input of shape {input.shape}" + assert ( + input.ndim - 2 == self.ndim + ), f"expecting input with {self.ndim} spatial dimensions, got input of shape {input.shape}" assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" @@ -123,12 +119,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, input ** 2, target * input # sum over kernel - fn = conv_dict[self.dim] + fn = conv_dict[self.ndim] t_sum = fn(target, weight=self.kernel, padding=self.padding) - p_sum = fn(input, weight=self.kernel, padding=self.paddin) - t2_sum = fn(t2, weight=self.kernel, padding=self.paddin) - p2_sum = fn(p2, weight=self.kernel, padding=self.paddin) - tp_sum = fn(tp, weight=self.kernel, padding=self.paddin) + p_sum = fn(input, weight=self.kernel, padding=self.padding) + t2_sum = fn(t2, weight=self.kernel, padding=self.padding) + p2_sum = fn(p2, weight=self.kernel, padding=self.padding) + tp_sum = fn(tp, weight=self.kernel, padding=self.padding) # average over kernel t_avg = t_sum / self.kernel_vol @@ -149,9 +145,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ncc = (cross * cross + EPS) / (t_var * p_var + EPS) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return torch.sum(ncc) # sum over the batch and channel dims + return -torch.sum(ncc) # sum over the batch and channel ndims if self.reduction == LossReduction.NONE.value: - return ncc + return -ncc if self.reduction == LossReduction.MEAN.value: - return torch.mean(ncc) # average over the batch and channel dims + return -torch.mean(ncc) # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py new file mode 100644 index 0000000000..332689be06 --- /dev/null +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -0,0 +1,121 @@ +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss + +TEST_CASES = [ + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "triangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "gaussian"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "triangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + }, + -0.06062524, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "triangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + }, + -0.9368649, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "gaussian"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + }, + -0.50272596, + ], +] + + +class TestBendingEnergy(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-2) + + def test_ill_shape(self): + loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) + # in_channel unmatch + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 2, 3, 3, 3), dtype=torch.float), torch.ones((1, 2, 3, 3, 3), dtype=torch.float)) + # ndim unmatch + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 3, 3), dtype=torch.float)) + # input, target shape unmatch + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 4, 4, 4), dtype=torch.float)) + + def test_ill_opts(self): + input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(input, target) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(input, target) + + +if __name__ == "__main__": + unittest.main() From 9376195b34df1eab0fc0f94b251102184e23864f Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 01:37:55 +0000 Subject: [PATCH 03/23] 1412 fix bug Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 26ee7ea718..42f0cfd9d7 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -43,7 +43,7 @@ def __init__( - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. """ - super(LocalNormalizedCrossCorrelation, self).__init__(reduction=LossReduction(reduction).value) + super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels self.ndim = ndim if self.ndim not in [1, 2, 3]: From ac36a9fc877882cf8da48378479782a9e58c7b8e Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 15:00:23 +0000 Subject: [PATCH 04/23] 1412 reformat code Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 58 ++++++++++++------- ...local_normalized_cross_correlation_loss.py | 15 ++++- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 42f0cfd9d7..232b8e0600 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -1,3 +1,14 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss @@ -6,8 +17,6 @@ conv_dict = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} -EPS = 1e-7 - class LocalNormalizedCrossCorrelationLoss(_Loss): """ @@ -29,6 +38,7 @@ def __init__( kernel_size: int = 9, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_dr: float = 1e-7, ) -> None: """ Args: @@ -42,12 +52,14 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_dr: a small constant added to the denominator to avoid nan. """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels self.ndim = ndim if self.ndim not in [1, 2, 3]: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") + self.fn = conv_dict[self.ndim] self.kernel_size = kernel_size if kernel_type == "rectangular": self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() @@ -59,26 +71,29 @@ def __init__( raise ValueError( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) + self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): shape = [1, self.in_channels] + [self.kernel_size] * self.ndim return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.ndim, int((self.kernel_size - 1) / 2) def make_triangular_kernel(self): - fsize = torch.tensor((self.kernel_size + 1) / 2, dtype=torch.int) - f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize # (1, 1, D, H, W) - f2 = ( - torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize - ) # (1, in_channels, D, H, W) + fsize = int((self.kernel_size + 1) // 2) + f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) # (1, 1, D, H, W) + f1 = F.pad(f1, [(fsize - 1) // 2, (fsize - 1) // 2] * self.ndim) + f2 = torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) + # (in_channels, 1, D, H, W) # (1, 1, D, H, W) -> (1, in_channels, D, H, W) - fn = conv_dict[self.ndim] - kernel = fn(f1, f2, padding=int((fsize - 1) / 2)) + padding_needed = max(fsize - 1, 0) + padding = [padding_needed // 2, padding_needed - padding_needed // 2] * self.ndim + f1 = F.pad(f1, padding) + kernel = self.fn(f1, f2) - return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2) + return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2.0) def make_gaussian_kernel(self): mean = (self.kernel_size - 1) / 2.0 - sigma = self.kernel_size / 3 + sigma = self.kernel_size / 3.0 grid_ndim = torch.arange(0, self.kernel_size) grid_ndim_ch = torch.arange(0, self.in_channels) @@ -96,7 +111,7 @@ def make_gaussian_kernel(self): kernel = torch.exp(-torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2)).unsqueeze( 0 ) # (1, in_channel, kernel_size, kernel_size, kernel_size) - return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2) + return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2.0) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -119,12 +134,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, input ** 2, target * input # sum over kernel - fn = conv_dict[self.ndim] - t_sum = fn(target, weight=self.kernel, padding=self.padding) - p_sum = fn(input, weight=self.kernel, padding=self.padding) - t2_sum = fn(t2, weight=self.kernel, padding=self.padding) - p2_sum = fn(p2, weight=self.kernel, padding=self.padding) - tp_sum = fn(tp, weight=self.kernel, padding=self.padding) + t_sum = self.fn(target, weight=self.kernel, padding=self.padding) + p_sum = self.fn(input, weight=self.kernel, padding=self.padding) + t2_sum = self.fn(t2, weight=self.kernel, padding=self.padding) + p2_sum = self.fn(p2, weight=self.kernel, padding=self.padding) + tp_sum = self.fn(tp, weight=self.kernel, padding=self.padding) # average over kernel t_avg = t_sum / self.kernel_vol @@ -142,12 +156,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc = (cross * cross + EPS) / (t_var * p_var + EPS) # shape = (batch, 1, D, H, W) + ncc = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return -torch.sum(ncc) # sum over the batch and channel ndims + return -torch.sum(ncc).neg() # sum over the batch and channel ndims if self.reduction == LossReduction.NONE.value: - return -ncc + return ncc.neg() if self.reduction == LossReduction.MEAN.value: - return -torch.mean(ncc) # average over the batch and channel ndims + return torch.mean(ncc).neg() # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index 332689be06..beb1a1dca2 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -1,3 +1,14 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import numpy as np @@ -72,7 +83,7 @@ -0.06062524, ], [ - {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "triangular"}, + {"in_channels": 3, "ndim": 3, "kernel_size": 6, "kernel_type": "triangular"}, { "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, @@ -90,7 +101,7 @@ ] -class TestBendingEnergy(unittest.TestCase): +class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) From ed6c28bca7393a527b5bc21a0442f0a08333d33c Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 15:20:47 +0000 Subject: [PATCH 05/23] 1412 debug type check Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 232b8e0600..4e64d81eb6 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -156,7 +156,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) + ncc: torch.Tensor = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) + # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: return -torch.sum(ncc).neg() # sum over the batch and channel ndims From 43c2f355416b249a34adfa87367aabd5f1b381f0 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 00:42:45 +0000 Subject: [PATCH 06/23] 1412 use separable filter for speed Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 75 +++++++------------ monai/networks/layers/simplelayers.py | 6 ++ ...local_normalized_cross_correlation_loss.py | 10 ++- 3 files changed, 41 insertions(+), 50 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 4e64d81eb6..bd9a06e187 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -13,10 +13,9 @@ from torch.nn import functional as F from torch.nn.modules.loss import _Loss +from monai.networks.layers import gaussian_1d, separable_filtering from monai.utils import LossReduction, Union -conv_dict = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} - class LocalNormalizedCrossCorrelationLoss(_Loss): """ @@ -44,7 +43,7 @@ def __init__( Args: in_channels: number of input channels ndim: number of spatial ndimensions, {``1``, ``2``, ``3``}. Defaults to 3. - kernel_size: kernel size or kernel sigma for kernel_type=``"gaussian"`` + kernel_size: kernel spatial size, must be odd. kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -56,62 +55,46 @@ def __init__( """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels + self.ndim = ndim if self.ndim not in [1, 2, 3]: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") - self.fn = conv_dict[self.ndim] + self.kernel_size = kernel_size + if self.kernel_size % 2 == 0: + raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") + if kernel_type == "rectangular": - self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() + self.kernel = self.make_rectangular_kernel() elif kernel_type == "triangular": - self.kernel, self.kernel_vol, self.padding = self.make_triangular_kernel() + self.kernel = self.make_triangular_kernel() elif kernel_type == "gaussian": - self.kernel, self.kernel_vol, self.padding = self.make_gaussian_kernel() + self.kernel = self.make_gaussian_kernel() else: raise ValueError( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) + + self.kernel_vol = torch.sum(self.kernel) ** self.ndim self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): - shape = [1, self.in_channels] + [self.kernel_size] * self.ndim - return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.ndim, int((self.kernel_size - 1) / 2) + return torch.ones(self.kernel_size) def make_triangular_kernel(self): - fsize = int((self.kernel_size + 1) // 2) - f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) # (1, 1, D, H, W) - f1 = F.pad(f1, [(fsize - 1) // 2, (fsize - 1) // 2] * self.ndim) - f2 = torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) - # (in_channels, 1, D, H, W) - # (1, 1, D, H, W) -> (1, in_channels, D, H, W) - padding_needed = max(fsize - 1, 0) - padding = [padding_needed // 2, padding_needed - padding_needed // 2] * self.ndim - f1 = F.pad(f1, padding) - kernel = self.fn(f1, f2) - - return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2.0) + fsize = (self.kernel_size + 1) // 2 + if fsize % 2 == 0: + fsize -= 1 + f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize) + padding = (self.kernel_size - fsize) // 2 + fsize // 2 + return F.conv1d(f, f, padding=padding).reshape(-1) def make_gaussian_kernel(self): - mean = (self.kernel_size - 1) / 2.0 - sigma = self.kernel_size / 3.0 - - grid_ndim = torch.arange(0, self.kernel_size) - grid_ndim_ch = torch.arange(0, self.in_channels) - - if self.ndim == 1: - grid = torch.meshgrid(grid_ndim_ch, grid_ndim) - elif self.ndim == 2: - grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim) - elif self.ndim == 3: - grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim, grid_ndim) - else: - raise ValueError - - grid = torch.stack(grid, dim=-1).to(dtype=torch.float) - kernel = torch.exp(-torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2)).unsqueeze( - 0 - ) # (1, in_channel, kernel_size, kernel_size, kernel_size) - return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2.0) + sigma = torch.tensor(self.kernel_size / 3.0) + kernel = gaussian_1d(sigma=sigma, truncated=self.kernel_size // 2, approx="sampled", normalize=False) * ( + 2.5066282 * sigma + ) + return kernel[: self.kernel_size] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -134,11 +117,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, input ** 2, target * input # sum over kernel - t_sum = self.fn(target, weight=self.kernel, padding=self.padding) - p_sum = self.fn(input, weight=self.kernel, padding=self.padding) - t2_sum = self.fn(t2, weight=self.kernel, padding=self.padding) - p2_sum = self.fn(p2, weight=self.kernel, padding=self.padding) - tp_sum = self.fn(tp, weight=self.kernel, padding=self.padding) + t_sum = separable_filtering(target, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + p_sum = separable_filtering(input, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + t2_sum = separable_filtering(t2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + p2_sum = separable_filtering(p2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + tp_sum = separable_filtering(tp, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) # average over kernel t_avg = t_sum / self.kernel_vol diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 48012dfb1c..8ff9464145 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -365,3 +365,9 @@ def reset_parameters(self): def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state) + + +if __name__ == "__main__": + input = torch.ones((1, 3, 3, 3)) + kernels = [torch.ones(1, 3)] * 2 + print(separable_filtering(input, kernels)) diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index beb1a1dca2..fe455b0597 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -83,12 +83,12 @@ -0.06062524, ], [ - {"in_channels": 3, "ndim": 3, "kernel_size": 6, "kernel_type": "triangular"}, + {"in_channels": 3, "ndim": 3, "kernel_size": 5, "kernel_type": "triangular"}, { "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, }, - -0.9368649, + -0.923356, ], [ {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "gaussian"}, @@ -96,7 +96,7 @@ "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, }, - -0.50272596, + -1.306177, ], ] @@ -105,7 +105,7 @@ class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-2) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) def test_ill_shape(self): loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) @@ -122,6 +122,8 @@ def test_ill_shape(self): def test_ill_opts(self): input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(input, target) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(input, target) with self.assertRaisesRegex(ValueError, ""): From f76e3f04bcc4f72eb3fb989261e671bcb5252a43 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 00:49:19 +0000 Subject: [PATCH 07/23] 1412 update Union import route Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index bd9a06e187..093aa2546b 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,13 +8,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss from monai.networks.layers import gaussian_1d, separable_filtering -from monai.utils import LossReduction, Union +from monai.utils import LossReduction class LocalNormalizedCrossCorrelationLoss(_Loss): From d3aba3ff57b5035f4c85056fd6efe4bf41a3eed6 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 03:07:44 +0000 Subject: [PATCH 08/23] 1412 add global mutual information Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 78 ++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 093aa2546b..d5ec1f81d9 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -38,6 +38,7 @@ def __init__( kernel_size: int = 9, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-7, smooth_dr: float = 1e-7, ) -> None: """ @@ -52,6 +53,7 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the nominator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) @@ -77,6 +79,7 @@ def __init__( ) self.kernel_vol = torch.sum(self.kernel) ** self.ndim + self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): @@ -140,7 +143,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc: torch.Tensor = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) + ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: @@ -150,3 +153,76 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.reduction == LossReduction.MEAN.value: return torch.mean(ncc).neg() # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + +class GlobalMutualInformationLoss(_Loss): + """ + Differentiable global mutual information loss via Parzen windowing method. + + Reference: + https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1 + """ + + def __init__( + self, + num_bins: int = 23, + sigma_ratio: float = 0.5, + reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-7, + smooth_dr: float = 1e-7, + ) -> None: + """ + Args: + num_bins: number of bins for intensity + sigma_ratio: a hyper param for gaussian function + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the nominator to avoid nan. + smooth_dr: a small constant added to the denominator to avoid nan. + """ + super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value) + bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) + sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio + self.preterm = 1 / (2 * sigma ** 2) + self.bin_centers = bin_centers[None, None, ...] + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + + def parzen_windowing(self, input: torch.Tensor) -> (torch.Tensor, torch.Tensor): + """ + Args: + input: the shape should be BNH[WD]. + """ + input = torch.clamp(input, 0, 1) + input = input.reshape(input.shape[0], -1, 1) # (batch, num_sample, 1) + weight = torch.exp(-self.preterm * (input - self.bin_centers) ** 2) # (batch, num_sample, num_bin) + weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) + probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) + return weight, probability + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD]. + target: the shape should be BNH[WD]. + Raises: + """ + wa, pa = self.parzen_windowing(input) # (batch, num_sample, num_bin), (batch, 1, num_bin) + wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin) + pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) + + papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins) + mi = torch.sum( + pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2) + ) # (batch) + if self.reduction == LossReduction.SUM.value: + return torch.sum(mi).neg() # sum over the batch and channel ndims + if self.reduction == LossReduction.NONE.value: + return mi.neg() + if self.reduction == LossReduction.MEAN.value: + return torch.mean(mi).neg() # average over the batch and channel ndims + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') From dc3c036897dc97bc763520903f5c412fe058afd5 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 03:07:44 +0000 Subject: [PATCH 09/23] 1412 add global mutual information Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 78 ++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 093aa2546b..cd1ad475bf 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -38,6 +38,7 @@ def __init__( kernel_size: int = 9, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-7, smooth_dr: float = 1e-7, ) -> None: """ @@ -52,6 +53,7 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) @@ -77,6 +79,7 @@ def __init__( ) self.kernel_vol = torch.sum(self.kernel) ** self.ndim + self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): @@ -140,7 +143,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc: torch.Tensor = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) + ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: @@ -150,3 +153,76 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.reduction == LossReduction.MEAN.value: return torch.mean(ncc).neg() # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + +class GlobalMutualInformationLoss(_Loss): + """ + Differentiable global mutual information loss via Parzen windowing method. + + Reference: + https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1 + """ + + def __init__( + self, + num_bins: int = 23, + sigma_ratio: float = 0.5, + reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-7, + smooth_dr: float = 1e-7, + ) -> None: + """ + Args: + num_bins: number of bins for intensity + sigma_ratio: a hyper param for gaussian function + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid nan. + smooth_dr: a small constant added to the denominator to avoid nan. + """ + super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value) + bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) + sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio + self.preterm = 1 / (2 * sigma ** 2) + self.bin_centers = bin_centers[None, None, ...] + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + + def parzen_windowing(self, input: torch.Tensor) -> (torch.Tensor, torch.Tensor): + """ + Args: + input: the shape should be BNH[WD]. + """ + input = torch.clamp(input, 0, 1) + input = input.reshape(input.shape[0], -1, 1) # (batch, num_sample, 1) + weight = torch.exp(-self.preterm * (input - self.bin_centers) ** 2) # (batch, num_sample, num_bin) + weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) + probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) + return weight, probability + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD]. + target: the shape should be BNH[WD]. + Raises: + """ + wa, pa = self.parzen_windowing(input) # (batch, num_sample, num_bin), (batch, 1, num_bin) + wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin) + pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) + + papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins) + mi = torch.sum( + pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2) + ) # (batch) + if self.reduction == LossReduction.SUM.value: + return torch.sum(mi).neg() # sum over the batch and channel ndims + if self.reduction == LossReduction.NONE.value: + return mi.neg() + if self.reduction == LossReduction.MEAN.value: + return torch.mean(mi).neg() # average over the batch and channel ndims + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') From 79a281903fc0b5d08c7b74173cfd663fb3f22b77 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 03:36:51 +0000 Subject: [PATCH 10/23] 1412 autostyle fix Signed-off-by: kate-sann5100 --- monai/losses/__init__.py | 5 +---- monai/losses/image_dissimilarity.py | 8 ++++++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index eab1995b3d..591fb08f7b 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -22,8 +22,5 @@ generalized_wasserstein_dice, ) from .focal_loss import FocalLoss -from .image_dissimilarity import ( - LocalNormalizedCrossCorrelationLoss, - GlobalMutualInformationLoss -) +from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .tversky import TverskyLoss diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 7313abbd0a..28663dabe9 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Tuple +from typing import Tuple, Union import torch from torch.nn import functional as F @@ -185,6 +185,7 @@ def __init__( smooth_dr: a small constant added to the denominator to avoid nan. """ super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value) + assert num_bins > 0, f"num_bins must > 0, got f{num_bins}" bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio self.preterm = 1 / (2 * sigma ** 2) @@ -192,7 +193,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def parzen_windowing(self, input: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]: + def parzen_windowing(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: the shape should be BNH[WD]. @@ -211,6 +212,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target: the shape should be BNH[WD]. Raises: """ + assert ( + target.shape == input.shape + ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" wa, pa = self.parzen_windowing(input) # (batch, num_sample, num_bin), (batch, 1, num_bin) wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin) pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) From 68de8ec9567ab3060e6272efa92b41c9e8c762b2 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 14:29:48 +0000 Subject: [PATCH 11/23] 1412 add unit test Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 9 +- tests/test_global_mutual_information_loss.py | 100 +++++++++++++++++++ 2 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 tests/test_global_mutual_information_loss.py diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 28663dabe9..07f04cec20 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -185,7 +185,7 @@ def __init__( smooth_dr: a small constant added to the denominator to avoid nan. """ super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value) - assert num_bins > 0, f"num_bins must > 0, got f{num_bins}" + assert num_bins > 0, f"num_bins must > 0, got {num_bins}" bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio self.preterm = 1 / (2 * sigma ** 2) @@ -196,7 +196,7 @@ def __init__( def parzen_windowing(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: - input: the shape should be BNH[WD]. + input: the shape should be B[NDHW]. """ input = torch.clamp(input, 0, 1) input = input.reshape(input.shape[0], -1, 1) # (batch, num_sample, 1) @@ -208,9 +208,10 @@ def parzen_windowing(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: - input: the shape should be BNH[WD]. - target: the shape should be BNH[WD]. + input: the shape should be B[NDHW]. + target: the shape should be same as the input shape. Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ assert ( target.shape == input.shape diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py new file mode 100644 index 0000000000..376903fc1d --- /dev/null +++ b/tests/test_global_mutual_information_loss.py @@ -0,0 +1,100 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.image_dissimilarity import GlobalMutualInformationLoss + +TEST_CASES = [ + [ + {}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + }, + -1.0986018, + ], + [ + {}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3) + ** 2, + }, + -1.083999, + ], + [ + {}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3) ** 2, + }, + -1.083999, + ], + [ + {}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3) ** 2, + }, + -1.083999, + ], + [ + {}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :].div(3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :].div(3) ** 2, + }, + -1.083999, + ], + [ + {}, + { + "input": torch.arange(0, 3, dtype=torch.float).div(3), + "target": torch.arange(0, 3, dtype=torch.float).div(3) ** 2, + }, + -1.1920927e-07, + ], +] + + +class TestGlobalMutualInformationLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = GlobalMutualInformationLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + + def test_ill_shape(self): + loss = GlobalMutualInformationLoss() + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) + + def test_ill_opts(self): + input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(AssertionError, ""): + GlobalMutualInformationLoss(num_bins=0)(input, target) + with self.assertRaisesRegex(AssertionError, ""): + GlobalMutualInformationLoss(num_bins=-1)(input, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(reduction="unknown")(input, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(reduction=None)(input, target) + + +if __name__ == "__main__": + unittest.main() From b66a3fb38d7ce5cf895a6b1780175ef269b7baa1 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 11 Jan 2021 14:16:43 +0000 Subject: [PATCH 12/23] 1412 add integration test Signed-off-by: kate-sann5100 --- monai/networks/layers/simplelayers.py | 6 -- tests/test_reg_loss_integration.py | 108 ++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 6 deletions(-) create mode 100644 tests/test_reg_loss_integration.py diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 55189e3e9b..ba60f4eca4 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -365,9 +365,3 @@ def reset_parameters(self): def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state) - - -if __name__ == "__main__": - input = torch.ones((1, 3, 3, 3)) - kernels = [torch.ones(1, 3)] * 2 - print(separable_filtering(input, kernels)) diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py new file mode 100644 index 0000000000..078a5de573 --- /dev/null +++ b/tests/test_reg_loss_integration.py @@ -0,0 +1,108 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from parameterized import parameterized + +from monai.losses import BendingEnergyLoss, GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss + +TEST_CASES = [ + [BendingEnergyLoss, {}, ["input"]], + [ + LocalNormalizedCrossCorrelationLoss, + {"in_channels": 3, "kernel_size": 5, "kernel_type": "rectangular"}, + ["input", "target"], + ], + [ + LocalNormalizedCrossCorrelationLoss, + {"in_channels": 3, "kernel_size": 5, "kernel_type": "triangular"}, + ["input", "target"], + ], + [ + LocalNormalizedCrossCorrelationLoss, + {"in_channels": 3, "kernel_size": 5, "kernel_type": "gaussian"}, + ["input", "target"], + ], + [GlobalMutualInformationLoss, {"num_bins": 5}, ["input", "target"]], +] + + +class TestRegLossIntegration(unittest.TestCase): + def setUp(self): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.manual_seed(0) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") + + def tearDown(self): + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + + @parameterized.expand(TEST_CASES) + def test_convergence(self, loss_type, loss_args, forward_args): + """ + The goal of this test is to assess if the gradient of the loss function + is correct by testing if we can train a one layer neural network + to segment one image. + We verify that the loss is decreasing in almost all SGD steps. + """ + learning_rate = 0.001 + max_iter = 40 + + # define a simple 3d example + target = torch.rand((1, 3, 5, 5, 5), device=self.device) + image = 12 * target + 27 + + # define a one layer model + class OnelayerNet(nn.Module): + def __init__(self): + super(OnelayerNet, self).__init__() + self.layer = nn.Sequential( + nn.Conv3d(in_channels=3, out_channels=3, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv3d(in_channels=3, out_channels=3, kernel_size=3, padding=1), + ) + + def forward(self, x): + return self.layer(x) + + # initialise the network + net = OnelayerNet().to(self.device) + + # initialize the loss + loss = loss_type(**loss_args) + + # initialize a SGD optimizer + optimizer = optim.Adam(net.parameters(), lr=learning_rate) + + # train the network + for iter_i in range(max_iter): + # set the gradient to zero + optimizer.zero_grad() + + # forward pass + output = net(image) + loss_input = {"input": output, "target": target} + + loss_val = loss(**{k: loss_input[k] for k in forward_args}) + + # backward pass + loss_val.backward() + optimizer.step() + + +if __name__ == "__main__": + unittest.main() From d2d2a87e943299f058c62d751fc72db911c6e2d0 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 11 Jan 2021 14:55:58 +0000 Subject: [PATCH 13/23] 1412 reformat code Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 34 ++++------ monai/losses/image_dissimilarity.py | 66 +++++++++---------- tests/test_bending_energy.py | 26 ++++---- tests/test_global_mutual_information_loss.py | 30 ++++----- ...local_normalized_cross_correlation_loss.py | 42 ++++++------ tests/test_reg_loss_integration.py | 15 ++--- 6 files changed, 102 insertions(+), 111 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 1005587021..97f389a782 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -17,7 +17,7 @@ from monai.utils import LossReduction -def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: +def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor: """ Calculate gradients on single dimension of a tensor using central finite difference. It moves the tensor along the dimension to calculate the approximate gradient @@ -26,7 +26,7 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: DeepReg (https://github.com/DeepRegNet/DeepReg) Args: - input: the shape should be BCH(WD). + pred: the shape should be BCH(WD). dim: dimension to calculate gradient along. Returns: gradient_dx: the shape should be BCH(WD) @@ -36,17 +36,17 @@ def spatial_gradient(input: torch.Tensor, dim: int) -> torch.Tensor: slice_2_e = slice(None, -2) slice_all = slice(None) slicing_s, slicing_e = [slice_all, slice_all], [slice_all, slice_all] - while len(slicing_s) < input.ndim: + while len(slicing_s) < x.ndim: slicing_s = slicing_s + [slice_1] slicing_e = slicing_e + [slice_1] slicing_s[dim] = slice_2_s slicing_e[dim] = slice_2_e - return (input[slicing_s] - input[slicing_e]) / 2.0 + return (x[slicing_s] - x[slicing_e]) / 2.0 class BendingEnergyLoss(_Loss): """ - Calculate the bending energy based on second-order differentiation of input using central finite difference. + Calculate the bending energy based on second-order differentiation of pred using central finite difference. Adapted from: DeepReg (https://github.com/DeepRegNet/DeepReg) @@ -67,35 +67,29 @@ def __init__( """ super(BendingEnergyLoss, self).__init__(reduction=LossReduction(reduction).value) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, pred: torch.Tensor) -> torch.Tensor: """ Args: - input: the shape should be BCH(WD) + pred: the shape should be BCH(WD) Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ - assert input.ndim in [3, 4, 5], f"expecting 3-d, 4-d or 5-d input, instead got input of shape {input.shape}" - if input.ndim == 3: - assert input.shape[-1] > 4, f"all spatial dimensions must > 4, got input of shape {input.shape}" - elif input.ndim == 4: - assert ( - input.shape[-1] > 4 and input.shape[-2] > 4 - ), f"all spatial dimensions must > 4, got input of shape {input.shape}" - elif input.ndim == 5: - assert ( - input.shape[-1] > 4 and input.shape[-2] > 4 and input.shape[-3] > 4 - ), f"all spatial dimensions must > 4, got input of shape {input.shape}" + if pred.ndim not in [3, 4, 5]: + raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + for i in range(pred.ndim - 2): + if pred.shape[- i - 1] <= 4: + raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}") # first order gradient - first_order_gradient = [spatial_gradient(input, dim) for dim in range(2, input.ndim)] + first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] energy = torch.tensor(0) for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 energy = spatial_gradient(g, dim_1) ** 2 + energy - for dim_2 in range(dim_1 + 1, input.ndim): + for dim_2 in range(dim_1 + 1, pred.ndim): energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy if self.reduction == LossReduction.MEAN.value: diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 184e341f97..dab743a463 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -76,11 +76,11 @@ def __init__( kernel_size: kernel spatial size, must be odd. kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} - Specifies the reduction to apply to the output. Defaults to ``"mean"``. + Specifies the reduction to apply to the pred. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - - ``"sum"``: the output will be summed. + - ``"mean"``: the sum of the pred will be divided by the number of elements in the pred. + - ``"sum"``: the pred will be summed. smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ @@ -89,7 +89,7 @@ def __init__( self.ndim = ndim if self.ndim not in [1, 2, 3]: - raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") + raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d preds are supported") self.kernel_size = kernel_size if self.kernel_size % 2 == 0: @@ -104,29 +104,26 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: - input: the shape should be BNH[WD]. + pred: the shape should be BNH[WD]. target: the shape should be BNH[WD]. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ - assert ( - input.shape[1] == self.in_channels - ), f"expecting input with {self.in_channels} channels, got input of shape {input.shape}" - assert ( - input.ndim - 2 == self.ndim - ), f"expecting input with {self.ndim} spatial dimensions, got input of shape {input.shape}" - assert ( - target.shape == input.shape - ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" - - t2, p2, tp = target ** 2, input ** 2, target * input + if pred.shape[1] != self.in_channels: + raise ValueError(f"expecting pred with {self.in_channels} channels, got pred of shape {pred.shape}") + if pred.ndim - 2 != self.ndim: + raise ValueError(f"expecting pred with {self.ndim} spatial dimensions, got pred of shape {pred.shape}") + if target.shape != pred.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") + + t2, p2, tp = target ** 2, pred ** 2, target * pred # sum over kernel t_sum = separable_filtering(target, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) - p_sum = separable_filtering(input, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + p_sum = separable_filtering(pred, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) t2_sum = separable_filtering(t2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) p2_sum = separable_filtering(p2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) tp_sum = separable_filtering(tp, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) @@ -180,16 +177,17 @@ def __init__( num_bins: number of bins for intensity sigma_ratio: a hyper param for gaussian function reduction: {``"none"``, ``"mean"``, ``"sum"``} - Specifies the reduction to apply to the output. Defaults to ``"mean"``. + Specifies the reduction to apply to the pred. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - - ``"sum"``: the output will be summed. + - ``"mean"``: the sum of the pred will be divided by the number of elements in the pred. + - ``"sum"``: the pred will be summed. smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ super(GlobalMutualInformationLoss, self).__init__(reduction=LossReduction(reduction).value) - assert num_bins > 0, f"num_bins must > 0, got {num_bins}" + if num_bins <= 0: + raise ValueError("num_bins must > 0, got {num_bins}") bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio self.preterm = 1 / (2 * sigma ** 2) @@ -197,30 +195,29 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def parzen_windowing(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: - input: the shape should be B[NDHW]. + pred: the shape should be B[NDHW]. """ - input = torch.clamp(input, 0, 1) - input = input.reshape(input.shape[0], -1, 1) # (batch, num_sample, 1) - weight = torch.exp(-self.preterm * (input - self.bin_centers) ** 2) # (batch, num_sample, num_bin) + pred = torch.clamp(pred, 0, 1) + pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1) + weight = torch.exp(-self.preterm * (pred - self.bin_centers) ** 2) # (batch, num_sample, num_bin) weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) return weight, probability - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: - input: the shape should be B[NDHW]. - target: the shape should be same as the input shape. + pred: the shape should be B[NDHW]. + target: the shape should be same as the pred shape. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ - assert ( - target.shape == input.shape - ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" - wa, pa = self.parzen_windowing(input) # (batch, num_sample, num_bin), (batch, 1, num_bin) + if target.shape != pred.shape: + raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") + wa, pa = self.parzen_windowing(pred) # (batch, num_sample, num_bin), (batch, 1, num_bin) wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin) pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) @@ -228,6 +225,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: mi = torch.sum( pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2) ) # (batch) + if self.reduction == LossReduction.SUM.value: return torch.sum(mi).neg() # sum over the batch and channel ndims if self.reduction == LossReduction.NONE.value: diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index 3ba22ebac0..f2b9a41cae 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -20,27 +20,27 @@ TEST_CASES = [ [ {}, - {"input": torch.ones((1, 3, 5, 5, 5))}, + {"pred": torch.ones((1, 3, 5, 5, 5))}, 0.0, ], [ {}, - {"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, + {"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0, ], [ {}, - {"input": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + {"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, 4.0, ], [ {}, - {"input": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, + {"pred": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, 4.0, ], [ {}, - {"input": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2}, + {"pred": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2}, 4.0, ], ] @@ -55,24 +55,24 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = BendingEnergyLoss() # not in 3-d, 4-d, 5-d - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3))) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 5, 5, 5))) # spatial_dim < 5 - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 4, 5, 5))) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 4, 5))) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 5, 4))) def test_ill_opts(self): - input = torch.rand(1, 3, 5, 5, 5) + pred = torch.rand(1, 3, 5, 5, 5) with self.assertRaisesRegex(ValueError, ""): - BendingEnergyLoss(reduction="unknown")(input) + BendingEnergyLoss(reduction="unknown")(pred) with self.assertRaisesRegex(ValueError, ""): - BendingEnergyLoss(reduction=None)(input) + BendingEnergyLoss(reduction=None)(pred) if __name__ == "__main__": diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index 376903fc1d..252a70e85e 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -21,7 +21,7 @@ [ {}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), }, -1.0986018, @@ -29,7 +29,7 @@ [ {}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3) ** 2, }, @@ -38,7 +38,7 @@ [ {}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3) ** 2, }, -1.083999, @@ -46,7 +46,7 @@ [ {}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3) ** 2, }, -1.083999, @@ -54,7 +54,7 @@ [ {}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :].div(3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :].div(3), "target": torch.arange(0, 3, dtype=torch.float)[None, :].div(3) ** 2, }, -1.083999, @@ -62,7 +62,7 @@ [ {}, { - "input": torch.arange(0, 3, dtype=torch.float).div(3), + "pred": torch.arange(0, 3, dtype=torch.float).div(3), "target": torch.arange(0, 3, dtype=torch.float).div(3) ** 2, }, -1.1920927e-07, @@ -78,22 +78,22 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = GlobalMutualInformationLoss() - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) def test_ill_opts(self): - input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) - with self.assertRaisesRegex(AssertionError, ""): - GlobalMutualInformationLoss(num_bins=0)(input, target) - with self.assertRaisesRegex(AssertionError, ""): - GlobalMutualInformationLoss(num_bins=-1)(input, target) with self.assertRaisesRegex(ValueError, ""): - GlobalMutualInformationLoss(reduction="unknown")(input, target) + GlobalMutualInformationLoss(num_bins=0)(pred, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(num_bins=-1)(pred, target) + with self.assertRaisesRegex(ValueError, ""): + GlobalMutualInformationLoss(reduction="unknown")(pred, target) with self.assertRaisesRegex(ValueError, ""): - GlobalMutualInformationLoss(reduction=None)(input, target) + GlobalMutualInformationLoss(reduction=None)(pred, target) if __name__ == "__main__": diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index cb2f446dfc..79027a03fb 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -21,7 +21,7 @@ [ {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), }, -1.0, @@ -29,7 +29,7 @@ [ {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "rectangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), }, -1.0, @@ -37,7 +37,7 @@ [ {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "triangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), }, -1.0, @@ -45,7 +45,7 @@ [ {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "gaussian"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), }, -1.0, @@ -53,7 +53,7 @@ [ {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "rectangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), }, -1.0, @@ -61,7 +61,7 @@ [ {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "triangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), }, -1.0, @@ -69,7 +69,7 @@ [ {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), }, -1.0, @@ -77,7 +77,7 @@ [ {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian", "reduction": "sum"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), }, -6.0, @@ -85,7 +85,7 @@ [ {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, }, -0.06062524, @@ -93,7 +93,7 @@ [ {"in_channels": 3, "ndim": 3, "kernel_size": 5, "kernel_type": "triangular"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, }, -0.923356, @@ -101,7 +101,7 @@ [ {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "gaussian"}, { - "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, }, -1.306177, @@ -118,28 +118,28 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) # in_channel unmatch - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 2, 3, 3, 3), dtype=torch.float), torch.ones((1, 2, 3, 3, 3), dtype=torch.float)) # ndim unmatch - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 3, 3), dtype=torch.float)) - # input, target shape unmatch - with self.assertRaisesRegex(AssertionError, ""): + # pred, target shape unmatch + with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 4, 4, 4), dtype=torch.float)) def test_ill_opts(self): - input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type="unknown")(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type="unknown")(pred, target) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type=None)(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type=None)(pred, target) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(pred, target) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(pred, target) with self.assertRaisesRegex(ValueError, ""): - LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(input, target) + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(pred, target) if __name__ == "__main__": diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 078a5de573..0063c96b60 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -11,7 +11,6 @@ import unittest -import numpy as np import torch import torch.nn as nn import torch.optim as optim @@ -20,23 +19,23 @@ from monai.losses import BendingEnergyLoss, GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss TEST_CASES = [ - [BendingEnergyLoss, {}, ["input"]], + [BendingEnergyLoss, {}, ["pred"]], [ LocalNormalizedCrossCorrelationLoss, {"in_channels": 3, "kernel_size": 5, "kernel_type": "rectangular"}, - ["input", "target"], + ["pred", "target"], ], [ LocalNormalizedCrossCorrelationLoss, {"in_channels": 3, "kernel_size": 5, "kernel_type": "triangular"}, - ["input", "target"], + ["pred", "target"], ], [ LocalNormalizedCrossCorrelationLoss, {"in_channels": 3, "kernel_size": 5, "kernel_type": "gaussian"}, - ["input", "target"], + ["pred", "target"], ], - [GlobalMutualInformationLoss, {"num_bins": 5}, ["input", "target"]], + [GlobalMutualInformationLoss, {"num_bins": 5}, ["pred", "target"]], ] @@ -89,13 +88,13 @@ def forward(self, x): optimizer = optim.Adam(net.parameters(), lr=learning_rate) # train the network - for iter_i in range(max_iter): + for _ in range(max_iter): # set the gradient to zero optimizer.zero_grad() # forward pass output = net(image) - loss_input = {"input": output, "target": target} + loss_input = {"pred": output, "target": target} loss_val = loss(**{k: loss_input[k] for k in forward_args}) From 32bcdec1c42950065318700400833dea3c1654ae Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 11 Jan 2021 15:12:00 +0000 Subject: [PATCH 14/23] 1412 autofix style Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 97f389a782..0df76c1533 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -79,7 +79,7 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: if pred.ndim not in [3, 4, 5]: raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") for i in range(pred.ndim - 2): - if pred.shape[- i - 1] <= 4: + if pred.shape[-i - 1] <= 4: raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}") # first order gradient From fd62c9644617bef91f2d69de999193a7ad349d10 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 11 Jan 2021 16:25:12 +0000 Subject: [PATCH 15/23] 1412 debug Signed-off-by: kate-sann5100 --- tests/test_reg_loss_integration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 0063c96b60..b84c0f1f48 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -64,6 +64,7 @@ def test_convergence(self, loss_type, loss_args, forward_args): # define a simple 3d example target = torch.rand((1, 3, 5, 5, 5), device=self.device) image = 12 * target + 27 + image = image.to(device=self.device) # define a one layer model class OnelayerNet(nn.Module): @@ -82,7 +83,7 @@ def forward(self, x): net = OnelayerNet().to(self.device) # initialize the loss - loss = loss_type(**loss_args) + loss = loss_type(**loss_args).to(self.device) # initialize a SGD optimizer optimizer = optim.Adam(net.parameters(), lr=learning_rate) From 5a8b1d7d65ffd555eeeb1b8554fabcb62639db57 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 11 Jan 2021 17:20:09 +0000 Subject: [PATCH 16/23] 1412 fix device bug Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index dab743a463..5c2cd5e571 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -120,17 +120,17 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") t2, p2, tp = target ** 2, pred ** 2, target * pred - + kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred) # sum over kernel - t_sum = separable_filtering(target, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) - p_sum = separable_filtering(pred, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) - t2_sum = separable_filtering(t2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) - p2_sum = separable_filtering(p2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) - tp_sum = separable_filtering(tp, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + t_sum = separable_filtering(target, kernels=[kernel] * self.ndim).sum(1, keepdim=True) + p_sum = separable_filtering(pred, kernels=[kernel] * self.ndim).sum(1, keepdim=True) + t2_sum = separable_filtering(t2, kernels=[kernel] * self.ndim).sum(1, keepdim=True) + p2_sum = separable_filtering(p2, kernels=[kernel] * self.ndim).sum(1, keepdim=True) + tp_sum = separable_filtering(tp, kernels=[kernel] * self.ndim).sum(1, keepdim=True) # average over kernel - t_avg = t_sum / self.kernel_vol - p_avg = p_sum / self.kernel_vol + t_avg = t_sum / kernel_vol + p_avg = p_sum / kernel_vol # normalized cross correlation between t and p # sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p] @@ -202,7 +202,7 @@ def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens """ pred = torch.clamp(pred, 0, 1) pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1) - weight = torch.exp(-self.preterm * (pred - self.bin_centers) ** 2) # (batch, num_sample, num_bin) + weight = torch.exp(-self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2) # (batch, num_sample, num_bin) weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) return weight, probability From 3d76ee10fe751db2ba66bf8648bd71670107ddbf Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 11 Jan 2021 17:25:38 +0000 Subject: [PATCH 17/23] 1412 autofix style Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 5c2cd5e571..281871cc68 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -202,7 +202,9 @@ def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens """ pred = torch.clamp(pred, 0, 1) pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1) - weight = torch.exp(-self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2) # (batch, num_sample, num_bin) + weight = torch.exp( + -self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2 + ) # (batch, num_sample, num_bin) weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) return weight, probability From 30a514598d0490f2ccab2b4e64df4dfb40bbae8c Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Mon, 11 Jan 2021 19:50:06 +0000 Subject: [PATCH 18/23] 1412 fix typo Signed-off-by: kate-sann5100 --- monai/losses/deform.py | 2 +- monai/losses/image_dissimilarity.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 0df76c1533..acba229121 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -26,7 +26,7 @@ def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor: DeepReg (https://github.com/DeepRegNet/DeepReg) Args: - pred: the shape should be BCH(WD). + x: the shape should be BCH(WD). dim: dimension to calculate gradient along. Returns: gradient_dx: the shape should be BCH(WD) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 281871cc68..9a45b697aa 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -76,11 +76,11 @@ def __init__( kernel_size: kernel spatial size, must be odd. kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} - Specifies the reduction to apply to the pred. Defaults to ``"mean"``. + Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - - ``"mean"``: the sum of the pred will be divided by the number of elements in the pred. - - ``"sum"``: the pred will be summed. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ @@ -89,7 +89,7 @@ def __init__( self.ndim = ndim if self.ndim not in [1, 2, 3]: - raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d preds are supported") + raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") self.kernel_size = kernel_size if self.kernel_size % 2 == 0: @@ -177,11 +177,11 @@ def __init__( num_bins: number of bins for intensity sigma_ratio: a hyper param for gaussian function reduction: {``"none"``, ``"mean"``, ``"sum"``} - Specifies the reduction to apply to the pred. Defaults to ``"mean"``. + Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - - ``"mean"``: the sum of the pred will be divided by the number of elements in the pred. - - ``"sum"``: the pred will be summed. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ From 5f4a7d9b0d579c02ef091b062ff66dd5ce952192 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 12 Jan 2021 14:58:40 +0000 Subject: [PATCH 19/23] 1412 debug kernel_vol Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 29 +++++-- ...local_normalized_cross_correlation_loss.py | 84 ++++++------------- tests/test_reg_loss_integration.py | 19 +++-- 3 files changed, 57 insertions(+), 75 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 9a45b697aa..747ad6afc4 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -63,7 +63,7 @@ def __init__( self, in_channels: int, ndim: int = 3, - kernel_size: int = 9, + kernel_size: int = 3, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-7, @@ -100,10 +100,17 @@ def __init__( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) self.kernel = kernel_dict[kernel_type](self.kernel_size) - self.kernel_vol = torch.sum(self.kernel) ** self.ndim + self.kernel_vol = self.get_kernel_vol() + self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) + def get_kernel_vol(self): + vol = self.kernel + for _ in range(self.ndim - 1): + vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0)) + return torch.sum(vol) + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: @@ -122,11 +129,11 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, pred ** 2, target * pred kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred) # sum over kernel - t_sum = separable_filtering(target, kernels=[kernel] * self.ndim).sum(1, keepdim=True) - p_sum = separable_filtering(pred, kernels=[kernel] * self.ndim).sum(1, keepdim=True) - t2_sum = separable_filtering(t2, kernels=[kernel] * self.ndim).sum(1, keepdim=True) - p2_sum = separable_filtering(p2, kernels=[kernel] * self.ndim).sum(1, keepdim=True) - tp_sum = separable_filtering(tp, kernels=[kernel] * self.ndim).sum(1, keepdim=True) + t_sum = separable_filtering(target, kernels=[kernel] * self.ndim) + p_sum = separable_filtering(pred, kernels=[kernel] * self.ndim) + t2_sum = separable_filtering(t2, kernels=[kernel] * self.ndim) + p2_sum = separable_filtering(p2, kernels=[kernel] * self.ndim) + tp_sum = separable_filtering(tp, kernels=[kernel] * self.ndim) # average over kernel t_avg = t_sum / kernel_vol @@ -148,11 +155,11 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return torch.sum(ncc).neg() # sum over the batch and spatial ndims + return torch.sum(ncc).neg() # sum over the batch, channel and spatial ndims if self.reduction == LossReduction.NONE.value: return ncc.neg() if self.reduction == LossReduction.MEAN.value: - return torch.mean(ncc).neg() # average over the batch and spatial ndims + return torch.mean(ncc).neg() # average over the batch, channel and spatial ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') @@ -235,3 +242,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.reduction == LossReduction.MEAN.value: return torch.mean(mi).neg() # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + +if __name__ == "__main__": + print(LocalNormalizedCrossCorrelationLoss(in_channels=1)(torch.ones((1, 1, 1, 3, 3)), torch.ones((1, 1, 1, 3, 3)))) diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index 79027a03fb..cf8566a559 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -19,92 +19,60 @@ TEST_CASES = [ [ - {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + {"in_channels": 1, "ndim": 1, "kernel_type": "rectangular", "reduction": "sum"}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "pred": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), }, - -1.0, - ], - [ - {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "rectangular"}, - { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - }, - -1.0, - ], - [ - {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "triangular"}, - { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - }, - -1.0, - ], - [ - {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "gaussian"}, - { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), - }, - -1.0, + -1.0 * 3, ], [ - {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "rectangular"}, + {"in_channels": 1, "ndim": 1, "kernel_type": "rectangular"}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "pred": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), }, -1.0, ], [ - {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "triangular"}, + {"in_channels": 1, "ndim": 2, "kernel_type": "rectangular"}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(torch.float), }, -1.0, ], [ - {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian"}, + {"in_channels": 1, "ndim": 3, "kernel_type": "rectangular"}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 1, 3, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 1, 3, 3, 3).to(torch.float), }, -1.0, ], [ - {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian", "reduction": "sum"}, - { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), - }, - -6.0, - ], - [ - {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + {"in_channels": 3, "ndim": 3, "kernel_type": "rectangular"}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float) ** 2, }, - -0.06062524, + -0.95801723, ], [ - {"in_channels": 3, "ndim": 3, "kernel_size": 5, "kernel_type": "triangular"}, + {"in_channels": 3, "ndim": 3, "kernel_type": "triangular", "kernel_size": 5}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + "pred": torch.arange(0, 5).reshape(1, 1, -1, 1, 1).expand(1, 3, 5, 5, 5).to(torch.float), + "target": torch.arange(0, 5).reshape(1, 1, -1, 1, 1).expand(1, 3, 5, 5, 5).to(torch.float) ** 2, }, - -0.923356, + -0.918672, ], [ - {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "gaussian"}, + {"in_channels": 3, "ndim": 3, "kernel_type": "gaussian"}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float) ** 2, }, - -1.306177, + -0.95406944, ], ] @@ -113,7 +81,7 @@ class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) def test_ill_shape(self): loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index b84c0f1f48..dede3661a2 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -22,20 +22,20 @@ [BendingEnergyLoss, {}, ["pred"]], [ LocalNormalizedCrossCorrelationLoss, - {"in_channels": 3, "kernel_size": 5, "kernel_type": "rectangular"}, + {"in_channels": 1, "kernel_size": 7, "kernel_type": "rectangular"}, ["pred", "target"], ], [ LocalNormalizedCrossCorrelationLoss, - {"in_channels": 3, "kernel_size": 5, "kernel_type": "triangular"}, + {"in_channels": 1, "kernel_size": 5, "kernel_type": "triangular"}, ["pred", "target"], ], [ LocalNormalizedCrossCorrelationLoss, - {"in_channels": 3, "kernel_size": 5, "kernel_type": "gaussian"}, + {"in_channels": 1, "kernel_size": 3, "kernel_type": "gaussian"}, ["pred", "target"], ], - [GlobalMutualInformationLoss, {"num_bins": 5}, ["pred", "target"]], + [GlobalMutualInformationLoss, {"num_bins": 10}, ["pred", "target"]], ] @@ -62,7 +62,7 @@ def test_convergence(self, loss_type, loss_args, forward_args): max_iter = 40 # define a simple 3d example - target = torch.rand((1, 3, 5, 5, 5), device=self.device) + target = torch.rand((1, 1, 10, 10, 10), device=self.device) image = 12 * target + 27 image = image.to(device=self.device) @@ -71,9 +71,9 @@ class OnelayerNet(nn.Module): def __init__(self): super(OnelayerNet, self).__init__() self.layer = nn.Sequential( - nn.Conv3d(in_channels=3, out_channels=3, kernel_size=3, padding=1), + nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), nn.ReLU(), - nn.Conv3d(in_channels=3, out_channels=3, kernel_size=3, padding=1), + nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), ) def forward(self, x): @@ -89,7 +89,7 @@ def forward(self, x): optimizer = optim.Adam(net.parameters(), lr=learning_rate) # train the network - for _ in range(max_iter): + for iter in range(max_iter): # set the gradient to zero optimizer.zero_grad() @@ -98,10 +98,13 @@ def forward(self, x): loss_input = {"pred": output, "target": target} loss_val = loss(**{k: loss_input[k] for k in forward_args}) + if iter == 0: + init_loss = loss_val # backward pass loss_val.backward() optimizer.step() + assert init_loss > loss_val, f"loss did not decrease" if __name__ == "__main__": From 1b33686ef84756b4aa954713a0f515957aca575c Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 12 Jan 2021 15:31:00 +0000 Subject: [PATCH 20/23] 1412 autofix style Signed-off-by: kate-sann5100 --- tests/test_reg_loss_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index dede3661a2..70614b5da4 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -59,7 +59,7 @@ def test_convergence(self, loss_type, loss_args, forward_args): We verify that the loss is decreasing in almost all SGD steps. """ learning_rate = 0.001 - max_iter = 40 + max_iter = 100 # define a simple 3d example target = torch.rand((1, 1, 10, 10, 10), device=self.device) @@ -104,7 +104,7 @@ def forward(self, x): # backward pass loss_val.backward() optimizer.step() - assert init_loss > loss_val, f"loss did not decrease" + assert init_loss > loss_val, "loss did not decrease" if __name__ == "__main__": From ef881a2b2f8113c4dac3dc95606d9d357320526a Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 12 Jan 2021 17:01:46 +0000 Subject: [PATCH 21/23] 1412 simplify simple network Signed-off-by: kate-sann5100 --- tests/test_reg_loss_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 70614b5da4..fc59da1913 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -73,7 +73,7 @@ def __init__(self): self.layer = nn.Sequential( nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), nn.ReLU(), - nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), + # nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), ) def forward(self, x): From 6fa2469df795ded871f282a58170d48d20a7dcb8 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 12 Jan 2021 17:21:22 +0000 Subject: [PATCH 22/23] 1412 debug Signed-off-by: kate-sann5100 --- tests/test_reg_loss_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index fc59da1913..46914dea3c 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -62,7 +62,7 @@ def test_convergence(self, loss_type, loss_args, forward_args): max_iter = 100 # define a simple 3d example - target = torch.rand((1, 1, 10, 10, 10), device=self.device) + target = torch.rand((1, 1, 5, 5, 5), device=self.device) image = 12 * target + 27 image = image.to(device=self.device) @@ -73,7 +73,7 @@ def __init__(self): self.layer = nn.Sequential( nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), nn.ReLU(), - # nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), + nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), ) def forward(self, x): From f61681374e67bc2efcce525889a792647dd2be8c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 12 Jan 2021 20:06:05 +0000 Subject: [PATCH 23/23] remove temp. scripts Signed-off-by: Wenqi Li --- docs/source/losses.rst | 4 ++-- monai/losses/image_dissimilarity.py | 4 ---- tests/test_reg_loss_integration.py | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 4613a8a57f..a6aa4d566d 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -67,11 +67,11 @@ Registration Losses :members: `LocalNormalizedCrossCorrelationLoss` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss :members: `GlobalMutualInformationLoss` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: GlobalMutualInformationLoss :members: \ No newline at end of file diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 747ad6afc4..b229a0c08f 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -242,7 +242,3 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.reduction == LossReduction.MEAN.value: return torch.mean(mi).neg() # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') - - -if __name__ == "__main__": - print(LocalNormalizedCrossCorrelationLoss(in_channels=1)(torch.ones((1, 1, 1, 3, 3)), torch.ones((1, 1, 1, 3, 3)))) diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 46914dea3c..da6af6f66d 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -104,7 +104,7 @@ def forward(self, x): # backward pass loss_val.backward() optimizer.step() - assert init_loss > loss_val, "loss did not decrease" + self.assertTrue(init_loss > loss_val, "loss did not decrease") if __name__ == "__main__":