Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,19 @@ Layers
.. autoclass:: Flatten
:members:

`Reshape`
~~~~~~~~~
.. autoclass:: Reshape
:members:

`separable_filtering`
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: separable_filtering

`apply_filter`
~~~~~~~~~~~~~~
.. autofunction:: apply_filter

`GaussianFilter`
~~~~~~~~~~~~~~~~
.. autoclass:: GaussianFilter
Expand Down
1 change: 1 addition & 0 deletions monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Reshape,
SavitzkyGolayFilter,
SkipConnection,
apply_filter,
separable_filtering,
)
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
Expand Down
88 changes: 81 additions & 7 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SkipMode,
look_up_option,
optional_import,
version_leq,
)
from monai.utils.misc import issequenceiterable

Expand All @@ -35,15 +36,16 @@
fft, _ = optional_import("torch.fft")

__all__ = [
"SkipConnection",
"ChannelPad",
"Flatten",
"GaussianFilter",
"HilbertTransform",
"LLTM",
"Reshape",
"separable_filtering",
"SavitzkyGolayFilter",
"HilbertTransform",
"ChannelPad",
"SkipConnection",
"apply_filter",
"separable_filtering",
]


Expand Down Expand Up @@ -211,25 +213,97 @@ def separable_filtering(x: torch.Tensor, kernels: List[torch.Tensor], mode: str
Args:
x: the input image. must have shape (batch, channels, H[, W, ...]).
kernels: kernel along each spatial dimension.
could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels.
could be a single kernel (duplicated for all spatial dimensions), or
a list of `spatial_dims` number of kernels.
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
or ``'circular'``. Default: ``'zeros'``. Modes other than ``'zeros'`` require PyTorch version >= 1.5.1. See
torch.nn.Conv1d() for more information.

Raises:
TypeError: When ``x`` is not a ``torch.Tensor``.

Examples:

.. code-block:: python

>>> import torch
>>> from monai.networks.layers import separable_filtering
>>> img = torch.randn(2, 4, 32, 32) # batch_size 2, channels 4, 32x32 2D images
# applying a [-1, 0, 1] filter along each of the spatial dimensions.
# the output shape is the same as the input shape.
>>> out = separable_filtering(img, torch.tensor((-1., 0., 1.)))
# applying `[-1, 0, 1]`, `[1, 0, -1]` filters along two spatial dimensions respectively.
# the output shape is the same as the input shape.
>>> out = separable_filtering(img, [torch.tensor((-1., 0., 1.)), torch.tensor((1., 0., -1.))])

"""

if not isinstance(x, torch.Tensor):
raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")

spatial_dims = len(x.shape) - 2
_kernels = [s.float() for s in kernels]
if isinstance(kernels, torch.Tensor):
kernels = [kernels] * spatial_dims
_kernels = [s.to(x) for s in kernels]
_paddings = [(k.shape[0] - 1) // 2 for k in _kernels]
n_chs = x.shape[1]
pad_mode = "constant" if mode == "zeros" else mode

return _separable_filtering_conv(x, kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs)
return _separable_filtering_conv(x, _kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs)


def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Filtering `x` with `kernel` independently for each batch and channel respectively.

Args:
x: the input image, must have shape (batch, channels, H[, W, D]).
kernel: `kernel` must at least have the spatial shape (H_k[, W_k, D_k]).
`kernel` shape must be broadcastable to the `batch` and `channels` dimensions of `x`.
kwargs: keyword arguments passed to `conv*d()` functions.

Returns:
The filtered `x`.

Examples:

.. code-block:: python

>>> import torch
>>> from monai.networks.layers import apply_filter
>>> img = torch.rand(2, 5, 10, 10) # batch_size 2, channels 5, 10x10 2D images
>>> out = apply_filter(img, torch.rand(3, 3)) # spatial kernel
>>> out = apply_filter(img, torch.rand(5, 3, 3)) # channel-wise kernels
>>> out = apply_filter(img, torch.rand(2, 5, 3, 3)) # batch-, channel-wise kernels

"""
if not isinstance(x, torch.Tensor):
raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")
batch, chns, *spatials = x.shape
n_spatial = len(spatials)
if n_spatial > 3:
raise NotImplementedError(f"Only spatial dimensions up to 3 are supported but got {n_spatial}.")
k_size = len(kernel.shape)
if k_size < n_spatial or k_size > n_spatial + 2:
raise ValueError(
f"kernel must have {n_spatial} ~ {n_spatial + 2} dimensions to match the input shape {x.shape}."
)
kernel = kernel.to(x)
# broadcast kernel size to (batch chns, spatial_kernel_size)
kernel = kernel.expand(batch, chns, *kernel.shape[(k_size - n_spatial) :])
kernel = kernel.reshape(-1, 1, *kernel.shape[2:]) # group=1
x = x.view(1, kernel.shape[0], *spatials)
conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1]
if "padding" not in kwargs:
if version_leq(torch.__version__, "1.10.0b"):
# even-sized kernels are not supported
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
else:
kwargs["padding"] = "same"
if "stride" not in kwargs:
kwargs["stride"] = 1
output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs)
return output.view(batch, chns, *output.shape[2:])


