diff --git a/monai/__init__.py b/monai/__init__.py index 3bb89cc348..2c7c920162 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -18,8 +18,8 @@ PY_REQUIRED_MINOR = 6 version_dict = get_versions() -__version__ = version_dict.get("version", "0+unknown") -__revision_id__ = version_dict.get("full-revisionid") +__version__: str = version_dict.get("version", "0+unknown") +__revision_id__: str = version_dict.get("full-revisionid") del get_versions, version_dict __copyright__ = "(c) 2020 - 2021 MONAI Consortium" diff --git a/monai/utils/deprecated.py b/monai/utils/deprecated.py index d7fc3bba2b..9d3fcd6435 100644 --- a/monai/utils/deprecated.py +++ b/monai/utils/deprecated.py @@ -10,6 +10,7 @@ # limitations under the License. import inspect +import re import warnings from functools import wraps from threading import Lock @@ -44,6 +45,10 @@ def version_leq(lhs, rhs): def _try_cast(val): val = val.strip() try: + m = re.match("(\\d+)(.*)", val) + if m is not None: + val = m.groups()[0] + return int(val) except ValueError: return val @@ -64,7 +69,7 @@ def _try_cast(val): def deprecated( - since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val=__version__ + since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val: str = __version__ ): """ Marks a function or class as deprecated. If `since` is given this should be a version at or earlier than the @@ -86,6 +91,10 @@ def deprecated( is_deprecated = since is not None and version_leq(since, version_val) is_removed = removed is not None and version_leq(removed, version_val) + is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) + + if is_not_yet_deprecated: + return lambda obj: obj def _decorator(obj): is_func = isinstance(obj, FunctionType) @@ -123,7 +132,11 @@ def _wrapper(*args, **kwargs): def deprecated_arg( - name, since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val=__version__ + name, + since: Optional[str] = None, + removed: Optional[str] = None, + msg_suffix: str = "", + version_val: str = __version__, ): """ Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as @@ -142,6 +155,10 @@ def deprecated_arg( is_deprecated = since is not None and version_leq(since, version_val) is_removed = removed is not None and version_leq(removed, version_val) + is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) + + if is_not_yet_deprecated: + return lambda obj: obj def _decorator(func): argname = f"{func.__name__}_{name}" diff --git a/monai/utils/module.py b/monai/utils/module.py index b51b2820a8..05ba297014 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -19,6 +19,7 @@ import torch +from .deprecated import version_leq from .misc import ensure_tuple OPTIONAL_IMPORT_MSG_FMT = "{}" @@ -270,12 +271,7 @@ def get_torch_version_tuple(): return tuple((int(x) for x in torch.__version__.split(".")[:2])) -PT_BEFORE_1_7 = True -ver, has_ver = optional_import("pkg_resources", name="parse_version") try: - if has_ver: - PT_BEFORE_1_7 = ver(torch.__version__) < ver("1.7") - else: - PT_BEFORE_1_7 = get_torch_version_tuple() < (1, 7) + PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0") except (AttributeError, TypeError): - pass + PT_BEFORE_1_7 = True diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index e3ff8719eb..7b5461bccf 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -11,6 +11,7 @@ import unittest +import warnings from monai.utils import DeprecatedError, deprecated, deprecated_arg @@ -165,3 +166,16 @@ def afoo5(a, b=None, c=None): self.assertWarns(DeprecationWarning, lambda: afoo5(1, 2)) self.assertWarns(DeprecationWarning, lambda: afoo5(1, 2, 3)) + + def test_future(self): + """Test deprecated decorator with `since` set to a future version.""" + + @deprecated(since=self.next_version, version_val=self.test_version) + def future1(): + pass + + with self.assertWarns(DeprecationWarning) as aw: + future1() + warnings.warn("fake warning", DeprecationWarning) + + self.assertEqual(aw.warning.args[0], "fake warning") diff --git a/tests/utils.py b/tests/utils.py index 7f17a64b54..ca407924ff 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,10 +32,9 @@ from monai.config.deviceconfig import USE_COMPILED from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism -from monai.utils.module import get_torch_version_tuple +from monai.utils.deprecated import version_leq nib, _ = optional_import("nibabel") -ver, has_pkg_res = optional_import("pkg_resources", name="parse_version") quick_test_var = "QUICKTEST" @@ -113,10 +112,8 @@ class SkipIfBeforePyTorchVersion: def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple - if has_pkg_res: - self.version_too_old = ver(torch.__version__) < ver(".".join(map(str, self.min_version))) - else: - self.version_too_old = get_torch_version_tuple() < self.min_version + test_ver = ".".join(map(str, self.min_version)) + self.version_too_old = torch.__version__ != test_ver and version_leq(torch.__version__, test_ver) def __call__(self, obj): return unittest.skipIf( @@ -130,10 +127,8 @@ class SkipIfAtLeastPyTorchVersion: def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple - if has_pkg_res: - self.version_too_new = ver(torch.__version__) >= ver(".".join(map(str, self.max_version))) - else: - self.version_too_new = get_torch_version_tuple() >= self.max_version + test_ver = ".".join(map(str, self.max_version)) + self.version_too_new = version_leq(test_ver, torch.__version__) def __call__(self, obj): return unittest.skipIf(