Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
which python
python -m pip install --upgrade pip wheel
python -m pip uninstall -y torch torchvision
python -m pip install torch==1.7.0 torchvision==0.8.1
python -m pip install torch==1.7.1 torchvision==0.8.2
python -m pip install -r requirements-dev.txt
- name: Run integration tests
run: |
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ jobs:
- if: runner.os == 'windows'
name: Install torch cpu from pytorch.org (Windows only)
run: |
python -m pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
# min. requirements for windows instances
python -c "f=open('requirements-dev.txt', 'r'); txt=f.readlines(); f.close(); print(txt); f=open('requirements-dev.txt', 'w'); f.writelines(txt[1:12]); f.close()"
- name: Install the dependencies
run: |
python -m pip install torch==1.7.0
python -m pip install torchvision==0.8.1
python -m pip install torch==1.7.1
python -m pip install torchvision==0.8.2
cat "requirements-dev.txt"
python -m pip install -r requirements-dev.txt
python -m pip list
Expand All @@ -108,7 +108,7 @@ jobs:
fail-fast: false
matrix:
os: [windows-latest, macOS-latest, ubuntu-latest]
timeout-minutes: 20
timeout-minutes: 40
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
Expand All @@ -134,11 +134,11 @@ jobs:
- if: runner.os == 'windows'
name: Install torch cpu from pytorch.org (Windows only)
run: |
python -m pip install torch==1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install torch==1.7.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install the dependencies
run: |
# min. requirements
python -m pip install torch==1.7.0
python -m pip install torch==1.7.1
python -m pip install -r requirements-min.txt
python -m pip list
BUILD_MONAI=0 python setup.py develop # no compile of extensions
Expand Down Expand Up @@ -173,7 +173,7 @@ jobs:
pytorch: "-h"
base: "nvcr.io/nvidia/pytorch:20.07-py3"
- environment: PT17+CUDA102
pytorch: "torch==1.7.0 torchvision==0.8.1"
pytorch: "torch==1.7.1 torchvision==0.8.2"
base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
- environment: PT17+CUDA110
# we explicitly set pytorch to -h to avoid pip install error
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/setupapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
which python
python -m pip install --upgrade pip wheel
python -m pip uninstall -y torch torchvision
python -m pip install torch==1.7.0 torchvision==0.8.1
python -m pip install torch==1.7.1 torchvision==0.8.2
python -m pip install -r requirements-dev.txt
- name: Run unit tests report coverage
run: |
Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:
- name: Install the dependencies
run: |
python -m pip install --upgrade pip wheel
python -m pip install torch==1.7.0 torchvision==0.8.1
python -m pip install torch==1.7.1 torchvision==0.8.2
python -m pip install -r requirements-dev.txt
- name: Run quick tests CPU ubuntu
run: |
Expand Down Expand Up @@ -151,7 +151,7 @@ jobs:
run: |
docker build -t localhost:5000/local_monai:latest -f Dockerfile .
docker push localhost:5000/local_monai:latest
sed -i '/flake/d' requirements-dev.txt
sed -i '/flake/d' requirements-dev.txt
docker build -t projectmonai/monai:latest -f Dockerfile .
docker login -u projectmonai -p ${{ secrets.DOCKER_PW }}
docker push projectmonai/monai:latest
Expand Down
5 changes: 5 additions & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ Misc
----
.. automodule:: monai.utils.misc
:members:

Profiling
---------
.. automodule:: monai.utils.profiling
:members:
8 changes: 0 additions & 8 deletions monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,6 @@ def set_visible_devices(*dev_inds):
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, dev_inds))


def get_torch_version_tuple():
"""
Returns:
tuple of ints represents the pytorch major/minor version.
"""
return tuple((int(x) for x in torch.__version__.split(".")[:2]))


def _dict_append(in_dict, key, fn):
try:
in_dict[key] = fn()
Expand Down
7 changes: 3 additions & 4 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
from torch import nn
from torch.autograd import Function

from monai.config import get_torch_version_tuple
from monai.networks.layers.convutils import gaussian_1d, same_padding
from monai.utils import InvalidPyTorchVersionError, SkipMode, ensure_tuple_rep, optional_import
from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, SkipMode, ensure_tuple_rep, optional_import

