Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ Layers
.. autoclass:: GaussianFilter
:members:

`HilbertTransform`
~~~~~~~~~~~~~~~~~~
.. autoclass:: HilbertTransform
:members:

`Affine Transform`
~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.networks.layers.AffineTransform
Expand Down
6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ Intensity
:members:
:special-members: __call__

`DetectEnvelope`
"""""""""""""""""""""
.. autoclass:: DetectEnvelope
:members:
:special-members: __call__

IO
^^

Expand Down
73 changes: 71 additions & 2 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
from torch import nn
from torch.autograd import Function

from monai.config import get_torch_version_tuple
from monai.networks.layers.convutils import gaussian_1d, same_padding
from monai.utils import SkipMode, ensure_tuple_rep, optional_import
from monai.utils import InvalidPyTorchVersionError, SkipMode, ensure_tuple_rep, optional_import

_C, _ = optional_import("monai._C")
if tuple(int(s) for s in torch.__version__.split(".")[0:2]) >= (1, 7):
fft, _ = optional_import("torch.fft")

__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape", "separable_filtering"]
__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape", "separable_filtering", "HilbertTransform"]


class SkipConnection(nn.Module):
Expand Down Expand Up @@ -130,6 +133,72 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor:
return _conv(x, spatial_dims - 1)


class HilbertTransform(nn.Module):
"""
Determine the analytical signal of a Tensor along a particular axis.
Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10).

Args:
axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension).
N: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``.
"""

def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None:

if get_torch_version_tuple() < (1, 7):
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)

super().__init__()
self.axis = axis
self.n = n

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor or array-like to transform. Must be real and in shape ``[Batch, chns, spatial1, spatial2, ...]``.
Returns:
torch.Tensor: Analytical signal of ``x``, transformed along axis specified in ``self.axis`` using
FFT of size ``self.N``. The absolute value of ``x_ht`` relates to the envelope of ``x`` along axis ``self.axis``.
"""

# Make input a real tensor
x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None)
if torch.is_complex(x):
raise ValueError("x must be real.")
else:
x = x.to(dtype=torch.float)

if (self.axis < 0) or (self.axis > len(x.shape) - 1):
raise ValueError("Invalid axis for shape of x.")

n = x.shape[self.axis] if self.n is None else self.n
if n <= 0:
raise ValueError("N must be positive.")
x = torch.as_tensor(x, dtype=torch.complex64)
# Create frequency axis
f = torch.cat(
[
torch.true_divide(torch.arange(0, (n - 1) // 2 + 1, device=x.device), float(n)),
torch.true_divide(torch.arange(-(n // 2), 0, device=x.device), float(n)),
]
)
xf = fft.fft(x, n=n, dim=self.axis)
# Create step function
u = torch.heaviside(f, torch.tensor([0.5], device=f.device))
u = torch.as_tensor(u, dtype=x.dtype, device=u.device)
new_dims_before = self.axis
new_dims_after = len(xf.shape) - self.axis - 1
for _ in range(new_dims_before):
u.unsqueeze_(0)
for _ in range(new_dims_after):
u.unsqueeze_(-1)

ht = fft.ifft(xf * 2 * u, dim=self.axis)

# Apply transform
return torch.as_tensor(ht, device=ht.device, dtype=ht.dtype)


class GaussianFilter(nn.Module):
def __init__(
self,
Expand Down
45 changes: 43 additions & 2 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import numpy as np
import torch

from monai.networks.layers import GaussianFilter
from monai.config import get_torch_version_tuple
from monai.networks.layers import GaussianFilter, HilbertTransform
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.utils import rescale_array
from monai.utils import dtype_torch_to_numpy, ensure_tuple_size
from monai.utils import InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size


class RandGaussianNoise(Randomizable, Transform):
Expand Down Expand Up @@ -509,6 +510,46 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n
return img * mask_data_


class DetectEnvelope(Transform):
"""
Find the envelope of the input data along the requested axis using a Hilbert transform.
Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10).

Args:
axis: Axis along which to detect the envelope. Default 1, i.e. the first spatial dimension.
N: FFT size. Default img.shape[axis]. Input will be zero-padded or truncated to this size along dimension
``axis``.

"""

def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None:

if get_torch_version_tuple() < (1, 7):
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)

if axis < 0:
raise ValueError("axis must be zero or positive.")

self.axis = axis
self.n = n

def __call__(self, img: np.ndarray) -> np.ndarray:
"""

Args:
img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...].

Returns:
np.ndarray containing envelope of data in img along the specified axis.

"""
# add one to transform axis because a batch axis will be added at dimension 0
hilbert_transform = HilbertTransform(self.axis + 1, self.n)
# convert to Tensor and add Batch axis expected by HilbertTransform
input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0)
return np.abs(hilbert_transform(input_data).squeeze(0).numpy())


class GaussianSmooth(Transform):
"""
Apply Gaussian smooth to the input data based on specified `sigma` parameter.
Expand Down
12 changes: 12 additions & 0 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
OPTIONAL_IMPORT_MSG_FMT = "{}"

__all__ = [
"InvalidPyTorchVersionError",
"OptionalImportError",
"exact_version",
"export",
Expand Down Expand Up @@ -105,6 +106,17 @@ def exact_version(the_module, version_str: str = "") -> bool:
return bool(the_module.__version__ == version_str)


class InvalidPyTorchVersionError(Exception):
"""
Raised when called function or method requires a more recent
PyTorch version than that installed.
"""

def __init__(self, required_version, name):
message = f"{name} requires PyTorch version {required_version} or later"
super().__init__(message)


class OptionalImportError(ImportError):
"""
Could not import APIs from an optional dependency.
Expand Down
165 changes: 165 additions & 0 deletions tests/test_detect_envelope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import DetectEnvelope
from monai.utils import InvalidPyTorchVersionError, OptionalImportError
from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, SkipIfModule, SkipIfNoModule

n_samples = 500
hann_windowed_sine = np.sin(2 * np.pi * 10 * np.linspace(0, 1, n_samples)) * np.hanning(n_samples)

# SINGLE-CHANNEL VALUE TESTS
# using np.expand_dims() to add length 1 channel dimension at dimension 0

TEST_CASE_1D_SINE = [
{}, # args (empty, so use default)
np.expand_dims(hann_windowed_sine, 0), # Input data: Hann windowed sine wave
np.expand_dims(np.hanning(n_samples), 0), # Expected output: the Hann window
1e-4, # absolute tolerance
]

TEST_CASE_2D_SINE = [
{}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1)
# Create 10 identical windowed sine waves as a 2D numpy array
np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0),
# Expected output: Set of 10 identical Hann windows
np.expand_dims(np.stack([np.hanning(n_samples)] * 10, axis=1), 0),
1e-4, # absolute tolerance
]

TEST_CASE_3D_SINE = [
{}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1)
# Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array
np.expand_dims(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), 0),
# Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array
np.expand_dims(np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2), 0),
1e-4, # absolute tolerance
]

