Skip to content
14 changes: 14 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,12 @@ Post-processing
.. autoclass:: ProbNMS
:members:

`SobelGradients`
""""""""""""""""
.. autoclass:: SobelGradients
:members:
:special-members: __call__

`VoteEnsemble`
""""""""""""""
.. autoclass:: VoteEnsemble
Expand Down Expand Up @@ -1593,6 +1599,14 @@ Post-processing (Dict)
:members:
:special-members: __call__


`SobelGradientsd`
"""""""""""""""""
.. autoclass:: SobelGradientsd
:members:
:special-members: __call__


Spatial (Dict)
^^^^^^^^^^^^^^

Expand Down
3 changes: 3 additions & 0 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso
else:
# even-sized kernels are not supported
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
elif kwargs["padding"] == "same" and not pytorch_after(1, 10):
# even-sized kernels are not supported
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]

if "stride" not in kwargs:
kwargs["stride"] = 1
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@
MeanEnsemble,
ProbNMS,
RemoveSmallObjects,
SobelGradients,
VoteEnsemble,
)
from .post.dictionary import (
Expand Down Expand Up @@ -307,6 +308,9 @@
SaveClassificationD,
SaveClassificationd,
SaveClassificationDict,
SobelGradientsd,
SobelGradientsD,
SobelGradientsDict,
VoteEnsembleD,
VoteEnsembled,
VoteEnsembleDict,
Expand Down
51 changes: 51 additions & 0 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"LabelToContour",
"MeanEnsemble",
"ProbNMS",
"SobelGradients",
"VoteEnsemble",
"Invert",
]
Expand Down Expand Up @@ -852,3 +853,53 @@ def __call__(self, data):
inverted = self.transform.inverse(data)
inverted = self.post_func(inverted.to(self.device))
return inverted


class SobelGradients(Transform):
"""Calculate Sobel horizontal and vertical gradients

Args:
kernel_size: the size of the Sobel kernel. Defaults to 3.
padding: the padding for the convolution to apply the kernel. Defaults to `"same"`.
dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
device: the device to create the kernel on. Defaults to `"cpu"`.

"""

backend = [TransformBackends.TORCH]

def __init__(
self,
kernel_size: int = 3,
padding: Union[int, str] = "same",
dtype: torch.dtype = torch.float32,
device: Union[torch.device, int, str] = "cpu",
) -> None:
super().__init__()
self.kernel: torch.Tensor = self._get_kernel(kernel_size, dtype, device)
self.padding = padding

def _get_kernel(self, size, dtype, device) -> torch.Tensor:
if size % 2 == 0:
raise ValueError(f"Sobel kernel size should be an odd number. {size} was given.")
if not dtype.is_floating_point:
raise ValueError(f"`dtype` for Sobel kernel should be floating point. {dtype} was given.")

numerator: torch.Tensor = torch.arange(
-size // 2 + 1, size // 2 + 1, dtype=dtype, device=device, requires_grad=False
).expand(size, size)
denominator = numerator * numerator
denominator = denominator + denominator.T
denominator[:, size // 2] = 1.0 # to avoid division by zero
kernel = numerator / denominator
return kernel

def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
image_tensor = convert_to_tensor(image, track_meta=get_track_meta())
kernel_v = self.kernel.to(image_tensor.device)
kernel_h = kernel_v.T
grad_v = apply_filter(image_tensor, kernel_v, padding=self.padding)
grad_h = apply_filter(image_tensor, kernel_h, padding=self.padding)
grad = torch.cat([grad_h, grad_v])

return grad
40 changes: 40 additions & 0 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
MeanEnsemble,
ProbNMS,
RemoveSmallObjects,
SobelGradients,
VoteEnsemble,
)
from monai.transforms.transform import MapTransform
Expand Down Expand Up @@ -795,6 +796,44 @@ def get_saver(self):
return self.saver


class SobelGradientsd(MapTransform):
"""Calculate Sobel horizontal and vertical gradients.

Args:
keys: keys of the corresponding items to model output.
kernel_size: the size of the Sobel kernel. Defaults to 3.
padding: the padding for the convolution to apply the kernel. Defaults to `"same"`.
dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
device: the device to create the kernel on. Defaults to `"cpu"`.
new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of
key intact. By default not prefix is set and the corresponding array to the key will be replaced.
allow_missing_keys: don't raise exception if key is missing.

"""

