From fd7f260e62790dee596b1461ffe65cdb2021e25a Mon Sep 17 00:00:00 2001 From: Jakob van Santen Date: Sun, 13 Dec 2020 19:25:07 +0100 Subject: [PATCH 1/2] python_api: handle array-like args in approx() This treats objects that expose an ndarray via the __array__ interface the same as direct subclasses of ndarray. Fixes #8132. --- changelog/8132.bugfix.rst | 10 ++++++++++ src/_pytest/python_api.py | 37 +++++++++++++++++++++++++++++++++---- testing/python/approx.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 changelog/8132.bugfix.rst diff --git a/changelog/8132.bugfix.rst b/changelog/8132.bugfix.rst new file mode 100644 index 00000000000..5be5e567491 --- /dev/null +++ b/changelog/8132.bugfix.rst @@ -0,0 +1,10 @@ +Fixed regression in ``approx``: in 6.2.0 ``approx`` no longer raises +``TypeError`` when dealing with non-numeric types, falling back to normal comparison. +Before 6.2.0, array types like tf.DeviceArray fell through to the scalar case, +and happened to compare correctly to a scalar if they had only one element. +After 6.2.0, these types began failing, because they inherited neither from +standard Python number hierarchy nor from ``numpy.ndarray``. + +``approx`` now converts arguments to ``numpy.ndarray`` if they expose the array +protocol and are not scalars. This treats array-like objects like numpy arrays, +regardless of size. diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index bae2076892b..5e01c0a57b6 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -15,9 +15,14 @@ from typing import Pattern from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +if TYPE_CHECKING: + from numpy import ndarray + + import _pytest._code from _pytest.compat import final from _pytest.compat import STRING_TYPES @@ -232,10 +237,11 @@ def __repr__(self) -> str: def __eq__(self, actual) -> bool: """Return whether the given value is equal to the expected value within the pre-specified tolerance.""" - if _is_numpy_array(actual): + asarray = _as_numpy_array(actual) + if asarray is not None: # Call ``__eq__()`` manually to prevent infinite-recursion with # numpy<1.13. See #3748. - return all(self.__eq__(a) for a in actual.flat) + return all(self.__eq__(a) for a in asarray.flat) # Short-circuit exact equality. if actual == self.expected: @@ -521,6 +527,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase: elif isinstance(expected, Mapping): cls = ApproxMapping elif _is_numpy_array(expected): + expected = _as_numpy_array(expected) cls = ApproxNumpy elif ( isinstance(expected, Iterable) @@ -536,7 +543,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase: def _is_numpy_array(obj: object) -> bool: - """Return true if the given object is a numpy array. + """Return true if the given object is implicitly convertible to numpy array. A special effort is made to avoid importing numpy unless it's really necessary. """ @@ -544,10 +551,32 @@ def _is_numpy_array(obj: object) -> bool: np: Any = sys.modules.get("numpy") if np is not None: - return isinstance(obj, np.ndarray) + # avoid infinite recursion on numpy scalars, which have __array__ + if np.isscalar(obj): + return False + elif isinstance(obj, np.ndarray): + return True + elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"): + return True return False +def _as_numpy_array(obj: object) -> Optional["ndarray"]: + """Return an ndarray if obj is implicitly convertible, and numpy is already imported.""" + import sys + + np: Any = sys.modules.get("numpy") + if np is not None: + # avoid infinite recursion on numpy scalars, which have __array__ + if np.isscalar(obj): + return None + elif isinstance(obj, np.ndarray): + return obj + elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"): + return np.asarray(obj) + return None + + # builtin pytest.raises helper _E = TypeVar("_E", bound=BaseException) diff --git a/testing/python/approx.py b/testing/python/approx.py index 91c1f3f85de..e76d6b774d6 100644 --- a/testing/python/approx.py +++ b/testing/python/approx.py @@ -447,6 +447,36 @@ def test_numpy_array_wrong_shape(self): assert a12 != approx(a21) assert a21 != approx(a12) + def test_numpy_array_protocol(self): + """ + array-like objects such as tensorflow's DeviceArray are handled like ndarray. + See issue #8132 + """ + np = pytest.importorskip("numpy") + + class DeviceArray: + def __init__(self, value, size): + self.value = value + self.size = size + + def __array__(self): + return self.value * np.ones(self.size) + + class DeviceScalar: + def __init__(self, value): + self.value = value + + def __array__(self): + return np.array(self.value) + + expected = 1 + actual = 1 + 1e-6 + assert approx(expected) == DeviceArray(actual, size=1) + assert approx(expected) == DeviceArray(actual, size=2) + assert approx(expected) == DeviceScalar(actual) + assert approx(DeviceScalar(expected)) == actual + assert approx(DeviceScalar(expected)) == DeviceScalar(actual) + def test_doctests(self, mocked_doctest_runner) -> None: import doctest From 95cae0baa42521f77a3b39978e9f88dfad025b37 Mon Sep 17 00:00:00 2001 From: Jakob van Santen Date: Sun, 13 Dec 2020 23:07:37 +0100 Subject: [PATCH 2/2] python_api: reduce duplication in _is_numpy_array --- src/_pytest/python_api.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 5e01c0a57b6..81ce4f89539 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -543,26 +543,18 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase: def _is_numpy_array(obj: object) -> bool: - """Return true if the given object is implicitly convertible to numpy array. - - A special effort is made to avoid importing numpy unless it's really necessary. """ - import sys - - np: Any = sys.modules.get("numpy") - if np is not None: - # avoid infinite recursion on numpy scalars, which have __array__ - if np.isscalar(obj): - return False - elif isinstance(obj, np.ndarray): - return True - elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"): - return True - return False + Return true if the given object is implicitly convertible to ndarray, + and numpy is already imported. + """ + return _as_numpy_array(obj) is not None def _as_numpy_array(obj: object) -> Optional["ndarray"]: - """Return an ndarray if obj is implicitly convertible, and numpy is already imported.""" + """ + Return an ndarray if the given object is implicitly convertible to ndarray, + and numpy is already imported, otherwise None. + """ import sys np: Any = sys.modules.get("numpy")