Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3503cc2
Update SobelGradients to include gradient direction
bhashemian Sep 21, 2022
c3dcb0d
Add more checking to raise errors
bhashemian Sep 21, 2022
7b26dd5
Add test cases
bhashemian Sep 21, 2022
852eba8
Update SobelGradientsd
bhashemian Sep 21, 2022
a1483dc
Add test cases for SobelGradientsD
bhashemian Sep 21, 2022
f3335c8
Merge branch 'dev' into update-sobel-direction
bhashemian Sep 21, 2022
2b2b28e
Update docstring
bhashemian Sep 21, 2022
48125e6
Address comments
bhashemian Sep 22, 2022
a667fc4
Merge branch 'dev' into update-sobel-direction
bhashemian Sep 23, 2022
54694b4
Type checking
bhashemian Sep 23, 2022
85e8881
Merge branch 'update-sobel-direction' of github.com:drbeh/MONAI into …
bhashemian Sep 23, 2022
b782753
Merge branch 'dev' into update-sobel-direction
bhashemian Sep 26, 2022
93464a4
Merge branch 'dev' into update-sobel-direction
bhashemian Sep 27, 2022
71018ed
Merge dev
bhashemian Oct 21, 2022
7a1c176
Merge branch 'dev' of github.com:Project-MONAI/MONAI into update-sobe…
bhashemian Oct 25, 2022
49b94d4
Merge branch 'dev' of github.com:Project-MONAI/MONAI into update-sobe…
bhashemian Oct 26, 2022
0c16ceb
Reimplementation of sobel with separable kernels
bhashemian Oct 26, 2022
5bfa852
Remove Direction from init
bhashemian Oct 26, 2022
a1cf946
Update unittests for sobel
bhashemian Oct 26, 2022
43e845a
Change arguments and add additional test case
bhashemian Oct 26, 2022
6f29bf3
Update sobel grad dict and related unittests
bhashemian Oct 26, 2022
a26e340
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2022
feedaf5
Remove unused imports
bhashemian Oct 26, 2022
695cdcc
Merge branch 'update-sobel-direction' of github.com:behxyz/MONAI into…
bhashemian Oct 26, 2022
ab3b2d2
Minor renaming
bhashemian Oct 26, 2022
be0570d
formatting
bhashemian Oct 26, 2022
69ff390
formatting
bhashemian Oct 26, 2022
d97c579
Merge branch 'dev' into update-sobel-direction
bhashemian Oct 27, 2022
2983edb
Reverse gradient direction
bhashemian Oct 27, 2022
8841685
Merge branch 'update-sobel-direction' of github.com:behxyz/MONAI into…
bhashemian Oct 27, 2022
7eb6ae5
Update sobel unittests
bhashemian Oct 27, 2022
a09c2ac
Add normalize kernels and normalize gradients
bhashemian Oct 27, 2022
6848a38
Update unitests and add new test cases
bhashemian Oct 27, 2022
3977784
Update hovernet unittests
bhashemian Oct 27, 2022
9146714
Merge branch 'dev' into update-sobel-direction
bhashemian Oct 28, 2022
45c0cea
formatting
bhashemian Oct 28, 2022
8f23120
Make less call to min and max
bhashemian Oct 28, 2022
8a9c4da
Merge branch 'dev' into update-sobel-direction
bhashemian Oct 28, 2022
e399fb4
Merge branch 'dev' into update-sobel-direction
Nic-Ma Oct 31, 2022
d4d8f4a
Merge branch 'dev' into update-sobel-direction
wyli Oct 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 90 additions & 30 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
"""

import warnings
from typing import Callable, Iterable, Optional, Sequence, Union
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F

from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.networks import one_hot
from monai.networks.layers import GaussianFilter, apply_filter
from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import Transform
from monai.transforms.utils import (
Expand Down Expand Up @@ -821,13 +822,18 @@ def __call__(self, data):


class SobelGradients(Transform):
"""Calculate Sobel horizontal and vertical gradients
"""Calculate Sobel gradients of a grayscale image with the shape of (CxH[xWxDx...]).

Args:
kernel_size: the size of the Sobel kernel. Defaults to 3.
padding: the padding for the convolution to apply the kernel. Defaults to `"same"`.
spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient
along each of the provide axis. By default it calculate the gradient for all spatial axes.
normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.
padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`.
Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
See ``torch.nn.Conv1d()`` for more information.
dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
device: the device to create the kernel on. Defaults to `"cpu"`.

"""