def __init__(
self,
keys: KeysCollection,
kernel_size: int = 3,
padding: Union[int, str] = "same",
dtype: torch.dtype = torch.float32,
device: Union[torch.device, int, str] = "cpu",
new_key_prefix: Optional[str] = None,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.transform = SobelGradients(kernel_size=kernel_size, padding=padding, dtype=dtype, device=device)
self.new_key_prefix = new_key_prefix

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
new_key = key if self.new_key_prefix is None else self.new_key_prefix + key
d[new_key] = self.transform(d[key])

return d


ActivationsD = ActivationsDict = Activationsd
AsDiscreteD = AsDiscreteDict = AsDiscreted
FillHolesD = FillHolesDict = FillHolesd
Expand All @@ -808,3 +847,4 @@ def get_saver(self):
SaveClassificationD = SaveClassificationDict = SaveClassificationd
VoteEnsembleD = VoteEnsembleDict = VoteEnsembled
EnsembleD = EnsembleDict = Ensembled
SobelGradientsD = SobelGradientsDict = SobelGradientsd
93 changes: 93 additions & 0 deletions tests/test_sobel_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms import SobelGradients
from tests.utils import assert_allclose

IMAGE = torch.zeros(1, 1, 16, 16, dtype=torch.float32)
IMAGE[0, 0, 8, :] = 1
OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)
OUTPUT_3x3[0, 7, :] = 2.0
OUTPUT_3x3[0, 9, :] = -2.0
OUTPUT_3x3[0, 7, 0] = OUTPUT_3x3[0, 7, -1] = 1.5
OUTPUT_3x3[0, 9, 0] = OUTPUT_3x3[0, 9, -1] = -1.5
OUTPUT_3x3[1, 7, 0] = OUTPUT_3x3[1, 9, 0] = 0.5
OUTPUT_3x3[1, 8, 0] = 1.0
OUTPUT_3x3[1, 8, -1] = -1.0
OUTPUT_3x3[1, 7, -1] = OUTPUT_3x3[1, 9, -1] = -0.5
OUTPUT_3x3 = OUTPUT_3x3.unsqueeze(1)

TEST_CASE_0 = [IMAGE, {"kernel_size": 3, "dtype": torch.float32}, OUTPUT_3x3]
TEST_CASE_1 = [IMAGE, {"kernel_size": 3, "dtype": torch.float64}, OUTPUT_3x3]

TEST_CASE_KERNEL_0 = [
{"kernel_size": 3, "dtype": torch.float64},
torch.tensor([[-0.5, 0.0, 0.5], [-1.0, 0.0, 1.0], [-0.5, 0.0, 0.5]], dtype=torch.float64),
]
TEST_CASE_KERNEL_1 = [
{"kernel_size": 5, "dtype": torch.float64},
torch.tensor(
[
[-0.25, -0.2, 0.0, 0.2, 0.25],
[-0.4, -0.5, 0.0, 0.5, 0.4],
[-0.5, -1.0, 0.0, 1.0, 0.5],
[-0.4, -0.5, 0.0, 0.5, 0.4],
[-0.25, -0.2, 0.0, 0.2, 0.25],
],
dtype=torch.float64,
),
]
TEST_CASE_KERNEL_2 = [
{"kernel_size": 7, "dtype": torch.float64},
torch.tensor(
[
[-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0],
[-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0],
[-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0],
[-3.0 / 9.0, -2.0 / 4.0, -1.0 / 1.0, 0.0, 1.0 / 1.0, 2.0 / 4.0, 3.0 / 9.0],
[-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0],
[-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0],
[-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0],
],
dtype=torch.float64,
),
]
TEST_CASE_ERROR_0 = [{"kernel_size": 2, "dtype": torch.float32}]


class SobelGradientTests(unittest.TestCase):
backend = None

@parameterized.expand([TEST_CASE_0])
def test_sobel_gradients(self, image, arguments, expected_grad):
sobel = SobelGradients(**arguments)
grad = sobel(image)
assert_allclose(grad, expected_grad)

@parameterized.expand([TEST_CASE_KERNEL_0, TEST_CASE_KERNEL_1, TEST_CASE_KERNEL_2])
def test_sobel_kernels(self, arguments, expected_kernel):
sobel = SobelGradients(**arguments)
self.assertTrue(sobel.kernel.dtype == expected_kernel.dtype)
assert_allclose(sobel.kernel, expected_kernel)

@parameterized.expand([TEST_CASE_ERROR_0])
def test_sobel_gradients_error(self, arguments):
with self.assertRaises(ValueError):
SobelGradients(**arguments)