class SavitzkyGolayFilter(nn.Module):
Expand Down
85 changes: 46 additions & 39 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@

import numpy as np
import torch
import torch.nn.functional as F

from monai.config.type_definitions import NdarrayOrTensor
from monai.networks import one_hot
from monai.networks.layers import GaussianFilter
from monai.networks.layers import GaussianFilter, apply_filter
from monai.transforms.transform import Transform
from monai.transforms.utils import fill_holes, get_largest_connected_component_mask
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
Expand Down Expand Up @@ -70,11 +69,11 @@ def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional

def __call__(
self,
img: torch.Tensor,
img: NdarrayOrTensor,
sigmoid: Optional[bool] = None,
softmax: Optional[bool] = None,
other: Optional[Callable] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
sigmoid: whether to execute sigmoid function on model output before transform.
Expand All @@ -96,17 +95,18 @@ def __call__(
raise TypeError(f"other must be None or callable but is {type(other).__name__}.")

# convert to float as activation must operate on float tensor
img = img.float()
img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore
if sigmoid or self.sigmoid:
img = torch.sigmoid(img)
img_t = torch.sigmoid(img_t)
if softmax or self.softmax:
img = torch.softmax(img, dim=0)
img_t = torch.softmax(img_t, dim=0)

act_func = self.other if other is None else other
if act_func is not None:
img = act_func(img)

return img
img_t = act_func(img_t)
out, *_ = convert_to_dst_type(img_t, img)
return out


class AsDiscrete(Transform):
Expand Down Expand Up @@ -164,15 +164,15 @@ def __init__(
@deprecated_arg("n_classes", since="0.6")
def __call__(
self,
img: torch.Tensor,
img: NdarrayOrTensor,
argmax: Optional[bool] = None,
to_onehot: Optional[bool] = None,
num_classes: Optional[int] = None,
threshold_values: Optional[bool] = None,
logit_thresh: Optional[float] = None,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
) -> torch.Tensor:
) -> NdarrayOrTensor:
"""
Args:
img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
Expand All @@ -197,24 +197,27 @@ def __call__(
# in case the new num_classes is default but you still call deprecated n_classes
if n_classes is not None and num_classes is None:
num_classes = n_classes
img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore
if argmax or self.argmax:
img = torch.argmax(img, dim=0, keepdim=True)
img_t = torch.argmax(img_t, dim=0, keepdim=True)

if to_onehot or self.to_onehot:
_nclasses = self.num_classes if num_classes is None else num_classes
if not isinstance(_nclasses, int):
raise AssertionError("One of self.num_classes or num_classes must be an integer")
img = one_hot(img, num_classes=_nclasses, dim=0)
img_t = one_hot(img_t, num_classes=_nclasses, dim=0)

if threshold_values or self.threshold_values:
img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh)
img_t = img_t >= (self.logit_thresh if logit_thresh is None else logit_thresh)

rounding = self.rounding if rounding is None else rounding
if rounding is not None:
look_up_option(rounding, ["torchrounding"])
img = torch.round(img)
img_t = torch.round(img_t)

return img.float()
img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float)
return img


class KeepLargestConnectedComponent(Transform):
Expand Down Expand Up @@ -275,7 +278,7 @@ def __init__(
If the data is in one-hot format, this is used to determine which channels to apply.
independent: whether to treat ``applied_labels`` as a union of foreground labels.
If ``True``, the connected component analysis will be performed on each foreground label independently
and return the intersection of the largest component.
and return the intersection of the largest components.
If ``False``, the analysis will be performed on the union of foreground labels.
default is `True`.
connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
Expand Down Expand Up @@ -368,7 +371,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
if isinstance(img, torch.Tensor):
if hasattr(torch, "isin"):
appl_lbls = torch.as_tensor(self.applied_labels, device=img.device)
return torch.where(torch.isin(img, appl_lbls), img, 0)
return torch.where(torch.isin(img, appl_lbls), img, torch.tensor(0.0).to(img))
else:
out = self(img.detach().cpu().numpy())
out, *_ = convert_to_dst_type(out, img)
Expand Down Expand Up @@ -460,7 +463,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

class LabelToContour(Transform):
"""
Return the contour of binary input images that only compose of 0 and 1, with Laplace kernel
Return the contour of binary input images that only compose of 0 and 1, with Laplacian kernel
set as default for edge detection. Typical usage is to plot the edge of label or segmentation output.

Args:
Expand All @@ -471,12 +474,14 @@ class LabelToContour(Transform):

"""

backend = [TransformBackends.TORCH]

def __init__(self, kernel_type: str = "Laplace") -> None:
if kernel_type != "Laplace":
raise NotImplementedError('Currently only kernel_type="Laplace" is supported.')
self.kernel_type = kernel_type

def __call__(self, img: torch.Tensor) -> torch.Tensor:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]]
Expand All @@ -492,22 +497,20 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor:
ideally the edge should be thin enough, but now it has a thickness.

"""
channels = img.shape[0]
img_ = img.unsqueeze(0)
if img.ndimension() == 3:
kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32, device=img.device)
kernel = kernel.repeat(channels, 1, 1, 1)
contour_img = F.conv2d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels)
elif img.ndimension() == 4:
kernel = -1 * torch.ones(3, 3, 3, dtype=torch.float32, device=img.device)
kernel[1, 1, 1] = 26
kernel = kernel.repeat(channels, 1, 1, 1, 1)
contour_img = F.conv3d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels)
img_: torch.Tensor = convert_data_type(img, torch.Tensor)[0] # type: ignore
spatial_dims = len(img_.shape) - 1
img_ = img_.unsqueeze(0) # adds a batch dim
if spatial_dims == 2:
kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32)
elif spatial_dims == 3:
kernel = -1.0 * torch.ones(3, 3, 3, dtype=torch.float32)
kernel[1, 1, 1] = 26.0
else:
raise ValueError(f"Unsupported img dimension: {img.ndimension()}, available options are [4, 5].")

