diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 839e6826b8..34709daa42 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -14,16 +14,17 @@ """ import warnings -from typing import Callable, Iterable, Optional, Sequence, Union +from typing import Callable, Iterable, Optional, Sequence, Tuple, Union import numpy as np import torch +import torch.nn.functional as F from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.networks import one_hot -from monai.networks.layers import GaussianFilter, apply_filter +from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Transform from monai.transforms.utils import ( @@ -821,13 +822,18 @@ def __call__(self, data): class SobelGradients(Transform): - """Calculate Sobel horizontal and vertical gradients + """Calculate Sobel gradients of a grayscale image with the shape of (CxH[xWxDx...]). Args: kernel_size: the size of the Sobel kernel. Defaults to 3. - padding: the padding for the convolution to apply the kernel. Defaults to `"same"`. + spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient + along each of the provide axis. By default it calculate the gradient for all spatial axes. + normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True. + normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False. + padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`. + Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + See ``torch.nn.Conv1d()`` for more information. dtype: kernel data type (torch.dtype). Defaults to `torch.float32`. - device: the device to create the kernel on. Defaults to `"cpu"`. """ @@ -836,36 +842,90 @@ class SobelGradients(Transform): def __init__( self, kernel_size: int = 3, - padding: Union[int, str] = "same", + spatial_axes: Optional[Union[Sequence[int], int]] = None, + normalize_kernels: bool = True, + normalize_gradients: bool = False, + padding_mode: str = "reflect", dtype: torch.dtype = torch.float32, - device: Union[torch.device, int, str] = "cpu", ) -> None: super().__init__() - self.kernel: torch.Tensor = self._get_kernel(kernel_size, dtype, device) - self.padding = padding - - def _get_kernel(self, size, dtype, device) -> torch.Tensor: + self.padding = padding_mode + self.spatial_axes = spatial_axes + self.normalize_kernels = normalize_kernels + self.normalize_gradients = normalize_gradients + self.kernel_diff, self.kernel_smooth = self._get_kernel(kernel_size, dtype) + + def _get_kernel(self, size, dtype) -> Tuple[torch.Tensor, torch.Tensor]: + if size < 3: + raise ValueError(f"Sobel kernel size should be at least three. {size} was given.") if size % 2 == 0: raise ValueError(f"Sobel kernel size should be an odd number. {size} was given.") - if not dtype.is_floating_point: - raise ValueError(f"`dtype` for Sobel kernel should be floating point. {dtype} was given.") - - numerator: torch.Tensor = torch.arange( - -size // 2 + 1, size // 2 + 1, dtype=dtype, device=device, requires_grad=False - ).expand(size, size) - denominator = numerator * numerator - denominator = denominator + denominator.T - denominator[:, size // 2] = 1.0 # to avoid division by zero - kernel = numerator / denominator - return kernel + + kernel_diff = torch.tensor([[[-1, 0, 1]]], dtype=dtype) + kernel_smooth = torch.tensor([[[1, 2, 1]]], dtype=dtype) + kernel_expansion = torch.tensor([[[1, 2, 1]]], dtype=dtype) + + if self.normalize_kernels: + if not dtype.is_floating_point: + raise ValueError( + f"`dtype` for Sobel kernel should be floating point when `normalize_kernel==True`. {dtype} was given." + ) + kernel_diff /= 2.0 + kernel_smooth /= 4.0 + kernel_expansion /= 4.0 + + # Expand the kernel to larger size than 3 + expand = (size - 3) // 2 + for _ in range(expand): + kernel_diff = F.conv1d(kernel_diff, kernel_expansion, padding=2) + kernel_smooth = F.conv1d(kernel_smooth, kernel_expansion, padding=2) + + return kernel_diff.squeeze(), kernel_smooth.squeeze() def __call__(self, image: NdarrayOrTensor) -> torch.Tensor: image_tensor = convert_to_tensor(image, track_meta=get_track_meta()) - kernel_v = self.kernel.to(image_tensor.device) - kernel_h = kernel_v.T - image_tensor = image_tensor.unsqueeze(0) # adds a batch dim - grad_v = apply_filter(image_tensor, kernel_v, padding=self.padding) - grad_h = apply_filter(image_tensor, kernel_h, padding=self.padding) - grad = torch.cat([grad_h, grad_v], dim=1) - grad, *_ = convert_to_dst_type(grad.squeeze(0), image_tensor) - return grad + + # Check/set spatial axes + n_spatial_dims = image_tensor.ndim - 1 # excluding the channel dimension + valid_spatial_axes = list(range(n_spatial_dims)) + list(range(-n_spatial_dims, 0)) + + # Check gradient axes to be valid + if self.spatial_axes is None: + spatial_axes = list(range(n_spatial_dims)) + else: + invalid_axis = set(ensure_tuple(self.spatial_axes)) - set(valid_spatial_axes) + if invalid_axis: + raise ValueError( + f"The provide axes to calculate gradient is not valid: {invalid_axis}. " + f"The image has {n_spatial_dims} spatial dimensions so it should be: {valid_spatial_axes}." + ) + spatial_axes = [ax % n_spatial_dims if ax < 0 else ax for ax in ensure_tuple(self.spatial_axes)] + + # Add batch dimension for separable_filtering + image_tensor = image_tensor.unsqueeze(0) + + # Get the Sobel kernels + kernel_diff = self.kernel_diff.to(image_tensor.device) + kernel_smooth = self.kernel_smooth.to(image_tensor.device) + + # Calculate gradient + grad_list = [] + for ax in spatial_axes: + kernels = [kernel_smooth] * n_spatial_dims + kernels[ax - 1] = kernel_diff + grad = separable_filtering(image_tensor, kernels, mode=self.padding) + if self.normalize_gradients: + grad_min = grad.min() + if grad_min != grad.max(): + grad -= grad_min + grad_max = grad.max() + if grad_max > 0: + grad /= grad_max + grad_list.append(grad) + + grads = torch.cat(grad_list, dim=1) + + # Remove batch dimension and convert the gradient type to be the same as input image + grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0] + + return grads diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 4ed37b605a..78d84a0bd1 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -794,14 +794,19 @@ def get_saver(self): class SobelGradientsd(MapTransform): - """Calculate Sobel horizontal and vertical gradients. + """Calculate Sobel horizontal and vertical gradients of a grayscale image. Args: keys: keys of the corresponding items to model output. kernel_size: the size of the Sobel kernel. Defaults to 3. - padding: the padding for the convolution to apply the kernel. Defaults to `"same"`. + spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient + along each of the provide axis. By default it calculate the gradient for all spatial axes. + normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True. + normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False. + padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`. + Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + See ``torch.nn.Conv1d()`` for more information. dtype: kernel data type (torch.dtype). Defaults to `torch.float32`. - device: the device to create the kernel on. Defaults to `"cpu"`. new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of key intact. By default not prefix is set and the corresponding array to the key will be replaced. allow_missing_keys: don't raise exception if key is missing. @@ -814,15 +819,26 @@ def __init__( self, keys: KeysCollection, kernel_size: int = 3, - padding: Union[int, str] = "same", + spatial_axes: Optional[Union[Sequence[int], int]] = None, + normalize_kernels: bool = True, + normalize_gradients: bool = False, + padding_mode: str = "reflect", dtype: torch.dtype = torch.float32, - device: Union[torch.device, int, str] = "cpu", new_key_prefix: Optional[str] = None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.transform = SobelGradients(kernel_size=kernel_size, padding=padding, dtype=dtype, device=device) + self.transform = SobelGradients( + kernel_size=kernel_size, + spatial_axes=spatial_axes, + normalize_kernels=normalize_kernels, + normalize_gradients=normalize_gradients, + padding_mode=padding_mode, + dtype=dtype, + ) self.new_key_prefix = new_key_prefix + self.kernel_diff = self.transform.kernel_diff + self.kernel_smooth = self.transform.kernel_smooth def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) diff --git a/tests/test_hovernet_loss.py b/tests/test_hovernet_loss.py index 0bdabdef70..c2c888804a 100644 --- a/tests/test_hovernet_loss.py +++ b/tests/test_hovernet_loss.py @@ -141,17 +141,17 @@ def test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, w TEST_CASE_3 = [ # batch size of 2, 3 classes with minor rotation of nuclear prediction {"prediction": inputs_test[3].inputs, "target": inputs_test[3].targets}, - 6.5777, + 3.6169, ] TEST_CASE_4 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction {"prediction": inputs_test[4].inputs, "target": inputs_test[4].targets}, - 8.5143, + 4.5079, ] TEST_CASE_5 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction {"prediction": inputs_test[5].inputs, "target": inputs_test[5].targets}, - 10.1705, + 5.4663, ] CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5] diff --git a/tests/test_sobel_gradient.py b/tests/test_sobel_gradient.py index 17f6bffdfb..d8366f576a 100644 --- a/tests/test_sobel_gradient.py +++ b/tests/test_sobel_gradient.py @@ -19,73 +19,169 @@ IMAGE = torch.zeros(1, 16, 16, dtype=torch.float32) IMAGE[0, 8, :] = 1 + +# Output with reflect padding OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32) -OUTPUT_3x3[0, 7, :] = 2.0 -OUTPUT_3x3[0, 9, :] = -2.0 -OUTPUT_3x3[0, 7, 0] = OUTPUT_3x3[0, 7, -1] = 1.5 -OUTPUT_3x3[0, 9, 0] = OUTPUT_3x3[0, 9, -1] = -1.5 -OUTPUT_3x3[1, 7, 0] = OUTPUT_3x3[1, 9, 0] = 0.5 -OUTPUT_3x3[1, 8, 0] = 1.0 -OUTPUT_3x3[1, 8, -1] = -1.0 -OUTPUT_3x3[1, 7, -1] = OUTPUT_3x3[1, 9, -1] = -0.5 +OUTPUT_3x3[1, 7, :] = 0.5 +OUTPUT_3x3[1, 9, :] = -0.5 + +# Output with zero padding +OUTPUT_3x3_ZERO_PAD = OUTPUT_3x3.clone() +OUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, 0] = 0.125 +OUTPUT_3x3_ZERO_PAD[0, 8, 0] = 0.25 +OUTPUT_3x3_ZERO_PAD[0, 7, -1] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -0.125 +OUTPUT_3x3_ZERO_PAD[0, 8, -1] = -0.25 +OUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 7, -1] = 3.0 / 8.0 +OUTPUT_3x3_ZERO_PAD[1, 9, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -3.0 / 8.0 TEST_CASE_0 = [IMAGE, {"kernel_size": 3, "dtype": torch.float32}, OUTPUT_3x3] TEST_CASE_1 = [IMAGE, {"kernel_size": 3, "dtype": torch.float64}, OUTPUT_3x3] +TEST_CASE_2 = [IMAGE, {"kernel_size": 3, "spatial_axes": 0, "dtype": torch.float64}, OUTPUT_3x3[0:1]] +TEST_CASE_3 = [IMAGE, {"kernel_size": 3, "spatial_axes": 1, "dtype": torch.float64}, OUTPUT_3x3[1:2]] +TEST_CASE_4 = [IMAGE, {"kernel_size": 3, "spatial_axes": [1], "dtype": torch.float64}, OUTPUT_3x3[1:2]] +TEST_CASE_5 = [ + IMAGE, + {"kernel_size": 3, "spatial_axes": [0, 1], "normalize_kernels": True, "dtype": torch.float64}, + OUTPUT_3x3, +] +TEST_CASE_6 = [ + IMAGE, + {"kernel_size": 3, "spatial_axes": (0, 1), "padding_mode": "reflect", "dtype": torch.float64}, + OUTPUT_3x3, +] +TEST_CASE_7 = [ + IMAGE, + {"kernel_size": 3, "spatial_axes": (0, 1), "padding_mode": "zeros", "dtype": torch.float64}, + OUTPUT_3x3_ZERO_PAD, +] +TEST_CASE_8 = [ # Non-normalized kernels + IMAGE, + {"kernel_size": 3, "normalize_kernels": False, "dtype": torch.float32}, + OUTPUT_3x3 * 8.0, +] +TEST_CASE_9 = [ # Normalized gradients and normalized kernels + IMAGE, + { + "kernel_size": 3, + "normalize_kernels": True, + "normalize_gradients": True, + "spatial_axes": (0, 1), + "dtype": torch.float64, + }, + torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5]), +] +TEST_CASE_10 = [ # Normalized gradients but non-normalized kernels + IMAGE, + { + "kernel_size": 3, + "normalize_kernels": False, + "normalize_gradients": True, + "spatial_axes": (0, 1), + "dtype": torch.float64, + }, + torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5]), +] + TEST_CASE_KERNEL_0 = [ {"kernel_size": 3, "dtype": torch.float64}, - torch.tensor([[-0.5, 0.0, 0.5], [-1.0, 0.0, 1.0], [-0.5, 0.0, 0.5]], dtype=torch.float64), + (torch.tensor([-0.5, 0.0, 0.5], dtype=torch.float64), torch.tensor([0.25, 0.5, 0.25], dtype=torch.float64)), ] TEST_CASE_KERNEL_1 = [ {"kernel_size": 5, "dtype": torch.float64}, - torch.tensor( - [ - [-0.25, -0.2, 0.0, 0.2, 0.25], - [-0.4, -0.5, 0.0, 0.5, 0.4], - [-0.5, -1.0, 0.0, 1.0, 0.5], - [-0.4, -0.5, 0.0, 0.5, 0.4], - [-0.25, -0.2, 0.0, 0.2, 0.25], - ], - dtype=torch.float64, + ( + torch.tensor([-0.1250, -0.2500, 0.0000, 0.2500, 0.1250], dtype=torch.float64), + torch.tensor([0.0625, 0.2500, 0.3750, 0.2500, 0.0625], dtype=torch.float64), ), ] TEST_CASE_KERNEL_2 = [ {"kernel_size": 7, "dtype": torch.float64}, - torch.tensor( - [ - [-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0], - [-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0], - [-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0], - [-3.0 / 9.0, -2.0 / 4.0, -1.0 / 1.0, 0.0, 1.0 / 1.0, 2.0 / 4.0, 3.0 / 9.0], - [-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0], - [-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0], - [-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0], - ], - dtype=torch.float64, + ( + torch.tensor([-0.03125, -0.125, -0.15625, 0.0, 0.15625, 0.125, 0.03125], dtype=torch.float64), + torch.tensor([0.015625, 0.09375, 0.234375, 0.3125, 0.234375, 0.09375, 0.015625], dtype=torch.float64), + ), +] +TEST_CASE_KERNEL_NON_NORMALIZED_0 = [ + {"kernel_size": 3, "normalize_kernels": False, "dtype": torch.float64}, + (torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float64), torch.tensor([1.0, 2.0, 1.0], dtype=torch.float64)), +] +TEST_CASE_KERNEL_NON_NORMALIZED_1 = [ + {"kernel_size": 5, "normalize_kernels": False, "dtype": torch.float64}, + ( + torch.tensor([-1.0, -2.0, 0.0, 2.0, 1.0], dtype=torch.float64), + torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0], dtype=torch.float64), + ), +] +TEST_CASE_KERNEL_NON_NORMALIZED_2 = [ + {"kernel_size": 7, "normalize_kernels": False, "dtype": torch.float64}, + ( + torch.tensor([-1.0, -4.0, -5.0, 0.0, 5.0, 4.0, 1.0], dtype=torch.float64), + torch.tensor([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0], dtype=torch.float64), ), ] -TEST_CASE_ERROR_0 = [{"kernel_size": 2, "dtype": torch.float32}] + +TEST_CASE_ERROR_0 = [IMAGE, {"kernel_size": 1}] # kernel size less than 3 +TEST_CASE_ERROR_1 = [IMAGE, {"kernel_size": 4}] # even kernel size +TEST_CASE_ERROR_2 = [IMAGE, {"spatial_axes": "horizontal"}] # wrong type direction +TEST_CASE_ERROR_3 = [IMAGE, {"spatial_axes": 3}] # wrong direction +TEST_CASE_ERROR_4 = [IMAGE, {"spatial_axes": [3]}] # wrong direction in a list +TEST_CASE_ERROR_5 = [IMAGE, {"spatial_axes": [0, 4]}] # correct and wrong direction in a list class SobelGradientTests(unittest.TestCase): backend = None - @parameterized.expand([TEST_CASE_0]) + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + ] + ) def test_sobel_gradients(self, image, arguments, expected_grad): sobel = SobelGradients(**arguments) grad = sobel(image) assert_allclose(grad, expected_grad) - @parameterized.expand([TEST_CASE_KERNEL_0, TEST_CASE_KERNEL_1, TEST_CASE_KERNEL_2]) - def test_sobel_kernels(self, arguments, expected_kernel): + @parameterized.expand( + [ + TEST_CASE_KERNEL_0, + TEST_CASE_KERNEL_1, + TEST_CASE_KERNEL_2, + TEST_CASE_KERNEL_NON_NORMALIZED_0, + TEST_CASE_KERNEL_NON_NORMALIZED_1, + TEST_CASE_KERNEL_NON_NORMALIZED_2, + ] + ) + def test_sobel_kernels(self, arguments, expected_kernels): sobel = SobelGradients(**arguments) - self.assertTrue(sobel.kernel.dtype == expected_kernel.dtype) - assert_allclose(sobel.kernel, expected_kernel) + self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype) + self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype) + assert_allclose(sobel.kernel_diff, expected_kernels[0]) + assert_allclose(sobel.kernel_smooth, expected_kernels[1]) - @parameterized.expand([TEST_CASE_ERROR_0]) - def test_sobel_gradients_error(self, arguments): + @parameterized.expand( + [ + TEST_CASE_ERROR_0, + TEST_CASE_ERROR_1, + TEST_CASE_ERROR_2, + TEST_CASE_ERROR_3, + TEST_CASE_ERROR_4, + TEST_CASE_ERROR_5, + ] + ) + def test_sobel_gradients_error(self, image, arguments): with self.assertRaises(ValueError): - SobelGradients(**arguments) + sobel = SobelGradients(**arguments) + sobel(image) if __name__ == "__main__": diff --git a/tests/test_sobel_gradientd.py b/tests/test_sobel_gradientd.py index b3e04da0bf..e4de0c4b54 100644 --- a/tests/test_sobel_gradientd.py +++ b/tests/test_sobel_gradientd.py @@ -19,15 +19,21 @@ IMAGE = torch.zeros(1, 16, 16, dtype=torch.float32) IMAGE[0, 8, :] = 1 + +# Output with reflect padding OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32) -OUTPUT_3x3[0, 7, :] = 2.0 -OUTPUT_3x3[0, 9, :] = -2.0 -OUTPUT_3x3[0, 7, 0] = OUTPUT_3x3[0, 7, -1] = 1.5 -OUTPUT_3x3[0, 9, 0] = OUTPUT_3x3[0, 9, -1] = -1.5 -OUTPUT_3x3[1, 7, 0] = OUTPUT_3x3[1, 9, 0] = 0.5 -OUTPUT_3x3[1, 8, 0] = 1.0 -OUTPUT_3x3[1, 8, -1] = -1.0 -OUTPUT_3x3[1, 7, -1] = OUTPUT_3x3[1, 9, -1] = -0.5 +OUTPUT_3x3[1, 7, :] = 0.5 +OUTPUT_3x3[1, 9, :] = -0.5 + +# Output with zero padding +OUTPUT_3x3_ZERO_PAD = OUTPUT_3x3.clone() +OUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, 0] = 0.125 +OUTPUT_3x3_ZERO_PAD[0, 8, 0] = 0.25 +OUTPUT_3x3_ZERO_PAD[0, 7, -1] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -0.125 +OUTPUT_3x3_ZERO_PAD[0, 8, -1] = -0.25 +OUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 7, -1] = 3.0 / 8.0 +OUTPUT_3x3_ZERO_PAD[1, 9, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -3.0 / 8.0 + TEST_CASE_0 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float32}, {"image": OUTPUT_3x3}] TEST_CASE_1 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float64}, {"image": OUTPUT_3x3}] @@ -36,62 +42,169 @@ {"keys": "image", "kernel_size": 3, "dtype": torch.float32, "new_key_prefix": "sobel_"}, {"sobel_image": OUTPUT_3x3}, ] +TEST_CASE_3 = [ + {"image": IMAGE}, + {"keys": "image", "kernel_size": 3, "spatial_axes": 0, "dtype": torch.float32}, + {"image": OUTPUT_3x3[0][None, ...]}, +] +TEST_CASE_4 = [ + {"image": IMAGE}, + {"keys": "image", "kernel_size": 3, "spatial_axes": 1, "dtype": torch.float32}, + {"image": OUTPUT_3x3[1][None, ...]}, +] +TEST_CASE_5 = [ + {"image": IMAGE}, + {"keys": "image", "kernel_size": 3, "spatial_axes": [1], "dtype": torch.float32}, + {"image": OUTPUT_3x3[1][None, ...]}, +] +TEST_CASE_6 = [ + {"image": IMAGE}, + {"keys": "image", "kernel_size": 3, "spatial_axes": [0, 1], "normalize_kernels": True, "dtype": torch.float32}, + {"image": OUTPUT_3x3}, +] +TEST_CASE_7 = [ + {"image": IMAGE}, + {"keys": "image", "kernel_size": 3, "spatial_axes": (0, 1), "padding_mode": "reflect", "dtype": torch.float32}, + {"image": OUTPUT_3x3}, +] +TEST_CASE_8 = [ + {"image": IMAGE}, + {"keys": "image", "kernel_size": 3, "spatial_axes": (0, 1), "padding_mode": "zeros", "dtype": torch.float32}, + {"image": OUTPUT_3x3_ZERO_PAD}, +] +TEST_CASE_9 = [ # Non-normalized kernels + {"image": IMAGE}, + {"keys": "image", "kernel_size": 3, "spatial_axes": (0, 1), "normalize_kernels": False, "dtype": torch.float32}, + {"image": OUTPUT_3x3 * 8.0}, +] +TEST_CASE_10 = [ # Normalized gradients and normalized kernels + {"image": IMAGE}, + { + "keys": "image", + "kernel_size": 3, + "spatial_axes": (0, 1), + "normalize_kernels": True, + "normalize_gradients": True, + "dtype": torch.float32, + }, + {"image": torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5])}, +] +TEST_CASE_11 = [ # Normalized gradients but non-normalized kernels + {"image": IMAGE}, + { + "keys": "image", + "kernel_size": 3, + "spatial_axes": (0, 1), + "normalize_kernels": False, + "normalize_gradients": True, + "dtype": torch.float32, + }, + {"image": torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5])}, +] TEST_CASE_KERNEL_0 = [ {"keys": "image", "kernel_size": 3, "dtype": torch.float64}, - torch.tensor([[-0.5, 0.0, 0.5], [-1.0, 0.0, 1.0], [-0.5, 0.0, 0.5]], dtype=torch.float64), + (torch.tensor([-0.5, 0.0, 0.5], dtype=torch.float64), torch.tensor([0.25, 0.5, 0.25], dtype=torch.float64)), ] TEST_CASE_KERNEL_1 = [ {"keys": "image", "kernel_size": 5, "dtype": torch.float64}, - torch.tensor( - [ - [-0.25, -0.2, 0.0, 0.2, 0.25], - [-0.4, -0.5, 0.0, 0.5, 0.4], - [-0.5, -1.0, 0.0, 1.0, 0.5], - [-0.4, -0.5, 0.0, 0.5, 0.4], - [-0.25, -0.2, 0.0, 0.2, 0.25], - ], - dtype=torch.float64, + ( + torch.tensor([-0.1250, -0.2500, 0.0000, 0.2500, 0.1250], dtype=torch.float64), + torch.tensor([0.0625, 0.2500, 0.3750, 0.2500, 0.0625], dtype=torch.float64), ), ] TEST_CASE_KERNEL_2 = [ {"keys": "image", "kernel_size": 7, "dtype": torch.float64}, - torch.tensor( - [ - [-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0], - [-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0], - [-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0], - [-3.0 / 9.0, -2.0 / 4.0, -1.0 / 1.0, 0.0, 1.0 / 1.0, 2.0 / 4.0, 3.0 / 9.0], - [-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0], - [-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0], - [-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0], - ], - dtype=torch.float64, + ( + torch.tensor([-0.03125, -0.125, -0.15625, 0.0, 0.15625, 0.125, 0.03125], dtype=torch.float64), + torch.tensor([0.015625, 0.09375, 0.234375, 0.3125, 0.234375, 0.09375, 0.015625], dtype=torch.float64), + ), +] +TEST_CASE_KERNEL_NON_NORMALIZED_0 = [ + {"keys": "image", "kernel_size": 3, "normalize_kernels": False, "dtype": torch.float64}, + (torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float64), torch.tensor([1.0, 2.0, 1.0], dtype=torch.float64)), +] +TEST_CASE_KERNEL_NON_NORMALIZED_1 = [ + {"keys": "image", "kernel_size": 5, "normalize_kernels": False, "dtype": torch.float64}, + ( + torch.tensor([-1.0, -2.0, 0.0, 2.0, 1.0], dtype=torch.float64), + torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0], dtype=torch.float64), + ), +] +TEST_CASE_KERNEL_NON_NORMALIZED_2 = [ + {"keys": "image", "kernel_size": 7, "normalize_kernels": False, "dtype": torch.float64}, + ( + torch.tensor([-1.0, -4.0, -5.0, 0.0, 5.0, 4.0, 1.0], dtype=torch.float64), + torch.tensor([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0], dtype=torch.float64), ), ] -TEST_CASE_ERROR_0 = [{"keys": "image", "kernel_size": 2, "dtype": torch.float32}] +TEST_CASE_ERROR_0 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 1}] # kernel size less than 3 +TEST_CASE_ERROR_1 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 4}] # even kernel size +TEST_CASE_ERROR_2 = [{"image": IMAGE}, {"keys": "image", "spatial_axes": "horizontal"}] # wrong type direction +TEST_CASE_ERROR_3 = [{"image": IMAGE}, {"keys": "image", "spatial_axes": 3}] # wrong direction +TEST_CASE_ERROR_4 = [{"image": IMAGE}, {"keys": "image", "spatial_axes": [3]}] # wrong direction in a list +TEST_CASE_ERROR_5 = [ + {"image": IMAGE}, + {"keys": "image", "spatial_axes": [0, 4]}, +] # correct and wrong direction in a list class SobelGradientTests(unittest.TestCase): backend = None - @parameterized.expand([TEST_CASE_0]) + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + ] + ) def test_sobel_gradients(self, image_dict, arguments, expected_grad): sobel = SobelGradientsd(**arguments) grad = sobel(image_dict) key = "image" if "new_key_prefix" not in arguments else arguments["new_key_prefix"] + arguments["keys"] assert_allclose(grad[key], expected_grad[key]) - @parameterized.expand([TEST_CASE_KERNEL_0, TEST_CASE_KERNEL_1, TEST_CASE_KERNEL_2]) - def test_sobel_kernels(self, arguments, expected_kernel): + @parameterized.expand( + [ + TEST_CASE_KERNEL_0, + TEST_CASE_KERNEL_1, + TEST_CASE_KERNEL_2, + TEST_CASE_KERNEL_NON_NORMALIZED_0, + TEST_CASE_KERNEL_NON_NORMALIZED_1, + TEST_CASE_KERNEL_NON_NORMALIZED_2, + ] + ) + def test_sobel_kernels(self, arguments, expected_kernels): sobel = SobelGradientsd(**arguments) - self.assertTrue(sobel.transform.kernel.dtype == expected_kernel.dtype) - assert_allclose(sobel.transform.kernel, expected_kernel) + self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype) + self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype) + assert_allclose(sobel.kernel_diff, expected_kernels[0]) + assert_allclose(sobel.kernel_smooth, expected_kernels[1]) - @parameterized.expand([TEST_CASE_ERROR_0]) - def test_sobel_gradients_error(self, arguments): + @parameterized.expand( + [ + TEST_CASE_ERROR_0, + TEST_CASE_ERROR_1, + TEST_CASE_ERROR_2, + TEST_CASE_ERROR_3, + TEST_CASE_ERROR_4, + TEST_CASE_ERROR_5, + ] + ) + def test_sobel_gradients_error(self, image_dict, arguments): with self.assertRaises(ValueError): - SobelGradientsd(**arguments) + sobel = SobelGradientsd(**arguments) + sobel(image_dict) if __name__ == "__main__":