diff --git a/docs/source/conf.py b/docs/source/conf.py index cfac66576b..68a5a22a81 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,10 +62,22 @@ ("py:obj", "trio._abc.SendType"), ("py:obj", "trio._abc.T"), ("py:obj", "trio._abc.T_resource"), + ("py:class", "types.FrameType"), ] autodoc_inherit_docstrings = False default_role = "obj" +# These have incorrect __module__ set in stdlib and give the error +# `py:class reference target not found` +# Some of the nitpick_ignore's above can probably be fixed with this. +# See https://github.com/sphinx-doc/sphinx/issues/8315#issuecomment-751335798 +autodoc_type_aliases = { + # aliasing doesn't actually fix the warning for types.FrameType, but displaying + # "types.FrameType" is more helpful than just "frame" + "FrameType": "types.FrameType", +} + + # XX hack the RTD theme until # https://github.com/rtfd/sphinx_rtd_theme/pull/382 # is shipped (should be in the release after 0.2.4) diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 922ae4680e..4f4f4d62b9 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1096,6 +1096,8 @@ Broadcasting an event with :class:`Event` .. autoclass:: Event :members: +.. autoclass:: EventStatistics + :members: .. _channels: @@ -1456,6 +1458,16 @@ don't have any special access to Trio's internals.) .. autoclass:: Condition :members: +These primitives return statistics objects that can be inspected. + +.. autoclass:: CapacityLimiterStatistics + :members: + +.. autoclass:: LockStatistics + :members: + +.. autoclass:: ConditionStatistics + :members: .. _async-generators: diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 815cff2ddf..bacebff5ad 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -378,6 +378,8 @@ Wait queue abstraction :members: :undoc-members: +.. autoclass:: ParkingLotStatistics + :members: Low-level checkpoint functions ------------------------------ diff --git a/trio/__init__.py b/trio/__init__.py index 42b57e69c0..2b8810504b 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -78,9 +78,13 @@ from ._subprocess import Process as Process, run_process as run_process from ._sync import ( CapacityLimiter as CapacityLimiter, + CapacityLimiterStatistics as CapacityLimiterStatistics, Condition as Condition, + ConditionStatistics as ConditionStatistics, Event as Event, + EventStatistics as EventStatistics, Lock as Lock, + LockStatistics as LockStatistics, Semaphore as Semaphore, StrictFIFOLock as StrictFIFOLock, ) diff --git a/trio/_abc.py b/trio/_abc.py index a01812dae8..2a1721db13 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,8 +1,15 @@ +from __future__ import annotations + from abc import ABCMeta, abstractmethod -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar import trio +if TYPE_CHECKING: + from types import TracebackType + + from typing_extensions import Self + # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a # __dict__ onto subclasses. @@ -12,7 +19,7 @@ class Clock(metaclass=ABCMeta): __slots__ = () @abstractmethod - def start_clock(self): + def start_clock(self) -> None: """Do any setup this clock might need. Called at the beginning of the run. @@ -20,7 +27,7 @@ def start_clock(self): """ @abstractmethod - def current_time(self): + def current_time(self) -> float: """Return the current time, according to this clock. This is used to implement functions like :func:`trio.current_time` and @@ -32,7 +39,7 @@ def current_time(self): """ @abstractmethod - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: """Compute the real time until the given deadline. This is called before we enter a system-specific wait function like @@ -225,7 +232,7 @@ class AsyncResource(metaclass=ABCMeta): __slots__ = () @abstractmethod - async def aclose(self): + async def aclose(self) -> None: """Close this resource, possibly blocking. IMPORTANT: This method may block in order to perform a "graceful" @@ -253,10 +260,15 @@ async def aclose(self): """ - async def __aenter__(self): + async def __aenter__(self) -> Self: return self - async def __aexit__(self, *args): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: await self.aclose() @@ -279,7 +291,7 @@ class SendStream(AsyncResource): __slots__ = () @abstractmethod - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Sends the given data through the stream, blocking if necessary. Args: @@ -305,7 +317,7 @@ async def send_all(self, data): """ @abstractmethod - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Block until it's possible that :meth:`send_all` might not block. This method may return early: it's possible that after it returns, @@ -385,7 +397,7 @@ class ReceiveStream(AsyncResource): __slots__ = () @abstractmethod - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: """Wait until there is data available on this stream, and then return some of it. @@ -413,10 +425,10 @@ async def receive_some(self, max_bytes=None): """ - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> bytes | bytearray: data = await self.receive_some() if not data: raise StopAsyncIteration @@ -446,7 +458,7 @@ class HalfCloseableStream(Stream): __slots__ = () @abstractmethod - async def send_eof(self): + async def send_eof(self) -> None: """Send an end-of-file indication on this stream, if possible. The difference between :meth:`send_eof` and @@ -632,7 +644,7 @@ async def receive(self) -> ReceiveType: """ - def __aiter__(self): + def __aiter__(self) -> Self: return self async def __anext__(self) -> ReceiveType: diff --git a/trio/_channel.py b/trio/_channel.py index 2bdec5bd09..7c8ff4660d 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -243,8 +243,8 @@ def __enter__(self: SelfT) -> SelfT: def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.close() @@ -389,8 +389,8 @@ def __enter__(self: SelfT) -> SelfT: def __exit__( self, exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.close() diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index c2991a4048..abd58245e3 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -20,7 +20,7 @@ from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection from ._local import RunVar from ._mock_clock import MockClock -from ._parking_lot import ParkingLot +from ._parking_lot import ParkingLot, ParkingLotStatistics # Imports that always exist from ._run import ( diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py index 0eb76b6356..fe35298631 100644 --- a/trio/_core/_mock_clock.py +++ b/trio/_core/_mock_clock.py @@ -62,7 +62,7 @@ class MockClock(Clock, metaclass=Final): """ - def __init__(self, rate=0.0, autojump_threshold=inf): + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): # when the real clock said 'real_base', the virtual time was # 'virtual_base', and since then it's advanced at 'rate' virtual # seconds per real second. @@ -77,17 +77,17 @@ def __init__(self, rate=0.0, autojump_threshold=inf): self.rate = rate self.autojump_threshold = autojump_threshold - def __repr__(self): + def __repr__(self) -> str: return "".format( self.current_time(), self._rate, id(self) ) @property - def rate(self): + def rate(self) -> float: return self._rate @rate.setter - def rate(self, new_rate): + def rate(self, new_rate: float) -> None: if new_rate < 0: raise ValueError("rate must be >= 0") else: @@ -98,11 +98,11 @@ def rate(self, new_rate): self._rate = float(new_rate) @property - def autojump_threshold(self): + def autojump_threshold(self) -> float: return self._autojump_threshold @autojump_threshold.setter - def autojump_threshold(self, new_autojump_threshold): + def autojump_threshold(self, new_autojump_threshold: float) -> None: self._autojump_threshold = float(new_autojump_threshold) self._try_resync_autojump_threshold() @@ -112,7 +112,7 @@ def autojump_threshold(self, new_autojump_threshold): # API. Discussion: # # https://github.com/python-trio/trio/issues/1587 - def _try_resync_autojump_threshold(self): + def _try_resync_autojump_threshold(self) -> None: try: runner = GLOBAL_RUN_CONTEXT.runner if runner.is_guest: @@ -124,24 +124,24 @@ def _try_resync_autojump_threshold(self): # Invoked by the run loop when runner.clock_autojump_threshold is # exceeded. - def _autojump(self): + def _autojump(self) -> None: statistics = _core.current_statistics() jump = statistics.seconds_to_next_deadline if 0 < jump < inf: self.jump(jump) - def _real_to_virtual(self, real): + def _real_to_virtual(self, real: float) -> float: real_offset = real - self._real_base virtual_offset = self._rate * real_offset return self._virtual_base + virtual_offset - def start_clock(self): + def start_clock(self) -> None: self._try_resync_autojump_threshold() - def current_time(self): + def current_time(self) -> float: return self._real_to_virtual(self._real_clock()) - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: virtual_timeout = deadline - self.current_time() if virtual_timeout <= 0: return 0 @@ -150,7 +150,7 @@ def deadline_to_sleep_time(self, deadline): else: return 999999999 - def jump(self, seconds): + def jump(self, seconds) -> None: """Manually advance the clock by the given number of seconds. Args: diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 9e69928162..3c6ebb789f 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import sys import warnings +from typing import TYPE_CHECKING import attr @@ -10,6 +13,8 @@ else: from traceback import print_exception +if TYPE_CHECKING: + from types import TracebackType ################################################################ # MultiError ################################################################ @@ -130,11 +135,16 @@ class MultiErrorCatcher: def __enter__(self): pass - def __exit__(self, etype, exc, tb): - if exc is not None: - filtered_exc = _filter_impl(self._handler, exc) + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + if exc_value is not None: + filtered_exc = _filter_impl(self._handler, exc_value) - if filtered_exc is exc: + if filtered_exc is exc_value: # Let the interpreter re-raise it return False if filtered_exc is None: @@ -154,6 +164,7 @@ def __exit__(self, etype, exc, tb): # delete references from locals to avoid creating cycles # see test_MultiError_catch_doesnt_create_cyclic_garbage del _, filtered_exc, value + return False class MultiError(BaseExceptionGroup): diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index 69882c787b..74708433da 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -69,18 +69,34 @@ # unpark is called. # # See: https://github.com/python-trio/trio/issues/53 +from __future__ import annotations +import math from collections import OrderedDict +from collections.abc import Iterator +from typing import TYPE_CHECKING import attr from .. import _core from .._util import Final +if TYPE_CHECKING: + from ._run import Task + @attr.s(frozen=True, slots=True) -class _ParkingLotStatistics: - tasks_waiting = attr.ib() +class ParkingLotStatistics: + """An object containing debugging information for a ParkingLot. + + Currently the following fields are defined: + + * ``tasks_waiting`` (int): The number of tasks blocked on this lot's + :meth:`trio.lowlevel.ParkingLot.park` method. + + """ + + tasks_waiting: int = attr.ib() @attr.s(eq=False, hash=False, slots=True) @@ -99,13 +115,13 @@ class ParkingLot(metaclass=Final): # {task: None}, we just want a deque where we can quickly delete random # items - _parked = attr.ib(factory=OrderedDict, init=False) + _parked: OrderedDict[Task, None] = attr.ib(factory=OrderedDict, init=False) - def __len__(self): + def __len__(self) -> int: """Returns the number of parked tasks.""" return len(self._parked) - def __bool__(self): + def __bool__(self) -> bool: """True if there are parked tasks, False otherwise.""" return bool(self._parked) @@ -114,7 +130,7 @@ def __bool__(self): # line (for false wakeups), then we could have it return a ticket that # abstracts the "place in line" concept. @_core.enable_ki_protection - async def park(self): + async def park(self) -> None: """Park the current task until woken by a call to :meth:`unpark` or :meth:`unpark_all`. @@ -129,13 +145,20 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) - def _pop_several(self, count): - for _ in range(min(count, len(self._parked))): + def _pop_several(self, count: int | float) -> Iterator[Task]: + if isinstance(count, float): + if math.isinf(count): + count = len(self._parked) + else: + raise ValueError("Cannot pop a non-integer number of tasks.") + else: + count = min(count, len(self._parked)) + for _ in range(count): task, _ = self._parked.popitem(last=False) yield task @_core.enable_ki_protection - def unpark(self, *, count=1): + def unpark(self, *, count: int | float = 1) -> list[Task]: """Unpark one or more tasks. This wakes up ``count`` tasks that are blocked in :meth:`park`. If @@ -143,7 +166,7 @@ def unpark(self, *, count=1): are available and then returns successfully. Args: - count (int): the number of tasks to unpark. + count (int | math.inf): the number of tasks to unpark. """ tasks = list(self._pop_several(count)) @@ -151,12 +174,12 @@ def unpark(self, *, count=1): _core.reschedule(task) return tasks - def unpark_all(self): + def unpark_all(self) -> list[Task]: """Unpark all parked tasks.""" return self.unpark(count=len(self)) @_core.enable_ki_protection - def repark(self, new_lot, *, count=1): + def repark(self, new_lot: ParkingLot, *, count: int | float = 1) -> None: """Move parked tasks from one :class:`ParkingLot` object to another. This dequeues ``count`` tasks from one lot, and requeues them on @@ -186,7 +209,7 @@ async def main(): Args: new_lot (ParkingLot): the parking lot to move tasks to. - count (int): the number of tasks to move. + count (int|math.inf): the number of tasks to move. """ if not isinstance(new_lot, ParkingLot): @@ -195,7 +218,7 @@ async def main(): new_lot._parked[task] = None task.custom_sleep_data = new_lot - def repark_all(self, new_lot): + def repark_all(self, new_lot: ParkingLot) -> None: """Move all parked tasks from one :class:`ParkingLot` object to another. @@ -204,7 +227,7 @@ def repark_all(self, new_lot): """ return self.repark(new_lot, count=len(self)) - def statistics(self): + def statistics(self) -> ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -213,4 +236,4 @@ def statistics(self): :meth:`park` method. """ - return _ParkingLotStatistics(tasks_waiting=len(self._parked)) + return ParkingLotStatistics(tasks_waiting=len(self._parked)) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 0b6d326546..585dc4aa41 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -10,7 +10,7 @@ import threading import warnings from collections import deque -from collections.abc import Callable, Iterator +from collections.abc import Callable, Coroutine, Iterator from contextlib import AbstractAsyncContextManager, contextmanager from contextvars import copy_context from heapq import heapify, heappop, heappush @@ -45,7 +45,11 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup +from types import FrameType + if TYPE_CHECKING: + import contextvars + # An unfortunate name collision here with trio._util.Final from typing_extensions import Final as FinalT @@ -272,7 +276,7 @@ class CancelStatus: # Our associated cancel scope. Can be any object with attributes # `deadline`, `shield`, and `cancel_called`, but in current usage # is always a CancelScope object. Must not be None. - _scope = attr.ib() + _scope: CancelScope = attr.ib() # True iff the tasks in self._tasks should receive cancellations # when they checkpoint. Always True when scope.cancel_called is True; @@ -282,31 +286,31 @@ class CancelStatus: # effectively cancelled due to the cancel scope two levels out # becoming cancelled, but then the cancel scope one level out # becomes shielded so we're not effectively cancelled anymore. - effectively_cancelled = attr.ib(default=False) + effectively_cancelled: bool = attr.ib(default=False) # The CancelStatus whose cancellations can propagate to us; we # become effectively cancelled when they do, unless scope.shield # is True. May be None (for the outermost CancelStatus in a call # to trio.run(), briefly during TaskStatus.started(), or during # recovery from mis-nesting of cancel scopes). - _parent = attr.ib(default=None, repr=False) + _parent: CancelStatus | None = attr.ib(default=None, repr=False) # All of the CancelStatuses that have this CancelStatus as their parent. - _children = attr.ib(factory=set, init=False, repr=False) + _children: set[CancelStatus] = attr.ib(factory=set, init=False, repr=False) # Tasks whose cancellation state is currently tied directly to # the cancellation state of this CancelStatus object. Don't modify # this directly; instead, use Task._activate_cancel_status(). # Invariant: all(task._cancel_status is self for task in self._tasks) - _tasks = attr.ib(factory=set, init=False, repr=False) + _tasks: set[Task] = attr.ib(factory=set, init=False, repr=False) # Set to True on still-active cancel statuses that are children # of a cancel status that's been closed. This is used to permit # recovery from mis-nested cancel scopes (well, at least enough # recovery to show a useful traceback). - abandoned_by_misnesting = attr.ib(default=False, init=False, repr=False) + abandoned_by_misnesting: bool = attr.ib(default=False, init=False, repr=False) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: if self._parent is not None: self._parent._children.add(self) self.recalculate() @@ -314,11 +318,11 @@ def __attrs_post_init__(self): # parent/children/tasks accessors are used by TaskStatus.started() @property - def parent(self): + def parent(self) -> CancelStatus | None: return self._parent @parent.setter - def parent(self, parent): + def parent(self, parent: CancelStatus) -> None: if self._parent is not None: self._parent._children.remove(self) self._parent = parent @@ -327,14 +331,14 @@ def parent(self, parent): self.recalculate() @property - def children(self): + def children(self) -> frozenset[CancelStatus]: return frozenset(self._children) @property - def tasks(self): + def tasks(self) -> frozenset[Task]: return frozenset(self._tasks) - def encloses(self, other): + def encloses(self, other: CancelStatus | None) -> bool: """Returns true if this cancel status is a direct or indirect parent of cancel status *other*, or if *other* is *self*. """ @@ -344,7 +348,7 @@ def encloses(self, other): other = other.parent return False - def close(self): + def close(self) -> None: self.parent = None # now we're not a child of self.parent anymore if self._tasks or self._children: # Cancel scopes weren't exited in opposite order of being @@ -404,7 +408,7 @@ def _mark_abandoned(self): for child in self._children: child._mark_abandoned() - def effective_deadline(self): + def effective_deadline(self) -> float: if self.effectively_cancelled: return -inf if self._parent is None or self._scope.shield: @@ -852,10 +856,10 @@ class NurseryManager: """ - strict_exception_groups = attr.ib(default=False) + strict_exception_groups: bool = attr.ib(default=False) @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self) -> Nursery: self._scope = CancelScope() self._scope.__enter__() self._nursery = Nursery._create( @@ -864,7 +868,12 @@ async def __aenter__(self): return self._nursery @enable_ki_protection - async def __aexit__(self, etype, exc, tb): + async def __aexit__( + self, + etype: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: new_exc = await self._nursery._nested_child_finished(exc) # Tracebacks show the 'raise' line below out of context, so let's give # this variable a name that makes sense out of context. @@ -887,13 +896,21 @@ async def __aexit__(self, etype, exc, tb): # see test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage del _, combined_error_from_nursery, value, new_exc - def __enter__(self): - raise RuntimeError( - "use 'async with open_nursery(...)', not 'with open_nursery(...)'" - ) + # make sure these raise errors in static analysis if called + if not TYPE_CHECKING: + + def __enter__(self) -> NoReturn: + raise RuntimeError( + "use 'async with open_nursery(...)', not 'with open_nursery(...)'" + ) - def __exit__(self): # pragma: no cover - assert False, """Never called, but should be defined""" + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> NoReturn: # pragma: no cover + raise AssertionError("Never called, but should be defined") def open_nursery( @@ -939,7 +956,12 @@ class Nursery(metaclass=NoPublicConstructor): in response to some external event. """ - def __init__(self, parent_task, cancel_scope, strict_exception_groups): + def __init__( + self, + parent_task: Task, + cancel_scope: CancelScope, + strict_exception_groups: bool, + ): self._parent_task = parent_task self._strict_exception_groups = strict_exception_groups parent_task._child_nurseries.append(self) @@ -950,8 +972,8 @@ def __init__(self, parent_task, cancel_scope, strict_exception_groups): # children. self.cancel_scope = cancel_scope assert self.cancel_scope._cancel_status is self._cancel_status - self._children = set() - self._pending_excs = [] + self._children: set[Task] = set() + self._pending_excs: list[BaseException] = [] # The "nested child" is how this code refers to the contents of the # nursery's 'async with' block, which acts like a child Task in all # the ways we can make it. @@ -961,17 +983,17 @@ def __init__(self, parent_task, cancel_scope, strict_exception_groups): self._closed = False @property - def child_tasks(self): + def child_tasks(self) -> frozenset[Task]: """(`frozenset`): Contains all the child :class:`~trio.lowlevel.Task` objects which are still running.""" return frozenset(self._children) @property - def parent_task(self): + def parent_task(self) -> Task: "(`~trio.lowlevel.Task`): The Task that opened this nursery." return self._parent_task - def _add_exc(self, exc): + def _add_exc(self, exc: BaseException) -> None: self._pending_excs.append(exc) self.cancel_scope.cancel() @@ -1133,7 +1155,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): self._pending_starts -= 1 self._check_nursery_closed() - def __del__(self): + def __del__(self) -> None: assert not self._children @@ -1144,12 +1166,11 @@ def __del__(self): @attr.s(eq=False, hash=False, repr=False, slots=True) class Task(metaclass=NoPublicConstructor): - _parent_nursery = attr.ib() - coro = attr.ib() + _parent_nursery: Nursery | None = attr.ib() + coro: Coroutine[Any, Outcome[object], Any] = attr.ib() _runner = attr.ib() - name = attr.ib() - # PEP 567 contextvars context - context = attr.ib() + name: str = attr.ib() + context: contextvars.Context = attr.ib() _counter: int = attr.ib(init=False, factory=itertools.count().__next__) # Invariant: @@ -1165,24 +1186,26 @@ class Task(metaclass=NoPublicConstructor): # Tasks start out unscheduled. _next_send_fn = attr.ib(default=None) _next_send = attr.ib(default=None) - _abort_func = attr.ib(default=None) - custom_sleep_data = attr.ib(default=None) + _abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib( + default=None + ) + custom_sleep_data: Any = attr.ib(default=None) # For introspection and nursery.start() - _child_nurseries = attr.ib(factory=list) - _eventual_parent_nursery = attr.ib(default=None) + _child_nurseries: list[Nursery] = attr.ib(factory=list) + _eventual_parent_nursery: Nursery | None = attr.ib(default=None) # these are counts of how many cancel/schedule points this task has # executed, for assert{_no,}_checkpoints # XX maybe these should be exposed as part of a statistics() method? - _cancel_points = attr.ib(default=0) - _schedule_points = attr.ib(default=0) + _cancel_points: int = attr.ib(default=0) + _schedule_points: int = attr.ib(default=0) - def __repr__(self): + def __repr__(self) -> str: return f"" @property - def parent_nursery(self): + def parent_nursery(self) -> Nursery | None: """The nursery this task is inside (or None if this is the "init" task). @@ -1193,7 +1216,7 @@ def parent_nursery(self): return self._parent_nursery @property - def eventual_parent_nursery(self): + def eventual_parent_nursery(self) -> Nursery | None: """The nursery this task will be inside after it calls ``task_status.started()``. @@ -1205,7 +1228,7 @@ def eventual_parent_nursery(self): return self._eventual_parent_nursery @property - def child_nurseries(self): + def child_nurseries(self) -> list[Nursery]: """The nurseries this task contains. This is a list, with outer nurseries before inner nurseries. @@ -1213,7 +1236,7 @@ def child_nurseries(self): """ return list(self._child_nurseries) - def iter_await_frames(self): + def iter_await_frames(self) -> Iterator[tuple[FrameType, int]]: """Iterates recursively over the coroutine-like objects this task is waiting on, yielding the frame and line number at each frame. @@ -1233,7 +1256,8 @@ def print_stack_for_task(task): print("".join(ss.format())) """ - coro = self.coro + # ignore static typing as we're doing lots of dynamic introspection + coro: Any = self.coro while coro is not None: if hasattr(coro, "cr_frame"): # A real coroutine @@ -1266,9 +1290,9 @@ def print_stack_for_task(task): # The CancelStatus object that is currently active for this task. # Don't change this directly; instead, use _activate_cancel_status(). - _cancel_status = attr.ib(default=None, repr=False) + _cancel_status: CancelStatus = attr.ib(default=None, repr=False) - def _activate_cancel_status(self, cancel_status): + def _activate_cancel_status(self, cancel_status: CancelStatus) -> None: if self._cancel_status is not None: self._cancel_status._tasks.remove(self) self._cancel_status = cancel_status @@ -1277,11 +1301,16 @@ def _activate_cancel_status(self, cancel_status): if self._cancel_status.effectively_cancelled: self._attempt_delivery_of_any_pending_cancel() - def _attempt_abort(self, raise_cancel): + def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None: # Either the abort succeeds, in which case we will reschedule the # task, or else it fails, in which case it will worry about # rescheduling itself (hopefully eventually calling reraise to raise # the given exception, but not necessarily). + + # This is only called by the functions immediately below, which both check + # `self.abort_func is not None`. + assert self._abort_func is not None, "FATAL INTERNAL ERROR" + success = self._abort_func(raise_cancel) if type(success) is not Abort: raise TrioInternalError("abort function must return Abort enum") @@ -1291,7 +1320,7 @@ def _attempt_abort(self, raise_cancel): if success is Abort.SUCCEEDED: self._runner.reschedule(self, capture(raise_cancel)) - def _attempt_delivery_of_any_pending_cancel(self): + def _attempt_delivery_of_any_pending_cancel(self) -> None: if self._abort_func is None: return if not self._cancel_status.effectively_cancelled: @@ -1302,12 +1331,12 @@ def raise_cancel(): self._attempt_abort(raise_cancel) - def _attempt_delivery_of_pending_ki(self): + def _attempt_delivery_of_pending_ki(self) -> None: assert self._runner.ki_pending if self._abort_func is None: return - def raise_cancel(): + def raise_cancel() -> NoReturn: self._runner.ki_pending = False raise KeyboardInterrupt @@ -2433,17 +2462,17 @@ def unrolled_run( class _TaskStatusIgnored: - def __repr__(self): + def __repr__(self) -> str: return "TASK_STATUS_IGNORED" - def started(self, value=None): + def started(self, value: object = None) -> None: pass TASK_STATUS_IGNORED: FinalT = _TaskStatusIgnored() -def current_task(): +def current_task() -> Task: """Return the :class:`Task` object representing the current task. Returns: @@ -2457,7 +2486,7 @@ def current_task(): raise RuntimeError("must be called from async context") from None -def current_effective_deadline(): +def current_effective_deadline() -> float: """Returns the current effective deadline for the current task. This function examines all the cancellation scopes that are currently in @@ -2484,7 +2513,7 @@ def current_effective_deadline(): return current_task()._cancel_status.effective_deadline() -async def checkpoint(): +async def checkpoint() -> None: """A pure :ref:`checkpoint `. This checks for cancellation and allows other tasks to be scheduled, @@ -2511,7 +2540,7 @@ async def checkpoint(): await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -async def checkpoint_if_cancelled(): +async def checkpoint_if_cancelled() -> None: """Issue a :ref:`checkpoint ` if the calling context has been cancelled. diff --git a/trio/_core/_tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py index db3fc76709..3f03fdbade 100644 --- a/trio/_core/_tests/test_parking_lot.py +++ b/trio/_core/_tests/test_parking_lot.py @@ -72,6 +72,9 @@ async def waiter(i, lot): ) lot.unpark_all() + with pytest.raises(ValueError): + lot.unpark(count=1.5) + async def cancellable_waiter(name, lot, scopes, record): with _core.CancelScope() as scope: diff --git a/trio/_dtls.py b/trio/_dtls.py index f46fc4fda0..722a9499f8 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -6,6 +6,8 @@ # Hopefully they fix this before implementing DTLS 1.3, because it's a very different # protocol, and it's probably impossible to pull tricks like we do here. +from __future__ import annotations + import enum import errno import hmac @@ -14,12 +16,16 @@ import warnings import weakref from itertools import count +from typing import TYPE_CHECKING import attr import trio from trio._util import Final, NoPublicConstructor +if TYPE_CHECKING: + from types import TracebackType + MAX_UDP_PACKET_SIZE = 65527 @@ -809,7 +815,7 @@ def _check_replaced(self): # DTLS where packets are all independent and can be lost anyway. We do at least need # to handle receiving it properly though, which might be easier if we send it... - def close(self): + def close(self) -> None: """Close this connection. `DTLSChannel`\\s don't actually own any OS-level resources – the @@ -833,8 +839,13 @@ def close(self): def __enter__(self): return self - def __exit__(self, *args): - self.close() + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self.close() async def aclose(self): """Close this connection, but asynchronously. @@ -1121,6 +1132,8 @@ def __init__(self, socket, *, incoming_packets_buffer=10): global SSL from OpenSSL import SSL + # TODO: create a `self._initialized` for `__del__`, so self.socket can be typed + # as trio.socket.SocketType and `is not None` checks can be removed. self.socket = None # for __del__, in case the next line raises if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") @@ -1167,12 +1180,16 @@ def __del__(self): f"unclosed DTLS endpoint {self!r}", ResourceWarning, source=self ) - def close(self): + def close(self) -> None: """Close this socket, and all associated DTLS connections. This object can also be used as a context manager. """ + # Do nothing if this object was never fully constructed + if self.socket is None: # pragma: no cover + return + self._closed = True self.socket.close() for stream in list(self._streams.values()): @@ -1182,8 +1199,13 @@ def close(self): def __enter__(self): return self - def __exit__(self, *args): - self.close() + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self.close() def _check_closed(self): if self._closed: diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index 2ae381c8e2..e1ac378c6a 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -1,12 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import attr import trio from trio._util import Final +if TYPE_CHECKING: + from .abc import SendStream, ReceiveStream, AsyncResource + from .abc import HalfCloseableStream -async def aclose_forcefully(resource): +async def aclose_forcefully(resource: AsyncResource) -> None: """Close an async resource or async generator immediately, without blocking to do any graceful cleanup. @@ -72,18 +79,18 @@ class StapledStream(HalfCloseableStream, metaclass=Final): """ - send_stream = attr.ib() - receive_stream = attr.ib() + send_stream: SendStream = attr.ib() + receive_stream: ReceiveStream = attr.ib() - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Calls ``self.send_stream.send_all``.""" return await self.send_stream.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls ``self.send_stream.wait_send_all_might_not_block``.""" return await self.send_stream.wait_send_all_might_not_block() - async def send_eof(self): + async def send_eof(self) -> None: """Shuts down the send side of the stream. If ``self.send_stream.send_eof`` exists, then calls it. Otherwise, @@ -91,15 +98,18 @@ async def send_eof(self): """ if hasattr(self.send_stream, "send_eof"): - return await self.send_stream.send_eof() + # send_stream.send_eof() is not defined in Trio, this should maybe be + # redesigned so it's possible to type it. + return await self.send_stream.send_eof() # type: ignore[no-any-return] else: return await self.send_stream.aclose() - async def receive_some(self, max_bytes=None): + # we intentionally accept more types from the caller than we support returning + async def receive_some(self, max_bytes: int | None = None) -> bytes: """Calls ``self.receive_stream.receive_some``.""" return await self.receive_stream.receive_some(max_bytes) - async def aclose(self): + async def aclose(self) -> None: """Calls ``aclose`` on both underlying streams.""" try: await self.send_stream.aclose() diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index ce23de17d7..ce96153805 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -1,7 +1,9 @@ # "High-level" networking interface +from __future__ import annotations import errno from contextlib import contextmanager +from typing import TYPE_CHECKING import trio @@ -9,6 +11,9 @@ from ._util import ConflictDetector, Final from .abc import HalfCloseableStream, Listener +if TYPE_CHECKING: + from ._socket import _SocketType as SocketType + # XX TODO: this number was picked arbitrarily. We should do experiments to # tune it. (Or make it dynamic -- one idea is to start small and increase it # if we observe single reads filling up the whole buffer, at least within some @@ -58,7 +63,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -109,14 +114,14 @@ async def send_all(self, data): sent = await self.socket.send(remaining) total_sent += sent - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: if self.socket.fileno() == -1: raise trio.ClosedResourceError with _translate_socket_errors_to_stream_errors(): await self.socket.wait_writable() - async def send_eof(self): + async def send_eof(self) -> None: with self._send_conflict_detector: await trio.lowlevel.checkpoint() # On macOS, calling shutdown a second time raises ENOTCONN, but @@ -126,7 +131,7 @@ async def send_eof(self): with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: @@ -134,7 +139,7 @@ async def receive_some(self, max_bytes=None): with _translate_socket_errors_to_stream_errors(): return await self.socket.recv(max_bytes) - async def aclose(self): + async def aclose(self) -> None: self.socket.close() await trio.lowlevel.checkpoint() @@ -331,7 +336,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -347,7 +352,7 @@ def __init__(self, socket): self.socket = socket - async def accept(self): + async def accept(self) -> SocketStream: """Accept an incoming connection. Returns: @@ -375,7 +380,7 @@ async def accept(self): else: return SocketStream(sock) - async def aclose(self): + async def aclose(self) -> None: """Close this listener and its underlying socket.""" self.socket.close() await trio.lowlevel.checkpoint() diff --git a/trio/_socket.py b/trio/_socket.py index b4ee4a7199..eaf0e04d15 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import select import socket as _stdlib_socket @@ -11,6 +13,12 @@ from . import _core +if TYPE_CHECKING: + from collections.abc import Iterable + from types import TracebackType + + from typing_extensions import Self + # Usage: # @@ -33,8 +41,13 @@ def _is_blocking_io_error(self, exc): async def __aenter__(self): await trio.lowlevel.checkpoint_if_cancelled() - async def __aexit__(self, etype, value, tb): - if value is not None and self._is_blocking_io_error(value): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: + if exc_value is not None and self._is_blocking_io_error(exc_value): # Discard the exception and fall through to the code below the # block return True @@ -430,7 +443,7 @@ def __init__(self): class _SocketType(SocketType): - def __init__(self, sock): + def __init__(self, sock: _stdlib_socket.socket): if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we # certainly don't want to blindly wrap one of those. @@ -473,44 +486,49 @@ def __getattr__(self, name): return getattr(self._sock, name) raise AttributeError(name) - def __dir__(self): - return super().__dir__() + list(self._forward) + def __dir__(self) -> Iterable[str]: + return [*super().__dir__(), *self._forward] - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, *exc_info): - return self._sock.__exit__(*exc_info) + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self._sock.__exit__(exc_type, exc_value, traceback) @property - def family(self): + def family(self) -> _stdlib_socket.AddressFamily: return self._sock.family @property - def type(self): + def type(self) -> _stdlib_socket.SocketKind: return self._sock.type @property - def proto(self): + def proto(self) -> int: return self._sock.proto @property - def did_shutdown_SHUT_WR(self): + def did_shutdown_SHUT_WR(self) -> bool: return self._did_shutdown_SHUT_WR - def __repr__(self): + def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") - def dup(self): + def dup(self) -> _SocketType: """Same as :meth:`socket.socket.dup`.""" return _SocketType(self._sock.dup()) - def close(self): + def close(self) -> None: if self._sock.fileno() != -1: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address): + async def bind(self, address: tuple[object, ...] | str | bytes) -> None: address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") @@ -519,7 +537,8 @@ async def bind(self, address): ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) - return await trio.to_thread.run_sync(self._sock.bind, address) + # remove the `type: ignore` when run.sync is typed. + return await trio.to_thread.run_sync(self._sock.bind, address) # type: ignore[no-any-return] else: # POSIX actually says that bind can return EWOULDBLOCK and # complete asynchronously, like connect. But in practice AFAICT @@ -528,14 +547,14 @@ async def bind(self, address): await trio.lowlevel.checkpoint() return self._sock.bind(address) - def shutdown(self, flag): + def shutdown(self, flag: int) -> None: # no need to worry about return value b/c always returns None: self._sock.shutdown(flag) # only do this if the call succeeded: if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]: self._did_shutdown_SHUT_WR = True - def is_readable(self): + def is_readable(self) -> bool: # use select.select on Windows, and select.poll everywhere else if sys.platform == "win32": rready, _, _ = select.select([self._sock], [], [], 0) @@ -544,7 +563,7 @@ def is_readable(self): p.register(self._sock, select.POLLIN) return bool(p.poll(0)) - async def wait_writable(self): + async def wait_writable(self) -> None: await _core.wait_writable(self._sock) async def _resolve_address_nocp(self, address, *, local): @@ -684,7 +703,13 @@ async def connect(self, address): # recv ################################################################ - recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable) + if TYPE_CHECKING: + + async def recv(self, buffersize: int, flags: int = 0) -> bytes: + ... + + else: + recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable) ################################################################ # recv_into diff --git a/trio/_sync.py b/trio/_sync.py index 60d7074d9e..5a7f240d5e 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import math +from typing import TYPE_CHECKING import attr @@ -8,10 +11,25 @@ from ._core import ParkingLot, enable_ki_protection from ._util import Final +if TYPE_CHECKING: + from types import TracebackType + + from ._core import Task + from ._core._parking_lot import ParkingLotStatistics + + +@attr.s(frozen=True, slots=True) +class EventStatistics: + """An object containing debugging information. + + Currently the following fields are defined: -@attr.s(frozen=True) -class _EventStatistics: - tasks_waiting = attr.ib() + * ``tasks_waiting``: The number of tasks blocked on this event's + :meth:`trio.Event.wait` method. + + """ + + tasks_waiting: int = attr.ib() @attr.s(repr=False, eq=False, hash=False, slots=True) @@ -41,15 +59,15 @@ class Event(metaclass=Final): """ - _tasks = attr.ib(factory=set, init=False) - _flag = attr.ib(default=False, init=False) + _tasks: set[Task] = attr.ib(factory=set, init=False) + _flag: bool = attr.ib(default=False, init=False) - def is_set(self): + def is_set(self) -> bool: """Return the current value of the internal flag.""" return self._flag @enable_ki_protection - def set(self): + def set(self) -> None: """Set the internal flag value to True, and wake any waiting tasks.""" if not self._flag: self._flag = True @@ -57,7 +75,7 @@ def set(self): _core.reschedule(task) self._tasks.clear() - async def wait(self): + async def wait(self) -> None: """Block until the internal flag value becomes True. If it's already True, then this method returns immediately. @@ -75,7 +93,7 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) - def statistics(self): + def statistics(self) -> EventStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -84,25 +102,49 @@ def statistics(self): :meth:`wait` method. """ - return _EventStatistics(tasks_waiting=len(self._tasks)) + return EventStatistics(tasks_waiting=len(self._tasks)) +# TODO: type this with a Protocol to get rid of type: ignore, see +# https://github.com/python-trio/trio/pull/2682#discussion_r1259097422 class AsyncContextManagerMixin: @enable_ki_protection - async def __aenter__(self): - await self.acquire() + async def __aenter__(self) -> None: + await self.acquire() # type: ignore[attr-defined] @enable_ki_protection - async def __aexit__(self, *args): - self.release() + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.release() # type: ignore[attr-defined] + + +@attr.s(frozen=True, slots=True) +class CapacityLimiterStatistics: + """An object containing debugging information. + + Currently the following fields are defined: + + * ``borrowed_tokens``: The number of tokens currently borrowed from + the sack. + * ``total_tokens``: The total number of tokens in the sack. Usually + this will be larger than ``borrowed_tokens``, but it's possibly for + it to be smaller if :attr:`trio.CapacityLimiter.total_tokens` was recently decreased. + * ``borrowers``: A list of all tasks or other entities that currently + hold a token. + * ``tasks_waiting``: The number of tasks blocked on this + :class:`CapacityLimiter`\'s :meth:`trio.CapacityLimiter.acquire` or + :meth:`trio.CapacityLimiter.acquire_on_behalf_of` methods. + """ -@attr.s(frozen=True) -class _CapacityLimiterStatistics: - borrowed_tokens = attr.ib() - total_tokens = attr.ib() - borrowers = attr.ib() - tasks_waiting = attr.ib() + borrowed_tokens: int = attr.ib() + total_tokens: int | float = attr.ib() + borrowers: list[Task] = attr.ib() + tasks_waiting: int = attr.ib() class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): @@ -159,22 +201,23 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): """ - def __init__(self, total_tokens): + # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing + def __init__(self, total_tokens: int | float): self._lot = ParkingLot() - self._borrowers = set() + self._borrowers: set[Task] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of - self._pending_borrowers = {} + self._pending_borrowers: dict[Task, Task] = {} # invoke the property setter for validation - self.total_tokens = total_tokens + self.total_tokens: int | float = total_tokens assert self._total_tokens == total_tokens - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), len(self._borrowers), self._total_tokens, len(self._lot) ) @property - def total_tokens(self): + def total_tokens(self) -> int | float: """The total capacity available. You can change :attr:`total_tokens` by assigning to this attribute. If @@ -189,7 +232,7 @@ def total_tokens(self): return self._total_tokens @total_tokens.setter - def total_tokens(self, new_total_tokens): + def total_tokens(self, new_total_tokens: int | float) -> None: if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf: raise TypeError("total_tokens must be an int or math.inf") if new_total_tokens < 1: @@ -197,23 +240,23 @@ def total_tokens(self, new_total_tokens): self._total_tokens = new_total_tokens self._wake_waiters() - def _wake_waiters(self): + def _wake_waiters(self) -> None: available = self._total_tokens - len(self._borrowers) for woken in self._lot.unpark(count=available): self._borrowers.add(self._pending_borrowers.pop(woken)) @property - def borrowed_tokens(self): + def borrowed_tokens(self) -> int: """The amount of capacity that's currently in use.""" return len(self._borrowers) @property - def available_tokens(self): + def available_tokens(self) -> int | float: """The amount of capacity that's available to use.""" return self.total_tokens - self.borrowed_tokens @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Borrow a token from the sack, without blocking. Raises: @@ -225,7 +268,7 @@ def acquire_nowait(self): self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) @enable_ki_protection - def acquire_on_behalf_of_nowait(self, borrower): + def acquire_on_behalf_of_nowait(self, borrower: Task) -> None: """Borrow a token from the sack on behalf of ``borrower``, without blocking. @@ -253,7 +296,7 @@ def acquire_on_behalf_of_nowait(self, borrower): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Borrow a token from the sack, blocking if necessary. Raises: @@ -264,7 +307,7 @@ async def acquire(self): await self.acquire_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - async def acquire_on_behalf_of(self, borrower): + async def acquire_on_behalf_of(self, borrower: Task) -> None: """Borrow a token from the sack on behalf of ``borrower``, blocking if necessary. @@ -293,7 +336,7 @@ async def acquire_on_behalf_of(self, borrower): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Put a token back into the sack. Raises: @@ -304,7 +347,7 @@ def release(self): self.release_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - def release_on_behalf_of(self, borrower): + def release_on_behalf_of(self, borrower: Task) -> None: """Put a token back into the sack on behalf of ``borrower``. Raises: @@ -319,7 +362,7 @@ def release_on_behalf_of(self, borrower): self._borrowers.remove(borrower) self._wake_waiters() - def statistics(self): + def statistics(self) -> CapacityLimiterStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -336,7 +379,7 @@ def statistics(self): :meth:`acquire_on_behalf_of` methods. """ - return _CapacityLimiterStatistics( + return CapacityLimiterStatistics( borrowed_tokens=len(self._borrowers), total_tokens=self._total_tokens, # Use a list instead of a frozenset just in case we start to allow @@ -373,7 +416,7 @@ class Semaphore(AsyncContextManagerMixin, metaclass=Final): """ - def __init__(self, initial_value, *, max_value=None): + def __init__(self, initial_value: int, *, max_value: int | None = None): if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") if initial_value < 0: @@ -391,7 +434,7 @@ def __init__(self, initial_value, *, max_value=None): self._value = initial_value self._max_value = max_value - def __repr__(self): + def __repr__(self) -> str: if self._max_value is None: max_value_str = "" else: @@ -401,17 +444,17 @@ def __repr__(self): ) @property - def value(self): + def value(self) -> int: """The current value of the semaphore.""" return self._value @property - def max_value(self): + def max_value(self) -> int | None: """The maximum allowed value. May be None to indicate no limit.""" return self._max_value @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to decrement the semaphore value, without blocking. Raises: @@ -425,7 +468,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Decrement the semaphore value, blocking if necessary to avoid letting it drop below zero. @@ -439,7 +482,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Increment the semaphore value, possibly waking a task blocked in :meth:`acquire`. @@ -456,7 +499,7 @@ def release(self): raise ValueError("semaphore released too many times") self._value += 1 - def statistics(self): + def statistics(self) -> ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -468,19 +511,31 @@ def statistics(self): return self._lot.statistics() -@attr.s(frozen=True) -class _LockStatistics: - locked = attr.ib() - owner = attr.ib() - tasks_waiting = attr.ib() +@attr.s(frozen=True, slots=True) +class LockStatistics: + """An object containing debugging information for a Lock. + + Currently the following fields are defined: + + * ``locked`` (boolean): indicating whether the lock is held. + * ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock, + or None if the lock is not held. + * ``tasks_waiting`` (int): The number of tasks blocked on this lock's + :meth:`trio.Lock.acquire` method. + + """ + + locked: bool = attr.ib() + owner: Task | None = attr.ib() + tasks_waiting: int = attr.ib() @attr.s(eq=False, hash=False, repr=False) class _LockImpl(AsyncContextManagerMixin): - _lot = attr.ib(factory=ParkingLot, init=False) - _owner = attr.ib(default=None, init=False) + _lot: ParkingLot = attr.ib(factory=ParkingLot, init=False) + _owner: Task | None = attr.ib(default=None, init=False) - def __repr__(self): + def __repr__(self) -> str: if self.locked(): s1 = "locked" s2 = f" with {len(self._lot)} waiters" @@ -491,7 +546,7 @@ def __repr__(self): s1, self.__class__.__name__, id(self), s2 ) - def locked(self): + def locked(self) -> bool: """Check whether the lock is currently held. Returns: @@ -501,7 +556,7 @@ def locked(self): return self._owner is not None @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the lock, without blocking. Raises: @@ -519,7 +574,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Acquire the lock, blocking if necessary.""" await trio.lowlevel.checkpoint_if_cancelled() try: @@ -533,7 +588,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Release the lock. Raises: @@ -548,7 +603,7 @@ def release(self): else: self._owner = None - def statistics(self): + def statistics(self) -> LockStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -560,7 +615,7 @@ def statistics(self): :meth:`acquire` method. """ - return _LockStatistics( + return LockStatistics( locked=self.locked(), owner=self._owner, tasks_waiting=len(self._lot) ) @@ -642,10 +697,20 @@ class StrictFIFOLock(_LockImpl, metaclass=Final): """ -@attr.s(frozen=True) -class _ConditionStatistics: - tasks_waiting = attr.ib() - lock_statistics = attr.ib() +@attr.s(frozen=True, slots=True) +class ConditionStatistics: + r"""An object containing debugging information for a Condition. + + Currently the following fields are defined: + + * ``tasks_waiting`` (int): The number of tasks blocked on this condition's + :meth:`trio.Condition.wait` method. + * ``lock_statistics``: The result of calling the underlying + :class:`Lock`\s :meth:`~Lock.statistics` method. + + """ + tasks_waiting: int = attr.ib() + lock_statistics: LockStatistics = attr.ib() class Condition(AsyncContextManagerMixin, metaclass=Final): @@ -663,7 +728,7 @@ class Condition(AsyncContextManagerMixin, metaclass=Final): """ - def __init__(self, lock=None): + def __init__(self, lock: Lock | None = None): if lock is None: lock = Lock() if not type(lock) is Lock: @@ -671,7 +736,7 @@ def __init__(self, lock=None): self._lock = lock self._lot = trio.lowlevel.ParkingLot() - def locked(self): + def locked(self) -> bool: """Check whether the underlying lock is currently held. Returns: @@ -680,7 +745,7 @@ def locked(self): """ return self._lock.locked() - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the underlying lock, without blocking. Raises: @@ -689,16 +754,16 @@ def acquire_nowait(self): """ return self._lock.acquire_nowait() - async def acquire(self): + async def acquire(self) -> None: """Acquire the underlying lock, blocking if necessary.""" await self._lock.acquire() - def release(self): + def release(self) -> None: """Release the underlying lock.""" self._lock.release() @enable_ki_protection - async def wait(self): + async def wait(self) -> None: """Wait for another task to call :meth:`notify` or :meth:`notify_all`. @@ -733,7 +798,7 @@ async def wait(self): await self.acquire() raise - def notify(self, n=1): + def notify(self, n: int = 1) -> None: """Wake one or more tasks that are blocked in :meth:`wait`. Args: @@ -747,7 +812,7 @@ def notify(self, n=1): raise RuntimeError("must hold the lock to notify") self._lot.repark(self._lock._lot, count=n) - def notify_all(self): + def notify_all(self) -> None: """Wake all tasks that are currently blocked in :meth:`wait`. Raises: @@ -758,7 +823,7 @@ def notify_all(self): raise RuntimeError("must hold the lock to notify") self._lot.repark_all(self._lock._lot) - def statistics(self): + def statistics(self) -> ConditionStatistics: r"""Return an object containing debugging information. Currently the following fields are defined: @@ -769,6 +834,6 @@ def statistics(self): :class:`Lock`\s :meth:`~Lock.statistics` method. """ - return _ConditionStatistics( + return ConditionStatistics( tasks_waiting=len(self._lot), lock_statistics=self._lock.statistics() ) diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py index 3ab0016386..e51bbe31f5 100644 --- a/trio/_tests/test_exports.py +++ b/trio/_tests/test_exports.py @@ -492,4 +492,8 @@ def test_classes_are_final(): continue # ... insert other special cases here ... + # don't care about the *Statistics classes + if name.endswith("Statistics"): + continue + assert isinstance(class_, _util.Final) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 57b307d1d9..9d7d7aa912 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.8317152103559871, + "completenessScore": 0.8764044943820225, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 514, - "withUnknownType": 103 + "withKnownType": 546, + "withUnknownType": 76 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -45,22 +45,13 @@ } ], "otherSymbolCounts": { - "withAmbiguousType": 14, - "withKnownType": 244, - "withUnknownType": 224 + "withAmbiguousType": 8, + "withKnownType": 433, + "withUnknownType": 135 }, "packageName": "trio", "symbols": [ "trio.__deprecated_attributes__", - "trio._abc.AsyncResource.__aenter__", - "trio._abc.AsyncResource.__aexit__", - "trio._abc.AsyncResource.aclose", - "trio._abc.Channel", - "trio._abc.Clock.current_time", - "trio._abc.Clock.deadline_to_sleep_time", - "trio._abc.Clock.start_clock", - "trio._abc.HalfCloseableStream", - "trio._abc.HalfCloseableStream.send_eof", "trio._abc.HostnameResolver.getaddrinfo", "trio._abc.HostnameResolver.getnameinfo", "trio._abc.Instrument.after_io_wait", @@ -72,58 +63,16 @@ "trio._abc.Instrument.task_exited", "trio._abc.Instrument.task_scheduled", "trio._abc.Instrument.task_spawned", - "trio._abc.Listener", "trio._abc.Listener.accept", - "trio._abc.ReceiveChannel", - "trio._abc.ReceiveChannel.__aiter__", - "trio._abc.ReceiveStream", - "trio._abc.ReceiveStream.__aiter__", - "trio._abc.ReceiveStream.__anext__", - "trio._abc.ReceiveStream.receive_some", - "trio._abc.SendChannel", - "trio._abc.SendStream", - "trio._abc.SendStream.send_all", - "trio._abc.SendStream.wait_send_all_might_not_block", "trio._abc.SocketFactory.socket", - "trio._abc.Stream", - "trio._channel.MemoryReceiveChannel", - "trio._channel.MemorySendChannel", "trio._core._entry_queue.TrioToken.run_sync_soon", "trio._core._local.RunVar.__repr__", "trio._core._local.RunVar.get", "trio._core._local.RunVar.reset", "trio._core._local.RunVar.set", - "trio._core._mock_clock.MockClock", - "trio._core._mock_clock.MockClock.__init__", - "trio._core._mock_clock.MockClock.__repr__", - "trio._core._mock_clock.MockClock.autojump_threshold", - "trio._core._mock_clock.MockClock.current_time", - "trio._core._mock_clock.MockClock.deadline_to_sleep_time", "trio._core._mock_clock.MockClock.jump", - "trio._core._mock_clock.MockClock.rate", - "trio._core._mock_clock.MockClock.start_clock", - "trio._core._parking_lot.ParkingLot.__bool__", - "trio._core._parking_lot.ParkingLot.__len__", - "trio._core._parking_lot.ParkingLot.repark_all", - "trio._core._parking_lot.ParkingLot.statistics", - "trio._core._parking_lot.ParkingLot.unpark_all", - "trio._core._run.Nursery.__del__", - "trio._core._run.Nursery.__init__", - "trio._core._run.Nursery.child_tasks", - "trio._core._run.Nursery.parent_task", "trio._core._run.Nursery.start", "trio._core._run.Nursery.start_soon", - "trio._core._run.Task.__repr__", - "trio._core._run.Task.child_nurseries", - "trio._core._run.Task.context", - "trio._core._run.Task.coro", - "trio._core._run.Task.custom_sleep_data", - "trio._core._run.Task.eventual_parent_nursery", - "trio._core._run.Task.iter_await_frames", - "trio._core._run.Task.name", - "trio._core._run.Task.parent_nursery", - "trio._core._run._TaskStatusIgnored.__repr__", - "trio._core._run._TaskStatusIgnored.started", "trio._core._unbounded_queue.UnboundedQueue.__aiter__", "trio._core._unbounded_queue.UnboundedQueue.__anext__", "trio._core._unbounded_queue.UnboundedQueue.__repr__", @@ -132,12 +81,9 @@ "trio._core._unbounded_queue.UnboundedQueue.get_batch_nowait", "trio._core._unbounded_queue.UnboundedQueue.qsize", "trio._core._unbounded_queue.UnboundedQueue.statistics", - "trio._dtls.DTLSChannel", "trio._dtls.DTLSChannel.__enter__", - "trio._dtls.DTLSChannel.__exit__", "trio._dtls.DTLSChannel.__init__", "trio._dtls.DTLSChannel.aclose", - "trio._dtls.DTLSChannel.close", "trio._dtls.DTLSChannel.do_handshake", "trio._dtls.DTLSChannel.get_cleartext_mtu", "trio._dtls.DTLSChannel.receive", @@ -146,34 +92,17 @@ "trio._dtls.DTLSChannel.statistics", "trio._dtls.DTLSEndpoint.__del__", "trio._dtls.DTLSEndpoint.__enter__", - "trio._dtls.DTLSEndpoint.__exit__", "trio._dtls.DTLSEndpoint.__init__", - "trio._dtls.DTLSEndpoint.close", "trio._dtls.DTLSEndpoint.connect", "trio._dtls.DTLSEndpoint.incoming_packets_buffer", "trio._dtls.DTLSEndpoint.serve", "trio._dtls.DTLSEndpoint.socket", - "trio._highlevel_generic.StapledStream", - "trio._highlevel_generic.StapledStream.aclose", - "trio._highlevel_generic.StapledStream.receive_some", - "trio._highlevel_generic.StapledStream.receive_stream", - "trio._highlevel_generic.StapledStream.send_all", - "trio._highlevel_generic.StapledStream.send_eof", - "trio._highlevel_generic.StapledStream.send_stream", - "trio._highlevel_generic.StapledStream.wait_send_all_might_not_block", "trio._highlevel_socket.SocketListener", "trio._highlevel_socket.SocketListener.__init__", - "trio._highlevel_socket.SocketListener.accept", - "trio._highlevel_socket.SocketListener.aclose", - "trio._highlevel_socket.SocketStream", "trio._highlevel_socket.SocketStream.__init__", - "trio._highlevel_socket.SocketStream.aclose", "trio._highlevel_socket.SocketStream.getsockopt", - "trio._highlevel_socket.SocketStream.receive_some", "trio._highlevel_socket.SocketStream.send_all", - "trio._highlevel_socket.SocketStream.send_eof", "trio._highlevel_socket.SocketStream.setsockopt", - "trio._highlevel_socket.SocketStream.wait_send_all_might_not_block", "trio._path.AsyncAutoWrapperType.__init__", "trio._path.AsyncAutoWrapperType.generate_forwards", "trio._path.AsyncAutoWrapperType.generate_iter", @@ -188,11 +117,21 @@ "trio._path.Path.__rtruediv__", "trio._path.Path.__truediv__", "trio._path.Path.open", + "trio._socket._SocketType.__getattr__", + "trio._socket._SocketType.accept", + "trio._socket._SocketType.connect", + "trio._socket._SocketType.recv_into", + "trio._socket._SocketType.recvfrom", + "trio._socket._SocketType.recvfrom_into", + "trio._socket._SocketType.recvmsg", + "trio._socket._SocketType.recvmsg_into", + "trio._socket._SocketType.send", + "trio._socket._SocketType.sendmsg", + "trio._socket._SocketType.sendto", "trio._ssl.SSLListener", "trio._ssl.SSLListener.__init__", "trio._ssl.SSLListener.accept", "trio._ssl.SSLListener.aclose", - "trio._ssl.SSLStream", "trio._ssl.SSLStream.__dir__", "trio._ssl.SSLStream.__getattr__", "trio._ssl.SSLStream.__init__", @@ -204,7 +143,6 @@ "trio._ssl.SSLStream.transport_stream", "trio._ssl.SSLStream.unwrap", "trio._ssl.SSLStream.wait_send_all_might_not_block", - "trio._subprocess.Process", "trio._subprocess.Process.__aenter__", "trio._subprocess.Process.__init__", "trio._subprocess.Process.__repr__", @@ -219,47 +157,14 @@ "trio._subprocess.Process.send_signal", "trio._subprocess.Process.terminate", "trio._subprocess.Process.wait", - "trio._sync.CapacityLimiter.__init__", - "trio._sync.CapacityLimiter.__repr__", - "trio._sync.CapacityLimiter.available_tokens", - "trio._sync.CapacityLimiter.borrowed_tokens", - "trio._sync.CapacityLimiter.statistics", - "trio._sync.CapacityLimiter.total_tokens", - "trio._sync.Condition.__init__", - "trio._sync.Condition.acquire", - "trio._sync.Condition.acquire_nowait", - "trio._sync.Condition.locked", - "trio._sync.Condition.notify", - "trio._sync.Condition.notify_all", - "trio._sync.Condition.release", - "trio._sync.Condition.statistics", - "trio._sync.Event.is_set", - "trio._sync.Event.statistics", - "trio._sync.Event.wait", - "trio._sync.Lock", - "trio._sync.Semaphore.__init__", - "trio._sync.Semaphore.__repr__", - "trio._sync.Semaphore.max_value", - "trio._sync.Semaphore.statistics", - "trio._sync.Semaphore.value", - "trio._sync.StrictFIFOLock", - "trio._sync._LockImpl.__repr__", - "trio._sync._LockImpl.locked", - "trio._sync._LockImpl.statistics", - "trio._unix_pipes.FdStream", - "trio.aclose_forcefully", - "trio.current_effective_deadline", "trio.current_time", "trio.from_thread.run", "trio.from_thread.run_sync", "trio.lowlevel.add_instrument", "trio.lowlevel.cancel_shielded_checkpoint", - "trio.lowlevel.checkpoint", - "trio.lowlevel.checkpoint_if_cancelled", "trio.lowlevel.current_clock", "trio.lowlevel.current_root_task", "trio.lowlevel.current_statistics", - "trio.lowlevel.current_task", "trio.lowlevel.current_trio_token", "trio.lowlevel.currently_ki_protected", "trio.lowlevel.notify_closing", @@ -294,7 +199,6 @@ "trio.socket.set_custom_socket_factory", "trio.socket.socket", "trio.socket.socketpair", - "trio.testing._memory_streams.MemoryReceiveStream", "trio.testing._memory_streams.MemoryReceiveStream.__init__", "trio.testing._memory_streams.MemoryReceiveStream.aclose", "trio.testing._memory_streams.MemoryReceiveStream.close", @@ -303,7 +207,6 @@ "trio.testing._memory_streams.MemoryReceiveStream.put_eof", "trio.testing._memory_streams.MemoryReceiveStream.receive_some", "trio.testing._memory_streams.MemoryReceiveStream.receive_some_hook", - "trio.testing._memory_streams.MemorySendStream", "trio.testing._memory_streams.MemorySendStream.__init__", "trio.testing._memory_streams.MemorySendStream.aclose", "trio.testing._memory_streams.MemorySendStream.close", diff --git a/trio/_util.py b/trio/_util.py index c21cefe71e..0a0795fc15 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -9,6 +9,7 @@ import typing as t from abc import ABCMeta from functools import update_wrapper +from types import TracebackType import trio @@ -188,7 +189,12 @@ def __enter__(self): else: self._held = True - def __exit__(self, *args): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self._held = False diff --git a/trio/lowlevel.py b/trio/lowlevel.py index db8d180181..54f4ef3141 100644 --- a/trio/lowlevel.py +++ b/trio/lowlevel.py @@ -11,6 +11,7 @@ from ._core import ( Abort as Abort, ParkingLot as ParkingLot, + ParkingLotStatistics as ParkingLotStatistics, RaiseCancelT as RaiseCancelT, RunVar as RunVar, Task as Task, diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 33d741e670..401b8ef0c2 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -1,13 +1,18 @@ # Generic stream tests +from __future__ import annotations import random from contextlib import contextmanager +from typing import TYPE_CHECKING from .. import _core from .._abc import HalfCloseableStream, ReceiveStream, SendStream, Stream from .._highlevel_generic import aclose_forcefully from ._checkpoints import assert_checkpoints +if TYPE_CHECKING: + from types import TracebackType + class _ForceCloseBoth: def __init__(self, both): @@ -16,7 +21,12 @@ def __init__(self, both): async def __aenter__(self): return self._both - async def __aexit__(self, *args): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: try: await aclose_forcefully(self._both[0]) finally: diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index f2d40fb7ff..b3bdfd85c0 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -6,16 +6,21 @@ # - TCP # - UDP broadcast +from __future__ import annotations + import errno import ipaddress import os -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import attr import trio from trio._util import Final, NoPublicConstructor +if TYPE_CHECKING: + from types import TracebackType + IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -338,7 +343,12 @@ def setsockopt(self, level, item, value): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self.close() async def send(self, data, flags=0):