raise ValueError(f"{self.__class__} can only handle 2D or 3D images.")
contour_img = apply_filter(img_, kernel)
contour_img.clamp_(min=0.0, max=1.0)
return contour_img.squeeze(0)
output, *_ = convert_to_dst_type(contour_img.squeeze(0), img)
return output


class Ensemble:
Expand All @@ -528,7 +531,7 @@ def post_convert(img: torch.Tensor, orig_img: Union[Sequence[NdarrayOrTensor], N
return out


class MeanEnsemble(Ensemble):
class MeanEnsemble(Ensemble, Transform):
"""
Execute mean ensemble on the input data.
The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]],
Expand All @@ -551,6 +554,8 @@ class MeanEnsemble(Ensemble):

"""

backend = [TransformBackends.TORCH]

def __init__(self, weights: Optional[Union[Sequence[float], NdarrayOrTensor]] = None) -> None:
self.weights = torch.as_tensor(weights, dtype=torch.float) if weights is not None else None

Expand All @@ -569,7 +574,7 @@ def __call__(self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> Nd
return self.post_convert(out_pt, img)


class VoteEnsemble(Ensemble):
class VoteEnsemble(Ensemble, Transform):
"""
Execute vote ensemble on the input data.
The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]],
Expand All @@ -589,6 +594,8 @@ class VoteEnsemble(Ensemble):

"""

backend = [TransformBackends.TORCH]

def __init__(self, num_classes: Optional[int] = None) -> None:
self.num_classes = num_classes

Expand Down Expand Up @@ -665,9 +672,9 @@ def __init__(
self.prob_threshold = prob_threshold
if isinstance(box_size, int):
self.box_size = np.asarray([box_size] * spatial_dims)
elif len(box_size) != spatial_dims:
raise ValueError("the sequence length of box_size should be the same as spatial_dims.")
else:
if len(box_size) != spatial_dims:
raise ValueError("the sequence length of box_size should be the same as spatial_dims.")
self.box_size = np.asarray(box_size)
if self.box_size.min() <= 0:
raise ValueError("box_size should be larger than 0.")
Expand Down
Loading