From dff8750ae282545df482b4b22682a3c0cfef736a Mon Sep 17 00:00:00 2001 From: Tobias Petersen Date: Tue, 14 Jan 2025 19:41:07 +0100 Subject: [PATCH 1/4] Accept generic ExceptionGroups for raises Closes #13115 --- AUTHORS | 1 + changelog/13115.improvement.rst | 1 + src/_pytest/python_api.py | 45 +++++++++++++++++++++++++++++---- testing/code/test_excinfo.py | 27 ++++++++++++++++++++ 4 files changed, 69 insertions(+), 5 deletions(-) create mode 100644 changelog/13115.improvement.rst diff --git a/AUTHORS b/AUTHORS index 8a1a7d183a3..8600735c8b5 100644 --- a/AUTHORS +++ b/AUTHORS @@ -435,6 +435,7 @@ Tim Hoffmann Tim Strazny TJ Bruno Tobias Diez +Tobias Petersen Tom Dalton Tom Viner Tomáš Gavenčiak diff --git a/changelog/13115.improvement.rst b/changelog/13115.improvement.rst new file mode 100644 index 00000000000..c77383c154f --- /dev/null +++ b/changelog/13115.improvement.rst @@ -0,0 +1 @@ +Allows supplying ``ExceptionGroup[Exception]`` and ``BaseExceptionGroup[BaseException]`` to ``pytest.raises`` to keep full typing on ExcInfo. diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 25cf9f04d61..97e673f43b3 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -12,10 +12,13 @@ from numbers import Complex import pprint import re +import sys from types import TracebackType from typing import Any from typing import cast from typing import final +from typing import get_args +from typing import get_origin from typing import overload from typing import TYPE_CHECKING from typing import TypeVar @@ -24,6 +27,10 @@ from _pytest.outcomes import fail +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + from exceptiongroup import ExceptionGroup + if TYPE_CHECKING: from numpy import ndarray @@ -954,15 +961,43 @@ def raises( f"Raising exceptions is already understood as failing the test, so you don't need " f"any special code to say 'this should never raise an exception'." ) + + expected_exceptions: tuple[type[E], ...] + origin_exc: type[E] | None = get_origin(expected_exception) if isinstance(expected_exception, type): - expected_exceptions: tuple[type[E], ...] = (expected_exception,) + expected_exceptions = (expected_exception,) + elif origin_exc and issubclass(origin_exc, BaseExceptionGroup): + expected_exceptions = (cast(type[E], expected_exception),) else: expected_exceptions = expected_exception - for exc in expected_exceptions: - if not isinstance(exc, type) or not issubclass(exc, BaseException): + + def validate_exc(exc: type[E]) -> type[E]: + origin_exc: type[E] | None = get_origin(exc) + if origin_exc and issubclass(origin_exc, BaseExceptionGroup): + exc_type = get_args(exc)[0] + if issubclass(origin_exc, ExceptionGroup) and exc_type is Exception: + return cast(type[E], origin_exc) + elif ( + issubclass(origin_exc, BaseExceptionGroup) and exc_type is BaseException + ): + return cast(type[E], origin_exc) + else: + raise ValueError( + f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` " + f"are accepted as generic types but got `{exc}`. " + f"As `raises` will catch all instances of the specified group regardless of the " + f"generic argument specific nested exceptions has to be checked " + f"with `ExceptionInfo.group_contains()`" + ) + + elif not isinstance(exc, type) or not issubclass(exc, BaseException): msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable] not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__ raise TypeError(msg.format(not_a)) + else: + return exc + + expected_exceptions = tuple(validate_exc(exc) for exc in expected_exceptions) message = f"DID NOT RAISE {expected_exception}" @@ -973,14 +1008,14 @@ def raises( msg += ", ".join(sorted(kwargs)) msg += "\nUse context-manager form instead?" raise TypeError(msg) - return RaisesContext(expected_exception, message, match) + return RaisesContext(expected_exceptions, message, match) else: func = args[0] if not callable(func): raise TypeError(f"{func!r} object (type: {type(func)}) must be callable") try: func(*args[1:], **kwargs) - except expected_exception as e: + except expected_exceptions as e: return _pytest._code.ExceptionInfo.from_exception(e) fail(message) diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index 22e695977e1..ae2f2084179 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -31,6 +31,7 @@ from _pytest._code.code import TracebackStyle if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup from exceptiongroup import ExceptionGroup @@ -453,6 +454,32 @@ def test_division_zero(): result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match]) +def test_raises_accepts_generic_group() -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(ExceptionGroup[Exception]) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + +def test_raises_accepts_generic_base_group() -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises(BaseExceptionGroup[BaseException]) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + +def test_raises_rejects_specific_generic_group() -> None: + with pytest.raises(ValueError): + pytest.raises(ExceptionGroup[RuntimeError]) + + +def test_raises_accepts_generic_group_in_tuple() -> None: + exc_group = ExceptionGroup("", [RuntimeError()]) + with pytest.raises((ValueError, ExceptionGroup[Exception])) as exc_info: + raise exc_group + assert exc_info.group_contains(RuntimeError) + + class TestGroupContains: def test_contains_exception_type(self) -> None: exc_group = ExceptionGroup("", [RuntimeError()]) From 01b2c0b4f64fe98630e28a871279e93257024063 Mon Sep 17 00:00:00 2001 From: Tobias Petersen Date: Mon, 20 Jan 2025 02:05:57 +0100 Subject: [PATCH 2/4] Fix review suggestions --- changelog/13115.improvement.rst | 1 + src/_pytest/python_api.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/changelog/13115.improvement.rst b/changelog/13115.improvement.rst index c77383c154f..20b9715a10f 100644 --- a/changelog/13115.improvement.rst +++ b/changelog/13115.improvement.rst @@ -1 +1,2 @@ Allows supplying ``ExceptionGroup[Exception]`` and ``BaseExceptionGroup[BaseException]`` to ``pytest.raises`` to keep full typing on ExcInfo. +Parametrizing with other element types remains an error - we do not check the types of child exceptions and thus do not permit code that might look like we do. diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 97e673f43b3..f79b975d632 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -975,10 +975,11 @@ def validate_exc(exc: type[E]) -> type[E]: origin_exc: type[E] | None = get_origin(exc) if origin_exc and issubclass(origin_exc, BaseExceptionGroup): exc_type = get_args(exc)[0] - if issubclass(origin_exc, ExceptionGroup) and exc_type is Exception: + if issubclass(origin_exc, ExceptionGroup) and exc_type in (Exception, Any): return cast(type[E], origin_exc) - elif ( - issubclass(origin_exc, BaseExceptionGroup) and exc_type is BaseException + elif issubclass(origin_exc, BaseExceptionGroup) and exc_type in ( + BaseException, + Any, ): return cast(type[E], origin_exc) else: From 840ce4fe163784b25dbf850a43c31802f61a5f8b Mon Sep 17 00:00:00 2001 From: Bruno Oliveira Date: Tue, 21 Jan 2025 07:40:00 -0300 Subject: [PATCH 3/4] Add extra test, changelog improvement --- changelog/13115.improvement.rst | 10 ++++++++-- src/_pytest/python_api.py | 1 + testing/code/test_excinfo.py | 19 +++++++++++++------ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/changelog/13115.improvement.rst b/changelog/13115.improvement.rst index 20b9715a10f..9ac45820917 100644 --- a/changelog/13115.improvement.rst +++ b/changelog/13115.improvement.rst @@ -1,2 +1,8 @@ -Allows supplying ``ExceptionGroup[Exception]`` and ``BaseExceptionGroup[BaseException]`` to ``pytest.raises`` to keep full typing on ExcInfo. -Parametrizing with other element types remains an error - we do not check the types of child exceptions and thus do not permit code that might look like we do. +Allows supplying ``ExceptionGroup[Exception]`` and ``BaseExceptionGroup[BaseException]`` to ``pytest.raises`` to keep full typing on :class:`ExceptionInfo `: + +.. code-block:: python + + with pytest.raises(ExceptionGroup[Exception]) as exc_info: + some_function() + +Parametrizing with other exception types remains an error - we do not check the types of child exceptions and thus do not permit code that might look like we do. diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index f79b975d632..9fb9e25529b 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -972,6 +972,7 @@ def raises( expected_exceptions = expected_exception def validate_exc(exc: type[E]) -> type[E]: + __tracebackhide__ = True origin_exc: type[E] | None = get_origin(exc) if origin_exc and issubclass(origin_exc, BaseExceptionGroup): exc_type = get_args(exc)[0] diff --git a/testing/code/test_excinfo.py b/testing/code/test_excinfo.py index ae2f2084179..2e1d4bdf6dc 100644 --- a/testing/code/test_excinfo.py +++ b/testing/code/test_excinfo.py @@ -455,16 +455,14 @@ def test_division_zero(): def test_raises_accepts_generic_group() -> None: - exc_group = ExceptionGroup("", [RuntimeError()]) with pytest.raises(ExceptionGroup[Exception]) as exc_info: - raise exc_group + raise ExceptionGroup("", [RuntimeError()]) assert exc_info.group_contains(RuntimeError) def test_raises_accepts_generic_base_group() -> None: - exc_group = ExceptionGroup("", [RuntimeError()]) with pytest.raises(BaseExceptionGroup[BaseException]) as exc_info: - raise exc_group + raise ExceptionGroup("", [RuntimeError()]) assert exc_info.group_contains(RuntimeError) @@ -474,12 +472,21 @@ def test_raises_rejects_specific_generic_group() -> None: def test_raises_accepts_generic_group_in_tuple() -> None: - exc_group = ExceptionGroup("", [RuntimeError()]) with pytest.raises((ValueError, ExceptionGroup[Exception])) as exc_info: - raise exc_group + raise ExceptionGroup("", [RuntimeError()]) assert exc_info.group_contains(RuntimeError) +def test_raises_exception_escapes_generic_group() -> None: + try: + with pytest.raises(ExceptionGroup[Exception]): + raise ValueError("my value error") + except ValueError as e: + assert str(e) == "my value error" + else: + pytest.fail("Expected ValueError to be raised") + + class TestGroupContains: def test_contains_exception_type(self) -> None: exc_group = ExceptionGroup("", [RuntimeError()]) From 4020739af7426f1fc44a054e5a271ef4e6373904 Mon Sep 17 00:00:00 2001 From: Tobias Petersen Date: Wed, 22 Jan 2025 11:53:10 +0100 Subject: [PATCH 4/4] Minor suggested refactor of if clause (review comment) --- src/_pytest/python_api.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 9fb9e25529b..ddbf9b87251 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -976,11 +976,11 @@ def validate_exc(exc: type[E]) -> type[E]: origin_exc: type[E] | None = get_origin(exc) if origin_exc and issubclass(origin_exc, BaseExceptionGroup): exc_type = get_args(exc)[0] - if issubclass(origin_exc, ExceptionGroup) and exc_type in (Exception, Any): - return cast(type[E], origin_exc) - elif issubclass(origin_exc, BaseExceptionGroup) and exc_type in ( - BaseException, - Any, + if ( + issubclass(origin_exc, ExceptionGroup) and exc_type in (Exception, Any) + ) or ( + issubclass(origin_exc, BaseExceptionGroup) + and exc_type in (BaseException, Any) ): return cast(type[E], origin_exc) else: