Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
PY_REQUIRED_MINOR = 6

version_dict = get_versions()
__version__ = version_dict.get("version", "0+unknown")
__revision_id__ = version_dict.get("full-revisionid")
__version__: str = version_dict.get("version", "0+unknown")
__revision_id__: str = version_dict.get("full-revisionid")
del get_versions, version_dict

__copyright__ = "(c) 2020 - 2021 MONAI Consortium"
Expand Down
21 changes: 19 additions & 2 deletions monai/utils/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import inspect
import re
import warnings
from functools import wraps
from threading import Lock
Expand Down Expand Up @@ -44,6 +45,10 @@ def version_leq(lhs, rhs):
def _try_cast(val):
val = val.strip()
try:
m = re.match("(\\d+)(.*)", val)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should prefer optional_import("pkg_resources", name="parse_version") if it is installed... this version_leq will be used in many cases, better to make it as robust as possible

if m is not None:
val = m.groups()[0]

return int(val)
except ValueError:
return val
Expand All @@ -64,7 +69,7 @@ def _try_cast(val):


def deprecated(
since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val=__version__
since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val: str = __version__
):
"""
Marks a function or class as deprecated. If `since` is given this should be a version at or earlier than the
Expand All @@ -86,6 +91,10 @@ def deprecated(

is_deprecated = since is not None and version_leq(since, version_val)
is_removed = removed is not None and version_leq(removed, version_val)
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)

if is_not_yet_deprecated:
return lambda obj: obj

def _decorator(obj):
is_func = isinstance(obj, FunctionType)
Expand Down Expand Up @@ -123,7 +132,11 @@ def _wrapper(*args, **kwargs):


def deprecated_arg(
name, since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val=__version__
name,
since: Optional[str] = None,
removed: Optional[str] = None,
msg_suffix: str = "",
version_val: str = __version__,
):
"""
Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as
Expand All @@ -142,6 +155,10 @@ def deprecated_arg(

is_deprecated = since is not None and version_leq(since, version_val)
is_removed = removed is not None and version_leq(removed, version_val)
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)

if is_not_yet_deprecated:
return lambda obj: obj

def _decorator(func):
argname = f"{func.__name__}_{name}"
Expand Down
10 changes: 3 additions & 7 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch

from .deprecated import version_leq
from .misc import ensure_tuple

OPTIONAL_IMPORT_MSG_FMT = "{}"
Expand Down Expand Up @@ -270,12 +271,7 @@ def get_torch_version_tuple():
return tuple((int(x) for x in torch.__version__.split(".")[:2]))


PT_BEFORE_1_7 = True
ver, has_ver = optional_import("pkg_resources", name="parse_version")
try:
if has_ver:
PT_BEFORE_1_7 = ver(torch.__version__) < ver("1.7")
else:
PT_BEFORE_1_7 = get_torch_version_tuple() < (1, 7)
PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0")
except (AttributeError, TypeError):
pass
PT_BEFORE_1_7 = True
14 changes: 14 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


import unittest
import warnings

from monai.utils import DeprecatedError, deprecated, deprecated_arg

Expand Down Expand Up @@ -165,3 +166,16 @@ def afoo5(a, b=None, c=None):

self.assertWarns(DeprecationWarning, lambda: afoo5(1, 2))
self.assertWarns(DeprecationWarning, lambda: afoo5(1, 2, 3))

def test_future(self):
"""Test deprecated decorator with `since` set to a future version."""

@deprecated(since=self.next_version, version_val=self.test_version)
def future1():
pass

with self.assertWarns(DeprecationWarning) as aw:
future1()
warnings.warn("fake warning", DeprecationWarning)

self.assertEqual(aw.warning.args[0], "fake warning")
15 changes: 5 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@
from monai.config.deviceconfig import USE_COMPILED
from monai.data import create_test_image_2d, create_test_image_3d
from monai.utils import ensure_tuple, optional_import, set_determinism
from monai.utils.module import get_torch_version_tuple
from monai.utils.deprecated import version_leq

nib, _ = optional_import("nibabel")
ver, has_pkg_res = optional_import("pkg_resources", name="parse_version")

quick_test_var = "QUICKTEST"

Expand Down Expand Up @@ -113,10 +112,8 @@ class SkipIfBeforePyTorchVersion:

def __init__(self, pytorch_version_tuple):
self.min_version = pytorch_version_tuple
if has_pkg_res:
self.version_too_old = ver(torch.__version__) < ver(".".join(map(str, self.min_version)))
else:
self.version_too_old = get_torch_version_tuple() < self.min_version
test_ver = ".".join(map(str, self.min_version))
self.version_too_old = torch.__version__ != test_ver and version_leq(torch.__version__, test_ver)

def __call__(self, obj):
return unittest.skipIf(
Expand All @@ -130,10 +127,8 @@ class SkipIfAtLeastPyTorchVersion:

def __init__(self, pytorch_version_tuple):
self.max_version = pytorch_version_tuple
if has_pkg_res:
self.version_too_new = ver(torch.__version__) >= ver(".".join(map(str, self.max_version)))
else:
self.version_too_new = get_torch_version_tuple() >= self.max_version
test_ver = ".".join(map(str, self.max_version))
self.version_too_new = version_leq(test_ver, torch.__version__)

def __call__(self, obj):
return unittest.skipIf(
Expand Down