Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1223a87
Add generic kernel transform with support for multiple kernels
kbressem Oct 12, 2022
320c4fa
Merge branch 'dev' into 3178-generic-filterkernel-transform
kbressem Nov 14, 2022
5ba3e10
Rewrite ImageFilter
kbressem Nov 14, 2022
e1932f2
Merge branch '3178-generic-filterkernel-transform' of https://github.…
kbressem Nov 14, 2022
32aa261
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2022
3bbd099
Merge branch 'dev' into 3178-generic-filterkernel-transform
kbressem Nov 15, 2022
fc42b96
add missing unit tests
kbressem Dec 30, 2022
8f8cacf
Merge branch 'dev' into 3178-generic-filterkernel-transform
kbressem Dec 30, 2022
4d1ac7c
runtest autofix
kbressem Dec 30, 2022
71ccc10
black
kbressem Dec 30, 2022
f59f7fc
reduce line length in ImageFilter
kbressem Dec 30, 2022
d265c82
fix mypy errors
kbressem Dec 30, 2022
d9f484e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2022
8fd75fd
black
kbressem Dec 30, 2022
7648a6c
fix unit tests and codestyle
kbressem Dec 31, 2022
203fbbe
update transforms.rst in docs
kbressem Dec 31, 2022
f753373
increase length of title underline
kbressem Dec 31, 2022
34d82d6
remove indent in docstring
kbressem Dec 31, 2022
7023298
fix docstring, add ignore for mypy
kbressem Dec 31, 2022
9caf72a
change default padding for SobelGradient when running tests for torch…
kbressem Dec 31, 2022
7fc2064
add tests for ApplyFilter
kbressem Dec 31, 2022
37f7802
remove packaging from imports
kbressem Dec 31, 2022
cc04698
Merge branch 'dev' into 3178-generic-filterkernel-transform
wyli Jan 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,17 @@ Utility
:members:
:special-members: __call__

`ImageFilter`
"""""""""""""
.. autoclass:: ImageFilter
:members:
:special-members: __call__

`RandImageFilter`
"""""""""""""""""
.. autoclass:: RandImageFilter
:members:
:special-members: __call__

Dictionary Transforms
---------------------
Expand Down Expand Up @@ -2124,6 +2135,19 @@ Utility (Dict)
:members:
:special-members: __call__

`ImageFilterd`
""""""""""""""
.. autoclass:: ImageFilterd
:members:
:special-members: __call__

`RandImageFilterd`
""""""""""""""""""
.. autoclass:: RandImageFilterd
:members:
:special-members: __call__


MetaTensor
^^^^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
from .gmm import GaussianMixtureModel
from .simplelayers import (
LLTM,
ApplyFilter,
ChannelPad,
EllipticalFilter,
Flatten,
GaussianFilter,
HilbertTransform,
LaplaceFilter,
MeanFilter,
MedianFilter,
Reshape,
SavitzkyGolayFilter,
SharpenFilter,
SkipConnection,
apply_filter,
median_filter,
Expand Down
88 changes: 88 additions & 0 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import nn
from torch.autograd import Function

from monai.config.type_definitions import NdarrayOrTensor
from monai.networks.layers.convutils import gaussian_1d
from monai.networks.layers.factories import Conv
from monai.utils import (
Expand Down Expand Up @@ -658,3 +659,90 @@ def reset_parameters(self):

def forward(self, input, state):
return LLTMFunction.apply(input, self.weights, self.bias, *state)


class ApplyFilter(nn.Module):
"Wrapper class to apply a filter to an image."

def __init__(self, filter: NdarrayOrTensor) -> None:
super().__init__()

self.filter = convert_to_tensor(filter, dtype=torch.float32)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return apply_filter(x, self.filter) # type: ignore


class MeanFilter(ApplyFilter):
"""
Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image.
The mean filter used, is a `torch.Tensor` of all ones.
"""

def __init__(self, spatial_dims: int, size: int) -> None:
"""
Args:
spatial_dims: `int` of either 2 for 2D images and 3 for 3D images
size: edge length of the filter
"""
filter = torch.ones([size] * spatial_dims)
filter = filter
super().__init__(filter=filter)


class LaplaceFilter(ApplyFilter):
"""
Laplacian filtering for outline detection in images. Can be used to transform labels to contours.
The laplace filter used, is a `torch.Tensor` where all values are -1, except the center value
which is `size` ** `spatial_dims`
"""

def __init__(self, spatial_dims: int, size: int) -> None:
"""
Args:
spatial_dims: `int` of either 2 for 2D images and 3 for 3D images
size: edge length of the filter
"""
filter = torch.zeros([size] * spatial_dims).float() - 1 # make all -1
center_point = tuple([size // 2] * spatial_dims)
filter[center_point] = (size**spatial_dims) - 1
super().__init__(filter=filter)


class EllipticalFilter(ApplyFilter):
"""
Elliptical filter, can be used to dilate labels or label-contours.
The elliptical filter used here, is a `torch.Tensor` with shape (size, ) * ndim containing a circle/sphere of `1`
"""

def __init__(self, spatial_dims: int, size: int) -> None:
"""
Args:
spatial_dims: `int` of either 2 for 2D images and 3 for 3D images
size: edge length of the filter
"""
radius = size // 2
grid = torch.meshgrid(*[torch.arange(0, size) for _ in range(spatial_dims)])
squared_distances = torch.stack([(axis - radius) ** 2 for axis in grid], 0).sum(0)
filter = squared_distances <= radius**2
super().__init__(filter=filter)


class SharpenFilter(EllipticalFilter):
"""
Convolutional filter to sharpen a 2D or 3D image.
The filter used contains a circle/sphere of `-1`, with the center value being
the absolute sum of all non-zero elements in the kernel
"""

def __init__(self, spatial_dims: int, size: int) -> None:
"""
Args:
spatial_dims: `int` of either 2 for 2D images and 3 for 3D images
size: edge length of the filter
"""
super().__init__(spatial_dims=spatial_dims, size=size)
center_point = tuple([size // 2] * spatial_dims)
center_value = self.filter.sum()
self.filter *= -1
self.filter[center_point] = center_value
8 changes: 8 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,13 @@
EnsureType,
FgBgToIndices,
Identity,
ImageFilter,
IntensityStats,
LabelToMask,
Lambda,
MapLabelValue,
RandCuCIM,
RandImageFilter,
RandLambda,
RemoveRepeatedChannel,
RepeatChannel,
Expand Down Expand Up @@ -553,6 +555,9 @@
Identityd,
IdentityD,
IdentityDict,
ImageFilterd,
ImageFilterD,
ImageFilterDict,
IntensityStatsd,
IntensityStatsD,
IntensityStatsDict,
Expand All @@ -568,6 +573,9 @@
RandCuCIMd,
RandCuCIMD,
RandCuCIMDict,
RandImageFilterd,
RandImageFilterD,
RandImageFilterDict,
RandLambdad,
RandLambdaD,
RandLambdaDict,
Expand Down
Loading