if __name__ == "__main__":
unittest.main()
99 changes: 99 additions & 0 deletions tests/test_sobel_gradientd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms import SobelGradientsd
from tests.utils import assert_allclose

IMAGE = torch.zeros(1, 1, 16, 16, dtype=torch.float32)
IMAGE[0, 0, 8, :] = 1
OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)
OUTPUT_3x3[0, 7, :] = 2.0
OUTPUT_3x3[0, 9, :] = -2.0
OUTPUT_3x3[0, 7, 0] = OUTPUT_3x3[0, 7, -1] = 1.5
OUTPUT_3x3[0, 9, 0] = OUTPUT_3x3[0, 9, -1] = -1.5
OUTPUT_3x3[1, 7, 0] = OUTPUT_3x3[1, 9, 0] = 0.5
OUTPUT_3x3[1, 8, 0] = 1.0
OUTPUT_3x3[1, 8, -1] = -1.0
OUTPUT_3x3[1, 7, -1] = OUTPUT_3x3[1, 9, -1] = -0.5
OUTPUT_3x3 = OUTPUT_3x3.unsqueeze(1)

TEST_CASE_0 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float32}, {"image": OUTPUT_3x3}]
TEST_CASE_1 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float64}, {"image": OUTPUT_3x3}]
TEST_CASE_2 = [
{"image": IMAGE},
{"keys": "image", "kernel_size": 3, "dtype": torch.float32, "new_key_prefix": "sobel_"},
{"sobel_image": OUTPUT_3x3},
]

TEST_CASE_KERNEL_0 = [
{"keys": "image", "kernel_size": 3, "dtype": torch.float64},
torch.tensor([[-0.5, 0.0, 0.5], [-1.0, 0.0, 1.0], [-0.5, 0.0, 0.5]], dtype=torch.float64),
]
TEST_CASE_KERNEL_1 = [
{"keys": "image", "kernel_size": 5, "dtype": torch.float64},
torch.tensor(
[
[-0.25, -0.2, 0.0, 0.2, 0.25],
[-0.4, -0.5, 0.0, 0.5, 0.4],
[-0.5, -1.0, 0.0, 1.0, 0.5],
[-0.4, -0.5, 0.0, 0.5, 0.4],
[-0.25, -0.2, 0.0, 0.2, 0.25],
],
dtype=torch.float64,
),
]
TEST_CASE_KERNEL_2 = [
{"keys": "image", "kernel_size": 7, "dtype": torch.float64},
torch.tensor(
[
[-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0],
[-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0],
[-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0],
[-3.0 / 9.0, -2.0 / 4.0, -1.0 / 1.0, 0.0, 1.0 / 1.0, 2.0 / 4.0, 3.0 / 9.0],
[-3.0 / 10.0, -2.0 / 5.0, -1.0 / 2.0, 0.0, 1.0 / 2.0, 2.0 / 5.0, 3.0 / 10.0],
[-3.0 / 13.0, -2.0 / 8.0, -1.0 / 5.0, 0.0, 1.0 / 5.0, 2.0 / 8.0, 3.0 / 13.0],
[-3.0 / 18.0, -2.0 / 13.0, -1.0 / 10.0, 0.0, 1.0 / 10.0, 2.0 / 13.0, 3.0 / 18.0],
],
dtype=torch.float64,
),
]
TEST_CASE_ERROR_0 = [{"keys": "image", "kernel_size": 2, "dtype": torch.float32}]


class SobelGradientTests(unittest.TestCase):
backend = None

@parameterized.expand([TEST_CASE_0])
def test_sobel_gradients(self, image_dict, arguments, expected_grad):
sobel = SobelGradientsd(**arguments)
grad = sobel(image_dict)
key = "image" if "new_key_prefix" not in arguments else arguments["new_key_prefix"] + arguments["keys"]
assert_allclose(grad[key], expected_grad[key])

@parameterized.expand([TEST_CASE_KERNEL_0, TEST_CASE_KERNEL_1, TEST_CASE_KERNEL_2])
def test_sobel_kernels(self, arguments, expected_kernel):
sobel = SobelGradientsd(**arguments)
self.assertTrue(sobel.transform.kernel.dtype == expected_kernel.dtype)
assert_allclose(sobel.transform.kernel, expected_kernel)

@parameterized.expand([TEST_CASE_ERROR_0])
def test_sobel_gradients_error(self, arguments):
with self.assertRaises(ValueError):
SobelGradientsd(**arguments)


if __name__ == "__main__":
unittest.main()