From ce61c38c377059f27deef07b26433669ab45c850 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 00:00:21 +0000 Subject: [PATCH 01/11] 1412 add local normalized cross correlation Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 157 ++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 monai/losses/image_dissimilarity.py diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py new file mode 100644 index 0000000000..557a757e79 --- /dev/null +++ b/monai/losses/image_dissimilarity.py @@ -0,0 +1,157 @@ +import torch +from torch.nn.modules.loss import _Loss + +from torch.nn import functional as F + +from monai.utils import LossReduction, Union + + +conv_dict = { + 1: F.conv1d, + 2: F.conv2d, + 3: F.conv3d +} + +EPS = 1e-7 + + +class LocalNormalizedCrossCorrelation(_Loss): + """ + Local squared zero-normalized cross-correlation. + The loss is based on a moving kernel/window over the y_true/y_pred, + within the window the square of zncc is calculated. + The kernel can be a rectangular / triangular / gaussian window. + The final loss is the averaged loss over all windows. + + Adapted from: + https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py + DeepReg (https://github.com/DeepRegNet/DeepReg) + """ + + def __init__( + self, + in_channels: int, + dim: int = 3, + kernel_size: int = 9, + kernel_type: str = "rectangular", + reduction: Union[LossReduction, str] = LossReduction.MEAN, + ) -> None: + """ + Args: + in_channels: number of input channels + dim: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3. + kernel_size: kernel size or kernel sigma for kernel_type=``"gaussian"`` + kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super(LocalNormalizedCrossCorrelation, self).__init__(reduction=LossReduction(reduction).value) + self.in_channels = in_channels + self.dim = dim + if self.dim not in [1, 2, 3]: + raise ValueError(f'Unsupported dim: {self.dim}-d, only 1-d, 2-d, and 3-d inputs are supported') + self.kernel_size = kernel_size + if kernel_type == "rectangular": + self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() + elif kernel_type == "triangular": + self.kernel, self.kernel_vol, self.padding = self.make_triangular_kernel() + elif kernel_type == "gaussian": + self.kernel, self.kernel_vol, self.padding = self.make_gaussian_kernel() + else: + raise ValueError( + f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' + ) + + def make_rectangular_kernel(self): + shape = [1, self.in_channels] + [self.kernel_size] * self.dim + return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.dim, int((self.kernel_size - 1) / 2) + + def make_triangular_kernel(self): + fsize = torch.tensor((self.kernel_size + 1) / 2, dtype=torch.int) + f1 = torch.ones( + [1, 1] + [fsize] * self.dim, + dtype=torch.float + ) / fsize # (1, 1, D, H, W) + f2 = torch.ones( + [1, self.in_channels] + [fsize] * self.dim, + dtype=torch.float + ) / fsize # (1, in_channels, D, H, W) + # (1, 1, D, H, W) -> (1, in_channels, D, H, W) + fn = conv_dict[self.dim] + kernel = fn(f1, f2, padding=int((fsize - 1) / 2)) + + return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2) + + def make_gaussian_kernel(self): + mean = (self.kernel_size - 1) / 2.0 + sigma = self.kernel_size / 3 + + grid_dim = torch.arange(0, self.kernel_size) + grid_dim_ch = torch.arange(0, self.in_channel) + + if self.dim == 1: + grid = torch.meshgrid(grid_dim_ch, grid_dim) + elif self.dim == 2: + grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim) + elif self.dim == 3: + grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim, grid_dim) + else: + raise ValueError + + grid = torch.stack(grid, dim=-1).to(dtype=torch.float) + kernel = torch.exp( + -torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2) + ).unsqueeze(0) # (1, in_channel, kernel_size, kernel_size, kernel_size) + return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2) + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD]. + target: the shape should be BNH[WD]. + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + """ + assert ( + target.shape == input.shape + ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" + + t2, p2, tp = target ** 2, input ** 2, target * input + + # sum over kernel + fn = conv_dict[self.dim] + t_sum = fn(target, weight=self.kernel, padding=self.padding) + p_sum = fn(input, weight=self.kernel, padding=self.paddin) + t2_sum = fn(t2, weight=self.kernel, padding=self.paddin) + p2_sum = fn(p2, weight=self.kernel, padding=self.paddin) + tp_sum = fn(tp, weight=self.kernel, padding=self.paddin) + + # average over kernel + t_avg = t_sum / self.kernel_vol + p_avg = p_sum / self.kernel_vol + + # normalized cross correlation between t and p + # sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p] + # denoted by num / denom + # assume we sum over N values + # num = sum[t * p - mean[t] * p - t * mean[p] + mean[t] * mean[p]] + # = sum[t*p] - sum[t] * sum[p] / N * 2 + sum[t] * sum[p] / N + # = sum[t*p] - sum[t] * sum[p] / N + # = sum[t*p] - sum[t] * mean[p] = cross + # the following is actually squared ncc + cross = tp_sum - p_avg * t_sum + t_var = t2_sum - t_avg * t_sum # std[t] ** 2 + p_var = p2_sum - p_avg * p_sum # std[p] ** 2 + ncc = (cross * cross + EPS) / (t_var * p_var + EPS) # shape = (batch, 1, D, H, W) + + if self.reduction == LossReduction.SUM.value: + return torch.sum(ncc) # sum over the batch and channel dims + if self.reduction == LossReduction.NONE.value: + return ncc + if self.reduction == LossReduction.MEAN.value: + return torch.mean(ncc) # average over the batch and channel dims + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') From 5cf91d0463f3b52d36a193dfea290263d9ad0819 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 01:02:13 +0000 Subject: [PATCH 02/11] 1412 add unit test and documentation Signed-off-by: kate-sann5100 --- docs/source/losses.rst | 5 + monai/losses/__init__.py | 1 + monai/losses/image_dissimilarity.py | 94 +++++++------- ...local_normalized_cross_correlation_loss.py | 121 ++++++++++++++++++ 4 files changed, 172 insertions(+), 49 deletions(-) create mode 100644 tests/test_local_normalized_cross_correlation_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index d2c8e02ca4..fc86c9cd05 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -65,3 +65,8 @@ Registration Losses ~~~~~~~~~~~~~~~~~~~ .. autoclass:: BendingEnergyLoss :members: + +`LocalNormalizedCrossCorrelationLoss` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: LocalNormalizedCrossCorrelationLoss + :members: diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d4e21f900c..42ccd0f65a 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -22,4 +22,5 @@ generalized_wasserstein_dice, ) from .focal_loss import FocalLoss +from .image_dissimilarity import LocalNormalizedCrossCorrelationLoss from .tversky import TverskyLoss diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 557a757e79..26ee7ea718 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -1,21 +1,15 @@ import torch -from torch.nn.modules.loss import _Loss - from torch.nn import functional as F +from torch.nn.modules.loss import _Loss from monai.utils import LossReduction, Union - -conv_dict = { - 1: F.conv1d, - 2: F.conv2d, - 3: F.conv3d -} +conv_dict = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} EPS = 1e-7 -class LocalNormalizedCrossCorrelation(_Loss): +class LocalNormalizedCrossCorrelationLoss(_Loss): """ Local squared zero-normalized cross-correlation. The loss is based on a moving kernel/window over the y_true/y_pred, @@ -29,17 +23,17 @@ class LocalNormalizedCrossCorrelation(_Loss): """ def __init__( - self, - in_channels: int, - dim: int = 3, - kernel_size: int = 9, - kernel_type: str = "rectangular", - reduction: Union[LossReduction, str] = LossReduction.MEAN, + self, + in_channels: int, + ndim: int = 3, + kernel_size: int = 9, + kernel_type: str = "rectangular", + reduction: Union[LossReduction, str] = LossReduction.MEAN, ) -> None: """ Args: in_channels: number of input channels - dim: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3. + ndim: number of spatial ndimensions, {``1``, ``2``, ``3``}. Defaults to 3. kernel_size: kernel size or kernel sigma for kernel_type=``"gaussian"`` kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -51,9 +45,9 @@ def __init__( """ super(LocalNormalizedCrossCorrelation, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels - self.dim = dim - if self.dim not in [1, 2, 3]: - raise ValueError(f'Unsupported dim: {self.dim}-d, only 1-d, 2-d, and 3-d inputs are supported') + self.ndim = ndim + if self.ndim not in [1, 2, 3]: + raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") self.kernel_size = kernel_size if kernel_type == "rectangular": self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() @@ -67,21 +61,17 @@ def __init__( ) def make_rectangular_kernel(self): - shape = [1, self.in_channels] + [self.kernel_size] * self.dim - return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.dim, int((self.kernel_size - 1) / 2) + shape = [1, self.in_channels] + [self.kernel_size] * self.ndim + return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.ndim, int((self.kernel_size - 1) / 2) def make_triangular_kernel(self): fsize = torch.tensor((self.kernel_size + 1) / 2, dtype=torch.int) - f1 = torch.ones( - [1, 1] + [fsize] * self.dim, - dtype=torch.float - ) / fsize # (1, 1, D, H, W) - f2 = torch.ones( - [1, self.in_channels] + [fsize] * self.dim, - dtype=torch.float - ) / fsize # (1, in_channels, D, H, W) + f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize # (1, 1, D, H, W) + f2 = ( + torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize + ) # (1, in_channels, D, H, W) # (1, 1, D, H, W) -> (1, in_channels, D, H, W) - fn = conv_dict[self.dim] + fn = conv_dict[self.ndim] kernel = fn(f1, f2, padding=int((fsize - 1) / 2)) return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2) @@ -90,22 +80,22 @@ def make_gaussian_kernel(self): mean = (self.kernel_size - 1) / 2.0 sigma = self.kernel_size / 3 - grid_dim = torch.arange(0, self.kernel_size) - grid_dim_ch = torch.arange(0, self.in_channel) + grid_ndim = torch.arange(0, self.kernel_size) + grid_ndim_ch = torch.arange(0, self.in_channels) - if self.dim == 1: - grid = torch.meshgrid(grid_dim_ch, grid_dim) - elif self.dim == 2: - grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim) - elif self.dim == 3: - grid = torch.meshgrid(grid_dim_ch, grid_dim, grid_dim, grid_dim) + if self.ndim == 1: + grid = torch.meshgrid(grid_ndim_ch, grid_ndim) + elif self.ndim == 2: + grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim) + elif self.ndim == 3: + grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim, grid_ndim) else: raise ValueError grid = torch.stack(grid, dim=-1).to(dtype=torch.float) - kernel = torch.exp( - -torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2) - ).unsqueeze(0) # (1, in_channel, kernel_size, kernel_size, kernel_size) + kernel = torch.exp(-torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2)).unsqueeze( + 0 + ) # (1, in_channel, kernel_size, kernel_size, kernel_size) return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: @@ -116,6 +106,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ + assert ( + input.shape[1] == self.in_channels + ), f"expecting input with {self.in_channels} channels, got input of shape {input.shape}" + assert ( + input.ndim - 2 == self.ndim + ), f"expecting input with {self.ndim} spatial dimensions, got input of shape {input.shape}" assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" @@ -123,12 +119,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, input ** 2, target * input # sum over kernel - fn = conv_dict[self.dim] + fn = conv_dict[self.ndim] t_sum = fn(target, weight=self.kernel, padding=self.padding) - p_sum = fn(input, weight=self.kernel, padding=self.paddin) - t2_sum = fn(t2, weight=self.kernel, padding=self.paddin) - p2_sum = fn(p2, weight=self.kernel, padding=self.paddin) - tp_sum = fn(tp, weight=self.kernel, padding=self.paddin) + p_sum = fn(input, weight=self.kernel, padding=self.padding) + t2_sum = fn(t2, weight=self.kernel, padding=self.padding) + p2_sum = fn(p2, weight=self.kernel, padding=self.padding) + tp_sum = fn(tp, weight=self.kernel, padding=self.padding) # average over kernel t_avg = t_sum / self.kernel_vol @@ -149,9 +145,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ncc = (cross * cross + EPS) / (t_var * p_var + EPS) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return torch.sum(ncc) # sum over the batch and channel dims + return -torch.sum(ncc) # sum over the batch and channel ndims if self.reduction == LossReduction.NONE.value: - return ncc + return -ncc if self.reduction == LossReduction.MEAN.value: - return torch.mean(ncc) # average over the batch and channel dims + return -torch.mean(ncc) # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py new file mode 100644 index 0000000000..332689be06 --- /dev/null +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -0,0 +1,121 @@ +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss + +TEST_CASES = [ + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "triangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 2, "kernel_size": 3, "kernel_type": "gaussian"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "triangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3), + }, + -1.0, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + }, + -0.06062524, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "triangular"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + }, + -0.9368649, + ], + [ + {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "gaussian"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, + }, + -0.50272596, + ], +] + + +class TestBendingEnergy(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-2) + + def test_ill_shape(self): + loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) + # in_channel unmatch + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 2, 3, 3, 3), dtype=torch.float), torch.ones((1, 2, 3, 3, 3), dtype=torch.float)) + # ndim unmatch + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 3, 3), dtype=torch.float)) + # input, target shape unmatch + with self.assertRaisesRegex(AssertionError, ""): + loss.forward(torch.ones((1, 3, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 4, 4, 4), dtype=torch.float)) + + def test_ill_opts(self): + input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(input, target) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction=None)(input, target) + + +if __name__ == "__main__": + unittest.main() From 9376195b34df1eab0fc0f94b251102184e23864f Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 01:37:55 +0000 Subject: [PATCH 03/11] 1412 fix bug Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 26ee7ea718..42f0cfd9d7 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -43,7 +43,7 @@ def __init__( - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. """ - super(LocalNormalizedCrossCorrelation, self).__init__(reduction=LossReduction(reduction).value) + super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels self.ndim = ndim if self.ndim not in [1, 2, 3]: From ac36a9fc877882cf8da48378479782a9e58c7b8e Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 15:00:23 +0000 Subject: [PATCH 04/11] 1412 reformat code Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 58 ++++++++++++------- ...local_normalized_cross_correlation_loss.py | 15 ++++- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 42f0cfd9d7..232b8e0600 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -1,3 +1,14 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss @@ -6,8 +17,6 @@ conv_dict = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} -EPS = 1e-7 - class LocalNormalizedCrossCorrelationLoss(_Loss): """ @@ -29,6 +38,7 @@ def __init__( kernel_size: int = 9, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_dr: float = 1e-7, ) -> None: """ Args: @@ -42,12 +52,14 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_dr: a small constant added to the denominator to avoid nan. """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels self.ndim = ndim if self.ndim not in [1, 2, 3]: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") + self.fn = conv_dict[self.ndim] self.kernel_size = kernel_size if kernel_type == "rectangular": self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() @@ -59,26 +71,29 @@ def __init__( raise ValueError( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) + self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): shape = [1, self.in_channels] + [self.kernel_size] * self.ndim return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.ndim, int((self.kernel_size - 1) / 2) def make_triangular_kernel(self): - fsize = torch.tensor((self.kernel_size + 1) / 2, dtype=torch.int) - f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize # (1, 1, D, H, W) - f2 = ( - torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float) / fsize - ) # (1, in_channels, D, H, W) + fsize = int((self.kernel_size + 1) // 2) + f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) # (1, 1, D, H, W) + f1 = F.pad(f1, [(fsize - 1) // 2, (fsize - 1) // 2] * self.ndim) + f2 = torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) + # (in_channels, 1, D, H, W) # (1, 1, D, H, W) -> (1, in_channels, D, H, W) - fn = conv_dict[self.ndim] - kernel = fn(f1, f2, padding=int((fsize - 1) / 2)) + padding_needed = max(fsize - 1, 0) + padding = [padding_needed // 2, padding_needed - padding_needed // 2] * self.ndim + f1 = F.pad(f1, padding) + kernel = self.fn(f1, f2) - return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2) + return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2.0) def make_gaussian_kernel(self): mean = (self.kernel_size - 1) / 2.0 - sigma = self.kernel_size / 3 + sigma = self.kernel_size / 3.0 grid_ndim = torch.arange(0, self.kernel_size) grid_ndim_ch = torch.arange(0, self.in_channels) @@ -96,7 +111,7 @@ def make_gaussian_kernel(self): kernel = torch.exp(-torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2)).unsqueeze( 0 ) # (1, in_channel, kernel_size, kernel_size, kernel_size) - return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2) + return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2.0) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -119,12 +134,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, input ** 2, target * input # sum over kernel - fn = conv_dict[self.ndim] - t_sum = fn(target, weight=self.kernel, padding=self.padding) - p_sum = fn(input, weight=self.kernel, padding=self.padding) - t2_sum = fn(t2, weight=self.kernel, padding=self.padding) - p2_sum = fn(p2, weight=self.kernel, padding=self.padding) - tp_sum = fn(tp, weight=self.kernel, padding=self.padding) + t_sum = self.fn(target, weight=self.kernel, padding=self.padding) + p_sum = self.fn(input, weight=self.kernel, padding=self.padding) + t2_sum = self.fn(t2, weight=self.kernel, padding=self.padding) + p2_sum = self.fn(p2, weight=self.kernel, padding=self.padding) + tp_sum = self.fn(tp, weight=self.kernel, padding=self.padding) # average over kernel t_avg = t_sum / self.kernel_vol @@ -142,12 +156,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc = (cross * cross + EPS) / (t_var * p_var + EPS) # shape = (batch, 1, D, H, W) + ncc = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return -torch.sum(ncc) # sum over the batch and channel ndims + return -torch.sum(ncc).neg() # sum over the batch and channel ndims if self.reduction == LossReduction.NONE.value: - return -ncc + return ncc.neg() if self.reduction == LossReduction.MEAN.value: - return -torch.mean(ncc) # average over the batch and channel ndims + return torch.mean(ncc).neg() # average over the batch and channel ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index 332689be06..beb1a1dca2 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -1,3 +1,14 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import numpy as np @@ -72,7 +83,7 @@ -0.06062524, ], [ - {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "triangular"}, + {"in_channels": 3, "ndim": 3, "kernel_size": 6, "kernel_type": "triangular"}, { "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, @@ -90,7 +101,7 @@ ] -class TestBendingEnergy(unittest.TestCase): +class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) From ed6c28bca7393a527b5bc21a0442f0a08333d33c Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 7 Jan 2021 15:20:47 +0000 Subject: [PATCH 05/11] 1412 debug type check Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 232b8e0600..4e64d81eb6 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -156,7 +156,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) + ncc: torch.Tensor = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) + # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: return -torch.sum(ncc).neg() # sum over the batch and channel ndims From 43c2f355416b249a34adfa87367aabd5f1b381f0 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 00:42:45 +0000 Subject: [PATCH 06/11] 1412 use separable filter for speed Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 75 +++++++------------ monai/networks/layers/simplelayers.py | 6 ++ ...local_normalized_cross_correlation_loss.py | 10 ++- 3 files changed, 41 insertions(+), 50 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 4e64d81eb6..bd9a06e187 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -13,10 +13,9 @@ from torch.nn import functional as F from torch.nn.modules.loss import _Loss +from monai.networks.layers import gaussian_1d, separable_filtering from monai.utils import LossReduction, Union -conv_dict = {1: F.conv1d, 2: F.conv2d, 3: F.conv3d} - class LocalNormalizedCrossCorrelationLoss(_Loss): """ @@ -44,7 +43,7 @@ def __init__( Args: in_channels: number of input channels ndim: number of spatial ndimensions, {``1``, ``2``, ``3``}. Defaults to 3. - kernel_size: kernel size or kernel sigma for kernel_type=``"gaussian"`` + kernel_size: kernel spatial size, must be odd. kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -56,62 +55,46 @@ def __init__( """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) self.in_channels = in_channels + self.ndim = ndim if self.ndim not in [1, 2, 3]: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") - self.fn = conv_dict[self.ndim] + self.kernel_size = kernel_size + if self.kernel_size % 2 == 0: + raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") + if kernel_type == "rectangular": - self.kernel, self.kernel_vol, self.padding = self.make_rectangular_kernel() + self.kernel = self.make_rectangular_kernel() elif kernel_type == "triangular": - self.kernel, self.kernel_vol, self.padding = self.make_triangular_kernel() + self.kernel = self.make_triangular_kernel() elif kernel_type == "gaussian": - self.kernel, self.kernel_vol, self.padding = self.make_gaussian_kernel() + self.kernel = self.make_gaussian_kernel() else: raise ValueError( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) + + self.kernel_vol = torch.sum(self.kernel) ** self.ndim self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): - shape = [1, self.in_channels] + [self.kernel_size] * self.ndim - return torch.ones(shape, dtype=torch.float), self.kernel_size ** self.ndim, int((self.kernel_size - 1) / 2) + return torch.ones(self.kernel_size) def make_triangular_kernel(self): - fsize = int((self.kernel_size + 1) // 2) - f1 = torch.ones([1, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) # (1, 1, D, H, W) - f1 = F.pad(f1, [(fsize - 1) // 2, (fsize - 1) // 2] * self.ndim) - f2 = torch.ones([self.in_channels, 1] + [fsize] * self.ndim, dtype=torch.float).div(fsize) - # (in_channels, 1, D, H, W) - # (1, 1, D, H, W) -> (1, in_channels, D, H, W) - padding_needed = max(fsize - 1, 0) - padding = [padding_needed // 2, padding_needed - padding_needed // 2] * self.ndim - f1 = F.pad(f1, padding) - kernel = self.fn(f1, f2) - - return kernel, torch.sum(kernel ** 2), int((fsize - 1) / 2.0) + fsize = (self.kernel_size + 1) // 2 + if fsize % 2 == 0: + fsize -= 1 + f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize) + padding = (self.kernel_size - fsize) // 2 + fsize // 2 + return F.conv1d(f, f, padding=padding).reshape(-1) def make_gaussian_kernel(self): - mean = (self.kernel_size - 1) / 2.0 - sigma = self.kernel_size / 3.0 - - grid_ndim = torch.arange(0, self.kernel_size) - grid_ndim_ch = torch.arange(0, self.in_channels) - - if self.ndim == 1: - grid = torch.meshgrid(grid_ndim_ch, grid_ndim) - elif self.ndim == 2: - grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim) - elif self.ndim == 3: - grid = torch.meshgrid(grid_ndim_ch, grid_ndim, grid_ndim, grid_ndim) - else: - raise ValueError - - grid = torch.stack(grid, dim=-1).to(dtype=torch.float) - kernel = torch.exp(-torch.sum(torch.square(grid - mean), dim=-1) / (2 * sigma ** 2)).unsqueeze( - 0 - ) # (1, in_channel, kernel_size, kernel_size, kernel_size) - return kernel, torch.sum(kernel ** 2), int((self.kernel_size - 1) / 2.0) + sigma = torch.tensor(self.kernel_size / 3.0) + kernel = gaussian_1d(sigma=sigma, truncated=self.kernel_size // 2, approx="sampled", normalize=False) * ( + 2.5066282 * sigma + ) + return kernel[: self.kernel_size] def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -134,11 +117,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, input ** 2, target * input # sum over kernel - t_sum = self.fn(target, weight=self.kernel, padding=self.padding) - p_sum = self.fn(input, weight=self.kernel, padding=self.padding) - t2_sum = self.fn(t2, weight=self.kernel, padding=self.padding) - p2_sum = self.fn(p2, weight=self.kernel, padding=self.padding) - tp_sum = self.fn(tp, weight=self.kernel, padding=self.padding) + t_sum = separable_filtering(target, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + p_sum = separable_filtering(input, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + t2_sum = separable_filtering(t2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + p2_sum = separable_filtering(p2, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) + tp_sum = separable_filtering(tp, kernels=[self.kernel] * self.ndim).sum(1, keepdim=True) # average over kernel t_avg = t_sum / self.kernel_vol diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 48012dfb1c..8ff9464145 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -365,3 +365,9 @@ def reset_parameters(self): def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state) + + +if __name__ == "__main__": + input = torch.ones((1, 3, 3, 3)) + kernels = [torch.ones(1, 3)] * 2 + print(separable_filtering(input, kernels)) diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index beb1a1dca2..fe455b0597 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -83,12 +83,12 @@ -0.06062524, ], [ - {"in_channels": 3, "ndim": 3, "kernel_size": 6, "kernel_type": "triangular"}, + {"in_channels": 3, "ndim": 3, "kernel_size": 5, "kernel_type": "triangular"}, { "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, }, - -0.9368649, + -0.923356, ], [ {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "gaussian"}, @@ -96,7 +96,7 @@ "input": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3), "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3) ** 2, }, - -0.50272596, + -1.306177, ], ] @@ -105,7 +105,7 @@ class TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_val): result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-2) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) def test_ill_shape(self): loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) @@ -122,6 +122,8 @@ def test_ill_shape(self): def test_ill_opts(self): input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(input, target) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, reduction="unknown")(input, target) with self.assertRaisesRegex(ValueError, ""): From f76e3f04bcc4f72eb3fb989261e671bcb5252a43 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 00:49:19 +0000 Subject: [PATCH 07/11] 1412 update Union import route Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index bd9a06e187..093aa2546b 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,13 +8,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss from monai.networks.layers import gaussian_1d, separable_filtering -from monai.utils import LossReduction, Union +from monai.utils import LossReduction class LocalNormalizedCrossCorrelationLoss(_Loss): From 50940c2982da3a97f33abb09b21b140641a47456 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sat, 9 Jan 2021 03:02:08 +0000 Subject: [PATCH 08/11] 1412 fix negative bug and add smooth_nr Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 093aa2546b..9a9a0291dd 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -38,6 +38,7 @@ def __init__( kernel_size: int = 9, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-7, smooth_dr: float = 1e-7, ) -> None: """ @@ -52,6 +53,7 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid zero. smooth_dr: a small constant added to the denominator to avoid nan. """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) @@ -77,6 +79,7 @@ def __init__( ) self.kernel_vol = torch.sum(self.kernel) ** self.ndim + self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) def make_rectangular_kernel(self): @@ -140,11 +143,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: cross = tp_sum - p_avg * t_sum t_var = t2_sum - t_avg * t_sum # std[t] ** 2 p_var = p2_sum - p_avg * p_sum # std[p] ** 2 - ncc: torch.Tensor = (cross * cross + self.smooth_dr) / (t_var * p_var + self.smooth_dr) + ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr) # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return -torch.sum(ncc).neg() # sum over the batch and channel ndims + return torch.sum(ncc).neg() # sum over the batch and channel ndims if self.reduction == LossReduction.NONE.value: return ncc.neg() if self.reduction == LossReduction.MEAN.value: From cab9b0beb4f514c68f62e6b950d428963d209ab7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 10 Jan 2021 15:44:48 +0000 Subject: [PATCH 09/11] remove temp. code Signed-off-by: Wenqi Li --- docs/source/losses.rst | 2 +- monai/networks/layers/simplelayers.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index fc86c9cd05..462c303e65 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -67,6 +67,6 @@ Registration Losses :members: `LocalNormalizedCrossCorrelationLoss` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss :members: diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 55189e3e9b..ba60f4eca4 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -365,9 +365,3 @@ def reset_parameters(self): def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state) - - -if __name__ == "__main__": - input = torch.ones((1, 3, 3, 3)) - kernels = [torch.ones(1, 3)] * 2 - print(separable_filtering(input, kernels)) From 6db25289fd07ae4cfc3b83a042f127acd94e0042 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sun, 10 Jan 2021 17:07:45 +0000 Subject: [PATCH 10/11] 1412 reformat code Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 64 ++++++++++--------- ...local_normalized_cross_correlation_loss.py | 12 ++++ 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 9a9a0291dd..65b53309f5 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Tuple, Union import torch from torch.nn import functional as F @@ -18,6 +18,34 @@ from monai.utils import LossReduction +def make_rectangular_kernel(kernel_size: int) -> torch.Tensor: + return torch.ones(kernel_size) + + +def make_triangular_kernel(kernel_size: int) -> torch.Tensor: + fsize = (kernel_size + 1) // 2 + if fsize % 2 == 0: + fsize -= 1 + f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize) + padding = (kernel_size - fsize) // 2 + fsize // 2 + return F.conv1d(f, f, padding=padding).reshape(-1) + + +def make_gaussian_kernel(kernel_size: int) -> torch.Tensor: + sigma = torch.tensor(kernel_size / 3.0) + kernel = gaussian_1d(sigma=sigma, truncated=kernel_size // 2, approx="sampled", normalize=False) * ( + 2.5066282 * sigma + ) + return kernel[:kernel_size] + + +kernel_dict = { + "rectangular": make_rectangular_kernel, + "triangular": make_triangular_kernel, + "gaussian": make_gaussian_kernel, +} + + class LocalNormalizedCrossCorrelationLoss(_Loss): """ Local squared zero-normalized cross-correlation. @@ -53,7 +81,7 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. - smooth_nr: a small constant added to the numerator to avoid zero. + smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ super(LocalNormalizedCrossCorrelationLoss, self).__init__(reduction=LossReduction(reduction).value) @@ -67,39 +95,15 @@ def __init__( if self.kernel_size % 2 == 0: raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") - if kernel_type == "rectangular": - self.kernel = self.make_rectangular_kernel() - elif kernel_type == "triangular": - self.kernel = self.make_triangular_kernel() - elif kernel_type == "gaussian": - self.kernel = self.make_gaussian_kernel() - else: + if kernel_type not in kernel_dict.keys(): raise ValueError( f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' ) - + self.kernel = kernel_dict[kernel_type](self.kernel_size) self.kernel_vol = torch.sum(self.kernel) ** self.ndim self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def make_rectangular_kernel(self): - return torch.ones(self.kernel_size) - - def make_triangular_kernel(self): - fsize = (self.kernel_size + 1) // 2 - if fsize % 2 == 0: - fsize -= 1 - f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize) - padding = (self.kernel_size - fsize) // 2 + fsize // 2 - return F.conv1d(f, f, padding=padding).reshape(-1) - - def make_gaussian_kernel(self): - sigma = torch.tensor(self.kernel_size / 3.0) - kernel = gaussian_1d(sigma=sigma, truncated=self.kernel_size // 2, approx="sampled", normalize=False) * ( - 2.5066282 * sigma - ) - return kernel[: self.kernel_size] - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: @@ -147,9 +151,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # shape = (batch, 1, D, H, W) if self.reduction == LossReduction.SUM.value: - return torch.sum(ncc).neg() # sum over the batch and channel ndims + return torch.sum(ncc).neg() # sum over the batch and spatial ndims if self.reduction == LossReduction.NONE.value: return ncc.neg() if self.reduction == LossReduction.MEAN.value: - return torch.mean(ncc).neg() # average over the batch and channel ndims + return torch.mean(ncc).neg() # average over the batch and spatial ndims raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index fe455b0597..cb2f446dfc 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -74,6 +74,14 @@ }, -1.0, ], + [ + {"in_channels": 3, "ndim": 1, "kernel_size": 3, "kernel_type": "gaussian", "reduction": "sum"}, + { + "input": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), + "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(2, 3, 3), + }, + -6.0, + ], [ {"in_channels": 3, "ndim": 3, "kernel_size": 3, "kernel_type": "rectangular"}, { @@ -122,6 +130,10 @@ def test_ill_shape(self): def test_ill_opts(self): input = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type="unknown")(input, target) + with self.assertRaisesRegex(ValueError, ""): + LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_type=None)(input, target) with self.assertRaisesRegex(ValueError, ""): LocalNormalizedCrossCorrelationLoss(in_channels=3, kernel_size=4)(input, target) with self.assertRaisesRegex(ValueError, ""): From af4cab559fc56ed47139f6546b443273fbafc605 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sun, 10 Jan 2021 17:35:07 +0000 Subject: [PATCH 11/11] 1412 remove redundant import Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 65b53309f5..d42303e154 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Union import torch from torch.nn import functional as F