Skip to content

Conversation

@crnbaker
Copy link
Contributor

Fixes #1210

Description

A new class HilbertTransform has been implemented in monai/networks/layers/simplelayers.py using the torch.fft module. Another new class DetectEnvelope, which wraps HilbertTransform, has been implemented in monai/transforms/intensity/array.py. Unit testing of DetectEnvelope is implemented in tests/test_detect_envelope.py. Documentation has been updated and tested locally. Tests have been passed locally. Commits have been squashed.

Status

Ready

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh --codeformat --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Details

Running on a Google CoLab VM with a GPU, HilbertTransform is faster than the SciPy equivalent for Tensors with more than about 10^4 elements:
image
Notebook demonstrating this is here.
I also implemented a Hilbert transform using torch.nn.functional.conv1d but its performance was poor compared to the FFT implementation.

@crnbaker crnbaker force-pushed the 1210-hilbert-transform branch 2 times, most recently from 1a08092 to 770b58c Compare November 24, 2020 20:00
@wyli
Copy link
Contributor

wyli commented Nov 24, 2020

looks like torch.fft has some breaking changes in torch1.5 - 1.7, perhaps you can use this pattern as a workaround:

if get_torch_version_tuple() >= (1, 5):
# additional argument since torch 1.5 (to avoid warnings)
def _torch_interp(**kwargs):
return torch.nn.functional.interpolate(recompute_scale_factor=True, **kwargs)
else:
_torch_interp = torch.nn.functional.interpolate

@crnbaker
Copy link
Contributor Author

looks like torch.fft has some breaking changes in torch1.5 - 1.7, perhaps you can use this pattern as a workaround:

if get_torch_version_tuple() >= (1, 5):
# additional argument since torch 1.5 (to avoid warnings)
def _torch_interp(**kwargs):
return torch.nn.functional.interpolate(recompute_scale_factor=True, **kwargs)
else:
_torch_interp = torch.nn.functional.interpolate

Thanks @wyli, I'll try that. There are a couple of mypy issues to fix too. Once all the tests (apart from DCO) have passed I'll rebase to a single signed commit.

@tvercaut
Copy link
Member

@crnbaker Nice graphs! From a quick look at it, the 10^4 performance crossing point refers to the number of pixels across the full image / tensor, rather than the size along the transformed dimension, right?

Looking at the performance crossing point in terms of number of points within an A-line, does the number of lines one has in the image/tensor influence it much?

If the crossing point is always around 200 samples / A-line, I guess it means the GPU implementation should be the best performing one in most real-life applications. Correct?

@crnbaker
Copy link
Contributor Author

Yes, the x axis is number of pixels in the whole tensor. That dataset contained images consisting of 50 A-lines, each image with different numbers of samples per line.

It doesn't appear that the shape of the image affects the results. In the plot below, that "2D" data is the same as the data above, and the "1D" data consists of Tensors containing only a single A-line (the number of samples in each line being higher so that the total size of the Tensors is similar to those in the 2D dataset):
image

I was expecting that the PyTorch FFT module would be doing doing parallelisation under the hood, so that multiple FFTs would be calculated simultaneously. That doesn't appear to be the case, but maybe this can be achieved using batching. Or PyTorch is parallelizing the FFT algorithm itself, which I think is possible but unlikely.

Here's a graph of the (1D) SciPy times divided by the (1D) MONAI GPU times:
image

Our FOH images have approximately 380,000 elements, so we should expect the MONAI Hilbert transform to run approximately 40 times faster than the ScipPy one (at least for the Google CoLab CPU/GPU).

@crnbaker
Copy link
Contributor Author

crnbaker commented Nov 25, 2020

@wyli I've used the torch.fft.fft module which is new in PyTorch 1.7. Pre v1.7 there was the torch.fft function, which will be removed in v1.8. I can relatively easily support <1.7 by importing the torch.fft function and using that instead, but this would require me to import get_torch_version_tuple before the torch modules, which would fail Flake8! Any ideas?

EDIT: I think I can just use torch.__version__ instead of get_torch_version_tuple , I'll try that.

@wyli
Copy link
Contributor

wyli commented Nov 25, 2020

I think you can declare the type like this one

_torch_interp: Callable[..., torch.Tensor]

you could also put # pytype: disable= or # type: ignore or # noqa to mute the errors for now

*in the docker torch.__version__ will return some verbose result like 1.7.0a0+7036e91

@crnbaker crnbaker force-pushed the 1210-hilbert-transform branch 7 times, most recently from 0ab2a30 to 14bdaea Compare November 25, 2020 22:56
@wyli
Copy link
Contributor

wyli commented Nov 26, 2020

