From 8dfdd56e88ff4578a4a2813fc0216e9fca907622 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 11 Jul 2023 15:07:42 +0200 Subject: [PATCH 01/49] various small type fixes, disallow_incomplete_defs = true, except for trio._core._run --- pyproject.toml | 4 ++++ trio/_core/_io_epoll.py | 12 +++++----- trio/_core/_mock_clock.py | 2 +- trio/_core/_run.py | 16 +++++++------- trio/_core/_thread_cache.py | 18 ++++++++++----- trio/_dtls.py | 41 ++++++++++++++++++++--------------- trio/_socket.py | 5 +++-- trio/_subprocess.py | 27 ++++++++++++++++++----- trio/_sync.py | 12 +++++----- trio/_tests/verify_types.json | 17 +++++---------- trio/_threads.py | 24 +++++++++++++++----- 11 files changed, 111 insertions(+), 67 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cfb4060ee7..954e21e2d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,10 @@ disallow_untyped_defs = false # DO NOT use `ignore_errors`; it doesn't apply # downstream and users have to deal with them. +[[tool.mypy.overrides]] +disallow_incomplete_defs = false +module = "trio._core._run" + [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] faulthandler_timeout = 60 diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 376dd18a4e..fbeb454c7d 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import select import sys from collections import defaultdict -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, DefaultDict, Dict import attr @@ -187,13 +189,13 @@ class EpollWaiters: @attr.s(slots=True, eq=False, hash=False) class EpollIOManager: - _epoll = attr.ib(factory=select.epoll) + _epoll: select.epoll = attr.ib(factory=select.epoll) # {fd: EpollWaiters} - _registered = attr.ib( + _registered: DefaultDict[int, EpollWaiters] = attr.ib( factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters] ) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: int | None = attr.ib(default=None) def __attrs_post_init__(self): self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py index fe35298631..27a5829076 100644 --- a/trio/_core/_mock_clock.py +++ b/trio/_core/_mock_clock.py @@ -150,7 +150,7 @@ def deadline_to_sleep_time(self, deadline: float) -> float: else: return 999999999 - def jump(self, seconds) -> None: + def jump(self, seconds: float) -> None: """Manually advance the clock by the given number of seconds. Args: diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 585dc4aa41..723370afd8 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1168,7 +1168,7 @@ def __del__(self) -> None: class Task(metaclass=NoPublicConstructor): _parent_nursery: Nursery | None = attr.ib() coro: Coroutine[Any, Outcome[object], Any] = attr.ib() - _runner = attr.ib() + _runner: Runner = attr.ib() name: str = attr.ib() context: contextvars.Context = attr.ib() _counter: int = attr.ib(init=False, factory=itertools.count().__next__) @@ -1184,8 +1184,8 @@ class Task(metaclass=NoPublicConstructor): # tracebacks with extraneous frames. # - for scheduled tasks, custom_sleep_data is None # Tasks start out unscheduled. - _next_send_fn = attr.ib(default=None) - _next_send = attr.ib(default=None) + _next_send_fn: Callable[[Outcome | None], None] = attr.ib(default=None) + _next_send: Outcome | None = attr.ib(default=None) _abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib( default=None ) @@ -1386,13 +1386,13 @@ class _RunStatistics: # worker thread. @attr.s(eq=False, hash=False, slots=True) class GuestState: - runner = attr.ib() - run_sync_soon_threadsafe = attr.ib() - run_sync_soon_not_threadsafe = attr.ib() - done_callback = attr.ib() + runner: Runner = attr.ib() + run_sync_soon_threadsafe: Callable = attr.ib() + run_sync_soon_not_threadsafe: Callable = attr.ib() + done_callback: Callable = attr.ib() unrolled_run_gen = attr.ib() _value_factory: Callable[[], Value] = lambda: Value(None) - unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome) + unrolled_run_next_send: Outcome = attr.ib(factory=_value_factory, type=Outcome) def guest_tick(self): try: diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py index cc272fc92c..3e27ce6a32 100644 --- a/trio/_core/_thread_cache.py +++ b/trio/_core/_thread_cache.py @@ -18,7 +18,9 @@ def _to_os_thread_name(name: str) -> bytes: # used to construct the method used to set os thread name, or None, depending on platform. # called once on import def get_os_thread_name_func() -> Optional[Callable[[Optional[int], str], None]]: - def namefunc(setname: Callable[[int, bytes], int], ident: Optional[int], name: str): + def namefunc( + setname: Callable[[int, bytes], int], ident: Optional[int], name: str + ) -> None: # Thread.ident is None "if it has not been started". Unclear if that can happen # with current usage. if ident is not None: # pragma: no cover @@ -28,7 +30,7 @@ def namefunc(setname: Callable[[int, bytes], int], ident: Optional[int], name: s # so the caller don't need to care about platform. def darwin_namefunc( setname: Callable[[bytes], int], ident: Optional[int], name: str - ): + ) -> None: # I don't know if Mac can rename threads that hasn't been started, but default # to no to be on the safe side. if ident is not None: # pragma: no cover @@ -111,7 +113,9 @@ def darwin_namefunc( class WorkerThread: - def __init__(self, thread_cache): + def __init__(self, thread_cache: ThreadCache): + # deliver (the second value) can probably be Callable[[outcome.Value], None] ? + # should generate stubs for outcome self._job: Optional[Tuple[Callable, Callable, str]] = None self._thread_cache = thread_cache # This Lock is used in an unconventional way. @@ -188,7 +192,9 @@ class ThreadCache: def __init__(self): self._idle_workers = {} - def start_thread_soon(self, fn, deliver, name: Optional[str] = None): + def start_thread_soon( + self, fn: Callable, deliver: Callable, name: Optional[str] = None + ) -> None: try: worker, _ = self._idle_workers.popitem() except KeyError: @@ -200,7 +206,9 @@ def start_thread_soon(self, fn, deliver, name: Optional[str] = None): THREAD_CACHE = ThreadCache() -def start_thread_soon(fn, deliver, name: Optional[str] = None): +def start_thread_soon( + fn: Callable, deliver: Callable, name: Optional[str] = None +) -> None: """Runs ``deliver(outcome.capture(fn))`` in a worker thread. Generally ``fn`` does some blocking work, and ``deliver`` delivers the diff --git a/trio/_dtls.py b/trio/_dtls.py index 722a9499f8..a1551fda0e 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -16,7 +16,7 @@ import warnings import weakref from itertools import count -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Awaitable, Callable import attr @@ -26,6 +26,11 @@ if TYPE_CHECKING: from types import TracebackType + from OpenSSL.SSL import Context + from typing_extensions import Self + + from trio._socket import _SocketType + MAX_UDP_PACKET_SIZE = 65527 @@ -1126,17 +1131,17 @@ class DTLSEndpoint(metaclass=Final): """ - def __init__(self, socket, *, incoming_packets_buffer=10): + def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. global SSL from OpenSSL import SSL - # TODO: create a `self._initialized` for `__del__`, so self.socket can be typed - # as trio.socket.SocketType and `is not None` checks can be removed. - self.socket = None # for __del__, in case the next line raises + # for __del__, in case the next line raises + self._initialized: bool = False if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") + self._initialized = True self.socket = socket self.incoming_packets_buffer = incoming_packets_buffer @@ -1146,8 +1151,8 @@ def __init__(self, socket, *, incoming_packets_buffer=10): # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams = weakref.WeakValueDictionary() - self._listening_context = None + self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self._listening_context: Context | None = None self._listening_key = None self._incoming_connections_q = _Queue(float("inf")) self._send_lock = trio.Lock() @@ -1164,9 +1169,9 @@ def _ensure_receive_loop(self): ) self._receive_loop_spawned = True - def __del__(self): + def __del__(self) -> None: # Do nothing if this object was never fully constructed - if self.socket is None: + if not self._initialized: return # Close the socket in Trio context (if our Trio context still exists), so that # the background task gets notified about the closure and can exit. @@ -1186,17 +1191,13 @@ def close(self) -> None: This object can also be used as a context manager. """ - # Do nothing if this object was never fully constructed - if self.socket is None: # pragma: no cover - return - self._closed = True self.socket.close() for stream in list(self._streams.values()): stream.close() self._incoming_connections_q.s.close() - def __enter__(self): + def __enter__(self) -> Self: return self def __exit__( @@ -1207,13 +1208,17 @@ def __exit__( ) -> None: return self.close() - def _check_closed(self): + def _check_closed(self) -> None: if self._closed: raise trio.ClosedResourceError - async def serve( - self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED - ): + async def serve( # type: ignore[no-untyped-def] + self, + ssl_context: Context, + async_fn: Callable[..., Awaitable], + *args, + task_status=trio.TASK_STATUS_IGNORED, # type: ignore[has-type] # ??? + ) -> None: """Listen for incoming connections, and spawn a handler for each using an internal nursery. diff --git a/trio/_socket.py b/trio/_socket.py index eaf0e04d15..e1a8c5562a 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -434,6 +434,8 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo return normed +# TODO: stopping users from initializing this type should be done in a different way, +# so SocketType can be used as a type. class SocketType: def __init__(self): raise TypeError( @@ -537,8 +539,7 @@ async def bind(self, address: tuple[object, ...] | str | bytes) -> None: ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) - # remove the `type: ignore` when run.sync is typed. - return await trio.to_thread.run_sync(self._sock.bind, address) # type: ignore[no-any-return] + return await trio.to_thread.run_sync(self._sock.bind, address) else: # POSIX actually says that bind can return EWOULDBLOCK and # complete asynchronously, like connect. But in practice AFAICT diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 1f8d0a8253..ab592d0e9a 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess import sys @@ -117,11 +119,17 @@ class Process(AsyncResource, metaclass=NoPublicConstructor): # arbitrarily many threads if wait() keeps getting cancelled. _wait_for_exit_data = None - def __init__(self, popen, stdin, stdout, stderr): + def __init__( + self, + popen: subprocess.Popen, + stdin: Optional[SendStream], + stdout: Optional[ReceiveStream], + stderr: Optional[ReceiveStream], + ): self._proc = popen - self.stdin: Optional[SendStream] = stdin - self.stdout: Optional[ReceiveStream] = stdout - self.stderr: Optional[ReceiveStream] = stderr + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr self.stdio: Optional[StapledStream] = None if self.stdin is not None and self.stdout is not None: @@ -294,8 +302,17 @@ def kill(self): self._proc.kill() +from typing import Any + + +# TODO: replace Any with a ParamSpec from Popen?? Or just type them out async def open_process( - command, *, stdin=None, stdout=None, stderr=None, **options + command: list[str] | str, + *, + stdin: int | None = None, + stdout: int | None = None, + stderr: int | None = None, + **options: Any, ) -> Process: r"""Execute a child program in a new process. diff --git a/trio/_sync.py b/trio/_sync.py index 5a7f240d5e..0f05dd458c 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -143,7 +143,7 @@ class CapacityLimiterStatistics: borrowed_tokens: int = attr.ib() total_tokens: int | float = attr.ib() - borrowers: list[Task] = attr.ib() + borrowers: list[object] = attr.ib() tasks_waiting: int = attr.ib() @@ -204,9 +204,9 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing def __init__(self, total_tokens: int | float): self._lot = ParkingLot() - self._borrowers: set[Task] = set() + self._borrowers: set[object] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of - self._pending_borrowers: dict[Task, Task] = {} + self._pending_borrowers: dict[Task, object] = {} # invoke the property setter for validation self.total_tokens: int | float = total_tokens assert self._total_tokens == total_tokens @@ -268,7 +268,7 @@ def acquire_nowait(self) -> None: self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) @enable_ki_protection - def acquire_on_behalf_of_nowait(self, borrower: Task) -> None: + def acquire_on_behalf_of_nowait(self, borrower: object) -> None: """Borrow a token from the sack on behalf of ``borrower``, without blocking. @@ -307,7 +307,7 @@ async def acquire(self) -> None: await self.acquire_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - async def acquire_on_behalf_of(self, borrower: Task) -> None: + async def acquire_on_behalf_of(self, borrower: object) -> None: """Borrow a token from the sack on behalf of ``borrower``, blocking if necessary. @@ -347,7 +347,7 @@ def release(self) -> None: self.release_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - def release_on_behalf_of(self, borrower: Task) -> None: + def release_on_behalf_of(self, borrower: object) -> None: """Put a token back into the sack on behalf of ``borrower``. Raises: diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 9d7d7aa912..147fa6253a 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.8764044943820225, + "completenessScore": 0.8812199036918138, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 546, - "withUnknownType": 76 + "withKnownType": 549, + "withUnknownType": 73 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -46,8 +46,8 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 433, - "withUnknownType": 135 + "withKnownType": 438, + "withUnknownType": 130 }, "packageName": "trio", "symbols": [ @@ -70,7 +70,6 @@ "trio._core._local.RunVar.get", "trio._core._local.RunVar.reset", "trio._core._local.RunVar.set", - "trio._core._mock_clock.MockClock.jump", "trio._core._run.Nursery.start", "trio._core._run.Nursery.start_soon", "trio._core._unbounded_queue.UnboundedQueue.__aiter__", @@ -90,13 +89,9 @@ "trio._dtls.DTLSChannel.send", "trio._dtls.DTLSChannel.set_ciphertext_mtu", "trio._dtls.DTLSChannel.statistics", - "trio._dtls.DTLSEndpoint.__del__", - "trio._dtls.DTLSEndpoint.__enter__", "trio._dtls.DTLSEndpoint.__init__", "trio._dtls.DTLSEndpoint.connect", - "trio._dtls.DTLSEndpoint.incoming_packets_buffer", "trio._dtls.DTLSEndpoint.serve", - "trio._dtls.DTLSEndpoint.socket", "trio._highlevel_socket.SocketListener", "trio._highlevel_socket.SocketListener.__init__", "trio._highlevel_socket.SocketStream.__init__", @@ -168,14 +163,12 @@ "trio.lowlevel.current_trio_token", "trio.lowlevel.currently_ki_protected", "trio.lowlevel.notify_closing", - "trio.lowlevel.open_process", "trio.lowlevel.permanently_detach_coroutine_object", "trio.lowlevel.reattach_detached_coroutine_object", "trio.lowlevel.remove_instrument", "trio.lowlevel.reschedule", "trio.lowlevel.spawn_system_task", "trio.lowlevel.start_guest_run", - "trio.lowlevel.start_thread_soon", "trio.lowlevel.temporarily_detach_coroutine_object", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable", diff --git a/trio/_threads.py b/trio/_threads.py index 807212e0f9..52c742d588 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextvars import functools import inspect @@ -57,10 +59,19 @@ class ThreadPlaceholder: name = attr.ib() +from typing import Any, Callable, TypeVar + +T = TypeVar("T") + + @enable_ki_protection async def to_thread_run_sync( - sync_fn, *args, thread_name: Optional[str] = None, cancellable=False, limiter=None -): + sync_fn: Callable[..., T], + *args: Any, + thread_name: Optional[str] = None, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T: """Convert a blocking operation into an async operation using a thread. These two lines are equivalent:: @@ -152,7 +163,7 @@ async def to_thread_run_sync( # Holds a reference to the task that's blocked in this function waiting # for the result – or None if this function was cancelled and we should # discard the result. - task_register = [trio.lowlevel.current_task()] + task_register: list[trio.lowlevel.Task | None] = [trio.lowlevel.current_task()] name = f"trio.to_thread.run_sync-{next(_thread_counter)}" placeholder = ThreadPlaceholder(name) @@ -217,14 +228,17 @@ def deliver_worker_fn_result(result): limiter.release_on_behalf_of(placeholder) raise - def abort(_): + from trio._core._traps import RaiseCancelT + + def abort(_: RaiseCancelT) -> trio.lowlevel.Abort: if cancellable: task_register[0] = None return trio.lowlevel.Abort.SUCCEEDED else: return trio.lowlevel.Abort.FAILED - return await trio.lowlevel.wait_task_rescheduled(abort) + # wait_task_rescheduled return value cannot be typed + return await trio.lowlevel.wait_task_rescheduled(abort) # type: ignore[no-any-return] def _run_fn_as_system_task(cb, fn, *args, context, trio_token=None): From 46ac4e852bce69bcdff2cacff2932626d45d5110 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 12 Jul 2023 13:33:35 +0200 Subject: [PATCH 02/49] stuff --- pyproject.toml | 2 +- trio/__init__.py | 5 +- trio/_abc.py | 54 +++++++++++++----- trio/_core/_local.py | 46 ++++++++------- trio/_core/_run.py | 4 +- trio/_core/_unbounded_queue.py | 51 +++++++++++------ trio/_deprecate.py | 13 ++++- trio/_dtls.py | 100 ++++++++++++++++++++------------- trio/_socket.py | 65 +++++++++++++++------ trio/_tests/verify_types.json | 59 ++----------------- trio/_threads.py | 2 +- trio/_util.py | 2 +- trio/testing/_fake_net.py | 23 +++++++- 13 files changed, 255 insertions(+), 171 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 954e21e2d3..121a398234 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,8 +45,8 @@ disallow_untyped_defs = false # downstream and users have to deal with them. [[tool.mypy.overrides]] -disallow_incomplete_defs = false module = "trio._core._run" +disallow_incomplete_defs = false [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] diff --git a/trio/__init__.py b/trio/__init__.py index 2b8810504b..35dc3e133f 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """Trio - A friendly Python library for async concurrency and I/O """ @@ -15,6 +17,7 @@ # # Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) + # must be imported early to avoid circular import from ._core import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED # isort: skip @@ -112,7 +115,7 @@ _deprecate.enable_attribute_deprecations(__name__) -__deprecated_attributes__ = { +__deprecated_attributes__: dict[str, _deprecate.DeprecatedAttribute] = { "open_process": _deprecate.DeprecatedAttribute( value=lowlevel.open_process, version="0.20.0", diff --git a/trio/_abc.py b/trio/_abc.py index 2a1721db13..0bb49e207d 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -6,10 +6,15 @@ import trio if TYPE_CHECKING: + import socket from types import TracebackType from typing_extensions import Self + from trio.lowlevel import Task + + from ._socket import _SocketType + # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a # __dict__ onto subclasses. @@ -73,13 +78,13 @@ class Instrument(metaclass=ABCMeta): __slots__ = () - def before_run(self): + def before_run(self) -> None: """Called at the beginning of :func:`trio.run`.""" - def after_run(self): + def after_run(self) -> None: """Called just before :func:`trio.run` returns.""" - def task_spawned(self, task): + def task_spawned(self, task: Task) -> None: """Called when the given task is created. Args: @@ -87,7 +92,7 @@ def task_spawned(self, task): """ - def task_scheduled(self, task): + def task_scheduled(self, task: Task) -> None: """Called when the given task becomes runnable. It may still be some time before it actually runs, if there are other @@ -98,7 +103,7 @@ def task_scheduled(self, task): """ - def before_task_step(self, task): + def before_task_step(self, task: Task) -> None: """Called immediately before we resume running the given task. Args: @@ -106,7 +111,7 @@ def before_task_step(self, task): """ - def after_task_step(self, task): + def after_task_step(self, task: Task) -> None: """Called when we return to the main run loop after a task has yielded. Args: @@ -114,7 +119,7 @@ def after_task_step(self, task): """ - def task_exited(self, task): + def task_exited(self, task: Task) -> None: """Called when the given task exits. Args: @@ -122,7 +127,7 @@ def task_exited(self, task): """ - def before_io_wait(self, timeout): + def before_io_wait(self, timeout: float) -> None: """Called before blocking to wait for I/O readiness. Args: @@ -130,7 +135,7 @@ def before_io_wait(self, timeout): """ - def after_io_wait(self, timeout): + def after_io_wait(self, timeout: float) -> None: """Called after handling pending I/O. Args: @@ -152,7 +157,23 @@ class HostnameResolver(metaclass=ABCMeta): __slots__ = () @abstractmethod - async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): + async def getaddrinfo( + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: """A custom implementation of :func:`~trio.socket.getaddrinfo`. Called by :func:`trio.socket.getaddrinfo`. @@ -169,7 +190,9 @@ async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): """ @abstractmethod - async def getnameinfo(self, sockaddr, flags): + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, str]: """A custom implementation of :func:`~trio.socket.getnameinfo`. Called by :func:`trio.socket.getnameinfo`. @@ -186,7 +209,12 @@ class SocketFactory(metaclass=ABCMeta): """ @abstractmethod - def socket(self, family=None, type=None, proto=None): + def socket( + self, + family: socket.AddressFamily | int = socket.AF_INET, + type: socket.SocketKind | int = socket.SOCK_STREAM, + proto: int = 0, + ) -> _SocketType: """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, @@ -537,7 +565,7 @@ class Listener(AsyncResource, Generic[T_resource]): __slots__ = () @abstractmethod - async def accept(self): + async def accept(self) -> AsyncResource: """Wait until an incoming connection arrives, and then return it. Returns: diff --git a/trio/_core/_local.py b/trio/_core/_local.py index a54f424fdf..89ccf93e95 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,25 +1,32 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + # Runvar implementations import attr from .._util import Final from . import _run +T = TypeVar("T") +C = TypeVar("C", bound="_RunVarToken") + @attr.s(eq=False, hash=False, slots=True) -class _RunVarToken: - _no_value = object() +class _RunVarToken(Generic[T]): + _no_value = None - _var = attr.ib() - previous_value = attr.ib(default=_no_value) - redeemed = attr.ib(default=False, init=False) + _var: RunVar[T] = attr.ib() + previous_value: T | None = attr.ib(default=_no_value) + redeemed: bool = attr.ib(default=False, init=False) @classmethod - def empty(cls, var): + def empty(cls: type[C], var: RunVar[T]) -> C: return cls(var) @attr.s(eq=False, hash=False, slots=True) -class RunVar(metaclass=Final): +class RunVar(Generic[T], metaclass=Final): """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, @@ -28,14 +35,15 @@ class RunVar(metaclass=Final): """ - _NO_DEFAULT = object() - _name = attr.ib() - _default = attr.ib(default=_NO_DEFAULT) + _NO_DEFAULT = None + _name: str = attr.ib() + _default: T | None = attr.ib(default=_NO_DEFAULT) - def get(self, default=_NO_DEFAULT): + def get(self, default: T | None = _NO_DEFAULT) -> T | None: """Gets the value of this :class:`RunVar` for the current run call.""" try: - return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] + # not typed yet + return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value, index] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: @@ -48,7 +56,7 @@ def get(self, default=_NO_DEFAULT): raise LookupError(self) from None - def set(self, value): + def set(self, value: T) -> _RunVarToken[T]: """Sets the value of this :class:`RunVar` for this current run call. @@ -56,16 +64,16 @@ def set(self, value): try: old_value = self.get() except LookupError: - token = _RunVarToken.empty(self) + token: _RunVarToken[T] = _RunVarToken.empty(self) else: token = _RunVarToken(self, old_value) # This can't fail, because if we weren't in Trio context then the # get() above would have failed. - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index] return token - def reset(self, token): + def reset(self, token: _RunVarToken[T]) -> None: """Resets the value of this :class:`RunVar` to what it was previously specified by the token. @@ -82,13 +90,13 @@ def reset(self, token): previous = token.previous_value try: if previous is _RunVarToken._no_value: - _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) + _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type] else: - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") token.redeemed = True - def __repr__(self): + def __repr__(self) -> str: return f"" diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 723370afd8..804a958714 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -789,10 +789,10 @@ class _TaskStatus: _called_started = attr.ib(default=False) _value = attr.ib(default=None) - def __repr__(self): + def __repr__(self) -> str: return f"" - def started(self, value=None): + def started(self, value: Any = None) -> None: if self._called_started: raise RuntimeError("called 'started' twice on the same task status") self._called_started = True diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index 9c747749b4..cbcf10cf89 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -1,17 +1,34 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + import attr +from typing_extensions import Self from .. import _core from .._deprecate import deprecated from .._util import Final +T = TypeVar("T") + @attr.s(frozen=True) -class _UnboundedQueueStats: - qsize = attr.ib() - tasks_waiting = attr.ib() +class UnboundedQueueStats: + """An object containing debugging information. + + Currently the following fields are defined: + + * ``qsize``: The number of items currently in the queue. + * ``tasks_waiting``: The number of tasks blocked on this queue's + :meth:`get_batch` method. + + """ + + qsize: int = attr.ib() + tasks_waiting: int = attr.ib() -class UnboundedQueue(metaclass=Final): +class UnboundedQueue(Generic[T], metaclass=Final): """An unbounded queue suitable for certain unusual forms of inter-task communication. @@ -47,20 +64,20 @@ class UnboundedQueue(metaclass=Final): thing="trio.lowlevel.UnboundedQueue", instead="trio.open_memory_channel(math.inf)", ) - def __init__(self): + def __init__(self) -> None: self._lot = _core.ParkingLot() - self._data = [] + self._data: list[T] = [] # used to allow handoff from put to the first task in the lot self._can_get = False - def __repr__(self): + def __repr__(self) -> str: return f"" - def qsize(self): + def qsize(self) -> int: """Returns the number of items currently in the queue.""" return len(self._data) - def empty(self): + def empty(self) -> bool: """Returns True if the queue is empty, False otherwise. There is some subtlety to interpreting this method's return value: see @@ -70,7 +87,7 @@ def empty(self): return not self._data @_core.enable_ki_protection - def put_nowait(self, obj): + def put_nowait(self, obj: T) -> None: """Put an object into the queue, without blocking. This always succeeds, because the queue is unbounded. We don't provide @@ -88,13 +105,13 @@ def put_nowait(self, obj): self._can_get = True self._data.append(obj) - def _get_batch_protected(self): + def _get_batch_protected(self) -> list[T]: data = self._data.copy() self._data.clear() self._can_get = False return data - def get_batch_nowait(self): + def get_batch_nowait(self) -> list[T]: """Attempt to get the next batch from the queue, without blocking. Returns: @@ -110,7 +127,7 @@ def get_batch_nowait(self): raise _core.WouldBlock return self._get_batch_protected() - async def get_batch(self): + async def get_batch(self) -> list[T]: """Get the next batch from the queue, blocking as necessary. Returns: @@ -128,7 +145,7 @@ async def get_batch(self): finally: await _core.cancel_shielded_checkpoint() - def statistics(self): + def statistics(self) -> UnboundedQueueStats: """Return an object containing debugging information. Currently the following fields are defined: @@ -138,12 +155,12 @@ def statistics(self): :meth:`get_batch` method. """ - return _UnboundedQueueStats( + return UnboundedQueueStats( qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting ) - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> list[T]: return await self.get_batch() diff --git a/trio/_deprecate.py b/trio/_deprecate.py index fe00192583..aeebe80722 100644 --- a/trio/_deprecate.py +++ b/trio/_deprecate.py @@ -1,10 +1,15 @@ +from __future__ import annotations + import sys import warnings from functools import wraps from types import ModuleType +from typing import Callable, TypeVar import attr +T = TypeVar("T") + # We want our warnings to be visible by default (at least for now), but we # also want it to be possible to override that using the -W switch. AFAICT @@ -53,7 +58,9 @@ def warn_deprecated(thing, version, *, issue, instead, stacklevel=2): # @deprecated("0.2.0", issue=..., instead=...) # def ... -def deprecated(version, *, thing=None, issue, instead): +def deprecated( + version: str, *, thing: str | None = None, issue: int, instead: str +) -> Callable[[T], T]: def do_wrap(fn): nonlocal thing @@ -124,10 +131,10 @@ def __getattr__(self, name): raise AttributeError(msg.format(self.__name__, name)) -def enable_attribute_deprecations(module_name): +def enable_attribute_deprecations(module_name: str) -> None: module = sys.modules[module_name] module.__class__ = _ModuleWithDeprecations # Make sure that this is always defined so that # _ModuleWithDeprecations.__getattr__ can access it without jumping # through hoops or risking infinite recursion. - module.__deprecated_attributes__ = {} + module.__deprecated_attributes__ = {} # type: ignore[attr-defined] diff --git a/trio/_dtls.py b/trio/_dtls.py index a1551fda0e..aea15be735 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -16,9 +16,10 @@ import warnings import weakref from itertools import count -from typing import TYPE_CHECKING, Awaitable, Callable +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Iterator, cast import attr +from OpenSSL import SSL import trio from trio._util import Final, NoPublicConstructor @@ -31,24 +32,26 @@ from trio._socket import _SocketType + from ._core._run import _TaskStatus + MAX_UDP_PACKET_SIZE = 65527 -def packet_header_overhead(sock): +def packet_header_overhead(sock: _SocketType) -> int: if sock.family == trio.socket.AF_INET: return 28 else: return 48 -def worst_case_mtu(sock): +def worst_case_mtu(sock: _SocketType) -> int: if sock.family == trio.socket.AF_INET: return 576 - packet_header_overhead(sock) else: return 1280 - packet_header_overhead(sock) -def best_guess_mtu(sock): +def best_guess_mtu(sock: _SocketType) -> int: return 1500 - packet_header_overhead(sock) @@ -110,14 +113,14 @@ class BadPacket(Exception): # ChangeCipherSpec is used during the handshake but has its own ContentType. # # Cannot fail. -def part_of_handshake_untrusted(packet): +def part_of_handshake_untrusted(packet: bytes) -> bool: # If the packet is too short, then slicing will successfully return a # short string, which will necessarily fail to match. return packet[3:5] == b"\x00\x00" # Cannot fail -def is_client_hello_untrusted(packet): +def is_client_hello_untrusted(packet: bytes) -> bool: try: return ( packet[0] == ContentType.handshake @@ -152,7 +155,7 @@ class Record: payload: bytes = attr.ib(repr=to_hex) -def records_untrusted(packet): +def records_untrusted(packet: bytes) -> Iterator[Record]: i = 0 while i < len(packet): try: @@ -170,7 +173,7 @@ def records_untrusted(packet): yield Record(ct, version, epoch_seqno, payload) -def encode_record(record): +def encode_record(record: Record) -> bytes: header = RECORD_HEADER.pack( record.content_type, record.version, @@ -199,7 +202,7 @@ class HandshakeFragment: frag: bytes = attr.ib(repr=to_hex) -def decode_handshake_fragment_untrusted(payload): +def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment: # Raises BadPacket if decoding fails try: ( @@ -229,7 +232,7 @@ def decode_handshake_fragment_untrusted(payload): ) -def encode_handshake_fragment(hsf): +def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes: hs_header = HANDSHAKE_MESSAGE_HEADER.pack( hsf.msg_type, hsf.msg_len.to_bytes(3, "big"), @@ -240,7 +243,7 @@ def encode_handshake_fragment(hsf): return hs_header + hsf.frag -def decode_client_hello_untrusted(packet): +def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]: # Raises BadPacket if parsing fails # Returns (record epoch_seqno, cookie from the packet, data that should be # hashed into cookie) @@ -340,8 +343,12 @@ class OpaqueHandshakeMessage: # reconstructs the handshake messages inside it, so that we can repack them # into records while retransmitting. So the data ought to be well-behaved -- # it's not coming from the network. -def decode_volley_trusted(volley): - messages = [] +def decode_volley_trusted( + volley: bytes, +) -> list[HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage]: + messages: list[ + HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage + ] = [] messages_by_seq = {} for record in records_untrusted(volley): # ChangeCipherSpec isn't a handshake message, so it can't be fragmented. @@ -388,10 +395,16 @@ class RecordEncoder: def __init__(self): self._record_seq = count() - def set_first_record_number(self, n): + def set_first_record_number(self, n: int) -> None: self._record_seq = count(n) - def encode_volley(self, messages, mtu): + def encode_volley( + self, + messages: Iterable[ + HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage + ], + mtu: int, + ) -> list[bytearray]: packets = [] packet = bytearray() for message in messages: @@ -523,13 +536,13 @@ def encode_volley(self, messages, mtu): COOKIE_LENGTH = 32 -def _current_cookie_tick(): +def _current_cookie_tick() -> int: return int(trio.current_time() / COOKIE_REFRESH_INTERVAL) # Simple deterministic and invertible serializer -- i.e., a useful tool for converting # structured data into something we can cryptographically sign. -def _signable(*fields): +def _signable(*fields: bytes) -> bytes: out = [] for field in fields: out.append(struct.pack("!Q", len(field))) @@ -618,7 +631,7 @@ def __init__(self, incoming_packets_buffer): self.s, self.r = trio.open_memory_channel(incoming_packets_buffer) -def _read_loop(read_fn): +def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: chunks = [] while True: try: @@ -778,7 +791,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint, peer_address, ctx): + def __init__(self, endpoint: DTLSEndpoint, peer_address: str, ctx: Context): self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -789,9 +802,9 @@ def __init__(self, endpoint, peer_address, ctx): # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to # support and isn't useful anyway -- especially for DTLS where it's equivalent # to just performing a new handshake. - ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) + ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined] self._ssl = SSL.Connection(ctx) - self._handshake_mtu = None + self._handshake_mtu: int | None = None # This calls self._ssl.set_ciphertext_mtu, which is important, because if you # don't call it then openssl doesn't work. self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket)) @@ -841,7 +854,7 @@ def close(self) -> None: # ClosedResourceError self._q.r.close() - def __enter__(self): + def __enter__(self) -> Self: return self def __exit__( @@ -852,7 +865,7 @@ def __exit__( ) -> None: return self.close() - async def aclose(self): + async def aclose(self) -> None: """Close this connection, but asynchronously. This is included to satisfy the `trio.abc.Channel` contract. It's @@ -873,7 +886,7 @@ async def _send_volley(self, volley_messages): async def _resend_final_volley(self): await self._send_volley(self._final_volley) - async def do_handshake(self, *, initial_retransmit_timeout=1.0): + async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None: """Perform the handshake. Calling this is optional – if you don't, then it will be automatically called @@ -906,17 +919,23 @@ async def do_handshake(self, *, initial_retransmit_timeout=1.0): return timeout = initial_retransmit_timeout - volley_messages = [] + volley_messages: list[ + HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage + ] = [] volley_failed_sends = 0 - def read_volley(): + def read_volley() -> ( + list[HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage] + ): volley_bytes = _read_loop(self._ssl.bio_read) new_volley_messages = decode_volley_trusted(volley_bytes) if ( new_volley_messages and volley_messages and isinstance(new_volley_messages[0], HandshakeMessage) - and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq + # TODO: add isinstance or do a cast? + and new_volley_messages[0].msg_seq + == cast(HandshakeMessage, volley_messages[0]).msg_seq ): # openssl decided to retransmit; discard because we handle # retransmits ourselves @@ -1000,10 +1019,13 @@ def read_volley(): # PMTU estimate is wrong? Let's try dropping it to the minimum # and hope that helps. self._handshake_mtu = min( - self._handshake_mtu, worst_case_mtu(self.endpoint.socket) + self._handshake_mtu or 0, + worst_case_mtu(self.endpoint.socket), ) - async def send(self, data): + async def send( + self, data: bytes + ) -> None: # or str? SendChannel defines it as bytes """Send a packet of data, securely.""" if self._closed: @@ -1019,7 +1041,7 @@ async def send(self, data): _read_loop(self._ssl.bio_read), self.peer_address ) - async def receive(self): + async def receive(self) -> bytes: # or str? """Fetch the next packet of data from this connection's peer, waiting if necessary. @@ -1045,7 +1067,7 @@ async def receive(self): if cleartext: return cleartext - def set_ciphertext_mtu(self, new_mtu): + def set_ciphertext_mtu(self, new_mtu: int) -> None: """Tells Trio the `largest amount of data that can be sent in a single packet to this peer `__. @@ -1080,7 +1102,7 @@ def set_ciphertext_mtu(self, new_mtu): self._handshake_mtu = new_mtu self._ssl.set_ciphertext_mtu(new_mtu) - def get_cleartext_mtu(self): + def get_cleartext_mtu(self) -> int: """Returns the largest number of bytes that you can pass in a single call to `send` while still fitting within the network-level MTU. @@ -1089,9 +1111,9 @@ def get_cleartext_mtu(self): """ if not self._did_handshake: raise trio.NeedHandshakeError - return self._ssl.get_cleartext_mtu() + return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return] - def statistics(self): + def statistics(self) -> DTLSChannelStatistics: """Returns an object with statistics about this connection. Currently this has only one attribute: @@ -1142,7 +1164,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") self._initialized = True - self.socket = socket + self.socket: _SocketType = socket self.incoming_packets_buffer = incoming_packets_buffer self._token = trio.lowlevel.current_trio_token() @@ -1212,12 +1234,12 @@ def _check_closed(self) -> None: if self._closed: raise trio.ClosedResourceError - async def serve( # type: ignore[no-untyped-def] + async def serve( self, ssl_context: Context, async_fn: Callable[..., Awaitable], - *args, - task_status=trio.TASK_STATUS_IGNORED, # type: ignore[has-type] # ??? + *args: Any, + task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] # ??? ) -> None: """Listen for incoming connections, and spawn a handler for each using an internal nursery. @@ -1272,7 +1294,7 @@ async def handler_wrapper(stream): finally: self._listening_context = None - def connect(self, address, ssl_context): + def connect(self, address: tuple[str, int], ssl_context: Context) -> DTLSChannel: """Initiate an outgoing DTLS connection. Notice that this is a synchronous method. That's because it doesn't actually diff --git a/trio/_socket.py b/trio/_socket.py index e1a8c5562a..0f5aa75fd2 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -5,7 +5,7 @@ import socket as _stdlib_socket import sys from functools import wraps as _wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Callable, SupportsIndex import idna as _idna @@ -19,6 +19,8 @@ from typing_extensions import Self + from ._abc import HostnameResolver, SocketFactory + # Usage: # @@ -73,11 +75,13 @@ async def __aexit__( # Overrides ################################################################ -_resolver = _core.RunVar("hostname_resolver") -_socket_factory = _core.RunVar("socket_factory") +_resolver: _core.RunVar[HostnameResolver | None] = _core.RunVar("hostname_resolver") +_socket_factory: _core.RunVar[SocketFactory] = _core.RunVar("socket_factory") -def set_custom_hostname_resolver(hostname_resolver): +def set_custom_hostname_resolver( + hostname_resolver: HostnameResolver | None, +) -> HostnameResolver | None: """Set a custom hostname resolver. By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions @@ -143,7 +147,22 @@ def set_custom_socket_factory(socket_factory): _NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV -async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): +async def getaddrinfo( + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> list[ + tuple[ + _stdlib_socket.AddressFamily, + _stdlib_socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] +]: """Look up a numeric address given a name. Arguments and return values are identical to :func:`socket.getaddrinfo`, @@ -190,7 +209,7 @@ def numeric_only_failure(exc): # idna.encode will error out if the hostname has Capital Letters # in it; with uts46=True it will lowercase them instead. host = _idna.encode(host, uts46=True) - hr = _resolver.get(None) + hr: HostnameResolver | None = _resolver.get(None) if hr is not None: return await hr.getaddrinfo(host, port, family, type, proto, flags) else: @@ -206,7 +225,9 @@ def numeric_only_failure(exc): ) -async def getnameinfo(sockaddr, flags): +async def getnameinfo( + sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int +) -> tuple[str, str]: """Look up a name given a numeric address. Arguments and return values are identical to :func:`socket.getnameinfo`, @@ -244,7 +265,7 @@ async def getprotobyname(name): ################################################################ -def from_stdlib_socket(sock): +def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType: """Convert a standard library :class:`socket.socket` object into a Trio socket object. @@ -253,7 +274,12 @@ def from_stdlib_socket(sock): @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) -def fromfd(fd, family, type, proto=0): +def fromfd( + fd: SupportsIndex, + family: _stdlib_socket.AddressFamily | int = _stdlib_socket.AF_INET, + type: _stdlib_socket.SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, +) -> _SocketType: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -280,11 +306,11 @@ def socketpair(*args, **kwargs): @_wraps(_stdlib_socket.socket, assigned=(), updated=()) def socket( - family=_stdlib_socket.AF_INET, - type=_stdlib_socket.SOCK_STREAM, - proto=0, - fileno=None, -): + family: _stdlib_socket.AddressFamily | int = _stdlib_socket.AF_INET, + type: _stdlib_socket.SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, + fileno: int | None = None, +) -> _SocketType: """Create a new Trio socket, like :class:`socket.socket`. This function's behavior can be customized using @@ -483,7 +509,7 @@ def __init__(self, sock: _stdlib_socket.socket): "share", } - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name in self._forward: return getattr(self._sock, name) raise AttributeError(name) @@ -619,9 +645,11 @@ async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): # accept ################################################################ - _accept = _make_simple_sock_method_wrapper("accept", _core.wait_readable) + _accept: Callable[ + [], Awaitable[tuple[_stdlib_socket.socket, object]] + ] = _make_simple_sock_method_wrapper("accept", _core.wait_readable) - async def accept(self): + async def accept(self) -> tuple[_SocketType, object]: """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr @@ -630,7 +658,8 @@ async def accept(self): # connect ################################################################ - async def connect(self, address): + # TODO: typing addresses is ... a pain + async def connect(self, address: str) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 147fa6253a..00a964787a 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.8812199036918138, + "completenessScore": 0.9020866773675762, "exportedSymbolCounts": { - "withAmbiguousType": 1, - "withKnownType": 549, - "withUnknownType": 73 + "withAmbiguousType": 0, + "withKnownType": 562, + "withUnknownType": 61 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -46,53 +46,17 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 438, - "withUnknownType": 130 + "withKnownType": 523, + "withUnknownType": 89 }, "packageName": "trio", "symbols": [ - "trio.__deprecated_attributes__", - "trio._abc.HostnameResolver.getaddrinfo", - "trio._abc.HostnameResolver.getnameinfo", - "trio._abc.Instrument.after_io_wait", - "trio._abc.Instrument.after_run", - "trio._abc.Instrument.after_task_step", - "trio._abc.Instrument.before_io_wait", - "trio._abc.Instrument.before_run", - "trio._abc.Instrument.before_task_step", - "trio._abc.Instrument.task_exited", - "trio._abc.Instrument.task_scheduled", - "trio._abc.Instrument.task_spawned", - "trio._abc.Listener.accept", "trio._abc.SocketFactory.socket", "trio._core._entry_queue.TrioToken.run_sync_soon", - "trio._core._local.RunVar.__repr__", - "trio._core._local.RunVar.get", - "trio._core._local.RunVar.reset", - "trio._core._local.RunVar.set", "trio._core._run.Nursery.start", "trio._core._run.Nursery.start_soon", - "trio._core._unbounded_queue.UnboundedQueue.__aiter__", - "trio._core._unbounded_queue.UnboundedQueue.__anext__", - "trio._core._unbounded_queue.UnboundedQueue.__repr__", - "trio._core._unbounded_queue.UnboundedQueue.empty", - "trio._core._unbounded_queue.UnboundedQueue.get_batch", - "trio._core._unbounded_queue.UnboundedQueue.get_batch_nowait", - "trio._core._unbounded_queue.UnboundedQueue.qsize", - "trio._core._unbounded_queue.UnboundedQueue.statistics", - "trio._dtls.DTLSChannel.__enter__", "trio._dtls.DTLSChannel.__init__", - "trio._dtls.DTLSChannel.aclose", - "trio._dtls.DTLSChannel.do_handshake", - "trio._dtls.DTLSChannel.get_cleartext_mtu", - "trio._dtls.DTLSChannel.receive", - "trio._dtls.DTLSChannel.send", - "trio._dtls.DTLSChannel.set_ciphertext_mtu", - "trio._dtls.DTLSChannel.statistics", "trio._dtls.DTLSEndpoint.__init__", - "trio._dtls.DTLSEndpoint.connect", - "trio._dtls.DTLSEndpoint.serve", - "trio._highlevel_socket.SocketListener", "trio._highlevel_socket.SocketListener.__init__", "trio._highlevel_socket.SocketStream.__init__", "trio._highlevel_socket.SocketStream.getsockopt", @@ -112,9 +76,6 @@ "trio._path.Path.__rtruediv__", "trio._path.Path.__truediv__", "trio._path.Path.open", - "trio._socket._SocketType.__getattr__", - "trio._socket._SocketType.accept", - "trio._socket._SocketType.connect", "trio._socket._SocketType.recv_into", "trio._socket._SocketType.recvfrom", "trio._socket._SocketType.recvfrom_into", @@ -123,7 +84,6 @@ "trio._socket._SocketType.send", "trio._socket._SocketType.sendmsg", "trio._socket._SocketType.sendto", - "trio._ssl.SSLListener", "trio._ssl.SSLListener.__init__", "trio._ssl.SSLListener.accept", "trio._ssl.SSLListener.aclose", @@ -155,7 +115,6 @@ "trio.current_time", "trio.from_thread.run", "trio.from_thread.run_sync", - "trio.lowlevel.add_instrument", "trio.lowlevel.cancel_shielded_checkpoint", "trio.lowlevel.current_clock", "trio.lowlevel.current_root_task", @@ -165,7 +124,6 @@ "trio.lowlevel.notify_closing", "trio.lowlevel.permanently_detach_coroutine_object", "trio.lowlevel.reattach_detached_coroutine_object", - "trio.lowlevel.remove_instrument", "trio.lowlevel.reschedule", "trio.lowlevel.spawn_system_task", "trio.lowlevel.start_guest_run", @@ -184,13 +142,8 @@ "trio.serve_ssl_over_tcp", "trio.serve_tcp", "trio.socket.from_stdlib_socket", - "trio.socket.fromfd", - "trio.socket.getaddrinfo", - "trio.socket.getnameinfo", "trio.socket.getprotobyname", - "trio.socket.set_custom_hostname_resolver", "trio.socket.set_custom_socket_factory", - "trio.socket.socket", "trio.socket.socketpair", "trio.testing._memory_streams.MemoryReceiveStream.__init__", "trio.testing._memory_streams.MemoryReceiveStream.aclose", diff --git a/trio/_threads.py b/trio/_threads.py index 52c742d588..45a416249e 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -27,7 +27,7 @@ # Global due to Threading API, thread local storage for trio token TOKEN_LOCAL = threading.local() -_limiter_local = RunVar("limiter") +_limiter_local: RunVar[CapacityLimiter] = RunVar("limiter") # I pulled this number out of the air; it isn't based on anything. Probably we # should make some kind of measurements to pick a good value. DEFAULT_LIMIT = 40 diff --git a/trio/_util.py b/trio/_util.py index 0a0795fc15..0f73ff19e9 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -216,7 +216,7 @@ def decorator(func): return decorator -def fixup_module_metadata(module_name, namespace): +def fixup_module_metadata(module_name: str, namespace: dict[str, object]) -> None: seen_ids = set() def fix_one(qualname, name, obj): diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index b3bdfd85c0..fdb4d45102 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -19,6 +19,7 @@ from trio._util import Final, NoPublicConstructor if TYPE_CHECKING: + import socket from types import TracebackType IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -113,11 +114,27 @@ class FakeHostnameResolver(trio.abc.HostnameResolver): fake_net: "FakeNet" async def getaddrinfo( - self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0 - ): + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: raise NotImplementedError("FakeNet doesn't do fake DNS yet") - async def getnameinfo(self, sockaddr, flags: int): + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, str]: raise NotImplementedError("FakeNet doesn't do fake DNS yet") From 4678d44ede954b8e3431b56d0c973bdebacd683a Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 13 Jul 2023 16:12:23 +0200 Subject: [PATCH 03/49] typecheck trio/_dtls.py --- pyproject.toml | 7 ++ trio/_channel.py | 10 +- trio/_dtls.py | 195 +++++++++++++++++++++------------- trio/_tests/verify_types.json | 21 +--- 4 files changed, 140 insertions(+), 93 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cfb4060ee7..ee0c019af7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,13 @@ disallow_untyped_defs = false # DO NOT use `ignore_errors`; it doesn't apply # downstream and users have to deal with them. +[[tool.mypy.overrides]] +module = [ + "trio._dtls" +] +disallow_incomplete_defs = true +disallow_untyped_defs = true + [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] faulthandler_timeout = 60 diff --git a/trio/_channel.py b/trio/_channel.py index 7c8ff4660d..df596adddd 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -20,7 +20,7 @@ def _open_memory_channel( - max_buffer_size: int, + max_buffer_size: int | float, ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: """Open a channel for passing objects between tasks within a process. @@ -92,11 +92,11 @@ def _open_memory_channel( # Need to use Tuple instead of tuple due to CI check running on 3.8 class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]): def __new__( # type: ignore[misc] # "must return a subtype" - cls, max_buffer_size: int + cls, max_buffer_size: int | float ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: return _open_memory_channel(max_buffer_size) - def __init__(self, max_buffer_size: int): + def __init__(self, max_buffer_size: int | float): ... else: @@ -108,7 +108,7 @@ def __init__(self, max_buffer_size: int): @attr.s(frozen=True, slots=True) class MemoryChannelStats: current_buffer_used: int = attr.ib() - max_buffer_size: int = attr.ib() + max_buffer_size: int | float = attr.ib() open_send_channels: int = attr.ib() open_receive_channels: int = attr.ib() tasks_waiting_send: int = attr.ib() @@ -117,7 +117,7 @@ class MemoryChannelStats: @attr.s(slots=True) class MemoryChannelState(Generic[T]): - max_buffer_size: int = attr.ib() + max_buffer_size: int | float = attr.ib() data: deque[T] = attr.ib(factory=deque) # Counts of open endpoints using this state open_send_channels: int = attr.ib(default=0) diff --git a/trio/_dtls.py b/trio/_dtls.py index 722a9499f8..acc5f950eb 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -16,34 +16,53 @@ import warnings import weakref from itertools import count -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Iterable, + Iterator, + TypeVar, + Union, + cast, +) import attr +from OpenSSL import SSL import trio -from trio._util import Final, NoPublicConstructor + +from ._util import Final, NoPublicConstructor if TYPE_CHECKING: from types import TracebackType + from OpenSSL.SSL import Context + from typing_extensions import Self, TypeAlias + + from ._core._run import _TaskStatus + from ._socket import _SocketType + MAX_UDP_PACKET_SIZE = 65527 -def packet_header_overhead(sock): +def packet_header_overhead(sock: _SocketType) -> int: if sock.family == trio.socket.AF_INET: return 28 else: return 48 -def worst_case_mtu(sock): +def worst_case_mtu(sock: _SocketType) -> int: if sock.family == trio.socket.AF_INET: return 576 - packet_header_overhead(sock) else: return 1280 - packet_header_overhead(sock) -def best_guess_mtu(sock): +def best_guess_mtu(sock: _SocketType) -> int: return 1500 - packet_header_overhead(sock) @@ -105,14 +124,14 @@ class BadPacket(Exception): # ChangeCipherSpec is used during the handshake but has its own ContentType. # # Cannot fail. -def part_of_handshake_untrusted(packet): +def part_of_handshake_untrusted(packet: bytes) -> bool: # If the packet is too short, then slicing will successfully return a # short string, which will necessarily fail to match. return packet[3:5] == b"\x00\x00" # Cannot fail -def is_client_hello_untrusted(packet): +def is_client_hello_untrusted(packet: bytes) -> bool: try: return ( packet[0] == ContentType.handshake @@ -147,7 +166,7 @@ class Record: payload: bytes = attr.ib(repr=to_hex) -def records_untrusted(packet): +def records_untrusted(packet: bytes) -> Iterator[Record]: i = 0 while i < len(packet): try: @@ -165,7 +184,7 @@ def records_untrusted(packet): yield Record(ct, version, epoch_seqno, payload) -def encode_record(record): +def encode_record(record: Record) -> bytes: header = RECORD_HEADER.pack( record.content_type, record.version, @@ -194,7 +213,7 @@ class HandshakeFragment: frag: bytes = attr.ib(repr=to_hex) -def decode_handshake_fragment_untrusted(payload): +def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment: # Raises BadPacket if decoding fails try: ( @@ -224,7 +243,7 @@ def decode_handshake_fragment_untrusted(payload): ) -def encode_handshake_fragment(hsf): +def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes: hs_header = HANDSHAKE_MESSAGE_HEADER.pack( hsf.msg_type, hsf.msg_len.to_bytes(3, "big"), @@ -235,7 +254,7 @@ def encode_handshake_fragment(hsf): return hs_header + hsf.frag -def decode_client_hello_untrusted(packet): +def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]: # Raises BadPacket if parsing fails # Returns (record epoch_seqno, cookie from the packet, data that should be # hashed into cookie) @@ -331,12 +350,20 @@ class OpaqueHandshakeMessage: record: Record +# for some reason doesn't work with | +_AnyHandshakeMessage: TypeAlias = Union[ + HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage +] + + # This takes a raw outgoing handshake volley that openssl generated, and # reconstructs the handshake messages inside it, so that we can repack them # into records while retransmitting. So the data ought to be well-behaved -- # it's not coming from the network. -def decode_volley_trusted(volley): - messages = [] +def decode_volley_trusted( + volley: bytes, +) -> list[_AnyHandshakeMessage]: + messages: list[_AnyHandshakeMessage] = [] messages_by_seq = {} for record in records_untrusted(volley): # ChangeCipherSpec isn't a handshake message, so it can't be fragmented. @@ -380,13 +407,17 @@ def decode_volley_trusted(volley): class RecordEncoder: - def __init__(self): + def __init__(self) -> None: self._record_seq = count() - def set_first_record_number(self, n): + def set_first_record_number(self, n: int) -> None: self._record_seq = count(n) - def encode_volley(self, messages, mtu): + def encode_volley( + self, + messages: Iterable[_AnyHandshakeMessage], + mtu: int, + ) -> list[bytearray]: packets = [] packet = bytearray() for message in messages: @@ -518,13 +549,13 @@ def encode_volley(self, messages, mtu): COOKIE_LENGTH = 32 -def _current_cookie_tick(): +def _current_cookie_tick() -> int: return int(trio.current_time() / COOKIE_REFRESH_INTERVAL) # Simple deterministic and invertible serializer -- i.e., a useful tool for converting # structured data into something we can cryptographically sign. -def _signable(*fields): +def _signable(*fields: bytes) -> bytes: out = [] for field in fields: out.append(struct.pack("!Q", len(field))) @@ -532,7 +563,9 @@ def _signable(*fields): return b"".join(out) -def _make_cookie(key, salt, tick, address, client_hello_bits): +def _make_cookie( + key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes +) -> bytes: assert len(salt) == SALT_BYTES assert len(key) == KEY_BYTES @@ -548,7 +581,9 @@ def _make_cookie(key, salt, tick, address, client_hello_bits): return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] -def valid_cookie(key, cookie, address, client_hello_bits): +def valid_cookie( + key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes +) -> bool: if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] @@ -568,7 +603,9 @@ def valid_cookie(key, cookie, address, client_hello_bits): return False -def challenge_for(key, address, epoch_seqno, client_hello_bits): +def challenge_for( + key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes +) -> bytes: salt = os.urandom(SALT_BYTES) tick = _current_cookie_tick() cookie = _make_cookie(key, salt, tick, address, client_hello_bits) @@ -608,12 +645,15 @@ def challenge_for(key, address, epoch_seqno, client_hello_bits): return packet -class _Queue: - def __init__(self, incoming_packets_buffer): - self.s, self.r = trio.open_memory_channel(incoming_packets_buffer) +T = TypeVar("T") + +class _Queue(Generic[T]): + def __init__(self, incoming_packets_buffer: int | float): + self.s, self.r = trio.open_memory_channel[T](incoming_packets_buffer) -def _read_loop(read_fn): + +def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: chunks = [] while True: try: @@ -624,7 +664,9 @@ def _read_loop(read_fn): return b"".join(chunks) -async def handle_client_hello_untrusted(endpoint, address, packet): +async def handle_client_hello_untrusted( + endpoint: DTLSEndpoint, address: Any, packet: bytes +) -> None: if endpoint._listening_context is None: return @@ -697,7 +739,9 @@ async def handle_client_hello_untrusted(endpoint, address, packet): endpoint._incoming_connections_q.s.send_nowait(stream) -async def dtls_receive_loop(endpoint_ref, sock): +async def dtls_receive_loop( + endpoint_ref: weakref.ReferenceType[DTLSEndpoint], sock: _SocketType +) -> None: try: while True: try: @@ -773,7 +817,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint, peer_address, ctx): + def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -784,25 +828,27 @@ def __init__(self, endpoint, peer_address, ctx): # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to # support and isn't useful anyway -- especially for DTLS where it's equivalent # to just performing a new handshake. - ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) + ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined] self._ssl = SSL.Connection(ctx) - self._handshake_mtu = None + self._handshake_mtu = 0 # This calls self._ssl.set_ciphertext_mtu, which is important, because if you # don't call it then openssl doesn't work. self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket)) self._replaced = False self._closed = False - self._q = _Queue(endpoint.incoming_packets_buffer) + self._q = _Queue[bytes](endpoint.incoming_packets_buffer) self._handshake_lock = trio.Lock() - self._record_encoder = RecordEncoder() + self._record_encoder: RecordEncoder = RecordEncoder() + + self._final_volley: list[_AnyHandshakeMessage] = [] - def _set_replaced(self): + def _set_replaced(self) -> None: self._replaced = True # Any packets we already received could maybe possibly still be processed, but # there are no more coming. So we close this on the sender side. self._q.s.close() - def _check_replaced(self): + def _check_replaced(self) -> None: if self._replaced: raise trio.BrokenResourceError( "peer tore down this connection to start a new one" @@ -836,7 +882,7 @@ def close(self) -> None: # ClosedResourceError self._q.r.close() - def __enter__(self): + def __enter__(self) -> Self: return self def __exit__( @@ -847,7 +893,7 @@ def __exit__( ) -> None: return self.close() - async def aclose(self): + async def aclose(self) -> None: """Close this connection, but asynchronously. This is included to satisfy the `trio.abc.Channel` contract. It's @@ -857,7 +903,7 @@ async def aclose(self): self.close() await trio.lowlevel.checkpoint() - async def _send_volley(self, volley_messages): + async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None: packets = self._record_encoder.encode_volley( volley_messages, self._handshake_mtu ) @@ -865,10 +911,10 @@ async def _send_volley(self, volley_messages): async with self.endpoint._send_lock: await self.endpoint.socket.sendto(packet, self.peer_address) - async def _resend_final_volley(self): + async def _resend_final_volley(self) -> None: await self._send_volley(self._final_volley) - async def do_handshake(self, *, initial_retransmit_timeout=1.0): + async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None: """Perform the handshake. Calling this is optional – if you don't, then it will be automatically called @@ -901,17 +947,19 @@ async def do_handshake(self, *, initial_retransmit_timeout=1.0): return timeout = initial_retransmit_timeout - volley_messages = [] + volley_messages: list[_AnyHandshakeMessage] = [] volley_failed_sends = 0 - def read_volley(): + def read_volley() -> list[_AnyHandshakeMessage]: volley_bytes = _read_loop(self._ssl.bio_read) new_volley_messages = decode_volley_trusted(volley_bytes) if ( new_volley_messages and volley_messages and isinstance(new_volley_messages[0], HandshakeMessage) - and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq + # TODO: add isinstance or do a cast? + and new_volley_messages[0].msg_seq + == cast(HandshakeMessage, volley_messages[0]).msg_seq ): # openssl decided to retransmit; discard because we handle # retransmits ourselves @@ -995,10 +1043,13 @@ def read_volley(): # PMTU estimate is wrong? Let's try dropping it to the minimum # and hope that helps. self._handshake_mtu = min( - self._handshake_mtu, worst_case_mtu(self.endpoint.socket) + self._handshake_mtu, + worst_case_mtu(self.endpoint.socket), ) - async def send(self, data): + async def send( + self, data: bytes + ) -> None: # or str? SendChannel defines it as bytes """Send a packet of data, securely.""" if self._closed: @@ -1014,7 +1065,7 @@ async def send(self, data): _read_loop(self._ssl.bio_read), self.peer_address ) - async def receive(self): + async def receive(self) -> bytes: # or str? """Fetch the next packet of data from this connection's peer, waiting if necessary. @@ -1040,7 +1091,7 @@ async def receive(self): if cleartext: return cleartext - def set_ciphertext_mtu(self, new_mtu): + def set_ciphertext_mtu(self, new_mtu: int) -> None: """Tells Trio the `largest amount of data that can be sent in a single packet to this peer `__. @@ -1075,7 +1126,7 @@ def set_ciphertext_mtu(self, new_mtu): self._handshake_mtu = new_mtu self._ssl.set_ciphertext_mtu(new_mtu) - def get_cleartext_mtu(self): + def get_cleartext_mtu(self) -> int: """Returns the largest number of bytes that you can pass in a single call to `send` while still fitting within the network-level MTU. @@ -1084,9 +1135,9 @@ def get_cleartext_mtu(self): """ if not self._did_handshake: raise trio.NeedHandshakeError - return self._ssl.get_cleartext_mtu() + return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return] - def statistics(self): + def statistics(self) -> DTLSChannelStatistics: """Returns an object with statistics about this connection. Currently this has only one attribute: @@ -1126,18 +1177,18 @@ class DTLSEndpoint(metaclass=Final): """ - def __init__(self, socket, *, incoming_packets_buffer=10): + def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. global SSL from OpenSSL import SSL - # TODO: create a `self._initialized` for `__del__`, so self.socket can be typed - # as trio.socket.SocketType and `is not None` checks can be removed. - self.socket = None # for __del__, in case the next line raises + # for __del__, in case the next line raises + self._initialized: bool = False if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") - self.socket = socket + self._initialized = True + self.socket: _SocketType = socket self.incoming_packets_buffer = incoming_packets_buffer self._token = trio.lowlevel.current_trio_token() @@ -1146,15 +1197,15 @@ def __init__(self, socket, *, incoming_packets_buffer=10): # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams = weakref.WeakValueDictionary() - self._listening_context = None - self._listening_key = None - self._incoming_connections_q = _Queue(float("inf")) + self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self._listening_context: Context | None = None + self._listening_key: bytes | None = None + self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) self._send_lock = trio.Lock() self._closed = False self._receive_loop_spawned = False - def _ensure_receive_loop(self): + def _ensure_receive_loop(self) -> None: # We have to spawn this lazily, because on Windows it will immediately error out # if the socket isn't already bound -- which for clients might not happen until # after we send our first packet. @@ -1164,9 +1215,9 @@ def _ensure_receive_loop(self): ) self._receive_loop_spawned = True - def __del__(self): + def __del__(self) -> None: # Do nothing if this object was never fully constructed - if self.socket is None: + if not self._initialized: return # Close the socket in Trio context (if our Trio context still exists), so that # the background task gets notified about the closure and can exit. @@ -1186,17 +1237,13 @@ def close(self) -> None: This object can also be used as a context manager. """ - # Do nothing if this object was never fully constructed - if self.socket is None: # pragma: no cover - return - self._closed = True self.socket.close() for stream in list(self._streams.values()): stream.close() self._incoming_connections_q.s.close() - def __enter__(self): + def __enter__(self) -> Self: return self def __exit__( @@ -1207,13 +1254,17 @@ def __exit__( ) -> None: return self.close() - def _check_closed(self): + def _check_closed(self) -> None: if self._closed: raise trio.ClosedResourceError async def serve( - self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED - ): + self, + ssl_context: Context, + async_fn: Callable[[DTLSChannel], Awaitable], + *args: Any, + task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, + ) -> None: """Listen for incoming connections, and spawn a handler for each using an internal nursery. @@ -1257,7 +1308,7 @@ async def handler(dtls_channel): self._listening_context = ssl_context task_status.started() - async def handler_wrapper(stream): + async def handler_wrapper(stream: DTLSChannel) -> None: with stream: await async_fn(stream, *args) @@ -1267,7 +1318,7 @@ async def handler_wrapper(stream): finally: self._listening_context = None - def connect(self, address, ssl_context): + def connect(self, address: tuple[str, int], ssl_context: Context) -> DTLSChannel: """Initiate an outgoing DTLS connection. Notice that this is a synchronous method. That's because it doesn't actually diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 9d7d7aa912..cf1e7eccfb 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -45,9 +45,9 @@ } ], "otherSymbolCounts": { - "withAmbiguousType": 8, - "withKnownType": 433, - "withUnknownType": 135 + "withAmbiguousType": 6, + "withKnownType": 460, + "withUnknownType": 129 }, "packageName": "trio", "symbols": [ @@ -73,6 +73,8 @@ "trio._core._mock_clock.MockClock.jump", "trio._core._run.Nursery.start", "trio._core._run.Nursery.start_soon", + "trio._core._run._TaskStatus.__repr__", + "trio._core._run._TaskStatus.started", "trio._core._unbounded_queue.UnboundedQueue.__aiter__", "trio._core._unbounded_queue.UnboundedQueue.__anext__", "trio._core._unbounded_queue.UnboundedQueue.__repr__", @@ -81,22 +83,9 @@ "trio._core._unbounded_queue.UnboundedQueue.get_batch_nowait", "trio._core._unbounded_queue.UnboundedQueue.qsize", "trio._core._unbounded_queue.UnboundedQueue.statistics", - "trio._dtls.DTLSChannel.__enter__", "trio._dtls.DTLSChannel.__init__", - "trio._dtls.DTLSChannel.aclose", - "trio._dtls.DTLSChannel.do_handshake", - "trio._dtls.DTLSChannel.get_cleartext_mtu", - "trio._dtls.DTLSChannel.receive", - "trio._dtls.DTLSChannel.send", - "trio._dtls.DTLSChannel.set_ciphertext_mtu", - "trio._dtls.DTLSChannel.statistics", - "trio._dtls.DTLSEndpoint.__del__", - "trio._dtls.DTLSEndpoint.__enter__", "trio._dtls.DTLSEndpoint.__init__", - "trio._dtls.DTLSEndpoint.connect", - "trio._dtls.DTLSEndpoint.incoming_packets_buffer", "trio._dtls.DTLSEndpoint.serve", - "trio._dtls.DTLSEndpoint.socket", "trio._highlevel_socket.SocketListener", "trio._highlevel_socket.SocketListener.__init__", "trio._highlevel_socket.SocketStream.__init__", From 2522bc529ca95602da1cb9b8aa1862dab3de69b2 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 13 Jul 2023 16:38:09 +0200 Subject: [PATCH 04/49] incorporate _abc --- trio/_abc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index 0bb49e207d..402bb78b27 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,19 +1,19 @@ from __future__ import annotations +import socket from abc import ABCMeta, abstractmethod from typing import TYPE_CHECKING, Generic, TypeVar import trio if TYPE_CHECKING: - import socket from types import TracebackType from typing_extensions import Self - from trio.lowlevel import Task - + # both of these introduce circular imports if outside a TYPE_CHECKING guard from ._socket import _SocketType + from .lowlevel import Task # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a From 909ac67fba675f036fcea18b9ba7b1eacc4d4ccb Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 13 Jul 2023 16:42:30 +0200 Subject: [PATCH 05/49] incorporate _dtls --- pyproject.toml | 9 +++- trio/_channel.py | 10 ++--- trio/_dtls.py | 108 +++++++++++++++++++++++++++++------------------ 3 files changed, 79 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 121a398234..fa46d76bc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ warn_return_any = true # Avoid subtle backsliding #disallow_any_decorated = true -#disallow_incomplete_defs = true +disallow_incomplete_defs = true #disallow_subclassing_any = true # Enable gradually / for new modules @@ -48,6 +48,13 @@ disallow_untyped_defs = false module = "trio._core._run" disallow_incomplete_defs = false +[[tool.mypy.overrides]] +module = [ + "trio._abc", + "trio._dtls" +] +disallow_untyped_defs = true + [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] faulthandler_timeout = 60 diff --git a/trio/_channel.py b/trio/_channel.py index 7c8ff4660d..df596adddd 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -20,7 +20,7 @@ def _open_memory_channel( - max_buffer_size: int, + max_buffer_size: int | float, ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: """Open a channel for passing objects between tasks within a process. @@ -92,11 +92,11 @@ def _open_memory_channel( # Need to use Tuple instead of tuple due to CI check running on 3.8 class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]): def __new__( # type: ignore[misc] # "must return a subtype" - cls, max_buffer_size: int + cls, max_buffer_size: int | float ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: return _open_memory_channel(max_buffer_size) - def __init__(self, max_buffer_size: int): + def __init__(self, max_buffer_size: int | float): ... else: @@ -108,7 +108,7 @@ def __init__(self, max_buffer_size: int): @attr.s(frozen=True, slots=True) class MemoryChannelStats: current_buffer_used: int = attr.ib() - max_buffer_size: int = attr.ib() + max_buffer_size: int | float = attr.ib() open_send_channels: int = attr.ib() open_receive_channels: int = attr.ib() tasks_waiting_send: int = attr.ib() @@ -117,7 +117,7 @@ class MemoryChannelStats: @attr.s(slots=True) class MemoryChannelState(Generic[T]): - max_buffer_size: int = attr.ib() + max_buffer_size: int | float = attr.ib() data: deque[T] = attr.ib(factory=deque) # Counts of open endpoints using this state open_send_channels: int = attr.ib(default=0) diff --git a/trio/_dtls.py b/trio/_dtls.py index aea15be735..8aede92bc8 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -16,23 +16,34 @@ import warnings import weakref from itertools import count -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Iterator, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Iterable, + Iterator, + TypeVar, + Union, + cast, +) import attr from OpenSSL import SSL import trio -from trio._util import Final, NoPublicConstructor + +from ._util import Final, NoPublicConstructor if TYPE_CHECKING: from types import TracebackType from OpenSSL.SSL import Context - from typing_extensions import Self - - from trio._socket import _SocketType + from typing_extensions import Self, TypeAlias from ._core._run import _TaskStatus + from ._socket import _SocketType MAX_UDP_PACKET_SIZE = 65527 @@ -339,16 +350,20 @@ class OpaqueHandshakeMessage: record: Record +# for some reason doesn't work with | +_AnyHandshakeMessage: TypeAlias = Union[ + HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage +] + + # This takes a raw outgoing handshake volley that openssl generated, and # reconstructs the handshake messages inside it, so that we can repack them # into records while retransmitting. So the data ought to be well-behaved -- # it's not coming from the network. def decode_volley_trusted( volley: bytes, -) -> list[HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage]: - messages: list[ - HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage - ] = [] +) -> list[_AnyHandshakeMessage]: + messages: list[_AnyHandshakeMessage] = [] messages_by_seq = {} for record in records_untrusted(volley): # ChangeCipherSpec isn't a handshake message, so it can't be fragmented. @@ -392,7 +407,7 @@ def decode_volley_trusted( class RecordEncoder: - def __init__(self): + def __init__(self) -> None: self._record_seq = count() def set_first_record_number(self, n: int) -> None: @@ -400,9 +415,7 @@ def set_first_record_number(self, n: int) -> None: def encode_volley( self, - messages: Iterable[ - HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage - ], + messages: Iterable[_AnyHandshakeMessage], mtu: int, ) -> list[bytearray]: packets = [] @@ -550,7 +563,9 @@ def _signable(*fields: bytes) -> bytes: return b"".join(out) -def _make_cookie(key, salt, tick, address, client_hello_bits): +def _make_cookie( + key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes +) -> bytes: assert len(salt) == SALT_BYTES assert len(key) == KEY_BYTES @@ -566,7 +581,9 @@ def _make_cookie(key, salt, tick, address, client_hello_bits): return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] -def valid_cookie(key, cookie, address, client_hello_bits): +def valid_cookie( + key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes +) -> bool: if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] @@ -586,7 +603,9 @@ def valid_cookie(key, cookie, address, client_hello_bits): return False -def challenge_for(key, address, epoch_seqno, client_hello_bits): +def challenge_for( + key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes +) -> bytes: salt = os.urandom(SALT_BYTES) tick = _current_cookie_tick() cookie = _make_cookie(key, salt, tick, address, client_hello_bits) @@ -626,9 +645,12 @@ def challenge_for(key, address, epoch_seqno, client_hello_bits): return packet -class _Queue: - def __init__(self, incoming_packets_buffer): - self.s, self.r = trio.open_memory_channel(incoming_packets_buffer) +T = TypeVar("T") + + +class _Queue(Generic[T]): + def __init__(self, incoming_packets_buffer: int | float): + self.s, self.r = trio.open_memory_channel[T](incoming_packets_buffer) def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: @@ -642,7 +664,9 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: return b"".join(chunks) -async def handle_client_hello_untrusted(endpoint, address, packet): +async def handle_client_hello_untrusted( + endpoint: DTLSEndpoint, address: Any, packet: bytes +) -> None: if endpoint._listening_context is None: return @@ -715,7 +739,9 @@ async def handle_client_hello_untrusted(endpoint, address, packet): endpoint._incoming_connections_q.s.send_nowait(stream) -async def dtls_receive_loop(endpoint_ref, sock): +async def dtls_receive_loop( + endpoint_ref: weakref.ReferenceType[DTLSEndpoint], sock: _SocketType +) -> None: try: while True: try: @@ -791,7 +817,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint: DTLSEndpoint, peer_address: str, ctx: Context): + def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -804,23 +830,25 @@ def __init__(self, endpoint: DTLSEndpoint, peer_address: str, ctx: Context): # to just performing a new handshake. ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined] self._ssl = SSL.Connection(ctx) - self._handshake_mtu: int | None = None + self._handshake_mtu = 0 # This calls self._ssl.set_ciphertext_mtu, which is important, because if you # don't call it then openssl doesn't work. self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket)) self._replaced = False self._closed = False - self._q = _Queue(endpoint.incoming_packets_buffer) + self._q = _Queue[bytes](endpoint.incoming_packets_buffer) self._handshake_lock = trio.Lock() - self._record_encoder = RecordEncoder() + self._record_encoder: RecordEncoder = RecordEncoder() + + self._final_volley: list[_AnyHandshakeMessage] = [] - def _set_replaced(self): + def _set_replaced(self) -> None: self._replaced = True # Any packets we already received could maybe possibly still be processed, but # there are no more coming. So we close this on the sender side. self._q.s.close() - def _check_replaced(self): + def _check_replaced(self) -> None: if self._replaced: raise trio.BrokenResourceError( "peer tore down this connection to start a new one" @@ -875,7 +903,7 @@ async def aclose(self) -> None: self.close() await trio.lowlevel.checkpoint() - async def _send_volley(self, volley_messages): + async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None: packets = self._record_encoder.encode_volley( volley_messages, self._handshake_mtu ) @@ -883,7 +911,7 @@ async def _send_volley(self, volley_messages): async with self.endpoint._send_lock: await self.endpoint.socket.sendto(packet, self.peer_address) - async def _resend_final_volley(self): + async def _resend_final_volley(self) -> None: await self._send_volley(self._final_volley) async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None: @@ -919,14 +947,10 @@ async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None return timeout = initial_retransmit_timeout - volley_messages: list[ - HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage - ] = [] + volley_messages: list[_AnyHandshakeMessage] = [] volley_failed_sends = 0 - def read_volley() -> ( - list[HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage] - ): + def read_volley() -> list[_AnyHandshakeMessage]: volley_bytes = _read_loop(self._ssl.bio_read) new_volley_messages = decode_volley_trusted(volley_bytes) if ( @@ -1019,7 +1043,7 @@ def read_volley() -> ( # PMTU estimate is wrong? Let's try dropping it to the minimum # and hope that helps. self._handshake_mtu = min( - self._handshake_mtu or 0, + self._handshake_mtu, worst_case_mtu(self.endpoint.socket), ) @@ -1175,13 +1199,13 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): # {remote address: DTLSChannel} self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary() self._listening_context: Context | None = None - self._listening_key = None - self._incoming_connections_q = _Queue(float("inf")) + self._listening_key: bytes | None = None + self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) self._send_lock = trio.Lock() self._closed = False self._receive_loop_spawned = False - def _ensure_receive_loop(self): + def _ensure_receive_loop(self) -> None: # We have to spawn this lazily, because on Windows it will immediately error out # if the socket isn't already bound -- which for clients might not happen until # after we send our first packet. @@ -1237,9 +1261,9 @@ def _check_closed(self) -> None: async def serve( self, ssl_context: Context, - async_fn: Callable[..., Awaitable], + async_fn: Callable[[DTLSChannel], Awaitable], *args: Any, - task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] # ??? + task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] ) -> None: """Listen for incoming connections, and spawn a handler for each using an internal nursery. @@ -1284,7 +1308,7 @@ async def handler(dtls_channel): self._listening_context = ssl_context task_status.started() - async def handler_wrapper(stream): + async def handler_wrapper(stream: DTLSChannel) -> None: with stream: await async_fn(stream, *args) From 40ff89f65c22397c4ecae52be49fe5b7d5974b34 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 13 Jul 2023 16:51:33 +0200 Subject: [PATCH 06/49] ignore weird error with TASK_STATUS_IGNORED, add pyOpenSSL to docs-requirements --- docs-requirements.in | 3 +++ docs-requirements.txt | 8 ++++++++ trio/_dtls.py | 15 ++++++++++++++- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs-requirements.in b/docs-requirements.in index 98d5030bc5..d6214ec1d0 100644 --- a/docs-requirements.in +++ b/docs-requirements.in @@ -19,3 +19,6 @@ exceptiongroup >= 1.0.0rc9 # See note in test-requirements.in immutables >= 0.6 + +# types used in annotations +pyOpenSSL diff --git a/docs-requirements.txt b/docs-requirements.txt index 06136fd765..c607f4f186 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -16,6 +16,8 @@ babel==2.12.1 # via sphinx certifi==2023.5.7 # via requests +cffi==1.15.1 + # via cryptography charset-normalizer==3.1.0 # via requests click==8.1.3 @@ -24,6 +26,8 @@ click==8.1.3 # towncrier click-default-group==1.2.2 # via towncrier +cryptography==41.0.2 + # via pyopenssl docutils==0.18.1 # via # sphinx @@ -55,8 +59,12 @@ outcome==1.2.0 # via -r docs-requirements.in packaging==23.1 # via sphinx +pycparser==2.21 + # via cffi pygments==2.15.1 # via sphinx +pyopenssl==23.2.0 + # via -r docs-requirements.in pytz==2023.3 # via babel requests==2.31.0 diff --git a/trio/_dtls.py b/trio/_dtls.py index acc5f950eb..0af8340732 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -798,6 +798,19 @@ async def dtls_receive_loop( @attr.frozen class DTLSChannelStatistics: + """An object with statistics about this connection. + + Currently this has only one attribute: + + - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of + incoming packets from this peer that Trio successfully received from the + network, but then got dropped because the internal channel buffer was full. If + this is non-zero, then you might want to call ``receive`` more often, or use a + larger ``incoming_packets_buffer``, or just not worry about it because your + UDP-based protocol should be able to handle the occasional lost packet, right? + + """ + incoming_packets_dropped_in_trio: int @@ -1263,7 +1276,7 @@ async def serve( ssl_context: Context, async_fn: Callable[[DTLSChannel], Awaitable], *args: Any, - task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, + task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] ) -> None: """Listen for incoming connections, and spawn a handler for each using an internal nursery. From f18abdd9c92759528fcb1ae746ba16034c926c2d Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 14 Jul 2023 16:35:03 +0200 Subject: [PATCH 07/49] socket is done - other than getting rid of the _SocketType <-> SocketType distinction --- pyproject.toml | 51 ++++- trio/_core/_io_epoll.py | 4 +- trio/_core/_local.py | 24 ++- trio/_core/_run.py | 2 +- trio/_core/_thread_cache.py | 2 + trio/_socket.py | 299 ++++++++++++++++++++------- trio/_subprocess_platform/windows.py | 2 +- trio/_tests/verify_types.json | 27 +-- 8 files changed, 307 insertions(+), 104 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fa46d76bc7..aa1644b443 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ disallow_incomplete_defs = true # Enable gradually / for new modules check_untyped_defs = false disallow_untyped_calls = false -disallow_untyped_defs = false +disallow_untyped_defs = true # DO NOT use `ignore_errors`; it doesn't apply # downstream and users have to deal with them. @@ -47,14 +47,61 @@ disallow_untyped_defs = false [[tool.mypy.overrides]] module = "trio._core._run" disallow_incomplete_defs = false +disallow_untyped_defs = false [[tool.mypy.overrides]] module = [ "trio._abc", - "trio._dtls" + "trio._dtls", + "trio._socket", ] disallow_untyped_defs = true +[[tool.mypy.overrides]] +module = [ +"trio/_core/_asyncgens", +"trio/_core/_entry_queue", +"trio/_core/_generated_io_epoll", +"trio/_core/_generated_io_kqueue", +"trio/_core/_generated_io_windows", +"trio/_core/_generated_run", +"trio/_core/_io_common", +"trio/_core/_io_epoll", +"trio/_core/_io_kqueue", +"trio/_core/_io_windows", +"trio/_core/_ki", +"trio/_core/_multierror", +"trio/_core/_parking_lot", +"trio/_core/_thread_cache", +"trio/_core/_traps", +"trio/_core/_wakeup_socketpair", +"trio/_core/_windows_cffi", +"trio/_deprecate", +"trio/_file_io", +"trio/_highlevel_open_tcp_listeners", +"trio/_highlevel_open_tcp_stream", +"trio/_highlevel_open_unix_stream", +"trio/_highlevel_serve_listeners", +"trio/_highlevel_socket", +"trio/_highlevel_ssl_helpers", +"trio/_path", +"trio/_signals", +"trio/_ssl", +"trio/_subprocess", +"trio/_subprocess_platform/kqueue", +"trio/_subprocess_platform/waitid", +"trio/_sync", +"trio/_threads", +"trio/_util", +"trio/_wait_for_object", +"trio/testing/_check_streams", +"trio/testing/_checkpoints", +"trio/testing/_memory_streams", +"trio/testing/_network", +"trio/testing/_trio_test", +] +disallow_untyped_defs = false + [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] faulthandler_timeout = 60 diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index fbeb454c7d..9d7b250785 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -3,7 +3,7 @@ import select import sys from collections import defaultdict -from typing import TYPE_CHECKING, DefaultDict, Dict +from typing import TYPE_CHECKING, DefaultDict import attr @@ -192,7 +192,7 @@ class EpollIOManager: _epoll: select.epoll = attr.ib(factory=select.epoll) # {fd: EpollWaiters} _registered: DefaultDict[int, EpollWaiters] = attr.ib( - factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters] + factory=lambda: defaultdict(EpollWaiters) ) _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) _force_wakeup_fd: int | None = attr.ib(default=None) diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 89ccf93e95..fe509ca7ad 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generic, TypeVar +from typing import Generic, TypeVar, overload # Runvar implementations import attr @@ -12,12 +12,16 @@ C = TypeVar("C", bound="_RunVarToken") +class NoValue(object): + ... + + @attr.s(eq=False, hash=False, slots=True) class _RunVarToken(Generic[T]): - _no_value = None + _no_value = NoValue() _var: RunVar[T] = attr.ib() - previous_value: T | None = attr.ib(default=_no_value) + previous_value: T | NoValue = attr.ib(default=_no_value) redeemed: bool = attr.ib(default=False, init=False) @classmethod @@ -35,11 +39,19 @@ class RunVar(Generic[T], metaclass=Final): """ - _NO_DEFAULT = None + _NO_DEFAULT = NoValue() _name: str = attr.ib() - _default: T | None = attr.ib(default=_NO_DEFAULT) + _default: T | NoValue = attr.ib(default=_NO_DEFAULT) + + @overload + def get(self, default: T) -> T: + ... + + @overload + def get(self, default: NoValue = _NO_DEFAULT) -> T | NoValue: + ... - def get(self, default: T | None = _NO_DEFAULT) -> T | None: + def get(self, default: T | NoValue = _NO_DEFAULT) -> T | NoValue: """Gets the value of this :class:`RunVar` for the current run call.""" try: # not typed yet diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 804a958714..ecc9138b23 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1392,7 +1392,7 @@ class GuestState: done_callback: Callable = attr.ib() unrolled_run_gen = attr.ib() _value_factory: Callable[[], Value] = lambda: Value(None) - unrolled_run_next_send: Outcome = attr.ib(factory=_value_factory, type=Outcome) + unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome) def guest_tick(self): try: diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py index 3e27ce6a32..157f14c5a1 100644 --- a/trio/_core/_thread_cache.py +++ b/trio/_core/_thread_cache.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ctypes import ctypes.util import sys diff --git a/trio/_socket.py b/trio/_socket.py index 0f5aa75fd2..d492bbc41f 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -5,9 +5,22 @@ import socket as _stdlib_socket import sys from functools import wraps as _wraps -from typing import TYPE_CHECKING, Any, Awaitable, Callable, SupportsIndex +from socket import AddressFamily, SocketKind +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + NoReturn, + SupportsIndex, + Tuple, + TypeVar, + Union, + overload, +) import idna as _idna +from typing_extensions import Concatenate, ParamSpec import trio @@ -17,11 +30,20 @@ from collections.abc import Iterable from types import TracebackType - from typing_extensions import Self + from typing_extensions import Buffer, Self, TypeAlias from ._abc import HostnameResolver, SocketFactory +T = TypeVar("T") +P = ParamSpec("P") + +# must use old-style typing for TypeAlias +Address: TypeAlias = Union[ + str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int] +] + + # Usage: # # async with _try_sync(): @@ -31,16 +53,18 @@ # return await do_it_properly_with_a_check_point() # class _try_sync: - def __init__(self, blocking_exc_override=None): + def __init__( + self, blocking_exc_override: Callable[[BaseException], bool] | None = None + ): self._blocking_exc_override = blocking_exc_override - def _is_blocking_io_error(self, exc): + def _is_blocking_io_error(self, exc: BaseException) -> bool: if self._blocking_exc_override is None: return isinstance(exc, BlockingIOError) else: return self._blocking_exc_override(exc) - async def __aenter__(self): + async def __aenter__(self) -> None: await trio.lowlevel.checkpoint_if_cancelled() async def __aexit__( @@ -76,7 +100,7 @@ async def __aexit__( ################################################################ _resolver: _core.RunVar[HostnameResolver | None] = _core.RunVar("hostname_resolver") -_socket_factory: _core.RunVar[SocketFactory] = _core.RunVar("socket_factory") +_socket_factory: _core.RunVar[SocketFactory | None] = _core.RunVar("socket_factory") def set_custom_hostname_resolver( @@ -113,7 +137,9 @@ def set_custom_hostname_resolver( return old -def set_custom_socket_factory(socket_factory): +def set_custom_socket_factory( + socket_factory: SocketFactory | None, +) -> SocketFactory | None: """Set a custom socket object factory. This function allows you to replace Trio's normal socket class with a @@ -147,6 +173,7 @@ def set_custom_socket_factory(socket_factory): _NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV +# It would be possible to @overload the return value depending on Literal[AddressFamily.INET/6], but should probably be added in typeshed first async def getaddrinfo( host: bytes | str | None, port: bytes | str | int | None, @@ -156,8 +183,8 @@ async def getaddrinfo( flags: int = 0, ) -> list[ tuple[ - _stdlib_socket.AddressFamily, - _stdlib_socket.SocketKind, + AddressFamily, + SocketKind, int, str, tuple[str, int] | tuple[str, int, int, int], @@ -183,7 +210,7 @@ async def getaddrinfo( # skip the whole thread thing, which seems worthwhile. So we try first # with the _NUMERIC_ONLY flags set, and then only spawn a thread if that # fails with EAI_NONAME: - def numeric_only_failure(exc): + def numeric_only_failure(exc: BaseException) -> bool: return ( isinstance(exc, _stdlib_socket.gaierror) and exc.errno == _stdlib_socket.EAI_NONAME @@ -246,7 +273,7 @@ async def getnameinfo( ) -async def getprotobyname(name): +async def getprotobyname(name: str) -> int: """Look up a protocol number by name. (Rarely used.) Like :func:`socket.getprotobyname`, but async. @@ -276,12 +303,12 @@ def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType: @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) def fromfd( fd: SupportsIndex, - family: _stdlib_socket.AddressFamily | int = _stdlib_socket.AF_INET, - type: _stdlib_socket.SocketKind | int = _stdlib_socket.SOCK_STREAM, + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, ) -> _SocketType: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" - family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd) + family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, int(fd)) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -290,24 +317,38 @@ def fromfd( ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(*args, **kwargs): - return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs)) + def fromshare(info: bytes) -> _SocketType: + return from_stdlib_socket(_stdlib_socket.fromshare(info)) + + +if sys.platform == "win32": + FamilyT = int + TypeT = int + FamilyDefault = _stdlib_socket.AF_INET +else: + FamilyDefault = None + FamilyT = Union[int, AddressFamily, None] + TypeT = Union[_stdlib_socket.socket, int] @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) -def socketpair(*args, **kwargs): +def socketpair( + family: FamilyT = FamilyDefault, + type: TypeT = SocketKind.SOCK_STREAM, + proto: int = 0, +) -> tuple[_SocketType, _SocketType]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. """ - left, right = _stdlib_socket.socketpair(*args, **kwargs) + left, right = _stdlib_socket.socketpair(family, type, proto) return (from_stdlib_socket(left), from_stdlib_socket(right)) @_wraps(_stdlib_socket.socket, assigned=(), updated=()) def socket( - family: _stdlib_socket.AddressFamily | int = _stdlib_socket.AF_INET, - type: _stdlib_socket.SocketKind | int = _stdlib_socket.SOCK_STREAM, + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, fileno: int | None = None, ) -> _SocketType: @@ -327,14 +368,24 @@ def socket( return from_stdlib_socket(stdlib_socket) -def _sniff_sockopts_for_fileno(family, type, proto, fileno): +def _sniff_sockopts_for_fileno( + family: AddressFamily | int, + type: SocketKind | int, + proto: int, + fileno: int | None, +) -> tuple[AddressFamily | int, SocketKind | int, int]: """Correct SOCKOPTS for given fileno, falling back to provided values.""" # Wrap the raw fileno into a Python socket object # This object might have the wrong metadata, but it lets us easily call getsockopt # and then we'll throw it away and construct a new one with the correct metadata. if sys.platform != "linux": return family, type, proto - from socket import SO_DOMAIN, SO_PROTOCOL, SO_TYPE, SOL_SOCKET + from socket import ( # type: ignore[attr-defined] + SO_DOMAIN, + SO_PROTOCOL, + SO_TYPE, + SOL_SOCKET, + ) sockobj = _stdlib_socket.socket(family, type, proto, fileno=fileno) try: @@ -364,19 +415,21 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno): ) -def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False): - fn = getattr(_stdlib_socket.socket, methname) - +def _make_simple_sock_method_wrapper( + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + wait_fn: Callable, + maybe_avail: bool = False, +) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]: @_wraps(fn, assigned=("__name__",), updated=()) - async def wrapper(self, *args, **kwargs): - return await self._nonblocking_helper(fn, args, kwargs, wait_fn) + async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T: + return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs) - wrapper.__doc__ = f"""Like :meth:`socket.socket.{methname}`, but async. + wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async. """ if maybe_avail: wrapper.__doc__ += ( - f"Only available on platforms where :meth:`socket.socket.{methname}` is " + f"Only available on platforms where :meth:`socket.socket.{fn.__name__}` is " "available." ) return wrapper @@ -395,8 +448,21 @@ async def wrapper(self, *args, **kwargs): # local=False means that the address is being used with connect() or sendto() or # similar. # + + +# Using a TypeVar to indicate we return the same type of address appears to give errors +# when passed a union of address types. +# @overload likely works, but is extremely verbose. # NOTE: this function does not always checkpoint -async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, local): +async def _resolve_address_nocp( + type: int, + family: AddressFamily, + proto: int, + *, + ipv6_v6only: bool | int, + address: Address, + local: bool, +) -> Address: # Do some pre-checking (or exit early for non-IP sockets) if family == _stdlib_socket.AF_INET: if not isinstance(address, tuple) or not len(address) == 2: @@ -406,13 +472,15 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo raise ValueError( "address should be a (host, port, [flowinfo, [scopeid]]) tuple" ) - elif family == _stdlib_socket.AF_UNIX: + elif family == getattr(_stdlib_socket, "AF_UNIX"): # unwrap path-likes + assert isinstance(address, (str, bytes)) return os.fspath(address) else: return address # -- From here on we know we have IPv4 or IPV6 -- + host: str | None host, port, *_ = address # Fast path for the simple case: already-resolved IP address, # already-resolved port. This is particularly important for UDP, since @@ -450,20 +518,20 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo # The above ignored any flowid and scopeid in the passed-in address, # so restore them if present: if family == _stdlib_socket.AF_INET6: - normed = list(normed) + list_normed = list(normed) assert len(normed) == 4 if len(address) >= 3: - normed[2] = address[2] + list_normed[2] = address[2] # type: ignore if len(address) >= 4: - normed[3] = address[3] - normed = tuple(normed) + list_normed[3] = address[3] # type: ignore + return tuple(list_normed) # type: ignore return normed # TODO: stopping users from initializing this type should be done in a different way, # so SocketType can be used as a type. class SocketType: - def __init__(self): + def __init__(self) -> NoReturn: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" @@ -529,11 +597,11 @@ def __exit__( return self._sock.__exit__(exc_type, exc_value, traceback) @property - def family(self) -> _stdlib_socket.AddressFamily: + def family(self) -> AddressFamily: return self._sock.family @property - def type(self) -> _stdlib_socket.SocketKind: + def type(self) -> SocketKind: return self._sock.type @property @@ -556,7 +624,7 @@ def close(self) -> None: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address: tuple[object, ...] | str | bytes) -> None: + async def bind(self, address: Address) -> None: address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") @@ -593,7 +661,12 @@ def is_readable(self) -> bool: async def wait_writable(self) -> None: await _core.wait_writable(self._sock) - async def _resolve_address_nocp(self, address, *, local): + async def _resolve_address_nocp( + self, + address: Address, + *, + local: bool, + ) -> Address: if self.family == _stdlib_socket.AF_INET6: ipv6_v6only = self._sock.getsockopt( IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY @@ -609,7 +682,19 @@ async def _resolve_address_nocp(self, address, *, local): local=local, ) - async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): + # args and kwargs must be starred, otherwise pyright complains: + # '"args" member of ParamSpec is valid only when used with *args parameter' + # '"kwargs" member of ParamSpec is valid only when used with **kwargs parameter' + # wait_fn and fn must also be first in the signature + # 'Keyword parameter cannot appear in signature after ParamSpec args parameter' + + async def _nonblocking_helper( + self, + wait_fn: Callable[[_stdlib_socket.socket], Awaitable], + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: # We have to reconcile two conflicting goals: # - We want to make it look like we always blocked in doing these # operations. The obvious way is to always do an IO wait before @@ -645,9 +730,9 @@ async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): # accept ################################################################ - _accept: Callable[ - [], Awaitable[tuple[_stdlib_socket.socket, object]] - ] = _make_simple_sock_method_wrapper("accept", _core.wait_readable) + _accept = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.accept, _core.wait_readable + ) async def accept(self) -> tuple[_SocketType, object]: """Like :meth:`socket.socket.accept`, but async.""" @@ -658,8 +743,7 @@ async def accept(self) -> tuple[_SocketType, object]: # connect ################################################################ - # TODO: typing addresses is ... a pain - async def connect(self, address: str) -> None: + async def connect(self, address: Address) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the @@ -727,38 +811,69 @@ async def connect(self, address: str) -> None: # Okay, the connect finished, but it might have failed: err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR) if err != 0: - raise OSError(err, f"Error connecting to {address}: {os.strerror(err)}") + raise OSError(err, f"Error connecting to {address!r}: {os.strerror(err)}") ################################################################ # recv ################################################################ + # Not possible to typecheck with a Callable (due to DefaultArg), nor with a + # callback Protocol (https://github.com/python/typing/discussions/1040) + # but this seems to work. If not explicitly defined then pyright --verifytypes will + # complain about AmbiguousType if TYPE_CHECKING: - async def recv(self, buffersize: int, flags: int = 0) -> bytes: + def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: ... - else: - recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable) + # _make_simple_sock_method_wrapper is typed, so this check that the above is correct + recv = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recv, _core.wait_readable + ) ################################################################ # recv_into ################################################################ - recv_into = _make_simple_sock_method_wrapper("recv_into", _core.wait_readable) + if TYPE_CHECKING: + + def recv_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: + ... + + recv_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recv_into, _core.wait_readable + ) ################################################################ # recvfrom ################################################################ - recvfrom = _make_simple_sock_method_wrapper("recvfrom", _core.wait_readable) + if TYPE_CHECKING: + # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] + def recvfrom( + __self, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, Address]]: + ... + + recvfrom = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvfrom, _core.wait_readable + ) ################################################################ # recvfrom_into ################################################################ - recvfrom_into = _make_simple_sock_method_wrapper( - "recvfrom_into", _core.wait_readable + if TYPE_CHECKING: + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[tuple[int, Address]]: + ... + + recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvfrom_into, _core.wait_readable ) ################################################################ @@ -766,8 +881,15 @@ async def recv(self, buffersize: int, flags: int = 0) -> bytes: ################################################################ if hasattr(_stdlib_socket.socket, "recvmsg"): - recvmsg = _make_simple_sock_method_wrapper( - "recvmsg", _core.wait_readable, maybe_avail=True + if TYPE_CHECKING: + + def recvmsg( + __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0 + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + ... + + recvmsg = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvmsg, _core.wait_readable, maybe_avail=True ) ################################################################ @@ -775,29 +897,58 @@ async def recv(self, buffersize: int, flags: int = 0) -> bytes: ################################################################ if hasattr(_stdlib_socket.socket, "recvmsg_into"): - recvmsg_into = _make_simple_sock_method_wrapper( - "recvmsg_into", _core.wait_readable, maybe_avail=True + if TYPE_CHECKING: + + def recvmsg_into( + __self, + __buffers: Iterable[Buffer], + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + ... + + recvmsg_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvmsg_into, _core.wait_readable, maybe_avail=True ) ################################################################ # send ################################################################ - send = _make_simple_sock_method_wrapper("send", _core.wait_writable) + if TYPE_CHECKING: + + def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: + ... + + send = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.send, _core.wait_writable + ) ################################################################ # sendto ################################################################ + @overload + async def sendto( + self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @overload + async def sendto( + self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) - async def sendto(self, *args): + async def sendto(self, *args: Any) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address) # and kwargs are not accepted - args = list(args) - args[-1] = await self._resolve_address_nocp(args[-1], local=False) + args_list = list(args) + args_list[-1] = await self._resolve_address_nocp(args[-1], local=False) return await self._nonblocking_helper( - _stdlib_socket.socket.sendto, args, {}, _core.wait_writable + _core.wait_writable, _stdlib_socket.socket.sendto, *args_list ) ################################################################ @@ -809,20 +960,28 @@ async def sendto(self, *args): ): @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) - async def sendmsg(self, *args): + async def sendmsg( + self, + __buffers: Iterable[Buffer], + __ancdata: Iterable[tuple[int, int, Buffer]] = (), + __flags: int = 0, + __address: Address | None = None, + ) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. Only available on platforms where :meth:`socket.socket.sendmsg` is available. """ - # args is: buffers[, ancdata[, flags[, address]]] - # and kwargs are not accepted - if len(args) == 4 and args[-1] is not None: - args = list(args) - args[-1] = await self._resolve_address_nocp(args[-1], local=False) + if __address is not None: + __address = await self._resolve_address_nocp(__address, local=False) return await self._nonblocking_helper( - _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable + _core.wait_writable, + _stdlib_socket.socket.sendmsg, + __buffers, + __ancdata, + __flags, + __address, ) ################################################################ diff --git a/trio/_subprocess_platform/windows.py b/trio/_subprocess_platform/windows.py index 958be8675c..816da4b203 100644 --- a/trio/_subprocess_platform/windows.py +++ b/trio/_subprocess_platform/windows.py @@ -3,4 +3,4 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: - await WaitForSingleObject(int(process._proc._handle)) + await WaitForSingleObject(int(process._proc._handle)) # type: ignore[attr-defined] diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 00a964787a..5d19fe6729 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9020866773675762, + "completenessScore": 0.9149277688603531, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 562, - "withUnknownType": 61 + "withKnownType": 570, + "withUnknownType": 53 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -46,19 +46,14 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 523, - "withUnknownType": 89 + "withKnownType": 546, + "withUnknownType": 67 }, "packageName": "trio", "symbols": [ - "trio._abc.SocketFactory.socket", "trio._core._entry_queue.TrioToken.run_sync_soon", "trio._core._run.Nursery.start", "trio._core._run.Nursery.start_soon", - "trio._dtls.DTLSChannel.__init__", - "trio._dtls.DTLSEndpoint.__init__", - "trio._highlevel_socket.SocketListener.__init__", - "trio._highlevel_socket.SocketStream.__init__", "trio._highlevel_socket.SocketStream.getsockopt", "trio._highlevel_socket.SocketStream.send_all", "trio._highlevel_socket.SocketStream.setsockopt", @@ -76,14 +71,6 @@ "trio._path.Path.__rtruediv__", "trio._path.Path.__truediv__", "trio._path.Path.open", - "trio._socket._SocketType.recv_into", - "trio._socket._SocketType.recvfrom", - "trio._socket._SocketType.recvfrom_into", - "trio._socket._SocketType.recvmsg", - "trio._socket._SocketType.recvmsg_into", - "trio._socket._SocketType.send", - "trio._socket._SocketType.sendmsg", - "trio._socket._SocketType.sendto", "trio._ssl.SSLListener.__init__", "trio._ssl.SSLListener.accept", "trio._ssl.SSLListener.aclose", @@ -141,10 +128,6 @@ "trio.serve_listeners", "trio.serve_ssl_over_tcp", "trio.serve_tcp", - "trio.socket.from_stdlib_socket", - "trio.socket.getprotobyname", - "trio.socket.set_custom_socket_factory", - "trio.socket.socketpair", "trio.testing._memory_streams.MemoryReceiveStream.__init__", "trio.testing._memory_streams.MemoryReceiveStream.aclose", "trio.testing._memory_streams.MemoryReceiveStream.close", From 128a8d26945ee24ac6780b336c13a06c7846991e Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 20 Jul 2023 13:48:44 +0200 Subject: [PATCH 08/49] fix RTD, and export DTLSChannelStatistics and TaskStatus --- docs/source/conf.py | 4 ++++ docs/source/reference-core.rst | 1 + docs/source/reference-io.rst | 2 ++ trio/__init__.py | 7 ++++++- trio/_core/__init__.py | 1 + trio/_core/_run.py | 8 ++++---- trio/_dtls.py | 4 ++-- trio/_tests/verify_types.json | 10 +++++----- 8 files changed, 25 insertions(+), 12 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 68a5a22a81..8efc702f02 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -63,6 +63,10 @@ ("py:obj", "trio._abc.T"), ("py:obj", "trio._abc.T_resource"), ("py:class", "types.FrameType"), + # TODO: figure out if you can link this to SSL + ("py:class", "Context"), + # TODO: temporary type + ("py:class", "_SocketType"), ] autodoc_inherit_docstrings = False default_role = "obj" diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index f571d23294..434b8c8b5b 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -922,6 +922,7 @@ The nursery API See :meth:`~Nursery.start`. +.. autoclass:: TaskStatus .. _task-local-storage: diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index a3291ef2ae..3ae66699e1 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -304,6 +304,8 @@ unfortunately that's not yet possible. .. automethod:: statistics +.. autoclass:: DTLSChannelStatistics + .. module:: trio.socket Low-level networking with :mod:`trio.socket` diff --git a/trio/__init__.py b/trio/__init__.py index 2b8810504b..ac0687f529 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -34,6 +34,7 @@ EndOfChannel as EndOfChannel, Nursery as Nursery, RunFinishedError as RunFinishedError, + TaskStatus as TaskStatus, TrioInternalError as TrioInternalError, WouldBlock as WouldBlock, current_effective_deadline as current_effective_deadline, @@ -46,7 +47,11 @@ NonBaseMultiError as _NonBaseMultiError, ) from ._deprecate import TrioDeprecationWarning as TrioDeprecationWarning -from ._dtls import DTLSChannel as DTLSChannel, DTLSEndpoint as DTLSEndpoint +from ._dtls import ( + DTLSChannel as DTLSChannel, + DTLSChannelStatistics as DTLSChannelStatistics, + DTLSEndpoint as DTLSEndpoint, +) from ._file_io import open_file as open_file, wrap_file as wrap_file from ._highlevel_generic import ( StapledStream as StapledStream, diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index abd58245e3..aa898fffe0 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -28,6 +28,7 @@ CancelScope, Nursery, Task, + TaskStatus, add_instrument, checkpoint, checkpoint_if_cancelled, diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 4f90889c5f..279061f872 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -783,7 +783,7 @@ def cancel_called(self) -> bool: # This code needs to be read alongside the code from Nursery.start to make # sense. @attr.s(eq=False, hash=False, repr=False) -class _TaskStatus: +class TaskStatus: _old_nursery = attr.ib() _new_nursery = attr.ib() _called_started = attr.ib(default=False) @@ -1137,16 +1137,16 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): try: self._pending_starts += 1 async with open_nursery() as old_nursery: - task_status = _TaskStatus(old_nursery, self) + task_status = TaskStatus(old_nursery, self) thunk = functools.partial(async_fn, task_status=task_status) task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( thunk, args, old_nursery, name ) task._eventual_parent_nursery = self - # Wait for either _TaskStatus.started or an exception to + # Wait for either TaskStatus.started or an exception to # cancel this nursery: # If we get here, then the child either got reparented or exited - # normally. The complicated logic is all in _TaskStatus.started(). + # normally. The complicated logic is all in TaskStatus.started(). # (Any exceptions propagate directly out of the above.) if not task_status._called_started: raise RuntimeError("child exited without calling task_status.started()") diff --git a/trio/_dtls.py b/trio/_dtls.py index 0af8340732..3b15c83b7e 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -42,7 +42,7 @@ from OpenSSL.SSL import Context from typing_extensions import Self, TypeAlias - from ._core._run import _TaskStatus + from ._core._run import TaskStatus from ._socket import _SocketType MAX_UDP_PACKET_SIZE = 65527 @@ -1276,7 +1276,7 @@ async def serve( ssl_context: Context, async_fn: Callable[[DTLSChannel], Awaitable], *args: Any, - task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] + task_status: TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] ) -> None: """Listen for incoming connections, and spawn a handler for each using an internal nursery. diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index cf1e7eccfb..65bfd6a301 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.8764044943820225, + "completenessScore": 0.8752, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 546, - "withUnknownType": 76 + "withKnownType": 547, + "withUnknownType": 77 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -73,8 +73,8 @@ "trio._core._mock_clock.MockClock.jump", "trio._core._run.Nursery.start", "trio._core._run.Nursery.start_soon", - "trio._core._run._TaskStatus.__repr__", - "trio._core._run._TaskStatus.started", + "trio._core._run.TaskStatus.__repr__", + "trio._core._run.TaskStatus.started", "trio._core._unbounded_queue.UnboundedQueue.__aiter__", "trio._core._unbounded_queue.UnboundedQueue.__anext__", "trio._core._unbounded_queue.UnboundedQueue.__repr__", From 0e9aedd9de39bd3f413096ff416f6ec13e568c8e Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 21 Jul 2023 12:38:56 +0200 Subject: [PATCH 09/49] update .gitattributes --- .gitattributes | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitattributes b/.gitattributes index 991065e069..3fd55705b6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -2,3 +2,4 @@ trio/_core/_generated* linguist-generated=true # Treat generated files as binary in git diff trio/_core/_generated* -diff +trio/_tests/verify_types.json merge=binary From c5b43a0d7d02b45782b007aecb967777a43903a1 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 22 Jul 2023 00:26:06 +0200 Subject: [PATCH 10/49] fixes after review from ZacHD --- docs/source/reference-core.rst | 1 + docs/source/reference-io.rst | 1 + trio/_dtls.py | 28 ++++++++-------------------- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 5f9381cbfc..980a3106e5 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -923,6 +923,7 @@ The nursery API See :meth:`~Nursery.start`. .. autoclass:: TaskStatus + :members: .. _task-local-storage: diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 3ae66699e1..9ad11b2c5a 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -305,6 +305,7 @@ unfortunately that's not yet possible. .. automethod:: statistics .. autoclass:: DTLSChannelStatistics + :members: .. module:: trio.socket diff --git a/trio/_dtls.py b/trio/_dtls.py index 3b15c83b7e..7885b1ff21 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -798,9 +798,7 @@ async def dtls_receive_loop( @attr.frozen class DTLSChannelStatistics: - """An object with statistics about this connection. - - Currently this has only one attribute: + """Currently this has only one attribute: - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of incoming packets from this peer that Trio successfully received from the @@ -1060,9 +1058,7 @@ def read_volley() -> list[_AnyHandshakeMessage]: worst_case_mtu(self.endpoint.socket), ) - async def send( - self, data: bytes - ) -> None: # or str? SendChannel defines it as bytes + async def send(self, data: bytes) -> None: """Send a packet of data, securely.""" if self._closed: @@ -1078,7 +1074,7 @@ async def send( _read_loop(self._ssl.bio_read), self.peer_address ) - async def receive(self) -> bytes: # or str? + async def receive(self) -> bytes: """Fetch the next packet of data from this connection's peer, waiting if necessary. @@ -1151,18 +1147,7 @@ def get_cleartext_mtu(self) -> int: return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return] def statistics(self) -> DTLSChannelStatistics: - """Returns an object with statistics about this connection. - - Currently this has only one attribute: - - - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of - incoming packets from this peer that Trio successfully received from the - network, but then got dropped because the internal channel buffer was full. If - this is non-zero, then you might want to call ``receive`` more often, or use a - larger ``incoming_packets_buffer``, or just not worry about it because your - UDP-based protocol should be able to handle the occasional lost packet, right? - - """ + """Returns a `DTLSChannelStatistics` object with statistics about this connection.""" return DTLSChannelStatistics(self._packets_dropped_in_trio) @@ -1271,10 +1256,13 @@ def _check_closed(self) -> None: if self._closed: raise trio.ClosedResourceError + # async_fn cannot be typed with ParamSpec, since we don't accept + # kwargs. Can be typed with TypeVarTuple once it's fully supported + # in mypy. async def serve( self, ssl_context: Context, - async_fn: Callable[[DTLSChannel], Awaitable], + async_fn: Callable[[...], Awaitable], *args: Any, task_status: TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] ) -> None: From 58e4412e9c7f4e4b56dddc3f2d61b3d122e3b83a Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 23 Jul 2023 23:26:06 +0200 Subject: [PATCH 11/49] fixes after review by a5rocks --- pyproject.toml | 3 +++ trio/_dtls.py | 37 ++++++++++++++++++++----------------- trio/_socket.py | 9 +++++++-- 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4fe96b06b6..3b14a075da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,9 @@ module = [ ] disallow_incomplete_defs = true disallow_untyped_defs = true +disallow_any_generics = true +disallow_any_decorated = true +disallow_subclassing_any = true [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] diff --git a/trio/_dtls.py b/trio/_dtls.py index 7885b1ff21..6ba5f7931e 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -26,7 +26,6 @@ Iterator, TypeVar, Union, - cast, ) import attr @@ -43,7 +42,7 @@ from typing_extensions import Self, TypeAlias from ._core._run import TaskStatus - from ._socket import _SocketType + from ._socket import Address, _SocketType MAX_UDP_PACKET_SIZE = 65527 @@ -350,7 +349,7 @@ class OpaqueHandshakeMessage: record: Record -# for some reason doesn't work with | +# Needs Union until <3.10 is dropped _AnyHandshakeMessage: TypeAlias = Union[ HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage ] @@ -564,7 +563,7 @@ def _signable(*fields: bytes) -> bytes: def _make_cookie( - key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes + key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes ) -> bytes: assert len(salt) == SALT_BYTES assert len(key) == KEY_BYTES @@ -582,7 +581,7 @@ def _make_cookie( def valid_cookie( - key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes + key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes ) -> bool: if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] @@ -604,7 +603,7 @@ def valid_cookie( def challenge_for( - key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes + key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes ) -> bytes: salt = os.urandom(SALT_BYTES) tick = _current_cookie_tick() @@ -665,7 +664,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: async def handle_client_hello_untrusted( - endpoint: DTLSEndpoint, address: Any, packet: bytes + endpoint: DTLSEndpoint, address: Address, packet: bytes ) -> None: if endpoint._listening_context is None: return @@ -776,7 +775,8 @@ async def dtls_receive_loop( await stream._resend_final_volley() else: try: - stream._q.s.send_nowait(packet) + # mypy for some reason cannot determine type of _q + stream._q.s.send_nowait(packet) # type:ignore[has-type] except trio.WouldBlock: stream._packets_dropped_in_trio += 1 else: @@ -828,7 +828,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): + def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context): self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -839,7 +839,12 @@ def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to # support and isn't useful anyway -- especially for DTLS where it's equivalent # to just performing a new handshake. - ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined] + ctx.set_options( + ( + SSL.OP_NO_QUERY_MTU + | SSL.OP_NO_RENEGOTIATION # type: ignore[attr-defined] + ) + ) self._ssl = SSL.Connection(ctx) self._handshake_mtu = 0 # This calls self._ssl.set_ciphertext_mtu, which is important, because if you @@ -968,9 +973,8 @@ def read_volley() -> list[_AnyHandshakeMessage]: new_volley_messages and volley_messages and isinstance(new_volley_messages[0], HandshakeMessage) - # TODO: add isinstance or do a cast? - and new_volley_messages[0].msg_seq - == cast(HandshakeMessage, volley_messages[0]).msg_seq + and isinstance(volley_messages[0], HandshakeMessage) + and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq ): # openssl decided to retransmit; discard because we handle # retransmits ourselves @@ -1054,8 +1058,7 @@ def read_volley() -> list[_AnyHandshakeMessage]: # PMTU estimate is wrong? Let's try dropping it to the minimum # and hope that helps. self._handshake_mtu = min( - self._handshake_mtu, - worst_case_mtu(self.endpoint.socket), + self._handshake_mtu, worst_case_mtu(self.endpoint.socket) ) async def send(self, data: bytes) -> None: @@ -1195,7 +1198,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self._streams = weakref.WeakValueDictionary[Address, DTLSChannel]() self._listening_context: Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) @@ -1262,7 +1265,7 @@ def _check_closed(self) -> None: async def serve( self, ssl_context: Context, - async_fn: Callable[[...], Awaitable], + async_fn: Callable[..., Awaitable[object]], *args: Any, task_status: TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] ) -> None: diff --git a/trio/_socket.py b/trio/_socket.py index 659f844078..26b03fc3e0 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -5,7 +5,7 @@ import socket as _stdlib_socket import sys from functools import wraps as _wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple, Union import idna as _idna @@ -17,7 +17,12 @@ from collections.abc import Iterable from types import TracebackType - from typing_extensions import Self + from typing_extensions import Self, TypeAlias + +# must use old-style typing because it's evaluated at runtime +Address: TypeAlias = Union[ + str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int] +] # Usage: From eda238b1feb919c01740d701831a7568e6a50173 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 23 Jul 2023 23:40:33 +0200 Subject: [PATCH 12/49] oopsies --- trio/_dtls.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 6ba5f7931e..e8888d7871 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -27,6 +27,7 @@ TypeVar, Union, ) +from weakref import ReferenceType, WeakValueDictionary import attr from OpenSSL import SSL @@ -739,7 +740,7 @@ async def handle_client_hello_untrusted( async def dtls_receive_loop( - endpoint_ref: weakref.ReferenceType[DTLSEndpoint], sock: _SocketType + endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType ) -> None: try: while True: @@ -1198,7 +1199,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams = weakref.WeakValueDictionary[Address, DTLSChannel]() + self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary() self._listening_context: Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) From c437821d74b54ccad30982fa45b65c26826bb104 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 24 Jul 2023 00:20:19 +0200 Subject: [PATCH 13/49] aoeu --- pyproject.toml | 75 +++++++++++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e44b75e9f..a67da04e11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ disallow_incomplete_defs = true check_untyped_defs = false disallow_untyped_calls = false disallow_untyped_defs = true +disallow_any_generics = true # DO NOT use `ignore_errors`; it doesn't apply # downstream and users have to deal with them. @@ -48,6 +49,7 @@ disallow_untyped_defs = true module = "trio._core._run" disallow_incomplete_defs = false disallow_untyped_defs = false +disallow_any_generics = false [[tool.mypy.overrides]] module = [ @@ -57,52 +59,51 @@ module = [ ] disallow_untyped_defs = true disallow_incomplete_defs = true -disallow_generic_any = true +disallow_any_generics = true [[tool.mypy.overrides]] module = [ -"trio/_core/_asyncgens", -"trio/_core/_entry_queue", -"trio/_core/_generated_io_epoll", -"trio/_core/_generated_io_kqueue", +"trio/_core/_asyncgens", # 10 +"trio/_core/_entry_queue", # 16 +"trio/_core/_generated_io_epoll", # 3 "trio/_core/_generated_io_windows", -"trio/_core/_generated_run", -"trio/_core/_io_common", -"trio/_core/_io_epoll", -"trio/_core/_io_kqueue", +"trio/_core/_generated_run", # 8 +"trio/_core/_io_common", # 1 +"trio/_core/_io_epoll", # 21 +"trio/_core/_io_kqueue", # 16 "trio/_core/_io_windows", -"trio/_core/_ki", -"trio/_core/_multierror", -"trio/_core/_parking_lot", -"trio/_core/_thread_cache", -"trio/_core/_traps", -"trio/_core/_wakeup_socketpair", +"trio/_core/_ki", # 14 +"trio/_core/_multierror", # 19 +"trio/_core/_parking_lot", # 1 +"trio/_core/_thread_cache", # 6 +"trio/_core/_traps", # 7 +"trio/_core/_wakeup_socketpair", # 12 "trio/_core/_windows_cffi", -"trio/_deprecate", -"trio/_file_io", -"trio/_highlevel_open_tcp_listeners", -"trio/_highlevel_open_tcp_stream", -"trio/_highlevel_open_unix_stream", -"trio/_highlevel_serve_listeners", -"trio/_highlevel_socket", -"trio/_highlevel_ssl_helpers", -"trio/_path", -"trio/_signals", -"trio/_ssl", -"trio/_subprocess", -"trio/_subprocess_platform/kqueue", -"trio/_subprocess_platform/waitid", -"trio/_sync", -"trio/_threads", -"trio/_util", +"trio/_deprecate", # 12 +"trio/_file_io", # 13 +"trio/_highlevel_open_tcp_listeners", # 3 +"trio/_highlevel_open_tcp_stream", # 5 +"trio/_highlevel_open_unix_stream", # 2 +"trio/_highlevel_serve_listeners", # 3 +"trio/_highlevel_socket", # 4 +"trio/_highlevel_ssl_helpers", # 3 +"trio/_path", # 21 +"trio/_signals", # 13 +"trio/_ssl", # 26 +"trio/_subprocess", # 21 +"trio/_subprocess_platform/waitid", # 2 +"trio/_sync", # 1 +"trio/_threads", # 15 +"trio/_util", # 13 "trio/_wait_for_object", -"trio/testing/_check_streams", -"trio/testing/_checkpoints", -"trio/testing/_memory_streams", -"trio/testing/_network", -"trio/testing/_trio_test", +"trio/testing/_check_streams", # 27 +"trio/testing/_checkpoints", # 3 +"trio/testing/_memory_streams", # 66 +"trio/testing/_network", # 1 +"trio/testing/_trio_test", # 2 ] disallow_untyped_defs = false +disallow_any_generics = false [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] From 0fbb87f87846c61ceeaac8a990cf4d79f5e87d49 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 24 Jul 2023 00:22:01 +0200 Subject: [PATCH 14/49] merge _dtls --- trio/_dtls.py | 77 +++++++++++++++++++---------------- trio/_tests/verify_types.json | 2 +- 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 8aede92bc8..e8888d7871 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -26,8 +26,8 @@ Iterator, TypeVar, Union, - cast, ) +from weakref import ReferenceType, WeakValueDictionary import attr from OpenSSL import SSL @@ -42,8 +42,8 @@ from OpenSSL.SSL import Context from typing_extensions import Self, TypeAlias - from ._core._run import _TaskStatus - from ._socket import _SocketType + from ._core._run import TaskStatus + from ._socket import Address, _SocketType MAX_UDP_PACKET_SIZE = 65527 @@ -350,7 +350,7 @@ class OpaqueHandshakeMessage: record: Record -# for some reason doesn't work with | +# Needs Union until <3.10 is dropped _AnyHandshakeMessage: TypeAlias = Union[ HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage ] @@ -564,7 +564,7 @@ def _signable(*fields: bytes) -> bytes: def _make_cookie( - key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes + key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes ) -> bytes: assert len(salt) == SALT_BYTES assert len(key) == KEY_BYTES @@ -582,7 +582,7 @@ def _make_cookie( def valid_cookie( - key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes + key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes ) -> bool: if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] @@ -604,7 +604,7 @@ def valid_cookie( def challenge_for( - key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes + key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes ) -> bytes: salt = os.urandom(SALT_BYTES) tick = _current_cookie_tick() @@ -665,7 +665,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: async def handle_client_hello_untrusted( - endpoint: DTLSEndpoint, address: Any, packet: bytes + endpoint: DTLSEndpoint, address: Address, packet: bytes ) -> None: if endpoint._listening_context is None: return @@ -740,7 +740,7 @@ async def handle_client_hello_untrusted( async def dtls_receive_loop( - endpoint_ref: weakref.ReferenceType[DTLSEndpoint], sock: _SocketType + endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType ) -> None: try: while True: @@ -776,7 +776,8 @@ async def dtls_receive_loop( await stream._resend_final_volley() else: try: - stream._q.s.send_nowait(packet) + # mypy for some reason cannot determine type of _q + stream._q.s.send_nowait(packet) # type:ignore[has-type] except trio.WouldBlock: stream._packets_dropped_in_trio += 1 else: @@ -798,6 +799,17 @@ async def dtls_receive_loop( @attr.frozen class DTLSChannelStatistics: + """Currently this has only one attribute: + + - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of + incoming packets from this peer that Trio successfully received from the + network, but then got dropped because the internal channel buffer was full. If + this is non-zero, then you might want to call ``receive`` more often, or use a + larger ``incoming_packets_buffer``, or just not worry about it because your + UDP-based protocol should be able to handle the occasional lost packet, right? + + """ + incoming_packets_dropped_in_trio: int @@ -817,7 +829,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): + def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context): self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -828,7 +840,12 @@ def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to # support and isn't useful anyway -- especially for DTLS where it's equivalent # to just performing a new handshake. - ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined] + ctx.set_options( + ( + SSL.OP_NO_QUERY_MTU + | SSL.OP_NO_RENEGOTIATION # type: ignore[attr-defined] + ) + ) self._ssl = SSL.Connection(ctx) self._handshake_mtu = 0 # This calls self._ssl.set_ciphertext_mtu, which is important, because if you @@ -957,9 +974,8 @@ def read_volley() -> list[_AnyHandshakeMessage]: new_volley_messages and volley_messages and isinstance(new_volley_messages[0], HandshakeMessage) - # TODO: add isinstance or do a cast? - and new_volley_messages[0].msg_seq - == cast(HandshakeMessage, volley_messages[0]).msg_seq + and isinstance(volley_messages[0], HandshakeMessage) + and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq ): # openssl decided to retransmit; discard because we handle # retransmits ourselves @@ -1043,13 +1059,10 @@ def read_volley() -> list[_AnyHandshakeMessage]: # PMTU estimate is wrong? Let's try dropping it to the minimum # and hope that helps. self._handshake_mtu = min( - self._handshake_mtu, - worst_case_mtu(self.endpoint.socket), + self._handshake_mtu, worst_case_mtu(self.endpoint.socket) ) - async def send( - self, data: bytes - ) -> None: # or str? SendChannel defines it as bytes + async def send(self, data: bytes) -> None: """Send a packet of data, securely.""" if self._closed: @@ -1065,7 +1078,7 @@ async def send( _read_loop(self._ssl.bio_read), self.peer_address ) - async def receive(self) -> bytes: # or str? + async def receive(self) -> bytes: """Fetch the next packet of data from this connection's peer, waiting if necessary. @@ -1138,18 +1151,7 @@ def get_cleartext_mtu(self) -> int: return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return] def statistics(self) -> DTLSChannelStatistics: - """Returns an object with statistics about this connection. - - Currently this has only one attribute: - - - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of - incoming packets from this peer that Trio successfully received from the - network, but then got dropped because the internal channel buffer was full. If - this is non-zero, then you might want to call ``receive`` more often, or use a - larger ``incoming_packets_buffer``, or just not worry about it because your - UDP-based protocol should be able to handle the occasional lost packet, right? - - """ + """Returns a `DTLSChannelStatistics` object with statistics about this connection.""" return DTLSChannelStatistics(self._packets_dropped_in_trio) @@ -1197,7 +1199,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary() self._listening_context: Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) @@ -1258,12 +1260,15 @@ def _check_closed(self) -> None: if self._closed: raise trio.ClosedResourceError + # async_fn cannot be typed with ParamSpec, since we don't accept + # kwargs. Can be typed with TypeVarTuple once it's fully supported + # in mypy. async def serve( self, ssl_context: Context, - async_fn: Callable[[DTLSChannel], Awaitable], + async_fn: Callable[..., Awaitable[object]], *args: Any, - task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] + task_status: TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] ) -> None: """Listen for incoming connections, and spawn a handler for each using an internal nursery. diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 5d19fe6729..be02c5203b 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -46,7 +46,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 546, + "withKnownType": 535, "withUnknownType": 67 }, "packageName": "trio", From 3d92de4c0d6bd87affb37cd0323935e0a2a66b96 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 24 Jul 2023 00:31:48 +0200 Subject: [PATCH 15/49] merge-ish _socket --- pyproject.toml | 11 --- trio/_core/_local.py | 53 ++++++-------- trio/_core/_parking_lot.py | 2 +- trio/_socket.py | 127 ++++++++++++++++++++++------------ trio/_tests/verify_types.json | 2 +- 5 files changed, 108 insertions(+), 87 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a67da04e11..e6f61a698e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,16 +51,6 @@ disallow_incomplete_defs = false disallow_untyped_defs = false disallow_any_generics = false -[[tool.mypy.overrides]] -module = [ - "trio._abc", - "trio._dtls", - "trio._socket", -] -disallow_untyped_defs = true -disallow_incomplete_defs = true -disallow_any_generics = true - [[tool.mypy.overrides]] module = [ "trio/_core/_asyncgens", # 10 @@ -74,7 +64,6 @@ module = [ "trio/_core/_io_windows", "trio/_core/_ki", # 14 "trio/_core/_multierror", # 19 -"trio/_core/_parking_lot", # 1 "trio/_core/_thread_cache", # 6 "trio/_core/_traps", # 7 "trio/_core/_wakeup_socketpair", # 12 diff --git a/trio/_core/_local.py b/trio/_core/_local.py index fe509ca7ad..b9dada64fe 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,32 +1,32 @@ from __future__ import annotations -from typing import Generic, TypeVar, overload +from typing import Generic, TypeVar, final # Runvar implementations import attr -from .._util import Final +from .._util import Final, NoPublicConstructor from . import _run +# `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released + T = TypeVar("T") -C = TypeVar("C", bound="_RunVarToken") -class NoValue(object): +@final +class _NoValue: ... -@attr.s(eq=False, hash=False, slots=True) -class _RunVarToken(Generic[T]): - _no_value = NoValue() - +@attr.s(eq=False, hash=False, slots=False) +class RunVarToken(Generic[T], metaclass=NoPublicConstructor): _var: RunVar[T] = attr.ib() - previous_value: T | NoValue = attr.ib(default=_no_value) + previous_value: T | type[_NoValue] = attr.ib(default=_NoValue) redeemed: bool = attr.ib(default=False, init=False) @classmethod - def empty(cls: type[C], var: RunVar[T]) -> C: - return cls(var) + def _empty(cls, var: RunVar[T]) -> RunVarToken[T]: + return cls._create(var) @attr.s(eq=False, hash=False, slots=True) @@ -39,19 +39,10 @@ class RunVar(Generic[T], metaclass=Final): """ - _NO_DEFAULT = NoValue() _name: str = attr.ib() - _default: T | NoValue = attr.ib(default=_NO_DEFAULT) - - @overload - def get(self, default: T) -> T: - ... - - @overload - def get(self, default: NoValue = _NO_DEFAULT) -> T | NoValue: - ... + _default: T | type[_NoValue] = attr.ib(default=_NoValue) - def get(self, default: T | NoValue = _NO_DEFAULT) -> T | NoValue: + def get(self, default: T | type[_NoValue] = _NoValue) -> T: """Gets the value of this :class:`RunVar` for the current run call.""" try: # not typed yet @@ -60,15 +51,15 @@ def get(self, default: T | NoValue = _NO_DEFAULT) -> T | NoValue: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: # contextvars consistency - if default is not self._NO_DEFAULT: - return default + if default is not _NoValue: + return default # type: ignore[return-value] - if self._default is not self._NO_DEFAULT: - return self._default + if self._default is not _NoValue: + return self._default # type: ignore[return-value] raise LookupError(self) from None - def set(self, value: T) -> _RunVarToken[T]: + def set(self, value: T) -> RunVarToken[T]: """Sets the value of this :class:`RunVar` for this current run call. @@ -76,16 +67,16 @@ def set(self, value: T) -> _RunVarToken[T]: try: old_value = self.get() except LookupError: - token: _RunVarToken[T] = _RunVarToken.empty(self) + token = RunVarToken._empty(self) else: - token = _RunVarToken(self, old_value) + token = RunVarToken[T]._create(self, old_value) # This can't fail, because if we weren't in Trio context then the # get() above would have failed. _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index] return token - def reset(self, token: _RunVarToken[T]) -> None: + def reset(self, token: RunVarToken[T]) -> None: """Resets the value of this :class:`RunVar` to what it was previously specified by the token. @@ -101,7 +92,7 @@ def reset(self, token: _RunVarToken[T]) -> None: previous = token.previous_value try: - if previous is _RunVarToken._no_value: + if previous is _NoValue: _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type] else: _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment] diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index 74708433da..6510745e5b 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -139,7 +139,7 @@ async def park(self) -> None: self._parked[task] = None task.custom_sleep_data = self - def abort_fn(_): + def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: del task.custom_sleep_data._parked[task] return _core.Abort.SUCCEEDED diff --git a/trio/_socket.py b/trio/_socket.py index 72498a5482..e9fa8f3537 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -5,6 +5,7 @@ import socket as _stdlib_socket import sys from functools import wraps as _wraps +from operator import index from socket import AddressFamily, SocketKind from typing import ( TYPE_CHECKING, @@ -16,11 +17,11 @@ Tuple, TypeVar, Union, + cast, overload, ) import idna as _idna -from typing_extensions import Concatenate, ParamSpec import trio @@ -30,13 +31,14 @@ from collections.abc import Iterable from types import TracebackType - from typing_extensions import Buffer, Self, TypeAlias + from typing_extensions import Buffer, Concatenate, ParamSpec, Self, TypeAlias from ._abc import HostnameResolver, SocketFactory + P = ParamSpec("P") + T = TypeVar("T") -P = ParamSpec("P") # must use old-style typing because it's evaluated at runtime Address: TypeAlias = Union[ @@ -224,7 +226,7 @@ def numeric_only_failure(exc: BaseException) -> bool: # idna.encode will error out if the hostname has Capital Letters # in it; with uts46=True it will lowercase them instead. host = _idna.encode(host, uts46=True) - hr: HostnameResolver | None = _resolver.get(None) + hr = _resolver.get(None) if hr is not None: return await hr.getaddrinfo(host, port, family, type, proto, flags) else: @@ -296,7 +298,7 @@ def fromfd( proto: int = 0, ) -> _SocketType: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" - family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, int(fd)) + family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -310,13 +312,13 @@ def fromshare(info: bytes) -> _SocketType: if sys.platform == "win32": - FamilyT = int - TypeT = int + FamilyT: TypeAlias = int + TypeT: TypeAlias = int FamilyDefault = _stdlib_socket.AF_INET else: FamilyDefault = None - FamilyT = Union[int, AddressFamily, None] - TypeT = Union[_stdlib_socket.socket, int] + FamilyT: TypeAlias = Union[int, AddressFamily, None] + TypeT: TypeAlias = Union[_stdlib_socket.socket, int] @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) @@ -405,7 +407,7 @@ def _sniff_sockopts_for_fileno( def _make_simple_sock_method_wrapper( fn: Callable[Concatenate[_stdlib_socket.socket, P], T], - wait_fn: Callable, + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], maybe_avail: bool = False, ) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]: @_wraps(fn, assigned=("__name__",), updated=()) @@ -508,6 +510,8 @@ async def _resolve_address_nocp( if family == _stdlib_socket.AF_INET6: list_normed = list(normed) assert len(normed) == 4 + # typechecking certainly doesn't like this logic, but given just how broad + # Address is, it's quite cumbersome to write the below without type: ignore if len(address) >= 3: list_normed[2] = address[2] # type: ignore if len(address) >= 4: @@ -517,7 +521,9 @@ async def _resolve_address_nocp( # TODO: stopping users from initializing this type should be done in a different way, -# so SocketType can be used as a type. +# so SocketType can be used as a type. Note that this is *far* from trivial without +# breaking subclasses of SocketType. Should maybe just add abstract methods to SocketType, +# or rename _SocketType. class SocketType: def __init__(self) -> NoReturn: raise TypeError( @@ -542,36 +548,69 @@ def __init__(self, sock: _stdlib_socket.socket): # Simple + portable methods and attributes ################################################################ - # NB this doesn't work because for loops don't create a scope - # for _name in [ - # ]: - # _meth = getattr(_stdlib_socket.socket, _name) - # @_wraps(_meth, assigned=("__name__", "__doc__"), updated=()) - # def _wrapped(self, *args, **kwargs): - # return getattr(self._sock, _meth)(*args, **kwargs) - # locals()[_meth] = _wrapped - # del _name, _meth, _wrapped - - _forward = { - "detach", - "get_inheritable", - "set_inheritable", - "fileno", - "getpeername", - "getsockname", - "getsockopt", - "setsockopt", - "listen", - "share", - } - - def __getattr__(self, name: str) -> Any: - if name in self._forward: - return getattr(self._sock, name) - raise AttributeError(name) - - def __dir__(self) -> Iterable[str]: - return [*super().__dir__(), *self._forward] + # forwarded methods + def detach(self) -> int: + return self._sock.detach() + + def fileno(self) -> int: + return self._sock.fileno() + + def getpeername(self) -> Any: + return self._sock.getpeername() + + def getsockname(self) -> Any: + return self._sock.getsockname() + + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + if buflen is None: + return self._sock.getsockopt(level, optname) + return self._sock.getsockopt(level, optname, buflen) + + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + if optlen is None: + return self._sock.setsockopt(level, optname, cast("int|Buffer", value)) + return self._sock.setsockopt(level, optname, cast(None, value), optlen) + + def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + return self._sock.listen(backlog) + + def get_inheritable(self) -> bool: + return self._sock.get_inheritable() + + def set_inheritable(self, inheritable: bool) -> None: + return self._sock.set_inheritable(inheritable) + + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): + + def share(self, /, process_id: int) -> bytes: + return self._sock.share(process_id) def __enter__(self) -> Self: return self @@ -678,7 +717,7 @@ async def _resolve_address_nocp( async def _nonblocking_helper( self, - wait_fn: Callable[[_stdlib_socket.socket], Awaitable], + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], fn: Callable[Concatenate[_stdlib_socket.socket, P], T], *args: P.args, **kwargs: P.kwargs, @@ -814,7 +853,9 @@ async def connect(self, address: Address) -> None: def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: ... - # _make_simple_sock_method_wrapper is typed, so this check that the above is correct + # _make_simple_sock_method_wrapper is typed, so this checks that the above is correct + # this requires that we refrain from using `/` to specify pos-only + # args, or mypy thinks the signature differs from typeshed. recv = _make_simple_sock_method_wrapper( # noqa: F811 _stdlib_socket.socket.recv, _core.wait_readable ) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 56ca3ace4d..ad044eb65f 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -46,7 +46,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 546, + "withKnownType": 552, "withUnknownType": 67 }, "packageName": "trio", From 2a51953d22919bffdc81972d56c18809b0565777 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 24 Jul 2023 00:34:16 +0200 Subject: [PATCH 16/49] _sync --- pyproject.toml | 1 - trio/_sync.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e6f61a698e..73131c90ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,6 @@ module = [ "trio/_ssl", # 26 "trio/_subprocess", # 21 "trio/_subprocess_platform/waitid", # 2 -"trio/_sync", # 1 "trio/_threads", # 15 "trio/_util", # 13 "trio/_wait_for_object", diff --git a/trio/_sync.py b/trio/_sync.py index 0f05dd458c..9764ddce2d 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -8,7 +8,7 @@ import trio from . import _core -from ._core import ParkingLot, enable_ki_protection +from ._core import Abort, ParkingLot, RaiseCancelT, enable_ki_protection from ._util import Final if TYPE_CHECKING: @@ -87,7 +87,7 @@ async def wait(self) -> None: task = _core.current_task() self._tasks.add(task) - def abort_fn(_): + def abort_fn(_: RaiseCancelT) -> Abort: self._tasks.remove(task) return _core.Abort.SUCCEEDED @@ -143,7 +143,7 @@ class CapacityLimiterStatistics: borrowed_tokens: int = attr.ib() total_tokens: int | float = attr.ib() - borrowers: list[object] = attr.ib() + borrowers: list[Task | object] = attr.ib() tasks_waiting: int = attr.ib() @@ -204,9 +204,9 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing def __init__(self, total_tokens: int | float): self._lot = ParkingLot() - self._borrowers: set[object] = set() + self._borrowers: set[Task | object] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of - self._pending_borrowers: dict[Task, object] = {} + self._pending_borrowers: dict[Task, Task | object] = {} # invoke the property setter for validation self.total_tokens: int | float = total_tokens assert self._total_tokens == total_tokens @@ -268,7 +268,7 @@ def acquire_nowait(self) -> None: self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) @enable_ki_protection - def acquire_on_behalf_of_nowait(self, borrower: object) -> None: + def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None: """Borrow a token from the sack on behalf of ``borrower``, without blocking. @@ -307,7 +307,7 @@ async def acquire(self) -> None: await self.acquire_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - async def acquire_on_behalf_of(self, borrower: object) -> None: + async def acquire_on_behalf_of(self, borrower: Task | object) -> None: """Borrow a token from the sack on behalf of ``borrower``, blocking if necessary. @@ -347,7 +347,7 @@ def release(self) -> None: self.release_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - def release_on_behalf_of(self, borrower: object) -> None: + def release_on_behalf_of(self, borrower: Task | object) -> None: """Put a token back into the sack on behalf of ``borrower``. Raises: From 4b513bdab79050e154643838712cab1d15fb6f60 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 24 Jul 2023 11:16:08 +0200 Subject: [PATCH 17/49] type _io_epoll and stuff --- pyproject.toml | 77 ++++++++++++++++------------- trio/_core/_generated_io_epoll.py | 6 +-- trio/_core/_io_common.py | 8 ++- trio/_core/_io_epoll.py | 39 ++++++++------- trio/_core/_io_kqueue.py | 8 +-- trio/_core/_run.py | 7 +-- trio/_core/_tests/test_ki.py | 10 +++- trio/_ssl.py | 9 ++-- trio/_subprocess_platform/kqueue.py | 7 ++- trio/_tests/verify_types.json | 9 ++-- trio/tests.py | 2 + 11 files changed, 108 insertions(+), 74 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 73131c90ad..dd3bafd078 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,42 +53,49 @@ disallow_any_generics = false [[tool.mypy.overrides]] module = [ -"trio/_core/_asyncgens", # 10 -"trio/_core/_entry_queue", # 16 -"trio/_core/_generated_io_epoll", # 3 -"trio/_core/_generated_io_windows", -"trio/_core/_generated_run", # 8 -"trio/_core/_io_common", # 1 -"trio/_core/_io_epoll", # 21 -"trio/_core/_io_kqueue", # 16 -"trio/_core/_io_windows", -"trio/_core/_ki", # 14 -"trio/_core/_multierror", # 19 -"trio/_core/_thread_cache", # 6 -"trio/_core/_traps", # 7 +#"trio/_core/_io_common", # 1, 24 +"trio/_core/_windows_cffi", # 2, 324 +#"trio/_core/_generated_io_epoll", # 3, 36 +"trio/_core/_thread_cache", # 6, 273 +"trio/_core/_generated_io_kqueue", # 6 (darwin), 60 +"trio/_core/_traps", # 7, 276 +"trio/_core/_generated_run", # 8, 242 +"trio/_core/_generated_io_windows", # 9 (win32), 84 +"trio/_core/_asyncgens", # 10, 194 + "trio/_core/_wakeup_socketpair", # 12 -"trio/_core/_windows_cffi", -"trio/_deprecate", # 12 -"trio/_file_io", # 13 -"trio/_highlevel_open_tcp_listeners", # 3 -"trio/_highlevel_open_tcp_stream", # 5 -"trio/_highlevel_open_unix_stream", # 2 -"trio/_highlevel_serve_listeners", # 3 -"trio/_highlevel_socket", # 4 -"trio/_highlevel_ssl_helpers", # 3 -"trio/_path", # 21 -"trio/_signals", # 13 -"trio/_ssl", # 26 -"trio/_subprocess", # 21 -"trio/_subprocess_platform/waitid", # 2 -"trio/_threads", # 15 -"trio/_util", # 13 -"trio/_wait_for_object", -"trio/testing/_check_streams", # 27 -"trio/testing/_checkpoints", # 3 -"trio/testing/_memory_streams", # 66 -"trio/testing/_network", # 1 -"trio/testing/_trio_test", # 2 +"trio/_core/_ki", # 14, 210 +"trio/_core/_entry_queue", # 16, 195 +"trio/_core/_io_kqueue", # 16, 198 +"trio/_core/_multierror", # 19, 469 + +#"trio/_core/_io_epoll", # 21, 323 +"trio/_core/_io_windows", # 47 (win32), 867 + + +"trio/testing/_network", # 1, 34 +"trio/testing/_trio_test", # 2, 29 +"trio/testing/_checkpoints", # 3, 62 +"trio/testing/_check_streams", # 27, 522 +"trio/testing/_memory_streams", # 66, 590 + +"trio/_highlevel_open_unix_stream", # 2, 49 lines +"trio/_highlevel_open_tcp_listeners", # 3, 227 lines +"trio/_highlevel_serve_listeners", # 3, 121 lines +"trio/_highlevel_ssl_helpers", # 3, 155 lines +"trio/_highlevel_socket", # 4, 386 lines +"trio/_highlevel_open_tcp_stream", # 5, 379 lines + +"trio/_subprocess_platform/waitid", # 2, 107 lines +"trio/_wait_for_object", # 2 (windows) +"trio/_deprecate", # 12, 140lines +"trio/_util", # 13, 348 lines +"trio/_file_io", # 13, 191 lines +"trio/_signals", # 13, 168 lines +"trio/_threads", # 15, 398 lines +"trio/_path", # 21, 295 lines +"trio/_subprocess", # 21, 759 lines +"trio/_ssl", # 26, 929 lines ] disallow_untyped_defs = false disallow_any_generics = false diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index 02fb3bc348..1de66b0a8a 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -9,7 +9,7 @@ # fmt: off -async def wait_readable(fd): +async def wait_readable(fd: int) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -17,7 +17,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: int) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -25,7 +25,7 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: int) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_io_common.py b/trio/_core/_io_common.py index b141474fda..c1af293278 100644 --- a/trio/_core/_io_common.py +++ b/trio/_core/_io_common.py @@ -1,12 +1,18 @@ +from __future__ import annotations + import copy +from typing import TYPE_CHECKING import outcome from .. import _core +if TYPE_CHECKING: + from ._io_epoll import EpollWaiters + # Utility function shared between _io_epoll and _io_windows -def wake_all(waiters, exc): +def wake_all(waiters: EpollWaiters, exc: BaseException) -> None: try: current_task = _core.current_task() except RuntimeError: diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 9d7b250785..31c49ca230 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -12,14 +12,17 @@ from ._run import _public from ._wakeup_socketpair import WakeupSocketpair +if TYPE_CHECKING: + from .._core import Abort, RaiseCancelT + assert not TYPE_CHECKING or sys.platform == "linux" @attr.s(slots=True, eq=False, frozen=True) class _EpollStatistics: - tasks_waiting_read = attr.ib() - tasks_waiting_write = attr.ib() - backend = attr.ib(default="epoll") + tasks_waiting_read: int = attr.ib() + tasks_waiting_write: int = attr.ib() + backend: str = attr.ib(default="epoll") # Some facts about epoll @@ -182,9 +185,9 @@ class _EpollStatistics: @attr.s(slots=True, eq=False) class EpollWaiters: - read_task = attr.ib(default=None) - write_task = attr.ib(default=None) - current_flags = attr.ib(default=0) + read_task: None = attr.ib(default=None) + write_task: None = attr.ib(default=None) + current_flags: int = attr.ib(default=0) @attr.s(slots=True, eq=False, hash=False) @@ -197,11 +200,11 @@ class EpollIOManager: _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) _force_wakeup_fd: int | None = attr.ib(default=None) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() - def statistics(self): + def statistics(self) -> _EpollStatistics: tasks_waiting_read = 0 tasks_waiting_write = 0 for waiter in self._registered.values(): @@ -214,24 +217,24 @@ def statistics(self): tasks_waiting_write=tasks_waiting_write, ) - def close(self): + def close(self) -> None: self._epoll.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() # Return value must be False-y IFF the timeout expired, NOT if any I/O # happened or force_wakeup was called. Otherwise it can be anything; gets # passed straight through to process_events. - def get_events(self, timeout): + def get_events(self, timeout: float) -> list[tuple[int, int]]: # max_events must be > 0 or epoll gets cranky # accessing self._registered from a thread looks dangerous, but it's # OK because it doesn't matter if our value is a little bit off. max_events = max(1, len(self._registered)) return self._epoll.poll(timeout, max_events) - def process_events(self, events): + def process_events(self, events: list[tuple[int, int]]) -> None: for fd, flags in events: if fd == self._force_wakeup_fd: self._force_wakeup.drain() @@ -250,7 +253,7 @@ def process_events(self, events): waiters.read_task = None self._update_registrations(fd) - def _update_registrations(self, fd): + def _update_registrations(self, fd: int) -> None: waiters = self._registered[fd] wanted_flags = 0 if waiters.read_task is not None: @@ -279,7 +282,7 @@ def _update_registrations(self, fd): if not wanted_flags: del self._registered[fd] - async def _epoll_wait(self, fd, attr_name): + async def _epoll_wait(self, fd: int, attr_name: str) -> None: if not isinstance(fd, int): fd = fd.fileno() waiters = self._registered[fd] @@ -290,7 +293,7 @@ async def _epoll_wait(self, fd, attr_name): setattr(waiters, attr_name, _core.current_task()) self._update_registrations(fd) - def abort(_): + def abort(_: RaiseCancelT) -> Abort: setattr(waiters, attr_name, None) self._update_registrations(fd) return _core.Abort.SUCCEEDED @@ -298,15 +301,15 @@ def abort(_): await _core.wait_task_rescheduled(abort) @_public - async def wait_readable(self, fd): + async def wait_readable(self, fd: int) -> None: await self._epoll_wait(fd, "read_task") @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: int) -> None: await self._epoll_wait(fd, "write_task") @_public - def notify_closing(self, fd): + def notify_closing(self, fd: int) -> None: if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index d1151843e8..5ce5a609ed 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -11,6 +11,8 @@ from ._run import _public from ._wakeup_socketpair import WakeupSocketpair +if TYPE_CHECKING: + from .._core import Abort, RaiseCancelT assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") @@ -123,11 +125,11 @@ async def wait_kevent(self, ident, filter, abort_func): ) self._registered[key] = _core.current_task() - def abort(raise_cancel): + def abort(raise_cancel: RaiseCancelT) -> Abort: r = abort_func(raise_cancel) if r is _core.Abort.SUCCEEDED: del self._registered[key] - return r + return r # type: ignore[no-any-return] return await _core.wait_task_rescheduled(abort) @@ -138,7 +140,7 @@ async def _wait_common(self, fd, filter): event = select.kevent(fd, filter, flags) self._kqueue.control([event], 0) - def abort(_): + def abort(_: RaiseCancelT) -> Abort: event = select.kevent(fd, filter, select.KQ_EV_DELETE) try: self._kqueue.control([event], 0) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 5daf08f462..49c0bcef67 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -37,6 +37,7 @@ Abort, CancelShieldedCheckpoint, PermanentlyDetachCoroutineObject, + RaiseCancelT, WaitTaskRescheduled, cancel_shielded_checkpoint, wait_task_rescheduled, @@ -1022,7 +1023,7 @@ async def _nested_child_finished(self, nested_child_exc): # If we get cancelled (or have an exception injected, like # KeyboardInterrupt), then save that, but still wait until our # children finish. - def aborted(raise_cancel): + def abort(raise_cancel: RaiseCancelT) -> Abort: self._add_exc(capture(raise_cancel).error) return Abort.FAILED @@ -1433,7 +1434,7 @@ def in_main_thread(): class Runner: clock = attr.ib() instruments: Instruments = attr.ib() - io_manager = attr.ib() + io_manager: TheIOManager = attr.ib() ki_manager = attr.ib() strict_exception_groups = attr.ib() @@ -1905,7 +1906,7 @@ async def test_lock_fairness(): key = (cushion, id(task)) self.waiting_for_idle[key] = task - def abort(_): + def abort(_: RaiseCancelT) -> Abort: del self.waiting_for_idle[key] return Abort.SUCCEEDED diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py index fdbada4624..b6eef68e22 100644 --- a/trio/_core/_tests/test_ki.py +++ b/trio/_core/_tests/test_ki.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import contextlib import inspect import signal import threading +from typing import TYPE_CHECKING import outcome import pytest @@ -16,6 +19,9 @@ from ..._util import signal_raise from ...testing import wait_all_tasks_blocked +if TYPE_CHECKING: + from ..._core import Abort, RaiseCancelT + def ki_self(): signal_raise(signal.SIGINT) @@ -375,7 +381,7 @@ async def main(): ki_self() task = _core.current_task() - def abort(_): + def abort(_: RaiseCancelT) -> Abort: _core.reschedule(task, outcome.Value(1)) return _core.Abort.FAILED @@ -394,7 +400,7 @@ async def main(): ki_self() task = _core.current_task() - def abort(raise_cancel): + def abort(raise_cancel: RaiseCancelT) -> Abort: result = outcome.capture(raise_cancel) _core.reschedule(task, result) return _core.Abort.FAILED diff --git a/trio/_ssl.py b/trio/_ssl.py index bd8b3b06b6..352f95edaf 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -148,10 +148,12 @@ # stream) # docs will need to make very clear that this is different from all the other # cancellations in core Trio +from __future__ import annotations import operator as _operator import ssl as _stdlib_ssl from enum import Enum as _Enum +from typing import Any, Awaitable, Callable import trio @@ -209,13 +211,14 @@ class NeedHandshakeError(Exception): class _Once: - def __init__(self, afn, *args): + # needs TypeVarTuple + def __init__(self, afn: Callable[..., Awaitable[object]], *args: Any): self._afn = afn self._args = args self.started = False self._done = _sync.Event() - async def ensure(self, *, checkpoint): + async def ensure(self, *, checkpoint: bool) -> None: if not self.started: self.started = True await self._afn(*self._args) @@ -226,7 +229,7 @@ async def ensure(self, *, checkpoint): await self._done.wait() @property - def done(self): + def done(self) -> bool: return self._done.is_set() diff --git a/trio/_subprocess_platform/kqueue.py b/trio/_subprocess_platform/kqueue.py index 9839fd046b..b40db75953 100644 --- a/trio/_subprocess_platform/kqueue.py +++ b/trio/_subprocess_platform/kqueue.py @@ -1,9 +1,14 @@ +from __future__ import annotations + import select import sys from typing import TYPE_CHECKING from .. import _core, _subprocess +if TYPE_CHECKING: + from .._core import Abort, RaiseCancelT + assert (sys.platform != "win32" and sys.platform != "linux") or not TYPE_CHECKING @@ -35,7 +40,7 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: # in Chromium it seems we should still keep the check. return - def abort(_): + def abort(_: RaiseCancelT) -> Abort: kqueue.control([make_event(select.KQ_EV_DELETE)], 0) return _core.Abort.SUCCEEDED diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index ad044eb65f..c454c4d3b9 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,14 +7,14 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9152, + "completenessScore": 0.92, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 572, - "withUnknownType": 53 + "withKnownType": 575, + "withUnknownType": 50 }, "ignoreUnknownTypesFromImports": true, - "missingClassDocStringCount": 1, + "missingClassDocStringCount": 0, "missingDefaultParamCount": 0, "missingFunctionDocStringCount": 4, "moduleName": "trio", @@ -159,7 +159,6 @@ "trio.testing.open_stream_to_socket_listener", "trio.testing.trio_test", "trio.testing.wait_all_tasks_blocked", - "trio.tests.TestsDeprecationWrapper", "trio.to_thread.current_default_thread_limiter", "trio.wrap_file" ] diff --git a/trio/tests.py b/trio/tests.py index 573a076da8..472befb1ce 100644 --- a/trio/tests.py +++ b/trio/tests.py @@ -16,6 +16,8 @@ # This won't give deprecation warning on import, but will give a warning on use of any # attribute in tests, and static analysis tools will also not see any content inside. class TestsDeprecationWrapper: + """trio.tests is deprecated, use trio._tests""" + __name__ = "trio.tests" def __getattr__(self, attr: str) -> Any: From 1831eaef19ede65303add9ab287bd46e436b5ec8 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 26 Jul 2023 13:22:03 +0200 Subject: [PATCH 18/49] aborted -> abort --- trio/_core/_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 49c0bcef67..4fb1c78048 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1028,7 +1028,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: return Abort.FAILED self._parent_waiting_in_aexit = True - await wait_task_rescheduled(aborted) + await wait_task_rescheduled(abort) else: # Nothing to wait for, so just execute a checkpoint -- but we # still need to mix any exception (e.g. from an external From 091063b2a08a5e001681785d096f5cdb016d0965 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 24 Jul 2023 11:59:03 +0200 Subject: [PATCH 19/49] merge _socket, more stuff --- .coveragerc | 1 + docs/source/conf.py | 2 ++ docs/source/reference-io.rst | 8 ++++++++ pyproject.toml | 8 ++++++++ trio/_core/_generated_instrumentation.py | 7 +++++++ trio/_core/_generated_io_epoll.py | 13 ++++++++++--- trio/_core/_generated_io_kqueue.py | 7 +++++++ trio/_core/_generated_io_windows.py | 7 +++++++ trio/_core/_generated_run.py | 7 +++++++ trio/_core/_io_epoll.py | 10 ++++++---- trio/_threads.py | 12 ++++-------- trio/_tools/gen_exports.py | 5 +++++ trio/socket.py | 1 + 13 files changed, 73 insertions(+), 15 deletions(-) diff --git a/.coveragerc b/.coveragerc index 98f923bd8e..d577aa8adf 100644 --- a/.coveragerc +++ b/.coveragerc @@ -21,6 +21,7 @@ exclude_lines = abc.abstractmethod if TYPE_CHECKING: if _t.TYPE_CHECKING: + @overload partial_branches = pragma: no branch diff --git a/docs/source/conf.py b/docs/source/conf.py index 91ce7d884c..7e8626c20d 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,6 +62,8 @@ ("py:obj", "trio._abc.SendType"), ("py:obj", "trio._abc.T"), ("py:obj", "trio._abc.T_resource"), + ("py:class", "trio._threads.T"), + # why aren't these found in stdlib? ("py:class", "types.FrameType"), # TODO: figure out if you can link this to SSL ("py:class", "Context"), diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 9ad11b2c5a..0669eb5323 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -504,6 +504,14 @@ Socket objects * :meth:`~socket.socket.set_inheritable` * :meth:`~socket.socket.get_inheritable` +The internal SocketType +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: _SocketType +.. + TODO: adding `:members:` here gives error due to overload+_wraps on `sendto` + TODO: rewrite ... all of the above when fixing _SocketType vs SocketType + + .. currentmodule:: trio diff --git a/pyproject.toml b/pyproject.toml index dd3bafd078..2b2678a366 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,14 @@ disallow_any_generics = true # DO NOT use `ignore_errors`; it doesn't apply # downstream and users have to deal with them. +[[tool.mypy.overrides]] +module = [ + "trio._socket", + "trio._core._local", + "trio._sync", +] +disallow_untyped_defs = true +disallow_any_generics = true [[tool.mypy.overrides]] module = "trio._core._run" diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index 30c2f26b4e..e38df6c1ad 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -2,10 +2,17 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* # isort: skip +from __future__ import annotations + +from typing import TYPE_CHECKING + from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT +if TYPE_CHECKING: + from socket import socket + # fmt: off diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index 1de66b0a8a..a6ea291d91 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -2,14 +2,21 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* # isort: skip +from __future__ import annotations + +from typing import TYPE_CHECKING + from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT +if TYPE_CHECKING: + from socket import socket + # fmt: off -async def wait_readable(fd: int) ->None: +async def wait_readable(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -17,7 +24,7 @@ async def wait_readable(fd: int) ->None: raise RuntimeError("must be called from async context") -async def wait_writable(fd: int) ->None: +async def wait_writable(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -25,7 +32,7 @@ async def wait_writable(fd: int) ->None: raise RuntimeError("must be called from async context") -def notify_closing(fd: int) ->None: +def notify_closing(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 94e819769c..5179f150c6 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -2,10 +2,17 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* # isort: skip +from __future__ import annotations + +from typing import TYPE_CHECKING + from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT +if TYPE_CHECKING: + from socket import socket + # fmt: off diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 26b4da697d..71172ef4df 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -2,10 +2,17 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* # isort: skip +from __future__ import annotations + +from typing import TYPE_CHECKING + from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT +if TYPE_CHECKING: + from socket import socket + # fmt: off diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index d1e74a93f4..e3f08a49e3 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -2,10 +2,17 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* # isort: skip +from __future__ import annotations + +from typing import TYPE_CHECKING + from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT +if TYPE_CHECKING: + from socket import socket + # fmt: off diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 31c49ca230..750f85fabb 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -13,6 +13,8 @@ from ._wakeup_socketpair import WakeupSocketpair if TYPE_CHECKING: + from socket import socket + from .._core import Abort, RaiseCancelT assert not TYPE_CHECKING or sys.platform == "linux" @@ -282,7 +284,7 @@ def _update_registrations(self, fd: int) -> None: if not wanted_flags: del self._registered[fd] - async def _epoll_wait(self, fd: int, attr_name: str) -> None: + async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None: if not isinstance(fd, int): fd = fd.fileno() waiters = self._registered[fd] @@ -301,15 +303,15 @@ def abort(_: RaiseCancelT) -> Abort: await _core.wait_task_rescheduled(abort) @_public - async def wait_readable(self, fd: int) -> None: + async def wait_readable(self, fd: int | socket) -> None: await self._epoll_wait(fd, "read_task") @_public - async def wait_writable(self, fd: int) -> None: + async def wait_writable(self, fd: int | socket) -> None: await self._epoll_wait(fd, "write_task") @_public - def notify_closing(self, fd: int) -> None: + def notify_closing(self, fd: int | socket) -> None: if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_threads.py b/trio/_threads.py index 45a416249e..3fbab05750 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -6,13 +6,14 @@ import queue as stdlib_queue import threading from itertools import count -from typing import Optional +from typing import Any, Callable, Optional, TypeVar import attr import outcome from sniffio import current_async_library_cvar import trio +from trio._core._traps import RaiseCancelT from ._core import ( RunVar, @@ -24,6 +25,8 @@ from ._sync import CapacityLimiter from ._util import coroutine_or_error +T = TypeVar("T") + # Global due to Threading API, thread local storage for trio token TOKEN_LOCAL = threading.local() @@ -59,11 +62,6 @@ class ThreadPlaceholder: name = attr.ib() -from typing import Any, Callable, TypeVar - -T = TypeVar("T") - - @enable_ki_protection async def to_thread_run_sync( sync_fn: Callable[..., T], @@ -228,8 +226,6 @@ def deliver_worker_fn_result(result): limiter.release_on_behalf_of(placeholder) raise - from trio._core._traps import RaiseCancelT - def abort(_: RaiseCancelT) -> trio.lowlevel.Abort: if cancellable: task_register[0] = None diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index a5d8529b53..9c9d91f413 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -18,9 +18,14 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* # isort: skip +from __future__ import annotations from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from socket import socket # fmt: off """ diff --git a/trio/socket.py b/trio/socket.py index a9e276c782..f6aebb6a6e 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -35,6 +35,7 @@ # import the overwrites from ._socket import ( SocketType as SocketType, + _SocketType as _SocketType, from_stdlib_socket as from_stdlib_socket, fromfd as fromfd, getaddrinfo as getaddrinfo, From 46d9e9596d4528e86fb70770c48da9a903804e28 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 24 Jul 2023 12:35:25 +0200 Subject: [PATCH 20/49] regen --- trio/_tools/gen_exports.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 9c9d91f413..95b92cedae 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -18,11 +18,13 @@ # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* # isort: skip + +from typing import TYPE_CHECKING + from __future__ import annotations from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT -from typing import TYPE_CHECKING if TYPE_CHECKING: from socket import socket From 156c94f99652760bf8b67ffa8878008640a13eb9 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 24 Jul 2023 13:56:21 +0200 Subject: [PATCH 21/49] progress --- pyproject.toml | 4 +- trio/_core/_generated_instrumentation.py | 10 +++- trio/_core/_generated_io_epoll.py | 10 +++- trio/_core/_generated_io_kqueue.py | 24 +++++--- trio/_core/_generated_io_windows.py | 10 +++- trio/_core/_generated_run.py | 10 +++- trio/_core/_io_epoll.py | 17 +++--- trio/_core/_io_kqueue.py | 75 +++++++++++++++--------- trio/_tools/gen_exports.py | 12 +++- 9 files changed, 114 insertions(+), 58 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2b2678a366..f2422f8eea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ module = [ "trio/_core/_windows_cffi", # 2, 324 #"trio/_core/_generated_io_epoll", # 3, 36 "trio/_core/_thread_cache", # 6, 273 -"trio/_core/_generated_io_kqueue", # 6 (darwin), 60 +#"trio/_core/_generated_io_kqueue", # 6 (darwin), 60 "trio/_core/_traps", # 7, 276 "trio/_core/_generated_run", # 8, 242 "trio/_core/_generated_io_windows", # 9 (win32), 84 @@ -74,7 +74,7 @@ module = [ "trio/_core/_wakeup_socketpair", # 12 "trio/_core/_ki", # 14, 210 "trio/_core/_entry_queue", # 16, 195 -"trio/_core/_io_kqueue", # 16, 198 +#"trio/_core/_io_kqueue", # 16, 198 "trio/_core/_multierror", # 19, 469 #"trio/_core/_io_epoll", # 21, 323 diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index e38df6c1ad..c783452bfc 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -1,18 +1,24 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip + from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT if TYPE_CHECKING: + import select from socket import socket + from _contextlib import _GeneratorContextManager + from _core import Abort, RaiseCancelT + + from .. import _core + # fmt: off diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index a6ea291d91..f35b927737 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -1,18 +1,24 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip + from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT if TYPE_CHECKING: + import select from socket import socket + from _contextlib import _GeneratorContextManager + from _core import Abort, RaiseCancelT + + from .. import _core + # fmt: off diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 5179f150c6..f4b14cd500 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -1,22 +1,28 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip + from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT if TYPE_CHECKING: + import select from socket import socket + from _contextlib import _GeneratorContextManager + from _core import Abort, RaiseCancelT + + from .. import _core + # fmt: off -def current_kqueue(): +def current_kqueue() ->select.kqueue: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() @@ -24,7 +30,8 @@ def current_kqueue(): raise RuntimeError("must be called from async context") -def monitor_kevent(ident, filter): +def monitor_kevent(ident: int, filter: int) ->_GeneratorContextManager[_core + .UnboundedQueue[select.kevent]]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) @@ -32,7 +39,8 @@ def monitor_kevent(ident, filter): raise RuntimeError("must be called from async context") -async def wait_kevent(ident, filter, abort_func): +async def wait_kevent(ident: int, filter: int, abort_func: Callable[[ + RaiseCancelT], Abort]) ->Abort: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func) @@ -40,7 +48,7 @@ async def wait_kevent(ident, filter, abort_func): raise RuntimeError("must be called from async context") -async def wait_readable(fd): +async def wait_readable(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -48,7 +56,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -56,7 +64,7 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 71172ef4df..90d7ce0d70 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -1,18 +1,24 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip + from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT if TYPE_CHECKING: + import select from socket import socket + from _contextlib import _GeneratorContextManager + from _core import Abort, RaiseCancelT + + from .. import _core + # fmt: off diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index e3f08a49e3..e644de78fb 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -1,18 +1,24 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip + from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT if TYPE_CHECKING: + import select from socket import socket + from _contextlib import _GeneratorContextManager + from _core import Abort, RaiseCancelT + + from .. import _core + # fmt: off diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 750f85fabb..cc38d9d537 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -17,7 +17,15 @@ from .._core import Abort, RaiseCancelT -assert not TYPE_CHECKING or sys.platform == "linux" + +@attr.s(slots=True, eq=False) +class EpollWaiters: + read_task: None = attr.ib(default=None) + write_task: None = attr.ib(default=None) + current_flags: int = attr.ib(default=0) + + +assert not TYPE_CHECKING or sys.platform == "linux" or sys.platform == "darwin" @attr.s(slots=True, eq=False, frozen=True) @@ -185,13 +193,6 @@ class _EpollStatistics: # wanted to about how epoll works. -@attr.s(slots=True, eq=False) -class EpollWaiters: - read_task: None = attr.ib(default=None) - write_task: None = attr.ib(default=None) - current_flags: int = attr.ib(default=0) - - @attr.s(slots=True, eq=False, hash=False) class EpollIOManager: _epoll: select.epoll = attr.ib(factory=select.epoll) diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 5ce5a609ed..ee25a748f7 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import errno import select import sys -from contextlib import contextmanager -from typing import TYPE_CHECKING +from contextlib import _GeneratorContextManager, contextmanager +from typing import TYPE_CHECKING, Callable, Iterator import attr import outcome @@ -12,33 +14,38 @@ from ._wakeup_socketpair import WakeupSocketpair if TYPE_CHECKING: - from .._core import Abort, RaiseCancelT + from socket import socket + + from .._core import Abort, RaiseCancelT, Task, UnboundedQueue assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") @attr.s(slots=True, eq=False, frozen=True) class _KqueueStatistics: - tasks_waiting = attr.ib() - monitors = attr.ib() - backend = attr.ib(default="kqueue") + tasks_waiting: int = attr.ib() + monitors: int = attr.ib() + backend: str = attr.ib(default="kqueue") @attr.s(slots=True, eq=False) class KqueueIOManager: - _kqueue = attr.ib(factory=select.kqueue) + _kqueue: select.kqueue = attr.ib(factory=select.kqueue) # {(ident, filter): Task or UnboundedQueue} - _registered = attr.ib(factory=dict) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) - - def __attrs_post_init__(self): + # TODO: int, int? + _registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = attr.ib( + factory=dict + ) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: None = attr.ib(default=None) + + def __attrs_post_init__(self) -> None: force_wakeup_event = select.kevent( self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD ) self._kqueue.control([force_wakeup_event], 0) self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() - def statistics(self): + def statistics(self) -> _KqueueStatistics: tasks_waiting = 0 monitors = 0 for receiver in self._registered.values(): @@ -48,14 +55,14 @@ def statistics(self): monitors += 1 return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors) - def close(self): + def close(self) -> None: self._kqueue.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() - def get_events(self, timeout): + def get_events(self, timeout: float) -> list[select.kevent]: # max_events must be > 0 or kqueue gets cranky # and we generally want this to be strictly larger than the actual # number of events we get, so that we can tell that we've gotten @@ -72,7 +79,7 @@ def get_events(self, timeout): # and loop back to the start return events - def process_events(self, events): + def process_events(self, events: list[select.kevent]) -> None: for event in events: key = (event.ident, event.filter) if event.ident == self._force_wakeup_fd: @@ -81,7 +88,7 @@ def process_events(self, events): receiver = self._registered[key] if event.flags & select.KQ_EV_ONESHOT: del self._registered[key] - if type(receiver) is _core.Task: + if isinstance(receiver, _core.Task): _core.reschedule(receiver, outcome.Value(event)) else: receiver.put_nowait(event) @@ -98,18 +105,25 @@ def process_events(self, events): # be more ergonomic... @_public - def current_kqueue(self): + def current_kqueue(self) -> select.kqueue: return self._kqueue - @contextmanager @_public - def monitor_kevent(self, ident, filter): + def monitor_kevent( + self, ident: int, filter: int + ) -> _GeneratorContextManager[_core.UnboundedQueue[select.kevent]]: + return self._monitor_kevent(ident, filter) + + @contextmanager + def _monitor_kevent( + self, ident: int, filter: int + ) -> Iterator[_core.UnboundedQueue[select.kevent]]: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( "attempt to register multiple listeners for same ident/filter pair" ) - q = _core.UnboundedQueue() + q = _core.UnboundedQueue[select.kevent]() self._registered[key] = q try: yield q @@ -117,7 +131,9 @@ def monitor_kevent(self, ident, filter): del self._registered[key] @_public - async def wait_kevent(self, ident, filter, abort_func): + async def wait_kevent( + self, ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort] + ) -> Abort: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( @@ -129,11 +145,12 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: r = abort_func(raise_cancel) if r is _core.Abort.SUCCEEDED: del self._registered[key] - return r # type: ignore[no-any-return] + return r - return await _core.wait_task_rescheduled(abort) + # wait_task_rescheduled does not have its return type typed + return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return] - async def _wait_common(self, fd, filter): + async def _wait_common(self, fd: int | socket, filter: int) -> None: if not isinstance(fd, int): fd = fd.fileno() flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT @@ -165,15 +182,15 @@ def abort(_: RaiseCancelT) -> Abort: await self.wait_kevent(fd, filter, abort) @_public - async def wait_readable(self, fd): + async def wait_readable(self, fd: int | socket) -> None: await self._wait_common(fd, select.KQ_FILTER_READ) @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: int | socket) -> None: await self._wait_common(fd, select.KQ_FILTER_WRITE) @_public - def notify_closing(self, fd): + def notify_closing(self, fd: int | socket) -> None: if not isinstance(fd, int): fd = fd.fileno() diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 95b92cedae..2996cfaaad 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -17,18 +17,24 @@ HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip - -from typing import TYPE_CHECKING from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Iterator + from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT if TYPE_CHECKING: + import select from socket import socket + from _contextlib import _GeneratorContextManager + from _core import Abort, RaiseCancelT + + from .. import _core + # fmt: off """ From 12b5af701737cf40f9bc4386edf6fd6285d43fb2 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 25 Jul 2023 15:45:30 +0200 Subject: [PATCH 22/49] mypy can now be run on trio/ --- pyproject.toml | 22 ++++++- trio/_core/__init__.py | 1 + trio/_core/_generated_instrumentation.py | 8 ++- trio/_core/_generated_io_epoll.py | 8 ++- trio/_core/_generated_io_kqueue.py | 8 ++- trio/_core/_generated_io_windows.py | 8 ++- trio/_core/_generated_run.py | 25 ++++---- trio/_core/_io_epoll.py | 2 +- trio/_core/_io_kqueue.py | 1 + trio/_core/_run.py | 59 ++++++++++++------- trio/_core/_tests/test_io.py | 13 ++-- trio/_core/_tests/test_multierror.py | 2 +- .../apport_excepthook.py | 2 +- .../ipython_custom_exc.py | 2 +- .../simple_excepthook.py | 2 +- trio/_deprecate.py | 2 +- trio/_subprocess_platform/waitid.py | 4 +- trio/_tests/check_type_completeness.py | 2 + trio/_tests/test_contextvars.py | 4 +- trio/_tests/test_dtls.py | 4 +- trio/_tests/test_exports.py | 10 +++- trio/_tests/test_highlevel_serve_listeners.py | 2 +- trio/_tests/test_subprocess.py | 14 ++++- trio/_tests/test_threads.py | 12 ++-- trio/_tests/test_tracing.py | 10 ++-- trio/_tests/test_unix_pipes.py | 8 ++- trio/_tests/verify_types.json | 14 +---- trio/_tools/gen_exports.py | 30 +++++++--- trio/_unix_pipes.py | 3 + trio/testing/_fake_net.py | 18 +++--- 30 files changed, 199 insertions(+), 101 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f2422f8eea..74d9108644 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,15 +59,31 @@ disallow_incomplete_defs = false disallow_untyped_defs = false disallow_any_generics = false +# TODO: gen_exports add platform checks to specific files +[[tool.mypy.overrides]] +module = "trio/_core/_generated_run" +disable_error_code = ['has-type'] +[[tool.mypy.overrides]] +module = "trio/_core/_generated_io_kqueue" +disable_error_code = ['name-defined', 'attr-defined', 'no-any-return'] +[[tool.mypy.overrides]] +module = "trio/_core/_generated_io_epoll" +disable_error_code = ['no-any-return'] + [[tool.mypy.overrides]] module = [ +"trio/_core/_tests/*", +"trio/_tests/*", + + +"trio/_windows_pipes", #"trio/_core/_io_common", # 1, 24 "trio/_core/_windows_cffi", # 2, 324 #"trio/_core/_generated_io_epoll", # 3, 36 "trio/_core/_thread_cache", # 6, 273 -#"trio/_core/_generated_io_kqueue", # 6 (darwin), 60 +"trio/_core/_generated_io_kqueue", # 6 (darwin), 60 "trio/_core/_traps", # 7, 276 -"trio/_core/_generated_run", # 8, 242 +#"trio/_core/_generated_run", # 8, 242 "trio/_core/_generated_io_windows", # 9 (win32), 84 "trio/_core/_asyncgens", # 10, 194 @@ -85,6 +101,7 @@ module = [ "trio/testing/_trio_test", # 2, 29 "trio/testing/_checkpoints", # 3, 62 "trio/testing/_check_streams", # 27, 522 +"trio/testing/_fake_net", # 30 "trio/testing/_memory_streams", # 66, 590 "trio/_highlevel_open_unix_stream", # 2, 49 lines @@ -107,6 +124,7 @@ module = [ ] disallow_untyped_defs = false disallow_any_generics = false +disallow_incomplete_defs = false [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index aa898fffe0..26f6f04e7c 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -27,6 +27,7 @@ TASK_STATUS_IGNORED, CancelScope, Nursery, + RunStatistics, Task, TaskStatus, add_instrument, diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index c783452bfc..a1d38519d9 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Iterator +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED @@ -12,12 +12,16 @@ if TYPE_CHECKING: import select + import sys + from contextvars import Context from socket import socket from _contextlib import _GeneratorContextManager - from _core import Abort, RaiseCancelT + from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken + from outcome import Outcome from .. import _core + from .._abc import Clock # fmt: off diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index f35b927737..ea30ddf1fc 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Iterator +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED @@ -12,12 +12,16 @@ if TYPE_CHECKING: import select + import sys + from contextvars import Context from socket import socket from _contextlib import _GeneratorContextManager - from _core import Abort, RaiseCancelT + from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken + from outcome import Outcome from .. import _core + from .._abc import Clock # fmt: off diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index f4b14cd500..57fbd6f423 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Iterator +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED @@ -12,12 +12,16 @@ if TYPE_CHECKING: import select + import sys + from contextvars import Context from socket import socket from _contextlib import _GeneratorContextManager - from _core import Abort, RaiseCancelT + from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken + from outcome import Outcome from .. import _core + from .._abc import Clock # fmt: off diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 90d7ce0d70..4c92661c50 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Iterator +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED @@ -12,12 +12,16 @@ if TYPE_CHECKING: import select + import sys + from contextvars import Context from socket import socket from _contextlib import _GeneratorContextManager - from _core import Abort, RaiseCancelT + from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken + from outcome import Outcome from .. import _core + from .._abc import Clock # fmt: off diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index e644de78fb..62d6eaff68 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Iterator +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED @@ -12,17 +12,21 @@ if TYPE_CHECKING: import select + import sys + from contextvars import Context from socket import socket from _contextlib import _GeneratorContextManager - from _core import Abort, RaiseCancelT + from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken + from outcome import Outcome from .. import _core + from .._abc import Clock # fmt: off -def current_statistics(): +def current_statistics() ->RunStatistics: """Returns an object containing run-loop-level debugging information. Currently the following fields are defined: @@ -52,7 +56,7 @@ def current_statistics(): raise RuntimeError("must be called from async context") -def current_time(): +def current_time() ->float: """Returns the current time according to Trio's internal clock. Returns: @@ -69,7 +73,7 @@ def current_time(): raise RuntimeError("must be called from async context") -def current_clock(): +def current_clock() ->(SystemClock | Clock): """Returns the current :class:`~trio.abc.Clock`.""" locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: @@ -78,7 +82,7 @@ def current_clock(): raise RuntimeError("must be called from async context") -def current_root_task(): +def current_root_task() ->(Task | None): """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. @@ -91,7 +95,7 @@ def current_root_task(): raise RuntimeError("must be called from async context") -def reschedule(task, next_send=_NO_SEND): +def reschedule(task: Task, next_send: Outcome=_NO_SEND) ->None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -116,7 +120,8 @@ def reschedule(task, next_send=_NO_SEND): raise RuntimeError("must be called from async context") -def spawn_system_task(async_fn, *args, name=None, context=None): +def spawn_system_task(async_fn: Callable[..., Awaitable[object]], *args: + Any, name: (str | None)=None, context: (Context | None)=None) ->Task: """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -175,7 +180,7 @@ def spawn_system_task(async_fn, *args, name=None, context=None): raise RuntimeError("must be called from async context") -def current_trio_token(): +def current_trio_token() ->TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. @@ -187,7 +192,7 @@ def current_trio_token(): raise RuntimeError("must be called from async context") -async def wait_all_tasks_blocked(cushion=0.0): +async def wait_all_tasks_blocked(cushion: float=0.0) ->None: """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index cc38d9d537..130403df73 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -25,7 +25,7 @@ class EpollWaiters: current_flags: int = attr.ib(default=0) -assert not TYPE_CHECKING or sys.platform == "linux" or sys.platform == "darwin" +assert not TYPE_CHECKING or sys.platform == "linux" @attr.s(slots=True, eq=False, frozen=True) diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index ee25a748f7..0014ed9ac7 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -17,6 +17,7 @@ from socket import socket from .._core import Abort, RaiseCancelT, Task, UnboundedQueue + assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 4fb1c78048..79ed7ba8e3 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -17,7 +17,7 @@ from math import inf from time import perf_counter from types import TracebackType -from typing import TYPE_CHECKING, Any, NoReturn, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Iterable, NoReturn, TypeVar import attr from outcome import Error, Outcome, Value, capture @@ -49,11 +49,14 @@ from types import FrameType if TYPE_CHECKING: - import contextvars + from contextvars import Context # An unfortunate name collision here with trio._util.Final from typing import Final as FinalT + from .._abc import Clock + from ._mock_clock import MockClock + DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 _NO_SEND: FinalT = object() @@ -120,6 +123,7 @@ def function_with_unique_name_xyzzy() -> NoReturn: CONTEXT_RUN_TB_FRAMES: FinalT = _count_context_run_tb_frames() +# Why doesn't this inherit from abc.Clock? @attr.s(frozen=True, slots=True) class SystemClock: # Add a large random offset to our clock to ensure that if people @@ -1171,7 +1175,7 @@ class Task(metaclass=NoPublicConstructor): coro: Coroutine[Any, Outcome[object], Any] = attr.ib() _runner: Runner = attr.ib() name: str = attr.ib() - context: contextvars.Context = attr.ib() + context: Context = attr.ib() _counter: int = attr.ib(init=False, factory=itertools.count().__next__) # Invariant: @@ -1358,7 +1362,7 @@ class RunContext(threading.local): @attr.s(frozen=True) -class _RunStatistics: +class RunStatistics: tasks_living = attr.ib() tasks_runnable = attr.ib() seconds_to_next_deadline = attr.ib() @@ -1432,7 +1436,7 @@ def in_main_thread(): @attr.s(eq=False, hash=False, slots=True) class Runner: - clock = attr.ib() + clock: SystemClock | Clock | MockClock = attr.ib() instruments: Instruments = attr.ib() io_manager: TheIOManager = attr.ib() ki_manager = attr.ib() @@ -1442,18 +1446,18 @@ class Runner: _locals = attr.ib(factory=dict) runq: deque[Task] = attr.ib(factory=deque) - tasks = attr.ib(factory=set) + tasks: set[Task] = attr.ib(factory=set) deadlines = attr.ib(factory=Deadlines) - init_task = attr.ib(default=None) + init_task: Task | None = attr.ib(default=None) system_nursery = attr.ib(default=None) system_context = attr.ib(default=None) main_task = attr.ib(default=None) main_task_outcome = attr.ib(default=None) entry_queue = attr.ib(factory=EntryQueue) - trio_token = attr.ib(default=None) + trio_token: TrioToken | None = attr.ib(default=None) asyncgens = attr.ib(factory=AsyncGenerators) # If everything goes idle for this long, we call clock._autojump() @@ -1479,7 +1483,7 @@ def close(self): self.ki_manager.close() @_public - def current_statistics(self): + def current_statistics(self) -> RunStatistics: """Returns an object containing run-loop-level debugging information. Currently the following fields are defined: @@ -1503,7 +1507,7 @@ def current_statistics(self): """ seconds_to_next_deadline = self.deadlines.next_deadline() - self.current_time() - return _RunStatistics( + return RunStatistics( tasks_living=len(self.tasks), tasks_runnable=len(self.runq), seconds_to_next_deadline=seconds_to_next_deadline, @@ -1512,7 +1516,7 @@ def current_statistics(self): ) @_public - def current_time(self): + def current_time(self) -> float: """Returns the current time according to Trio's internal clock. Returns: @@ -1524,13 +1528,15 @@ def current_time(self): """ return self.clock.current_time() + # TODO: abc.Clock or SystemClock? (the latter which doesn't inherit + # from abc.Clock) @_public - def current_clock(self): + def current_clock(self) -> SystemClock | Clock: """Returns the current :class:`~trio.abc.Clock`.""" return self.clock @_public - def current_root_task(self): + def current_root_task(self) -> Task | None: """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. @@ -1543,7 +1549,7 @@ def current_root_task(self): ################ @_public - def reschedule(self, task, next_send=_NO_SEND): + def reschedule(self, task: Task, next_send: Outcome = _NO_SEND) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -1577,8 +1583,15 @@ def reschedule(self, task, next_send=_NO_SEND): self.instruments.call("task_scheduled", task) def spawn_impl( - self, async_fn, args, nursery, name, *, system_task=False, context=None - ): + self, + async_fn: Callable[..., Awaitable[object]], + args: Iterable[Any], + nursery: Nursery | None, + name: str | functools.partial | Callable[..., Awaitable[object]] | None, + *, + system_task: bool = False, + context: Context | None = None, + ) -> Task: ###### # Make sure the nursery is in working order ###### @@ -1696,7 +1709,13 @@ def task_exited(self, task, outcome): ################ @_public - def spawn_system_task(self, async_fn, *args, name=None, context=None): + def spawn_system_task( + self, + async_fn: Callable[..., Awaitable[object]], + *args: Any, + name: str | None = None, + context: Context | None = None, + ) -> Task: """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -1795,7 +1814,7 @@ async def init(self, async_fn, args): ################ @_public - def current_trio_token(self): + def current_trio_token(self) -> TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. @@ -1844,7 +1863,7 @@ def _deliver_ki_cb(self): waiting_for_idle = attr.ib(factory=SortedDict) @_public - async def wait_all_tasks_blocked(self, cushion=0.0): + async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a @@ -2311,7 +2330,7 @@ def unrolled_run( break else: assert idle_primed is IdlePrimedTypes.AUTOJUMP_CLOCK - runner.clock._autojump() + runner.clock._autojump() # type: ignore[union-attr] # Process all runnable tasks, but only the ones that are already # runnable now. Anything that becomes runnable during this cycle diff --git a/trio/_core/_tests/test_io.py b/trio/_core/_tests/test_io.py index 21a954941c..2205c83976 100644 --- a/trio/_core/_tests/test_io.py +++ b/trio/_core/_tests/test_io.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import random import socket as stdlib_socket from contextlib import suppress +from typing import Callable import pytest @@ -47,15 +50,15 @@ def fileno_wrapper(fileobj): return fileno_wrapper -wait_readable_options = [trio.lowlevel.wait_readable] -wait_writable_options = [trio.lowlevel.wait_writable] -notify_closing_options = [trio.lowlevel.notify_closing] +wait_readable_options: list[Callable] = [trio.lowlevel.wait_readable] +wait_writable_options: list[Callable] = [trio.lowlevel.wait_writable] +notify_closing_options: list[Callable] = [trio.lowlevel.notify_closing] -for options_list in [ +for options_list in ( wait_readable_options, wait_writable_options, notify_closing_options, -]: +): options_list += [using_fileno(f) for f in options_list] # Decorators that feed in different settings for wait_readable / wait_writable diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py index 7a8bd2f9a8..52e5e39d1b 100644 --- a/trio/_core/_tests/test_multierror.py +++ b/trio/_core/_tests/test_multierror.py @@ -555,7 +555,7 @@ def test_apport_excepthook_monkeypatch_interaction(): @pytest.mark.parametrize("protocol", range(0, pickle.HIGHEST_PROTOCOL + 1)) -def test_pickle_multierror(protocol) -> None: +def test_pickle_multierror(protocol: int) -> None: # use trio.MultiError to make sure that pickle works through the deprecation layer import trio diff --git a/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py b/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py index 3e1d23ca8e..e51b8cdca0 100644 --- a/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py +++ b/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py @@ -12,4 +12,4 @@ import trio -raise trio.MultiError([KeyError("key_error"), ValueError("value_error")]) +raise trio.MultiError([KeyError("key_error"), ValueError("value_error")]) # type: ignore[attr-defined] diff --git a/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py b/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py index 80e42b6a2c..c8086d3a0e 100644 --- a/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py +++ b/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py @@ -33,4 +33,4 @@ def custom_exc_hook(etype, value, tb, tb_offset=None): # The custom excepthook should run, because Trio was polite and didn't # override it -raise trio.MultiError([ValueError(), KeyError()]) +raise trio.MultiError([ValueError(), KeyError()]) # type: ignore[attr-defined] diff --git a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py index 94004525db..c2297df400 100644 --- a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py +++ b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py @@ -18,4 +18,4 @@ def exc2_fn(): # This should be printed nicely, because Trio overrode sys.excepthook -raise trio.MultiError([exc1_fn(), exc2_fn()]) +raise trio.MultiError([exc1_fn(), exc2_fn()]) # type: ignore[attr-defined] diff --git a/trio/_deprecate.py b/trio/_deprecate.py index aeebe80722..7deecb7042 100644 --- a/trio/_deprecate.py +++ b/trio/_deprecate.py @@ -59,7 +59,7 @@ def warn_deprecated(thing, version, *, issue, instead, stacklevel=2): # @deprecated("0.2.0", issue=..., instead=...) # def ... def deprecated( - version: str, *, thing: str | None = None, issue: int, instead: str + version: str, *, thing: str | None = None, issue: int | None, instead: object ) -> Callable[[T], T]: def do_wrap(fn): nonlocal thing diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index ad69017219..f90a3f5b65 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -2,13 +2,15 @@ import math import os import sys +from typing import TYPE_CHECKING from .. import _core, _subprocess from .._sync import CapacityLimiter, Event from .._threads import to_thread_run_sync try: - from os import waitid + if not TYPE_CHECKING or sys.platform == "unix": + from os import waitid def sync_wait_reapable(pid): waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT) diff --git a/trio/_tests/check_type_completeness.py b/trio/_tests/check_type_completeness.py index 7a65a4249e..abaabcf785 100755 --- a/trio/_tests/check_type_completeness.py +++ b/trio/_tests/check_type_completeness.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + # this file is not run as part of the tests, instead it's run standalone from check.sh import argparse import json diff --git a/trio/_tests/test_contextvars.py b/trio/_tests/test_contextvars.py index 63853f5171..0ff13435cf 100644 --- a/trio/_tests/test_contextvars.py +++ b/trio/_tests/test_contextvars.py @@ -2,7 +2,9 @@ from .. import _core -trio_testing_contextvar = contextvars.ContextVar("trio_testing_contextvar") +trio_testing_contextvar: contextvars.ContextVar = contextvars.ContextVar( + "trio_testing_contextvar" +) async def test_contextvars_default(): diff --git a/trio/_tests/test_dtls.py b/trio/_tests/test_dtls.py index b8c32c6d5f..8cb06ccb3d 100644 --- a/trio/_tests/test_dtls.py +++ b/trio/_tests/test_dtls.py @@ -17,10 +17,10 @@ ca = trustme.CA() server_cert = ca.issue_cert("example.com") -server_ctx = SSL.Context(SSL.DTLS_METHOD) +server_ctx = SSL.Context(SSL.DTLS_METHOD) # type: ignore[attr-defined] server_cert.configure_cert(server_ctx) -client_ctx = SSL.Context(SSL.DTLS_METHOD) +client_ctx = SSL.Context(SSL.DTLS_METHOD) # type: ignore[attr-defined] ca.configure_trust(client_ctx) diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py index 20635b0022..6e65a39316 100644 --- a/trio/_tests/test_exports.py +++ b/trio/_tests/test_exports.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum import functools import importlib @@ -175,7 +177,7 @@ def no_underscores(symbols): if modname == "trio": static_names.add("testing") - # these are hidden behind `if sys.plaftorm != "win32" or not TYPE_CHECKING` + # these are hidden behind `if sys.platform != "win32" or not TYPE_CHECKING` # so presumably pyright is parsing that if statement, in which case we don't # care about them being missing. if modname == "trio.socket" and sys.platform == "win32": @@ -226,7 +228,9 @@ def no_underscores(symbols): ) @pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES) @pytest.mark.parametrize("tool", ["jedi", "mypy"]) -def test_static_tool_sees_class_members(tool, module_name, tmpdir) -> None: +def test_static_tool_sees_class_members( + tool: str, module_name: str, tmpdir: Path +) -> None: module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)] # ignore hidden, but not dunder, symbols @@ -483,7 +487,7 @@ def test_classes_are_final(): continue # These are classes that are conceptually abstract, but # inspect.isabstract returns False for boring reasons. - if class_ in {trio.abc.Instrument, trio.socket.SocketType}: + if class_ in (trio.abc.Instrument, trio.socket.SocketType): continue # Enums have their own metaclass, so we can't use our metaclasses. # And I don't think there's a lot of risk from people subclassing diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py index 4385263899..67e2eecbc8 100644 --- a/trio/_tests/test_highlevel_serve_listeners.py +++ b/trio/_tests/test_highlevel_serve_listeners.py @@ -12,7 +12,7 @@ class MemoryListener(trio.abc.Listener): closed = attr.ib(default=False) accepted_streams = attr.ib(factory=list) - queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1))) + queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel[object](1))) accept_hook = attr.ib(default=None) async def connect(self): diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py index 4dfaef4c7f..7986dfd71e 100644 --- a/trio/_tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import random import signal @@ -6,6 +8,7 @@ from contextlib import asynccontextmanager from functools import partial from pathlib import Path as SyncPath +from typing import TYPE_CHECKING import pytest @@ -24,8 +27,15 @@ from ..lowlevel import open_process from ..testing import assert_no_checkpoints, wait_all_tasks_blocked +if TYPE_CHECKING: + ... + from signal import Signals + posix = os.name == "posix" -if posix: +SIGKILL: Signals | None +SIGTERM: Signals | None +SIGUSR1: Signals | None +if (not TYPE_CHECKING and posix) or sys.platform != "win32": from signal import SIGKILL, SIGTERM, SIGUSR1 else: SIGKILL, SIGTERM, SIGUSR1 = None, None, None @@ -574,7 +584,7 @@ async def test_for_leaking_fds(): async def test_subprocess_pidfd_unnotified(): noticed_exit = None - async def wait_and_tell(proc) -> None: + async def wait_and_tell(proc: Process) -> None: nonlocal noticed_exit noticed_exit = Event() await proc.wait() diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 21eb7b12e8..9149f43037 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -170,7 +170,7 @@ async def main(): async def test_named_thread(): ending = " from trio._tests.test_threads.test_named_thread" - def inner(name="inner" + ending) -> threading.Thread: + def inner(name: str = "inner" + ending) -> threading.Thread: assert threading.current_thread().name == name return threading.current_thread() @@ -185,7 +185,7 @@ def f(name: str) -> Callable[[None], threading.Thread]: await to_thread_run_sync(f("None" + ending)) # test that you can set a custom name, and that it's reset afterwards - async def test_thread_name(name: str): + async def test_thread_name(name: str) -> None: thread = await to_thread_run_sync(f(name), thread_name=name) assert re.match("Trio thread [0-9]*", thread.name) @@ -235,7 +235,7 @@ def _get_thread_name(ident: Optional[int] = None) -> Optional[str]: # and most mac machines. So unless the platform is linux it will just skip # in case it fails to fetch the os thread name. async def test_named_thread_os(): - def inner(name) -> threading.Thread: + def inner(name: str) -> threading.Thread: os_thread_name = _get_thread_name() if os_thread_name is None and sys.platform != "linux": pytest.skip(f"no pthread OS support on {sys.platform}") @@ -253,7 +253,7 @@ def f(name: str) -> Callable[[None], threading.Thread]: await to_thread_run_sync(f(default), thread_name=None) # test that you can set a custom name, and that it's reset afterwards - async def test_thread_name(name: str, expected: Optional[str] = None): + async def test_thread_name(name: str, expected: Optional[str] = None) -> None: if expected is None: expected = name thread = await to_thread_run_sync(f(expected), thread_name=name) @@ -584,7 +584,9 @@ async def async_fn(): # pragma: no cover await to_thread_run_sync(async_fn) -trio_test_contextvar = contextvars.ContextVar("trio_test_contextvar") +trio_test_contextvar: contextvars.ContextVar = contextvars.ContextVar( + "trio_test_contextvar" +) async def test_trio_to_thread_run_sync_contextvars(): diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py index 07d1ff7609..e5110eaff3 100644 --- a/trio/_tests/test_tracing.py +++ b/trio/_tests/test_tracing.py @@ -1,26 +1,26 @@ import trio -async def coro1(event: trio.Event): +async def coro1(event: trio.Event) -> None: event.set() await trio.sleep_forever() -async def coro2(event: trio.Event): +async def coro2(event: trio.Event) -> None: await coro1(event) -async def coro3(event: trio.Event): +async def coro3(event: trio.Event) -> None: await coro2(event) -async def coro2_async_gen(event: trio.Event): +async def coro2_async_gen(event): yield await trio.lowlevel.checkpoint() yield await coro1(event) yield await trio.lowlevel.checkpoint() -async def coro3_async_gen(event: trio.Event): +async def coro3_async_gen(event: trio.Event) -> None: async for x in coro2_async_gen(event): pass diff --git a/trio/_tests/test_unix_pipes.py b/trio/_tests/test_unix_pipes.py index acee75aafb..0b0d2ceb23 100644 --- a/trio/_tests/test_unix_pipes.py +++ b/trio/_tests/test_unix_pipes.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import errno import os import select import sys +from typing import TYPE_CHECKING import pytest @@ -11,6 +14,9 @@ posix = os.name == "posix" pytestmark = pytest.mark.skipif(not posix, reason="posix only") + +assert not TYPE_CHECKING or sys.platform == "unix" + if posix: from .._unix_pipes import FdStream else: @@ -19,7 +25,7 @@ # Have to use quoted types so import doesn't crash on windows -async def make_pipe() -> "Tuple[FdStream, FdStream]": +async def make_pipe() -> "tuple[FdStream, FdStream]": """Makes a new pair of pipes.""" (r, w) = os.pipe() return FdStream(w), FdStream(r) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index c454c4d3b9..1824df59e1 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.92, + "completenessScore": 0.9328, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 575, - "withUnknownType": 50 + "withKnownType": 583, + "withUnknownType": 42 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 0, @@ -99,20 +99,13 @@ "trio._subprocess.Process.send_signal", "trio._subprocess.Process.terminate", "trio._subprocess.Process.wait", - "trio.current_time", "trio.from_thread.run", "trio.from_thread.run_sync", "trio.lowlevel.cancel_shielded_checkpoint", - "trio.lowlevel.current_clock", - "trio.lowlevel.current_root_task", - "trio.lowlevel.current_statistics", - "trio.lowlevel.current_trio_token", "trio.lowlevel.currently_ki_protected", "trio.lowlevel.notify_closing", "trio.lowlevel.permanently_detach_coroutine_object", "trio.lowlevel.reattach_detached_coroutine_object", - "trio.lowlevel.reschedule", - "trio.lowlevel.spawn_system_task", "trio.lowlevel.start_guest_run", "trio.lowlevel.temporarily_detach_coroutine_object", "trio.lowlevel.wait_readable", @@ -158,7 +151,6 @@ "trio.testing.memory_stream_pump", "trio.testing.open_stream_to_socket_listener", "trio.testing.trio_test", - "trio.testing.wait_all_tasks_blocked", "trio.to_thread.current_default_thread_limiter", "trio.wrap_file" ] diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 2996cfaaad..0e2dfb7b42 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -3,15 +3,21 @@ Code generation script for class methods to be exported as public API """ +from __future__ import annotations + import argparse import ast import os import sys from pathlib import Path from textwrap import indent +from typing import TYPE_CHECKING, Iterator, Union import astor +if TYPE_CHECKING: + from typing_extensions import TypeAlias, TypeGuard + PREFIX = "_generated" HEADER = """# *********************************************************** @@ -20,18 +26,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Iterator +from typing import TYPE_CHECKING, Callable, Iterator, Awaitable, Any from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT if TYPE_CHECKING: + import sys import select from socket import socket from _contextlib import _GeneratorContextManager - from _core import Abort, RaiseCancelT + from contextvars import Context + from _core import Abort, RaiseCancelT, RunStatistics, Task, SystemClock, TrioToken + from outcome import Outcome + from .._abc import Clock from .. import _core @@ -48,8 +58,10 @@ raise RuntimeError("must be called from async context") """ +AstFun: TypeAlias = Union[ast.FunctionDef, ast.AsyncFunctionDef] + -def is_function(node): +def is_function(node: ast.AST) -> TypeGuard[AstFun]: """Check if the AST node is either a function or an async function """ @@ -58,7 +70,7 @@ def is_function(node): return False -def is_public(node): +def is_public(node: ast.AST) -> TypeGuard[AstFun]: """Check if the AST node has a _public decorator""" if not is_function(node): return False @@ -68,7 +80,7 @@ def is_public(node): return False -def get_public_methods(tree): +def get_public_methods(tree: ast.AST) -> Iterator[AstFun]: """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked @@ -79,7 +91,7 @@ def get_public_methods(tree): yield node -def create_passthrough_args(funcdef): +def create_passthrough_args(funcdef: AstFun) -> str: """Given a function definition, create a string that represents taking all the arguments from the function, and passing them through to another invocation of the same function. @@ -143,7 +155,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: return "\n\n".join(generated) -def matches_disk_files(new_files): +def matches_disk_files(new_files: dict[str, str]) -> bool: for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False @@ -154,7 +166,7 @@ def matches_disk_files(new_files): return True -def process(sources_and_lookups, *, do_test): +def process(sources_and_lookups: list[tuple[Path, str]], *, do_test: bool) -> None: new_files = {} for source_path, lookup_path in sources_and_lookups: print("Scanning:", source_path) @@ -177,7 +189,7 @@ def process(sources_and_lookups, *, do_test): # This is in fact run in CI, but only in the formatting check job, which # doesn't collect coverage. -def main(): # pragma: no cover +def main() -> None: # pragma: no cover parser = argparse.ArgumentParser( description="Generate python code for public api wrappers" ) diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 716550790e..1a389e12dd 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -2,6 +2,7 @@ import errno import os +import sys from typing import TYPE_CHECKING import trio @@ -12,6 +13,8 @@ if TYPE_CHECKING: from typing import Final as FinalType +assert not TYPE_CHECKING or sys.platform != "win32" + if os.name != "posix": # We raise an error here rather than gating the import in lowlevel.py # in order to keep jedi static analysis happy. diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index fdb4d45102..9befedf21b 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -19,7 +19,7 @@ from trio._util import Final, NoPublicConstructor if TYPE_CHECKING: - import socket + from socket import AddressFamily, SocketKind from types import TracebackType IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -105,7 +105,7 @@ def reply(self, payload): class FakeSocketFactory(trio.abc.SocketFactory): fake_net: "FakeNet" - def socket(self, family: int, type: int, proto: int) -> "FakeSocket": + def socket(self, family: int, type: int, proto: int) -> FakeSocket: # type: ignore[override] return FakeSocket._create(self.fake_net, family, type, proto) @@ -123,8 +123,8 @@ async def getaddrinfo( flags: int = 0, ) -> list[ tuple[ - socket.AddressFamily, - socket.SocketKind, + AddressFamily, + SocketKind, int, str, tuple[str, int] | tuple[str, int, int, int], @@ -139,13 +139,13 @@ async def getnameinfo( class FakeNet(metaclass=Final): - def __init__(self): + def __init__(self) -> None: # When we need to pick an arbitrary unique ip address/port, use these: self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() - self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() + self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() # type: ignore[assignment] self._auto_port_iter = iter(range(50000, 65535)) - self._bound: Dict[UDPBinding, FakeSocket] = {} + self._bound: dict[UDPBinding, FakeSocket] = {} self.route_packet = None @@ -193,7 +193,7 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): self._closed = False - self._packet_sender, self._packet_receiver = trio.open_memory_channel( + self._packet_sender, self._packet_receiver = trio.open_memory_channel[object]( float("inf") ) @@ -223,7 +223,7 @@ async def _resolve_address_nocp(self, address, *, local): local=local, ) - def _deliver_packet(self, packet: UDPPacket): + def _deliver_packet(self, packet: UDPPacket) -> None: try: self._packet_sender.send_nowait(packet) except trio.BrokenResourceError: From 2490873beee68557babcc87f10ae7a3232dd6c1b Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 25 Jul 2023 16:13:59 +0200 Subject: [PATCH 23/49] started playing around with windows but mostly gave up. Some types added though --- pyproject.toml | 24 ++++--- trio/_core/_generated_instrumentation.py | 1 + trio/_core/_generated_io_epoll.py | 1 + trio/_core/_generated_io_kqueue.py | 1 + trio/_core/_generated_io_windows.py | 14 ++-- trio/_core/_generated_run.py | 1 + trio/_core/_io_epoll.py | 1 + trio/_core/_io_windows.py | 89 ++++++++++++++---------- trio/_core/_windows_cffi.py | 3 +- trio/_tools/gen_exports.py | 1 + 10 files changed, 83 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 74d9108644..51a1bd8178 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,11 @@ disallow_any_generics = false module = "trio/_core/_generated_run" disable_error_code = ['has-type'] [[tool.mypy.overrides]] -module = "trio/_core/_generated_io_kqueue" +module = [ + "trio/_core/_generated_io_kqueue", + "trio/_core/_generated_io_windows", + ] + disable_error_code = ['name-defined', 'attr-defined', 'no-any-return'] [[tool.mypy.overrides]] module = "trio/_core/_generated_io_epoll" @@ -75,26 +79,28 @@ module = [ "trio/_core/_tests/*", "trio/_tests/*", - +# windows "trio/_windows_pipes", -#"trio/_core/_io_common", # 1, 24 "trio/_core/_windows_cffi", # 2, 324 +"trio/_core/_generated_io_windows", # 9 (win32), 84 +"trio/_core/_io_windows", # 47 (win32), 867 + +#"trio/_core/_io_common", # 1, 24 #"trio/_core/_generated_io_epoll", # 3, 36 +#"trio/_core/_generated_io_kqueue", # 6 (darwin), 60 +#"trio/_core/_generated_run", # 8, 242 +#"trio/_core/_io_kqueue", # 16, 198 +#"trio/_core/_io_epoll", # 21, 323 + "trio/_core/_thread_cache", # 6, 273 -"trio/_core/_generated_io_kqueue", # 6 (darwin), 60 "trio/_core/_traps", # 7, 276 -#"trio/_core/_generated_run", # 8, 242 -"trio/_core/_generated_io_windows", # 9 (win32), 84 "trio/_core/_asyncgens", # 10, 194 "trio/_core/_wakeup_socketpair", # 12 "trio/_core/_ki", # 14, 210 "trio/_core/_entry_queue", # 16, 195 -#"trio/_core/_io_kqueue", # 16, 198 "trio/_core/_multierror", # 19, 469 -#"trio/_core/_io_epoll", # 21, 323 -"trio/_core/_io_windows", # 47 (win32), 867 "trio/testing/_network", # 1, 34 diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index a1d38519d9..bb19c2fbe5 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -22,6 +22,7 @@ from .. import _core from .._abc import Clock + from ._unbounded_queue import UnboundedQueue # fmt: off diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index ea30ddf1fc..17d712726b 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -22,6 +22,7 @@ from .. import _core from .._abc import Clock + from ._unbounded_queue import UnboundedQueue # fmt: off diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 57fbd6f423..bfba052792 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -22,6 +22,7 @@ from .. import _core from .._abc import Clock + from ._unbounded_queue import UnboundedQueue # fmt: off diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 4c92661c50..7072f12aaf 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -22,11 +22,12 @@ from .. import _core from .._abc import Clock + from ._unbounded_queue import UnboundedQueue # fmt: off -async def wait_readable(sock): +async def wait_readable(sock) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) @@ -34,7 +35,7 @@ async def wait_readable(sock): raise RuntimeError("must be called from async context") -async def wait_writable(sock): +async def wait_writable(sock) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) @@ -42,7 +43,7 @@ async def wait_writable(sock): raise RuntimeError("must be called from async context") -def notify_closing(handle): +def notify_closing(handle) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) @@ -50,7 +51,7 @@ def notify_closing(handle): raise RuntimeError("must be called from async context") -def register_with_iocp(handle): +def register_with_iocp(handle) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) @@ -82,7 +83,7 @@ async def readinto_overlapped(handle, buffer, file_offset=0): raise RuntimeError("must be called from async context") -def current_iocp(): +def current_iocp() ->int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() @@ -90,7 +91,8 @@ def current_iocp(): raise RuntimeError("must be called from async context") -def monitor_completion_key(): +def monitor_completion_key() ->_GeneratorContextManager[tuple[int, + UnboundedQueue[object]]]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 62d6eaff68..999ce9d1e5 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -22,6 +22,7 @@ from .. import _core from .._abc import Clock + from ._unbounded_queue import UnboundedQueue # fmt: off diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 130403df73..dfa979ac82 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -20,6 +20,7 @@ @attr.s(slots=True, eq=False) class EpollWaiters: + # TODO: why is nobody complaining about this? read_task: None = attr.ib(default=None) write_task: None = attr.ib(default=None) current_flags: int = attr.ib(default=0) diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 4084f72b6e..66bf6a9349 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import enum import itertools import socket import sys from contextlib import contextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterator, Literal import attr from outcome import Value @@ -29,6 +31,12 @@ assert not TYPE_CHECKING or sys.platform == "win32" +if TYPE_CHECKING: + from _contextlib import _GeneratorContextManager + + from ._traps import Abort, RaiseCancelT + from ._unbouded_queue import UnboundedQueue + # There's a lot to be said about the overall design of a Windows event # loop. See # @@ -179,13 +187,15 @@ class CKeys(enum.IntEnum): USER_DEFINED = 4 # and above -def _check(success): +def _check(success: bool) -> Literal[True]: if not success: raise_winerror() return success -def _get_underlying_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): +def _get_underlying_socket( + sock: socket.socket | int, *, which=WSAIoctls.SIO_BASE_HANDLE +): if hasattr(sock, "fileno"): sock = sock.fileno() base_ptr = ffi.new("HANDLE *") @@ -330,9 +340,9 @@ def _afd_helper_handle(): # operation and start a new one. @attr.s(slots=True, eq=False) class AFDWaiters: - read_task = attr.ib(default=None) - write_task = attr.ib(default=None) - current_op = attr.ib(default=None) + read_task: None = attr.ib(default=None) + write_task: None = attr.ib(default=None) + current_op: None = attr.ib(default=None) # We also need to bundle up all the info for a single op into a standalone @@ -340,10 +350,10 @@ class AFDWaiters: # finishes, even if we're throwing it away. @attr.s(slots=True, eq=False, frozen=True) class AFDPollOp: - lpOverlapped = attr.ib() - poll_info = attr.ib() - waiters = attr.ib() - afd_group = attr.ib() + lpOverlapped: None = attr.ib() + poll_info: None = attr.ib() + waiters: None = attr.ib() + afd_group: None = attr.ib() # The Windows kernel has a weird issue when using AFD handles. If you have N @@ -359,17 +369,17 @@ class AFDPollOp: @attr.s(slots=True, eq=False) class AFDGroup: - size = attr.ib() - handle = attr.ib() + size: int = attr.ib() + handle: None = attr.ib() @attr.s(slots=True, eq=False, frozen=True) class _WindowsStatistics: - tasks_waiting_read = attr.ib() - tasks_waiting_write = attr.ib() - tasks_waiting_overlapped = attr.ib() - completion_key_monitors = attr.ib() - backend = attr.ib(default="windows") + tasks_waiting_read: int = attr.ib() + tasks_waiting_write: int = attr.ib() + tasks_waiting_overlapped: int = attr.ib() + completion_key_monitors: int = attr.ib() + backend: str = attr.ib(default="windows") # Maximum number of events to dequeue from the completion port on each pass @@ -381,8 +391,8 @@ class _WindowsStatistics: @attr.s(frozen=True) class CompletionKeyEventInfo: - lpOverlapped = attr.ib() - dwNumberOfBytesTransferred = attr.ib() + lpOverlapped: None = attr.ib() + dwNumberOfBytesTransferred: int = attr.ib() class WindowsIOManager: @@ -449,7 +459,7 @@ def __init__(self): "netsh winsock show catalog" ) - def close(self): + def close(self) -> None: try: if self._iocp is not None: iocp = self._iocp @@ -460,10 +470,10 @@ def close(self): afd_handle = self._all_afd_handles.pop() _check(kernel32.CloseHandle(afd_handle)) - def __del__(self): + def __del__(self) -> None: self.close() - def statistics(self): + def statistics(self) -> _WindowsStatistics: tasks_waiting_read = 0 tasks_waiting_write = 0 for waiter in self._afd_waiters.values(): @@ -478,7 +488,7 @@ def statistics(self): completion_key_monitors=len(self._completion_key_queues), ) - def force_wakeup(self): + def force_wakeup(self) -> None: _check( kernel32.PostQueuedCompletionStatus( self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL @@ -502,7 +512,7 @@ def get_events(self, timeout): return 0 return received[0] - def process_events(self, received): + def process_events(self, received: int) -> None: for i in range(received): entry = self._events[i] if entry.lpCompletionKey == CKeys.AFD_POLL: @@ -582,7 +592,7 @@ def process_events(self, received): ) queue.put_nowait(info) - def _register_with_iocp(self, handle, completion_key): + def _register_with_iocp(self, handle, completion_key) -> None: handle = _handle(handle) _check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0)) # Supposedly this makes things slightly faster, by disabling the @@ -599,7 +609,7 @@ def _register_with_iocp(self, handle, completion_key): # AFD stuff ################################################################ - def _refresh_afd(self, base_handle): + def _refresh_afd(self, base_handle) -> None: waiters = self._afd_waiters[base_handle] if waiters.current_op is not None: afd_group = waiters.current_op.afd_group @@ -675,7 +685,7 @@ def _refresh_afd(self, base_handle): if afd_group.size >= MAX_AFD_GROUP_SIZE: self._vacant_afd_groups.remove(afd_group) - async def _afd_poll(self, sock, mode): + async def _afd_poll(self, sock, mode) -> None: base_handle = _get_base_socket(sock) waiters = self._afd_waiters.get(base_handle) if waiters is None: @@ -688,7 +698,7 @@ async def _afd_poll(self, sock, mode): # we let it escape. self._refresh_afd(base_handle) - def abort_fn(_): + def abort_fn(_: RaiseCancelT) -> Abort: setattr(waiters, mode, None) self._refresh_afd(base_handle) return _core.Abort.SUCCEEDED @@ -696,15 +706,15 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) @_public - async def wait_readable(self, sock): + async def wait_readable(self, sock) -> None: await self._afd_poll(sock, "read_task") @_public - async def wait_writable(self, sock): + async def wait_writable(self, sock) -> None: await self._afd_poll(sock, "write_task") @_public - def notify_closing(self, handle): + def notify_closing(self, handle) -> None: handle = _get_base_socket(handle) waiters = self._afd_waiters.get(handle) if waiters is not None: @@ -716,7 +726,7 @@ def notify_closing(self, handle): ################################################################ @_public - def register_with_iocp(self, handle): + def register_with_iocp(self, handle) -> None: self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED) @_public @@ -732,7 +742,7 @@ async def wait_overlapped(self, handle, lpOverlapped): self._overlapped_waiters[lpOverlapped] = task raise_cancel = None - def abort(raise_cancel_): + def abort(raise_cancel_: RaiseCancelT) -> Abort: nonlocal raise_cancel raise_cancel = raise_cancel_ try: @@ -852,14 +862,19 @@ def submit_read(lpOverlapped): ################################################################ @_public - def current_iocp(self): + def current_iocp(self) -> int: return int(ffi.cast("uintptr_t", self._iocp)) - @contextmanager @_public - def monitor_completion_key(self): + def monitor_completion_key( + self, + ) -> _GeneratorContextManager[tuple[int, UnboundedQueue[object]]]: + return self._monitor_completion_key() + + @contextmanager + def _monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]: key = next(self._completion_key_counter) - queue = _core.UnboundedQueue() + queue = _core.UnboundedQueue[object]() self._completion_key_queues[key] = queue try: yield (key, queue) diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py index 639e75b50e..50d598c2be 100644 --- a/trio/_core/_windows_cffi.py +++ b/trio/_core/_windows_cffi.py @@ -1,5 +1,6 @@ import enum import re +from typing import NoReturn import cffi @@ -315,7 +316,7 @@ def _handle(obj): return obj -def raise_winerror(winerror=None, *, filename=None, filename2=None): +def raise_winerror(winerror=None, *, filename=None, filename2=None) -> NoReturn: if winerror is None: winerror, msg = ffi.getwinerror() else: diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 0e2dfb7b42..8dd2d4fae3 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -37,6 +37,7 @@ import select from socket import socket + from ._unbounded_queue import UnboundedQueue from _contextlib import _GeneratorContextManager from contextvars import Context from _core import Abort, RaiseCancelT, RunStatistics, Task, SystemClock, TrioToken From a349d998dec5846ee56e3655c9c3dff34e71baf4 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 25 Jul 2023 16:28:52 +0200 Subject: [PATCH 24/49] _core/_thread_cache --- pyproject.toml | 2 +- trio/_core/_thread_cache.py | 27 ++++++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 51a1bd8178..f839b5b78e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ module = [ #"trio/_core/_io_kqueue", # 16, 198 #"trio/_core/_io_epoll", # 21, 323 -"trio/_core/_thread_cache", # 6, 273 +#"trio/_core/_thread_cache", # 6, 273 "trio/_core/_traps", # 7, 276 "trio/_core/_asyncgens", # 10, 194 diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py index 157f14c5a1..823d22a10a 100644 --- a/trio/_core/_thread_cache.py +++ b/trio/_core/_thread_cache.py @@ -7,10 +7,13 @@ from functools import partial from itertools import count from threading import Lock, Thread -from typing import Callable, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple import outcome +if TYPE_CHECKING: + from outcome import Value + def _to_os_thread_name(name: str) -> bytes: # ctypes handles the trailing \00 @@ -116,9 +119,10 @@ def darwin_namefunc( class WorkerThread: def __init__(self, thread_cache: ThreadCache): - # deliver (the second value) can probably be Callable[[outcome.Value], None] ? # should generate stubs for outcome - self._job: Optional[Tuple[Callable, Callable, str]] = None + self._job: Optional[ + Tuple[Callable[[None], None], Callable[[Value], None], str | None] + ] = None self._thread_cache = thread_cache # This Lock is used in an unconventional way. # @@ -136,7 +140,7 @@ def __init__(self, thread_cache: ThreadCache): set_os_thread_name(self._thread.ident, self._default_name) self._thread.start() - def _handle_job(self): + def _handle_job(self) -> None: # Handle job in a separate method to ensure user-created # objects are cleaned up in a consistent manner. assert self._job is not None @@ -167,7 +171,7 @@ def _handle_job(self): print("Exception while delivering result of thread", file=sys.stderr) traceback.print_exception(type(e), e, e.__traceback__) - def _work(self): + def _work(self) -> None: while True: if self._worker_lock.acquire(timeout=IDLE_TIMEOUT): # We got a job @@ -191,11 +195,14 @@ def _work(self): class ThreadCache: - def __init__(self): - self._idle_workers = {} + def __init__(self) -> None: + self._idle_workers: dict[WorkerThread, None] = {} def start_thread_soon( - self, fn: Callable, deliver: Callable, name: Optional[str] = None + self, + fn: Callable[[None], Any] | partial[Any], + deliver: Callable[[Value], None], + name: Optional[str] = None, ) -> None: try: worker, _ = self._idle_workers.popitem() @@ -209,7 +216,9 @@ def start_thread_soon( def start_thread_soon( - fn: Callable, deliver: Callable, name: Optional[str] = None + fn: Callable[[None], None] | partial[Any], + deliver: Callable[[Value], None], + name: Optional[str] = None, ) -> None: """Runs ``deliver(outcome.capture(fn))`` in a worker thread. From 7ed3a7b7d16286d92e247b67c18f99f71390c558 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 25 Jul 2023 16:39:20 +0200 Subject: [PATCH 25/49] type _traps as far as it's possible ... but tests are failing --- pyproject.toml | 2 +- trio/_core/_traps.py | 33 +++++++++++++++++++++++---------- trio/_tests/verify_types.json | 10 +++------- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f839b5b78e..9e156299cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ module = [ #"trio/_core/_io_epoll", # 21, 323 #"trio/_core/_thread_cache", # 6, 273 -"trio/_core/_traps", # 7, 276 +#"trio/_core/_traps", # 7, 276 "trio/_core/_asyncgens", # 10, 194 "trio/_core/_wakeup_socketpair", # 12 diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index 08a8ceac01..77fc9966fe 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -1,14 +1,25 @@ -# These are the only functions that ever yield back to the task runner. +from __future__ import annotations import enum import types -from typing import Any, Callable, NoReturn +from typing import TYPE_CHECKING, Any, Callable, Iterator, NoReturn, TypeVar import attr import outcome from . import _run +# These are the only functions that ever yield back to the task runner. + + +if TYPE_CHECKING: + from outcome import Outcome + from typing_extensions import TypeAlias + + from ._run import Task + +T = TypeVar("T") + # Helper for the bottommost 'yield'. You can't use 'yield' inside an async # function, but you can inside a generator, and if you decorate your generator @@ -18,7 +29,7 @@ # tracking machinery. Since our traps are public APIs, we make them real async # functions, and then this helper takes care of the actual yield: @types.coroutine -def _async_yield(obj): +def _async_yield(obj: T) -> Iterator[T]: return (yield obj) @@ -28,7 +39,7 @@ class CancelShieldedCheckpoint: pass -async def cancel_shielded_checkpoint(): +async def cancel_shielded_checkpoint() -> Any: """Introduce a schedule point, but not a cancel point. This is *not* a :ref:`checkpoint `, but it is half of a @@ -62,10 +73,10 @@ class Abort(enum.Enum): # Not exported in the trio._core namespace, but imported directly by _run. @attr.s(frozen=True) class WaitTaskRescheduled: - abort_func = attr.ib() + abort_func: Callable[[RaiseCancelT], Abort] = attr.ib() -RaiseCancelT = Callable[[], NoReturn] # TypeAlias +RaiseCancelT: TypeAlias = Callable[[], NoReturn] # Should always return the type a Task "expects", unless you willfully reschedule it @@ -175,10 +186,10 @@ def abort(inner_raise_cancel): # Not exported in the trio._core namespace, but imported directly by _run. @attr.s(frozen=True) class PermanentlyDetachCoroutineObject: - final_outcome = attr.ib() + final_outcome: Outcome = attr.ib() -async def permanently_detach_coroutine_object(final_outcome): +async def permanently_detach_coroutine_object(final_outcome: Outcome) -> Any: """Permanently detach the current task from the Trio scheduler. Normally, a Trio task doesn't exit until its coroutine object exits. When @@ -209,7 +220,9 @@ async def permanently_detach_coroutine_object(final_outcome): return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome)) -async def temporarily_detach_coroutine_object(abort_func): +async def temporarily_detach_coroutine_object( + abort_func: Callable[[RaiseCancelT], Abort] +) -> Any: """Temporarily detach the current coroutine object from the Trio scheduler. @@ -245,7 +258,7 @@ async def temporarily_detach_coroutine_object(abort_func): return await _async_yield(WaitTaskRescheduled(abort_func)) -async def reattach_detached_coroutine_object(task, yield_value): +async def reattach_detached_coroutine_object(task: Task, yield_value: object) -> None: """Reattach a coroutine object that was detached using :func:`temporarily_detach_coroutine_object`. diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 1824df59e1..a61c417781 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9328, + "completenessScore": 0.9392, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 583, - "withUnknownType": 42 + "withKnownType": 587, + "withUnknownType": 38 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 0, @@ -101,13 +101,9 @@ "trio._subprocess.Process.wait", "trio.from_thread.run", "trio.from_thread.run_sync", - "trio.lowlevel.cancel_shielded_checkpoint", "trio.lowlevel.currently_ki_protected", "trio.lowlevel.notify_closing", - "trio.lowlevel.permanently_detach_coroutine_object", - "trio.lowlevel.reattach_detached_coroutine_object", "trio.lowlevel.start_guest_run", - "trio.lowlevel.temporarily_detach_coroutine_object", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable", "trio.open_file", From 1974a5cdc67028a5a9ab6e0317d20a4b9a4ade39 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 26 Jul 2023 13:17:50 +0200 Subject: [PATCH 26/49] fix after nit from a5 --- trio/_dtls.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index e8888d7871..9795357f94 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -645,12 +645,12 @@ def challenge_for( return packet -T = TypeVar("T") +_T = TypeVar("_T") -class _Queue(Generic[T]): +class _Queue(Generic[_T]): def __init__(self, incoming_packets_buffer: int | float): - self.s, self.r = trio.open_memory_channel[T](incoming_packets_buffer) + self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer) def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: From d0fae93eaf09f41a45987cfcd3846b14dde8220e Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 26 Jul 2023 13:32:36 +0200 Subject: [PATCH 27/49] fix test errors --- trio/_core/__init__.py | 1 - trio/_core/_run.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index 26f6f04e7c..aa898fffe0 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -27,7 +27,6 @@ TASK_STATUS_IGNORED, CancelScope, Nursery, - RunStatistics, Task, TaskStatus, add_instrument, diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 79ed7ba8e3..463e6a7a1d 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1362,7 +1362,7 @@ class RunContext(threading.local): @attr.s(frozen=True) -class RunStatistics: +class _RunStatistics: tasks_living = attr.ib() tasks_runnable = attr.ib() seconds_to_next_deadline = attr.ib() @@ -1483,7 +1483,7 @@ def close(self): self.ki_manager.close() @_public - def current_statistics(self) -> RunStatistics: + def current_statistics(self) -> _RunStatistics: """Returns an object containing run-loop-level debugging information. Currently the following fields are defined: @@ -1507,7 +1507,7 @@ def current_statistics(self) -> RunStatistics: """ seconds_to_next_deadline = self.deadlines.next_deadline() - self.current_time() - return RunStatistics( + return _RunStatistics( tasks_living=len(self.tasks), tasks_runnable=len(self.runq), seconds_to_next_deadline=seconds_to_next_deadline, From f353058e7364d4c6554f54700089b9dd0205fbb3 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 26 Jul 2023 14:29:49 +0200 Subject: [PATCH 28/49] _asyncgens --- pyproject.toml | 4 ++-- trio/_core/_asyncgens.py | 36 ++++++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9e156299cc..543b8587d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ module = [ "trio/_core/_windows_cffi", # 2, 324 "trio/_core/_generated_io_windows", # 9 (win32), 84 "trio/_core/_io_windows", # 47 (win32), 867 +"trio/_wait_for_object", # 2 (windows) #"trio/_core/_io_common", # 1, 24 #"trio/_core/_generated_io_epoll", # 3, 36 @@ -94,7 +95,7 @@ module = [ #"trio/_core/_thread_cache", # 6, 273 #"trio/_core/_traps", # 7, 276 -"trio/_core/_asyncgens", # 10, 194 +#"trio/_core/_asyncgens", # 10, 194 "trio/_core/_wakeup_socketpair", # 12 "trio/_core/_ki", # 14, 210 @@ -118,7 +119,6 @@ module = [ "trio/_highlevel_open_tcp_stream", # 5, 379 lines "trio/_subprocess_platform/waitid", # 2, 107 lines -"trio/_wait_for_object", # 2 (windows) "trio/_deprecate", # 12, 140lines "trio/_util", # 13, 348 lines "trio/_file_io", # 13, 191 lines diff --git a/trio/_core/_asyncgens.py b/trio/_core/_asyncgens.py index 5f02ebe76d..eacdbb4923 100644 --- a/trio/_core/_asyncgens.py +++ b/trio/_core/_asyncgens.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import sys import warnings @@ -12,6 +14,16 @@ # Used to log exceptions in async generator finalizers ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors") +from typing import TYPE_CHECKING, AsyncGenerator + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from ._run import Runner + +# can this be typed more strictly in any way? +AGenT: TypeAlias = AsyncGenerator[object, object] + @attr.s(eq=False, slots=True) class AsyncGenerators: @@ -22,17 +34,17 @@ class AsyncGenerators: # asyncgens after the system nursery has been closed, it's a # regular set so we don't have to deal with GC firing at # unexpected times. - alive = attr.ib(factory=weakref.WeakSet) + alive: weakref.WeakSet[AGenT] = attr.ib(factory=weakref.WeakSet) # This collects async generators that get garbage collected during # the one-tick window between the system nursery closing and the # init task starting end-of-run asyncgen finalization. - trailing_needs_finalize = attr.ib(factory=set) + trailing_needs_finalize: set[AGenT] = attr.ib(factory=set) prev_hooks = attr.ib(init=False) - def install_hooks(self, runner): - def firstiter(agen): + def install_hooks(self, runner: Runner) -> None: + def firstiter(agen: AGenT) -> None: if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"): self.alive.add(agen) else: @@ -46,7 +58,7 @@ def firstiter(agen): if self.prev_hooks.firstiter is not None: self.prev_hooks.firstiter(agen) - def finalize_in_trio_context(agen, agen_name): + def finalize_in_trio_context(agen: AGenT, agen_name: str) -> None: try: runner.spawn_system_task( self._finalize_one, @@ -61,7 +73,7 @@ def finalize_in_trio_context(agen, agen_name): # have hit it. self.trailing_needs_finalize.add(agen) - def finalizer(agen): + def finalizer(agen: AGenT) -> None: agen_name = name_asyncgen(agen) try: is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen") @@ -99,7 +111,7 @@ def finalizer(agen): try: # If the next thing is a yield, this will raise RuntimeError # which we allow to propagate - closer.send(None) + closer.send(None) # type: ignore[attr-defined] except StopIteration: pass else: @@ -114,7 +126,7 @@ def finalizer(agen): self.prev_hooks = sys.get_asyncgen_hooks() sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer) - async def finalize_remaining(self, runner): + async def finalize_remaining(self, runner: Runner) -> None: # This is called from init after shutting down the system nursery. # The only tasks running at this point are init and # the run_sync_soon task, and since the system nursery is closed, @@ -125,7 +137,7 @@ async def finalize_remaining(self, runner): # To make async generator finalization easier to reason # about, we'll shut down asyncgen garbage collection by turning # the alive WeakSet into a regular set. - self.alive = set(self.alive) + self.alive = set(self.alive) # type: ignore # Process all pending run_sync_soon callbacks, in case one of # them was an asyncgen finalizer that snuck in under the wire. @@ -170,14 +182,14 @@ async def finalize_remaining(self, runner): # all are gone. while self.alive: batch = self.alive - self.alive = set() + self.alive = set() # type: ignore for agen in batch: await self._finalize_one(agen, name_asyncgen(agen)) - def close(self): + def close(self) -> None: sys.set_asyncgen_hooks(*self.prev_hooks) - async def _finalize_one(self, agen, name): + async def _finalize_one(self, agen: AGenT, name: str) -> None: try: # This shield ensures that finalize_asyncgen never exits # with an exception, not even a Cancelled. The inside From 766efdcf9302505232dc1454372edf7409ca3ecd Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 26 Jul 2023 14:32:56 +0200 Subject: [PATCH 29/49] _wakeup_socketpair --- pyproject.toml | 2 +- trio/_core/_wakeup_socketpair.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 543b8587d9..b866125bc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,7 @@ module = [ #"trio/_core/_traps", # 7, 276 #"trio/_core/_asyncgens", # 10, 194 -"trio/_core/_wakeup_socketpair", # 12 +#"trio/_core/_wakeup_socketpair", # 12 "trio/_core/_ki", # 14, 210 "trio/_core/_entry_queue", # 16, 195 "trio/_core/_multierror", # 19, 469 diff --git a/trio/_core/_wakeup_socketpair.py b/trio/_core/_wakeup_socketpair.py index 51a80ef024..2ad1a023fe 100644 --- a/trio/_core/_wakeup_socketpair.py +++ b/trio/_core/_wakeup_socketpair.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import signal import socket import warnings @@ -7,7 +9,7 @@ class WakeupSocketpair: - def __init__(self): + def __init__(self) -> None: self.wakeup_sock, self.write_sock = socket.socketpair() self.wakeup_sock.setblocking(False) self.write_sock.setblocking(False) @@ -27,26 +29,26 @@ def __init__(self): self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) except OSError: pass - self.old_wakeup_fd = None + self.old_wakeup_fd: int | None = None - def wakeup_thread_and_signal_safe(self): + def wakeup_thread_and_signal_safe(self) -> None: try: self.write_sock.send(b"\x00") except BlockingIOError: pass - async def wait_woken(self): + async def wait_woken(self) -> None: await _core.wait_readable(self.wakeup_sock) self.drain() - def drain(self): + def drain(self) -> None: try: while True: self.wakeup_sock.recv(2**16) except BlockingIOError: pass - def wakeup_on_signals(self): + def wakeup_on_signals(self) -> None: assert self.old_wakeup_fd is None if not is_main_thread(): return @@ -64,7 +66,7 @@ def wakeup_on_signals(self): ) ) - def close(self): + def close(self) -> None: self.wakeup_sock.close() self.write_sock.close() if self.old_wakeup_fd is not None: From 6a13de7a657839da7255735a3bfd3b0278328f1a Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 26 Jul 2023 15:51:20 +0200 Subject: [PATCH 30/49] trio/_core/_ki --- pyproject.toml | 2 +- trio/_core/_ki.py | 33 +++++++++++++++++++-------------- trio/_tests/verify_types.json | 7 +++---- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b866125bc8..6f49c6c1ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ module = [ #"trio/_core/_asyncgens", # 10, 194 #"trio/_core/_wakeup_socketpair", # 12 -"trio/_core/_ki", # 14, 210 +#"trio/_core/_ki", # 14, 210 "trio/_core/_entry_queue", # 16, 195 "trio/_core/_multierror", # 19, 469 diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index cc05ef9177..aca321bce9 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -11,6 +11,7 @@ from .._util import is_main_thread if TYPE_CHECKING: + from types import FrameType from typing import Any, Callable, TypeVar F = TypeVar("F", bound=Callable[..., Any]) @@ -85,17 +86,17 @@ # NB: according to the signal.signal docs, 'frame' can be None on entry to # this function: -def ki_protection_enabled(frame): +def ki_protection_enabled(frame: FrameType | None) -> bool: while frame is not None: if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals: - return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] + return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] # type: ignore[no-any-return] if frame.f_code.co_name == "__del__": return True frame = frame.f_back return True -def currently_ki_protected(): +def currently_ki_protected() -> bool: r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection enabled. @@ -115,19 +116,19 @@ def currently_ki_protected(): # functions decorated @async_generator are given this magic property that's a # reference to the object itself # see python-trio/async_generator/async_generator/_impl.py -def legacy_isasyncgenfunction(obj): +def legacy_isasyncgenfunction(obj: object) -> bool: return getattr(obj, "_async_gen_function", None) == id(obj) -def _ki_protection_decorator(enabled): - def decorator(fn): +def _ki_protection_decorator(enabled: bool) -> Callable[[F], F]: + def decorator(fn): # type: ignore[no-untyped-def] # In some version of Python, isgeneratorfunction returns true for # coroutine functions, so we have to check for coroutine functions # first. if inspect.iscoroutinefunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] # See the comment for regular generators below coro = fn(*args, **kwargs) coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled @@ -137,7 +138,7 @@ def wrapper(*args, **kwargs): elif inspect.isgeneratorfunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] # It's important that we inject this directly into the # generator's locals, as opposed to setting it here and then # doing 'yield from'. The reason is, if a generator is @@ -154,7 +155,7 @@ def wrapper(*args, **kwargs): elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] # See the comment for regular generators above agen = fn(*args, **kwargs) agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled @@ -164,7 +165,7 @@ def wrapper(*args, **kwargs): else: @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return fn(*args, **kwargs) @@ -182,9 +183,13 @@ def wrapper(*args, **kwargs): @attr.s class KIManager: - handler = attr.ib(default=None) + handler: Callable[[int, FrameType | None], None] | None = attr.ib(default=None) - def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints): + def install( + self, + deliver_cb: Callable[[], None], + restrict_keyboard_interrupt_to_checkpoints: bool, + ) -> None: assert self.handler is None if ( not is_main_thread() @@ -192,7 +197,7 @@ def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints): ): return - def handler(signum, frame): + def handler(signum: int, frame: FrameType | None) -> None: assert signum == signal.SIGINT protection_enabled = ki_protection_enabled(frame) if protection_enabled or restrict_keyboard_interrupt_to_checkpoints: @@ -203,7 +208,7 @@ def handler(signum, frame): self.handler = handler signal.signal(signal.SIGINT, handler) - def close(self): + def close(self) -> None: if self.handler is not None: if signal.getsignal(signal.SIGINT) is self.handler: signal.signal(signal.SIGINT, signal.default_int_handler) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index a61c417781..b8eb3d5dcd 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9392, + "completenessScore": 0.9408, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 587, - "withUnknownType": 38 + "withKnownType": 588, + "withUnknownType": 37 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 0, @@ -101,7 +101,6 @@ "trio._subprocess.Process.wait", "trio.from_thread.run", "trio.from_thread.run_sync", - "trio.lowlevel.currently_ki_protected", "trio.lowlevel.notify_closing", "trio.lowlevel.start_guest_run", "trio.lowlevel.wait_readable", From 420e98b0aab3e34ed0516b102c4ae45487f87f48 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 26 Jul 2023 16:01:51 +0200 Subject: [PATCH 31/49] trio/_core/_entry_queue --- pyproject.toml | 2 +- trio/_core/_entry_queue.py | 41 ++++++++++++++++++++++------------- trio/_tests/verify_types.json | 11 +++++----- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f49c6c1ae..867d4df1a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ module = [ #"trio/_core/_wakeup_socketpair", # 12 #"trio/_core/_ki", # 14, 210 -"trio/_core/_entry_queue", # 16, 195 +#"trio/_core/_entry_queue", # 16, 195 "trio/_core/_multierror", # 19, 469 diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py index 878506bb2b..553143ceb1 100644 --- a/trio/_core/_entry_queue.py +++ b/trio/_core/_entry_queue.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import threading from collections import deque +from typing import Any, Callable, Iterable, Literal, NoReturn import attr @@ -17,11 +20,13 @@ class EntryQueue: # atomic WRT signal delivery (signal handlers can run on either side, but # not *during* a deque operation). dict makes similar guarantees - and # it's even ordered! - queue = attr.ib(factory=deque) - idempotent_queue = attr.ib(factory=dict) + queue: deque[tuple[Callable[..., Any], Iterable[Any]]] = attr.ib(factory=deque) + idempotent_queue: dict[tuple[Callable[..., Any], Iterable[Any]], None] = attr.ib( + factory=dict + ) - wakeup = attr.ib(factory=WakeupSocketpair) - done = attr.ib(default=False) + wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + done: bool = attr.ib(default=False) # Must be a reentrant lock, because it's acquired from signal handlers. # RLock is signal-safe as of cpython 3.2. NB that this does mean that the # lock is effectively *disabled* when we enter from signal context. The @@ -30,9 +35,9 @@ class EntryQueue: # main thread -- it just might happen at some inconvenient place. But if # you look at the one place where the main thread holds the lock, it's # just to make 1 assignment, so that's atomic WRT a signal anyway. - lock = attr.ib(factory=threading.RLock) + lock: threading.RLock = attr.ib(factory=threading.RLock) - async def task(self): + async def task(self) -> None: assert _core.currently_ki_protected() # RLock has two implementations: a signal-safe version in _thread, and # and signal-UNsafe version in threading. We need the signal safe @@ -43,7 +48,7 @@ async def task(self): # https://bugs.python.org/issue13697#msg237140 assert self.lock.__class__.__module__ == "_thread" - def run_cb(job): + def run_cb(job: tuple[Callable[..., object], Iterable[Any]]) -> Literal[True]: # We run this with KI protection enabled; it's the callback's # job to disable it if it wants it disabled. Exceptions are # treated like system task exceptions (i.e., converted into @@ -53,7 +58,7 @@ def run_cb(job): sync_fn(*args) except BaseException as exc: - async def kill_everything(exc): + async def kill_everything(exc: BaseException) -> NoReturn: raise exc try: @@ -63,14 +68,16 @@ async def kill_everything(exc): # system nursery is already closed. # TODO(2020-06): this is a gross hack and should # be fixed soon when we address #1607. - _core.current_task().parent_nursery.start_soon(kill_everything, exc) + parent_nursery = _core.current_task().parent_nursery + assert parent_nursery is not None + parent_nursery.start_soon(kill_everything, exc) return True # This has to be carefully written to be safe in the face of new items # being queued while we iterate, and to do a bounded amount of work on # each pass: - def run_all_bounded(): + def run_all_bounded() -> None: for _ in range(len(self.queue)): run_cb(self.queue.popleft()) for job in list(self.idempotent_queue): @@ -104,13 +111,15 @@ def run_all_bounded(): assert not self.queue assert not self.idempotent_queue - def close(self): + def close(self) -> None: self.wakeup.close() - def size(self): + def size(self) -> int: return len(self.queue) + len(self.idempotent_queue) - def run_sync_soon(self, sync_fn, *args, idempotent=False): + def run_sync_soon( + self, sync_fn: Callable[..., object], *args: object, idempotent: bool = False + ) -> None: with self.lock: if self.done: raise _core.RunFinishedError("run() has exited") @@ -146,9 +155,11 @@ class TrioToken(metaclass=NoPublicConstructor): """ - _reentry_queue = attr.ib() + _reentry_queue: EntryQueue = attr.ib() - def run_sync_soon(self, sync_fn, *args, idempotent=False): + def run_sync_soon( + self, sync_fn: Callable[..., object], *args: object, idempotent: bool = False + ) -> None: """Schedule a call to ``sync_fn(*args)`` to occur in the context of a Trio task. diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index b8eb3d5dcd..e6e074802c 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9408, + "completenessScore": 0.9424, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 588, - "withUnknownType": 37 + "withKnownType": 589, + "withUnknownType": 36 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 0, @@ -46,12 +46,11 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 552, - "withUnknownType": 67 + "withKnownType": 554, + "withUnknownType": 65 }, "packageName": "trio", "symbols": [ - "trio._core._entry_queue.TrioToken.run_sync_soon", "trio._core._run.Nursery.start", "trio._core._run.Nursery.start_soon", "trio._highlevel_socket.SocketStream.getsockopt", From 2370c3fb0a45ca58d2fcd574e7ce5aa9369aae02 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 27 Jul 2023 15:08:59 +0200 Subject: [PATCH 32/49] type trio/_core/_run --- pyproject.toml | 10 +- trio/_core/_local.py | 9 +- trio/_core/_run.py | 245 +++++++++++++++++++++------------- trio/_core/_thread_cache.py | 8 +- trio/_tests/verify_types.json | 14 +- 5 files changed, 172 insertions(+), 114 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 867d4df1a2..975d1f5fda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,11 +53,11 @@ module = [ disallow_untyped_defs = true disallow_any_generics = true -[[tool.mypy.overrides]] -module = "trio._core._run" -disallow_incomplete_defs = false -disallow_untyped_defs = false -disallow_any_generics = false +#[[tool.mypy.overrides]] +#module = "trio._core._run" +#disallow_incomplete_defs = false +#disallow_untyped_defs = false +#disallow_any_generics = false # TODO: gen_exports add platform checks to specific files [[tool.mypy.overrides]] diff --git a/trio/_core/_local.py b/trio/_core/_local.py index b9dada64fe..1965d44eb1 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -45,8 +45,7 @@ class RunVar(Generic[T], metaclass=Final): def get(self, default: T | type[_NoValue] = _NoValue) -> T: """Gets the value of this :class:`RunVar` for the current run call.""" try: - # not typed yet - return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value, index] + return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[no-any-return] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: @@ -73,7 +72,7 @@ def set(self, value: T) -> RunVarToken[T]: # This can't fail, because if we weren't in Trio context then the # get() above would have failed. - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index] + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value return token def reset(self, token: RunVarToken[T]) -> None: @@ -93,9 +92,9 @@ def reset(self, token: RunVarToken[T]) -> None: previous = token.previous_value try: if previous is _NoValue: - _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type] + _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) else: - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment] + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous except AttributeError: raise RuntimeError("Cannot be used outside of a run context") diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 463e6a7a1d..8c7b89ad2a 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -17,7 +17,16 @@ from math import inf from time import perf_counter from types import TracebackType -from typing import TYPE_CHECKING, Any, Awaitable, Iterable, NoReturn, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Generator, + Iterable, + NoReturn, + Sequence, + TypeVar, +) import attr from outcome import Error, Outcome, Value, capture @@ -54,15 +63,27 @@ # An unfortunate name collision here with trio._util.Final from typing import Final as FinalT + from typing_extensions import Self, TypeAlias + from .._abc import Clock + from ._local import RunVar from ._mock_clock import MockClock + if sys.platform == "win32": + from ._io_windows import _WindowsStatistics + elif sys.platform == "darwin": + from ._io_kqueue import _KqueueStatistics + elif sys.platform == "linux": + from ._io_epoll import _EpollStatistics + DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 _NO_SEND: FinalT = object() FnT = TypeVar("FnT", bound="Callable[..., Any]") +T = TypeVar("T") + # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. @@ -154,7 +175,9 @@ class IdlePrimedTypes(enum.Enum): ################################################################ -def collapse_exception_group(excgroup): +def collapse_exception_group( + excgroup: BaseExceptionGroup[BaseException], +) -> BaseExceptionGroup[BaseException] | BaseException: """Recursively collapse any single-exception groups into that single contained exception. @@ -174,7 +197,7 @@ def collapse_exception_group(excgroup): ) return exceptions[0] elif modified: - return excgroup.derive(exceptions) + return excgroup.derive(exceptions) # type: ignore[no-any-return] else: return excgroup @@ -189,18 +212,18 @@ class Deadlines: """ # Heap of (deadline, id(CancelScope), CancelScope) - _heap = attr.ib(factory=list) + _heap: list[tuple[float, int, CancelScope]] = attr.ib(factory=list) # Count of active deadlines (those that haven't been changed) - _active = attr.ib(default=0) + _active: int = attr.ib(default=0) - def add(self, deadline, cancel_scope): + def add(self, deadline: float, cancel_scope: CancelScope) -> None: heappush(self._heap, (deadline, id(cancel_scope), cancel_scope)) self._active += 1 - def remove(self, deadline, cancel_scope): + def remove(self, deadline: float, cancel_scope: CancelScope) -> None: self._active -= 1 - def next_deadline(self): + def next_deadline(self) -> float: while self._heap: deadline, _, cancel_scope = self._heap[0] if deadline == cancel_scope._registered_deadline: @@ -210,7 +233,7 @@ def next_deadline(self): heappop(self._heap) return inf - def _prune(self): + def _prune(self) -> None: # In principle, it's possible for a cancel scope to toggle back and # forth repeatedly between the same two deadlines, and end up with # lots of stale entries that *look* like they're still active, because @@ -231,7 +254,7 @@ def _prune(self): heapify(pruned_heap) self._heap = pruned_heap - def expire(self, now): + def expire(self, now: float) -> bool: did_something = False while self._heap and self._heap[0][0] <= now: deadline, _, cancel_scope = heappop(self._heap) @@ -382,14 +405,14 @@ def close(self) -> None: child.recalculate() @property - def parent_cancellation_is_visible_to_us(self): + def parent_cancellation_is_visible_to_us(self) -> bool: return ( self._parent is not None and not self._scope.shield and self._parent.effectively_cancelled ) - def recalculate(self): + def recalculate(self) -> None: # This does a depth-first traversal over this and descendent cancel # statuses, to ensure their state is up-to-date. It's basically a # recursive algorithm, but we use an explicit stack to avoid any @@ -408,7 +431,7 @@ def recalculate(self): task._attempt_delivery_of_any_pending_cancel() todo.extend(current._children) - def _mark_abandoned(self): + def _mark_abandoned(self) -> None: self.abandoned_by_misnesting = True for child in self._children: child._mark_abandoned() @@ -496,7 +519,7 @@ class CancelScope(metaclass=Final): _shield: bool = attr.ib(default=False, kw_only=True) @enable_ki_protection - def __enter__(self): + def __enter__(self) -> Self: task = _core.current_task() if self._has_been_entered: raise RuntimeError( @@ -510,7 +533,7 @@ def __enter__(self): task._activate_cancel_status(self._cancel_status) return self - def _close(self, exc): + def _close(self, exc: BaseException | None) -> BaseException | None: if self._cancel_status is None: new_exc = RuntimeError( "Cancel scope stack corrupted: attempted to exit {!r} " @@ -548,6 +571,7 @@ def _close(self, exc): # CancelStatus.close() will take care of the plumbing; # we just need to make sure we don't let the error # pass silently. + assert scope_task._cancel_status is not None new_exc = RuntimeError( "Cancel scope stack corrupted: attempted to exit {!r} " "in {!r} that's still within its child {!r}\n{}".format( @@ -789,10 +813,10 @@ def cancel_called(self) -> bool: # sense. @attr.s(eq=False, hash=False, repr=False) class TaskStatus(metaclass=Final): - _old_nursery = attr.ib() - _new_nursery = attr.ib() - _called_started = attr.ib(default=False) - _value = attr.ib(default=None) + _old_nursery: Nursery = attr.ib() + _new_nursery: Nursery = attr.ib() + _called_started: bool = attr.ib(default=False) + _value: Any = attr.ib(default=None) def __repr__(self) -> str: return f"" @@ -807,6 +831,7 @@ def started(self, value: Any = None) -> None: # will eventually exit on its own, and we don't want to risk moving # children that might have propagating Cancelled exceptions into # a place with no cancelled cancel scopes to catch them. + assert self._old_nursery._cancel_status is not None if self._old_nursery._cancel_status.effectively_cancelled: return @@ -830,6 +855,7 @@ def started(self, value: Any = None) -> None: # do something evil like cancel the old nursery. We thus break # everything off from the old nursery before we start attaching # anything to the new. + assert self._old_nursery._cancel_status is not None cancel_status_children = self._old_nursery._cancel_status.children cancel_status_tasks = set(self._old_nursery._cancel_status.tasks) cancel_status_tasks.discard(self._old_nursery._parent_task) @@ -1002,20 +1028,22 @@ def _add_exc(self, exc: BaseException) -> None: self._pending_excs.append(exc) self.cancel_scope.cancel() - def _check_nursery_closed(self): + def _check_nursery_closed(self) -> None: if not any([self._nested_child_running, self._children, self._pending_starts]): self._closed = True if self._parent_waiting_in_aexit: self._parent_waiting_in_aexit = False GLOBAL_RUN_CONTEXT.runner.reschedule(self._parent_task) - def _child_finished(self, task, outcome): + def _child_finished(self, task: Task, outcome: Outcome) -> None: self._children.remove(task) if isinstance(outcome, Error): self._add_exc(outcome.error) self._check_nursery_closed() - async def _nested_child_finished(self, nested_child_exc): + async def _nested_child_finished( + self, nested_child_exc: BaseException | None + ) -> MultiError | None: # Returns MultiError instance (or any exception if the nursery is in loose mode # and there is just one contained exception) if there are pending exceptions if nested_child_exc is not None: @@ -1053,8 +1081,14 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: # avoid a garbage cycle # (see test_nursery_cancel_doesnt_create_cyclic_garbage) del self._pending_excs + return None - def start_soon(self, async_fn, *args, name=None): + def start_soon( + self, + async_fn: Callable[..., Awaitable[object]], + *args: object, + name: str | None = None, + ) -> None: """Creates a child task, scheduling ``await async_fn(*args)``. If you want to run a function and immediately wait for its result, @@ -1096,7 +1130,12 @@ def start_soon(self, async_fn, *args, name=None): """ GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name) - async def start(self, async_fn, *args, name=None): + async def start( + self, + async_fn: Callable[..., Awaitable[object]], + *args: object, + name: str | None = None, + ) -> Value: r"""Creates and initializes a child task. Like :meth:`start_soon`, but blocks until the new task has @@ -1295,9 +1334,9 @@ def print_stack_for_task(task): # The CancelStatus object that is currently active for this task. # Don't change this directly; instead, use _activate_cancel_status(). - _cancel_status: CancelStatus = attr.ib(default=None, repr=False) + _cancel_status: CancelStatus | None = attr.ib(default=None, repr=False) - def _activate_cancel_status(self, cancel_status: CancelStatus) -> None: + def _activate_cancel_status(self, cancel_status: CancelStatus | None) -> None: if self._cancel_status is not None: self._cancel_status._tasks.remove(self) self._cancel_status = cancel_status @@ -1328,10 +1367,11 @@ def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None: def _attempt_delivery_of_any_pending_cancel(self) -> None: if self._abort_func is None: return + assert self._cancel_status is not None if not self._cancel_status.effectively_cancelled: return - def raise_cancel(): + def raise_cancel() -> NoReturn: raise Cancelled._create() self._attempt_abort(raise_cancel) @@ -1360,14 +1400,24 @@ class RunContext(threading.local): GLOBAL_RUN_CONTEXT: FinalT = RunContext() +if TYPE_CHECKING: + if sys.platform == "win32": + IO_STATISTICS_TYPE: TypeAlias = _WindowsStatistics + elif sys.platform == "darwin": + IO_STATISTICS_TYPE: TypeAlias = _KqueueStatistics + elif sys.platform == "linux": + IO_STATISTICS_TYPE: TypeAlias = _EpollStatistics +else: + IO_STATISTICS_TYPE = None + @attr.s(frozen=True) class _RunStatistics: - tasks_living = attr.ib() - tasks_runnable = attr.ib() - seconds_to_next_deadline = attr.ib() - io_statistics = attr.ib() - run_sync_soon_queue_size = attr.ib() + tasks_living: int = attr.ib() + tasks_runnable: int = attr.ib() + seconds_to_next_deadline: float = attr.ib() + io_statistics: IO_STATISTICS_TYPE = attr.ib() + run_sync_soon_queue_size: int = attr.ib() # This holds all the state that gets trampolined back and forth between @@ -1392,14 +1442,14 @@ class _RunStatistics: @attr.s(eq=False, hash=False, slots=True) class GuestState: runner: Runner = attr.ib() - run_sync_soon_threadsafe: Callable = attr.ib() - run_sync_soon_not_threadsafe: Callable = attr.ib() - done_callback: Callable = attr.ib() - unrolled_run_gen = attr.ib() + run_sync_soon_threadsafe: Callable[[Callable[[], None]], None] = attr.ib() + run_sync_soon_not_threadsafe: Callable[[Callable[[], None]], None] = attr.ib() + done_callback: Callable[[Outcome], None] = attr.ib() + unrolled_run_gen: Generator[float, list[tuple[int, int]], None] = attr.ib() _value_factory: Callable[[], Value] = lambda: Value(None) unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome) - def guest_tick(self): + def guest_tick(self) -> None: try: timeout = self.unrolled_run_next_send.send(self.unrolled_run_gen) except StopIteration: @@ -1420,11 +1470,11 @@ def guest_tick(self): # Need to go into the thread and call get_events() there self.runner.guest_tick_scheduled = False - def get_events(): + def get_events() -> list[tuple[int, int]]: return self.runner.io_manager.get_events(timeout) - def deliver(events_outcome): - def in_main_thread(): + def deliver(events_outcome: Outcome) -> None: + def in_main_thread() -> None: self.unrolled_run_next_send = events_outcome self.runner.guest_tick_scheduled = True self.guest_tick() @@ -1439,41 +1489,41 @@ class Runner: clock: SystemClock | Clock | MockClock = attr.ib() instruments: Instruments = attr.ib() io_manager: TheIOManager = attr.ib() - ki_manager = attr.ib() - strict_exception_groups = attr.ib() + ki_manager: KIManager = attr.ib() + strict_exception_groups: bool = attr.ib() # Run-local values, see _local.py - _locals = attr.ib(factory=dict) + _locals: dict[RunVar[Any], Any] = attr.ib(factory=dict) runq: deque[Task] = attr.ib(factory=deque) tasks: set[Task] = attr.ib(factory=set) - deadlines = attr.ib(factory=Deadlines) + deadlines: Deadlines = attr.ib(factory=Deadlines) init_task: Task | None = attr.ib(default=None) - system_nursery = attr.ib(default=None) - system_context = attr.ib(default=None) - main_task = attr.ib(default=None) - main_task_outcome = attr.ib(default=None) + system_nursery: Nursery | None = attr.ib(default=None) + system_context: Context | None = attr.ib(default=None) + main_task: Task | None = attr.ib(default=None) + main_task_outcome: Outcome | None = attr.ib(default=None) - entry_queue = attr.ib(factory=EntryQueue) + entry_queue: EntryQueue = attr.ib(factory=EntryQueue) trio_token: TrioToken | None = attr.ib(default=None) - asyncgens = attr.ib(factory=AsyncGenerators) + asyncgens: AsyncGenerators = attr.ib(factory=AsyncGenerators) # If everything goes idle for this long, we call clock._autojump() - clock_autojump_threshold = attr.ib(default=inf) + clock_autojump_threshold: float = attr.ib(default=inf) # Guest mode stuff - is_guest = attr.ib(default=False) - guest_tick_scheduled = attr.ib(default=False) + is_guest: bool = attr.ib(default=False) + guest_tick_scheduled: bool = attr.ib(default=False) - def force_guest_tick_asap(self): + def force_guest_tick_asap(self) -> None: if self.guest_tick_scheduled: return self.guest_tick_scheduled = True self.io_manager.force_wakeup() - def close(self): + def close(self) -> None: self.io_manager.close() self.entry_queue.close() self.asyncgens.close() @@ -1587,7 +1637,7 @@ def spawn_impl( async_fn: Callable[..., Awaitable[object]], args: Iterable[Any], nursery: Nursery | None, - name: str | functools.partial | Callable[..., Awaitable[object]] | None, + name: str | functools.partial[Any] | Callable[..., Awaitable[object]] | None, *, system_task: bool = False, context: Context | None = None, @@ -1609,6 +1659,7 @@ def spawn_impl( ###### if context is None: if system_task: + assert self.system_context is not None context = self.system_context.copy() else: context = copy_context() @@ -1635,7 +1686,7 @@ def spawn_impl( if not hasattr(coro, "cr_frame"): # This async function is implemented in C or Cython - async def python_wrapper(orig_coro): + async def python_wrapper(orig_coro: Awaitable[T]) -> T: return await orig_coro coro = python_wrapper(coro) @@ -1660,7 +1711,7 @@ async def python_wrapper(orig_coro): self.reschedule(task, None) return task - def task_exited(self, task, outcome): + def task_exited(self, task: Task, outcome: Outcome) -> None: if ( task._cancel_status is not None and task._cancel_status.abandoned_by_misnesting @@ -1699,6 +1750,7 @@ def task_exited(self, task, outcome): if task is self.main_task: self.main_task_outcome = outcome outcome = Value(None) + assert task._parent_nursery is not None task._parent_nursery._child_finished(task, outcome) if "task_exited" in self.instruments: @@ -1776,7 +1828,9 @@ def spawn_system_task( context=context, ) - async def init(self, async_fn, args): + async def init( + self, async_fn: Callable[..., Awaitable[object]], args: Iterable[object] + ) -> None: # run_sync_soon task runs here: async with open_nursery() as run_sync_soon_nursery: # All other system tasks run here: @@ -1827,7 +1881,7 @@ def current_trio_token(self) -> TrioToken: # KI handling ################ - ki_pending = attr.ib(default=False) + ki_pending: bool = attr.ib(default=False) # deliver_ki is broke. Maybe move all the actual logic and state into # RunToken, and we'll only have one instance per runner? But then we can't @@ -1836,14 +1890,14 @@ def current_trio_token(self) -> TrioToken: # keep the class public so people can isinstance() it if they want. # This gets called from signal context - def deliver_ki(self): + def deliver_ki(self) -> None: self.ki_pending = True try: self.entry_queue.run_sync_soon(self._deliver_ki_cb) except RunFinishedError: pass - def _deliver_ki_cb(self): + def _deliver_ki_cb(self) -> None: if not self.ki_pending: return # Can't happen because main_task and run_sync_soon_task are created at @@ -1860,7 +1914,7 @@ def _deliver_ki_cb(self): # Quiescing ################ - waiting_for_idle = attr.ib(factory=SortedDict) + waiting_for_idle: SortedDict = attr.ib(factory=SortedDict) @_public async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: @@ -2000,11 +2054,11 @@ def abort(_: RaiseCancelT) -> Abort: def setup_runner( - clock, - instruments, - restrict_keyboard_interrupt_to_checkpoints, - strict_exception_groups, -): + clock: Clock | None, + instruments: Sequence[Instrument], + restrict_keyboard_interrupt_to_checkpoints: bool, + strict_exception_groups: bool, +) -> Runner: """Create a Runner object and install it as the GLOBAL_RUN_CONTEXT.""" # It wouldn't be *hard* to support nested calls to run(), but I can't # think of a single good reason for it, so let's be conservative for @@ -2013,15 +2067,17 @@ def setup_runner( raise RuntimeError("Attempted to call run() from inside a run()") if clock is None: - clock = SystemClock() - instruments = Instruments(instruments) + _clock: Clock | SystemClock = SystemClock() + else: + _clock = clock + _instruments = Instruments(instruments) io_manager = TheIOManager() system_context = copy_context() ki_manager = KIManager() runner = Runner( - clock=clock, - instruments=instruments, + clock=_clock, + instruments=_instruments, io_manager=io_manager, system_context=system_context, ki_manager=ki_manager, @@ -2038,13 +2094,13 @@ def setup_runner( def run( - async_fn, - *args, - clock=None, - instruments=(), + async_fn: Callable[..., Awaitable[T]], + *args: object, + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, strict_exception_groups: bool = False, -): +) -> T: """Run a Trio-flavored async function, and return the result. Calling:: @@ -2130,30 +2186,32 @@ def run( next_send = None while True: try: - timeout = gen.send(next_send) + # sending next_send==None here ... should not work?? + timeout = gen.send(next_send) # type: ignore[arg-type] except StopIteration: break next_send = runner.io_manager.get_events(timeout) # Inlined copy of runner.main_task_outcome.unwrap() to avoid # cluttering every single Trio traceback with an extra frame. if isinstance(runner.main_task_outcome, Value): - return runner.main_task_outcome.value + return runner.main_task_outcome.value # type: ignore[no-any-return] else: + assert runner.main_task_outcome is not None raise runner.main_task_outcome.error def start_guest_run( - async_fn, - *args, - run_sync_soon_threadsafe, - done_callback, - run_sync_soon_not_threadsafe=None, + async_fn: Callable[..., Awaitable[object]], + *args: object, + run_sync_soon_threadsafe: Callable[[Callable[..., None]], None], + done_callback: Callable[[Outcome], None], + run_sync_soon_not_threadsafe: Callable[[Callable[..., None]], None] | None = None, host_uses_signal_set_wakeup_fd: bool = False, - clock=None, - instruments=(), + clock: Clock | None = None, + instruments: tuple[Instrument, ...] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, strict_exception_groups: bool = False, -): +) -> None: """Start a "guest" run of Trio on top of some other "host" event loop. Each host loop can only have one guest run at a time. @@ -2241,10 +2299,10 @@ def my_done_callback(run_outcome): # straight through. def unrolled_run( runner: Runner, - async_fn, - args, + async_fn: Callable[..., object], + args: Iterable[object], host_uses_signal_set_wakeup_fd: bool = False, -): +) -> Generator[float, list[tuple[int, int]], None]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True __tracebackhide__ = True @@ -2530,7 +2588,10 @@ def current_effective_deadline() -> float: float: the effective deadline, as an absolute time. """ - return current_task()._cancel_status.effective_deadline() + curr_cancel_status = current_task()._cancel_status + assert curr_cancel_status is not None + return curr_cancel_status.effective_deadline() + # return current_task()._cancel_status.effective_deadline() async def checkpoint() -> None: @@ -2553,6 +2614,7 @@ async def checkpoint() -> None: await cancel_shielded_checkpoint() task = current_task() task._cancel_points += 1 + assert task._cancel_status is not None if task._cancel_status.effectively_cancelled or ( task is task._runner.main_task and task._runner.ki_pending ): @@ -2576,6 +2638,7 @@ async def checkpoint_if_cancelled() -> None: """ task = current_task() + assert task._cancel_status is not None if task._cancel_status.effectively_cancelled or ( task is task._runner.main_task and task._runner.ki_pending ): diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py index 823d22a10a..d66bf7c05d 100644 --- a/trio/_core/_thread_cache.py +++ b/trio/_core/_thread_cache.py @@ -7,7 +7,7 @@ from functools import partial from itertools import count from threading import Lock, Thread -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Optional, Tuple import outcome @@ -121,7 +121,7 @@ class WorkerThread: def __init__(self, thread_cache: ThreadCache): # should generate stubs for outcome self._job: Optional[ - Tuple[Callable[[None], None], Callable[[Value], None], str | None] + Tuple[Callable[[], object], Callable[[Value], None], str | None] ] = None self._thread_cache = thread_cache # This Lock is used in an unconventional way. @@ -200,7 +200,7 @@ def __init__(self) -> None: def start_thread_soon( self, - fn: Callable[[None], Any] | partial[Any], + fn: Callable[[], object] | partial[object], deliver: Callable[[Value], None], name: Optional[str] = None, ) -> None: @@ -216,7 +216,7 @@ def start_thread_soon( def start_thread_soon( - fn: Callable[[None], None] | partial[Any], + fn: Callable[[], object] | partial[object], deliver: Callable[[Value], None], name: Optional[str] = None, ) -> None: diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index e6e074802c..8f75bb13b2 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9424, + "completenessScore": 0.9472, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 589, - "withUnknownType": 36 + "withKnownType": 592, + "withUnknownType": 33 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 0, @@ -46,13 +46,11 @@ ], "otherSymbolCounts": { "withAmbiguousType": 8, - "withKnownType": 554, - "withUnknownType": 65 + "withKnownType": 557, + "withUnknownType": 62 }, "packageName": "trio", "symbols": [ - "trio._core._run.Nursery.start", - "trio._core._run.Nursery.start_soon", "trio._highlevel_socket.SocketStream.getsockopt", "trio._highlevel_socket.SocketStream.send_all", "trio._highlevel_socket.SocketStream.setsockopt", @@ -101,7 +99,6 @@ "trio.from_thread.run", "trio.from_thread.run_sync", "trio.lowlevel.notify_closing", - "trio.lowlevel.start_guest_run", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable", "trio.open_file", @@ -110,7 +107,6 @@ "trio.open_tcp_listeners", "trio.open_tcp_stream", "trio.open_unix_socket", - "trio.run", "trio.run_process", "trio.serve_listeners", "trio.serve_ssl_over_tcp", From ae32ea903249f20ff473edb1b810fddcd3a7f919 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 28 Jul 2023 14:25:02 +0200 Subject: [PATCH 33/49] typecheck trio/_core/_unbounded_queue --- pyproject.toml | 3 +- trio/_core/__init__.py | 2 +- trio/_core/_unbounded_queue.py | 62 ++++++++++++++++++++-------------- trio/_tests/verify_types.json | 18 +++------- 4 files changed, 44 insertions(+), 41 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d479442c7a..1a95434ec5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ disallow_untyped_defs = true [[tool.mypy.overrides]] module = [ "trio._dtls", - "trio._abc" + "trio._abc", + "trio._core._unbounded_queue", ] disallow_incomplete_defs = true disallow_untyped_defs = true diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index aa898fffe0..0325572376 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -62,7 +62,7 @@ temporarily_detach_coroutine_object, wait_task_rescheduled, ) -from ._unbounded_queue import UnboundedQueue +from ._unbounded_queue import UnboundedQueue, UnboundedQueueStats # Windows imports if sys.platform == "win32": diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index 9c747749b4..27d36c965d 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -1,17 +1,34 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + import attr +from typing_extensions import Self from .. import _core from .._deprecate import deprecated from .._util import Final +T = TypeVar("T") + @attr.s(frozen=True) -class _UnboundedQueueStats: - qsize = attr.ib() - tasks_waiting = attr.ib() +class UnboundedQueueStats: + """An object containing debugging information. + + Currently the following fields are defined: + + * ``qsize``: The number of items currently in the queue. + * ``tasks_waiting``: The number of tasks blocked on this queue's + :meth:`get_batch` method. + + """ + + qsize: int = attr.ib() + tasks_waiting: int = attr.ib() -class UnboundedQueue(metaclass=Final): +class UnboundedQueue(Generic[T], metaclass=Final): """An unbounded queue suitable for certain unusual forms of inter-task communication. @@ -41,26 +58,27 @@ class UnboundedQueue(metaclass=Final): """ + # deprecated is not typed @deprecated( "0.9.0", issue=497, thing="trio.lowlevel.UnboundedQueue", instead="trio.open_memory_channel(math.inf)", ) - def __init__(self): + def __init__(self) -> None: # type: ignore[misc] self._lot = _core.ParkingLot() - self._data = [] + self._data: list[T] = [] # used to allow handoff from put to the first task in the lot self._can_get = False - def __repr__(self): + def __repr__(self) -> str: return f"" - def qsize(self): + def qsize(self) -> int: """Returns the number of items currently in the queue.""" return len(self._data) - def empty(self): + def empty(self) -> bool: """Returns True if the queue is empty, False otherwise. There is some subtlety to interpreting this method's return value: see @@ -70,7 +88,7 @@ def empty(self): return not self._data @_core.enable_ki_protection - def put_nowait(self, obj): + def put_nowait(self, obj: T) -> None: """Put an object into the queue, without blocking. This always succeeds, because the queue is unbounded. We don't provide @@ -88,13 +106,13 @@ def put_nowait(self, obj): self._can_get = True self._data.append(obj) - def _get_batch_protected(self): + def _get_batch_protected(self) -> list[T]: data = self._data.copy() self._data.clear() self._can_get = False return data - def get_batch_nowait(self): + def get_batch_nowait(self) -> list[T]: """Attempt to get the next batch from the queue, without blocking. Returns: @@ -110,7 +128,7 @@ def get_batch_nowait(self): raise _core.WouldBlock return self._get_batch_protected() - async def get_batch(self): + async def get_batch(self) -> list[T]: """Get the next batch from the queue, blocking as necessary. Returns: @@ -128,22 +146,14 @@ async def get_batch(self): finally: await _core.cancel_shielded_checkpoint() - def statistics(self): - """Return an object containing debugging information. - - Currently the following fields are defined: - - * ``qsize``: The number of items currently in the queue. - * ``tasks_waiting``: The number of tasks blocked on this queue's - :meth:`get_batch` method. - - """ - return _UnboundedQueueStats( + def statistics(self) -> UnboundedQueueStats: + """Return an UnboundedQueueStats object containing debugging information.""" + return UnboundedQueueStats( qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting ) - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> list[T]: return await self.get_batch() diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index d08c03060c..f1276037a2 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.888, + "completenessScore": 0.8896, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 555, - "withUnknownType": 69 + "withKnownType": 556, + "withUnknownType": 68 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -46,8 +46,8 @@ ], "otherSymbolCounts": { "withAmbiguousType": 3, - "withKnownType": 529, - "withUnknownType": 102 + "withKnownType": 551, + "withUnknownType": 93 }, "packageName": "trio", "symbols": [ @@ -63,14 +63,6 @@ "trio._core._run.Nursery.start_soon", "trio._core._run.TaskStatus.__repr__", "trio._core._run.TaskStatus.started", - "trio._core._unbounded_queue.UnboundedQueue.__aiter__", - "trio._core._unbounded_queue.UnboundedQueue.__anext__", - "trio._core._unbounded_queue.UnboundedQueue.__repr__", - "trio._core._unbounded_queue.UnboundedQueue.empty", - "trio._core._unbounded_queue.UnboundedQueue.get_batch", - "trio._core._unbounded_queue.UnboundedQueue.get_batch_nowait", - "trio._core._unbounded_queue.UnboundedQueue.qsize", - "trio._core._unbounded_queue.UnboundedQueue.statistics", "trio._dtls.DTLSChannel.__init__", "trio._dtls.DTLSEndpoint.__init__", "trio._dtls.DTLSEndpoint.serve", From ea1998894ab26fa4658515b2a2ba414bd7031d41 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 28 Jul 2023 14:40:06 +0200 Subject: [PATCH 34/49] fix CI --- trio/_core/__init__.py | 2 +- trio/_core/_unbounded_queue.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index 0325572376..8e42d2743b 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -62,7 +62,7 @@ temporarily_detach_coroutine_object, wait_task_rescheduled, ) -from ._unbounded_queue import UnboundedQueue, UnboundedQueueStats +from ._unbounded_queue import UnboundedQueue, UnboundedQueueStatistics # Windows imports if sys.platform == "win32": diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index 27d36c965d..7659845a0f 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar import attr -from typing_extensions import Self from .. import _core from .._deprecate import deprecated @@ -11,9 +10,12 @@ T = TypeVar("T") +if TYPE_CHECKING: + from typing_extensions import Self + @attr.s(frozen=True) -class UnboundedQueueStats: +class UnboundedQueueStatistics: """An object containing debugging information. Currently the following fields are defined: @@ -146,9 +148,9 @@ async def get_batch(self) -> list[T]: finally: await _core.cancel_shielded_checkpoint() - def statistics(self) -> UnboundedQueueStats: - """Return an UnboundedQueueStats object containing debugging information.""" - return UnboundedQueueStats( + def statistics(self) -> UnboundedQueueStatistics: + """Return an UnboundedQueueStatistics object containing debugging information.""" + return UnboundedQueueStatistics( qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting ) From b4621e550d81e37c1f1509273d7537cb537bda2b Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 28 Jul 2023 14:46:11 +0200 Subject: [PATCH 35/49] fix test --- trio/_tests/verify_types.json | 4 ++-- trio/lowlevel.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index f1276037a2..af3283c2ef 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,10 +7,10 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.8896, + "completenessScore": 0.889776357827476, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 556, + "withKnownType": 557, "withUnknownType": 68 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/lowlevel.py b/trio/lowlevel.py index 54f4ef3141..36d23d5955 100644 --- a/trio/lowlevel.py +++ b/trio/lowlevel.py @@ -17,6 +17,7 @@ Task as Task, TrioToken as TrioToken, UnboundedQueue as UnboundedQueue, + UnboundedQueueStatistics as UnboundedQueueStatistics, add_instrument as add_instrument, cancel_shielded_checkpoint as cancel_shielded_checkpoint, checkpoint as checkpoint, From d2511967de12feae0868940b8eec3b932a8b571d Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 28 Jul 2023 15:44:37 +0200 Subject: [PATCH 36/49] fix Statistics not defining slots breaking tests --- trio/_core/_unbounded_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index 7659845a0f..94348bfc26 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -14,7 +14,7 @@ from typing_extensions import Self -@attr.s(frozen=True) +@attr.s(slots=True, frozen=True) class UnboundedQueueStatistics: """An object containing debugging information. From 34341c2c3592e4d109aaf0ecad6fabce5a54f6a2 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 28 Jul 2023 15:49:13 +0200 Subject: [PATCH 37/49] Any -> object in _entry_queue --- trio/_core/_entry_queue.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py index 553143ceb1..68e1a89180 100644 --- a/trio/_core/_entry_queue.py +++ b/trio/_core/_entry_queue.py @@ -2,7 +2,7 @@ import threading from collections import deque -from typing import Any, Callable, Iterable, Literal, NoReturn +from typing import Callable, Iterable, Literal, NoReturn import attr @@ -20,10 +20,12 @@ class EntryQueue: # atomic WRT signal delivery (signal handlers can run on either side, but # not *during* a deque operation). dict makes similar guarantees - and # it's even ordered! - queue: deque[tuple[Callable[..., Any], Iterable[Any]]] = attr.ib(factory=deque) - idempotent_queue: dict[tuple[Callable[..., Any], Iterable[Any]], None] = attr.ib( - factory=dict + queue: deque[tuple[Callable[..., object], Iterable[object]]] = attr.ib( + factory=deque ) + idempotent_queue: dict[ + tuple[Callable[..., object], Iterable[object]], None + ] = attr.ib(factory=dict) wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) done: bool = attr.ib(default=False) @@ -48,7 +50,9 @@ async def task(self) -> None: # https://bugs.python.org/issue13697#msg237140 assert self.lock.__class__.__module__ == "_thread" - def run_cb(job: tuple[Callable[..., object], Iterable[Any]]) -> Literal[True]: + def run_cb( + job: tuple[Callable[..., object], Iterable[object]] + ) -> Literal[True]: # We run this with KI protection enabled; it's the callback's # job to disable it if it wants it disabled. Exceptions are # treated like system task exceptions (i.e., converted into From cf20b631a3a6bd9ea0ab20e7d094a256e4a541d2 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 28 Jul 2023 16:34:42 +0200 Subject: [PATCH 38/49] adding py.typed, for the fun of it --- trio/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 trio/py.typed diff --git a/trio/py.typed b/trio/py.typed new file mode 100644 index 0000000000..e69de29bb2 From dc8c18cc1bb02f6a9e454a9adbec15178df1d75c Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 6 Aug 2023 11:50:22 +0200 Subject: [PATCH 39/49] . --- pyproject.toml | 1 - trio/__init__.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dbdf1bff5a..c0619d58e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ disallow_any_unimported = false # DO NOT use `ignore_errors`; it doesn't apply # downstream and users have to deal with them. [[tool.mypy.overrides]] -# Fully typed, enable stricter checks module = [ "trio/_core/_tests/*", "trio/_tests/*", diff --git a/trio/__init__.py b/trio/__init__.py index c193fe58e3..8db5439d70 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -1,5 +1,3 @@ -from __future__ import annotations - """Trio - A friendly Python library for async concurrency and I/O """ from __future__ import annotations From 9bde0a1b2f1351972c0b94a2fc3bbfbe258ad241 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 10 Aug 2023 16:46:06 +0200 Subject: [PATCH 40/49] reset _core/_run in expectation of #2733 --- pyproject.toml | 46 +++-- trio/_core/_generated_run.py | 17 +- trio/_core/_local.py | 8 +- trio/_core/_run.py | 330 ++++++++++++---------------------- trio/_tests/verify_types.json | 21 ++- 5 files changed, 174 insertions(+), 248 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 93e5e71ea9..037e0d4403 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ disallow_untyped_defs = true # Enable gradually / for new modules check_untyped_defs = false disallow_untyped_calls = false -disallow_any_unimported = false +disallow_any_unimported = true @@ -55,6 +55,24 @@ module = [ "trio/_core/_tests/*", "trio/_tests/*", +# 2749 +"trio/_threads", # 15, 398 lines +# 2747 +"trio/testing/_network", # 1, 34 +"trio/testing/_trio_test", # 2, 29 +"trio/testing/_checkpoints", # 3, 62 +"trio/testing/_check_streams", # 27, 522 +"trio/testing/_memory_streams", # 66, 590 +# 2745 +"trio/_ssl", # 26, 929 lines +# 2742 +"trio/_core/_multierror", # 19, 469 +# 2735 trio/_core/_asyncgens + +# 2733 +"trio/_core/_run", +"trio/_core/_generated_run", + # windows "trio/_windows_pipes", "trio/_core/_windows_cffi", # 2, 324 @@ -62,15 +80,8 @@ module = [ "trio/_core/_io_windows", # 47 (win32), 867 "trio/_wait_for_object", # 2 (windows) -"trio/_core/_multierror", # 19, 469 - -"trio/testing/_network", # 1, 34 -"trio/testing/_trio_test", # 2, 29 -"trio/testing/_checkpoints", # 3, 62 -"trio/testing/_check_streams", # 27, 522 "trio/testing/_fake_net", # 30 -"trio/testing/_memory_streams", # 66, 590 "trio/_highlevel_open_unix_stream", # 2, 49 lines "trio/_highlevel_open_tcp_listeners", # 3, 227 lines @@ -80,26 +91,27 @@ module = [ "trio/_subprocess_platform/waitid", # 2, 107 lines "trio/_signals", # 13, 168 lines -"trio/_threads", # 15, 398 lines "trio/_subprocess", # 21, 759 lines -"trio/_ssl", # 26, 929 lines ] -disallow_untyped_defs = false +disallow_any_decorated = false disallow_any_generics = false +disallow_any_unimported = false disallow_incomplete_defs = false -disallow_any_decorated = false +disallow_untyped_defs = false + +[[tool.mypy.overrides]] +# awaiting typing of OutCome +module = [ + "trio._core._traps", +] +disallow_any_unimported = false [[tool.mypy.overrides]] # Needs to use Any due to some complex introspection. module = [ "trio._path", ] -disallow_incomplete_defs = true -disallow_untyped_defs = true disallow_any_generics = false -disallow_any_decorated = true -disallow_any_unimported = true -disallow_subclassing_any = true [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 038ef0e5e2..674c86aaec 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -11,7 +11,7 @@ from ._run import _NO_SEND -def current_statistics() ->_RunStatistics: +def current_statistics(): """Returns an object containing run-loop-level debugging information. Currently the following fields are defined: @@ -41,7 +41,7 @@ def current_statistics() ->_RunStatistics: raise RuntimeError("must be called from async context") -def current_time() ->float: +def current_time(): """Returns the current time according to Trio's internal clock. Returns: @@ -58,7 +58,7 @@ def current_time() ->float: raise RuntimeError("must be called from async context") -def current_clock() ->(SystemClock | Clock): +def current_clock(): """Returns the current :class:`~trio.abc.Clock`.""" locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: @@ -67,7 +67,7 @@ def current_clock() ->(SystemClock | Clock): raise RuntimeError("must be called from async context") -def current_root_task() ->(Task | None): +def current_root_task(): """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. @@ -80,7 +80,7 @@ def current_root_task() ->(Task | None): raise RuntimeError("must be called from async context") -def reschedule(task: Task, next_send: Outcome=_NO_SEND) ->None: +def reschedule(task, next_send=_NO_SEND): """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -105,8 +105,7 @@ def reschedule(task: Task, next_send: Outcome=_NO_SEND) ->None: raise RuntimeError("must be called from async context") -def spawn_system_task(async_fn: Callable[..., Awaitable[object]], *args: - Any, name: (str | None)=None, context: (Context | None)=None) ->Task: +def spawn_system_task(async_fn, *args, name=None, context=None): """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -165,7 +164,7 @@ def spawn_system_task(async_fn: Callable[..., Awaitable[object]], *args: raise RuntimeError("must be called from async context") -def current_trio_token() ->TrioToken: +def current_trio_token(): """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. @@ -177,7 +176,7 @@ def current_trio_token() ->TrioToken: raise RuntimeError("must be called from async context") -async def wait_all_tasks_blocked(cushion: float=0.0) ->None: +async def wait_all_tasks_blocked(cushion=0.0): """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a diff --git a/trio/_core/_local.py b/trio/_core/_local.py index ee74f36c49..4f267ba006 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -44,7 +44,7 @@ def get(self, default: T | type[_NoValue] = _NoValue) -> T: """Gets the value of this :class:`RunVar` for the current run call.""" try: # not typed yet - return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[no-any-return] + return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value, index] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: @@ -72,7 +72,7 @@ def set(self, value: T) -> RunVarToken[T]: # This can't fail, because if we weren't in Trio context then the # get() above would have failed. - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[index,assignment] return token def reset(self, token: RunVarToken[T]) -> None: @@ -92,9 +92,9 @@ def reset(self, token: RunVarToken[T]) -> None: previous = token.previous_value try: if previous is _NoValue: - _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) + _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type] else: - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index,assignment] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 3361dbe375..7d247a2738 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -17,16 +17,7 @@ from math import inf from time import perf_counter from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Generator, - Iterable, - NoReturn, - Sequence, - TypeVar, -) +from typing import TYPE_CHECKING, Any, NoReturn, TypeVar import attr from outcome import Error, Outcome, Value, capture @@ -46,7 +37,6 @@ Abort, CancelShieldedCheckpoint, PermanentlyDetachCoroutineObject, - RaiseCancelT, WaitTaskRescheduled, cancel_shielded_checkpoint, wait_task_rescheduled, @@ -58,34 +48,17 @@ from types import FrameType if TYPE_CHECKING: - from contextvars import Context + import contextvars # An unfortunate name collision here with trio._util.Final from typing import Final as FinalT - from typing_extensions import Self, TypeAlias - - from .._abc import Clock - from ._local import RunVar - from ._mock_clock import MockClock - - if sys.platform == "win32": - from ._io_windows import _WindowsStatistics - elif sys.platform == "darwin": - from select import kevent - - from ._io_kqueue import _KqueueStatistics - elif sys.platform == "linux": - from ._io_epoll import _EpollStatistics - DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 _NO_SEND: FinalT = object() FnT = TypeVar("FnT", bound="Callable[..., Any]") -T = TypeVar("T") - # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. @@ -146,7 +119,6 @@ def function_with_unique_name_xyzzy() -> NoReturn: CONTEXT_RUN_TB_FRAMES: FinalT = _count_context_run_tb_frames() -# Why doesn't this inherit from abc.Clock? @attr.s(frozen=True, slots=True) class SystemClock: # Add a large random offset to our clock to ensure that if people @@ -177,9 +149,7 @@ class IdlePrimedTypes(enum.Enum): ################################################################ -def collapse_exception_group( - excgroup: BaseExceptionGroup[BaseException], -) -> BaseExceptionGroup[BaseException] | BaseException: +def collapse_exception_group(excgroup): """Recursively collapse any single-exception groups into that single contained exception. @@ -199,7 +169,7 @@ def collapse_exception_group( ) return exceptions[0] elif modified: - return excgroup.derive(exceptions) # type: ignore[no-any-return] + return excgroup.derive(exceptions) else: return excgroup @@ -214,18 +184,18 @@ class Deadlines: """ # Heap of (deadline, id(CancelScope), CancelScope) - _heap: list[tuple[float, int, CancelScope]] = attr.ib(factory=list) + _heap = attr.ib(factory=list) # Count of active deadlines (those that haven't been changed) - _active: int = attr.ib(default=0) + _active = attr.ib(default=0) - def add(self, deadline: float, cancel_scope: CancelScope) -> None: + def add(self, deadline, cancel_scope): heappush(self._heap, (deadline, id(cancel_scope), cancel_scope)) self._active += 1 - def remove(self, deadline: float, cancel_scope: CancelScope) -> None: + def remove(self, deadline, cancel_scope): self._active -= 1 - def next_deadline(self) -> float: + def next_deadline(self): while self._heap: deadline, _, cancel_scope = self._heap[0] if deadline == cancel_scope._registered_deadline: @@ -235,7 +205,7 @@ def next_deadline(self) -> float: heappop(self._heap) return inf - def _prune(self) -> None: + def _prune(self): # In principle, it's possible for a cancel scope to toggle back and # forth repeatedly between the same two deadlines, and end up with # lots of stale entries that *look* like they're still active, because @@ -256,7 +226,7 @@ def _prune(self) -> None: heapify(pruned_heap) self._heap = pruned_heap - def expire(self, now: float) -> bool: + def expire(self, now): did_something = False while self._heap and self._heap[0][0] <= now: deadline, _, cancel_scope = heappop(self._heap) @@ -407,14 +377,14 @@ def close(self) -> None: child.recalculate() @property - def parent_cancellation_is_visible_to_us(self) -> bool: + def parent_cancellation_is_visible_to_us(self): return ( self._parent is not None and not self._scope.shield and self._parent.effectively_cancelled ) - def recalculate(self) -> None: + def recalculate(self): # This does a depth-first traversal over this and descendent cancel # statuses, to ensure their state is up-to-date. It's basically a # recursive algorithm, but we use an explicit stack to avoid any @@ -433,7 +403,7 @@ def recalculate(self) -> None: task._attempt_delivery_of_any_pending_cancel() todo.extend(current._children) - def _mark_abandoned(self) -> None: + def _mark_abandoned(self): self.abandoned_by_misnesting = True for child in self._children: child._mark_abandoned() @@ -521,7 +491,7 @@ class CancelScope(metaclass=Final): _shield: bool = attr.ib(default=False, kw_only=True) @enable_ki_protection - def __enter__(self) -> Self: + def __enter__(self): task = _core.current_task() if self._has_been_entered: raise RuntimeError( @@ -535,7 +505,7 @@ def __enter__(self) -> Self: task._activate_cancel_status(self._cancel_status) return self - def _close(self, exc: BaseException | None) -> BaseException | None: + def _close(self, exc): if self._cancel_status is None: new_exc = RuntimeError( "Cancel scope stack corrupted: attempted to exit {!r} " @@ -573,7 +543,6 @@ def _close(self, exc: BaseException | None) -> BaseException | None: # CancelStatus.close() will take care of the plumbing; # we just need to make sure we don't let the error # pass silently. - assert scope_task._cancel_status is not None new_exc = RuntimeError( "Cancel scope stack corrupted: attempted to exit {!r} " "in {!r} that's still within its child {!r}\n{}".format( @@ -815,15 +784,15 @@ def cancel_called(self) -> bool: # sense. @attr.s(eq=False, hash=False, repr=False) class TaskStatus(metaclass=Final): - _old_nursery: Nursery = attr.ib() - _new_nursery: Nursery = attr.ib() - _called_started: bool = attr.ib(default=False) - _value: Any = attr.ib(default=None) + _old_nursery = attr.ib() + _new_nursery = attr.ib() + _called_started = attr.ib(default=False) + _value = attr.ib(default=None) - def __repr__(self) -> str: + def __repr__(self): return f"" - def started(self, value: Any = None) -> None: + def started(self, value=None): if self._called_started: raise RuntimeError("called 'started' twice on the same task status") self._called_started = True @@ -833,7 +802,6 @@ def started(self, value: Any = None) -> None: # will eventually exit on its own, and we don't want to risk moving # children that might have propagating Cancelled exceptions into # a place with no cancelled cancel scopes to catch them. - assert self._old_nursery._cancel_status is not None if self._old_nursery._cancel_status.effectively_cancelled: return @@ -857,7 +825,6 @@ def started(self, value: Any = None) -> None: # do something evil like cancel the old nursery. We thus break # everything off from the old nursery before we start attaching # anything to the new. - assert self._old_nursery._cancel_status is not None cancel_status_children = self._old_nursery._cancel_status.children cancel_status_tasks = set(self._old_nursery._cancel_status.tasks) cancel_status_tasks.discard(self._old_nursery._parent_task) @@ -1030,22 +997,20 @@ def _add_exc(self, exc: BaseException) -> None: self._pending_excs.append(exc) self.cancel_scope.cancel() - def _check_nursery_closed(self) -> None: + def _check_nursery_closed(self): if not any([self._nested_child_running, self._children, self._pending_starts]): self._closed = True if self._parent_waiting_in_aexit: self._parent_waiting_in_aexit = False GLOBAL_RUN_CONTEXT.runner.reschedule(self._parent_task) - def _child_finished(self, task: Task, outcome: Outcome) -> None: + def _child_finished(self, task, outcome): self._children.remove(task) if isinstance(outcome, Error): self._add_exc(outcome.error) self._check_nursery_closed() - async def _nested_child_finished( - self, nested_child_exc: BaseException | None - ) -> MultiError | None: + async def _nested_child_finished(self, nested_child_exc): # Returns MultiError instance (or any exception if the nursery is in loose mode # and there is just one contained exception) if there are pending exceptions if nested_child_exc is not None: @@ -1057,12 +1022,12 @@ async def _nested_child_finished( # If we get cancelled (or have an exception injected, like # KeyboardInterrupt), then save that, but still wait until our # children finish. - def abort(raise_cancel: RaiseCancelT) -> Abort: + def aborted(raise_cancel): self._add_exc(capture(raise_cancel).error) return Abort.FAILED self._parent_waiting_in_aexit = True - await wait_task_rescheduled(abort) + await wait_task_rescheduled(aborted) else: # Nothing to wait for, so just execute a checkpoint -- but we # still need to mix any exception (e.g. from an external @@ -1083,14 +1048,8 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: # avoid a garbage cycle # (see test_nursery_cancel_doesnt_create_cyclic_garbage) del self._pending_excs - return None - def start_soon( - self, - async_fn: Callable[..., Awaitable[object]], - *args: object, - name: str | None = None, - ) -> None: + def start_soon(self, async_fn, *args, name=None): """Creates a child task, scheduling ``await async_fn(*args)``. If you want to run a function and immediately wait for its result, @@ -1132,12 +1091,7 @@ def start_soon( """ GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name) - async def start( - self, - async_fn: Callable[..., Awaitable[object]], - *args: object, - name: str | None = None, - ) -> Value: + async def start(self, async_fn, *args, name=None): r"""Creates and initializes a child task. Like :meth:`start_soon`, but blocks until the new task has @@ -1214,9 +1168,9 @@ def __del__(self) -> None: class Task(metaclass=NoPublicConstructor): _parent_nursery: Nursery | None = attr.ib() coro: Coroutine[Any, Outcome[object], Any] = attr.ib() - _runner: Runner = attr.ib() + _runner = attr.ib() name: str = attr.ib() - context: Context = attr.ib() + context: contextvars.Context = attr.ib() _counter: int = attr.ib(init=False, factory=itertools.count().__next__) # Invariant: @@ -1230,8 +1184,8 @@ class Task(metaclass=NoPublicConstructor): # tracebacks with extraneous frames. # - for scheduled tasks, custom_sleep_data is None # Tasks start out unscheduled. - _next_send_fn: Callable[[Outcome | None], None] = attr.ib(default=None) - _next_send: Outcome | None = attr.ib(default=None) + _next_send_fn = attr.ib(default=None) + _next_send = attr.ib(default=None) _abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib( default=None ) @@ -1336,9 +1290,9 @@ def print_stack_for_task(task): # The CancelStatus object that is currently active for this task. # Don't change this directly; instead, use _activate_cancel_status(). - _cancel_status: CancelStatus | None = attr.ib(default=None, repr=False) + _cancel_status: CancelStatus = attr.ib(default=None, repr=False) - def _activate_cancel_status(self, cancel_status: CancelStatus | None) -> None: + def _activate_cancel_status(self, cancel_status: CancelStatus) -> None: if self._cancel_status is not None: self._cancel_status._tasks.remove(self) self._cancel_status = cancel_status @@ -1369,11 +1323,10 @@ def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None: def _attempt_delivery_of_any_pending_cancel(self) -> None: if self._abort_func is None: return - assert self._cancel_status is not None if not self._cancel_status.effectively_cancelled: return - def raise_cancel() -> NoReturn: + def raise_cancel(): raise Cancelled._create() self._attempt_abort(raise_cancel) @@ -1402,32 +1355,14 @@ class RunContext(threading.local): GLOBAL_RUN_CONTEXT: FinalT = RunContext() -if TYPE_CHECKING: - if sys.platform == "win32": - IO_STATISTICS_TYPE: TypeAlias = _WindowsStatistics - elif sys.platform == "darwin": - IO_STATISTICS_TYPE: TypeAlias = _KqueueStatistics - elif sys.platform == "linux": - IO_STATISTICS_TYPE: TypeAlias = _EpollStatistics -else: - IO_STATISTICS_TYPE = None - @attr.s(frozen=True) class _RunStatistics: - tasks_living: int = attr.ib() - tasks_runnable: int = attr.ib() - seconds_to_next_deadline: float = attr.ib() - io_statistics: IO_STATISTICS_TYPE = attr.ib() - run_sync_soon_queue_size: int = attr.ib() - - -if sys.platform == "linux": - GetEventsT: TypeAlias = "list[tuple[int, int]]" -elif sys.platform == "darwin": - GetEventsT: TypeAlias = "list[kevent]" -else: - GetEventsT: TypeAlias = int + tasks_living = attr.ib() + tasks_runnable = attr.ib() + seconds_to_next_deadline = attr.ib() + io_statistics = attr.ib() + run_sync_soon_queue_size = attr.ib() # This holds all the state that gets trampolined back and forth between @@ -1451,15 +1386,15 @@ class _RunStatistics: # worker thread. @attr.s(eq=False, hash=False, slots=True) class GuestState: - runner: Runner = attr.ib() - run_sync_soon_threadsafe: Callable[[Callable[[], None]], None] = attr.ib() - run_sync_soon_not_threadsafe: Callable[[Callable[[], None]], None] = attr.ib() - done_callback: Callable[[Outcome], None] = attr.ib() - unrolled_run_gen: Generator[float, GetEventsT, None] = attr.ib() + runner = attr.ib() + run_sync_soon_threadsafe = attr.ib() + run_sync_soon_not_threadsafe = attr.ib() + done_callback = attr.ib() + unrolled_run_gen = attr.ib() _value_factory: Callable[[], Value] = lambda: Value(None) unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome) - def guest_tick(self) -> None: + def guest_tick(self): try: timeout = self.unrolled_run_next_send.send(self.unrolled_run_gen) except StopIteration: @@ -1480,11 +1415,11 @@ def guest_tick(self) -> None: # Need to go into the thread and call get_events() there self.runner.guest_tick_scheduled = False - def get_events() -> GetEventsT: + def get_events(): return self.runner.io_manager.get_events(timeout) - def deliver(events_outcome: Outcome) -> None: - def in_main_thread() -> None: + def deliver(events_outcome): + def in_main_thread(): self.unrolled_run_next_send = events_outcome self.runner.guest_tick_scheduled = True self.guest_tick() @@ -1496,44 +1431,44 @@ def in_main_thread() -> None: @attr.s(eq=False, hash=False, slots=True) class Runner: - clock: SystemClock | Clock | MockClock = attr.ib() + clock = attr.ib() instruments: Instruments = attr.ib() io_manager: TheIOManager = attr.ib() - ki_manager: KIManager = attr.ib() - strict_exception_groups: bool = attr.ib() + ki_manager = attr.ib() + strict_exception_groups = attr.ib() # Run-local values, see _local.py - _locals: dict[RunVar[Any], Any] = attr.ib(factory=dict) + _locals = attr.ib(factory=dict) runq: deque[Task] = attr.ib(factory=deque) - tasks: set[Task] = attr.ib(factory=set) + tasks = attr.ib(factory=set) - deadlines: Deadlines = attr.ib(factory=Deadlines) + deadlines = attr.ib(factory=Deadlines) - init_task: Task | None = attr.ib(default=None) - system_nursery: Nursery | None = attr.ib(default=None) - system_context: Context | None = attr.ib(default=None) - main_task: Task | None = attr.ib(default=None) - main_task_outcome: Outcome | None = attr.ib(default=None) + init_task = attr.ib(default=None) + system_nursery = attr.ib(default=None) + system_context = attr.ib(default=None) + main_task = attr.ib(default=None) + main_task_outcome = attr.ib(default=None) - entry_queue: EntryQueue = attr.ib(factory=EntryQueue) - trio_token: TrioToken | None = attr.ib(default=None) - asyncgens: AsyncGenerators = attr.ib(factory=AsyncGenerators) + entry_queue = attr.ib(factory=EntryQueue) + trio_token = attr.ib(default=None) + asyncgens = attr.ib(factory=AsyncGenerators) # If everything goes idle for this long, we call clock._autojump() - clock_autojump_threshold: float = attr.ib(default=inf) + clock_autojump_threshold = attr.ib(default=inf) # Guest mode stuff - is_guest: bool = attr.ib(default=False) - guest_tick_scheduled: bool = attr.ib(default=False) + is_guest = attr.ib(default=False) + guest_tick_scheduled = attr.ib(default=False) - def force_guest_tick_asap(self) -> None: + def force_guest_tick_asap(self): if self.guest_tick_scheduled: return self.guest_tick_scheduled = True self.io_manager.force_wakeup() - def close(self) -> None: + def close(self): self.io_manager.close() self.entry_queue.close() self.asyncgens.close() @@ -1543,7 +1478,7 @@ def close(self) -> None: self.ki_manager.close() @_public - def current_statistics(self) -> _RunStatistics: + def current_statistics(self): """Returns an object containing run-loop-level debugging information. Currently the following fields are defined: @@ -1576,7 +1511,7 @@ def current_statistics(self) -> _RunStatistics: ) @_public - def current_time(self) -> float: + def current_time(self): """Returns the current time according to Trio's internal clock. Returns: @@ -1588,15 +1523,13 @@ def current_time(self) -> float: """ return self.clock.current_time() - # TODO: abc.Clock or SystemClock? (the latter which doesn't inherit - # from abc.Clock) @_public - def current_clock(self) -> SystemClock | Clock: + def current_clock(self): """Returns the current :class:`~trio.abc.Clock`.""" return self.clock @_public - def current_root_task(self) -> Task | None: + def current_root_task(self): """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. @@ -1608,9 +1541,8 @@ def current_root_task(self) -> Task | None: # Core task handling primitives ################ - # Outcome is not typed @_public - def reschedule(self, task: Task, next_send: Outcome = _NO_SEND) -> None: # type: ignore[misc] + def reschedule(self, task, next_send=_NO_SEND): """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -1644,15 +1576,8 @@ def reschedule(self, task: Task, next_send: Outcome = _NO_SEND) -> None: # type self.instruments.call("task_scheduled", task) def spawn_impl( - self, - async_fn: Callable[..., Awaitable[object]], - args: Iterable[Any], - nursery: Nursery | None, - name: str | functools.partial[Any] | Callable[..., Awaitable[object]] | None, - *, - system_task: bool = False, - context: Context | None = None, - ) -> Task: + self, async_fn, args, nursery, name, *, system_task=False, context=None + ): ###### # Make sure the nursery is in working order ###### @@ -1670,7 +1595,6 @@ def spawn_impl( ###### if context is None: if system_task: - assert self.system_context is not None context = self.system_context.copy() else: context = copy_context() @@ -1683,8 +1607,7 @@ def spawn_impl( # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. ###### - # TODO: ?? - coro = context.run(coroutine_or_error, async_fn, *args) # type: ignore[arg-type] + coro = context.run(coroutine_or_error, async_fn, *args) if name is None: name = async_fn @@ -1698,7 +1621,7 @@ def spawn_impl( if not hasattr(coro, "cr_frame"): # This async function is implemented in C or Cython - async def python_wrapper(orig_coro: Awaitable[T]) -> T: + async def python_wrapper(orig_coro): return await orig_coro coro = python_wrapper(coro) @@ -1723,7 +1646,7 @@ async def python_wrapper(orig_coro: Awaitable[T]) -> T: self.reschedule(task, None) return task - def task_exited(self, task: Task, outcome: Outcome) -> None: + def task_exited(self, task, outcome): if ( task._cancel_status is not None and task._cancel_status.abandoned_by_misnesting @@ -1762,7 +1685,6 @@ def task_exited(self, task: Task, outcome: Outcome) -> None: if task is self.main_task: self.main_task_outcome = outcome outcome = Value(None) - assert task._parent_nursery is not None task._parent_nursery._child_finished(task, outcome) if "task_exited" in self.instruments: @@ -1772,15 +1694,8 @@ def task_exited(self, task: Task, outcome: Outcome) -> None: # System tasks and init ################ - # TODO: [misc]typed with Any @_public - def spawn_system_task( # type: ignore[misc] - self, - async_fn: Callable[..., Awaitable[object]], - *args: Any, - name: str | None = None, - context: Context | None = None, - ) -> Task: + def spawn_system_task(self, async_fn, *args, name=None, context=None): """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -1841,9 +1756,7 @@ def spawn_system_task( # type: ignore[misc] context=context, ) - async def init( - self, async_fn: Callable[..., Awaitable[object]], args: Iterable[object] - ) -> None: + async def init(self, async_fn, args): # run_sync_soon task runs here: async with open_nursery() as run_sync_soon_nursery: # All other system tasks run here: @@ -1881,7 +1794,7 @@ async def init( ################ @_public - def current_trio_token(self) -> TrioToken: + def current_trio_token(self): """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. @@ -1894,7 +1807,7 @@ def current_trio_token(self) -> TrioToken: # KI handling ################ - ki_pending: bool = attr.ib(default=False) + ki_pending = attr.ib(default=False) # deliver_ki is broke. Maybe move all the actual logic and state into # RunToken, and we'll only have one instance per runner? But then we can't @@ -1903,14 +1816,14 @@ def current_trio_token(self) -> TrioToken: # keep the class public so people can isinstance() it if they want. # This gets called from signal context - def deliver_ki(self) -> None: + def deliver_ki(self): self.ki_pending = True try: self.entry_queue.run_sync_soon(self._deliver_ki_cb) except RunFinishedError: pass - def _deliver_ki_cb(self) -> None: + def _deliver_ki_cb(self): if not self.ki_pending: return # Can't happen because main_task and run_sync_soon_task are created at @@ -1927,10 +1840,10 @@ def _deliver_ki_cb(self) -> None: # Quiescing ################ - waiting_for_idle: SortedDict = attr.ib(factory=SortedDict) + waiting_for_idle = attr.ib(factory=SortedDict) @_public - async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: + async def wait_all_tasks_blocked(self, cushion=0.0): """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a @@ -1992,7 +1905,7 @@ async def test_lock_fairness(): key = (cushion, id(task)) self.waiting_for_idle[key] = task - def abort(_: RaiseCancelT) -> Abort: + def abort(_): del self.waiting_for_idle[key] return Abort.SUCCEEDED @@ -2067,11 +1980,11 @@ def abort(_: RaiseCancelT) -> Abort: def setup_runner( - clock: Clock | None, - instruments: Sequence[Instrument], - restrict_keyboard_interrupt_to_checkpoints: bool, - strict_exception_groups: bool, -) -> Runner: + clock, + instruments, + restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups, +): """Create a Runner object and install it as the GLOBAL_RUN_CONTEXT.""" # It wouldn't be *hard* to support nested calls to run(), but I can't # think of a single good reason for it, so let's be conservative for @@ -2080,17 +1993,15 @@ def setup_runner( raise RuntimeError("Attempted to call run() from inside a run()") if clock is None: - _clock: Clock | SystemClock = SystemClock() - else: - _clock = clock - _instruments = Instruments(instruments) + clock = SystemClock() + instruments = Instruments(instruments) io_manager = TheIOManager() system_context = copy_context() ki_manager = KIManager() runner = Runner( - clock=_clock, - instruments=_instruments, + clock=clock, + instruments=instruments, io_manager=io_manager, system_context=system_context, ki_manager=ki_manager, @@ -2107,13 +2018,13 @@ def setup_runner( def run( - async_fn: Callable[..., Awaitable[T]], - *args: object, - clock: Clock | None = None, - instruments: Sequence[Instrument] = (), + async_fn, + *args, + clock=None, + instruments=(), restrict_keyboard_interrupt_to_checkpoints: bool = False, strict_exception_groups: bool = False, -) -> T: +): """Run a Trio-flavored async function, and return the result. Calling:: @@ -2199,32 +2110,30 @@ def run( next_send = None while True: try: - # sending next_send==None here ... should not work?? - timeout = gen.send(next_send) # type: ignore[arg-type] + timeout = gen.send(next_send) except StopIteration: break next_send = runner.io_manager.get_events(timeout) # Inlined copy of runner.main_task_outcome.unwrap() to avoid # cluttering every single Trio traceback with an extra frame. if isinstance(runner.main_task_outcome, Value): - return runner.main_task_outcome.value # type: ignore[no-any-return] + return runner.main_task_outcome.value else: - assert runner.main_task_outcome is not None raise runner.main_task_outcome.error def start_guest_run( - async_fn: Callable[..., Awaitable[object]], - *args: object, - run_sync_soon_threadsafe: Callable[[Callable[..., None]], None], - done_callback: Callable[[Outcome], None], - run_sync_soon_not_threadsafe: Callable[[Callable[..., None]], None] | None = None, + async_fn, + *args, + run_sync_soon_threadsafe, + done_callback, + run_sync_soon_not_threadsafe=None, host_uses_signal_set_wakeup_fd: bool = False, - clock: Clock | None = None, - instruments: tuple[Instrument, ...] = (), + clock=None, + instruments=(), restrict_keyboard_interrupt_to_checkpoints: bool = False, strict_exception_groups: bool = False, -) -> None: +): """Start a "guest" run of Trio on top of some other "host" event loop. Each host loop can only have one guest run at a time. @@ -2312,10 +2221,10 @@ def my_done_callback(run_outcome): # straight through. def unrolled_run( runner: Runner, - async_fn: Callable[..., object], - args: Iterable[object], + async_fn, + args, host_uses_signal_set_wakeup_fd: bool = False, -) -> Generator[float, GetEventsT, None]: +): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True __tracebackhide__ = True @@ -2401,7 +2310,7 @@ def unrolled_run( break else: assert idle_primed is IdlePrimedTypes.AUTOJUMP_CLOCK - runner.clock._autojump() # type: ignore[union-attr] + runner.clock._autojump() # Process all runnable tasks, but only the ones that are already # runnable now. Anything that becomes runnable during this cycle @@ -2601,10 +2510,7 @@ def current_effective_deadline() -> float: float: the effective deadline, as an absolute time. """ - curr_cancel_status = current_task()._cancel_status - assert curr_cancel_status is not None - return curr_cancel_status.effective_deadline() - # return current_task()._cancel_status.effective_deadline() + return current_task()._cancel_status.effective_deadline() async def checkpoint() -> None: @@ -2627,7 +2533,6 @@ async def checkpoint() -> None: await cancel_shielded_checkpoint() task = current_task() task._cancel_points += 1 - assert task._cancel_status is not None if task._cancel_status.effectively_cancelled or ( task is task._runner.main_task and task._runner.ki_pending ): @@ -2651,7 +2556,6 @@ async def checkpoint_if_cancelled() -> None: """ task = current_task() - assert task._cancel_status is not None if task._cancel_status.effectively_cancelled or ( task is task._runner.main_task and task._runner.ki_pending ): diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index ce82e1ec65..cf108c4bff 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9473684210526315, + "completenessScore": 0.9330143540669856, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 594, - "withUnknownType": 33 + "withKnownType": 585, + "withUnknownType": 42 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 0, @@ -46,11 +46,17 @@ ], "otherSymbolCounts": { "withAmbiguousType": 5, - "withKnownType": 613, - "withUnknownType": 48 + "withKnownType": 602, + "withUnknownType": 59 }, "packageName": "trio", "symbols": [ + "trio._core._run.Nursery.start", + "trio._core._run.Nursery.start_soon", + "trio._core._run.TaskStatus.__repr__", + "trio._core._run.TaskStatus.started", + "trio._dtls.DTLSChannel.__init__", + "trio._dtls.DTLSEndpoint.serve", "trio._highlevel_socket.SocketStream.getsockopt", "trio._highlevel_socket.SocketStream.send_all", "trio._highlevel_socket.SocketStream.setsockopt", @@ -80,6 +86,7 @@ "trio._subprocess.Process.send_signal", "trio._subprocess.Process.terminate", "trio._subprocess.Process.wait", + "trio.current_time", "trio.from_thread.run", "trio.from_thread.run_sync", "trio.lowlevel.current_clock", @@ -89,14 +96,17 @@ "trio.lowlevel.notify_closing", "trio.lowlevel.reschedule", "trio.lowlevel.spawn_system_task", + "trio.lowlevel.start_guest_run", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable", "trio.open_ssl_over_tcp_listeners", "trio.open_ssl_over_tcp_stream", "trio.open_unix_socket", + "trio.run", "trio.run_process", "trio.serve_listeners", "trio.serve_ssl_over_tcp", + "trio.serve_tcp", "trio.testing._memory_streams.MemoryReceiveStream.__init__", "trio.testing._memory_streams.MemoryReceiveStream.aclose", "trio.testing._memory_streams.MemoryReceiveStream.close", @@ -127,6 +137,7 @@ "trio.testing.memory_stream_pump", "trio.testing.open_stream_to_socket_listener", "trio.testing.trio_test", + "trio.testing.wait_all_tasks_blocked", "trio.to_thread.current_default_thread_limiter" ] } From 048b48a4db86323189654c1ca2d907d8f9ab241c Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 10 Aug 2023 17:13:46 +0200 Subject: [PATCH 41/49] rearrange files in toml --- pyproject.toml | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 037e0d4403..ebbe88f5ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,8 +52,6 @@ disallow_any_unimported = true # downstream and users have to deal with them. [[tool.mypy.overrides]] module = [ -"trio/_core/_tests/*", -"trio/_tests/*", # 2749 "trio/_threads", # 15, 398 lines @@ -67,31 +65,36 @@ module = [ "trio/_ssl", # 26, 929 lines # 2742 "trio/_core/_multierror", # 19, 469 -# 2735 trio/_core/_asyncgens - # 2733 "trio/_core/_run", "trio/_core/_generated_run", +# 2724 +"trio/_highlevel_open_tcp_listeners", # 3, 227 lines +# 2735 trio/_core/_asyncgens -# windows -"trio/_windows_pipes", -"trio/_core/_windows_cffi", # 2, 324 -"trio/_core/_generated_io_windows", # 9 (win32), 84 -"trio/_core/_io_windows", # 47 (win32), 867 -"trio/_wait_for_object", # 2 (windows) - - -"trio/testing/_fake_net", # 30 - +# exported API "trio/_highlevel_open_unix_stream", # 2, 49 lines -"trio/_highlevel_open_tcp_listeners", # 3, 227 lines "trio/_highlevel_serve_listeners", # 3, 121 lines "trio/_highlevel_ssl_helpers", # 3, 155 lines "trio/_highlevel_socket", # 4, 386 lines -"trio/_subprocess_platform/waitid", # 2, 107 lines "trio/_signals", # 13, 168 lines "trio/_subprocess", # 21, 759 lines + +# windows API +"trio/_core/_generated_io_windows", # 9 (win32), 84 +"trio/_core/_io_windows", # 47 (win32), 867 +"trio/_wait_for_object", # 2 (windows) + +# internal +"trio/_windows_pipes", +"trio/_core/_windows_cffi", # 2, 324 + +# tests +"trio/_subprocess_platform/waitid", # 2, 107 lines +"trio/_core/_tests/*", +"trio/_tests/*", +"trio/testing/_fake_net", # 30 ] disallow_any_decorated = false disallow_any_generics = false From 5c0a25cbe651b9fdd2cdb0d91db380cc4352425a Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 10 Aug 2023 17:21:39 +0200 Subject: [PATCH 42/49] remove unnecessary/incorrect diffs relative to master --- .coveragerc | 1 - trio/__init__.py | 1 - trio/_core/_asyncgens.py | 8 +++++--- trio/_socket.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.coveragerc b/.coveragerc index a7e309c11b..431a02971b 100644 --- a/.coveragerc +++ b/.coveragerc @@ -10,7 +10,6 @@ omit= */trio/_core/_generated_* # Script used to check type completeness that isn't run in tests */trio/_tests/check_type_completeness.py - # The test suite spawns subprocesses to test some stuff, so make sure # this doesn't corrupt the coverage files parallel=True diff --git a/trio/__init__.py b/trio/__init__.py index d147012b0a..277baa5339 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -16,7 +16,6 @@ # # Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) - # must be imported early to avoid circular import from ._core import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED # isort: split diff --git a/trio/_core/_asyncgens.py b/trio/_core/_asyncgens.py index 975db63555..4261328278 100644 --- a/trio/_core/_asyncgens.py +++ b/trio/_core/_asyncgens.py @@ -140,7 +140,7 @@ async def finalize_remaining(self, runner: _run.Runner) -> None: # To make async generator finalization easier to reason # about, we'll shut down asyncgen garbage collection by turning # the alive WeakSet into a regular set. - self.alive = set(self.alive) # type: ignore + self.alive = set(self.alive) # Process all pending run_sync_soon callbacks, in case one of # them was an asyncgen finalizer that snuck in under the wire. @@ -185,14 +185,16 @@ async def finalize_remaining(self, runner: _run.Runner) -> None: # all are gone. while self.alive: batch = self.alive - self.alive = set() # type: ignore + self.alive = _ASYNC_GEN_SET() for agen in batch: await self._finalize_one(agen, name_asyncgen(agen)) def close(self) -> None: sys.set_asyncgen_hooks(*self.prev_hooks) - async def _finalize_one(self, agen: AGenT, name: str) -> None: + async def _finalize_one( + self, agen: AsyncGeneratorType[object, NoReturn], name: object + ) -> None: try: # This shield ensures that finalize_asyncgen never exits # with an exception, not even a Cancelled. The inside diff --git a/trio/_socket.py b/trio/_socket.py index cb739bb903..b0ec1d480d 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -979,8 +979,8 @@ async def sendto( ) -> int: ... - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) - async def sendto(self, *args: Any) -> int: # type: ignore[misc] # Any + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] + async def sendto(self, *args: Any) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address) # and kwargs are not accepted From 9f5370564cf78fa216b048d1515d1c3516e84d7a Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 10 Aug 2023 17:34:16 +0200 Subject: [PATCH 43/49] _traps cleanup --- trio/_core/_traps.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index 77fc9966fe..e0f40f8ad6 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -1,3 +1,4 @@ +# These are the only functions that ever yield back to the task runner. from __future__ import annotations import enum @@ -9,9 +10,6 @@ from . import _run -# These are the only functions that ever yield back to the task runner. - - if TYPE_CHECKING: from outcome import Outcome from typing_extensions import TypeAlias From 63a8de7ef38285cddae1cf5653a1cda4cdc46781 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 10 Aug 2023 17:54:06 +0200 Subject: [PATCH 44/49] cleanup --- trio/_core/_generated_io_windows.py | 8 +++++--- trio/_core/_io_windows.py | 15 ++++----------- trio/_core/_windows_cffi.py | 3 +-- trio/_ssl.py | 9 +++------ trio/_tools/gen_exports.py | 14 +++++++++++++- 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 29f4eb56db..301573c6ee 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -8,7 +8,10 @@ from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import GLOBAL_RUN_CONTEXT -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ContextManager + +if TYPE_CHECKING: + from ._unbounded_queue import UnboundedQueue import sys assert not TYPE_CHECKING or sys.platform=="win32" @@ -78,8 +81,7 @@ def current_iocp() ->int: raise RuntimeError("must be called from async context") -def monitor_completion_key() ->_GeneratorContextManager[tuple[int, - UnboundedQueue[object]]]: +def monitor_completion_key() ->ContextManager[tuple[int, UnboundedQueue[object]]]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 7a9f827b0c..74c5ff1552 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -32,10 +32,8 @@ assert not TYPE_CHECKING or sys.platform == "win32" if TYPE_CHECKING: - from _contextlib import _GeneratorContextManager - from ._traps import Abort, RaiseCancelT - from ._unbouded_queue import UnboundedQueue + from ._unbounded_queue import UnboundedQueue # There's a lot to be said about the overall design of a Windows event # loop. See @@ -190,7 +188,7 @@ class CKeys(enum.IntEnum): def _check(success: bool) -> Literal[True]: if not success: raise_winerror() - return success + return True def _get_underlying_socket( @@ -865,14 +863,9 @@ def submit_read(lpOverlapped): def current_iocp(self) -> int: return int(ffi.cast("uintptr_t", self._iocp)) - @_public - def monitor_completion_key( - self, - ) -> _GeneratorContextManager[tuple[int, UnboundedQueue[object]]]: - return self._monitor_completion_key() - @contextmanager - def _monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]: + @_public + def monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]: key = next(self._completion_key_counter) queue = _core.UnboundedQueue[object]() self._completion_key_queues[key] = queue diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py index 50d598c2be..639e75b50e 100644 --- a/trio/_core/_windows_cffi.py +++ b/trio/_core/_windows_cffi.py @@ -1,6 +1,5 @@ import enum import re -from typing import NoReturn import cffi @@ -316,7 +315,7 @@ def _handle(obj): return obj -def raise_winerror(winerror=None, *, filename=None, filename2=None) -> NoReturn: +def raise_winerror(winerror=None, *, filename=None, filename2=None): if winerror is None: winerror, msg = ffi.getwinerror() else: diff --git a/trio/_ssl.py b/trio/_ssl.py index 352f95edaf..bd8b3b06b6 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -148,12 +148,10 @@ # stream) # docs will need to make very clear that this is different from all the other # cancellations in core Trio -from __future__ import annotations import operator as _operator import ssl as _stdlib_ssl from enum import Enum as _Enum -from typing import Any, Awaitable, Callable import trio @@ -211,14 +209,13 @@ class NeedHandshakeError(Exception): class _Once: - # needs TypeVarTuple - def __init__(self, afn: Callable[..., Awaitable[object]], *args: Any): + def __init__(self, afn, *args): self._afn = afn self._args = args self.started = False self._done = _sync.Event() - async def ensure(self, *, checkpoint: bool) -> None: + async def ensure(self, *, checkpoint): if not self.started: self.started = True await self._afn(*self._args) @@ -229,7 +226,7 @@ async def ensure(self, *, checkpoint: bool) -> None: await self._done.wait() @property - def done(self) -> bool: + def done(self): return self._done.is_set() diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index f3ed2e26e7..4517eb7bf9 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -228,7 +228,12 @@ def main() -> None: # pragma: no cover "runner.instruments", imports=IMPORTS_INSTRUMENT, ), - File(core / "_io_windows.py", "runner.io_manager", platform="win32"), + File( + core / "_io_windows.py", + "runner.io_manager", + platform="win32", + imports=IMPORTS_WINDOWS, + ), File( core / "_io_epoll.py", "runner.io_manager", @@ -270,6 +275,13 @@ def main() -> None: # pragma: no cover """ +IMPORTS_WINDOWS = """\ +from typing import TYPE_CHECKING, ContextManager + +if TYPE_CHECKING: + from ._unbounded_queue import UnboundedQueue +""" + if __name__ == "__main__": # pragma: no cover main() From 8a4bc3349fbc7dc359c30f1c6d2fe3e99a99e247 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 17 Aug 2023 13:38:28 +0200 Subject: [PATCH 45/49] revert changes to pyproject.toml, properly merge stuff so Success: no issues found in 129 source files works again --- pyproject.toml | 49 +++++------------------------------ trio/_core/_generated_run.py | 2 +- trio/_core/_local.py | 4 +-- trio/_core/_tests/test_run.py | 4 +-- trio/_tests/test_exports.py | 4 +-- trio/_tests/verify_types.json | 18 ++++--------- trio/_tools/gen_exports.py | 7 ----- 7 files changed, 18 insertions(+), 70 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eec4402e73..d4418ba053 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ disallow_untyped_defs = true # Enable gradually / for new modules check_untyped_defs = false disallow_untyped_calls = false -disallow_any_unimported = true +disallow_any_unimported = false # awaiting Outcome @@ -63,11 +63,6 @@ module = [ "trio/testing/_memory_streams", # 66, 590 # 2745 "trio/_ssl", # 26, 929 lines -# 2742 -"trio/_core/_multierror", # 19, 469 -# 2733 -"trio/_core/_run", -"trio/_core/_generated_run", # 2724 "trio/_highlevel_open_tcp_listeners", # 3, 227 lines @@ -94,37 +89,12 @@ module = [ "trio/_core/_tests/*", "trio/_tests/*", "trio/testing/_fake_net", # 30 - - "trio._abc", - "trio._core._asyncgens", - "trio._core._entry_queue", - "trio._core._generated_run", - "trio._core._generated_io_epoll", - "trio._core._generated_io_kqueue", - "trio._core._io_epoll", - "trio._core._io_kqueue", - "trio._core._local", - "trio._core._multierror", - "trio._core._thread_cache", - "trio._core._unbounded_queue", - "trio._core._run", - "trio._deprecate", - "trio._dtls", - "trio._file_io", - "trio._highlevel_open_tcp_stream", - "trio._ki", - "trio._socket", - "trio._sync", - "trio._tools.gen_exports", - "trio._util", ] -disallow_incomplete_defs = true -disallow_untyped_defs = true -disallow_untyped_decorators = true -disallow_any_generics = true -disallow_any_decorated = true -disallow_any_unimported = false # Enable once outcome has stubs. -disallow_subclassing_any = true +disallow_any_decorated = false +disallow_any_generics = false +disallow_any_unimported = false +disallow_incomplete_defs = false +disallow_untyped_defs = false [[tool.mypy.overrides]] # Needs to use Any due to some complex introspection. @@ -133,13 +103,6 @@ module = [ ] disallow_any_generics = false -[[tool.mypy.overrides]] -# awaiting typing of OutCome -module = [ - "trio._core._traps", -] -disallow_any_unimported = false - [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] faulthandler_timeout = 60 diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 35ecd45a1b..bd5abbd639 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -88,7 +88,7 @@ def current_root_task() ->(Task | None): raise RuntimeError("must be called from async context") -def reschedule(task: Task, next_send: Outcome[Any]=_NO_SEND) ->None: # type: ignore[has-type] +def reschedule(task: Task, next_send: Outcome[Any]=_NO_SEND) ->None: """Reschedule the given task with the given :class:`outcome.Outcome`. diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 83826fc63f..8286a5578f 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -71,7 +71,7 @@ def set(self, value: T) -> RunVarToken[T]: # This can't fail, because if we weren't in Trio context then the # get() above would have failed. - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[index,assignment] + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value return token def reset(self, token: RunVarToken[T]) -> None: @@ -93,7 +93,7 @@ def reset(self, token: RunVarToken[T]) -> None: if previous is _NoValue: _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) else: - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index,assignment] + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous except AttributeError: raise RuntimeError("Cannot be used outside of a run context") diff --git a/trio/_core/_tests/test_run.py b/trio/_core/_tests/test_run.py index 81c3b73cc4..6d34d8f223 100644 --- a/trio/_core/_tests/test_run.py +++ b/trio/_core/_tests/test_run.py @@ -1954,7 +1954,7 @@ async def test_Nursery_private_init(): def test_Nursery_subclass(): with pytest.raises(TypeError): - class Subclass(_core._run.Nursery): + class Subclass(_core._run.Nursery): # type: ignore[misc] pass @@ -1984,7 +1984,7 @@ class Subclass(_core.Cancelled): def test_CancelScope_subclass(): with pytest.raises(TypeError): - class Subclass(_core.CancelScope): + class Subclass(_core.CancelScope): # type: ignore[misc] pass diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py index 1553474c7a..1a145f397e 100644 --- a/trio/_tests/test_exports.py +++ b/trio/_tests/test_exports.py @@ -9,7 +9,7 @@ import sys from pathlib import Path from types import ModuleType -from typing import Protocol +from typing import Dict, Protocol import attrs import pytest @@ -27,7 +27,7 @@ try: # If installed, check both versions of this class. from typing_extensions import Protocol as Protocol_ext except ImportError: # pragma: no cover - Protocol_ext = Protocol + Protocol_ext = Protocol # type: ignore[assignment] def _ensure_mypy_cache_updated(): diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index ac5ab46812..28288652c9 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9330143540669856, + "completenessScore": 0.9570063694267515, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 585, - "withUnknownType": 42 + "withKnownType": 601, + "withUnknownType": 27 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 0, @@ -46,8 +46,8 @@ ], "otherSymbolCounts": { "withAmbiguousType": 5, - "withKnownType": 602, - "withUnknownType": 59 + "withKnownType": 627, + "withUnknownType": 48 }, "packageName": "trio", "symbols": [ @@ -82,14 +82,7 @@ "trio._subprocess.Process.wait", "trio.from_thread.run", "trio.from_thread.run_sync", - "trio.lowlevel.current_clock", - "trio.lowlevel.current_root_task", - "trio.lowlevel.current_statistics", - "trio.lowlevel.current_trio_token", "trio.lowlevel.notify_closing", - "trio.lowlevel.reschedule", - "trio.lowlevel.spawn_system_task", - "trio.lowlevel.start_guest_run", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable", "trio.open_ssl_over_tcp_listeners", @@ -128,7 +121,6 @@ "trio.testing.memory_stream_pump", "trio.testing.open_stream_to_socket_listener", "trio.testing.trio_test", - "trio.testing.wait_all_tasks_blocked", "trio.to_thread.current_default_thread_limiter" ] } diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 43ed8b8bb8..0730c16684 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -158,13 +158,6 @@ def gen_public_wrappers_source(file: File) -> str: if is_cm: # pragma: no cover func = func.replace("->Iterator", "->ContextManager") - # TODO: hacky workaround until we run mypy without `-m`, which breaks imports - # enough that it cannot figure out the type of _NO_SEND - if file.path.stem == "_run" and func.startswith( - "def reschedule" - ): # pragma: no cover - func = func.replace("None:\n", "None: # type: ignore[has-type]\n") - # Create export function body template = TEMPLATE.format( " await " if isinstance(method, ast.AsyncFunctionDef) else " ", From f0e18ad490c40e5094e0ac5a7f26ceca7292bc1c Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 17 Aug 2023 14:04:39 +0200 Subject: [PATCH 46/49] _io_windows changes moved to #2761 --- trio/_core/_generated_io_windows.py | 17 +++----- trio/_core/_io_windows.py | 66 ++++++++++++++--------------- trio/_tools/gen_exports.py | 14 +----- 3 files changed, 39 insertions(+), 58 deletions(-) diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 301573c6ee..7fa6fd5126 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -8,16 +8,13 @@ from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import GLOBAL_RUN_CONTEXT -from typing import TYPE_CHECKING, ContextManager - -if TYPE_CHECKING: - from ._unbounded_queue import UnboundedQueue +from typing import TYPE_CHECKING import sys assert not TYPE_CHECKING or sys.platform=="win32" -async def wait_readable(sock) ->None: +async def wait_readable(sock): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) @@ -25,7 +22,7 @@ async def wait_readable(sock) ->None: raise RuntimeError("must be called from async context") -async def wait_writable(sock) ->None: +async def wait_writable(sock): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) @@ -33,7 +30,7 @@ async def wait_writable(sock) ->None: raise RuntimeError("must be called from async context") -def notify_closing(handle) ->None: +def notify_closing(handle): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) @@ -41,7 +38,7 @@ def notify_closing(handle) ->None: raise RuntimeError("must be called from async context") -def register_with_iocp(handle) ->None: +def register_with_iocp(handle): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) @@ -73,7 +70,7 @@ async def readinto_overlapped(handle, buffer, file_offset=0): raise RuntimeError("must be called from async context") -def current_iocp() ->int: +def current_iocp(): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() @@ -81,7 +78,7 @@ def current_iocp() ->int: raise RuntimeError("must be called from async context") -def monitor_completion_key() ->ContextManager[tuple[int, UnboundedQueue[object]]]: +def monitor_completion_key(): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index ba2f78cfe2..9757d25b5f 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -5,7 +5,7 @@ import socket import sys from contextlib import contextmanager -from typing import TYPE_CHECKING, Iterator, Literal +from typing import TYPE_CHECKING, Literal import attr from outcome import Value @@ -32,8 +32,6 @@ assert not TYPE_CHECKING or sys.platform == "win32" if TYPE_CHECKING: - from ._traps import Abort, RaiseCancelT - from ._unbounded_queue import UnboundedQueue from typing_extensions import TypeAlias EventResult: TypeAlias = int @@ -187,15 +185,13 @@ class CKeys(enum.IntEnum): USER_DEFINED = 4 # and above -def _check(success: bool) -> Literal[True]: +def _check(success): if not success: raise_winerror() - return True + return success -def _get_underlying_socket( - sock: socket.socket | int, *, which=WSAIoctls.SIO_BASE_HANDLE -): +def _get_underlying_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): if hasattr(sock, "fileno"): sock = sock.fileno() base_ptr = ffi.new("HANDLE *") @@ -340,9 +336,9 @@ def _afd_helper_handle(): # operation and start a new one. @attr.s(slots=True, eq=False) class AFDWaiters: - read_task: None = attr.ib(default=None) - write_task: None = attr.ib(default=None) - current_op: None = attr.ib(default=None) + read_task = attr.ib(default=None) + write_task = attr.ib(default=None) + current_op = attr.ib(default=None) # We also need to bundle up all the info for a single op into a standalone @@ -350,10 +346,10 @@ class AFDWaiters: # finishes, even if we're throwing it away. @attr.s(slots=True, eq=False, frozen=True) class AFDPollOp: - lpOverlapped: None = attr.ib() - poll_info: None = attr.ib() - waiters: None = attr.ib() - afd_group: None = attr.ib() + lpOverlapped = attr.ib() + poll_info = attr.ib() + waiters = attr.ib() + afd_group = attr.ib() # The Windows kernel has a weird issue when using AFD handles. If you have N @@ -369,8 +365,8 @@ class AFDPollOp: @attr.s(slots=True, eq=False) class AFDGroup: - size: int = attr.ib() - handle: None = attr.ib() + size = attr.ib() + handle = attr.ib() @attr.s(slots=True, eq=False, frozen=True) @@ -391,8 +387,8 @@ class _WindowsStatistics: @attr.s(frozen=True) class CompletionKeyEventInfo: - lpOverlapped: None = attr.ib() - dwNumberOfBytesTransferred: int = attr.ib() + lpOverlapped = attr.ib() + dwNumberOfBytesTransferred = attr.ib() class WindowsIOManager: @@ -459,7 +455,7 @@ def __init__(self): "netsh winsock show catalog" ) - def close(self) -> None: + def close(self): try: if self._iocp is not None: iocp = self._iocp @@ -470,10 +466,10 @@ def close(self) -> None: afd_handle = self._all_afd_handles.pop() _check(kernel32.CloseHandle(afd_handle)) - def __del__(self) -> None: + def __del__(self): self.close() - def statistics(self) -> _WindowsStatistics: + def statistics(self): tasks_waiting_read = 0 tasks_waiting_write = 0 for waiter in self._afd_waiters.values(): @@ -488,7 +484,7 @@ def statistics(self) -> _WindowsStatistics: completion_key_monitors=len(self._completion_key_queues), ) - def force_wakeup(self) -> None: + def force_wakeup(self): _check( kernel32.PostQueuedCompletionStatus( self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL @@ -594,7 +590,7 @@ def process_events(self, received: EventResult) -> None: ) queue.put_nowait(info) - def _register_with_iocp(self, handle, completion_key) -> None: + def _register_with_iocp(self, handle, completion_key): handle = _handle(handle) _check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0)) # Supposedly this makes things slightly faster, by disabling the @@ -611,7 +607,7 @@ def _register_with_iocp(self, handle, completion_key) -> None: # AFD stuff ################################################################ - def _refresh_afd(self, base_handle) -> None: + def _refresh_afd(self, base_handle): waiters = self._afd_waiters[base_handle] if waiters.current_op is not None: afd_group = waiters.current_op.afd_group @@ -687,7 +683,7 @@ def _refresh_afd(self, base_handle) -> None: if afd_group.size >= MAX_AFD_GROUP_SIZE: self._vacant_afd_groups.remove(afd_group) - async def _afd_poll(self, sock, mode) -> None: + async def _afd_poll(self, sock, mode): base_handle = _get_base_socket(sock) waiters = self._afd_waiters.get(base_handle) if waiters is None: @@ -700,7 +696,7 @@ async def _afd_poll(self, sock, mode) -> None: # we let it escape. self._refresh_afd(base_handle) - def abort_fn(_: RaiseCancelT) -> Abort: + def abort_fn(_): setattr(waiters, mode, None) self._refresh_afd(base_handle) return _core.Abort.SUCCEEDED @@ -708,15 +704,15 @@ def abort_fn(_: RaiseCancelT) -> Abort: await _core.wait_task_rescheduled(abort_fn) @_public - async def wait_readable(self, sock) -> None: + async def wait_readable(self, sock): await self._afd_poll(sock, "read_task") @_public - async def wait_writable(self, sock) -> None: + async def wait_writable(self, sock): await self._afd_poll(sock, "write_task") @_public - def notify_closing(self, handle) -> None: + def notify_closing(self, handle): handle = _get_base_socket(handle) waiters = self._afd_waiters.get(handle) if waiters is not None: @@ -728,7 +724,7 @@ def notify_closing(self, handle) -> None: ################################################################ @_public - def register_with_iocp(self, handle) -> None: + def register_with_iocp(self, handle): self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED) @_public @@ -744,7 +740,7 @@ async def wait_overlapped(self, handle, lpOverlapped): self._overlapped_waiters[lpOverlapped] = task raise_cancel = None - def abort(raise_cancel_: RaiseCancelT) -> Abort: + def abort(raise_cancel_): nonlocal raise_cancel raise_cancel = raise_cancel_ try: @@ -864,14 +860,14 @@ def submit_read(lpOverlapped): ################################################################ @_public - def current_iocp(self) -> int: + def current_iocp(self): return int(ffi.cast("uintptr_t", self._iocp)) @contextmanager @_public - def monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]: + def monitor_completion_key(self): key = next(self._completion_key_counter) - queue = _core.UnboundedQueue[object]() + queue = _core.UnboundedQueue() self._completion_key_queues[key] = queue try: yield (key, queue) diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 0730c16684..3c598e8eae 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -228,12 +228,7 @@ def main() -> None: # pragma: no cover "runner.instruments", imports=IMPORTS_INSTRUMENT, ), - File( - core / "_io_windows.py", - "runner.io_manager", - platform="win32", - imports=IMPORTS_WINDOWS, - ), + File(core / "_io_windows.py", "runner.io_manager", platform="win32"), File( core / "_io_epoll.py", "runner.io_manager", @@ -283,13 +278,6 @@ def main() -> None: # pragma: no cover """ -IMPORTS_WINDOWS = """\ -from typing import TYPE_CHECKING, ContextManager - -if TYPE_CHECKING: - from ._unbounded_queue import UnboundedQueue -""" - if __name__ == "__main__": # pragma: no cover main() From cd4e907fe45ec6ee95d87604c480cc922b2e4687 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 17 Aug 2023 14:11:22 +0200 Subject: [PATCH 47/49] move around in pyproject.toml --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 51ed378b49..4526571fc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,15 +69,15 @@ module = [ # 2755 "trio/_core/_windows_cffi", # 2, 324 "trio/_wait_for_object", # 2 (windows) +# 2761 +"trio/_core/_generated_io_windows", # 9 (win32), 84 +"trio/_core/_io_windows", # 47 (win32), 867 "trio/_signals", # 13, 168 lines -# windows API -"trio/_core/_generated_io_windows", # 9 (win32), 84 -"trio/_core/_io_windows", # 47 (win32), 867 # internal "trio/_windows_pipes", From 69b7e77d5a1d493198283fdd63fa0fbe715f48e5 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 17 Aug 2023 14:58:06 +0200 Subject: [PATCH 48/49] cleanup --- pyproject.toml | 49 +++++++++++++---------------------- trio/_path.py | 6 ++--- trio/_tests/verify_types.json | 5 ++-- trio/tests.py | 2 -- 4 files changed, 24 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4526571fc3..a212393452 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,48 +36,42 @@ warn_return_any = true # Avoid subtle backsliding disallow_any_decorated = true disallow_any_generics = true +disallow_any_unimported = false # Enable once Outcome has stubs. disallow_incomplete_defs = true disallow_subclassing_any = true disallow_untyped_decorators = true disallow_untyped_defs = true -# Enable gradually / for new modules +# Enable once other problems are dealt with check_untyped_defs = false disallow_untyped_calls = false -disallow_any_unimported = false # awaiting Outcome - -# DO NOT use `ignore_errors`; it doesn't apply -# downstream and users have to deal with them. +# files not yet fully typed [[tool.mypy.overrides]] module = [ - # 2747 -"trio/testing/_network", # 1, 34 -"trio/testing/_trio_test", # 2, 29 -"trio/testing/_checkpoints", # 3, 62 -"trio/testing/_check_streams", # 27, 522 -"trio/testing/_memory_streams", # 66, 590 +"trio/testing/_network", +"trio/testing/_trio_test", +"trio/testing/_checkpoints", +"trio/testing/_check_streams", +"trio/testing/_memory_streams", # 2745 -"trio/_ssl", # 26, 929 lines +"trio/_ssl", # 2756 -"trio/_highlevel_open_unix_stream", # 2, 49 lines -"trio/_highlevel_serve_listeners", # 3, 121 lines -"trio/_highlevel_ssl_helpers", # 3, 155 lines -"trio/_highlevel_socket", # 4, 386 lines +"trio/_highlevel_open_unix_stream", +"trio/_highlevel_serve_listeners", +"trio/_highlevel_ssl_helpers", +"trio/_highlevel_socket", # 2755 -"trio/_core/_windows_cffi", # 2, 324 -"trio/_wait_for_object", # 2 (windows) +"trio/_core/_windows_cffi", +"trio/_wait_for_object", # 2761 -"trio/_core/_generated_io_windows", # 9 (win32), 84 -"trio/_core/_io_windows", # 47 (win32), 867 - - - +"trio/_core/_generated_io_windows", +"trio/_core/_io_windows", -"trio/_signals", # 13, 168 lines +"trio/_signals", # internal "trio/_windows_pipes", @@ -93,13 +87,6 @@ disallow_any_unimported = false disallow_incomplete_defs = false disallow_untyped_defs = false -[[tool.mypy.overrides]] -# Needs to use Any due to some complex introspection. -module = [ - "trio._path", -] -disallow_any_generics = false - [tool.pytest.ini_options] addopts = ["--strict-markers", "--strict-config"] faulthandler_timeout = 60 diff --git a/trio/_path.py b/trio/_path.py index cad83e0e6a..c2763e03af 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -116,7 +116,7 @@ async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Path: def classmethod_wrapper_factory( cls: AsyncAutoWrapperType, meth_name: str -) -> classmethod: +) -> classmethod: # type: ignore[type-arg] @async_wraps(cls, cls._wraps, meth_name) async def wrapper(cls: type[Path], *args: Any, **kwargs: Any) -> Path: # type: ignore[misc] # contains Any meth = getattr(cls._wraps, meth_name) @@ -163,7 +163,7 @@ def generate_forwards(cls, attrs: dict[str, object]) -> None: def generate_wraps(cls, attrs: dict[str, object]) -> None: # generate wrappers for functions of _wraps - wrapper: classmethod | Callable + wrapper: classmethod | Callable[..., object] # type: ignore[type-arg] for attr_name, attr in cls._wraps.__dict__.items(): # .z. exclude cls._wrap_iter if attr_name.startswith("_") or attr_name in attrs: @@ -188,7 +188,7 @@ def generate_magic(cls, attrs: dict[str, object]) -> None: def generate_iter(cls, attrs: dict[str, object]) -> None: # generate wrappers for methods that return iterators - wrapper: Callable + wrapper: Callable[..., object] for attr_name, attr in cls._wraps.__dict__.items(): if attr_name in cls._wrap_iter: wrapper = iter_wrapper_factory(cls, attr_name) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 9d2f5b3a55..c5e9c4dc66 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -14,7 +14,7 @@ "withUnknownType": 22 }, "ignoreUnknownTypesFromImports": true, - "missingClassDocStringCount": 0, + "missingClassDocStringCount": 1, "missingDefaultParamCount": 0, "missingFunctionDocStringCount": 4, "moduleName": "trio", @@ -105,7 +105,8 @@ "trio.testing.memory_stream_pair", "trio.testing.memory_stream_pump", "trio.testing.open_stream_to_socket_listener", - "trio.testing.trio_test" + "trio.testing.trio_test", + "trio.tests.TestsDeprecationWrapper" ] } } diff --git a/trio/tests.py b/trio/tests.py index 1c5f039f0f..4ffb583a3a 100644 --- a/trio/tests.py +++ b/trio/tests.py @@ -16,8 +16,6 @@ # This won't give deprecation warning on import, but will give a warning on use of any # attribute in tests, and static analysis tools will also not see any content inside. class TestsDeprecationWrapper: - """trio.tests is deprecated, use trio._tests""" - __name__ = "trio.tests" def __getattr__(self, attr: str) -> Any: From 47a7228b3bc97bbfd44708801d1b2db73176e597 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 17 Aug 2023 14:58:57 +0200 Subject: [PATCH 49/49] make CI run without -m --- check.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/check.sh b/check.sh index a0efa531b6..ace193a62a 100755 --- a/check.sh +++ b/check.sh @@ -27,9 +27,9 @@ fi flake8 trio/ || EXIT_STATUS=$? # Run mypy on all supported platforms -mypy -m trio -m trio.testing --platform linux || EXIT_STATUS=$? -mypy -m trio -m trio.testing --platform darwin || EXIT_STATUS=$? # tests FreeBSD too -mypy -m trio -m trio.testing --platform win32 || EXIT_STATUS=$? +mypy trio --platform linux || EXIT_STATUS=$? +mypy trio --platform darwin || EXIT_STATUS=$? # tests FreeBSD too +mypy trio --platform win32 || EXIT_STATUS=$? # Check pip compile is consistent pip-compile test-requirements.in