diff --git a/trio/_core/_run.py b/trio/_core/_run.py index be0497b635..31ff874a40 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1,60 +1,61 @@ +from __future__ import annotations + +import enum import functools +import gc import itertools import random import select import sys import threading -import gc +import warnings from collections import deque +from collections.abc import Callable from contextlib import contextmanager -import warnings -import enum - from contextvars import copy_context +from heapq import heapify, heappop, heappush from math import inf from time import perf_counter -from typing import Callable, TYPE_CHECKING - -from sniffio import current_async_library_cvar +from typing import TYPE_CHECKING, Any, NoReturn, TypeVar import attr -from heapq import heapify, heappop, heappush -from sortedcontainers import SortedDict from outcome import Error, Outcome, Value, capture +from sniffio import current_async_library_cvar +from sortedcontainers import SortedDict +# An unfortunate name collision here with trio._util.Final +from typing_extensions import Final as FinalT + +from .. import _core +from .._util import Final, NoPublicConstructor, coroutine_or_error +from ._asyncgens import AsyncGenerators from ._entry_queue import EntryQueue, TrioToken -from ._exceptions import TrioInternalError, RunFinishedError, Cancelled -from ._ki import ( - LOCALS_KEY_KI_PROTECTION_ENABLED, - KIManager, - enable_ki_protection, -) +from ._exceptions import Cancelled, RunFinishedError, TrioInternalError +from ._instrumentation import Instruments +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED, KIManager, enable_ki_protection from ._multierror import MultiError, concat_tb +from ._thread_cache import start_thread_soon from ._traps import ( Abort, - wait_task_rescheduled, - cancel_shielded_checkpoint, CancelShieldedCheckpoint, PermanentlyDetachCoroutineObject, WaitTaskRescheduled, + cancel_shielded_checkpoint, + wait_task_rescheduled, ) -from ._asyncgens import AsyncGenerators -from ._thread_cache import start_thread_soon -from ._instrumentation import Instruments -from .. import _core -from .._util import Final, NoPublicConstructor, coroutine_or_error if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup -DEADLINE_HEAP_MIN_PRUNE_THRESHOLD = 1000 +DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 -_NO_SEND = object() +_NO_SEND: FinalT = object() +FnT = TypeVar("FnT", bound="Callable[..., Any]") # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn): +def _public(fn: FnT) -> FnT: return fn @@ -63,15 +64,31 @@ def _public(fn): # variable to True, and registers the Random instance _r for Hypothesis # to manage for each test case, which together should make Trio's task # scheduling loop deterministic. We have a test for that, of course. -_ALLOW_DETERMINISTIC_SCHEDULING = False +_ALLOW_DETERMINISTIC_SCHEDULING: FinalT = False _r = random.Random() -# On CPython, Context.run() is implemented in C and doesn't show up in -# tracebacks. On PyPy, it is implemented in Python and adds 1 frame to tracebacks. -def _count_context_run_tb_frames(): - def function_with_unique_name_xyzzy(): - 1 / 0 +def _count_context_run_tb_frames() -> int: + """Count implementation dependent traceback frames from Context.run() + + On CPython, Context.run() is implemented in C and doesn't show up in + tracebacks. On PyPy, it is implemented in Python and adds 1 frame to + tracebacks. + + Returns: + int: Traceback frame count + + """ + + def function_with_unique_name_xyzzy() -> NoReturn: + try: + 1 / 0 + except ZeroDivisionError: + raise + else: # pragma: no cover + raise TrioInternalError( + "A ZeroDivisionError should have been raised, but it wasn't." + ) ctx = copy_context() try: @@ -79,15 +96,20 @@ def function_with_unique_name_xyzzy(): except ZeroDivisionError as exc: tb = exc.__traceback__ # Skip the frame where we caught it - tb = tb.tb_next + tb = tb.tb_next # type: ignore[union-attr] count = 0 - while tb.tb_frame.f_code.co_name != "function_with_unique_name_xyzzy": - tb = tb.tb_next + while tb.tb_frame.f_code.co_name != "function_with_unique_name_xyzzy": # type: ignore[union-attr] + tb = tb.tb_next # type: ignore[union-attr] count += 1 return count + else: # pragma: no cover + raise TrioInternalError( + f"The purpose of {function_with_unique_name_xyzzy.__name__} is " + "to raise a ZeroDivisionError, but it didn't." + ) -CONTEXT_RUN_TB_FRAMES = _count_context_run_tb_frames() +CONTEXT_RUN_TB_FRAMES: FinalT = _count_context_run_tb_frames() @attr.s(frozen=True, slots=True) @@ -95,18 +117,18 @@ class SystemClock: # Add a large random offset to our clock to ensure that if people # accidentally call time.perf_counter() directly or start comparing clocks # between different runs, then they'll notice the bug quickly: - offset = attr.ib(factory=lambda: _r.uniform(10000, 200000)) + offset: float = attr.ib(factory=lambda: _r.uniform(10000, 200000)) - def start_clock(self): + def start_clock(self) -> None: pass # In cPython 3, on every platform except Windows, perf_counter is # exactly the same as time.monotonic; and on Windows, it uses # QueryPerformanceCounter instead of GetTickCount64. - def current_time(self): + def current_time(self) -> float: return self.offset + perf_counter() - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: return deadline - self.current_time() @@ -1119,7 +1141,7 @@ class Task(metaclass=NoPublicConstructor): name = attr.ib() # PEP 567 contextvars context context = attr.ib() - _counter = attr.ib(init=False, factory=itertools.count().__next__) + _counter: int = attr.ib(init=False, factory=itertools.count().__next__) # Invariant: # - for unscheduled tasks, _next_send_fn and _next_send are both None @@ -1293,7 +1315,7 @@ class RunContext(threading.local): task: Task -GLOBAL_RUN_CONTEXT = RunContext() +GLOBAL_RUN_CONTEXT: FinalT = RunContext() @attr.s(frozen=True) @@ -1380,7 +1402,7 @@ class Runner: # Run-local values, see _local.py _locals = attr.ib(factory=dict) - runq = attr.ib(factory=deque) + runq: deque[Task] = attr.ib(factory=deque) tasks = attr.ib(factory=set) deadlines = attr.ib(factory=Deadlines) @@ -1957,8 +1979,8 @@ def run( *args, clock=None, instruments=(), - restrict_keyboard_interrupt_to_checkpoints=False, - strict_exception_groups=False, + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = False, ): """Run a Trio-flavored async function, and return the result. @@ -2063,11 +2085,11 @@ def start_guest_run( run_sync_soon_threadsafe, done_callback, run_sync_soon_not_threadsafe=None, - host_uses_signal_set_wakeup_fd=False, + host_uses_signal_set_wakeup_fd: bool = False, clock=None, instruments=(), - restrict_keyboard_interrupt_to_checkpoints=False, - strict_exception_groups=False, + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = False, ): """Start a "guest" run of Trio on top of some other "host" event loop. @@ -2147,14 +2169,19 @@ def my_done_callback(run_outcome): # 24 hours is arbitrary, but it avoids issues like people setting timeouts of # 10**20 and then getting integer overflows in the underlying system calls. -_MAX_TIMEOUT = 24 * 60 * 60 +_MAX_TIMEOUT: FinalT = 24 * 60 * 60 # Weird quirk: this is written as a generator in order to support "guest # mode", where our core event loop gets unrolled into a series of callbacks on # the host loop. If you're doing a regular trio.run then this gets run # straight through. -def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False): +def unrolled_run( + runner: Runner, + async_fn, + args, + host_uses_signal_set_wakeup_fd: bool = False, +): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True __tracebackhide__ = True @@ -2173,7 +2200,7 @@ def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False): # here is our event loop: while runner.tasks: if runner.runq: - timeout = 0 + timeout: float = 0 else: deadline = runner.deadlines.next_deadline() timeout = runner.clock.deadline_to_sleep_time(deadline) @@ -2301,8 +2328,10 @@ def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False): # frame we always remove, because it's this function # catching it, and then in addition we remove however many # more Context.run adds. - tb = task_exc.__traceback__.tb_next - for _ in range(CONTEXT_RUN_TB_FRAMES): + tb = task_exc.__traceback__ + for _ in range(1 + CONTEXT_RUN_TB_FRAMES): + if tb is None: + break tb = tb.tb_next final_outcome = Error(task_exc.with_traceback(tb)) # Remove local refs so that e.g. cancelled coroutine locals @@ -2397,7 +2426,7 @@ def started(self, value=None): pass -TASK_STATUS_IGNORED = _TaskStatusIgnored() +TASK_STATUS_IGNORED: FinalT = _TaskStatusIgnored() def current_task(): @@ -2493,16 +2522,16 @@ async def checkpoint_if_cancelled(): if sys.platform == "win32": - from ._io_windows import WindowsIOManager as TheIOManager from ._generated_io_windows import * + from ._io_windows import WindowsIOManager as TheIOManager elif sys.platform == "linux" or (not TYPE_CHECKING and hasattr(select, "epoll")): - from ._io_epoll import EpollIOManager as TheIOManager from ._generated_io_epoll import * + from ._io_epoll import EpollIOManager as TheIOManager elif TYPE_CHECKING or hasattr(select, "kqueue"): - from ._io_kqueue import KqueueIOManager as TheIOManager from ._generated_io_kqueue import * + from ._io_kqueue import KqueueIOManager as TheIOManager else: # pragma: no cover raise NotImplementedError("unsupported platform") -from ._generated_run import * from ._generated_instrumentation import * +from ._generated_run import *