Expand All @@ -836,36 +842,90 @@ class SobelGradients(Transform):
def __init__(
self,
kernel_size: int = 3,
padding: Union[int, str] = "same",
spatial_axes: Optional[Union[Sequence[int], int]] = None,
normalize_kernels: bool = True,
normalize_gradients: bool = False,
padding_mode: str = "reflect",
dtype: torch.dtype = torch.float32,
device: Union[torch.device, int, str] = "cpu",
) -> None:
super().__init__()
self.kernel: torch.Tensor = self._get_kernel(kernel_size, dtype, device)
self.padding = padding

def _get_kernel(self, size, dtype, device) -> torch.Tensor:
self.padding = padding_mode
self.spatial_axes = spatial_axes
self.normalize_kernels = normalize_kernels
self.normalize_gradients = normalize_gradients
self.kernel_diff, self.kernel_smooth = self._get_kernel(kernel_size, dtype)

def _get_kernel(self, size, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
if size < 3:
raise ValueError(f"Sobel kernel size should be at least three. {size} was given.")
if size % 2 == 0:
raise ValueError(f"Sobel kernel size should be an odd number. {size} was given.")
if not dtype.is_floating_point:
raise ValueError(f"`dtype` for Sobel kernel should be floating point. {dtype} was given.")

numerator: torch.Tensor = torch.arange(
-size // 2 + 1, size // 2 + 1, dtype=dtype, device=device, requires_grad=False
).expand(size, size)
denominator = numerator * numerator
denominator = denominator + denominator.T
denominator[:, size // 2] = 1.0 # to avoid division by zero
kernel = numerator / denominator
return kernel

kernel_diff = torch.tensor([[[-1, 0, 1]]], dtype=dtype)
kernel_smooth = torch.tensor([[[1, 2, 1]]], dtype=dtype)
kernel_expansion = torch.tensor([[[1, 2, 1]]], dtype=dtype)

if self.normalize_kernels:
if not dtype.is_floating_point:
raise ValueError(
f"`dtype` for Sobel kernel should be floating point when `normalize_kernel==True`. {dtype} was given."
)
kernel_diff /= 2.0
kernel_smooth /= 4.0
kernel_expansion /= 4.0

# Expand the kernel to larger size than 3
expand = (size - 3) // 2
for _ in range(expand):
kernel_diff = F.conv1d(kernel_diff, kernel_expansion, padding=2)
kernel_smooth = F.conv1d(kernel_smooth, kernel_expansion, padding=2)

return kernel_diff.squeeze(), kernel_smooth.squeeze()

def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
image_tensor = convert_to_tensor(image, track_meta=get_track_meta())
kernel_v = self.kernel.to(image_tensor.device)
kernel_h = kernel_v.T
image_tensor = image_tensor.unsqueeze(0) # adds a batch dim
grad_v = apply_filter(image_tensor, kernel_v, padding=self.padding)
grad_h = apply_filter(image_tensor, kernel_h, padding=self.padding)
grad = torch.cat([grad_h, grad_v], dim=1)
grad, *_ = convert_to_dst_type(grad.squeeze(0), image_tensor)
return grad

# Check/set spatial axes
n_spatial_dims = image_tensor.ndim - 1 # excluding the channel dimension
valid_spatial_axes = list(range(n_spatial_dims)) + list(range(-n_spatial_dims, 0))

# Check gradient axes to be valid
if self.spatial_axes is None:
spatial_axes = list(range(n_spatial_dims))
else:
invalid_axis = set(ensure_tuple(self.spatial_axes)) - set(valid_spatial_axes)
if invalid_axis:
raise ValueError(
f"The provide axes to calculate gradient is not valid: {invalid_axis}. "
f"The image has {n_spatial_dims} spatial dimensions so it should be: {valid_spatial_axes}."
)
spatial_axes = [ax % n_spatial_dims if ax < 0 else ax for ax in ensure_tuple(self.spatial_axes)]

# Add batch dimension for separable_filtering
image_tensor = image_tensor.unsqueeze(0)

# Get the Sobel kernels
kernel_diff = self.kernel_diff.to(image_tensor.device)
kernel_smooth = self.kernel_smooth.to(image_tensor.device)

# Calculate gradient
grad_list = []
for ax in spatial_axes:
kernels = [kernel_smooth] * n_spatial_dims
kernels[ax - 1] = kernel_diff
grad = separable_filtering(image_tensor, kernels, mode=self.padding)
if self.normalize_gradients:
grad_min = grad.min()
if grad_min != grad.max():
grad -= grad_min
grad_max = grad.max()
if grad_max > 0:
grad /= grad_max
grad_list.append(grad)

grads = torch.cat(grad_list, dim=1)

# Remove batch dimension and convert the gradient type to be the same as input image
grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0]

return grads
28 changes: 22 additions & 6 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,14 +794,19 @@ def get_saver(self):


class SobelGradientsd(MapTransform):
"""Calculate Sobel horizontal and vertical gradients.
"""Calculate Sobel horizontal and vertical gradients of a grayscale image.

Args:
keys: keys of the corresponding items to model output.
kernel_size: the size of the Sobel kernel. Defaults to 3.
padding: the padding for the convolution to apply the kernel. Defaults to `"same"`.
spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient
along each of the provide axis. By default it calculate the gradient for all spatial axes.
normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.
padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`.
Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
See ``torch.nn.Conv1d()`` for more information.
dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
device: the device to create the kernel on. Defaults to `"cpu"`.
new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of
key intact. By default not prefix is set and the corresponding array to the key will be replaced.
allow_missing_keys: don't raise exception if key is missing.
Expand All @@ -814,15 +819,26 @@ def __init__(
self,
keys: KeysCollection,
kernel_size: int = 3,
padding: Union[int, str] = "same",
spatial_axes: Optional[Union[Sequence[int], int]] = None,
normalize_kernels: bool = True,
normalize_gradients: bool = False,
padding_mode: str = "reflect",
dtype: torch.dtype = torch.float32,
device: Union[torch.device, int, str] = "cpu",
new_key_prefix: Optional[str] = None,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.transform = SobelGradients(kernel_size=kernel_size, padding=padding, dtype=dtype, device=device)
self.transform = SobelGradients(
kernel_size=kernel_size,
spatial_axes=spatial_axes,
normalize_kernels=normalize_kernels,
normalize_gradients=normalize_gradients,
padding_mode=padding_mode,
dtype=dtype,
)
self.new_key_prefix = new_key_prefix
self.kernel_diff = self.transform.kernel_diff
self.kernel_smooth = self.transform.kernel_smooth

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_hovernet_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,17 @@ def test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, w

TEST_CASE_3 = [ # batch size of 2, 3 classes with minor rotation of nuclear prediction
{"prediction": inputs_test[3].inputs, "target": inputs_test[3].targets},
6.5777,
3.6169,
]

TEST_CASE_4 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
{"prediction": inputs_test[4].inputs, "target": inputs_test[4].targets},
8.5143,
4.5079,
]

TEST_CASE_5 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
{"prediction": inputs_test[5].inputs, "target": inputs_test[5].targets},
10.1705,
5.4663,
]

CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]
Expand Down
Loading