Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
load_submodules,
min_version,
optional_import,
version_leq,
)
from .profiling import PerfContext, torch_profiler_full, torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end
from .state_cacher import StateCacher
83 changes: 34 additions & 49 deletions monai/utils/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,62 +10,27 @@
# limitations under the License.

import inspect
import re
import warnings
from functools import wraps
from threading import Lock
from types import FunctionType
from typing import Optional

from monai.utils.module import version_leq

from .. import __version__

__all__ = ["deprecated", "deprecated_arg", "DeprecatedError"]

warned_set = set()
warned_lock = Lock()


class DeprecatedError(Exception):
pass


def warn_deprecated(obj, msg):
"""
Issue the warning message `msg` only once per process for the given object `obj`. When this function is called
and `obj` is not in `warned_set`, it is added and the warning issued, if it's already present nothing happens.
Issue the warning message `msg`.
"""
if obj not in warned_set: # ensure warning is issued only once per process
warned_set.add(obj)
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)


def version_leq(lhs, rhs):
"""Returns True if version `lhs` is earlier or equal to `rhs`."""

def _try_cast(val):
val = val.strip()
try:
m = re.match("(\\d+)(.*)", val)
if m is not None:
val = m.groups()[0]

return int(val)
except ValueError:
return val

# remove git version suffixes if present
lhs = lhs.split("+", 1)[0]
rhs = rhs.split("+", 1)[0]

# parse the version strings in this basic way to avoid needing the `packaging` package
lhs = map(_try_cast, lhs.split("."))
rhs = map(_try_cast, rhs.split("."))

for l, r in zip(lhs, rhs):
if l != r:
return l < r

return True
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)


def deprecated(
Expand All @@ -89,13 +54,22 @@ def deprecated(
Decorated definition which warns or raises exception when used
"""

is_deprecated = since is not None and version_leq(since, version_val)
is_removed = removed is not None and version_leq(removed, version_val)
if since is not None and removed is not None and not version_leq(since, removed):
raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.")
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)

if is_not_yet_deprecated:
# smaller than `since`, do nothing
return lambda obj: obj

if since is None and removed is None:
# raise a DeprecatedError directly
is_removed = True
is_deprecated = True
else:
# compare the numbers
is_deprecated = since is not None and version_leq(since, version_val)
is_removed = removed is not None and version_leq(removed, version_val)

def _decorator(obj):
is_func = isinstance(obj, FunctionType)
call_obj = obj if is_func else obj.__init__
Expand All @@ -115,10 +89,10 @@ def _decorator(obj):

@wraps(call_obj)
def _wrapper(*args, **kwargs):
if is_removed:
raise DeprecatedError(msg)
if is_deprecated:
warn_deprecated(obj, msg)
else:
raise DeprecatedError(msg)

return call_obj(*args, **kwargs)

Expand Down Expand Up @@ -152,10 +126,21 @@ def deprecated_arg(
Returns:
Decorated callable which warns or raises exception when deprecated argument used
"""

is_deprecated = since is not None and version_leq(since, version_val)
is_removed = removed is not None and version_leq(removed, version_val)
if since is not None and removed is not None and not version_leq(since, removed):
raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.")
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
if is_not_yet_deprecated:
# smaller than `since`, do nothing
return lambda obj: obj

if since is None and removed is None:
# raise a DeprecatedError directly
is_removed = True
is_deprecated = True
else:
# compare the numbers
is_deprecated = since is not None and version_leq(since, version_val)
is_removed = removed is not None and version_leq(removed, version_val)

if is_not_yet_deprecated:
return lambda obj: obj
Expand Down Expand Up @@ -186,10 +171,10 @@ def _wrapper(*args, **kwargs):
kw_found = "kwargs" in binding and name in binding["kwargs"]

if positional_found or kw_found:
if is_removed:
raise DeprecatedError(msg)
if is_deprecated:
warn_deprecated(argname, msg)
else:
raise DeprecatedError(msg)

return func(*args, **kwargs)

Expand Down
37 changes: 36 additions & 1 deletion monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import torch

from .deprecated import version_leq
from .misc import ensure_tuple

OPTIONAL_IMPORT_MSG_FMT = "{}"
Expand All @@ -37,6 +36,7 @@
"get_package_version",
"get_torch_version_tuple",
"PT_BEFORE_1_7",
"version_leq",
]


Expand Down Expand Up @@ -271,6 +271,41 @@ def get_torch_version_tuple():
return tuple((int(x) for x in torch.__version__.split(".")[:2]))


def version_leq(lhs, rhs):
"""Returns True if version `lhs` is earlier or equal to `rhs`."""

ver, has_ver = optional_import("pkg_resources", name="parse_version")
if has_ver:
return ver(lhs) <= ver(rhs)

def _try_cast(val):
val = val.strip()
try:
m = match("(\\d+)(.*)", val)
if m is not None:
val = m.groups()[0]
return int(val)
return val
except ValueError:
return val

# remove git version suffixes if present
lhs = lhs.split("+", 1)[0]
rhs = rhs.split("+", 1)[0]

# parse the version strings in this basic way without `packaging` package
lhs = map(_try_cast, lhs.split("."))
rhs = map(_try_cast, rhs.split("."))

for l, r in zip(lhs, rhs):
if l != r:
if isinstance(l, int) and isinstance(r, int):
return l < r
return f"{l}" < f"{r}"

return True


try:
PT_BEFORE_1_7 = torch.__version__ != "1.7.0" and version_leq(torch.__version__, "1.7.0")
except (AttributeError, TypeError):
Expand Down
43 changes: 43 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,49 @@
from monai.utils import DeprecatedError, deprecated, deprecated_arg


class TestDeprecatedRC(unittest.TestCase):
def setUp(self):
self.test_version_rc = "0.6.0rc1"
self.test_version = "0.6.0"
self.next_version = "0.7.0"

def test_warning(self):
"""Test deprecated decorator with `since` and `removed` set for an RC version"""

@deprecated(since=self.test_version, removed=self.next_version, version_val=self.test_version_rc)
def foo2():
pass

print(foo2())

def test_warning_milestone(self):
"""Test deprecated decorator with `since` and `removed` set for a milestone version"""

@deprecated(since=self.test_version, removed=self.next_version, version_val=self.test_version)
def foo2():
pass

self.assertWarns(DeprecationWarning, foo2)

def test_warning_last(self):
"""Test deprecated decorator with `since` and `removed` set, for the last version"""

@deprecated(since=self.test_version, removed=self.next_version, version_val=self.next_version)
def foo3():
pass

self.assertRaises(DeprecatedError, foo3)

def test_warning_beyond(self):
"""Test deprecated decorator with `since` and `removed` set, beyond the last version"""

@deprecated(since=self.test_version_rc, removed=self.test_version, version_val=self.next_version)
def foo3():
pass

self.assertRaises(DeprecatedError, foo3)


class TestDeprecated(unittest.TestCase):
def setUp(self):
self.test_version = "0.5.3+96.g1fa03c2.dirty"
Expand Down
81 changes: 81 additions & 0 deletions tests/test_version_leq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import itertools
import unittest

from parameterized import parameterized

from monai.utils import version_leq


# from pkg_resources
def _pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)


