diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index c92d6fe46a..c0b8af71e1 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1134,6 +1134,17 @@ Utility :members: :special-members: __call__ +`ImageFilter` +""""""""""""" +.. autoclass:: ImageFilter + :members: + :special-members: __call__ + +`RandImageFilter` +""""""""""""""""" +.. autoclass:: RandImageFilter + :members: + :special-members: __call__ Dictionary Transforms --------------------- @@ -2124,6 +2135,19 @@ Utility (Dict) :members: :special-members: __call__ +`ImageFilterd` +"""""""""""""" +.. autoclass:: ImageFilterd + :members: + :special-members: __call__ + +`RandImageFilterd` +"""""""""""""""""" +.. autoclass:: RandImageFilterd + :members: + :special-members: __call__ + + MetaTensor ^^^^^^^^^^ diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 31bd36dd8f..5d91cf66f6 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -16,13 +16,18 @@ from .gmm import GaussianMixtureModel from .simplelayers import ( LLTM, + ApplyFilter, ChannelPad, + EllipticalFilter, Flatten, GaussianFilter, HilbertTransform, + LaplaceFilter, + MeanFilter, MedianFilter, Reshape, SavitzkyGolayFilter, + SharpenFilter, SkipConnection, apply_filter, median_filter, diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index bc72f8d1d3..ca7997e272 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -18,6 +18,7 @@ from torch import nn from torch.autograd import Function +from monai.config.type_definitions import NdarrayOrTensor from monai.networks.layers.convutils import gaussian_1d from monai.networks.layers.factories import Conv from monai.utils import ( @@ -658,3 +659,90 @@ def reset_parameters(self): def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state) + + +class ApplyFilter(nn.Module): + "Wrapper class to apply a filter to an image." + + def __init__(self, filter: NdarrayOrTensor) -> None: + super().__init__() + + self.filter = convert_to_tensor(filter, dtype=torch.float32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return apply_filter(x, self.filter) # type: ignore + + +class MeanFilter(ApplyFilter): + """ + Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image. + The mean filter used, is a `torch.Tensor` of all ones. + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + filter = torch.ones([size] * spatial_dims) + filter = filter + super().__init__(filter=filter) + + +class LaplaceFilter(ApplyFilter): + """ + Laplacian filtering for outline detection in images. Can be used to transform labels to contours. + The laplace filter used, is a `torch.Tensor` where all values are -1, except the center value + which is `size` ** `spatial_dims` + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + filter = torch.zeros([size] * spatial_dims).float() - 1 # make all -1 + center_point = tuple([size // 2] * spatial_dims) + filter[center_point] = (size**spatial_dims) - 1 + super().__init__(filter=filter) + + +class EllipticalFilter(ApplyFilter): + """ + Elliptical filter, can be used to dilate labels or label-contours. + The elliptical filter used here, is a `torch.Tensor` with shape (size, ) * ndim containing a circle/sphere of `1` + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + radius = size // 2 + grid = torch.meshgrid(*[torch.arange(0, size) for _ in range(spatial_dims)]) + squared_distances = torch.stack([(axis - radius) ** 2 for axis in grid], 0).sum(0) + filter = squared_distances <= radius**2 + super().__init__(filter=filter) + + +class SharpenFilter(EllipticalFilter): + """ + Convolutional filter to sharpen a 2D or 3D image. + The filter used contains a circle/sphere of `-1`, with the center value being + the absolute sum of all non-zero elements in the kernel + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + super().__init__(spatial_dims=spatial_dims, size=size) + center_point = tuple([size // 2] * spatial_dims) + center_value = self.filter.sum() + self.filter *= -1 + self.filter[center_point] = center_value diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ba8702ebd9..1fa03c0317 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -478,11 +478,13 @@ EnsureType, FgBgToIndices, Identity, + ImageFilter, IntensityStats, LabelToMask, Lambda, MapLabelValue, RandCuCIM, + RandImageFilter, RandLambda, RemoveRepeatedChannel, RepeatChannel, @@ -553,6 +555,9 @@ Identityd, IdentityD, IdentityDict, + ImageFilterd, + ImageFilterD, + ImageFilterDict, IntensityStatsd, IntensityStatsD, IntensityStatsDict, @@ -568,6 +573,9 @@ RandCuCIMd, RandCuCIMD, RandCuCIMDict, + RandImageFilterd, + RandImageFilterD, + RandImageFilterDict, RandLambdad, RandLambdaD, RandLambdaDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 27314ab91c..02688d6604 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -18,16 +18,28 @@ import time import warnings from copy import deepcopy +from functools import partial from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch +import torch.nn as nn from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import no_collation +from monai.networks.layers.simplelayers import ( + ApplyFilter, + EllipticalFilter, + GaussianFilter, + LaplaceFilter, + MeanFilter, + SavitzkyGolayFilter, + SharpenFilter, + median_filter, +) from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform from monai.transforms.utils import ( @@ -92,6 +104,8 @@ "CuCIM", "RandCuCIM", "ToCupy", + "ImageFilter", + "RandImageFilter", ] @@ -1422,3 +1436,287 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: coord_channels, *_ = convert_to_dst_type(coord_channels, img) # type: ignore coord_channels = coord_channels[list(self.spatial_dims)] return concatenate((img, coord_channels), axis=0) + + +class ImageFilter(Transform): + """ + Applies a convolution filter to the input image. + + Args: + filter: + A string specifying the filter, a custom filter as ``torch.Tenor`` or ``np.ndarray`` or a ``nn.Module``. + Available options for string are: ``mean``, ``laplace``, ``elliptical``, ``sobel``, ``sharpen``, ``median``, ``gauss`` + See below for short explanations on every filter. + filter_size: + A single integer value specifying the size of the quadratic or cubic filter. + Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which + should be considered when choosing filter size. + kwargs: + Additional arguments passed to filter function, required by ``sobel`` and ``gauss``. + See below for details. + + Raises: + ValueError: When ``filter_size`` is not an uneven integer + ValueError: When ``filter`` is an array and ``ndim`` is not in [1,2,3] + ValueError: When ``filter`` is an array and any dimension has an even shape + NotImplementedError: When ``filter`` is a string and not in ``self.supported_filters`` + KeyError: When necessary ``kwargs`` are not passed to a filter that requires additional arguments. + + + **Mean Filtering:** ``filter='mean'`` + + Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image. + See also py:func:`monai.networks.layers.simplelayers.MeanFilter` + Example 2D filter (5 x 5):: + + [[1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1]] + + If smoothing labels with this filter, ensure they are in one-hot format. + + **Outline Detection:** ``filter='laplace'`` + + Laplacian filtering for outline detection in images. Can be used to transform labels to contours. + See also py:func:`monai.networks.layers.simplelayers.LaplaceFilter` + + Example 2D filter (5x5):: + + [[-1., -1., -1., -1., -1.], + [-1., -1., -1., -1., -1.], + [-1., -1., 24., -1., -1.], + [-1., -1., -1., -1., -1.], + [-1., -1., -1., -1., -1.]] + + + **Dilation:** ``filter='elliptical'`` + + An elliptical filter can be used to dilate labels or label-contours. + Example 2D filter (5x5):: + + [[0., 0., 1., 0., 0.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [0., 0., 1., 0., 0.]] + + + **Edge Detection:** ``filter='sobel'`` + + This filter allows for additional arguments passed as ``kwargs`` during initialization. + See also py:func:`monai.transforms.post.SobelGradients` + + *kwargs* + + * ``spatial_axes``: the axes that define the direction of the gradient to be calculated. + It calculates the gradient along each of the provide axis. + By default it calculate the gradient for all spatial axes. + * ``normalize_kernels``: if normalize the Sobel kernel to provide proper gradients. Defaults to True. + * ``normalize_gradients``: if normalize the output gradient to 0 and 1. Defaults to False. + * ``padding_mode``: the padding mode of the image when convolving with Sobel kernels. Defaults to ``"reflect"``. + Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + See ``torch.nn.Conv1d()`` for more information. + * ``dtype``: kernel data type (torch.dtype). Defaults to ``torch.float32``. + + + **Sharpening:** ``filter='sharpen'`` + + Sharpen an image with a 2D or 3D filter. + Example 2D filter (5x5):: + + [[ 0., 0., -1., 0., 0.], + [-1., -1., -1., -1., -1.], + [-1., -1., 17., -1., -1.], + [-1., -1., -1., -1., -1.], + [ 0., 0., -1., 0., 0.]] + + + **Gaussian Smooth:** ``filter='gauss'`` + + Blur/smooth an image with 2D or 3D gaussian filter. + This filter requires additional arguments passed as ``kwargs`` during initialization. + See also py:func:`monai.networks.layers.simplelayers.GaussianFilter` + + *kwargs* + + * ``sigma``: std. could be a single value, or spatial_dims number of values. + * ``truncated``: spreads how many stds. + * ``approx``: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". + + + **Median Filter:** ``filter='median'`` + + Blur an image with 2D or 3D median filter to remove noise. + Useful in image preprocessing to improve results of later processing. + See also py:func:`monai.networks.layers.simplelayers.MedianFilter` + + + **Savitzky Golay Filter:** ``filter = 'savitzky_golay'`` + + Convolve a Tensor along a particular axis with a Savitzky-Golay kernel. + This filter requires additional arguments passed as ``kwargs`` during initialization. + See also py:func:`monai.networks.layers.simplelayers.SavitzkyGolayFilter` + + *kwargs* + + * ``order``: Order of the polynomial to fit to each window, must be less than ``window_length``. + * ``axis``: (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). + * ``mode``: (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or + ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. + + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + supported_filters = sorted( + ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] + ) + + def __init__( + self, filter: Union[str, NdarrayOrTensor, nn.Module], filter_size: Optional[int] = None, **kwargs + ) -> None: + + self._check_filter_format(filter, filter_size) + self._check_kwargs_are_present(filter, **kwargs) + self.filter = filter + self.filter_size = filter_size + self.additional_args_for_filter = kwargs + + def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Dict] = None) -> NdarrayOrTensor: + """ + Args: + img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] + meta_dict: An optional dictionary with metadata + + Returns: + A MetaTensor with the same shape as `img` and identical metadata + """ + if isinstance(img, MetaTensor): + meta_dict = img.meta + img_, prev_type, device = convert_data_type(img, torch.Tensor) + ndim = img_.ndim - 1 # assumes channel first format + + if isinstance(self.filter, str): + self.filter = self._get_filter_from_string(self.filter, self.filter_size, ndim) # type: ignore + elif isinstance(self.filter, (torch.Tensor, np.ndarray)): + self.filter = ApplyFilter(self.filter) + + img_ = self._apply_filter(img_) + if meta_dict: + img_ = MetaTensor(img_, meta=meta_dict) + else: + img_, *_ = convert_data_type(img_, prev_type, device) + return img_ + + def _check_all_values_uneven(self, x: tuple) -> None: + for value in x: + if value % 2 == 0: + raise ValueError(f"Only uneven filters are supported, but filter size is {x}") + + def _check_filter_format( + self, filter: Union[str, NdarrayOrTensor, nn.Module], filter_size: Optional[int] = None + ) -> None: + if isinstance(filter, str): + if not filter_size: + raise ValueError("`filter_size` must be specified when specifying filters by string.") + if filter_size % 2 == 0: + raise ValueError("`filter_size` should be a single uneven integer.") + if filter not in self.supported_filters: + raise NotImplementedError(f"{filter}. Supported filters are {self.supported_filters}.") + elif isinstance(filter, torch.Tensor) or isinstance(filter, np.ndarray): + if filter.ndim not in [1, 2, 3]: + raise ValueError("Only 1D, 2D, and 3D filters are supported.") + self._check_all_values_uneven(filter.shape) # type: ignore + elif isinstance(filter, (nn.Module, Transform)): + pass + else: + raise TypeError( + f"{type(filter)} is not supported." + "Supported types are `class 'str'`, `class 'torch.Tensor'`, `class 'np.ndarray'`, " + "`class 'torch.nn.modules.module.Module'`, `class 'monai.transforms.Transform'`" + ) + + def _check_kwargs_are_present(self, filter, **kwargs): + if filter == "gauss" and "sigma" not in kwargs.keys(): + raise KeyError("`filter='gauss', requires the additonal keyword argument `sigma`") + if filter == "savitzky_golay" and "order" not in kwargs.keys(): + raise KeyError("`filter='savitzky_golay', requires the additonal keyword argument `order`") + + def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> Union[nn.Module, Callable]: + if filter == "mean": + return MeanFilter(ndim, size) + elif filter == "laplace": + return LaplaceFilter(ndim, size) + elif filter == "elliptical": + return EllipticalFilter(ndim, size) + elif filter == "sobel": + from monai.transforms.post.array import SobelGradients # cannot import on top because of circular imports + + allowed_keys = SobelGradients.__init__.__annotations__.keys() + kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} + return SobelGradients(size, **kwargs) + elif filter == "sharpen": + return SharpenFilter(ndim, size) + elif filter == "gauss": + allowed_keys = GaussianFilter.__init__.__annotations__.keys() + kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} + return GaussianFilter(ndim, **kwargs) + elif filter == "median": + return partial(median_filter, kernel_size=size, spatial_dims=ndim) + elif filter == "savitzky_golay": + allowed_keys = SavitzkyGolayFilter.__init__.__annotations__.keys() + kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} + return SavitzkyGolayFilter(size, **kwargs) + else: + raise NotImplementedError(f"Filter {filter} not implemented") + + def _apply_filter(self, img: torch.Tensor) -> torch.Tensor: + if isinstance(self.filter, Transform): + img = self.filter(img) + else: + img = self.filter(img.unsqueeze(0)) # type: ignore + img = img[0] # add and remove batch dim + return img + + +class RandImageFilter(RandomizableTransform): + """ + Randomly apply a convolutional filter to the input data. + + Args: + filter: + A string specifying the filter or a custom filter as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplace`, `elliptical`, `gaussian`` + See below for short explanations on every filter. + filter_size: + A single integer value specifying the size of the quadratic or cubic filter. + Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which + should be considered when choosing filter size. + prob: + Probability the transform is applied to the data + """ + + backend = ImageFilter.backend + + def __init__( + self, filter: Union[str, NdarrayOrTensor], filter_size: Optional[int] = None, prob: float = 0.1, **kwargs + ) -> None: + super().__init__(prob) + self.filter = ImageFilter(filter, filter_size, **kwargs) + + def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: + """ + Args: + img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] + meta_dict: An optional dictionary with metadata + kwargs: optional arguments required by specific filters. E.g. `sigma`if filter is `gauss`. + see py:func:`monai.transforms.utility.array.ImageFilter` for more details + + Returns: + A MetaTensor with the same shape as `img` and identical metadata + """ + self.randomize(None) + if self._do_transform: + img = self.filter(img) + return img diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index d52fdbe251..c16a456f74 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -43,6 +43,7 @@ EnsureType, FgBgToIndices, Identity, + ImageFilter, IntensityStats, LabelToMask, Lambda, @@ -118,6 +119,7 @@ "IntensityStatsd", "IntensityStatsD", "IntensityStatsDict", + "ImageFilterd", "LabelToMaskD", "LabelToMaskDict", "LabelToMaskd", @@ -133,6 +135,7 @@ "RandCuCIMd", "RandCuCIMD", "RandCuCIMDict", + "RandImageFilterd", "RandLambdaD", "RandLambdaDict", "RandLambdad", @@ -1738,6 +1741,90 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +class ImageFilterd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ImageFilter`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + kernel: + A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplacian`, `elliptical`, `sobel_{w,h,d}`` + kernel_size: + A single integer value specifying the size of the quadratic or cubic kernel. + Computational complexity increases exponentially with kernel_size, which + should be considered when choosing the kernel size. + allow_missing_keys: + Don't raise exception if key is missing. + """ + + backend = ImageFilter.backend + + def __init__( + self, + keys: KeysCollection, + kernel: Union[str, NdarrayOrTensor], + kernel_size: Optional[int] = None, + allow_missing_keys: bool = False, + **kwargs, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.filter = ImageFilter(kernel, kernel_size, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.filter(d[key]) + return d + + +class RandImageFilterd(MapTransform, RandomizableTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandomFilterKernel`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + kernel: + A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplacian`, `elliptical`, `sobel_{w,h,d}`` + kernel_size: + A single integer value specifying the size of the quadratic or cubic kernel. + Computational complexity increases exponentially with kernel_size, which + should be considered when choosing the kernel size. + prob: + Probability the transform is applied to the data + allow_missing_keys: + Don't raise exception if key is missing. + """ + + backend = ImageFilter.backend + + def __init__( + self, + keys: KeysCollection, + kernel: Union[str, NdarrayOrTensor], + kernel_size: Optional[int] = None, + prob: float = 0.1, + allow_missing_keys: bool = False, + **kwargs, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.filter = ImageFilter(kernel, kernel_size, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if self._do_transform: + for key in self.key_iterator(d): + d[key] = self.filter(d[key]) + return d + + +RandImageFilterD = RandImageFilterDict = RandImageFilterd +ImageFilterD = ImageFilterDict = ImageFilterd IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py new file mode 100644 index 0000000000..007b5e8e2c --- /dev/null +++ b/tests/test_image_filter.py @@ -0,0 +1,230 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.simplelayers import GaussianFilter +from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd + +EXPECTED_FILTERS = { + "mean": torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).float(), + "laplace": torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]).float(), + "elliptical": torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]]).float(), + "sharpen": torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]).float(), +} + +SUPPORTED_FILTERS = ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] +SAMPLE_IMAGE_2D = torch.randn(1, 10, 10) +SAMPLE_IMAGE_3D = torch.randn(1, 10, 10, 10) +SAMPLE_DICT = {"image_2d": SAMPLE_IMAGE_2D, "image_3d": SAMPLE_IMAGE_3D} + +# Sobel filter uses reflect pad as default which is not implemented for 3d in torch 1.8.1 or 1.9.1 +ADDITIONAL_ARGUMENTS = {"order": 1, "sigma": 1, "padding_mode": "zeros"} + + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + 1 + + +class TestNotAModuleOrTransform: + pass + + +class TestImageFilter(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string(self, filter_name): + "Test init from string" + _ = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + + def test_init_raises(self): + with self.assertRaises(Exception) as context: + _ = ImageFilter("mean") + self.assertTrue("`filter_size` must be specified when specifying filters by string." in str(context.output)) + with self.assertRaises(Exception) as context: + _ = ImageFilter("mean") + self.assertTrue("`filter_size` should be a single uneven integer." in str(context.output)) + with self.assertRaises(Exception) as context: + _ = ImageFilter("gauss", 3) + self.assertTrue("`filter='gauss', requires the additonal keyword argument `sigma`" in str(context.output)) + with self.assertRaises(Exception) as context: + _ = ImageFilter("savitzky_golay", 3) + self.assertTrue( + "`filter='savitzky_golay', requires the additonal keyword argument `order`" in str(context.output) + ) + + def test_init_from_array(self): + "Test init with custom filter and assert wrong filter shape throws an error" + _ = ImageFilter(torch.ones(3, 3)) + _ = ImageFilter(torch.ones(3, 3, 3)) + _ = ImageFilter(np.ones((3, 3))) + _ = ImageFilter(np.ones((3, 3, 3))) + + with self.assertRaises(Exception) as context: + _ = ImageFilter(torch.ones(3, 3, 3, 3)) + self.assertTrue("Only 1D, 2D, and 3D filters are supported." in str(context.output)) + + def test_init_from_module(self): + filter = ImageFilter(TestModule()) + out = filter(torch.zeros(1, 3, 3, 3)) + torch.testing.assert_allclose(torch.ones(1, 3, 3, 3), out) + + def test_init_from_transform(self): + _ = ImageFilter(GaussianFilter(3, sigma=2)) + + def test_init_from_wrong_type_fails(self): + with self.assertRaises(Exception) as context: + _ = ImageFilter(TestNotAModuleOrTransform()) + self.assertTrue(" is not supported." in str(context.output)) + + @parameterized.expand(EXPECTED_FILTERS.keys()) + def test_2d_filter_correctness(self, filter_name): + "Test correctness of filters (2d only)" + tfm = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + filter = tfm._get_filter_from_string(filter_name, size=3, ndim=2).filter.squeeze() + torch.testing.assert_allclose(filter, EXPECTED_FILTERS[filter_name]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d(self, filter_name): + "Text function `__call__` for 2d images" + filter = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d(self, filter_name): + "Text function `__call__` for 3d images" + filter = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_3D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + +class TestImageFilterDict(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string_dict(self, filter_name): + "Test init from string and assert an error is thrown if no size is passed" + _ = ImageFilterd("image", filter_name, 3, **ADDITIONAL_ARGUMENTS) + with self.assertRaises(Exception) as _: + _ = ImageFilterd(self.image_key, filter_name) + + def test_init_from_array_dict(self): + "Test init with custom filter and assert wrong filter shape throws an error" + _ = ImageFilterd("image", torch.ones(3, 3)) + with self.assertRaises(Exception) as _: + _ = ImageFilterd(self.image_key, torch.ones(3, 3, 3, 3)) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d(self, filter_name): + "Text function `__call__` for 2d images" + filter = ImageFilterd("image_2d", filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_2d"].shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d(self, filter_name): + "Text function `__call__` for 3d images" + filter = ImageFilterd("image_3d", filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_3d"].shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + +class TestRandImageFilter(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string(self, filter_name): + "Test init from string and assert an error is thrown if no size is passed" + _ = RandImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + with self.assertRaises(Exception) as _: + _ = RandImageFilter(filter_name) + + def test_init_from_array(self): + "Test init with custom filter and assert wrong filter shape throws an error" + _ = RandImageFilter(torch.ones(3, 3)) + with self.assertRaises(Exception) as _: + _ = RandImageFilter(torch.ones(3, 3, 3, 3)) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_1(self, filter_name): + "Text function `__call__` for 2d images" + filter = RandImageFilter(filter_name, 3, 1, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_1(self, filter_name): + "Text function `__call__` for 3d images" + filter = RandImageFilter(filter_name, 3, 1, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_3D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_0(self, filter_name): + "Text function `__call__` for 2d images" + filter = RandImageFilter(filter_name, 3, 0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_2D) + torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_2D) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_0(self, filter_name): + "Text function `__call__` for 3d images" + filter = RandImageFilter(filter_name, 3, 0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_3D) + torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_3D) + + +class TestRandImageFilterDict(unittest.TestCase): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string_dict(self, filter_name): + "Test init from string and assert an error is thrown if no size is passed" + _ = RandImageFilterd("image", filter_name, 3, **ADDITIONAL_ARGUMENTS) + with self.assertRaises(Exception) as _: + _ = RandImageFilterd("image", filter_name) + + def test_init_from_array_dict(self): + "Test init with custom filter and assert wrong filter shape throws an error" + _ = RandImageFilterd("image", torch.ones(3, 3)) + with self.assertRaises(Exception) as _: + _ = RandImageFilterd("image", torch.ones(3, 3, 3, 3)) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_1(self, filter_name): + filter = RandImageFilterd("image_2d", filter_name, 3, 1.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_2d"].shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_1(self, filter_name): + filter = RandImageFilterd("image_3d", filter_name, 3, 1.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_3d"].shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_0(self, filter_name): + filter = RandImageFilterd("image_2d", filter_name, 3, 0.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + torch.testing.assert_allclose(out_tensor["image_2d"].shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_0(self, filter_name): + filter = RandImageFilterd("image_3d", filter_name, 3, 0.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + torch.testing.assert_allclose(out_tensor["image_3d"].shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_preset_filters.py b/tests/test_preset_filters.py new file mode 100644 index 0000000000..0a6a4e8c50 --- /dev/null +++ b/tests/test_preset_filters.py @@ -0,0 +1,130 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import ApplyFilter, EllipticalFilter, LaplaceFilter, MeanFilter, SharpenFilter + +TEST_CASES_MEAN = [(3, 3, torch.ones(3, 3, 3)), (2, 5, torch.ones(5, 5))] + + +TEST_CASES_LAPLACE = [ + ( + 3, + 3, + torch.Tensor( + [ + [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]], + [[-1, -1, -1], [-1, 26, -1], [-1, -1, -1]], + [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]], + ] + ), + ), + (2, 3, torch.Tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])), +] + +TEST_CASES_ELLIPTICAL = [ + ( + 3, + 3, + torch.Tensor( + [[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [1, 1, 1], [0, 1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]] + ), + ), + (2, 3, torch.Tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]])), +] + + +TEST_CASES_SHARPEN = [ + ( + 3, + 3, + torch.Tensor( + [ + [[0, 0, 0], [0, -1, 0], [0, 0, 0]], + [[0, -1, 0], [-1, 7, -1], [0, -1, 0]], + [[0, 0, 0], [0, -1, 0], [0, 0, 0]], + ] + ), + ), + (2, 3, torch.Tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])), +] + + +class _TestFilter: + def test_init(self, spatial_dims, size, expected): + test_filter = self.filter_class(spatial_dims=spatial_dims, size=size) + torch.testing.assert_allclose(expected, test_filter.filter) + self.assertIsInstance(test_filter, torch.nn.Module) + + def test_forward(self): + test_filter = self.filter_class(spatial_dims=2, size=3) + input = torch.ones(1, 1, 5, 5) + _ = test_filter(input) + + +class TestApplyFilter(unittest.TestCase): + def test_init_and_forward_2d(self): + filter_2d = torch.ones(3, 3) + image_2d = torch.ones(1, 3, 3) + apply_filter_2d = ApplyFilter(filter_2d) + out = apply_filter_2d(image_2d) + self.assertEqual(out.shape, image_2d.shape) + + def test_init_and_forward_3d(self): + filter_2d = torch.ones(3, 3, 3) + image_2d = torch.ones(1, 3, 3, 3) + apply_filter_2d = ApplyFilter(filter_2d) + out = apply_filter_2d(image_2d) + self.assertEqual(out.shape, image_2d.shape) + + +class MeanFilterTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = MeanFilter + + @parameterized.expand(TEST_CASES_MEAN) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +class LaplaceFilterTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = LaplaceFilter + + @parameterized.expand(TEST_CASES_LAPLACE) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +class EllipticalTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = EllipticalFilter + + @parameterized.expand(TEST_CASES_ELLIPTICAL) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +class SharpenTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = SharpenFilter + + @parameterized.expand(TEST_CASES_SHARPEN) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +if __name__ == "__main__": + unittest.main()