_C, _ = optional_import("monai._C")
if tuple(int(s) for s in torch.__version__.split(".")[0:2]) >= (1, 7):
if not PT_BEFORE_1_7:
fft, _ = optional_import("torch.fft")

__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape", "separable_filtering", "HilbertTransform"]
Expand Down Expand Up @@ -145,7 +144,7 @@ class HilbertTransform(nn.Module):

def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None:

if get_torch_version_tuple() < (1, 7):
if PT_BEFORE_1_7:
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)

super().__init__()
Expand Down
5 changes: 2 additions & 3 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
import numpy as np
import torch

from monai.config import get_torch_version_tuple
from monai.networks.layers import GaussianFilter, HilbertTransform
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.utils import rescale_array
from monai.utils import InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size
from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size


class RandGaussianNoise(Randomizable, Transform):
Expand Down Expand Up @@ -524,7 +523,7 @@ class DetectEnvelope(Transform):

def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None:

if get_torch_version_tuple() < (1, 7):
if PT_BEFORE_1_7:
raise InvalidPyTorchVersionError("1.7.0", self.__class__.__name__)

if axis < 0:
Expand Down
21 changes: 0 additions & 21 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import collections.abc
import itertools
import random
import time
from ast import literal_eval
from distutils.util import strtobool
from typing import Any, Callable, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -287,23 +286,3 @@ def dtype_torch_to_numpy(dtype):
def dtype_numpy_to_torch(dtype):
"""Convert a numpy dtype to its torch equivalent."""
return _np_to_torch_dtype[dtype]


class PerfContext:
"""
Context manager for tracking how much time is spent within context blocks. This uses `time.perf_counter` to
accumulate the total amount of time in seconds in the attribute `total_time` over however many context blocks
the object is used in.
"""

def __init__(self):
self.total_time = 0
self.start_time = None

def __enter__(self):
self.start_time = time.perf_counter()
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.total_time += time.perf_counter() - self.start_time
self.start_time = None
23 changes: 23 additions & 0 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from re import match
from typing import Any, Callable, List, Sequence, Tuple, Union

import torch

from .misc import ensure_tuple

OPTIONAL_IMPORT_MSG_FMT = "{}"
Expand All @@ -31,6 +33,8 @@
"get_full_type_name",
"has_option",
"get_package_version",
"get_torch_version_tuple",
"PT_BEFORE_1_7",
]


Expand Down Expand Up @@ -264,3 +268,22 @@ def get_package_version(dep_name, default="NOT INSTALLED or UNKNOWN VERSION."):
del dep
del sys.modules[dep_name]
return dep_ver


def get_torch_version_tuple():
"""
Returns:
tuple of ints represents the pytorch major/minor version.
"""
return tuple((int(x) for x in torch.__version__.split(".")[:2]))


PT_BEFORE_1_7 = True
ver, has_ver = optional_import("pkg_resources", name="parse_version")
try:
if has_ver:
PT_BEFORE_1_7 = ver(torch.__version__) < ver("1.7")
else:
PT_BEFORE_1_7 = get_torch_version_tuple() < (1, 7)
except (AttributeError, TypeError):
pass
22 changes: 21 additions & 1 deletion monai/utils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch

__all__ = ["torch_profiler_full", "torch_profiler_time_cpu_gpu", "torch_profiler_time_end_to_end"]
__all__ = ["torch_profiler_full", "torch_profiler_time_cpu_gpu", "torch_profiler_time_end_to_end", "PerfContext"]


def torch_profiler_full(func):
Expand Down Expand Up @@ -88,3 +88,23 @@ def wrapper(*args, **kwargs):
return result

return wrapper


class PerfContext:
"""
Context manager for tracking how much time is spent within context blocks. This uses `time.perf_counter` to
accumulate the total amount of time in seconds in the attribute `total_time` over however many context blocks
the object is used in.
"""

def __init__(self):
self.total_time = 0
self.start_time = None