the current PR works fine with torch1.6+cu10.2. for torch1.7+cu11.1 only if I change the version conditions in the utils to

if tuple(int(s) for s in torch.__version__.split(".")[0:2]) <= (1, 7):  # line 22
if get_torch_version_tuple() > (1, 7):   # line 248

@crnbaker
Copy link
Contributor Author

crnbaker commented Nov 26, 2020

Is it possible for me to get the Dockerfiles so that I can set up the containers that are failing and debug on my system?

@wyli
Copy link
Contributor

wyli commented Nov 26, 2020

that's not yet implemented, for now perhaps you could refer to the configuration which are essentially bash commands

btw the fft API of pytorch is a mess, and apparently the documentation is not quite correct ...

@crnbaker
Copy link
Contributor Author

Thanks. I'm starting to realise that! Also, complex datatypes have only been recently introduced.

@crnbaker
Copy link
Contributor Author

crnbaker commented Dec 9, 2020

I had another look at this at the end of last week. To debug, I need a system with PyTorch 1.6 and CUDA 11.0. There are no windows binaries available for this combination via pip or conda. There is a Docker image available, but Docker on Windows doesn't support CUDA. So I need a Linux system. I'd use WSL but that doesn't work with CUDA without installing a preview release on Windows... so I think I need to set up dual boot which I'd like to do anyway.

Or as @tvercaut suggested, could we make this feature only available on PyTorch 1.7 somehow?

@wyli
Copy link
Contributor

wyli commented Dec 9, 2020

I think we should go for pytorch 1.7+ , in the unit tests we need a new decorator similar to this one, to skip the tests on torch < 1.7

def skip_if_no_cuda(obj):

@crnbaker
Copy link
Contributor Author

crnbaker commented Dec 9, 2020

Great, thanks @wyli I'm nearly there with this now. Any suggestions on what exception should be raised if the function is called with PyTorch < 1.7.0? I could defined a new exception, something like:

class InvalidPyTorchVersion(Error):
    """Raised when called function or method requires a more recent PyTorch version than that installed."""
    def __init__(self, required_version, name):
        message = f'{name} requires PyTorch version {required_version} or later'
        super().__init__(message)

class HilbertTransform(nn.Module):
    def __init__(self):
        if get_torch_version_tuple() < (1, 7):
            raise InvalidPyTorchVersion('1.7.0', self.__class__.__name__)

Is there anywhere in particular this should go?

@wyli
Copy link
Contributor

wyli commented Dec 9, 2020

cool, the exception could be in the module util, next to this one:

class OptionalImportError(ImportError):

or perhaps reuse this example:

>>> torch, flag = optional_import('torch', '1.1')

I'm fine with either

@crnbaker crnbaker force-pushed the 1210-hilbert-transform branch from 14bdaea to 1e7013e Compare December 9, 2020 22:16
@crnbaker
Copy link
Contributor Author

crnbaker commented Dec 10, 2020

It turns out the PT17+CU110 check uses an NVIDIA container image which is based on a pre-release version of PyTorch 1.7 (commit 7036e91 from 15th September) that doesn't yet have the FFT module. Apparently the FFT module was added sometime between then and the 1.7.0 release on 27th October.

I suppose others may also be using this NVIDIA image, so I'll need to add in something that not only checks the PyTorch version, but also whether the FFT module is available.

EDIT: I'll use optional_import() and @SkipIfNoModule

@crnbaker crnbaker force-pushed the 1210-hilbert-transform branch 2 times, most recently from 45eff8f to 41788f7 Compare December 10, 2020 11:54
@crnbaker
Copy link
Contributor Author

All sorted. Just failing the flake8 check for some reason... seems to be raising an exception related to the multiprocessing module. Have you seen that before @wyli?

@wyli
Copy link
Contributor

wyli commented Dec 10, 2020

All sorted. Just failing the flake8 check for some reason... seems to be raising an exception related to the multiprocessing module. Have you seen that before @wyli?

interesting, I've never seen this before I'm looking into this...

@wyli
Copy link
Contributor

wyli commented Dec 10, 2020

not sure if this is relevant but I got these errors on mac, perhaps remove the shebang line in the file?