TEST_CASE_2D_SINE_AXIS_1 = [
{"axis": 2}, # set axis argument to 1
# Create 10 identical windowed sine waves as a 2D numpy array
np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0),
# Expected output: absolute value of each sample of the waveform, repeated (i.e. flat envelopes)
np.expand_dims(np.abs(np.repeat(hann_windowed_sine, 10).reshape((n_samples, 10))), 0),
1e-4, # absolute tolerance
]

TEST_CASE_1D_SINE_PADDING_N = [
{"n": 512}, # args (empty, so use default)
np.expand_dims(hann_windowed_sine, 0), # Input data: Hann windowed sine wave
np.expand_dims(np.concatenate([np.hanning(500), np.zeros(12)]), 0), # Expected output: the Hann window
1e-3, # absolute tolerance
]

# MULTI-CHANNEL VALUE TEST

TEST_CASE_2_CHAN_3D_SINE = [
{}, # args (empty, so use default (i.e. process along first spatial dimension, axis=1)
# Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array, twice (2 channels)
np.stack([np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2)] * 2, axis=0),
# Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array, twice (2 channels)
np.stack([np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2)] * 2, axis=0),
1e-4, # absolute tolerance
]

# EXCEPTION TESTS

TEST_CASE_INVALID_AXIS_1 = [
{"axis": 3}, # set axis argument to 3 when only 3 dimensions (1 channel + 2 spatial)
np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), # Create 2D dataset
"__call__", # method expected to raise exception
]

TEST_CASE_INVALID_AXIS_2 = [
{"axis": -1}, # set axis argument negative
np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), # Create 2D dataset
"__init__", # method expected to raise exception
]

TEST_CASE_INVALID_N = [
{"n": 0}, # set FFT length to zero
np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0), # Create 2D dataset
"__call__", # method expected to raise exception
]

TEST_CASE_INVALID_DTYPE = [
{},
np.expand_dims(np.array(hann_windowed_sine, dtype=np.complex), 0), # complex numbers are invalid
"__call__", # method expected to raise exception
]

TEST_CASE_INVALID_IMG_LEN = [
{},
np.expand_dims(np.array([]), 0), # empty array is invalid
"__call__", # method expected to raise exception
]

TEST_CASE_INVALID_OBJ = [{}, "a string", "__call__"] # method expected to raise exception


@SkipIfBeforePyTorchVersion((1, 7))
@SkipIfNoModule("torch.fft")
class TestDetectEnvelope(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_1D_SINE,
TEST_CASE_2D_SINE,
TEST_CASE_3D_SINE,
TEST_CASE_2D_SINE_AXIS_1,
TEST_CASE_1D_SINE_PADDING_N,
TEST_CASE_2_CHAN_3D_SINE,
]
)
def test_value(self, arguments, image, expected_data, atol):
result = DetectEnvelope(**arguments)(image)
np.testing.assert_allclose(result, expected_data, atol=atol)

@parameterized.expand(
[
TEST_CASE_INVALID_AXIS_1,
TEST_CASE_INVALID_AXIS_2,
TEST_CASE_INVALID_N,
TEST_CASE_INVALID_DTYPE,
TEST_CASE_INVALID_IMG_LEN,
]
)
def test_value_error(self, arguments, image, method):
if method == "__init__":
self.assertRaises(ValueError, DetectEnvelope, **arguments)
elif method == "__call__":
self.assertRaises(ValueError, DetectEnvelope(**arguments), image)
else:
raise ValueError("Expected raising method invalid. Should be __init__ or __call__.")


@SkipIfBeforePyTorchVersion((1, 7))
@SkipIfModule("torch.fft")
class TestHilbertTransformNoFFTMod(unittest.TestCase):
def test_no_fft_module_error(self):
self.assertRaises(OptionalImportError, DetectEnvelope(), np.random.rand(1, 10))


@SkipIfAtLeastPyTorchVersion((1, 7))
class TestDetectEnvelopeInvalidPyTorch(unittest.TestCase):
def test_invalid_pytorch_error(self):
with self.assertRaises(InvalidPyTorchVersionError) as cm:
DetectEnvelope()
self.assertEqual("DetectEnvelope requires PyTorch version 1.7.0 or later", str(cm.exception))


if __name__ == "__main__":
unittest.main()
Loading