diff --git a/pyproject.toml b/pyproject.toml index 3121799cb7..d93cb382ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,9 @@ module = [ "trio._core._io_epoll", "trio._core._io_kqueue", "trio._core._local", - "trio._core._unbounded_queue", + "trio._core._multierror", "trio._core._thread_cache", + "trio._core._unbounded_queue", "trio._core._run", "trio._deprecate", "trio._dtls", diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 3c6ebb789f..6e4cb8b923 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -2,7 +2,9 @@ import sys import warnings -from typing import TYPE_CHECKING +from collections.abc import Callable, Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Any, cast, overload import attr @@ -14,13 +16,15 @@ from traceback import print_exception if TYPE_CHECKING: - from types import TracebackType + from typing_extensions import Self ################################################################ # MultiError ################################################################ -def _filter_impl(handler, root_exc): +def _filter_impl( + handler: Callable[[BaseException], BaseException | None], root_exc: BaseException +) -> BaseException | None: # We have a tree of MultiError's, like: # # MultiError([ @@ -79,7 +83,9 @@ def _filter_impl(handler, root_exc): # Filters a subtree, ignoring tracebacks, while keeping a record of # which MultiErrors were preserved unchanged - def filter_tree(exc, preserved): + def filter_tree( + exc: MultiError | BaseException, preserved: set[int] + ) -> MultiError | BaseException | None: if isinstance(exc, MultiError): new_exceptions = [] changed = False @@ -103,7 +109,9 @@ def filter_tree(exc, preserved): new_exc.__context__ = exc return new_exc - def push_tb_down(tb, exc, preserved): + def push_tb_down( + tb: TracebackType | None, exc: BaseException, preserved: set[int] + ) -> None: if id(exc) in preserved: return new_tb = concat_tb(tb, exc.__traceback__) @@ -114,7 +122,7 @@ def push_tb_down(tb, exc, preserved): else: exc.__traceback__ = new_tb - preserved = set() + preserved: set[int] = set() new_root_exc = filter_tree(root_exc, preserved) push_tb_down(None, root_exc, preserved) # Delete the local functions to avoid a reference cycle (see @@ -130,9 +138,9 @@ def push_tb_down(tb, exc, preserved): # frame show up in the traceback; otherwise, we leave no trace.) @attr.s(frozen=True) class MultiErrorCatcher: - _handler = attr.ib() + _handler: Callable[[BaseException], BaseException | None] = attr.ib() - def __enter__(self): + def __enter__(self) -> None: pass def __exit__( @@ -167,7 +175,13 @@ def __exit__( return False -class MultiError(BaseExceptionGroup): +if TYPE_CHECKING: + _BaseExceptionGroup = BaseExceptionGroup[BaseException] +else: + _BaseExceptionGroup = BaseExceptionGroup + + +class MultiError(_BaseExceptionGroup): """An exception that contains other exceptions; also known as an "inception". @@ -190,7 +204,9 @@ class MultiError(BaseExceptionGroup): """ - def __init__(self, exceptions, *, _collapse=True): + def __init__( + self, exceptions: Sequence[BaseException], *, _collapse: bool = True + ) -> None: self.collapse = _collapse # Avoid double initialization when _collapse is True and exceptions[0] returned @@ -201,7 +217,9 @@ def __init__(self, exceptions, *, _collapse=True): super().__init__("multiple tasks failed", exceptions) - def __new__(cls, exceptions, *, _collapse=True): + def __new__( # type: ignore[misc] # mypy says __new__ must return a class instance + cls, exceptions: Sequence[BaseException], *, _collapse: bool = True + ) -> NonBaseMultiError | Self | BaseException: exceptions = list(exceptions) for exc in exceptions: if not isinstance(exc, BaseException): @@ -218,33 +236,54 @@ def __new__(cls, exceptions, *, _collapse=True): # In an earlier version of the code, we didn't define __init__ and # simply set the `exceptions` attribute directly on the new object. # However, linters expect attributes to be initialized in __init__. + from_class: type[Self] | type[NonBaseMultiError] = cls if all(isinstance(exc, Exception) for exc in exceptions): - cls = NonBaseMultiError + from_class = NonBaseMultiError - return super().__new__(cls, "multiple tasks failed", exceptions) + # Ignoring arg-type: 'Argument 3 to "__new__" of "BaseExceptionGroup" has incompatible type "list[BaseException]"; expected "Sequence[_BaseExceptionT_co]"' + # We have checked that exceptions is indeed a list of BaseException objects, this is fine. + new_obj = super().__new__(from_class, "multiple tasks failed", exceptions) # type: ignore[arg-type] + assert isinstance(new_obj, (cls, NonBaseMultiError)) + return new_obj - def __reduce__(self): + def __reduce__( + self, + ) -> tuple[object, tuple[type[Self], list[BaseException]], dict[str, bool]]: return ( self.__new__, (self.__class__, list(self.exceptions)), {"collapse": self.collapse}, ) - def __str__(self): + def __str__(self) -> str: return ", ".join(repr(exc) for exc in self.exceptions) - def __repr__(self): + def __repr__(self) -> str: return f"" - def derive(self, __excs): + @overload + def derive(self, excs: Sequence[Exception], /) -> NonBaseMultiError: + ... + + @overload + def derive(self, excs: Sequence[BaseException], /) -> MultiError: + ... + + def derive( + self, excs: Sequence[Exception | BaseException], / + ) -> NonBaseMultiError | MultiError: # We use _collapse=False here to get ExceptionGroup semantics, since derive() # is part of the PEP 654 API - exc = MultiError(__excs, _collapse=False) + exc = MultiError(excs, _collapse=False) exc.collapse = self.collapse return exc @classmethod - def filter(cls, handler, root_exc): + def filter( + cls, + handler: Callable[[BaseException], BaseException | None], + root_exc: BaseException, + ) -> BaseException | None: """Apply the given ``handler`` to all the exceptions in ``root_exc``. Args: @@ -268,7 +307,9 @@ def filter(cls, handler, root_exc): return _filter_impl(handler, root_exc) @classmethod - def catch(cls, handler): + def catch( + cls, handler: Callable[[BaseException], BaseException | None] + ) -> MultiErrorCatcher: """Return a context manager that catches and re-throws exceptions after running :meth:`filter` on them. @@ -286,8 +327,14 @@ def catch(cls, handler): return MultiErrorCatcher(handler) -class NonBaseMultiError(MultiError, ExceptionGroup): - pass +if TYPE_CHECKING: + _ExceptionGroup = ExceptionGroup[Exception] +else: + _ExceptionGroup = ExceptionGroup + + +class NonBaseMultiError(MultiError, _ExceptionGroup): + __slots__ = () # Clean up exception printing: @@ -316,30 +363,6 @@ class NonBaseMultiError(MultiError, ExceptionGroup): try: import tputil except ImportError: - have_tproxy = False -else: - have_tproxy = True - -if have_tproxy: - # http://doc.pypy.org/en/latest/objspace-proxies.html - def copy_tb(base_tb, tb_next): - def controller(operation): - # Rationale for pragma: I looked fairly carefully and tried a few - # things, and AFAICT it's not actually possible to get any - # 'opname' that isn't __getattr__ or __getattribute__. So there's - # no missing test we could add, and no value in coverage nagging - # us about adding one. - if operation.opname in [ - "__getattribute__", - "__getattr__", - ]: # pragma: no cover - if operation.args[0] == "tb_next": - return tb_next - return operation.delegate() - - return tputil.make_proxy(controller, type(base_tb), base_tb) - -else: # ctypes it is import ctypes @@ -359,12 +382,13 @@ class CTraceback(ctypes.Structure): ("tb_lineno", ctypes.c_int), ] - def copy_tb(base_tb, tb_next): + def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: # TracebackType has no public constructor, so allocate one the hard way try: raise ValueError except ValueError as exc: new_tb = exc.__traceback__ + assert new_tb is not None c_new_tb = CTraceback.from_address(id(new_tb)) # At the C level, tb_next either pointer to the next traceback or is @@ -377,14 +401,14 @@ def copy_tb(base_tb, tb_next): # which it already is, so we're done. Otherwise, we have to actually # do some work: if tb_next is not None: - _ctypes.Py_INCREF(tb_next) + _ctypes.Py_INCREF(tb_next) # type: ignore[attr-defined] c_new_tb.tb_next = id(tb_next) assert c_new_tb.tb_frame is not None - _ctypes.Py_INCREF(base_tb.tb_frame) + _ctypes.Py_INCREF(base_tb.tb_frame) # type: ignore[attr-defined] old_tb_frame = new_tb.tb_frame c_new_tb.tb_frame = id(base_tb.tb_frame) - _ctypes.Py_DECREF(old_tb_frame) + _ctypes.Py_DECREF(old_tb_frame) # type: ignore[attr-defined] c_new_tb.tb_lasti = base_tb.tb_lasti c_new_tb.tb_lineno = base_tb.tb_lineno @@ -396,8 +420,33 @@ def copy_tb(base_tb, tb_next): # see test_MultiError_catch_doesnt_create_cyclic_garbage del new_tb, old_tb_frame +else: + # http://doc.pypy.org/en/latest/objspace-proxies.html + def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: + # Mypy refuses to believe that ProxyOperation can be imported properly + # TODO: will need no-any-unimported if/when that's toggled on + def controller(operation: tputil.ProxyOperation) -> Any | None: + # Rationale for pragma: I looked fairly carefully and tried a few + # things, and AFAICT it's not actually possible to get any + # 'opname' that isn't __getattr__ or __getattribute__. So there's + # no missing test we could add, and no value in coverage nagging + # us about adding one. + if operation.opname in [ + "__getattribute__", + "__getattr__", + ]: # pragma: no cover + if operation.args[0] == "tb_next": + return tb_next + return operation.delegate() # Deligate is reverting to original behaviour + + return cast( + TracebackType, tputil.make_proxy(controller, type(base_tb), base_tb) + ) # Returns proxy to traceback + -def concat_tb(head, tail): +def concat_tb( + head: TracebackType | None, tail: TracebackType | None +) -> TracebackType | None: # We have to use an iterative algorithm here, because in the worst case # this might be a RecursionError stack that is by definition too deep to # process by recursion! @@ -429,7 +478,13 @@ def concat_tb(head, tail): ) else: - def trio_show_traceback(self, etype, value, tb, tb_offset=None): + def trio_show_traceback( + self: IPython.core.interactiveshell.InteractiveShell, + etype: type[BaseException], + value: BaseException, + tb: TracebackType, + tb_offset: int | None = None, + ) -> None: # XX it would be better to integrate with IPython's fancy # exception formatting stuff (and not ignore tb_offset) print_exception(value) @@ -460,10 +515,14 @@ def trio_show_traceback(self, etype, value, tb, tb_offset=None): assert sys.excepthook is apport_python_hook.apport_excepthook - def replacement_excepthook(etype, value, tb): - sys.stderr.write("".join(format_exception(etype, value, tb))) + def replacement_excepthook( + etype: type[BaseException], value: BaseException, tb: TracebackType | None + ) -> None: + # This does work, it's an overloaded function + sys.stderr.write("".join(format_exception(etype, value, tb))) # type: ignore[arg-type] fake_sys = ModuleType("trio_fake_sys") fake_sys.__dict__.update(sys.__dict__) - fake_sys.__excepthook__ = replacement_excepthook # type: ignore + # Fake does have __excepthook__ after __dict__ update, but type checkers don't recognize this + fake_sys.__excepthook__ = replacement_excepthook # type: ignore[attr-defined] apport_python_hook.sys = fake_sys