flake8
3.8.4 (flake8-bugbear: 20.1.4, flake8-comprehensions: 3.3.0,
flake8-executable: 2.0.4, flake8-pyi: 20.10.0, mccabe: 0.6.1, naming: 0.11.1,
pycodestyle: 2.6.0, pyflakes: 2.2.0) CPython 3.6.10 on Darwin
/Users/wenqili/Documents/MONAI/tests/test_hilbert_transform.py:1:1: EXE001 Shebang is present but the file is not executable.
/Users/wenqili/Documents/MONAI/tests/test_hilbert_transform.py:32:6: N806 variable 'X' in function should be lowercase
/Users/wenqili/Documents/MONAI/tests/test_hilbert_transform.py:37:6: N806 variable 'U' in function should be lowercase
/Users/wenqili/Documents/MONAI/tests/test_hilbert_transform.py:41:10: N806 variable 'U' in function should be lowercase
/Users/wenqili/Documents/MONAI/tests/test_hilbert_transform.py:43:10: N806 variable 'U' in function should be lowercase
1     EXE001 Shebang is present but the file is not executable.
4     N806 variable 'X' in function should be lowercase
5
Check failed!

@crnbaker crnbaker force-pushed the 1210-hilbert-transform branch 2 times, most recently from 8116943 to d1b1ca0 Compare December 10, 2020 13:53
@wyli
Copy link
Contributor

wyli commented Dec 10, 2020

/integration-test I think it's a bug from the flake8 plugin xuhdev/flake8-executable#15 need to make the py files not executablechmod -x

@crnbaker crnbaker force-pushed the 1210-hilbert-transform branch from e576d8d to ba5bcd9 Compare December 10, 2020 14:28
…ed. Project-MONAI#1210. Requires torch.fft module and PyTorch 1.7.0+.

Signed-off-by: Christian Baker <christian.baker@kcl.ac.uk>

* New tests/utils.py unittest decorators added: SkipIfModule(),
SkipIfBeforePyTorchVersion() and SkipIfAtLeastPyTorchVersion

Signed-off-by: Christian Baker <christian.baker@kcl.ac.uk>

* New unit tests added test_hilbert_transform.py and
test_detect_envelope.py

Signed-off-by: Christian Baker <christian.baker@kcl.ac.uk>

* New exception InvalidPyTorchVersionError implemented in monai/utils/module.py

Signed-off-by: Christian Baker <christian.baker@kcl.ac.uk>

* New class HilbertTransform implemented in
monai/networks/layers/simplelayers.py

Signed-off-by: Christian Baker <christian.baker@kcl.ac.uk>

* New class DetectEnvelope implemented in
monai/transforms/intensity/array.py, uses HilbertTransform

Signed-off-by: Christian Baker <christian.baker@kcl.ac.uk>
@crnbaker crnbaker force-pushed the 1210-hilbert-transform branch from ba5bcd9 to aa75c5b Compare December 10, 2020 15:39
@crnbaker
Copy link
Contributor Author

Good to go!

Copy link
Contributor

@wyli wyli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

@wyli wyli merged commit 453d40d into Project-MONAI:master Dec 10, 2020
@wyli
Copy link
Contributor

wyli commented Dec 10, 2020

It turns out the PT17+CU110 check uses an NVIDIA container image which is based on a pre-release version of PyTorch 1.7 (commit 7036e91 from 15th September) that doesn't yet have the FFT module. Apparently the FFT module was added sometime between then and the 1.7.0 release on 27th October.

I suppose others may also be using this NVIDIA image, so I'll need to add in something that not only checks the PyTorch version, but also whether the FFT module is available.

EDIT: I'll use optional_import() and @SkipIfNoModule

I didn't aware of this issue earlier, monai's docker image also uses PyTorch Version 1.7.0a0+7036e91...

@crnbaker
Copy link
Contributor Author

It turns out the PT17+CU110 check uses an NVIDIA container image which is based on a pre-release version of PyTorch 1.7 (commit 7036e91 from 15th September) that doesn't yet have the FFT module. Apparently the FFT module was added sometime between then and the 1.7.0 release on 27th October.
I suppose others may also be using this NVIDIA image, so I'll need to add in something that not only checks the PyTorch version, but also whether the FFT module is available.
EDIT: I'll use optional_import() and @SkipIfNoModule

I didn't aware of this issue earlier, monai's docker image also uses PyTorch Version 1.7.0a0+7036e91...

Yes and I've just seen it's failed the Docker check, which seems bizarre considering it passed with the same version of PyTorch here... I'll look into it.

@wyli
Copy link
Contributor

wyli commented Dec 10, 2020

no problem I pushed a quick fix

MONAI/monai/utils/module.py

Lines 281 to 287 in 5c75dcb

PT_BEFORE_1_7 = True
ver, has_ver = optional_import("pkg_resources", name="parse_version")
try:
if has_ver:
PT_BEFORE_1_7 = ver(torch.__version__) < ver("1.7")
else:
PT_BEFORE_1_7 = get_torch_version_tuple() < (1, 7)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Hilbert transform for envelope detection

3 participants