# from pkg_resources
torture = """
0.80.1-3 0.80.1-2 0.80.1-1 0.79.9999+0.80.0pre4-1
0.79.9999+0.80.0pre2-3 0.79.9999+0.80.0pre2-2
0.77.2-1 0.77.1-1 0.77.0-1
"""

TEST_CASES = (
("1.6.0", "1.6.0"),
("1.6.0a0+9907a3e", "1.6.0"),
("0+unknown", "0.6"),
("ab", "abc"),
("0.6rc1", "0.6"),
("0.6", "0.7"),
("1.2.a", "1.2a"),
("1.2-rc1", "1.2rc1"),
("0.4", "0.4.0"),
("0.4.0.0", "0.4.0"),
("0.4.0-0", "0.4-0"),
("0post1", "0.0post1"),
("0pre1", "0.0c1"),
("0.0.0preview1", "0c1"),
("0.0c1", "0-rc1"),
("1.2a1", "1.2.a.1"),
("1.2.a", "1.2a"),
("2.1", "2.1.1"),
("2a1", "2b0"),
("2a1", "2.1"),
("2.3a1", "2.3"),
("2.1-1", "2.1-2"),
("2.1-1", "2.1.1"),
("2.1", "2.1post4"),
("2.1a0-20040501", "2.1"),
("1.1", "02.1"),
("3.2", "3.2.post0"),
("3.2post1", "3.2post2"),
("0.4", "4.0"),
("0.0.4", "0.4.0"),
("0post1", "0.4post1"),
("2.1.0-rc1", "2.1.0"),
("2.1dev", "2.1a0"),
) + tuple(_pairwise(reversed(torture.split())))


class TestVersionCompare(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_compare(self, a, b, expected=True):
"""Test version_leq with `a` and `b`"""
self.assertEqual(version_leq(a, b), expected)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from monai.config.deviceconfig import USE_COMPILED
from monai.data import create_test_image_2d, create_test_image_3d
from monai.utils import ensure_tuple, optional_import, set_determinism
from monai.utils.deprecated import version_leq
from monai.utils.module import version_leq

nib, _ = optional_import("nibabel")

Expand Down