diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 7026a2ac4d..f854e6fa92 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -72,6 +72,7 @@ load_submodules, min_version, optional_import, + version_leq, ) from .profiling import PerfContext, torch_profiler_full, torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end from .state_cacher import StateCacher diff --git a/monai/utils/deprecated.py b/monai/utils/deprecated.py index 9d3fcd6435..1caab6b4ab 100644 --- a/monai/utils/deprecated.py +++ b/monai/utils/deprecated.py @@ -10,20 +10,17 @@ # limitations under the License. import inspect -import re import warnings from functools import wraps -from threading import Lock from types import FunctionType from typing import Optional +from monai.utils.module import version_leq + from .. import __version__ __all__ = ["deprecated", "deprecated_arg", "DeprecatedError"] -warned_set = set() -warned_lock = Lock() - class DeprecatedError(Exception): pass @@ -31,41 +28,9 @@ class DeprecatedError(Exception): def warn_deprecated(obj, msg): """ - Issue the warning message `msg` only once per process for the given object `obj`. When this function is called - and `obj` is not in `warned_set`, it is added and the warning issued, if it's already present nothing happens. + Issue the warning message `msg`. """ - if obj not in warned_set: # ensure warning is issued only once per process - warned_set.add(obj) - warnings.warn(msg, category=DeprecationWarning, stacklevel=2) - - -def version_leq(lhs, rhs): - """Returns True if version `lhs` is earlier or equal to `rhs`.""" - - def _try_cast(val): - val = val.strip() - try: - m = re.match("(\\d+)(.*)", val) - if m is not None: - val = m.groups()[0] - - return int(val) - except ValueError: - return val - - # remove git version suffixes if present - lhs = lhs.split("+", 1)[0] - rhs = rhs.split("+", 1)[0] - - # parse the version strings in this basic way to avoid needing the `packaging` package - lhs = map(_try_cast, lhs.split(".")) - rhs = map(_try_cast, rhs.split(".")) - - for l, r in zip(lhs, rhs): - if l != r: - return l < r - - return True + warnings.warn(msg, category=DeprecationWarning, stacklevel=2) def deprecated( @@ -89,13 +54,22 @@ def deprecated( Decorated definition which warns or raises exception when used """ - is_deprecated = since is not None and version_leq(since, version_val) - is_removed = removed is not None and version_leq(removed, version_val) + if since is not None and removed is not None and not version_leq(since, removed): + raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) - if is_not_yet_deprecated: + # smaller than `since`, do nothing return lambda obj: obj + if since is None and removed is None: + # raise a DeprecatedError directly + is_removed = True + is_deprecated = True + else: + # compare the numbers + is_deprecated = since is not None and version_leq(since, version_val) + is_removed = removed is not None and version_leq(removed, version_val) + def _decorator(obj): is_func = isinstance(obj, FunctionType) call_obj = obj if is_func else obj.__init__ @@ -115,10 +89,10 @@ def _decorator(obj): @wraps(call_obj) def _wrapper(*args, **kwargs): + if is_removed: + raise DeprecatedError(msg) if is_deprecated: warn_deprecated(obj, msg) - else: - raise DeprecatedError(msg) return call_obj(*args, **kwargs) @@ -152,10 +126,21 @@ def deprecated_arg( Returns: Decorated callable which warns or raises exception when deprecated argument used """ - - is_deprecated = since is not None and version_leq(since, version_val) - is_removed = removed is not None and version_leq(removed, version_val) + if since is not None and removed is not None and not version_leq(since, removed): + raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) + if is_not_yet_deprecated: + # smaller than `since`, do nothing + return lambda obj: obj + + if since is None and removed is None: + # raise a DeprecatedError directly + is_removed = True + is_deprecated = True + else: + # compare the numbers + is_deprecated = since is not None and version_leq(since, version_val) + is_removed = removed is not None and version_leq(removed, version_val) if is_not_yet_deprecated: return lambda obj: obj @@ -186,10 +171,10 @@ def _wrapper(*args, **kwargs): kw_found = "kwargs" in binding and name in binding["kwargs"] if positional_found or kw_found: + if is_removed: + raise DeprecatedError(msg) if is_deprecated: warn_deprecated(argname, msg) - else: - raise DeprecatedError(msg) return func(*args, **kwargs) diff --git a/monai/utils/module.py b/monai/utils/module.py index 05ba297014..f6d7687bb6 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -19,7 +19,6 @@ import torch -from .deprecated import version_leq from .misc import ensure_tuple OPTIONAL_IMPORT_MSG_FMT = "{}" @@ -37,6 +36,7 @@ "get_package_version", "get_torch_version_tuple", "PT_BEFORE_1_7", + "version_leq", ] @@ -271,6 +271,41 @@ def get_torch_version_tuple(): return tuple((int(x) for x in torch.__version__.split(".")[:2])) +def version_leq(lhs, rhs): + """Returns True if version `lhs` is earlier or equal to `rhs`.""" + + ver, has_ver = optional_import("pkg_resources", name="parse_version") + if has_ver: + return ver(lhs) <= ver(rhs) + + def _try_cast(val): + val = val.strip() + try: + m = match("(\\d+)(.*)", val) + if m is not None: + val = m.groups()[0] + return int(val) + return val + except ValueError: + return val + + # remove git version suffixes if present + lhs = lhs.split("+", 1)[0] + rhs = rhs.split("+", 1)[0] + + # parse the version strings in this basic way without `packaging` package + lhs = map(_try_cast, lhs.split(".")) + rhs = map(_try_cast, rhs.split(".")) + + for l, r in zip(lhs, rhs): + if l != r: + if isinstance(l, int) and isinstance(r, int): + return l < r + return f"{l}" < f"{r}" + + return True + + try: PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0") except (AttributeError, TypeError): diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 7b5461bccf..429d5ee767 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -16,6 +16,49 @@ from monai.utils import DeprecatedError, deprecated, deprecated_arg +class TestDeprecatedRC(unittest.TestCase): + def setUp(self): + self.test_version_rc = "0.6.0rc1" + self.test_version = "0.6.0" + self.next_version = "0.7.0" + + def test_warning(self): + """Test deprecated decorator with `since` and `removed` set for an RC version""" + + @deprecated(since=self.test_version, removed=self.next_version, version_val=self.test_version_rc) + def foo2(): + pass + + print(foo2()) + + def test_warning_milestone(self): + """Test deprecated decorator with `since` and `removed` set for a milestone version""" + + @deprecated(since=self.test_version, removed=self.next_version, version_val=self.test_version) + def foo2(): + pass + + self.assertWarns(DeprecationWarning, foo2) + + def test_warning_last(self): + """Test deprecated decorator with `since` and `removed` set, for the last version""" + + @deprecated(since=self.test_version, removed=self.next_version, version_val=self.next_version) + def foo3(): + pass + + self.assertRaises(DeprecatedError, foo3) + + def test_warning_beyond(self): + """Test deprecated decorator with `since` and `removed` set, beyond the last version""" + + @deprecated(since=self.test_version_rc, removed=self.test_version, version_val=self.next_version) + def foo3(): + pass + + self.assertRaises(DeprecatedError, foo3) + + class TestDeprecated(unittest.TestCase): def setUp(self): self.test_version = "0.5.3+96.g1fa03c2.dirty" diff --git a/tests/test_version_leq.py b/tests/test_version_leq.py new file mode 100644 index 0000000000..a1913069d3 --- /dev/null +++ b/tests/test_version_leq.py @@ -0,0 +1,81 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import unittest + +from parameterized import parameterized + +from monai.utils import version_leq + + +# from pkg_resources +def _pairwise(iterable): + "s -> (s0,s1), (s1,s2), (s2, s3), ..." + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +# from pkg_resources +torture = """ + 0.80.1-3 0.80.1-2 0.80.1-1 0.79.9999+0.80.0pre4-1 + 0.79.9999+0.80.0pre2-3 0.79.9999+0.80.0pre2-2 + 0.77.2-1 0.77.1-1 0.77.0-1 + """ + +TEST_CASES = ( + ("1.6.0", "1.6.0"), + ("1.6.0a0+9907a3e", "1.6.0"), + ("0+unknown", "0.6"), + ("ab", "abc"), + ("0.6rc1", "0.6"), + ("0.6", "0.7"), + ("1.2.a", "1.2a"), + ("1.2-rc1", "1.2rc1"), + ("0.4", "0.4.0"), + ("0.4.0.0", "0.4.0"), + ("0.4.0-0", "0.4-0"), + ("0post1", "0.0post1"), + ("0pre1", "0.0c1"), + ("0.0.0preview1", "0c1"), + ("0.0c1", "0-rc1"), + ("1.2a1", "1.2.a.1"), + ("1.2.a", "1.2a"), + ("2.1", "2.1.1"), + ("2a1", "2b0"), + ("2a1", "2.1"), + ("2.3a1", "2.3"), + ("2.1-1", "2.1-2"), + ("2.1-1", "2.1.1"), + ("2.1", "2.1post4"), + ("2.1a0-20040501", "2.1"), + ("1.1", "02.1"), + ("3.2", "3.2.post0"), + ("3.2post1", "3.2post2"), + ("0.4", "4.0"), + ("0.0.4", "0.4.0"), + ("0post1", "0.4post1"), + ("2.1.0-rc1", "2.1.0"), + ("2.1dev", "2.1a0"), +) + tuple(_pairwise(reversed(torture.split()))) + + +class TestVersionCompare(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_compare(self, a, b, expected=True): + """Test version_leq with `a` and `b`""" + self.assertEqual(version_leq(a, b), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index ca407924ff..5970e65d9d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,7 +32,7 @@ from monai.config.deviceconfig import USE_COMPILED from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism -from monai.utils.deprecated import version_leq +from monai.utils.module import version_leq nib, _ = optional_import("nibabel")