From 1d22f52421db8401a3667e0a53d93d44df9d39f1 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 29 Jun 2021 20:06:19 +0100 Subject: [PATCH 01/10] enhance version compare Signed-off-by: Wenqi Li --- monai/utils/__init__.py | 1 + monai/utils/deprecated.py | 72 +++++++++++++++++---------------------- monai/utils/module.py | 36 +++++++++++++++++++- tests/test_deprecated.py | 43 +++++++++++++++++++++++ tests/utils.py | 2 +- 5 files changed, 112 insertions(+), 42 deletions(-) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 7026a2ac4d..f854e6fa92 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -72,6 +72,7 @@ load_submodules, min_version, optional_import, + version_leq, ) from .profiling import PerfContext, torch_profiler_full, torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end from .state_cacher import StateCacher diff --git a/monai/utils/deprecated.py b/monai/utils/deprecated.py index 9d3fcd6435..98adec229b 100644 --- a/monai/utils/deprecated.py +++ b/monai/utils/deprecated.py @@ -10,13 +10,14 @@ # limitations under the License. import inspect -import re import warnings from functools import wraps from threading import Lock from types import FunctionType from typing import Optional +from monai.utils.module import version_leq + from .. import __version__ __all__ = ["deprecated", "deprecated_arg", "DeprecatedError"] @@ -39,35 +40,6 @@ def warn_deprecated(obj, msg): warnings.warn(msg, category=DeprecationWarning, stacklevel=2) -def version_leq(lhs, rhs): - """Returns True if version `lhs` is earlier or equal to `rhs`.""" - - def _try_cast(val): - val = val.strip() - try: - m = re.match("(\\d+)(.*)", val) - if m is not None: - val = m.groups()[0] - - return int(val) - except ValueError: - return val - - # remove git version suffixes if present - lhs = lhs.split("+", 1)[0] - rhs = rhs.split("+", 1)[0] - - # parse the version strings in this basic way to avoid needing the `packaging` package - lhs = map(_try_cast, lhs.split(".")) - rhs = map(_try_cast, rhs.split(".")) - - for l, r in zip(lhs, rhs): - if l != r: - return l < r - - return True - - def deprecated( since: Optional[str] = None, removed: Optional[str] = None, msg_suffix: str = "", version_val: str = __version__ ): @@ -89,13 +61,22 @@ def deprecated( Decorated definition which warns or raises exception when used """ - is_deprecated = since is not None and version_leq(since, version_val) - is_removed = removed is not None and version_leq(removed, version_val) + if since is not None and removed is not None and not version_leq(since, removed): + raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) - if is_not_yet_deprecated: + # smaller than `since`, do nothing return lambda obj: obj + if since is None and removed is None: + # raise a DeprecatedError directly + is_removed = True + is_deprecated = True + else: + # compare the numbers + is_deprecated = since is not None and version_leq(since, version_val) + is_removed = removed is not None and version_leq(removed, version_val) + def _decorator(obj): is_func = isinstance(obj, FunctionType) call_obj = obj if is_func else obj.__init__ @@ -115,10 +96,10 @@ def _decorator(obj): @wraps(call_obj) def _wrapper(*args, **kwargs): + if is_removed: + raise DeprecatedError(msg) if is_deprecated: warn_deprecated(obj, msg) - else: - raise DeprecatedError(msg) return call_obj(*args, **kwargs) @@ -152,10 +133,21 @@ def deprecated_arg( Returns: Decorated callable which warns or raises exception when deprecated argument used """ - - is_deprecated = since is not None and version_leq(since, version_val) - is_removed = removed is not None and version_leq(removed, version_val) + if since is not None and removed is not None and not version_leq(since, removed): + raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.") is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since) + if is_not_yet_deprecated: + # smaller than `since`, do nothing + return lambda obj: obj + + if since is None and removed is None: + # raise a DeprecatedError directly + is_removed = True + is_deprecated = True + else: + # compare the numbers + is_deprecated = since is not None and version_leq(since, version_val) + is_removed = removed is not None and version_leq(removed, version_val) if is_not_yet_deprecated: return lambda obj: obj @@ -186,10 +178,10 @@ def _wrapper(*args, **kwargs): kw_found = "kwargs" in binding and name in binding["kwargs"] if positional_found or kw_found: + if is_removed: + raise DeprecatedError(msg) if is_deprecated: warn_deprecated(argname, msg) - else: - raise DeprecatedError(msg) return func(*args, **kwargs) diff --git a/monai/utils/module.py b/monai/utils/module.py index 05ba297014..e6ac6b0cfe 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -10,6 +10,7 @@ # limitations under the License. import inspect +import re import sys import warnings from importlib import import_module @@ -19,7 +20,6 @@ import torch -from .deprecated import version_leq from .misc import ensure_tuple OPTIONAL_IMPORT_MSG_FMT = "{}" @@ -37,6 +37,7 @@ "get_package_version", "get_torch_version_tuple", "PT_BEFORE_1_7", + "version_leq", ] @@ -271,6 +272,39 @@ def get_torch_version_tuple(): return tuple((int(x) for x in torch.__version__.split(".")[:2])) +def version_leq(lhs, rhs): + """Returns True if version `lhs` is earlier or equal to `rhs`.""" + + ver, has_ver = optional_import("pkg_resources", name="parse_version") + if has_ver: + return ver(lhs) <= ver(rhs) + + def _try_cast(val): + val = val.strip() + try: + m = re.match("(\\d+)(.*)", val) + if m is not None: + val = m.groups()[0] + + return int(val) + except ValueError: + return val + + # remove git version suffixes if present + lhs = lhs.split("+", 1)[0] + rhs = rhs.split("+", 1)[0] + + # parse the version strings in this basic way without `packaging` package + lhs = map(_try_cast, lhs.split(".")) + rhs = map(_try_cast, rhs.split(".")) + + for l, r in zip(lhs, rhs): + if l != r: + return l < r + + return True + + try: PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0") except (AttributeError, TypeError): diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 7b5461bccf..429d5ee767 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -16,6 +16,49 @@ from monai.utils import DeprecatedError, deprecated, deprecated_arg +class TestDeprecatedRC(unittest.TestCase): + def setUp(self): + self.test_version_rc = "0.6.0rc1" + self.test_version = "0.6.0" + self.next_version = "0.7.0" + + def test_warning(self): + """Test deprecated decorator with `since` and `removed` set for an RC version""" + + @deprecated(since=self.test_version, removed=self.next_version, version_val=self.test_version_rc) + def foo2(): + pass + + print(foo2()) + + def test_warning_milestone(self): + """Test deprecated decorator with `since` and `removed` set for a milestone version""" + + @deprecated(since=self.test_version, removed=self.next_version, version_val=self.test_version) + def foo2(): + pass + + self.assertWarns(DeprecationWarning, foo2) + + def test_warning_last(self): + """Test deprecated decorator with `since` and `removed` set, for the last version""" + + @deprecated(since=self.test_version, removed=self.next_version, version_val=self.next_version) + def foo3(): + pass + + self.assertRaises(DeprecatedError, foo3) + + def test_warning_beyond(self): + """Test deprecated decorator with `since` and `removed` set, beyond the last version""" + + @deprecated(since=self.test_version_rc, removed=self.test_version, version_val=self.next_version) + def foo3(): + pass + + self.assertRaises(DeprecatedError, foo3) + + class TestDeprecated(unittest.TestCase): def setUp(self): self.test_version = "0.5.3+96.g1fa03c2.dirty" diff --git a/tests/utils.py b/tests/utils.py index ca407924ff..5970e65d9d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,7 +32,7 @@ from monai.config.deviceconfig import USE_COMPILED from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism -from monai.utils.deprecated import version_leq +from monai.utils.module import version_leq nib, _ = optional_import("nibabel") From 15416b9cf2caef468630451ffc2ab228bfbeecba Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 29 Jun 2021 21:06:43 +0100 Subject: [PATCH 02/10] update Signed-off-by: Wenqi Li --- monai/utils/deprecated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/deprecated.py b/monai/utils/deprecated.py index 98adec229b..66793f4908 100644 --- a/monai/utils/deprecated.py +++ b/monai/utils/deprecated.py @@ -23,7 +23,6 @@ __all__ = ["deprecated", "deprecated_arg", "DeprecatedError"] warned_set = set() -warned_lock = Lock() class DeprecatedError(Exception): @@ -35,6 +34,7 @@ def warn_deprecated(obj, msg): Issue the warning message `msg` only once per process for the given object `obj`. When this function is called and `obj` is not in `warned_set`, it is added and the warning issued, if it's already present nothing happens. """ + global warned_set if obj not in warned_set: # ensure warning is issued only once per process warned_set.add(obj) warnings.warn(msg, category=DeprecationWarning, stacklevel=2) From b67214b30d6cbfc1c72325725d79909eb9ab5e0a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 29 Jun 2021 21:57:53 +0100 Subject: [PATCH 03/10] remove unused Signed-off-by: Wenqi Li --- monai/utils/deprecated.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/utils/deprecated.py b/monai/utils/deprecated.py index 66793f4908..ca0859eae9 100644 --- a/monai/utils/deprecated.py +++ b/monai/utils/deprecated.py @@ -12,7 +12,6 @@ import inspect import warnings from functools import wraps -from threading import Lock from types import FunctionType from typing import Optional From 2a5d54a570e1ea3fae6097cd950541980746fc59 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 29 Jun 2021 22:46:53 +0100 Subject: [PATCH 04/10] remove unused vars Signed-off-by: Wenqi Li --- monai/utils/deprecated.py | 10 ++-------- monai/utils/module.py | 3 +-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/monai/utils/deprecated.py b/monai/utils/deprecated.py index ca0859eae9..1caab6b4ab 100644 --- a/monai/utils/deprecated.py +++ b/monai/utils/deprecated.py @@ -21,8 +21,6 @@ __all__ = ["deprecated", "deprecated_arg", "DeprecatedError"] -warned_set = set() - class DeprecatedError(Exception): pass @@ -30,13 +28,9 @@ class DeprecatedError(Exception): def warn_deprecated(obj, msg): """ - Issue the warning message `msg` only once per process for the given object `obj`. When this function is called - and `obj` is not in `warned_set`, it is added and the warning issued, if it's already present nothing happens. + Issue the warning message `msg`. """ - global warned_set - if obj not in warned_set: # ensure warning is issued only once per process - warned_set.add(obj) - warnings.warn(msg, category=DeprecationWarning, stacklevel=2) + warnings.warn(msg, category=DeprecationWarning, stacklevel=2) def deprecated( diff --git a/monai/utils/module.py b/monai/utils/module.py index e6ac6b0cfe..11db042b16 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -10,7 +10,6 @@ # limitations under the License. import inspect -import re import sys import warnings from importlib import import_module @@ -282,7 +281,7 @@ def version_leq(lhs, rhs): def _try_cast(val): val = val.strip() try: - m = re.match("(\\d+)(.*)", val) + m = match("(\\d+)(.*)", val) if m is not None: val = m.groups()[0] From 1f87607558304954f80966857bcc84a7c0520acd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 30 Jun 2021 09:37:43 +0100 Subject: [PATCH 05/10] update debug info Signed-off-by: Wenqi Li --- monai/networks/nets/torchvision_fc.py | 2 +- monai/utils/module.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 19b973796f..66d905be85 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -68,7 +68,7 @@ def __init__( ) -@deprecated(since="0.6.0", version_val="0.7.0", msg_suffix="please consider using `TorchVisionFCModel` instead.") +@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="Please consider using `TorchVisionFCModel` instead.") class TorchVisionFullyConvModel(TorchVisionFCModel): """ Customize TorchVision models to replace fully connected layer by convolutional layer. diff --git a/monai/utils/module.py b/monai/utils/module.py index 11db042b16..aa3ae70955 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -274,9 +274,10 @@ def get_torch_version_tuple(): def version_leq(lhs, rhs): """Returns True if version `lhs` is earlier or equal to `rhs`.""" - ver, has_ver = optional_import("pkg_resources", name="parse_version") - if has_ver: - return ver(lhs) <= ver(rhs) + print('lhs', lhs, 'rhs', rhs) + # ver, has_ver = optional_import("pkg_resources", name="parse_version") + # if has_ver: + # return ver(lhs) <= ver(rhs) def _try_cast(val): val = val.strip() From 597adb315eef6f42848fa9cbbdf5047581e6eb3b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 30 Jun 2021 09:38:21 +0100 Subject: [PATCH 06/10] temp tests Signed-off-by: Wenqi Li --- .github/workflows/pythonapp.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 9e84862ede..46e639dc5e 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -7,6 +7,7 @@ on: - dev - main - releasing/* + - update-version-cmp pull_request: concurrency: From f6c19e9b35d75be001af08ce21dc4beee3b0836b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 30 Jun 2021 10:07:01 +0100 Subject: [PATCH 07/10] non integer comparison Signed-off-by: Wenqi Li --- monai/utils/module.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index aa3ae70955..f6d7687bb6 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -274,10 +274,9 @@ def get_torch_version_tuple(): def version_leq(lhs, rhs): """Returns True if version `lhs` is earlier or equal to `rhs`.""" - print('lhs', lhs, 'rhs', rhs) - # ver, has_ver = optional_import("pkg_resources", name="parse_version") - # if has_ver: - # return ver(lhs) <= ver(rhs) + ver, has_ver = optional_import("pkg_resources", name="parse_version") + if has_ver: + return ver(lhs) <= ver(rhs) def _try_cast(val): val = val.strip() @@ -285,8 +284,8 @@ def _try_cast(val): m = match("(\\d+)(.*)", val) if m is not None: val = m.groups()[0] - - return int(val) + return int(val) + return val except ValueError: return val @@ -300,7 +299,9 @@ def _try_cast(val): for l, r in zip(lhs, rhs): if l != r: - return l < r + if isinstance(l, int) and isinstance(r, int): + return l < r + return f"{l}" < f"{r}" return True From d462b173c004a52830c9736849cddc49e6b7f27e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 30 Jun 2021 10:49:11 +0100 Subject: [PATCH 08/10] adds tests Signed-off-by: Wenqi Li --- monai/networks/nets/torchvision_fc.py | 2 +- tests/test_version_leq.py | 81 +++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 tests/test_version_leq.py diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py index 66d905be85..19b973796f 100644 --- a/monai/networks/nets/torchvision_fc.py +++ b/monai/networks/nets/torchvision_fc.py @@ -68,7 +68,7 @@ def __init__( ) -@deprecated(since="0.6.0", removed="0.7.0", msg_suffix="Please consider using `TorchVisionFCModel` instead.") +@deprecated(since="0.6.0", version_val="0.7.0", msg_suffix="please consider using `TorchVisionFCModel` instead.") class TorchVisionFullyConvModel(TorchVisionFCModel): """ Customize TorchVision models to replace fully connected layer by convolutional layer. diff --git a/tests/test_version_leq.py b/tests/test_version_leq.py new file mode 100644 index 0000000000..ddc621484b --- /dev/null +++ b/tests/test_version_leq.py @@ -0,0 +1,81 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import unittest + +from parameterized import parameterized + +from monai.utils import version_leq + + +# from pkg_resources +def _pairwise(iterable): + "s -> (s0,s1), (s1,s2), (s2, s3), ..." + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +# from pkg_resources +torture = """ + 0.80.1-3 0.80.1-2 0.80.1-1 0.79.9999+0.80.0pre4-1 + 0.79.9999+0.80.0pre2-3 0.79.9999+0.80.0pre2-2 + 0.77.2-1 0.77.1-1 0.77.0-1 + """ + +TEST_CASES = ( + ("1.6.0", "1.6.0"), + ("1.6.0a0+9907a3e", "1.6.0"), + ("0+unknown", "0.6"), + ("ab", "abc"), + ("0.6rc1", "0.6"), + ("0.6", "0.7"), + ("1.2.a", "1.2a"), + ("1.2-rc1", "1.2rc1"), + ("0.4", "0.4.0"), + ("0.4.0.0", "0.4.0"), + ("0.4.0-0", "0.4-0"), + ("0post1", "0.0post1"), + ("0pre1", "0.0c1"), + ("0.0.0preview1", "0c1"), + ("0.0c1", "0-rc1"), + ("1.2a1", "1.2.a.1"), + ("1.2.a", "1.2a"), + ("2.1", "2.1.1"), + ("2a1", "2b0"), + ("2a1", "2.1"), + ("2.3a1", "2.3"), + ("2.1-1", "2.1-2"), + ("2.1-1", "2.1.1"), + ("2.1", "2.1post4"), + ("2.1a0-20040501", "2.1"), + ("1.1", "02.1"), + ("3.2", "3.2.post0"), + ("3.2post1", "3.2post2"), + ("0.4", "4.0"), + ("0.0.4", "0.4.0"), + ("0post1", "0.4post1"), + ("2.1.0-rc1", "2.1.0"), + ("2.1dev", "2.1a0"), +) + tuple(_pairwise(reversed(torture.split()))) + + +class TestVersioncmp(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_compare(self, a, b, expected=True): + """Test version_leq with `a` and `b`""" + self.assertEqual(version_leq(a, b), expected) + + +if __name__ == "__main__": + unittest.main() From e178b054904244af68325f49101c2db1dca5e414 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 30 Jun 2021 10:49:34 +0100 Subject: [PATCH 09/10] no temp tests Signed-off-by: Wenqi Li --- .github/workflows/pythonapp.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 46e639dc5e..9e84862ede 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -7,7 +7,6 @@ on: - dev - main - releasing/* - - update-version-cmp pull_request: concurrency: From 841858158395d84c06bcfb67cd96575bc233371b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 30 Jun 2021 10:50:47 +0100 Subject: [PATCH 10/10] update test Signed-off-by: Wenqi Li --- tests/test_version_leq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_version_leq.py b/tests/test_version_leq.py index ddc621484b..a1913069d3 100644 --- a/tests/test_version_leq.py +++ b/tests/test_version_leq.py @@ -70,7 +70,7 @@ def _pairwise(iterable): ) + tuple(_pairwise(reversed(torture.split()))) -class TestVersioncmp(unittest.TestCase): +class TestVersionCompare(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_compare(self, a, b, expected=True): """Test version_leq with `a` and `b`"""