From 1223a8771f94177ff754eba28f8adff3a004e202 Mon Sep 17 00:00:00 2001 From: kbressem Date: Wed, 12 Oct 2022 20:17:22 +0200 Subject: [PATCH 01/18] Add generic kernel transform with support for multiple kernels Signed-off-by: kbressem --- monai/transforms/__init__.py | 8 + monai/transforms/utility/array.py | 248 +++++++++++++++++++++++++ monai/transforms/utility/dictionary.py | 85 +++++++++ 3 files changed, 341 insertions(+) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 713f848f86..afe9678c4c 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -437,10 +437,12 @@ FgBgToIndices, Identity, IntensityStats, + KernelTransform, LabelToMask, Lambda, MapLabelValue, RandCuCIM, + RandKernelTransform, RandLambda, RemoveRepeatedChannel, RepeatChannel, @@ -511,6 +513,9 @@ IntensityStatsd, IntensityStatsD, IntensityStatsDict, + KernelTransformd, + KernelTransformD, + KernelTransformDict, LabelToMaskd, LabelToMaskD, LabelToMaskDict, @@ -523,6 +528,9 @@ RandCuCIMd, RandCuCIMD, RandCuCIMDict, + RandKernelTransformd, + RandKernelTransformD, + RandKernelTransformDict, RandLambdad, RandLambdaD, RandLambdaDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 8d9eb374db..ebea0ec09e 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -28,6 +28,7 @@ from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import no_collation +from monai.networks.layers import apply_filter from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( @@ -93,6 +94,8 @@ "CuCIM", "RandCuCIM", "ToCupy", + "KernelTransform", + "RandKernelTransform", ] @@ -1410,3 +1413,248 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: coord_channels, *_ = convert_to_dst_type(coord_channels, img) # type: ignore coord_channels = coord_channels[list(self.spatial_dims)] return concatenate((img, coord_channels), axis=0) + + +class KernelTransform(Transform): + """ + Applies a kernel transformation to the input image. + + Args: + kernel: + A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplacian`, `elliptical`, `gaussian`` + See below for short explanations on every kernel. + kernel_size: + A single integer value specifying the size of the quadratic or cubic kernel. + Computational complexity scales to the power of 2 (2D kernel) or 3 (3D kernel), which + should be considered when choosing kernel size. + + Raises: + AssertionError: When `kernel` is a string and `kernel_size` is not specified + AssertionError: When `kernel_size` is not an uneven integer + AssertionError: When `kernel` is an array and `ndim` is not in [1,2,3] + AssertionError: When `kernel` is an array and any dimension has an even shape + NotImplementedError: When `kernel` is a string and not in `self.supported_kernels` + + + ## Mean kernel + > `kernel='mean'` + + Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image. + Example 2D kernel (5 x 5): + + [1, 1, 1, 1, 1] + [1, 1, 1, 1, 1] + [1, 1, 1, 1, 1] + [1, 1, 1, 1, 1] + [1, 1, 1, 1, 1] + + If smoothing labels with this kernel, ensure they are in one-hot format. + + ## Laplacian kernel + > `kernel='laplacian'` + + Laplacian filtering for outline detection in images. Can be used to transform labels to contours. + Example 2D kernel (5x5): + + [-1., -1., -1., -1., -1.] + [-1., -1., -1., -1., -1.] + [-1., -1., 24., -1., -1.] + [-1., -1., -1., -1., -1.] + [-1., -1., -1., -1., -1.] + + ## Elliptical kernel + > `kernel='elliptical'` + + An elliptical kernel can be used to dilate labels or label-contours. + Example 2D kernel (5x5): + + [0., 0., 1., 0., 0.] + [1., 1., 1., 1., 1.] + [1., 1., 1., 1., 1.] + [1., 1., 1., 1., 1.] + [0., 0., 1., 0., 0.] + + ## Sobel kernel + > `kernel={'sobel_h', 'sobel_w', ''sobel_d}` + + Edge detection with sobel kernel, along the h,w or d axis of tensor. + Example 2D kernel (5x5) for `sobel_w`: + + [-0.25, -0.20, 0.00, 0.20, 0.25] + [-0.40, -0.50, 0.00, 0.50, 0.40] + [-0.50, -1.00, 0.00, 1.00, 0.50] + [-0.40, -0.50, 0.00, 0.50, 0.40] + [-0.25, -0.20, 0.00, 0.20, 0.25] + + ## Sharpen kernel + > `kernel="sharpen"` + + Sharpen an image with a 2D or 3D kernel. + Example 2D kernel (5x5): + + [ 0., 0., -1., 0., 0.] + [-1., -1., -1., -1., -1.] + [-1., -1., 17., -1., -1.] + [-1., -1., -1., -1., -1.] + [ 0., 0., -1., 0., 0.] + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + supported_kernels = sorted(["mean", "laplacian", "elliptical", "sobel_w", "sobel_h", "sobel_d", "sharpen"]) + + def __init__(self, kernel: Union[str, NdarrayOrTensor], kernel_size: Optional[int] = None) -> None: + + if isinstance(kernel, str): + assert kernel_size, "`kernel_size` must be specified when specifying kernels by string." + assert kernel_size % 2 == 1, "`kernel_size` should be a single uneven integer." + if kernel not in self.supported_kernels: + raise NotImplementedError(f"{kernel}. Supported kernels are {self.supported_kernels}.") + else: + assert kernel.ndim in [1, 2, 3], "Only 1D, 2D, and 3D kernels are supported" + kernel = convert_to_tensor(kernel, dtype=torch.float32) + self._assert_all_values_uneven(kernel.shape) + + self.kernel = kernel + self.kernel_size = kernel_size + + def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: + """ + Args: + img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] + meta_dict: An optional dictionary with metadata + + Returns: + A MetaTensor with the same shape as `img` and identical metadata + """ + if isinstance(img, MetaTensor): + meta_dict = img.meta + img_, prev_type, device = convert_data_type(img, torch.Tensor) + ndim = img_.ndim - 1 # assumes channel first format + if isinstance(self.kernel, str): + self.kernel = self._create_kernel_from_string(self.kernel, self.kernel_size, ndim) + img_ = img_.unsqueeze(0) + img_ = apply_filter(img_, self.kernel) # batch, channels, H[, W, D] is required for img_ + img_ = img_[0] + if meta_dict: + img_ = MetaTensor(img_, meta_dict) + else: + img_, *_ = convert_data_type(img_, prev_type, device) + return img_ + + def _assert_all_values_uneven(self, x: tuple) -> None: + for value in x: + assert value % 2 == 1, f"Only uneven kernels are supported, but kernel size is {x}" + + def _create_kernel_from_string(self, name, size, ndim) -> torch.Tensor: + """Create an `ndim` kernel of size `(size, ) * ndim`.""" + func = getattr(self, f"_create_{name}_kernel") + kernel = func(size, ndim) + return kernel.to(torch.float32) + + def _create_mean_kernel(self, size, ndim) -> torch.Tensor: + """Create a torch.Tensor with shape (size, ) * ndim with all values equal to `1`""" + return torch.ones([size] * ndim) + + def _create_laplacian_kernel(self, size, ndim) -> torch.Tensor: + """Create a torch.Tensor with shape (size, ) * ndim. + All values are `-1` except the center value which is size**ndim - 1 + """ + kernel = torch.zeros([size] * ndim).float() - 1 # make all -1 + center_point = tuple([size // 2] * ndim) + kernel[center_point] = (size**ndim) - 1 + return kernel + + def _create_elliptical_kernel(self, size: int, ndim: int) -> torch.Tensor: + """Create a torch.Tensor with shape (size, ) * ndim containing a circle/sphere of `1`""" + radius = size // 2 + grid = torch.meshgrid(*[torch.arange(0, size) for _ in range(ndim)]) + squared_distances = torch.stack([(axis - radius) ** 2 for axis in grid], 0).sum(0) + kernel = squared_distances <= radius**2 + return kernel + + def _sobel_2d(self, size) -> torch.Tensor: + """Create a generic 2d sobel kernel""" + numerator = torch.arange(-size // 2 + 1, size // 2 + 1, dtype=torch.float32).unsqueeze(0) + denominator = numerator * numerator + denominator = denominator + denominator.T + denominator[:, size // 2] = 1.0 # to avoid division by zero + return numerator / denominator + + def _sobel_3d(self, size) -> torch.Tensor: + """Create a generic 3d sobel kernel""" + kernel_2d = self._sobel_2d(size) + kernel_3d = torch.stack((kernel_2d,) * size, -1) + adapter = (size // 2) - torch.arange(-size // 2 + 1, size // 2 + 1, dtype=torch.float32).abs() + adapter = adapter / adapter.max() + 1 # scale between 1 - 2 + return kernel_3d * adapter + + def _create_sobel_w_kernel(self, size, ndim) -> torch.Tensor: + """Sobel kernel in x/w direction for Tensor in shape (B,C)[WH[D]]""" + if ndim == 2: + kernel = self._sobel_2d(size) + elif ndim == 3: + kernel = self._sobel_3d(size) + else: + raise ValueError(f"Only 2 or 3 dimensional kernels are supported. Got {ndim}") + return kernel + + def _create_sobel_h_kernel(self, size, ndim) -> torch.Tensor: + """Sobel kernel in y/h direction for Tensor in shape (B,C)[WH[D]]""" + kernel = self._create_sobel_w_kernel(size, ndim).transpose(0, 1) + return kernel + + def _create_sobel_d_kernel(self, size, ndim) -> torch.Tensor: + """Sobel kernel in z/d direction for Tensor in shape (B,C)[WHD]]""" + assert ndim == 3, "Only 3 dimensional kernels are supported for `sobel_d`" + return self._sobel_3d(size).transpose(1, 2) + + def _create_sharpen_kernel(self, size, ndim) -> torch.Tensor: + """Create a torch.Tensor with shape (size, ) * ndim. + The kernel contains a circle/sphere of `-1`, with the center value beeing + the absolut sum of all non-zero elements in the kernel + """ + kernel = self._create_elliptical_kernel(size, ndim) + center_point = tuple([size // 2] * ndim) + center_value = kernel.sum() + kernel = kernel * -1 + kernel[center_point] = center_value + return kernel + + +class RandKernelTransform(RandomizableTransform): + """Randomly apply a Filterkernel to the input data. + Args: + kernel: + A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplacian`, `elliptical`, `gaussian`` + See below for short explanations on every kernel. + kernel_size: + A single integer value specifying the size of the quadratic or cubic kernel. + Computational complexity scales to the power of 2 (2D kernel) or 3 (3D kernel), which + should be considered when choosing kernel size. + prob: + Probability the transform is applied to the data + """ + + backend = KernelTransform.backend + + def __init__( + self, kernel: Union[str, NdarrayOrTensor], kernel_size: Optional[int] = None, prob: float = 0.1 + ) -> None: + super().__init__(prob) + self.filter = KernelTransform(kernel, kernel_size) + + def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: + """ + Args: + img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] + meta_dict: An optional dictionary with metadata + + Returns: + A MetaTensor with the same shape as `img` and identical metadata + """ + self.randomize(None) + if self._do_transform: + img = self.filter(img) + return img diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 073e50a3be..92c1aebbf4 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -44,6 +44,7 @@ FgBgToIndices, Identity, IntensityStats, + KernelTransform, LabelToMask, Lambda, MapLabelValue, @@ -118,6 +119,7 @@ "IntensityStatsd", "IntensityStatsD", "IntensityStatsDict", + "KernelTransformd", "LabelToMaskD", "LabelToMaskDict", "LabelToMaskd", @@ -130,6 +132,7 @@ "RandCuCIMd", "RandCuCIMD", "RandCuCIMDict", + "RandKernelTransformd", "RandLambdaD", "RandLambdaDict", "RandLambdad", @@ -1657,6 +1660,88 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +class KernelTransformd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.KernelTransform`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + kernel: + A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplacian`, `elliptical`, `sobel_{w,h,d}`` + kernel_size: + A single integer value specifying the size of the quadratic or cubic kernel. + Computational complexity increases exponentially with kernel_size, which + should be considered when choosing the kernel size. + allow_missing_keys: + Don't raise exception if key is missing. + """ + + backend = KernelTransform.backend + + def __init__( + self, + keys: KeysCollection, + kernel: Union[str, NdarrayOrTensor], + kernel_size: Optional[int] = None, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.filter = KernelTransform(kernel, kernel_size) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.filter(d[key]) + return d + + +class RandKernelTransformd(MapTransform, RandomizableTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandomFilterKernel`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + kernel: + A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplacian`, `elliptical`, `sobel_{w,h,d}`` + kernel_size: + A single integer value specifying the size of the quadratic or cubic kernel. + Computational complexity increases exponentially with kernel_size, which + should be considered when choosing the kernel size. + prob: + Probability the transform is applied to the data + allow_missing_keys: + Don't raise exception if key is missing. + """ + + backend = KernelTransform.backend + + def __init__( + self, + keys: KeysCollection, + kernel: Union[str, NdarrayOrTensor], + kernel_size: Optional[int] = None, + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.filter = KernelTransform(kernel, kernel_size) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if self._do_transform: + for key in self.key_iterator(d): + d[key] = self.filter(d[key]) + return d + + +RandKernelTransformD = RandKernelTransformDict = RandKernelTransformd +KernelTransformD = KernelTransformDict = KernelTransformd IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd From 5ba3e10ba5000795d2acbc9bfa56d97bf55cf561 Mon Sep 17 00:00:00 2001 From: kbressem Date: Mon, 14 Nov 2022 20:50:06 +0100 Subject: [PATCH 02/18] Rewrite ImageFilter - rename class to `ImageFilter` - `ImageFilter` now accepts strings (preset filters), `torch.Tensor` or `np.ndarray`, `nn.Module` and `monai.transforms.Transform` - The filter is created just once, to make `__call__` faster - Additional `nn.Modules` for simple filters have been created to make them easier available - add tests for `ImageFilter` and `MeanFilter` - Extend tests for ImageFilter - Maybe extend ImageFilter to accept multiple arguments? - Add Tests for other nn.Modules Signed-off-by: kbressem --- monai/networks/layers/simplelayers.py | 234 +++++++++++- monai/transforms/__init__.py | 60 ++- monai/transforms/utility/array.py | 492 ++++++++++++++----------- monai/transforms/utility/dictionary.py | 81 ++-- tests/test_apply_filter.py | 1 - tests/test_image_filter.py | 241 ++++++++++++ tests/test_mean_filter.py | 40 ++ 7 files changed, 897 insertions(+), 252 deletions(-) create mode 100644 tests/test_image_filter.py create mode 100644 tests/test_mean_filter.py diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 3de4e75766..00cfea34d3 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -11,17 +11,26 @@ import math from copy import deepcopy -from typing import List, Sequence, Union +from typing import List, Optional, Sequence, Union import torch import torch.nn.functional as F from torch import nn from torch.autograd import Function +from monai.config.type_definitions import NdarrayOrTensor from monai.networks.layers.convutils import gaussian_1d from monai.networks.layers.factories import Conv -from monai.utils import ChannelMatching, SkipMode, look_up_option, optional_import, pytorch_after -from monai.utils.misc import issequenceiterable +from monai.utils import ( + ChannelMatching, + SkipMode, + convert_to_tensor, + ensure_tuple_rep, + issequenceiterable, + look_up_option, + optional_import, + pytorch_after, +) _C, _ = optional_import("monai._C") fft, _ = optional_import("torch.fft") @@ -32,10 +41,12 @@ "GaussianFilter", "HilbertTransform", "LLTM", + "MedianFilter", "Reshape", "SavitzkyGolayFilter", "SkipConnection", "apply_filter", + "median_filter", "separable_filtering", ] @@ -168,7 +179,6 @@ def _separable_filtering_conv( paddings: List[int], num_channels: int, ) -> torch.Tensor: - if d < 0: return input_ @@ -290,6 +300,9 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso else: # even-sized kernels are not supported kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]] + elif kwargs["padding"] == "same" and not pytorch_after(1, 10): + # even-sized kernels are not supported + kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]] if "stride" not in kwargs: kwargs["stride"] = 1 @@ -363,7 +376,11 @@ def _make_coeffs(window_length, order): a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1) y = torch.zeros(order + 1, dtype=torch.float, device="cpu") y[0] = 1.0 - return torch.lstsq(y, a).solution.squeeze() + return ( + torch.lstsq(y, a).solution.squeeze() + if not pytorch_after(1, 11) + else torch.linalg.lstsq(a, y).solution.squeeze() + ) class HilbertTransform(nn.Module): @@ -427,6 +444,118 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.as_tensor(ht, device=ht.device, dtype=ht.dtype) +def get_binary_kernel(window_size: Sequence[int], dtype=torch.float, device=None) -> torch.Tensor: + """ + Create a binary kernel to extract the patches. + The window size HxWxD will create a (H*W*D)xHxWxD kernel. + """ + win_size = convert_to_tensor(window_size, int, wrap_sequence=True) + prod = torch.prod(win_size) + s = [prod, 1, *win_size] + return torch.diag(torch.ones(prod, dtype=dtype, device=device)).view(s) # type: ignore + + +def median_filter( + in_tensor: torch.Tensor, + kernel_size: Sequence[int] = (3, 3, 3), + spatial_dims: int = 3, + kernel: Optional[torch.Tensor] = None, + **kwargs, +) -> torch.Tensor: + """ + Apply median filter to an image. + + Args: + in_tensor: input tensor; median filtering will be applied to the last `spatial_dims` dimensions. + kernel_size: the convolution kernel size. + spatial_dims: number of spatial dimensions to apply median filtering. + kernel: an optional customized kernel. + kwargs: additional parameters to the `conv`. + + Returns: + the filtered input tensor, shape remains the same as ``in_tensor`` + + Example:: + + >>> from monai.networks.layers import median_filter + >>> import torch + >>> x = torch.rand(4, 5, 7, 6) + >>> output = median_filter(x, (3, 3, 3)) + >>> output.shape + torch.Size([4, 5, 7, 6]) + + """ + if not isinstance(in_tensor, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(in_tensor)}") + + original_shape = in_tensor.shape + oshape, sshape = original_shape[: len(original_shape) - spatial_dims], original_shape[-spatial_dims:] + oprod = torch.prod(convert_to_tensor(oshape, int, wrap_sequence=True)) + # prepare kernel + if kernel is None: + kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) + kernel = get_binary_kernel(kernel_size, in_tensor.dtype, in_tensor.device) + else: + kernel = kernel.to(in_tensor) + # map the local window to single vector + conv = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] + reshaped_input: torch.Tensor = in_tensor.reshape(oprod, 1, *sshape) # type: ignore + + # even-sized kernels are not supported + padding = [(k - 1) // 2 for k in reversed(kernel.shape[2:]) for _ in range(2)] + padded_input: torch.Tensor = F.pad(reshaped_input, pad=padding, mode="replicate") + features: torch.Tensor = conv(padded_input, kernel, padding=0, stride=1, **kwargs) + + features = features.view(oprod, -1, *sshape) # type: ignore + + # compute the median along the feature axis + median: torch.Tensor = torch.median(features, dim=1)[0] + median = median.reshape(original_shape) + + return median + + +class MedianFilter(nn.Module): + """ + Apply median filter to an image. + + Args: + radius: the blurring kernel radius (radius of 1 corresponds to 3x3x3 kernel when spatial_dims=3). + + Returns: + filtered input tensor. + + Example:: + + >>> from monai.networks.layers import MedianFilter + >>> import torch + >>> in_tensor = torch.rand(4, 5, 7, 6) + >>> blur = MedianFilter([1, 1, 1]) # 3x3x3 kernel + >>> output = blur(in_tensor) + >>> output.shape + torch.Size([4, 5, 7, 6]) + + """ + + def __init__(self, radius: Union[Sequence[int], int], spatial_dims: int = 3, device="cpu") -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.radius: Sequence[int] = ensure_tuple_rep(radius, spatial_dims) + self.window: Sequence[int] = [1 + 2 * deepcopy(r) for r in self.radius] + self.kernel = get_binary_kernel(self.window, device=device) + + def forward(self, in_tensor: torch.Tensor, number_of_passes=1) -> torch.Tensor: + """ + Args: + in_tensor: input tensor, median filtering will be applied to the last `spatial_dims` dimensions. + number_of_passes: median filtering will be repeated this many times + """ + x = in_tensor + for _ in range(number_of_passes): + x = median_filter(x, kernel=self.kernel, spatial_dims=self.spatial_dims) + return x + + class GaussianFilter(nn.Module): def __init__( self, @@ -530,3 +659,98 @@ def reset_parameters(self): def forward(self, input, state): return LLTMFunction.apply(input, self.weights, self.bias, *state) + + +class ApplyFilter(nn.Module): + "Apply a convolutional filter to an image" + + def __init__(self, filter: NdarrayOrTensor) -> None: + super().__init__() + + filter = convert_to_tensor(filter, dtype=torch.float32) + self.filter = filter + + def forward(self, x: torch.Tensor): + """ + Args: + x: in shape B, C, [H, W, [D]] + """ + return apply_filter(x, self.filter) + + +class MeanFilter(ApplyFilter): + """ + Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image. + The mean filter used, is a `torch.Tensor` of all ones. + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + filter = torch.ones([size] * spatial_dims) + super().__init__(filter=filter) + + +class LaplaceFilter(ApplyFilter): + """ + Laplacian filtering for outline detection in images. Can be used to transform labels to contours. + The laplace filter used, is a `torch.Tensor` where all values are -1, except the center value + which is `size` ** `spatial_dims` + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + filter = torch.zeros([size] * spatial_dims).float() - 1 # make all -1 + center_point = tuple([size // 2] * spatial_dims) + filter[center_point] = (size**spatial_dims) - 1 + super().__init__(filter=filter) + + +class EllipticalFilter(ApplyFilter): + """ + Elliptical filter, can be used to dilate labels or label-contours. + The ellipical filter used here, is a `torch.Tensor` with shape (size, ) * ndim containing a circle/sphere of `1` + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + radius = size // 2 + grid = torch.meshgrid(*[torch.arange(0, size) for _ in range(spatial_dims)]) + squared_distances = torch.stack([(axis - radius) ** 2 for axis in grid], 0).sum(0) + filter = squared_distances <= radius**2 + super().__init__(filter=filter) + + +class SharpenFilter(ApplyFilter): + """ + Convolutional filter to sharpen a 2D or 3D image. + The filter used contains a circle/sphere of `-1`, with the center value beeing + the absolut sum of all non-zero elements in the kernel + """ + + def __init__(self, spatial_dims: int, size: int) -> None: + """ + Args: + spatial_dims: `int` of either 2 for 2D images and 3 for 3D images + size: edge length of the filter + """ + radius = size // 2 + grid = torch.meshgrid(*[torch.arange(0, size) for _ in range(spatial_dims)]) + squared_distances = torch.stack([(axis - radius) ** 2 for axis in grid], 0).sum(0) + filter = squared_distances <= radius**2 + center_point = tuple([size // 2] * spatial_dims) + center_value = filter.sum() + filter = filter * -1 + filter[center_point] = center_value + super().__init__(filter=filter) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index afe9678c4c..9892e71455 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -89,6 +89,7 @@ ) from .intensity.array import ( AdjustContrast, + ComputeHoVerMaps, DetectEnvelope, ForegroundMask, GaussianSharpen, @@ -98,6 +99,7 @@ IntensityRemap, KSpaceSpikeNoise, MaskIntensity, + MedianSmooth, NormalizeIntensity, RandAdjustContrast, RandBiasField, @@ -127,6 +129,9 @@ AdjustContrastd, AdjustContrastD, AdjustContrastDict, + ComputeHoVerMapsd, + ComputeHoVerMapsD, + ComputeHoVerMapsDict, ForegroundMaskd, ForegroundMaskD, ForegroundMaskDict, @@ -148,6 +153,9 @@ MaskIntensityd, MaskIntensityD, MaskIntensityDict, + MedianSmoothd, + MedianSmoothD, + MedianSmoothDict, NormalizeIntensityd, NormalizeIntensityD, NormalizeIntensityDict, @@ -263,6 +271,8 @@ LabelToContour, MeanEnsemble, ProbNMS, + RemoveSmallObjects, + SobelGradients, VoteEnsemble, ) from .post.dictionary import ( @@ -296,13 +306,32 @@ ProbNMSD, ProbNMSd, ProbNMSDict, + RemoveSmallObjectsD, + RemoveSmallObjectsd, + RemoveSmallObjectsDict, SaveClassificationD, SaveClassificationd, SaveClassificationDict, + SobelGradientsd, + SobelGradientsD, + SobelGradientsDict, VoteEnsembleD, VoteEnsembled, VoteEnsembleDict, ) +from .signal.array import ( + SignalContinuousWavelet, + SignalFillEmpty, + SignalRandAddGaussianNoise, + SignalRandAddSine, + SignalRandAddSinePartial, + SignalRandAddSquarePulse, + SignalRandAddSquarePulsePartial, + SignalRandDrop, + SignalRandScale, + SignalRandShift, + SignalRemoveFrequency, +) from .smooth_field.array import ( RandSmoothDeform, RandSmoothFieldAdjustContrast, @@ -420,7 +449,18 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import ( + LazyTrait, + LazyTransform, + MapTransform, + MultiSampleTrait, + Randomizable, + RandomizableTrait, + RandomizableTransform, + ThreadUnsafe, + Transform, + apply_transform, +) from .utility.array import ( AddChannel, AddCoordinateChannels, @@ -436,13 +476,13 @@ EnsureType, FgBgToIndices, Identity, + ImageFilter, IntensityStats, - KernelTransform, LabelToMask, Lambda, MapLabelValue, RandCuCIM, - RandKernelTransform, + RandImageFilter, RandLambda, RemoveRepeatedChannel, RepeatChannel, @@ -513,9 +553,9 @@ IntensityStatsd, IntensityStatsD, IntensityStatsDict, - KernelTransformd, - KernelTransformD, - KernelTransformDict, + ImageFilterd, + ImageFilterD, + ImageFilterDict, LabelToMaskd, LabelToMaskD, LabelToMaskDict, @@ -528,9 +568,9 @@ RandCuCIMd, RandCuCIMD, RandCuCIMDict, - RandKernelTransformd, - RandKernelTransformD, - RandKernelTransformDict, + RandImageFilterd, + RandImageFilterD, + RandImageFilterDict, RandLambdad, RandLambdaD, RandLambdaDict, @@ -613,9 +653,11 @@ map_spatial_axes, print_transform_backends, rand_choice, + remove_small_objects, rescale_array, rescale_array_int_max, rescale_instance_array, + reset_ops_id, resize_center, sync_meta_info, weighted_patch_samples, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ebea0ec09e..3a297f8ca5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -13,6 +13,7 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +from functools import partial import logging import sys import time @@ -22,13 +23,23 @@ import numpy as np import torch +import torch.nn as nn from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import no_collation -from monai.networks.layers import apply_filter +from monai.networks.layers.simplelayers import ( + ApplyFilter, + EllipticalFilter, + GaussianFilter, + LaplaceFilter, + MeanFilter, + SavitzkyGolayFilter, + SharpenFilter, + median_filter, +) from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( @@ -59,7 +70,6 @@ pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") cp, has_cp = optional_import("cupy") - __all__ = [ "Identity", "AsChannelFirst", @@ -94,8 +104,8 @@ "CuCIM", "RandCuCIM", "ToCupy", - "KernelTransform", - "RandKernelTransform", + "ImageFilter", + "RandImageFilter", ] @@ -204,45 +214,64 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: class EnsureChannelFirst(Transform): """ - Automatically adjust or add the channel dimension of input data to ensure `channel_first` shape. - It extracts the `original_channel_dim` info from provided meta_data dictionary. - Typical values of `original_channel_dim` can be: "no_channel", 0, -1. - Convert the data to `channel_first` based on the `original_channel_dim` information. + Adjust or add the channel dimension of input data to ensure `channel_first` shape. + + This extracts the `original_channel_dim` info from provided meta_data dictionary or MetaTensor input. This value + should state which dimension is the channel dimension so that it can be moved forward, or contain "no_channel" to + state no dimension is the channel and so a 1-size first dimension is to be added. + + Args: + strict_check: whether to raise an error when the meta information is insufficient. + channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array. + It overrides the `original_channel_dim` from provided MetaTensor input. + If the input array doesn't have a channel dim, this value should be ``'no_channel'``. + If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, strict_check: bool = True): - """ - Args: - strict_check: whether to raise an error when the meta information is insufficient. - """ + def __init__(self, strict_check: bool = True, channel_dim: Union[None, str, int] = None): self.strict_check = strict_check - self.add_channel = AddChannel() + self.input_channel_dim = channel_dim def __call__(self, img: torch.Tensor, meta_dict: Optional[Mapping] = None) -> torch.Tensor: """ Apply the transform to `img`. """ if not isinstance(img, MetaTensor) and not isinstance(meta_dict, Mapping): - msg = "metadata not available, EnsureChannelFirst is not in use." - if self.strict_check: - raise ValueError(msg) - warnings.warn(msg) - return img + if self.input_channel_dim is None: + msg = "Metadata not available and channel_dim=None, EnsureChannelFirst is not in use." + if self.strict_check: + raise ValueError(msg) + warnings.warn(msg) + return img + else: + img = MetaTensor(img) + if isinstance(img, MetaTensor): meta_dict = img.meta - channel_dim = meta_dict.get("original_channel_dim") # type: ignore + + channel_dim = meta_dict.get("original_channel_dim", None) if isinstance(meta_dict, Mapping) else None + if self.input_channel_dim is not None: + channel_dim = self.input_channel_dim if channel_dim is None: - msg = "Unknown original_channel_dim in the meta_dict, EnsureChannelFirst is not in use." + msg = "Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`." if self.strict_check: raise ValueError(msg) warnings.warn(msg) return img + + # track the original channel dim + if isinstance(meta_dict, dict): + meta_dict["original_channel_dim"] = channel_dim + if channel_dim == "no_channel": - return self.add_channel(img) # type: ignore - return AsChannelFirst(channel_dim=channel_dim)(img) # type: ignore + result = img[None] + else: + result = moveaxis(img, channel_dim, 0) # type: ignore + + return convert_to_tensor(result, track_meta=get_track_meta()) # type: ignore class RepeatChannel(Transform): @@ -255,7 +284,7 @@ class RepeatChannel(Transform): repeats: the number of repetitions for each element. """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] def __init__(self, repeats: int) -> None: if repeats <= 0: @@ -405,19 +434,19 @@ class ToTensor(Transform): device: target device to put the converted Tensor data. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. - track_meta: whether to convert to `MetaTensor`, default to `False`, output type will be `torch.Tensor`. - if `None`, use the return value of ``get_track_meta``. + track_meta: whether to convert to `MetaTensor` or regular tensor, default to `None`, + use the return value of ``get_track_meta``. """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] def __init__( self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = True, - track_meta: Optional[bool] = False, + track_meta: Optional[bool] = None, ) -> None: super().__init__() self.dtype = dtype @@ -449,8 +478,8 @@ class EnsureType(Transform): device: for Tensor data type, specify the target device. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. - track_meta: whether to convert to `MetaTensor` when `data_type` is "tensor". - If False, the output data type will be `torch.Tensor`. Default to the return value of ``get_track_meta``. + track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``, + if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`. """ @@ -505,7 +534,7 @@ class ToNumpy(Transform): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.NUMPY] def __init__(self, dtype: DtypeLike = None, wrap_sequence: bool = True) -> None: super().__init__() @@ -532,7 +561,7 @@ class ToCupy(Transform): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.CUPY] def __init__(self, dtype: Optional[np.dtype] = None, wrap_sequence: bool = True) -> None: super().__init__() @@ -551,7 +580,7 @@ class ToPIL(Transform): Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.NUMPY] def __call__(self, img): """ @@ -569,7 +598,7 @@ class Transpose(Transform): Transposes the input image based on the given `indices` dimension ordering. """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] def __init__(self, indices: Optional[Sequence[int]]) -> None: self.indices = None if indices is None else tuple(indices) @@ -589,11 +618,12 @@ class SqueezeDim(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, dim: Optional[int] = 0) -> None: + def __init__(self, dim: Optional[int] = 0, update_meta=True) -> None: """ Args: dim: dimension to be squeezed. Default = 0 "None" works when the input is numpy array. + update_meta: whether to update the meta info if the input is a metatensor. Default is ``True``. Raises: TypeError: When ``dim`` is not an ``Optional[int]``. @@ -602,6 +632,7 @@ def __init__(self, dim: Optional[int] = 0) -> None: if dim is not None and not isinstance(dim, int): raise TypeError(f"dim must be None or a int but is {type(dim).__name__}.") self.dim = dim + self.update_meta = update_meta def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ @@ -610,11 +641,25 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) if self.dim is None: + if self.update_meta: + warnings.warn("update_meta=True is ignored when dim=None.") return img.squeeze() + dim = (self.dim + len(img.shape)) if self.dim < 0 else self.dim # for pytorch/numpy unification - if img.shape[self.dim] != 1: - raise ValueError(f"Can only squeeze singleton dimension, got shape {img.shape}.") - return img.squeeze(self.dim) + if img.shape[dim] != 1: + raise ValueError(f"Can only squeeze singleton dimension, got shape {img.shape[dim]} of {img.shape}.") + img = img.squeeze(dim) + if self.update_meta and isinstance(img, MetaTensor) and dim > 0 and len(img.affine.shape) == 2: + h, w = img.affine.shape + affine, device = img.affine, img.affine.device if isinstance(img.affine, torch.Tensor) else None + if h > dim: + affine = affine[torch.arange(0, h, device=device) != dim - 1] + if w > dim: + affine = affine[:, torch.arange(0, w, device=device) != dim - 1] + if (affine.shape[0] == affine.shape[1]) and not np.linalg.det(convert_to_numpy(affine, wrap_sequence=True)): + warnings.warn(f"After SqueezeDim, img.affine is ill-posed: \n{img.affine}.") + img.affine = affine + return img class DataStats(Transform): @@ -673,8 +718,7 @@ def __init__( if logging.root.getEffectiveLevel() > logging.INFO: # Avoid duplicate stream handlers to be added when multiple DataStats are used in a chain. has_console_handler = any( - hasattr(h, "is_data_stats_handler") and h.is_data_stats_handler # type:ignore[attr-defined] - for h in _logger.handlers + hasattr(h, "is_data_stats_handler") and h.is_data_stats_handler for h in _logger.handlers ) if not has_console_handler: # if the root log level is higher than INFO, set a separate stream handler to record @@ -1021,7 +1065,7 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): """ Convert labels to multi channels based on brats18 classes: label 1 is the necrotic and non-enhancing tumor core - label 2 is the the peritumoral edema + label 2 is the peritumoral edema label 4 is the GD-enhancing tumor The possible classes are TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor). @@ -1060,7 +1104,7 @@ class AddExtremePointsChannel(Randomizable, Transform): ValueError: When label image is not single channel. """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.TORCH] def __init__(self, background: int = 0, pert: float = 0.0) -> None: self._background = background @@ -1392,7 +1436,7 @@ class AddCoordinateChannels(Transform): """ - backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + backend = [TransformBackends.NUMPY] @deprecated_arg( name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead." @@ -1415,108 +1459,139 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: return concatenate((img, coord_channels), axis=0) -class KernelTransform(Transform): +class ImageFilter(Transform): """ - Applies a kernel transformation to the input image. + Applies a convolution filter to the input image. Args: - kernel: - A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. - Available options are: `mean`, `laplacian`, `elliptical`, `gaussian`` - See below for short explanations on every kernel. - kernel_size: - A single integer value specifying the size of the quadratic or cubic kernel. - Computational complexity scales to the power of 2 (2D kernel) or 3 (3D kernel), which - should be considered when choosing kernel size. + filter: + A string specifying the filter, a custom filter as `torch.Tenor` or `np.ndarray` or a `nn.Module`. + Available options for string are: `mean`, `laplace`, `elliptical`, `sobel`, `sharpen`, `median`, `gauss` + See below for short explanations on every filter. + filter_size: + A single integer value specifying the size of the quadratic or cubic filter. + Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which + should be considered when choosing filter size. + kwargs: + Additional arguments passed to filter function, required by `sobel` and `gauss`. + See below for details. Raises: - AssertionError: When `kernel` is a string and `kernel_size` is not specified - AssertionError: When `kernel_size` is not an uneven integer - AssertionError: When `kernel` is an array and `ndim` is not in [1,2,3] - AssertionError: When `kernel` is an array and any dimension has an even shape - NotImplementedError: When `kernel` is a string and not in `self.supported_kernels` + ValueError: When `filter_size` is not an uneven integer + ValueError: When `filter` is an array and `ndim` is not in [1,2,3] + ValueError: When `filter` is an array and any dimension has an even shape + NotImplementedError: When `filter` is a string and not in `self.supported_filters` + KeyError: When necessary `kwargs` are not passed to a filter that requires additional arguments. - ## Mean kernel - > `kernel='mean'` + **Mean Filtering:** ``filter='mean'`` Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image. - Example 2D kernel (5 x 5): + See also py:func:`monai.networks.layers.simplelayers.MeanFilter` + Example 2D filter (5 x 5):: - [1, 1, 1, 1, 1] - [1, 1, 1, 1, 1] - [1, 1, 1, 1, 1] - [1, 1, 1, 1, 1] - [1, 1, 1, 1, 1] + [[1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1]] - If smoothing labels with this kernel, ensure they are in one-hot format. + If smoothing labels with this filter, ensure they are in one-hot format. - ## Laplacian kernel - > `kernel='laplacian'` + **Outline Detection:** ``filter='laplace'`` Laplacian filtering for outline detection in images. Can be used to transform labels to contours. - Example 2D kernel (5x5): + See also py:func:`monai.networks.layers.simplelayers.LaplaceFilter` + + Example 2D filter (5x5):: + + [[-1., -1., -1., -1., -1.], + [-1., -1., -1., -1., -1.], + [-1., -1., 24., -1., -1.], + [-1., -1., -1., -1., -1.], + [-1., -1., -1., -1., -1.]] + + **Dilation:** ``filter='elliptical'`` + + An elliptical filter can be used to dilate labels or label-contours. + Example 2D filter (5x5):: + + [[0., 0., 1., 0., 0.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [0., 0., 1., 0., 0.]] - [-1., -1., -1., -1., -1.] - [-1., -1., -1., -1., -1.] - [-1., -1., 24., -1., -1.] - [-1., -1., -1., -1., -1.] - [-1., -1., -1., -1., -1.] + **Edge Detection:** ``filter='sobel'`` - ## Elliptical kernel - > `kernel='elliptical'` + This filter allows for additional arguments passed as ``kwargs`` during initialization. + See also py:func:`monai.transforms.post.SobelGradients` + ::kwargs:: + spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient + along each of the provide axis. By default it calculate the gradient for all spatial axes. + normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True. + normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False. + padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`. + Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + See ``torch.nn.Conv1d()`` for more information. + dtype: kernel data type (torch.dtype). Defaults to `torch.float32`. - An elliptical kernel can be used to dilate labels or label-contours. - Example 2D kernel (5x5): + **Sharpening:** ``filter='sharpen'`` - [0., 0., 1., 0., 0.] - [1., 1., 1., 1., 1.] - [1., 1., 1., 1., 1.] - [1., 1., 1., 1., 1.] - [0., 0., 1., 0., 0.] + Sharpen an image with a 2D or 3D filter. + Example 2D filter (5x5):: - ## Sobel kernel - > `kernel={'sobel_h', 'sobel_w', ''sobel_d}` + [[ 0., 0., -1., 0., 0.], + [-1., -1., -1., -1., -1.], + [-1., -1., 17., -1., -1.], + [-1., -1., -1., -1., -1.], + [ 0., 0., -1., 0., 0.]] - Edge detection with sobel kernel, along the h,w or d axis of tensor. - Example 2D kernel (5x5) for `sobel_w`: + **Gaussian Smooth:** ``filter='gauss'`` - [-0.25, -0.20, 0.00, 0.20, 0.25] - [-0.40, -0.50, 0.00, 0.50, 0.40] - [-0.50, -1.00, 0.00, 1.00, 0.50] - [-0.40, -0.50, 0.00, 0.50, 0.40] - [-0.25, -0.20, 0.00, 0.20, 0.25] + Blur/smooth an image with 2D or 3D gaussian filter. + This filter requires additional arguments passed as ``kwargs`` during initialization. + See also py:func:`monai.networks.layers.simplelayers.GaussianFilter` + ::kwargs:: + sigma: std. could be a single value, or spatial_dims number of values. + truncated: spreads how many stds. + approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". + - `erf` approximation interpolates the error function; + - `sampled` uses a sampled Gaussian kernel; + - `scalespace` corresponds to - ## Sharpen kernel - > `kernel="sharpen"` + **Median Filter:** ``filter='median'`` - Sharpen an image with a 2D or 3D kernel. - Example 2D kernel (5x5): + Blur an image with 2D or 3D median filter to remove noise. + Useful in image preprocessing to improve results of later processing. + See also py:func:`monai.networks.layers.simplelayers.MedianFilter` + + + **Savitzky Golay Filter:** ``filter = 'savitzky_golay'`` + + Convolve a Tensor along a particular axis with a Savitzky-Golay kernel. + This filter requires additional arguments passed as ``kwargs`` during initialization. + See also py:func:`monai.networks.layers.simplelayers.SavitzkyGolayFilter` + ::kwargs:: + order: Order of the polynomial to fit to each window, must be less than ``window_length``. + axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or + ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. - [ 0., 0., -1., 0., 0.] - [-1., -1., -1., -1., -1.] - [-1., -1., 17., -1., -1.] - [-1., -1., -1., -1., -1.] - [ 0., 0., -1., 0., 0.] """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - supported_kernels = sorted(["mean", "laplacian", "elliptical", "sobel_w", "sobel_h", "sobel_d", "sharpen"]) - - def __init__(self, kernel: Union[str, NdarrayOrTensor], kernel_size: Optional[int] = None) -> None: + supported_filters = sorted(["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"]) - if isinstance(kernel, str): - assert kernel_size, "`kernel_size` must be specified when specifying kernels by string." - assert kernel_size % 2 == 1, "`kernel_size` should be a single uneven integer." - if kernel not in self.supported_kernels: - raise NotImplementedError(f"{kernel}. Supported kernels are {self.supported_kernels}.") - else: - assert kernel.ndim in [1, 2, 3], "Only 1D, 2D, and 3D kernels are supported" - kernel = convert_to_tensor(kernel, dtype=torch.float32) - self._assert_all_values_uneven(kernel.shape) + def __init__( + self, filter: Union[str, NdarrayOrTensor, nn.Module], filter_size: Optional[int] = None, **kwargs + ) -> None: - self.kernel = kernel - self.kernel_size = kernel_size + self._check_filter_format(filter, filter_size) + self._check_kwargs_are_present(filter, **kwargs) + self.filter = filter + self.filter_size = filter_size + self.additional_args_for_filter = kwargs def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: """ @@ -1531,119 +1606,112 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> meta_dict = img.meta img_, prev_type, device = convert_data_type(img, torch.Tensor) ndim = img_.ndim - 1 # assumes channel first format - if isinstance(self.kernel, str): - self.kernel = self._create_kernel_from_string(self.kernel, self.kernel_size, ndim) - img_ = img_.unsqueeze(0) - img_ = apply_filter(img_, self.kernel) # batch, channels, H[, W, D] is required for img_ - img_ = img_[0] + if not callable(self.filter): + self.filter = self._create_filter(self.filter, self.filter_size, ndim) + img_ = self._apply_filter(img_) if meta_dict: img_ = MetaTensor(img_, meta_dict) else: img_, *_ = convert_data_type(img_, prev_type, device) return img_ - def _assert_all_values_uneven(self, x: tuple) -> None: + def _check_all_values_uneven(self, x: tuple) -> None: for value in x: - assert value % 2 == 1, f"Only uneven kernels are supported, but kernel size is {x}" - - def _create_kernel_from_string(self, name, size, ndim) -> torch.Tensor: - """Create an `ndim` kernel of size `(size, ) * ndim`.""" - func = getattr(self, f"_create_{name}_kernel") - kernel = func(size, ndim) - return kernel.to(torch.float32) - - def _create_mean_kernel(self, size, ndim) -> torch.Tensor: - """Create a torch.Tensor with shape (size, ) * ndim with all values equal to `1`""" - return torch.ones([size] * ndim) - - def _create_laplacian_kernel(self, size, ndim) -> torch.Tensor: - """Create a torch.Tensor with shape (size, ) * ndim. - All values are `-1` except the center value which is size**ndim - 1 - """ - kernel = torch.zeros([size] * ndim).float() - 1 # make all -1 - center_point = tuple([size // 2] * ndim) - kernel[center_point] = (size**ndim) - 1 - return kernel - - def _create_elliptical_kernel(self, size: int, ndim: int) -> torch.Tensor: - """Create a torch.Tensor with shape (size, ) * ndim containing a circle/sphere of `1`""" - radius = size // 2 - grid = torch.meshgrid(*[torch.arange(0, size) for _ in range(ndim)]) - squared_distances = torch.stack([(axis - radius) ** 2 for axis in grid], 0).sum(0) - kernel = squared_distances <= radius**2 - return kernel - - def _sobel_2d(self, size) -> torch.Tensor: - """Create a generic 2d sobel kernel""" - numerator = torch.arange(-size // 2 + 1, size // 2 + 1, dtype=torch.float32).unsqueeze(0) - denominator = numerator * numerator - denominator = denominator + denominator.T - denominator[:, size // 2] = 1.0 # to avoid division by zero - return numerator / denominator - - def _sobel_3d(self, size) -> torch.Tensor: - """Create a generic 3d sobel kernel""" - kernel_2d = self._sobel_2d(size) - kernel_3d = torch.stack((kernel_2d,) * size, -1) - adapter = (size // 2) - torch.arange(-size // 2 + 1, size // 2 + 1, dtype=torch.float32).abs() - adapter = adapter / adapter.max() + 1 # scale between 1 - 2 - return kernel_3d * adapter - - def _create_sobel_w_kernel(self, size, ndim) -> torch.Tensor: - """Sobel kernel in x/w direction for Tensor in shape (B,C)[WH[D]]""" - if ndim == 2: - kernel = self._sobel_2d(size) - elif ndim == 3: - kernel = self._sobel_3d(size) + if value % 2 == 0: + raise ValueError(f"Only uneven filters are supported, but filter size is {x}") + + def _check_filter_format( + self, filter: Union[str, NdarrayOrTensor, nn.Module], filter_size: Optional[int] = None + ) -> None: + if isinstance(filter, str): + if not filter_size: + raise ValueError("`filter_size` must be specified when specifying filters by string.") + if filter_size % 2 == 0: + raise ValueError("`filter_size` should be a single uneven integer.") + if filter not in self.supported_filters: + raise NotImplementedError(f"{filter}. Supported filters are {self.supported_filters}.") + elif isinstance(filter, torch.Tensor) or isinstance(filter, np.ndarray): + if filter.ndim not in [1, 2, 3]: + raise ValueError("Only 1D, 2D, and 3D filters are supported.") + self._check_all_values_uneven(filter.shape) # type: ignore + elif isinstance(filter, (nn.Module, Transform)): + pass else: - raise ValueError(f"Only 2 or 3 dimensional kernels are supported. Got {ndim}") - return kernel - - def _create_sobel_h_kernel(self, size, ndim) -> torch.Tensor: - """Sobel kernel in y/h direction for Tensor in shape (B,C)[WH[D]]""" - kernel = self._create_sobel_w_kernel(size, ndim).transpose(0, 1) - return kernel - - def _create_sobel_d_kernel(self, size, ndim) -> torch.Tensor: - """Sobel kernel in z/d direction for Tensor in shape (B,C)[WHD]]""" - assert ndim == 3, "Only 3 dimensional kernels are supported for `sobel_d`" - return self._sobel_3d(size).transpose(1, 2) - - def _create_sharpen_kernel(self, size, ndim) -> torch.Tensor: - """Create a torch.Tensor with shape (size, ) * ndim. - The kernel contains a circle/sphere of `-1`, with the center value beeing - the absolut sum of all non-zero elements in the kernel - """ - kernel = self._create_elliptical_kernel(size, ndim) - center_point = tuple([size // 2] * ndim) - center_value = kernel.sum() - kernel = kernel * -1 - kernel[center_point] = center_value - return kernel - - -class RandKernelTransform(RandomizableTransform): - """Randomly apply a Filterkernel to the input data. + raise TypeError( + f"{type(filter)} is not supported." + "Supported types are `class 'str'`, `class 'torch.Tensor'`, `class 'np.ndarray'`, `class 'torch.nn.modules.module.Module'`, `class 'monai.transforms.Transform'`" + ) + + def _check_kwargs_are_present(self, filter, **kwargs): + if filter == "gauss" and "sigma" not in kwargs.keys(): + raise KeyError("`filter='gauss', requires the additonal keyword argument `sigma`") + if filter == "savitzky_golay" and "order" not in kwargs.keys(): + raise KeyError("`filter='savitzky_golay', requires the additonal keyword argument `order`") + + def _create_filter(self, filter: Union[str, NdarrayOrTensor], size: int, ndim: int) -> nn.Module: + """Create an `ndim` filter of size `(size, ) * ndim`.""" + if isinstance(filter, str): + filter = self._get_filter_from_string(filter, size, ndim) + else: + filter = ApplyFilter(filter) + return filter + + def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> nn.Module: + if filter == "mean": # this would be a great future use for match/case + return MeanFilter(ndim, size) + elif filter == "laplace": + return LaplaceFilter(ndim, size) + elif filter == "elliptical": + return EllipticalFilter(ndim, size) + elif filter == "sobel": + from monai.transforms.post.array import SobelGradients # cannot import on top because of circular imports + allowed_keys = SobelGradients.__init__.__annotations__.keys() + kwargs = {k: v for k,v in self.additional_args_for_filter.items() if k in allowed_keys} + return SobelGradients(size, **kwargs) + elif filter == "sharpen": + return SharpenFilter(ndim, size) + elif filter == "gauss": + allowed_keys = GaussianFilter.__init__.__annotations__.keys() + kwargs = {k: v for k,v in self.additional_args_for_filter.items() if k in allowed_keys} + return GaussianFilter(ndim, **kwargs) + elif filter == "median": + return partial(median_filter, kernel_size = size, spatial_dims = ndim) + elif filter == "savitzky_golay": + allowed_keys = SavitzkyGolayFilter.__init__.__annotations__.keys() + kwargs = {k: v for k,v in self.additional_args_for_filter.items() if k in allowed_keys} + return SavitzkyGolayFilter(size, **kwargs) + + def _apply_filter(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + if isinstance(self.filter, Transform): + return self.filter(img) + else: + return self.filter(img.unsqueeze(0))[0] # add and remove batch dim + + +class RandImageFilter(RandomizableTransform): + """ + Randomly apply a convolutional filter to the input data. + Args: - kernel: - A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`. - Available options are: `mean`, `laplacian`, `elliptical`, `gaussian`` - See below for short explanations on every kernel. - kernel_size: - A single integer value specifying the size of the quadratic or cubic kernel. - Computational complexity scales to the power of 2 (2D kernel) or 3 (3D kernel), which - should be considered when choosing kernel size. + filter: + A string specifying the filter or a custom filter as `torch.Tenor` or `np.ndarray`. + Available options are: `mean`, `laplace`, `elliptical`, `gaussian`` + See below for short explanations on every filter. + filter_size: + A single integer value specifying the size of the quadratic or cubic filter. + Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which + should be considered when choosing filter size. prob: Probability the transform is applied to the data """ - backend = KernelTransform.backend + backend = ImageFilter.backend def __init__( - self, kernel: Union[str, NdarrayOrTensor], kernel_size: Optional[int] = None, prob: float = 0.1 + self, filter: Union[str, NdarrayOrTensor], filter_size: Optional[int] = None, prob: float = 0.1 ) -> None: super().__init__(prob) - self.filter = KernelTransform(kernel, kernel_size) + self.filter = ImageFilter(filter, filter_size) def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: """ diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 92c1aebbf4..609d6c6a66 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -32,6 +32,7 @@ AddChannel, AddCoordinateChannels, AddExtremePointsChannel, + ImageFilter, AsChannelFirst, AsChannelLast, CastToType, @@ -44,7 +45,6 @@ FgBgToIndices, Identity, IntensityStats, - KernelTransform, LabelToMask, Lambda, MapLabelValue, @@ -119,7 +119,7 @@ "IntensityStatsd", "IntensityStatsD", "IntensityStatsDict", - "KernelTransformd", + "ImageFilterd", "LabelToMaskD", "LabelToMaskDict", "LabelToMaskd", @@ -132,7 +132,7 @@ "RandCuCIMd", "RandCuCIMD", "RandCuCIMDict", - "RandKernelTransformd", + "RandImageFilterd", "RandLambdaD", "RandLambdaDict", "RandLambdad", @@ -304,6 +304,7 @@ def __init__( meta_key_postfix: str = DEFAULT_POST_FIX, strict_check: bool = True, allow_missing_keys: bool = False, + channel_dim=None, ) -> None: """ Args: @@ -311,9 +312,13 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` strict_check: whether to raise an error when the meta information is insufficient. allow_missing_keys: don't raise exception if key is missing. + channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array. + It overrides the `original_channel_dim` from provided MetaTensor input. + If the input array doesn't have a channel dim, this value should be ``'no_channel'``. + If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension. """ super().__init__(keys, allow_missing_keys) - self.adjuster = EnsureChannelFirst(strict_check=strict_check) + self.adjuster = EnsureChannelFirst(strict_check=strict_check, channel_dim=channel_dim) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) @@ -375,6 +380,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N class SplitDimd(MapTransform): + + backend = SplitDim.backend + def __init__( self, keys: KeysCollection, @@ -382,6 +390,7 @@ def __init__( dim: int = 0, keepdim: bool = True, update_meta: bool = True, + list_output: bool = False, allow_missing_keys: bool = False, ) -> None: """ @@ -397,15 +406,34 @@ def __init__( dimension will be squeezed. update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to reflect the cropped image + list_output: it `True`, the output will be a list of dictionaries with the same keys as original. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) self.output_postfixes = output_postfixes self.splitter = SplitDim(dim, keepdim, update_meta) + self.list_output = list_output + if self.list_output is None and self.output_postfixes is not None: + raise ValueError("`output_postfixes` should not be provided when `list_output` is `True`.") - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__( + self, data: Mapping[Hashable, torch.Tensor] + ) -> Union[Dict[Hashable, torch.Tensor], List[Dict[Hashable, torch.Tensor]]]: d = dict(data) - for key in self.key_iterator(d): + all_keys = list(set(self.key_iterator(d))) + + if self.list_output: + output = [] + results = [self.splitter(d[key]) for key in all_keys] + for row in zip(*results): + new_dict = dict(zip(all_keys, row)) + # fill in the extra keys with unmodified data + for k in set(d.keys()).difference(set(all_keys)): + new_dict[k] = deepcopy(d[k]) + output.append(new_dict) + return output + + for key in all_keys: rets = self.splitter(d[key]) postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes if len(postfixes) != len(rets): @@ -489,7 +517,7 @@ def __init__( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = True, - track_meta: Optional[bool] = False, + track_meta: Optional[bool] = None, allow_missing_keys: bool = False, ) -> None: """ @@ -500,8 +528,8 @@ def __init__( device: specify the target device to put the Tensor data. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. - track_meta: whether to convert to `MetaTensor`, default to `False`, output type will be `torch.Tensor`. - if `None`, use the return value of ``get_track_meta``. + track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``, + if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`. allow_missing_keys: don't raise exception if key is missing. """ @@ -511,19 +539,19 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): - self.push_transform(d, key) d[key] = self.converter(d[key]) + self.push_transform(d, key) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): + # Remove the applied transform + self.pop_transform(d, key) # Create inverse transform inverse_transform = ToNumpy() # Apply inverse d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) return d @@ -721,7 +749,7 @@ def __init__(self, keys: KeysCollection, sep: str = ".", use_re: Union[Sequence[ See also: :py:class:`monai.transforms.compose.MapTransform` sep: the separator tag to define nested dictionary keys, default to ".". use_re: whether the specified key is a regular expression, it also can be - a list of bool values, map the to keys. + a list of bool values, mapping them to `keys`. """ super().__init__(keys) self.sep = sep @@ -733,7 +761,7 @@ def _delete_item(keys, d, use_re: bool = False): if len(keys) > 1: d[key] = _delete_item(keys[1:], d[key], use_re) return d - return {k: v for k, v in d.items() if (use_re and not re.search(key, k)) or (not use_re and k != key)} + return {k: v for k, v in d.items() if (use_re and not re.search(key, f"{k}")) or (not use_re and k != key)} d = dict(data) for key, use_re in zip(self.keys, self.use_re): @@ -761,16 +789,19 @@ class SqueezeDimd(MapTransform): backend = SqueezeDim.backend - def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, dim: int = 0, update_meta: bool = True, allow_missing_keys: bool = False + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` dim: dimension to be squeezed. Default: 0 (the first dimension) + update_meta: whether to update the meta info if the input is a metatensor. Default is ``True``. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = SqueezeDim(dim=dim) + self.converter = SqueezeDim(dim=dim, update_meta=update_meta) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -1660,9 +1691,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d -class KernelTransformd(MapTransform): +class ImageFilterd(MapTransform): """ - Dictionary-based wrapper of :py:class:`monai.transforms.KernelTransform`. + Dictionary-based wrapper of :py:class:`monai.transforms.ImageFilter`. Args: keys: keys of the corresponding items to be transformed. @@ -1678,7 +1709,7 @@ class KernelTransformd(MapTransform): Don't raise exception if key is missing. """ - backend = KernelTransform.backend + backend = ImageFilter.backend def __init__( self, @@ -1688,7 +1719,7 @@ def __init__( allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.filter = KernelTransform(kernel, kernel_size) + self.filter = ImageFilter(kernel, kernel_size) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -1697,7 +1728,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d -class RandKernelTransformd(MapTransform, RandomizableTransform): +class RandImageFilterd(MapTransform, RandomizableTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandomFilterKernel`. @@ -1717,7 +1748,7 @@ class RandKernelTransformd(MapTransform, RandomizableTransform): Don't raise exception if key is missing. """ - backend = KernelTransform.backend + backend = ImageFilter.backend def __init__( self, @@ -1729,7 +1760,7 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.filter = KernelTransform(kernel, kernel_size) + self.filter = ImageFilter(kernel, kernel_size) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -1740,8 +1771,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d -RandKernelTransformD = RandKernelTransformDict = RandKernelTransformd -KernelTransformD = KernelTransformDict = KernelTransformd +RandImageFilterD = RandImageFilterDict = RandImageFilterd +ImageFilterD = ImageFilterDict = ImageFilterd IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd diff --git a/tests/test_apply_filter.py b/tests/test_apply_filter.py index 3174211f34..62372516a5 100644 --- a/tests/test_apply_filter.py +++ b/tests/test_apply_filter.py @@ -64,7 +64,6 @@ def test_3d(self): ], ] ) - expected = expected # testing shapes k = torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) for kernel in (k, k[None], k[None][None]): diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py new file mode 100644 index 0000000000..89b4131a88 --- /dev/null +++ b/tests/test_image_filter.py @@ -0,0 +1,241 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +import numpy as np +from parameterized import parameterized +from monai.networks.layers.simplelayers import GaussianFilter + +from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd + +EXPECTED_KERNELS = { + "mean": torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).float(), + "laplacian": torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]).float(), + "elliptical": torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]]).float(), + "sharpen": torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]).float(), +} + +SUPPORTED_KERNELS = ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] +SAMPLE_IMAGE_2D = torch.randn(1, 10, 10) +SAMPLE_IMAGE_3D = torch.randn(1, 10, 10, 10) +SAMPLE_DICT = {"image_2d": SAMPLE_IMAGE_2D, "image_3d": SAMPLE_IMAGE_3D} + +ADDITIONAL_ARGUMENTS = { + "order": 1, + "sigma": 1 + } + +class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, x): + return x + 1 + +class TestNotAModuleOrTransform: + pass + +class TestImageFilter(unittest.TestCase): + @parameterized.expand(SUPPORTED_KERNELS) + def test_init_from_string(self, kernel_name): + "Test init from string" + _ = ImageFilter(kernel_name, 3, **ADDITIONAL_ARGUMENTS) + + def test_init_raises(self): + with self.assertRaises(Exception) as context: + _ = ImageFilter("mean") + self.assertTrue( + "`filter_size` must be specified when specifying filters by string." in str(context.output) + ) + with self.assertRaises(Exception) as context: + _ = ImageFilter("mean") + self.assertTrue( + "`filter_size` should be a single uneven integer." in str(context.output) + ) + with self.assertRaises(Exception) as context: + _ = ImageFilter("gauss", 3) + self.assertTrue( + "`filter='gauss', requires the additonal keyword argument `sigma`" in str(context.output) + ) + with self.assertRaises(Exception) as context: + _ = ImageFilter("savitzky_golay", 3) + self.assertTrue( + "`filter='savitzky_golay', requires the additonal keyword argument `order`" in str(context.output) + ) + + def test_init_from_array(self): + "Test init with custom kernel and assert wrong kernel shape throws an error" + _ = ImageFilter(torch.ones(3, 3)) + _ = ImageFilter(torch.ones(3, 3, 3)) + _ = ImageFilter(np.ones((3, 3))) + _ = ImageFilter(np.ones((3, 3, 3))) + + with self.assertRaises(Exception) as context: + _ = ImageFilter(torch.ones(3, 3, 3, 3)) + self.assertTrue( + "Only 1D, 2D, and 3D filters are supported." in str(context.output) + ) + + def test_init_from_module(self): + filter = ImageFilter(TestModule()) + out = filter(torch.zeros(1,3,3,3)) + torch.testing.assert_allclose(torch.ones(1,3,3,3), out) + + def test_init_from_transform(self): + _ = ImageFilter(GaussianFilter(3, sigma = 2)) + + def test_init_from_wrong_type_fails(self): + with self.assertRaises(Exception) as context: + _ = ImageFilter(TestNotAModuleOrTransform()) + self.assertTrue( + " is not supported." in str(context.output) + ) + + @parameterized.expand(EXPECTED_KERNELS.keys()) + def test_2d_kernel_correctness(self, kernel_name): + "Test correctness of kernels (2d only)" + tfm = ImageFilter(kernel_name, kernel_size=3) + kernel = tfm._create_kernel_from_string(kernel_name, size=3, ndim=2).squeeze() + torch.testing.assert_allclose(kernel, EXPECTED_KERNELS[kernel_name]) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_2d(self, kernel_name): + "Text function `__call__` for 2d images" + filter = ImageFilter(kernel_name, 3) + if kernel_name != "sobel_d": # sobel_d does not support 2d + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape, SAMPLE_IMAGE_2D.shape) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_3d(self, kernel_name): + "Text function `__call__` for 3d images" + filter = ImageFilter(kernel_name, 3) + out_tensor = filter(SAMPLE_IMAGE_3D) + self.assertEqual(out_tensor.shape, SAMPLE_IMAGE_3D.shape) + + +class TestImageFilterDict(unittest.TestCase): + @parameterized.expand(SUPPORTED_KERNELS) + def test_init_from_string_dict(self, kernel_name): + "Test init from string and assert an error is thrown if no size is passed" + _ = ImageFilterd("image", kernel_name, 3) + with self.assertRaises(Exception) as context: # noqa F841 + _ = ImageFilterd(self.image_key, kernel_name) + + def test_init_from_array_dict(self): + "Test init with custom kernel and assert wrong kernel shape throws an error" + _ = ImageFilterd("image", torch.ones(3, 3)) + with self.assertRaises(Exception) as context: # noqa F841 + _ = ImageFilterd(self.image_key, torch.ones(3, 3, 3, 3)) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_2d(self, kernel_name): + "Text function `__call__` for 2d images" + filter = ImageFilterd("image_2d", kernel_name, 3) + if kernel_name != "sobel_d": # sobel_d does not support 2d + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_2d"].shape, SAMPLE_IMAGE_2D.shape) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_3d(self, kernel_name): + "Text function `__call__` for 3d images" + filter = ImageFilterd("image_3d", kernel_name, 3) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_3d"].shape, SAMPLE_IMAGE_3D.shape) + + +class TestRandImageFilter(unittest.TestCase): + @parameterized.expand(SUPPORTED_KERNELS) + def test_init_from_string(self, kernel_name): + "Test init from string and assert an error is thrown if no size is passed" + _ = RandImageFilter(kernel_name, 3) + with self.assertRaises(Exception) as context: # noqa F841 + _ = RandImageFilter(kernel_name) + + def test_init_from_array(self): + "Test init with custom kernel and assert wrong kernel shape throws an error" + _ = RandImageFilter(torch.ones(3, 3)) + with self.assertRaises(Exception) as context: # noqa F841 + _ = RandImageFilter(torch.ones(3, 3, 3, 3)) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_2d_prob_1(self, kernel_name): + "Text function `__call__` for 2d images" + filter = RandImageFilter(kernel_name, 3, 1) + if kernel_name != "sobel_d": # sobel_d does not support 2d + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape, SAMPLE_IMAGE_2D.shape) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_3d_prob_1(self, kernel_name): + "Text function `__call__` for 3d images" + filter = RandImageFilter(kernel_name, 3, 1) + out_tensor = filter(SAMPLE_IMAGE_3D) + self.assertEqual(out_tensor.shape, SAMPLE_IMAGE_3D.shape) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_2d_prob_0(self, kernel_name): + "Text function `__call__` for 2d images" + filter = RandImageFilter(kernel_name, 3, 0) + if kernel_name != "sobel_d": # sobel_d does not support 2d + out_tensor = filter(SAMPLE_IMAGE_2D) + torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_2D) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_3d_prob_0(self, kernel_name): + "Text function `__call__` for 3d images" + filter = RandImageFilter(kernel_name, 3, 0) + out_tensor = filter(SAMPLE_IMAGE_3D) + torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_3D) + + +class TestRandImageFilterDict(unittest.TestCase): + @parameterized.expand(SUPPORTED_KERNELS) + def test_init_from_string_dict(self, kernel_name): + "Test init from string and assert an error is thrown if no size is passed" + _ = RandImageFilterd("image", kernel_name, 3) + with self.assertRaises(Exception) as context: # noqa F841 + _ = RandImageFilterd("image", kernel_name) + + def test_init_from_array_dict(self): + "Test init with custom kernel and assert wrong kernel shape throws an error" + _ = RandImageFilterd("image", torch.ones(3, 3)) + with self.assertRaises(Exception) as context: # noqa F841 + _ = RandImageFilterd("image", torch.ones(3, 3, 3, 3)) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_2d_prob_1(self, kernel_name): + filter = RandImageFilterd("image_2d", kernel_name, 3, 1.0) + if kernel_name != "sobel_d": # sobel_d does not support 2d + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_2d"].shape, SAMPLE_IMAGE_2D.shape) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_3d_prob_1(self, kernel_name): + filter = RandImageFilterd("image_3d", kernel_name, 3, 1.0) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_3d"].shape, SAMPLE_IMAGE_3D.shape) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_2d_prob_0(self, kernel_name): + filter = RandImageFilterd("image_2d", kernel_name, 3, 0.0) + if kernel_name != "sobel_d": # sobel_d does not support 2d + out_tensor = filter(SAMPLE_DICT) + torch.testing.assert_allclose(out_tensor["image_2d"].shape, SAMPLE_IMAGE_2D.shape) + + @parameterized.expand(SUPPORTED_KERNELS) + def test_call_3d_prob_0(self, kernel_name): + filter = RandImageFilterd("image_3d", kernel_name, 3, 0.0) + out_tensor = filter(SAMPLE_DICT) + torch.testing.assert_allclose(out_tensor["image_3d"].shape, SAMPLE_IMAGE_3D.shape) \ No newline at end of file diff --git a/tests/test_mean_filter.py b/tests/test_mean_filter.py new file mode 100644 index 0000000000..2cf934397a --- /dev/null +++ b/tests/test_mean_filter.py @@ -0,0 +1,40 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import MeanFilter + +TEST_CASES = [ + {"spatial_dims": 3, "size": 3, "expected": torch.ones(3, 3, 3)}, + {"spatial_dims": 2, "size": 5, "expected": torch.ones(5, 5)}, +] + + +class MedianFilterTestCase(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_init(self, spatial_dims, size, expected): + mean_filter = MeanFilter(spatial_dims=spatial_dims, size=size) + self.assertEqual(expected, mean_filter.filter) + self.assertIsInstance(mean_filter, torch.nn.Module) + + def test_forward(self): + mean_filter = MeanFilter(spatial_dims=2, size=3) + input = torch.ones(1, 1, 5, 5) + output = mean_filter(input) + self.assertEqual(input, output) + + +if __name__ == "__main__": + unittest.main() From 32aa2618cddbb16c203c23b47c6cc30374e8245d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Nov 2022 19:57:57 +0000 Subject: [PATCH 03/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/utility/array.py | 4 ++-- tests/test_image_filter.py | 30 +++++++++++++++--------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3a297f8ca5..11e30dd804 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1682,9 +1682,9 @@ def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> nn.Modul return SavitzkyGolayFilter(size, **kwargs) def _apply_filter(self, img: NdarrayOrTensor) -> NdarrayOrTensor: - if isinstance(self.filter, Transform): + if isinstance(self.filter, Transform): return self.filter(img) - else: + else: return self.filter(img.unsqueeze(0))[0] # add and remove batch dim diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 89b4131a88..4c03b26c28 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -32,18 +32,18 @@ SAMPLE_DICT = {"image_2d": SAMPLE_IMAGE_2D, "image_3d": SAMPLE_IMAGE_3D} ADDITIONAL_ARGUMENTS = { - "order": 1, + "order": 1, "sigma": 1 } class TestModule(torch.nn.Module): - def __init__(self): - super(TestModule, self).__init__() - - def forward(self, x): + def __init__(self): + super().__init__() + + def forward(self, x): return x + 1 -class TestNotAModuleOrTransform: +class TestNotAModuleOrTransform: pass class TestImageFilter(unittest.TestCase): @@ -53,22 +53,22 @@ def test_init_from_string(self, kernel_name): _ = ImageFilter(kernel_name, 3, **ADDITIONAL_ARGUMENTS) def test_init_raises(self): - with self.assertRaises(Exception) as context: + with self.assertRaises(Exception) as context: _ = ImageFilter("mean") self.assertTrue( "`filter_size` must be specified when specifying filters by string." in str(context.output) ) - with self.assertRaises(Exception) as context: + with self.assertRaises(Exception) as context: _ = ImageFilter("mean") self.assertTrue( "`filter_size` should be a single uneven integer." in str(context.output) ) - with self.assertRaises(Exception) as context: + with self.assertRaises(Exception) as context: _ = ImageFilter("gauss", 3) self.assertTrue( "`filter='gauss', requires the additonal keyword argument `sigma`" in str(context.output) ) - with self.assertRaises(Exception) as context: + with self.assertRaises(Exception) as context: _ = ImageFilter("savitzky_golay", 3) self.assertTrue( "`filter='savitzky_golay', requires the additonal keyword argument `order`" in str(context.output) @@ -80,13 +80,13 @@ def test_init_from_array(self): _ = ImageFilter(torch.ones(3, 3, 3)) _ = ImageFilter(np.ones((3, 3))) _ = ImageFilter(np.ones((3, 3, 3))) - + with self.assertRaises(Exception) as context: _ = ImageFilter(torch.ones(3, 3, 3, 3)) self.assertTrue( "Only 1D, 2D, and 3D filters are supported." in str(context.output) ) - + def test_init_from_module(self): filter = ImageFilter(TestModule()) out = filter(torch.zeros(1,3,3,3)) @@ -95,8 +95,8 @@ def test_init_from_module(self): def test_init_from_transform(self): _ = ImageFilter(GaussianFilter(3, sigma = 2)) - def test_init_from_wrong_type_fails(self): - with self.assertRaises(Exception) as context: + def test_init_from_wrong_type_fails(self): + with self.assertRaises(Exception) as context: _ = ImageFilter(TestNotAModuleOrTransform()) self.assertTrue( " is not supported." in str(context.output) @@ -238,4 +238,4 @@ def test_call_2d_prob_0(self, kernel_name): def test_call_3d_prob_0(self, kernel_name): filter = RandImageFilterd("image_3d", kernel_name, 3, 0.0) out_tensor = filter(SAMPLE_DICT) - torch.testing.assert_allclose(out_tensor["image_3d"].shape, SAMPLE_IMAGE_3D.shape) \ No newline at end of file + torch.testing.assert_allclose(out_tensor["image_3d"].shape, SAMPLE_IMAGE_3D.shape) From fc42b96db125c287ad0a9e4c3fd0efd303480226 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 30 Dec 2022 17:40:56 +0100 Subject: [PATCH 04/18] add missing unit tests Signed-off-by: kbressem --- monai/networks/layers/__init__.py | 5 + monai/networks/layers/simplelayers.py | 29 ++---- monai/transforms/utility/array.py | 18 ++-- monai/transforms/utility/dictionary.py | 2 +- tests/test_image_filter.py | 38 +++---- tests/test_mean_filter.py | 40 ------- tests/test_preset_filters.py | 138 +++++++++++++++++++++++++ 7 files changed, 179 insertions(+), 91 deletions(-) delete mode 100644 tests/test_mean_filter.py create mode 100644 tests/test_preset_filters.py diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 31bd36dd8f..5d91cf66f6 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -16,13 +16,18 @@ from .gmm import GaussianMixtureModel from .simplelayers import ( LLTM, + ApplyFilter, ChannelPad, + EllipticalFilter, Flatten, GaussianFilter, HilbertTransform, + LaplaceFilter, + MeanFilter, MedianFilter, Reshape, SavitzkyGolayFilter, + SharpenFilter, SkipConnection, apply_filter, median_filter, diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 00cfea34d3..840c977e83 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -662,7 +662,7 @@ def forward(self, input, state): class ApplyFilter(nn.Module): - "Apply a convolutional filter to an image" + "Wrapper class to apply a filter to an image." def __init__(self, filter: NdarrayOrTensor) -> None: super().__init__() @@ -670,11 +670,7 @@ def __init__(self, filter: NdarrayOrTensor) -> None: filter = convert_to_tensor(filter, dtype=torch.float32) self.filter = filter - def forward(self, x: torch.Tensor): - """ - Args: - x: in shape B, C, [H, W, [D]] - """ + def forward(self, x: torch.Tensor) -> torch.Tensor: return apply_filter(x, self.filter) @@ -691,6 +687,7 @@ def __init__(self, spatial_dims: int, size: int) -> None: size: edge length of the filter """ filter = torch.ones([size] * spatial_dims) + filter = filter super().__init__(filter=filter) @@ -716,7 +713,7 @@ def __init__(self, spatial_dims: int, size: int) -> None: class EllipticalFilter(ApplyFilter): """ Elliptical filter, can be used to dilate labels or label-contours. - The ellipical filter used here, is a `torch.Tensor` with shape (size, ) * ndim containing a circle/sphere of `1` + The elliptical filter used here, is a `torch.Tensor` with shape (size, ) * ndim containing a circle/sphere of `1` """ def __init__(self, spatial_dims: int, size: int) -> None: @@ -732,11 +729,11 @@ def __init__(self, spatial_dims: int, size: int) -> None: super().__init__(filter=filter) -class SharpenFilter(ApplyFilter): +class SharpenFilter(EllipticalFilter): """ Convolutional filter to sharpen a 2D or 3D image. - The filter used contains a circle/sphere of `-1`, with the center value beeing - the absolut sum of all non-zero elements in the kernel + The filter used contains a circle/sphere of `-1`, with the center value being + the absolute sum of all non-zero elements in the kernel """ def __init__(self, spatial_dims: int, size: int) -> None: @@ -745,12 +742,8 @@ def __init__(self, spatial_dims: int, size: int) -> None: spatial_dims: `int` of either 2 for 2D images and 3 for 3D images size: edge length of the filter """ - radius = size // 2 - grid = torch.meshgrid(*[torch.arange(0, size) for _ in range(spatial_dims)]) - squared_distances = torch.stack([(axis - radius) ** 2 for axis in grid], 0).sum(0) - filter = squared_distances <= radius**2 + super().__init__(spatial_dims=spatial_dims, size=size) center_point = tuple([size // 2] * spatial_dims) - center_value = filter.sum() - filter = filter * -1 - filter[center_point] = center_value - super().__init__(filter=filter) + center_value = self.filter.sum() + self.filter *= -1 + self.filter[center_point] = center_value diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 11e30dd804..bc20483e14 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -13,12 +13,12 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -from functools import partial import logging import sys import time import warnings from copy import deepcopy +from functools import partial from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -31,7 +31,6 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import no_collation from monai.networks.layers.simplelayers import ( - ApplyFilter, EllipticalFilter, GaussianFilter, LaplaceFilter, @@ -1581,7 +1580,9 @@ class ImageFilter(Transform): """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - supported_filters = sorted(["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"]) + supported_filters = sorted( + ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] + ) def __init__( self, filter: Union[str, NdarrayOrTensor, nn.Module], filter_size: Optional[int] = None, **kwargs @@ -1657,7 +1658,7 @@ def _create_filter(self, filter: Union[str, NdarrayOrTensor], size: int, ndim: i return filter def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> nn.Module: - if filter == "mean": # this would be a great future use for match/case + if filter == "mean": return MeanFilter(ndim, size) elif filter == "laplace": return LaplaceFilter(ndim, size) @@ -1665,20 +1666,21 @@ def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> nn.Modul return EllipticalFilter(ndim, size) elif filter == "sobel": from monai.transforms.post.array import SobelGradients # cannot import on top because of circular imports + allowed_keys = SobelGradients.__init__.__annotations__.keys() - kwargs = {k: v for k,v in self.additional_args_for_filter.items() if k in allowed_keys} + kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} return SobelGradients(size, **kwargs) elif filter == "sharpen": return SharpenFilter(ndim, size) elif filter == "gauss": allowed_keys = GaussianFilter.__init__.__annotations__.keys() - kwargs = {k: v for k,v in self.additional_args_for_filter.items() if k in allowed_keys} + kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} return GaussianFilter(ndim, **kwargs) elif filter == "median": - return partial(median_filter, kernel_size = size, spatial_dims = ndim) + return partial(median_filter, kernel_size=size, spatial_dims=ndim) elif filter == "savitzky_golay": allowed_keys = SavitzkyGolayFilter.__init__.__annotations__.keys() - kwargs = {k: v for k,v in self.additional_args_for_filter.items() if k in allowed_keys} + kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} return SavitzkyGolayFilter(size, **kwargs) def _apply_filter(self, img: NdarrayOrTensor) -> NdarrayOrTensor: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 609d6c6a66..1c24d5af7b 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -32,7 +32,6 @@ AddChannel, AddCoordinateChannels, AddExtremePointsChannel, - ImageFilter, AsChannelFirst, AsChannelLast, CastToType, @@ -44,6 +43,7 @@ EnsureType, FgBgToIndices, Identity, + ImageFilter, IntensityStats, LabelToMask, Lambda, diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 4c03b26c28..d8e9f35627 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -12,11 +12,11 @@ import unittest -import torch import numpy as np +import torch from parameterized import parameterized -from monai.networks.layers.simplelayers import GaussianFilter +from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd EXPECTED_KERNELS = { @@ -31,10 +31,8 @@ SAMPLE_IMAGE_3D = torch.randn(1, 10, 10, 10) SAMPLE_DICT = {"image_2d": SAMPLE_IMAGE_2D, "image_3d": SAMPLE_IMAGE_3D} -ADDITIONAL_ARGUMENTS = { - "order": 1, - "sigma": 1 - } +ADDITIONAL_ARGUMENTS = {"order": 1, "sigma": 1} + class TestModule(torch.nn.Module): def __init__(self): @@ -43,9 +41,11 @@ def __init__(self): def forward(self, x): return x + 1 + class TestNotAModuleOrTransform: pass + class TestImageFilter(unittest.TestCase): @parameterized.expand(SUPPORTED_KERNELS) def test_init_from_string(self, kernel_name): @@ -55,19 +55,13 @@ def test_init_from_string(self, kernel_name): def test_init_raises(self): with self.assertRaises(Exception) as context: _ = ImageFilter("mean") - self.assertTrue( - "`filter_size` must be specified when specifying filters by string." in str(context.output) - ) + self.assertTrue("`filter_size` must be specified when specifying filters by string." in str(context.output)) with self.assertRaises(Exception) as context: _ = ImageFilter("mean") - self.assertTrue( - "`filter_size` should be a single uneven integer." in str(context.output) - ) + self.assertTrue("`filter_size` should be a single uneven integer." in str(context.output)) with self.assertRaises(Exception) as context: _ = ImageFilter("gauss", 3) - self.assertTrue( - "`filter='gauss', requires the additonal keyword argument `sigma`" in str(context.output) - ) + self.assertTrue("`filter='gauss', requires the additonal keyword argument `sigma`" in str(context.output)) with self.assertRaises(Exception) as context: _ = ImageFilter("savitzky_golay", 3) self.assertTrue( @@ -83,24 +77,20 @@ def test_init_from_array(self): with self.assertRaises(Exception) as context: _ = ImageFilter(torch.ones(3, 3, 3, 3)) - self.assertTrue( - "Only 1D, 2D, and 3D filters are supported." in str(context.output) - ) + self.assertTrue("Only 1D, 2D, and 3D filters are supported." in str(context.output)) def test_init_from_module(self): filter = ImageFilter(TestModule()) - out = filter(torch.zeros(1,3,3,3)) - torch.testing.assert_allclose(torch.ones(1,3,3,3), out) + out = filter(torch.zeros(1, 3, 3, 3)) + torch.testing.assert_allclose(torch.ones(1, 3, 3, 3), out) def test_init_from_transform(self): - _ = ImageFilter(GaussianFilter(3, sigma = 2)) + _ = ImageFilter(GaussianFilter(3, sigma=2)) def test_init_from_wrong_type_fails(self): with self.assertRaises(Exception) as context: _ = ImageFilter(TestNotAModuleOrTransform()) - self.assertTrue( - " is not supported." in str(context.output) - ) + self.assertTrue(" is not supported." in str(context.output)) @parameterized.expand(EXPECTED_KERNELS.keys()) def test_2d_kernel_correctness(self, kernel_name): diff --git a/tests/test_mean_filter.py b/tests/test_mean_filter.py deleted file mode 100644 index 2cf934397a..0000000000 --- a/tests/test_mean_filter.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks.layers import MeanFilter - -TEST_CASES = [ - {"spatial_dims": 3, "size": 3, "expected": torch.ones(3, 3, 3)}, - {"spatial_dims": 2, "size": 5, "expected": torch.ones(5, 5)}, -] - - -class MedianFilterTestCase(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_init(self, spatial_dims, size, expected): - mean_filter = MeanFilter(spatial_dims=spatial_dims, size=size) - self.assertEqual(expected, mean_filter.filter) - self.assertIsInstance(mean_filter, torch.nn.Module) - - def test_forward(self): - mean_filter = MeanFilter(spatial_dims=2, size=3) - input = torch.ones(1, 1, 5, 5) - output = mean_filter(input) - self.assertEqual(input, output) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_preset_filters.py b/tests/test_preset_filters.py new file mode 100644 index 0000000000..88eca1044f --- /dev/null +++ b/tests/test_preset_filters.py @@ -0,0 +1,138 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import EllipticalFilter, LaplaceFilter, MeanFilter, SharpenFilter + +TEST_CASES_MEAN = [ + (3, 3, torch.ones(3, 3, 3)), + (2, 5, torch.ones(5, 5)), +] + + +TEST_CASES_LAPLACE = [ + ( + 3, + 3, + torch.Tensor( + [ + [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]], + [[-1, -1, -1], [-1, 26, -1], [-1, -1, -1]], + [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]], + ] + ), + ), + ( + 2, + 3, + torch.Tensor( + [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], + ), + ), +] + +TEST_CASES_ELLIPTICAL = [ + ( + 3, + 3, + torch.Tensor( + [ + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], + [[0, 1, 0], [1, 1, 1], [0, 1, 0]], + [[0, 0, 0], [0, 1, 0], [0, 0, 0]], + ] + ), + ), + ( + 2, + 3, + torch.Tensor( + [[0, 1, 0], [1, 1, 1], [0, 1, 0]], + ), + ), +] + + +TEST_CASES_SHARPEN = [ + ( + 3, + 3, + torch.Tensor( + [ + [[0, 0, 0], [0, -1, 0], [0, 0, 0]], + [[0, -1, 0], [-1, 7, -1], [0, -1, 0]], + [[0, 0, 0], [0, -1, 0], [0, 0, 0]], + ] + ), + ), + ( + 2, + 3, + torch.Tensor( + [[0, -1, 0], [-1, 5, -1], [0, -1, 0]], + ), + ), +] + + +class _TestFilter: + def test_init(self, spatial_dims, size, expected): + test_filter = self.filter_class(spatial_dims=spatial_dims, size=size) + torch.testing.assert_allclose(expected, test_filter.filter) + self.assertIsInstance(test_filter, torch.nn.Module) + + def test_forward(self): + test_filter = self.filter_class(spatial_dims=2, size=3) + input = torch.ones(1, 1, 5, 5) + _ = test_filter(input) + + +class MeanFilterTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = MeanFilter + + @parameterized.expand(TEST_CASES_MEAN) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +class LaplaceFilterTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = LaplaceFilter + + @parameterized.expand(TEST_CASES_LAPLACE) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +class EllipticalTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = EllipticalFilter + + @parameterized.expand(TEST_CASES_ELLIPTICAL) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + + +class SharpenTestCase(_TestFilter, unittest.TestCase): + def setUp(self) -> None: + self.filter_class = SharpenFilter + + @parameterized.expand(TEST_CASES_SHARPEN) + def test_init(self, spatial_dims, size, expected): + super().test_init(spatial_dims, size, expected) + +if __name__ == "__main__": + unittest.main() From 4d1ac7c1d25f66ce08994c0bed171012e7819054 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 30 Dec 2022 20:48:39 +0100 Subject: [PATCH 05/18] runtest autofix Signed-off-by: kbressem --- monai/transforms/__init__.py | 6 +++--- monai/transforms/utility/array.py | 5 +++-- tests/test_preset_filters.py | 36 ++++++------------------------- 3 files changed, 12 insertions(+), 35 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index a4d9f8265c..1fa03c0317 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -555,12 +555,12 @@ Identityd, IdentityD, IdentityDict, - IntensityStatsd, - IntensityStatsD, - IntensityStatsDict, ImageFilterd, ImageFilterD, ImageFilterDict, + IntensityStatsd, + IntensityStatsD, + IntensityStatsDict, LabelToMaskd, LabelToMaskD, LabelToMaskDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 0276c92f49..abd7d1b9a5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,6 +31,7 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import no_collation from monai.networks.layers.simplelayers import ( + ApplyFilter, EllipticalFilter, GaussianFilter, LaplaceFilter, @@ -1590,7 +1591,7 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> self.filter = self._create_filter(self.filter, self.filter_size, ndim) img_ = self._apply_filter(img_) if meta_dict: - img_ = MetaTensor(img_, meta_dict) + img_ = MetaTensor(img_, meta = meta_dict) else: img_, *_ = convert_data_type(img_, prev_type, device) return img_ @@ -1628,7 +1629,7 @@ def _check_kwargs_are_present(self, filter, **kwargs): if filter == "savitzky_golay" and "order" not in kwargs.keys(): raise KeyError("`filter='savitzky_golay', requires the additonal keyword argument `order`") - def _create_filter(self, filter: Union[str, NdarrayOrTensor], size: int, ndim: int) -> nn.Module: + def _create_filter(self, filter: Union[str, NdarrayOrTensor], size: Optional[int], ndim: Optional[int]) -> nn.Module: """Create an `ndim` filter of size `(size, ) * ndim`.""" if isinstance(filter, str): filter = self._get_filter_from_string(filter, size, ndim) diff --git a/tests/test_preset_filters.py b/tests/test_preset_filters.py index 88eca1044f..e2dc8e761e 100644 --- a/tests/test_preset_filters.py +++ b/tests/test_preset_filters.py @@ -16,10 +16,7 @@ from monai.networks.layers import EllipticalFilter, LaplaceFilter, MeanFilter, SharpenFilter -TEST_CASES_MEAN = [ - (3, 3, torch.ones(3, 3, 3)), - (2, 5, torch.ones(5, 5)), -] +TEST_CASES_MEAN = [(3, 3, torch.ones(3, 3, 3)), (2, 5, torch.ones(5, 5))] TEST_CASES_LAPLACE = [ @@ -34,13 +31,7 @@ ] ), ), - ( - 2, - 3, - torch.Tensor( - [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], - ), - ), + (2, 3, torch.Tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])), ] TEST_CASES_ELLIPTICAL = [ @@ -48,20 +39,10 @@ 3, 3, torch.Tensor( - [ - [[0, 0, 0], [0, 1, 0], [0, 0, 0]], - [[0, 1, 0], [1, 1, 1], [0, 1, 0]], - [[0, 0, 0], [0, 1, 0], [0, 0, 0]], - ] - ), - ), - ( - 2, - 3, - torch.Tensor( - [[0, 1, 0], [1, 1, 1], [0, 1, 0]], + [[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [1, 1, 1], [0, 1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]] ), ), + (2, 3, torch.Tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]])), ] @@ -77,13 +58,7 @@ ] ), ), - ( - 2, - 3, - torch.Tensor( - [[0, -1, 0], [-1, 5, -1], [0, -1, 0]], - ), - ), + (2, 3, torch.Tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])), ] @@ -134,5 +109,6 @@ def setUp(self) -> None: def test_init(self, spatial_dims, size, expected): super().test_init(spatial_dims, size, expected) + if __name__ == "__main__": unittest.main() From 71ccc10c5bc52b8b7ddd951165c978823f3bd884 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 30 Dec 2022 22:26:08 +0100 Subject: [PATCH 06/18] black Signed-off-by: kbressem --- monai/transforms/utility/array.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index abd7d1b9a5..65bba5ae2a 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1591,7 +1591,7 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> self.filter = self._create_filter(self.filter, self.filter_size, ndim) img_ = self._apply_filter(img_) if meta_dict: - img_ = MetaTensor(img_, meta = meta_dict) + img_ = MetaTensor(img_, meta=meta_dict) else: img_, *_ = convert_data_type(img_, prev_type, device) return img_ @@ -1629,7 +1629,9 @@ def _check_kwargs_are_present(self, filter, **kwargs): if filter == "savitzky_golay" and "order" not in kwargs.keys(): raise KeyError("`filter='savitzky_golay', requires the additonal keyword argument `order`") - def _create_filter(self, filter: Union[str, NdarrayOrTensor], size: Optional[int], ndim: Optional[int]) -> nn.Module: + def _create_filter( + self, filter: Union[str, NdarrayOrTensor], size: Optional[int], ndim: Optional[int] + ) -> nn.Module: """Create an `ndim` filter of size `(size, ) * ndim`.""" if isinstance(filter, str): filter = self._get_filter_from_string(filter, size, ndim) From f59f7fc4636a847878be67acd898f66d369864d8 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 30 Dec 2022 22:29:13 +0100 Subject: [PATCH 07/18] reduce line length in ImageFilter Signed-off-by: kbressem --- monai/transforms/utility/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 65bba5ae2a..4f783837de 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1620,7 +1620,8 @@ def _check_filter_format( else: raise TypeError( f"{type(filter)} is not supported." - "Supported types are `class 'str'`, `class 'torch.Tensor'`, `class 'np.ndarray'`, `class 'torch.nn.modules.module.Module'`, `class 'monai.transforms.Transform'`" + "Supported types are `class 'str'`, `class 'torch.Tensor'`, `class 'np.ndarray'`, " + "`class 'torch.nn.modules.module.Module'`, `class 'monai.transforms.Transform'`" ) def _check_kwargs_are_present(self, filter, **kwargs): From d265c82aaad3c6a57c4a52b585743d29337afe5f Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 30 Dec 2022 23:42:01 +0100 Subject: [PATCH 08/18] fix mypy errors Signed-off-by: kbressem --- monai/transforms/utility/array.py | 33 ++++++++++++++----------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 4f783837de..e3ac124eef 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1574,7 +1574,7 @@ def __init__( self.filter_size = filter_size self.additional_args_for_filter = kwargs - def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: + def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Dict] = None) -> NdarrayOrTensor: """ Args: img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] @@ -1587,8 +1587,12 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> meta_dict = img.meta img_, prev_type, device = convert_data_type(img, torch.Tensor) ndim = img_.ndim - 1 # assumes channel first format - if not callable(self.filter): - self.filter = self._create_filter(self.filter, self.filter_size, ndim) + + if isinstance(self.filter, str): + self.filter = self._get_filter_from_string(self.filter, self.filter_size, ndim) # type: ignore + elif isinstance(self.filter, (torch.Tensor, np.ndarray)): + self.filter = ApplyFilter(self.filter) + img_ = self._apply_filter(img_) if meta_dict: img_ = MetaTensor(img_, meta=meta_dict) @@ -1630,17 +1634,7 @@ def _check_kwargs_are_present(self, filter, **kwargs): if filter == "savitzky_golay" and "order" not in kwargs.keys(): raise KeyError("`filter='savitzky_golay', requires the additonal keyword argument `order`") - def _create_filter( - self, filter: Union[str, NdarrayOrTensor], size: Optional[int], ndim: Optional[int] - ) -> nn.Module: - """Create an `ndim` filter of size `(size, ) * ndim`.""" - if isinstance(filter, str): - filter = self._get_filter_from_string(filter, size, ndim) - else: - filter = ApplyFilter(filter) - return filter - - def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> nn.Module: + def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> Union[nn.Module, Callable]: if filter == "mean": return MeanFilter(ndim, size) elif filter == "laplace": @@ -1649,7 +1643,6 @@ def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> nn.Modul return EllipticalFilter(ndim, size) elif filter == "sobel": from monai.transforms.post.array import SobelGradients # cannot import on top because of circular imports - allowed_keys = SobelGradients.__init__.__annotations__.keys() kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} return SobelGradients(size, **kwargs) @@ -1665,12 +1658,16 @@ def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> nn.Modul allowed_keys = SavitzkyGolayFilter.__init__.__annotations__.keys() kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} return SavitzkyGolayFilter(size, **kwargs) + else: + raise NotImplementedError(f"Filter {filter} not implemented") - def _apply_filter(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def _apply_filter(self, img: torch.Tensor) -> torch.Tensor: if isinstance(self.filter, Transform): - return self.filter(img) + img = self.filter(img) else: - return self.filter(img.unsqueeze(0))[0] # add and remove batch dim + img = self.filter(img.unsqueeze(0)) # type: ignore + img = img[0] # add and remove batch dim + return img class RandImageFilter(RandomizableTransform): From d9f484ecd6fa8c0aff87d44ca60e5148b2363425 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 Dec 2022 22:42:33 +0000 Subject: [PATCH 09/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/utility/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index e3ac124eef..ad48d22ad6 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1590,7 +1590,7 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Dict] = None) -> Nd if isinstance(self.filter, str): self.filter = self._get_filter_from_string(self.filter, self.filter_size, ndim) # type: ignore - elif isinstance(self.filter, (torch.Tensor, np.ndarray)): + elif isinstance(self.filter, (torch.Tensor, np.ndarray)): self.filter = ApplyFilter(self.filter) img_ = self._apply_filter(img_) @@ -1658,7 +1658,7 @@ def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> Union[nn allowed_keys = SavitzkyGolayFilter.__init__.__annotations__.keys() kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} return SavitzkyGolayFilter(size, **kwargs) - else: + else: raise NotImplementedError(f"Filter {filter} not implemented") def _apply_filter(self, img: torch.Tensor) -> torch.Tensor: From 8fd75fd8e715c0cbb6ea0098cdc410aefa45a671 Mon Sep 17 00:00:00 2001 From: kbressem Date: Fri, 30 Dec 2022 23:43:52 +0100 Subject: [PATCH 10/18] black Signed-off-by: kbressem --- monai/transforms/utility/array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ad48d22ad6..1d0fa6de83 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1643,6 +1643,7 @@ def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> Union[nn return EllipticalFilter(ndim, size) elif filter == "sobel": from monai.transforms.post.array import SobelGradients # cannot import on top because of circular imports + allowed_keys = SobelGradients.__init__.__annotations__.keys() kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys} return SobelGradients(size, **kwargs) From 7648a6cf68f0210a464f31cfe7663f928ef8fe63 Mon Sep 17 00:00:00 2001 From: kbressem Date: Sat, 31 Dec 2022 12:09:10 +0100 Subject: [PATCH 11/18] fix unit tests and codestyle Signed-off-by: kbressem --- monai/transforms/utility/array.py | 6 +- monai/transforms/utility/dictionary.py | 6 +- tests/test_image_filter.py | 188 ++++++++++++------------- 3 files changed, 101 insertions(+), 99 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 1d0fa6de83..9e1c047b57 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1691,16 +1691,18 @@ class RandImageFilter(RandomizableTransform): backend = ImageFilter.backend def __init__( - self, filter: Union[str, NdarrayOrTensor], filter_size: Optional[int] = None, prob: float = 0.1 + self, filter: Union[str, NdarrayOrTensor], filter_size: Optional[int] = None, prob: float = 0.1, **kwargs ) -> None: super().__init__(prob) - self.filter = ImageFilter(filter, filter_size) + self.filter = ImageFilter(filter, filter_size, **kwargs) def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: """ Args: img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] meta_dict: An optional dictionary with metadata + kwargs: optional arguments required by specific filters. E.g. `sigma`if filter is `gauss`. + see py:func:`monai.transforms.utility.array.ImageFilter` for more details Returns: A MetaTensor with the same shape as `img` and identical metadata diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index a6ac2c1437..c16a456f74 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1767,9 +1767,10 @@ def __init__( kernel: Union[str, NdarrayOrTensor], kernel_size: Optional[int] = None, allow_missing_keys: bool = False, + **kwargs, ) -> None: super().__init__(keys, allow_missing_keys) - self.filter = ImageFilter(kernel, kernel_size) + self.filter = ImageFilter(kernel, kernel_size, **kwargs) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -1807,10 +1808,11 @@ def __init__( kernel_size: Optional[int] = None, prob: float = 0.1, allow_missing_keys: bool = False, + **kwargs, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.filter = ImageFilter(kernel, kernel_size) + self.filter = ImageFilter(kernel, kernel_size, **kwargs) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index d8e9f35627..6df9ea03c9 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -19,14 +19,14 @@ from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd -EXPECTED_KERNELS = { +EXPECTED_FILTERS = { "mean": torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).float(), - "laplacian": torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]).float(), + "laplace": torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]).float(), "elliptical": torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]]).float(), "sharpen": torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]).float(), } -SUPPORTED_KERNELS = ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] +SUPPORTED_FILTERS = ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] SAMPLE_IMAGE_2D = torch.randn(1, 10, 10) SAMPLE_IMAGE_3D = torch.randn(1, 10, 10, 10) SAMPLE_DICT = {"image_2d": SAMPLE_IMAGE_2D, "image_3d": SAMPLE_IMAGE_3D} @@ -47,10 +47,10 @@ class TestNotAModuleOrTransform: class TestImageFilter(unittest.TestCase): - @parameterized.expand(SUPPORTED_KERNELS) - def test_init_from_string(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string(self, filter_name): "Test init from string" - _ = ImageFilter(kernel_name, 3, **ADDITIONAL_ARGUMENTS) + _ = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) def test_init_raises(self): with self.assertRaises(Exception) as context: @@ -69,7 +69,7 @@ def test_init_raises(self): ) def test_init_from_array(self): - "Test init with custom kernel and assert wrong kernel shape throws an error" + "Test init with custom filter and assert wrong filter shape throws an error" _ = ImageFilter(torch.ones(3, 3)) _ = ImageFilter(torch.ones(3, 3, 3)) _ = ImageFilter(np.ones((3, 3))) @@ -92,140 +92,138 @@ def test_init_from_wrong_type_fails(self): _ = ImageFilter(TestNotAModuleOrTransform()) self.assertTrue(" is not supported." in str(context.output)) - @parameterized.expand(EXPECTED_KERNELS.keys()) - def test_2d_kernel_correctness(self, kernel_name): - "Test correctness of kernels (2d only)" - tfm = ImageFilter(kernel_name, kernel_size=3) - kernel = tfm._create_kernel_from_string(kernel_name, size=3, ndim=2).squeeze() - torch.testing.assert_allclose(kernel, EXPECTED_KERNELS[kernel_name]) + @parameterized.expand(EXPECTED_FILTERS.keys()) + def test_2d_filter_correctness(self, filter_name): + "Test correctness of filters (2d only)" + tfm = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + filter = tfm._get_filter_from_string(filter_name, size=3, ndim=2).filter.squeeze() + torch.testing.assert_allclose(filter, EXPECTED_FILTERS[filter_name]) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_2d(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d(self, filter_name): "Text function `__call__` for 2d images" - filter = ImageFilter(kernel_name, 3) - if kernel_name != "sobel_d": # sobel_d does not support 2d - out_tensor = filter(SAMPLE_IMAGE_2D) - self.assertEqual(out_tensor.shape, SAMPLE_IMAGE_2D.shape) + filter = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:]) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_3d(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d(self, filter_name): "Text function `__call__` for 3d images" - filter = ImageFilter(kernel_name, 3) + filter = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) out_tensor = filter(SAMPLE_IMAGE_3D) - self.assertEqual(out_tensor.shape, SAMPLE_IMAGE_3D.shape) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) class TestImageFilterDict(unittest.TestCase): - @parameterized.expand(SUPPORTED_KERNELS) - def test_init_from_string_dict(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string_dict(self, filter_name): "Test init from string and assert an error is thrown if no size is passed" - _ = ImageFilterd("image", kernel_name, 3) - with self.assertRaises(Exception) as context: # noqa F841 - _ = ImageFilterd(self.image_key, kernel_name) + _ = ImageFilterd("image", filter_name, 3, **ADDITIONAL_ARGUMENTS) + with self.assertRaises(Exception) as _: + _ = ImageFilterd(self.image_key, filter_name) def test_init_from_array_dict(self): - "Test init with custom kernel and assert wrong kernel shape throws an error" + "Test init with custom filter and assert wrong filter shape throws an error" _ = ImageFilterd("image", torch.ones(3, 3)) - with self.assertRaises(Exception) as context: # noqa F841 + with self.assertRaises(Exception) as _: _ = ImageFilterd(self.image_key, torch.ones(3, 3, 3, 3)) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_2d(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d(self, filter_name): "Text function `__call__` for 2d images" - filter = ImageFilterd("image_2d", kernel_name, 3) - if kernel_name != "sobel_d": # sobel_d does not support 2d - out_tensor = filter(SAMPLE_DICT) - self.assertEqual(out_tensor["image_2d"].shape, SAMPLE_IMAGE_2D.shape) + filter = ImageFilterd("image_2d", filter_name, 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_2d"].shape[1:], SAMPLE_IMAGE_2D.shape[1:]) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_3d(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d(self, filter_name): "Text function `__call__` for 3d images" - filter = ImageFilterd("image_3d", kernel_name, 3) + filter = ImageFilterd("image_3d", filter_name, 3, **ADDITIONAL_ARGUMENTS) out_tensor = filter(SAMPLE_DICT) - self.assertEqual(out_tensor["image_3d"].shape, SAMPLE_IMAGE_3D.shape) + self.assertEqual(out_tensor["image_3d"].shape[1:], SAMPLE_IMAGE_3D.shape[1:]) class TestRandImageFilter(unittest.TestCase): - @parameterized.expand(SUPPORTED_KERNELS) - def test_init_from_string(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string(self, filter_name): "Test init from string and assert an error is thrown if no size is passed" - _ = RandImageFilter(kernel_name, 3) - with self.assertRaises(Exception) as context: # noqa F841 - _ = RandImageFilter(kernel_name) + _ = RandImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS) + with self.assertRaises(Exception) as _: + _ = RandImageFilter(filter_name) def test_init_from_array(self): - "Test init with custom kernel and assert wrong kernel shape throws an error" + "Test init with custom filter and assert wrong filter shape throws an error" _ = RandImageFilter(torch.ones(3, 3)) - with self.assertRaises(Exception) as context: # noqa F841 + with self.assertRaises(Exception) as _: _ = RandImageFilter(torch.ones(3, 3, 3, 3)) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_2d_prob_1(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_1(self, filter_name): "Text function `__call__` for 2d images" - filter = RandImageFilter(kernel_name, 3, 1) - if kernel_name != "sobel_d": # sobel_d does not support 2d - out_tensor = filter(SAMPLE_IMAGE_2D) - self.assertEqual(out_tensor.shape, SAMPLE_IMAGE_2D.shape) + filter = RandImageFilter(filter_name, 3, 1, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_2D) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:]) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_3d_prob_1(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_1(self, filter_name): "Text function `__call__` for 3d images" - filter = RandImageFilter(kernel_name, 3, 1) + filter = RandImageFilter(filter_name, 3, 1, **ADDITIONAL_ARGUMENTS) out_tensor = filter(SAMPLE_IMAGE_3D) - self.assertEqual(out_tensor.shape, SAMPLE_IMAGE_3D.shape) + self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_2d_prob_0(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_0(self, filter_name): "Text function `__call__` for 2d images" - filter = RandImageFilter(kernel_name, 3, 0) - if kernel_name != "sobel_d": # sobel_d does not support 2d - out_tensor = filter(SAMPLE_IMAGE_2D) - torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_2D) + filter = RandImageFilter(filter_name, 3, 0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_IMAGE_2D) + torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_2D) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_3d_prob_0(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_0(self, filter_name): "Text function `__call__` for 3d images" - filter = RandImageFilter(kernel_name, 3, 0) + filter = RandImageFilter(filter_name, 3, 0, **ADDITIONAL_ARGUMENTS) out_tensor = filter(SAMPLE_IMAGE_3D) torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_3D) class TestRandImageFilterDict(unittest.TestCase): - @parameterized.expand(SUPPORTED_KERNELS) - def test_init_from_string_dict(self, kernel_name): + @parameterized.expand(SUPPORTED_FILTERS) + def test_init_from_string_dict(self, filter_name): "Test init from string and assert an error is thrown if no size is passed" - _ = RandImageFilterd("image", kernel_name, 3) - with self.assertRaises(Exception) as context: # noqa F841 - _ = RandImageFilterd("image", kernel_name) + _ = RandImageFilterd("image", filter_name, 3, **ADDITIONAL_ARGUMENTS) + with self.assertRaises(Exception) as _: + _ = RandImageFilterd("image", filter_name) def test_init_from_array_dict(self): - "Test init with custom kernel and assert wrong kernel shape throws an error" + "Test init with custom filter and assert wrong filter shape throws an error" _ = RandImageFilterd("image", torch.ones(3, 3)) - with self.assertRaises(Exception) as context: # noqa F841 + with self.assertRaises(Exception) as _: _ = RandImageFilterd("image", torch.ones(3, 3, 3, 3)) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_2d_prob_1(self, kernel_name): - filter = RandImageFilterd("image_2d", kernel_name, 3, 1.0) - if kernel_name != "sobel_d": # sobel_d does not support 2d - out_tensor = filter(SAMPLE_DICT) - self.assertEqual(out_tensor["image_2d"].shape, SAMPLE_IMAGE_2D.shape) + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_1(self, filter_name): + filter = RandImageFilterd("image_2d", filter_name, 3, 1.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + self.assertEqual(out_tensor["image_2d"].shape[1:], SAMPLE_IMAGE_2D.shape[1:]) - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_3d_prob_1(self, kernel_name): - filter = RandImageFilterd("image_3d", kernel_name, 3, 1.0) + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_1(self, filter_name): + filter = RandImageFilterd("image_3d", filter_name, 3, 1.0, **ADDITIONAL_ARGUMENTS) out_tensor = filter(SAMPLE_DICT) - self.assertEqual(out_tensor["image_3d"].shape, SAMPLE_IMAGE_3D.shape) - - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_2d_prob_0(self, kernel_name): - filter = RandImageFilterd("image_2d", kernel_name, 3, 0.0) - if kernel_name != "sobel_d": # sobel_d does not support 2d - out_tensor = filter(SAMPLE_DICT) - torch.testing.assert_allclose(out_tensor["image_2d"].shape, SAMPLE_IMAGE_2D.shape) - - @parameterized.expand(SUPPORTED_KERNELS) - def test_call_3d_prob_0(self, kernel_name): - filter = RandImageFilterd("image_3d", kernel_name, 3, 0.0) + self.assertEqual(out_tensor["image_3d"].shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_2d_prob_0(self, filter_name): + filter = RandImageFilterd("image_2d", filter_name, 3, 0.0, **ADDITIONAL_ARGUMENTS) out_tensor = filter(SAMPLE_DICT) - torch.testing.assert_allclose(out_tensor["image_3d"].shape, SAMPLE_IMAGE_3D.shape) + torch.testing.assert_allclose(out_tensor["image_2d"].shape[1:], SAMPLE_IMAGE_2D.shape[1:]) + + @parameterized.expand(SUPPORTED_FILTERS) + def test_call_3d_prob_0(self, filter_name): + filter = RandImageFilterd("image_3d", filter_name, 3, 0.0, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(SAMPLE_DICT) + torch.testing.assert_allclose(out_tensor["image_3d"].shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + + +if __name__ == "__main__": + unittest.main() From 203fbbec5684bf6fac58a71193bc1a6dec54c73a Mon Sep 17 00:00:00 2001 From: kbressem Date: Sat, 31 Dec 2022 12:32:01 +0100 Subject: [PATCH 12/18] update transforms.rst in docs Signed-off-by: kbressem --- docs/source/transforms.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index c92d6fe46a..650fb84e14 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1134,6 +1134,17 @@ Utility :members: :special-members: __call__ +`ImageFilter` +""""""""""""" +.. autoclass:: ImageFilter + :members: + :special-members: __call__ + +`RandImageFilter` +""""""""""""" +.. autoclass:: RandImageFilter + :members: + :special-members: __call__ Dictionary Transforms --------------------- @@ -2124,6 +2135,19 @@ Utility (Dict) :members: :special-members: __call__ +`ImageFilterd` +"""""""""""""""""""""""" +.. autoclass:: ImageFilterd + :members: + :special-members: __call__ + +`RandImageFilterd` +"""""""""""""""""""""""" +.. autoclass:: RandImageFilterd + :members: + :special-members: __call__ + + MetaTensor ^^^^^^^^^^ From f753373fe3c73de8eccb726f1e8f64df5f8853de Mon Sep 17 00:00:00 2001 From: kbressem Date: Sat, 31 Dec 2022 12:50:10 +0100 Subject: [PATCH 13/18] increase length of title underline Signed-off-by: kbressem --- docs/source/transforms.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 650fb84e14..c0b8af71e1 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1141,7 +1141,7 @@ Utility :special-members: __call__ `RandImageFilter` -""""""""""""" +""""""""""""""""" .. autoclass:: RandImageFilter :members: :special-members: __call__ @@ -2136,13 +2136,13 @@ Utility (Dict) :special-members: __call__ `ImageFilterd` -"""""""""""""""""""""""" +"""""""""""""" .. autoclass:: ImageFilterd :members: :special-members: __call__ `RandImageFilterd` -"""""""""""""""""""""""" +"""""""""""""""""" .. autoclass:: RandImageFilterd :members: :special-members: __call__ From 34d82d6e3a3ae2bcfcfff0bf2cc271685744e65e Mon Sep 17 00:00:00 2001 From: kbressem Date: Sat, 31 Dec 2022 13:21:18 +0100 Subject: [PATCH 14/18] remove indent in docstring Signed-off-by: kbressem --- monai/transforms/utility/array.py | 36 +++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9e1c047b57..da2978b102 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1506,14 +1506,14 @@ class ImageFilter(Transform): This filter allows for additional arguments passed as ``kwargs`` during initialization. See also py:func:`monai.transforms.post.SobelGradients` ::kwargs:: - spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient - along each of the provide axis. By default it calculate the gradient for all spatial axes. - normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True. - normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False. - padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`. - Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. - See ``torch.nn.Conv1d()`` for more information. - dtype: kernel data type (torch.dtype). Defaults to `torch.float32`. + spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient + along each of the provide axis. By default it calculate the gradient for all spatial axes. + normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True. + normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False. + padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`. + Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + See ``torch.nn.Conv1d()`` for more information. + dtype: kernel data type (torch.dtype). Defaults to `torch.float32`. **Sharpening:** ``filter='sharpen'`` @@ -1532,12 +1532,12 @@ class ImageFilter(Transform): This filter requires additional arguments passed as ``kwargs`` during initialization. See also py:func:`monai.networks.layers.simplelayers.GaussianFilter` ::kwargs:: - sigma: std. could be a single value, or spatial_dims number of values. - truncated: spreads how many stds. - approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". - - `erf` approximation interpolates the error function; - - `sampled` uses a sampled Gaussian kernel; - - `scalespace` corresponds to + sigma: std. could be a single value, or spatial_dims number of values. + truncated: spreads how many stds. + approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". + - `erf` approximation interpolates the error function; + - `sampled` uses a sampled Gaussian kernel; + - `scalespace` corresponds to **Median Filter:** ``filter='median'`` @@ -1552,10 +1552,10 @@ class ImageFilter(Transform): This filter requires additional arguments passed as ``kwargs`` during initialization. See also py:func:`monai.networks.layers.simplelayers.SavitzkyGolayFilter` ::kwargs:: - order: Order of the polynomial to fit to each window, must be less than ``window_length``. - axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). - mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or - ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. + order: Order of the polynomial to fit to each window, must be less than ``window_length``. + axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or + ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. """ From 702329854741923267a4cf50cfdcbd6691c9d88d Mon Sep 17 00:00:00 2001 From: kbressem Date: Sat, 31 Dec 2022 16:51:23 +0100 Subject: [PATCH 15/18] fix docstring, add ignore for mypy Signed-off-by: kbressem --- monai/networks/layers/simplelayers.py | 5 +- monai/transforms/utility/array.py | 67 +++++++++++++++------------ 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 83b02e6746..ca7997e272 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -667,11 +667,10 @@ class ApplyFilter(nn.Module): def __init__(self, filter: NdarrayOrTensor) -> None: super().__init__() - filter = convert_to_tensor(filter, dtype=torch.float32) - self.filter = filter + self.filter = convert_to_tensor(filter, dtype=torch.float32) def forward(self, x: torch.Tensor) -> torch.Tensor: - return apply_filter(x, self.filter) + return apply_filter(x, self.filter) # type: ignore class MeanFilter(ApplyFilter): diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index da2978b102..02688d6604 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1444,23 +1444,23 @@ class ImageFilter(Transform): Args: filter: - A string specifying the filter, a custom filter as `torch.Tenor` or `np.ndarray` or a `nn.Module`. - Available options for string are: `mean`, `laplace`, `elliptical`, `sobel`, `sharpen`, `median`, `gauss` + A string specifying the filter, a custom filter as ``torch.Tenor`` or ``np.ndarray`` or a ``nn.Module``. + Available options for string are: ``mean``, ``laplace``, ``elliptical``, ``sobel``, ``sharpen``, ``median``, ``gauss`` See below for short explanations on every filter. filter_size: A single integer value specifying the size of the quadratic or cubic filter. Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which should be considered when choosing filter size. kwargs: - Additional arguments passed to filter function, required by `sobel` and `gauss`. + Additional arguments passed to filter function, required by ``sobel`` and ``gauss``. See below for details. Raises: - ValueError: When `filter_size` is not an uneven integer - ValueError: When `filter` is an array and `ndim` is not in [1,2,3] - ValueError: When `filter` is an array and any dimension has an even shape - NotImplementedError: When `filter` is a string and not in `self.supported_filters` - KeyError: When necessary `kwargs` are not passed to a filter that requires additional arguments. + ValueError: When ``filter_size`` is not an uneven integer + ValueError: When ``filter`` is an array and ``ndim`` is not in [1,2,3] + ValueError: When ``filter`` is an array and any dimension has an even shape + NotImplementedError: When ``filter`` is a string and not in ``self.supported_filters`` + KeyError: When necessary ``kwargs`` are not passed to a filter that requires additional arguments. **Mean Filtering:** ``filter='mean'`` @@ -1490,6 +1490,7 @@ class ImageFilter(Transform): [-1., -1., -1., -1., -1.], [-1., -1., -1., -1., -1.]] + **Dilation:** ``filter='elliptical'`` An elliptical filter can be used to dilate labels or label-contours. @@ -1501,19 +1502,24 @@ class ImageFilter(Transform): [1., 1., 1., 1., 1.], [0., 0., 1., 0., 0.]] + **Edge Detection:** ``filter='sobel'`` This filter allows for additional arguments passed as ``kwargs`` during initialization. See also py:func:`monai.transforms.post.SobelGradients` - ::kwargs:: - spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient - along each of the provide axis. By default it calculate the gradient for all spatial axes. - normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True. - normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False. - padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`. - Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. - See ``torch.nn.Conv1d()`` for more information. - dtype: kernel data type (torch.dtype). Defaults to `torch.float32`. + + *kwargs* + + * ``spatial_axes``: the axes that define the direction of the gradient to be calculated. + It calculates the gradient along each of the provide axis. + By default it calculate the gradient for all spatial axes. + * ``normalize_kernels``: if normalize the Sobel kernel to provide proper gradients. Defaults to True. + * ``normalize_gradients``: if normalize the output gradient to 0 and 1. Defaults to False. + * ``padding_mode``: the padding mode of the image when convolving with Sobel kernels. Defaults to ``"reflect"``. + Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + See ``torch.nn.Conv1d()`` for more information. + * ``dtype``: kernel data type (torch.dtype). Defaults to ``torch.float32``. + **Sharpening:** ``filter='sharpen'`` @@ -1526,18 +1532,19 @@ class ImageFilter(Transform): [-1., -1., -1., -1., -1.], [ 0., 0., -1., 0., 0.]] + **Gaussian Smooth:** ``filter='gauss'`` Blur/smooth an image with 2D or 3D gaussian filter. This filter requires additional arguments passed as ``kwargs`` during initialization. See also py:func:`monai.networks.layers.simplelayers.GaussianFilter` - ::kwargs:: - sigma: std. could be a single value, or spatial_dims number of values. - truncated: spreads how many stds. - approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". - - `erf` approximation interpolates the error function; - - `sampled` uses a sampled Gaussian kernel; - - `scalespace` corresponds to + + *kwargs* + + * ``sigma``: std. could be a single value, or spatial_dims number of values. + * ``truncated``: spreads how many stds. + * ``approx``: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". + **Median Filter:** ``filter='median'`` @@ -1551,11 +1558,13 @@ class ImageFilter(Transform): Convolve a Tensor along a particular axis with a Savitzky-Golay kernel. This filter requires additional arguments passed as ``kwargs`` during initialization. See also py:func:`monai.networks.layers.simplelayers.SavitzkyGolayFilter` - ::kwargs:: - order: Order of the polynomial to fit to each window, must be less than ``window_length``. - axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). - mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or - ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. + + *kwargs* + + * ``order``: Order of the polynomial to fit to each window, must be less than ``window_length``. + * ``axis``: (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). + * ``mode``: (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or + ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. """ From 9caf72ae50e5ec6c19edab6a2fe7df2e65e6da76 Mon Sep 17 00:00:00 2001 From: kbressem Date: Sat, 31 Dec 2022 23:03:43 +0100 Subject: [PATCH 16/18] change default padding for SobelGradient when running tests for torch 1.8 or 1.9 Signed-off-by: kbressem --- tests/test_image_filter.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 6df9ea03c9..a0819c840b 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -14,6 +14,7 @@ import numpy as np import torch +from packaging import version from parameterized import parameterized from monai.networks.layers.simplelayers import GaussianFilter @@ -33,6 +34,10 @@ ADDITIONAL_ARGUMENTS = {"order": 1, "sigma": 1} +if version.parse(torch.__version__) < version.parse("1.10"): + # Sobel filter uses reflect as default which is not implemented for 3d in torch 1.8.1 or 1.9.1 + ADDITIONAL_ARGUMENTS["padding_mode"] = "zeros" + class TestModule(torch.nn.Module): def __init__(self): From 7fc2064744eec6e448ca796950e86f9c8dffb544 Mon Sep 17 00:00:00 2001 From: kbressem Date: Sat, 31 Dec 2022 23:15:26 +0100 Subject: [PATCH 17/18] add tests for ApplyFilter Signed-off-by: kbressem --- tests/test_preset_filters.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/test_preset_filters.py b/tests/test_preset_filters.py index e2dc8e761e..0a6a4e8c50 100644 --- a/tests/test_preset_filters.py +++ b/tests/test_preset_filters.py @@ -14,7 +14,7 @@ import torch from parameterized import parameterized -from monai.networks.layers import EllipticalFilter, LaplaceFilter, MeanFilter, SharpenFilter +from monai.networks.layers import ApplyFilter, EllipticalFilter, LaplaceFilter, MeanFilter, SharpenFilter TEST_CASES_MEAN = [(3, 3, torch.ones(3, 3, 3)), (2, 5, torch.ones(5, 5))] @@ -74,6 +74,22 @@ def test_forward(self): _ = test_filter(input) +class TestApplyFilter(unittest.TestCase): + def test_init_and_forward_2d(self): + filter_2d = torch.ones(3, 3) + image_2d = torch.ones(1, 3, 3) + apply_filter_2d = ApplyFilter(filter_2d) + out = apply_filter_2d(image_2d) + self.assertEqual(out.shape, image_2d.shape) + + def test_init_and_forward_3d(self): + filter_2d = torch.ones(3, 3, 3) + image_2d = torch.ones(1, 3, 3, 3) + apply_filter_2d = ApplyFilter(filter_2d) + out = apply_filter_2d(image_2d) + self.assertEqual(out.shape, image_2d.shape) + + class MeanFilterTestCase(_TestFilter, unittest.TestCase): def setUp(self) -> None: self.filter_class = MeanFilter From 37f7802d7bd4cf94dea5c379a542b635a7a23f31 Mon Sep 17 00:00:00 2001 From: kbressem Date: Sat, 31 Dec 2022 23:19:45 +0100 Subject: [PATCH 18/18] remove packaging from imports Signed-off-by: kbressem --- tests/test_image_filter.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index a0819c840b..007b5e8e2c 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -14,7 +14,6 @@ import numpy as np import torch -from packaging import version from parameterized import parameterized from monai.networks.layers.simplelayers import GaussianFilter @@ -32,11 +31,8 @@ SAMPLE_IMAGE_3D = torch.randn(1, 10, 10, 10) SAMPLE_DICT = {"image_2d": SAMPLE_IMAGE_2D, "image_3d": SAMPLE_IMAGE_3D} -ADDITIONAL_ARGUMENTS = {"order": 1, "sigma": 1} - -if version.parse(torch.__version__) < version.parse("1.10"): - # Sobel filter uses reflect as default which is not implemented for 3d in torch 1.8.1 or 1.9.1 - ADDITIONAL_ARGUMENTS["padding_mode"] = "zeros" +# Sobel filter uses reflect pad as default which is not implemented for 3d in torch 1.8.1 or 1.9.1 +ADDITIONAL_ARGUMENTS = {"order": 1, "sigma": 1, "padding_mode": "zeros"} class TestModule(torch.nn.Module):