def __enter__(self):
self.start_time = time.perf_counter()
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.total_time += time.perf_counter() - self.start_time
self.start_time = None
2 changes: 2 additions & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def run_testsuit():
"test_cachedataset",
"test_cachedataset_parallel",
"test_dataset",
"test_detect_envelope",
"test_iterable_dataset",
"test_ensemble_evaluator",
"test_handler_checkpoint_loader",
Expand All @@ -51,6 +52,7 @@ def run_testsuit():
"test_handler_validation",
"test_hausdorff_distance",
"test_header_correct",
"test_hilbert_transform",
"test_img2tensorboard",
"test_integration_segmentation_3d",
"test_integration_sliding_window",
Expand Down
3 changes: 1 addition & 2 deletions tests/test_detect_envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ def test_no_fft_module_error(self):
@SkipIfAtLeastPyTorchVersion((1, 7))
class TestDetectEnvelopeInvalidPyTorch(unittest.TestCase):
def test_invalid_pytorch_error(self):
with self.assertRaises(InvalidPyTorchVersionError) as cm:
with self.assertRaisesRegexp(InvalidPyTorchVersionError, "version"):
DetectEnvelope()
self.assertEqual("DetectEnvelope requires PyTorch version 1.7.0 or later", str(cm.exception))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_highresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_shape(self, input_param, input_shape, expected_shape):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

@TimedCall(seconds=100, force_quit=True)
@TimedCall(seconds=200, force_quit=True)
def test_script(self):
input_param, input_shape, expected_shape = TEST_CASE_1
net = HighResNet(**input_param)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_hilbert_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,8 @@ def test_no_fft_module_error(self):
@SkipIfAtLeastPyTorchVersion((1, 7))
class TestHilbertTransformInvalidPyTorch(unittest.TestCase):
def test_invalid_pytorch_error(self):
with self.assertRaises(InvalidPyTorchVersionError) as cm:
with self.assertRaisesRegex(InvalidPyTorchVersionError, "version"):
HilbertTransform()
self.assertEqual("HilbertTransform requires PyTorch version 1.7.0 or later", str(cm.exception))


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def train_and_infer(self, idx=0):

def test_training(self):
repeated = []
test_rounds = 3 if monai.config.get_torch_version_tuple() >= (1, 6) else 2
test_rounds = 3 if monai.utils.module.get_torch_version_tuple() >= (1, 6) else 2
for i in range(test_rounds):
results = self.train_and_infer(idx=i)
repeated.append(results)
Expand All @@ -308,7 +308,7 @@ def test_training(self):
daemon=False,
)
def test_timing(self):
if monai.config.get_torch_version_tuple() >= (1, 6):
if monai.utils.module.get_torch_version_tuple() >= (1, 6):
self.train_and_infer(idx=2)


Expand Down
13 changes: 10 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
import torch
import torch.distributed as dist

from monai.config import get_torch_version_tuple
from monai.data import create_test_image_2d, create_test_image_3d
from monai.utils import ensure_tuple, optional_import, set_determinism
from monai.utils.module import get_torch_version_tuple

nib, _ = optional_import("nibabel")
ver, has_pkg_res = optional_import("pkg_resources", name="parse_version")

quick_test_var = "QUICKTEST"

Expand Down Expand Up @@ -99,7 +100,10 @@ class SkipIfBeforePyTorchVersion(object):

def __init__(self, pytorch_version_tuple):
self.min_version = pytorch_version_tuple
self.version_too_old = get_torch_version_tuple() < self.min_version
if has_pkg_res:
self.version_too_old = ver(torch.__version__) < ver(".".join(map(str, self.min_version)))
else:
self.version_too_old = get_torch_version_tuple() < self.min_version

def __call__(self, obj):
return unittest.skipIf(
Expand All @@ -113,7 +117,10 @@ class SkipIfAtLeastPyTorchVersion(object):

def __init__(self, pytorch_version_tuple):
self.max_version = pytorch_version_tuple
self.version_too_new = get_torch_version_tuple() >= self.max_version
if has_pkg_res:
self.version_too_new = ver(torch.__version__) >= ver(".".join(map(str, self.max_version)))
else:
self.version_too_new = get_torch_version_tuple() >= self.max_version

def __call__(self, obj):
return unittest.skipIf(
Expand Down