From 8dfdd56e88ff4578a4a2813fc0216e9fca907622 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 11 Jul 2023 15:07:42 +0200
Subject: [PATCH 01/49] various small type fixes, disallow_incomplete_defs =
true, except for trio._core._run
---
pyproject.toml | 4 ++++
trio/_core/_io_epoll.py | 12 +++++-----
trio/_core/_mock_clock.py | 2 +-
trio/_core/_run.py | 16 +++++++-------
trio/_core/_thread_cache.py | 18 ++++++++++-----
trio/_dtls.py | 41 ++++++++++++++++++++---------------
trio/_socket.py | 5 +++--
trio/_subprocess.py | 27 ++++++++++++++++++-----
trio/_sync.py | 12 +++++-----
trio/_tests/verify_types.json | 17 +++++----------
trio/_threads.py | 24 +++++++++++++++-----
11 files changed, 111 insertions(+), 67 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index cfb4060ee7..954e21e2d3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,6 +44,10 @@ disallow_untyped_defs = false
# DO NOT use `ignore_errors`; it doesn't apply
# downstream and users have to deal with them.
+[[tool.mypy.overrides]]
+disallow_incomplete_defs = false
+module = "trio._core._run"
+
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
faulthandler_timeout = 60
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index 376dd18a4e..fbeb454c7d 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -1,7 +1,9 @@
+from __future__ import annotations
+
import select
import sys
from collections import defaultdict
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, DefaultDict, Dict
import attr
@@ -187,13 +189,13 @@ class EpollWaiters:
@attr.s(slots=True, eq=False, hash=False)
class EpollIOManager:
- _epoll = attr.ib(factory=select.epoll)
+ _epoll: select.epoll = attr.ib(factory=select.epoll)
# {fd: EpollWaiters}
- _registered = attr.ib(
+ _registered: DefaultDict[int, EpollWaiters] = attr.ib(
factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters]
)
- _force_wakeup = attr.ib(factory=WakeupSocketpair)
- _force_wakeup_fd = attr.ib(default=None)
+ _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
+ _force_wakeup_fd: int | None = attr.ib(default=None)
def __attrs_post_init__(self):
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py
index fe35298631..27a5829076 100644
--- a/trio/_core/_mock_clock.py
+++ b/trio/_core/_mock_clock.py
@@ -150,7 +150,7 @@ def deadline_to_sleep_time(self, deadline: float) -> float:
else:
return 999999999
- def jump(self, seconds) -> None:
+ def jump(self, seconds: float) -> None:
"""Manually advance the clock by the given number of seconds.
Args:
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 585dc4aa41..723370afd8 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -1168,7 +1168,7 @@ def __del__(self) -> None:
class Task(metaclass=NoPublicConstructor):
_parent_nursery: Nursery | None = attr.ib()
coro: Coroutine[Any, Outcome[object], Any] = attr.ib()
- _runner = attr.ib()
+ _runner: Runner = attr.ib()
name: str = attr.ib()
context: contextvars.Context = attr.ib()
_counter: int = attr.ib(init=False, factory=itertools.count().__next__)
@@ -1184,8 +1184,8 @@ class Task(metaclass=NoPublicConstructor):
# tracebacks with extraneous frames.
# - for scheduled tasks, custom_sleep_data is None
# Tasks start out unscheduled.
- _next_send_fn = attr.ib(default=None)
- _next_send = attr.ib(default=None)
+ _next_send_fn: Callable[[Outcome | None], None] = attr.ib(default=None)
+ _next_send: Outcome | None = attr.ib(default=None)
_abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib(
default=None
)
@@ -1386,13 +1386,13 @@ class _RunStatistics:
# worker thread.
@attr.s(eq=False, hash=False, slots=True)
class GuestState:
- runner = attr.ib()
- run_sync_soon_threadsafe = attr.ib()
- run_sync_soon_not_threadsafe = attr.ib()
- done_callback = attr.ib()
+ runner: Runner = attr.ib()
+ run_sync_soon_threadsafe: Callable = attr.ib()
+ run_sync_soon_not_threadsafe: Callable = attr.ib()
+ done_callback: Callable = attr.ib()
unrolled_run_gen = attr.ib()
_value_factory: Callable[[], Value] = lambda: Value(None)
- unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome)
+ unrolled_run_next_send: Outcome = attr.ib(factory=_value_factory, type=Outcome)
def guest_tick(self):
try:
diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py
index cc272fc92c..3e27ce6a32 100644
--- a/trio/_core/_thread_cache.py
+++ b/trio/_core/_thread_cache.py
@@ -18,7 +18,9 @@ def _to_os_thread_name(name: str) -> bytes:
# used to construct the method used to set os thread name, or None, depending on platform.
# called once on import
def get_os_thread_name_func() -> Optional[Callable[[Optional[int], str], None]]:
- def namefunc(setname: Callable[[int, bytes], int], ident: Optional[int], name: str):
+ def namefunc(
+ setname: Callable[[int, bytes], int], ident: Optional[int], name: str
+ ) -> None:
# Thread.ident is None "if it has not been started". Unclear if that can happen
# with current usage.
if ident is not None: # pragma: no cover
@@ -28,7 +30,7 @@ def namefunc(setname: Callable[[int, bytes], int], ident: Optional[int], name: s
# so the caller don't need to care about platform.
def darwin_namefunc(
setname: Callable[[bytes], int], ident: Optional[int], name: str
- ):
+ ) -> None:
# I don't know if Mac can rename threads that hasn't been started, but default
# to no to be on the safe side.
if ident is not None: # pragma: no cover
@@ -111,7 +113,9 @@ def darwin_namefunc(
class WorkerThread:
- def __init__(self, thread_cache):
+ def __init__(self, thread_cache: ThreadCache):
+ # deliver (the second value) can probably be Callable[[outcome.Value], None] ?
+ # should generate stubs for outcome
self._job: Optional[Tuple[Callable, Callable, str]] = None
self._thread_cache = thread_cache
# This Lock is used in an unconventional way.
@@ -188,7 +192,9 @@ class ThreadCache:
def __init__(self):
self._idle_workers = {}
- def start_thread_soon(self, fn, deliver, name: Optional[str] = None):
+ def start_thread_soon(
+ self, fn: Callable, deliver: Callable, name: Optional[str] = None
+ ) -> None:
try:
worker, _ = self._idle_workers.popitem()
except KeyError:
@@ -200,7 +206,9 @@ def start_thread_soon(self, fn, deliver, name: Optional[str] = None):
THREAD_CACHE = ThreadCache()
-def start_thread_soon(fn, deliver, name: Optional[str] = None):
+def start_thread_soon(
+ fn: Callable, deliver: Callable, name: Optional[str] = None
+) -> None:
"""Runs ``deliver(outcome.capture(fn))`` in a worker thread.
Generally ``fn`` does some blocking work, and ``deliver`` delivers the
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 722a9499f8..a1551fda0e 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -16,7 +16,7 @@
import warnings
import weakref
from itertools import count
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Awaitable, Callable
import attr
@@ -26,6 +26,11 @@
if TYPE_CHECKING:
from types import TracebackType
+ from OpenSSL.SSL import Context
+ from typing_extensions import Self
+
+ from trio._socket import _SocketType
+
MAX_UDP_PACKET_SIZE = 65527
@@ -1126,17 +1131,17 @@ class DTLSEndpoint(metaclass=Final):
"""
- def __init__(self, socket, *, incoming_packets_buffer=10):
+ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
# We do this lazily on first construction, so only people who actually use DTLS
# have to install PyOpenSSL.
global SSL
from OpenSSL import SSL
- # TODO: create a `self._initialized` for `__del__`, so self.socket can be typed
- # as trio.socket.SocketType and `is not None` checks can be removed.
- self.socket = None # for __del__, in case the next line raises
+ # for __del__, in case the next line raises
+ self._initialized: bool = False
if socket.type != trio.socket.SOCK_DGRAM:
raise ValueError("DTLS requires a SOCK_DGRAM socket")
+ self._initialized = True
self.socket = socket
self.incoming_packets_buffer = incoming_packets_buffer
@@ -1146,8 +1151,8 @@ def __init__(self, socket, *, incoming_packets_buffer=10):
# as a peer provides a valid cookie, we can immediately tear down the
# old connection.
# {remote address: DTLSChannel}
- self._streams = weakref.WeakValueDictionary()
- self._listening_context = None
+ self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
+ self._listening_context: Context | None = None
self._listening_key = None
self._incoming_connections_q = _Queue(float("inf"))
self._send_lock = trio.Lock()
@@ -1164,9 +1169,9 @@ def _ensure_receive_loop(self):
)
self._receive_loop_spawned = True
- def __del__(self):
+ def __del__(self) -> None:
# Do nothing if this object was never fully constructed
- if self.socket is None:
+ if not self._initialized:
return
# Close the socket in Trio context (if our Trio context still exists), so that
# the background task gets notified about the closure and can exit.
@@ -1186,17 +1191,13 @@ def close(self) -> None:
This object can also be used as a context manager.
"""
- # Do nothing if this object was never fully constructed
- if self.socket is None: # pragma: no cover
- return
-
self._closed = True
self.socket.close()
for stream in list(self._streams.values()):
stream.close()
self._incoming_connections_q.s.close()
- def __enter__(self):
+ def __enter__(self) -> Self:
return self
def __exit__(
@@ -1207,13 +1208,17 @@ def __exit__(
) -> None:
return self.close()
- def _check_closed(self):
+ def _check_closed(self) -> None:
if self._closed:
raise trio.ClosedResourceError
- async def serve(
- self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED
- ):
+ async def serve( # type: ignore[no-untyped-def]
+ self,
+ ssl_context: Context,
+ async_fn: Callable[..., Awaitable],
+ *args,
+ task_status=trio.TASK_STATUS_IGNORED, # type: ignore[has-type] # ???
+ ) -> None:
"""Listen for incoming connections, and spawn a handler for each using an
internal nursery.
diff --git a/trio/_socket.py b/trio/_socket.py
index eaf0e04d15..e1a8c5562a 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -434,6 +434,8 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo
return normed
+# TODO: stopping users from initializing this type should be done in a different way,
+# so SocketType can be used as a type.
class SocketType:
def __init__(self):
raise TypeError(
@@ -537,8 +539,7 @@ async def bind(self, address: tuple[object, ...] | str | bytes) -> None:
):
# Use a thread for the filesystem traversal (unless it's an
# abstract domain socket)
- # remove the `type: ignore` when run.sync is typed.
- return await trio.to_thread.run_sync(self._sock.bind, address) # type: ignore[no-any-return]
+ return await trio.to_thread.run_sync(self._sock.bind, address)
else:
# POSIX actually says that bind can return EWOULDBLOCK and
# complete asynchronously, like connect. But in practice AFAICT
diff --git a/trio/_subprocess.py b/trio/_subprocess.py
index 1f8d0a8253..ab592d0e9a 100644
--- a/trio/_subprocess.py
+++ b/trio/_subprocess.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import subprocess
import sys
@@ -117,11 +119,17 @@ class Process(AsyncResource, metaclass=NoPublicConstructor):
# arbitrarily many threads if wait() keeps getting cancelled.
_wait_for_exit_data = None
- def __init__(self, popen, stdin, stdout, stderr):
+ def __init__(
+ self,
+ popen: subprocess.Popen,
+ stdin: Optional[SendStream],
+ stdout: Optional[ReceiveStream],
+ stderr: Optional[ReceiveStream],
+ ):
self._proc = popen
- self.stdin: Optional[SendStream] = stdin
- self.stdout: Optional[ReceiveStream] = stdout
- self.stderr: Optional[ReceiveStream] = stderr
+ self.stdin = stdin
+ self.stdout = stdout
+ self.stderr = stderr
self.stdio: Optional[StapledStream] = None
if self.stdin is not None and self.stdout is not None:
@@ -294,8 +302,17 @@ def kill(self):
self._proc.kill()
+from typing import Any
+
+
+# TODO: replace Any with a ParamSpec from Popen?? Or just type them out
async def open_process(
- command, *, stdin=None, stdout=None, stderr=None, **options
+ command: list[str] | str,
+ *,
+ stdin: int | None = None,
+ stdout: int | None = None,
+ stderr: int | None = None,
+ **options: Any,
) -> Process:
r"""Execute a child program in a new process.
diff --git a/trio/_sync.py b/trio/_sync.py
index 5a7f240d5e..0f05dd458c 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -143,7 +143,7 @@ class CapacityLimiterStatistics:
borrowed_tokens: int = attr.ib()
total_tokens: int | float = attr.ib()
- borrowers: list[Task] = attr.ib()
+ borrowers: list[object] = attr.ib()
tasks_waiting: int = attr.ib()
@@ -204,9 +204,9 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final):
# total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
def __init__(self, total_tokens: int | float):
self._lot = ParkingLot()
- self._borrowers: set[Task] = set()
+ self._borrowers: set[object] = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
- self._pending_borrowers: dict[Task, Task] = {}
+ self._pending_borrowers: dict[Task, object] = {}
# invoke the property setter for validation
self.total_tokens: int | float = total_tokens
assert self._total_tokens == total_tokens
@@ -268,7 +268,7 @@ def acquire_nowait(self) -> None:
self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task())
@enable_ki_protection
- def acquire_on_behalf_of_nowait(self, borrower: Task) -> None:
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, without
blocking.
@@ -307,7 +307,7 @@ async def acquire(self) -> None:
await self.acquire_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- async def acquire_on_behalf_of(self, borrower: Task) -> None:
+ async def acquire_on_behalf_of(self, borrower: object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, blocking if
necessary.
@@ -347,7 +347,7 @@ def release(self) -> None:
self.release_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- def release_on_behalf_of(self, borrower: Task) -> None:
+ def release_on_behalf_of(self, borrower: object) -> None:
"""Put a token back into the sack on behalf of ``borrower``.
Raises:
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 9d7d7aa912..147fa6253a 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.8764044943820225,
+ "completenessScore": 0.8812199036918138,
"exportedSymbolCounts": {
"withAmbiguousType": 1,
- "withKnownType": 546,
- "withUnknownType": 76
+ "withKnownType": 549,
+ "withUnknownType": 73
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
@@ -46,8 +46,8 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 433,
- "withUnknownType": 135
+ "withKnownType": 438,
+ "withUnknownType": 130
},
"packageName": "trio",
"symbols": [
@@ -70,7 +70,6 @@
"trio._core._local.RunVar.get",
"trio._core._local.RunVar.reset",
"trio._core._local.RunVar.set",
- "trio._core._mock_clock.MockClock.jump",
"trio._core._run.Nursery.start",
"trio._core._run.Nursery.start_soon",
"trio._core._unbounded_queue.UnboundedQueue.__aiter__",
@@ -90,13 +89,9 @@
"trio._dtls.DTLSChannel.send",
"trio._dtls.DTLSChannel.set_ciphertext_mtu",
"trio._dtls.DTLSChannel.statistics",
- "trio._dtls.DTLSEndpoint.__del__",
- "trio._dtls.DTLSEndpoint.__enter__",
"trio._dtls.DTLSEndpoint.__init__",
"trio._dtls.DTLSEndpoint.connect",
- "trio._dtls.DTLSEndpoint.incoming_packets_buffer",
"trio._dtls.DTLSEndpoint.serve",
- "trio._dtls.DTLSEndpoint.socket",
"trio._highlevel_socket.SocketListener",
"trio._highlevel_socket.SocketListener.__init__",
"trio._highlevel_socket.SocketStream.__init__",
@@ -168,14 +163,12 @@
"trio.lowlevel.current_trio_token",
"trio.lowlevel.currently_ki_protected",
"trio.lowlevel.notify_closing",
- "trio.lowlevel.open_process",
"trio.lowlevel.permanently_detach_coroutine_object",
"trio.lowlevel.reattach_detached_coroutine_object",
"trio.lowlevel.remove_instrument",
"trio.lowlevel.reschedule",
"trio.lowlevel.spawn_system_task",
"trio.lowlevel.start_guest_run",
- "trio.lowlevel.start_thread_soon",
"trio.lowlevel.temporarily_detach_coroutine_object",
"trio.lowlevel.wait_readable",
"trio.lowlevel.wait_writable",
diff --git a/trio/_threads.py b/trio/_threads.py
index 807212e0f9..52c742d588 100644
--- a/trio/_threads.py
+++ b/trio/_threads.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import contextvars
import functools
import inspect
@@ -57,10 +59,19 @@ class ThreadPlaceholder:
name = attr.ib()
+from typing import Any, Callable, TypeVar
+
+T = TypeVar("T")
+
+
@enable_ki_protection
async def to_thread_run_sync(
- sync_fn, *args, thread_name: Optional[str] = None, cancellable=False, limiter=None
-):
+ sync_fn: Callable[..., T],
+ *args: Any,
+ thread_name: Optional[str] = None,
+ cancellable: bool = False,
+ limiter: CapacityLimiter | None = None,
+) -> T:
"""Convert a blocking operation into an async operation using a thread.
These two lines are equivalent::
@@ -152,7 +163,7 @@ async def to_thread_run_sync(
# Holds a reference to the task that's blocked in this function waiting
# for the result – or None if this function was cancelled and we should
# discard the result.
- task_register = [trio.lowlevel.current_task()]
+ task_register: list[trio.lowlevel.Task | None] = [trio.lowlevel.current_task()]
name = f"trio.to_thread.run_sync-{next(_thread_counter)}"
placeholder = ThreadPlaceholder(name)
@@ -217,14 +228,17 @@ def deliver_worker_fn_result(result):
limiter.release_on_behalf_of(placeholder)
raise
- def abort(_):
+ from trio._core._traps import RaiseCancelT
+
+ def abort(_: RaiseCancelT) -> trio.lowlevel.Abort:
if cancellable:
task_register[0] = None
return trio.lowlevel.Abort.SUCCEEDED
else:
return trio.lowlevel.Abort.FAILED
- return await trio.lowlevel.wait_task_rescheduled(abort)
+ # wait_task_rescheduled return value cannot be typed
+ return await trio.lowlevel.wait_task_rescheduled(abort) # type: ignore[no-any-return]
def _run_fn_as_system_task(cb, fn, *args, context, trio_token=None):
From 46ac4e852bce69bcdff2cacff2932626d45d5110 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 12 Jul 2023 13:33:35 +0200
Subject: [PATCH 02/49] stuff
---
pyproject.toml | 2 +-
trio/__init__.py | 5 +-
trio/_abc.py | 54 +++++++++++++-----
trio/_core/_local.py | 46 ++++++++-------
trio/_core/_run.py | 4 +-
trio/_core/_unbounded_queue.py | 51 +++++++++++------
trio/_deprecate.py | 13 ++++-
trio/_dtls.py | 100 ++++++++++++++++++++-------------
trio/_socket.py | 65 +++++++++++++++------
trio/_tests/verify_types.json | 59 ++-----------------
trio/_threads.py | 2 +-
trio/_util.py | 2 +-
trio/testing/_fake_net.py | 23 +++++++-
13 files changed, 255 insertions(+), 171 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 954e21e2d3..121a398234 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,8 +45,8 @@ disallow_untyped_defs = false
# downstream and users have to deal with them.
[[tool.mypy.overrides]]
-disallow_incomplete_defs = false
module = "trio._core._run"
+disallow_incomplete_defs = false
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
diff --git a/trio/__init__.py b/trio/__init__.py
index 2b8810504b..35dc3e133f 100644
--- a/trio/__init__.py
+++ b/trio/__init__.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
"""Trio - A friendly Python library for async concurrency and I/O
"""
@@ -15,6 +17,7 @@
#
# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625)
+
# must be imported early to avoid circular import
from ._core import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED # isort: skip
@@ -112,7 +115,7 @@
_deprecate.enable_attribute_deprecations(__name__)
-__deprecated_attributes__ = {
+__deprecated_attributes__: dict[str, _deprecate.DeprecatedAttribute] = {
"open_process": _deprecate.DeprecatedAttribute(
value=lowlevel.open_process,
version="0.20.0",
diff --git a/trio/_abc.py b/trio/_abc.py
index 2a1721db13..0bb49e207d 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -6,10 +6,15 @@
import trio
if TYPE_CHECKING:
+ import socket
from types import TracebackType
from typing_extensions import Self
+ from trio.lowlevel import Task
+
+ from ._socket import _SocketType
+
# We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a
# __dict__ onto subclasses.
@@ -73,13 +78,13 @@ class Instrument(metaclass=ABCMeta):
__slots__ = ()
- def before_run(self):
+ def before_run(self) -> None:
"""Called at the beginning of :func:`trio.run`."""
- def after_run(self):
+ def after_run(self) -> None:
"""Called just before :func:`trio.run` returns."""
- def task_spawned(self, task):
+ def task_spawned(self, task: Task) -> None:
"""Called when the given task is created.
Args:
@@ -87,7 +92,7 @@ def task_spawned(self, task):
"""
- def task_scheduled(self, task):
+ def task_scheduled(self, task: Task) -> None:
"""Called when the given task becomes runnable.
It may still be some time before it actually runs, if there are other
@@ -98,7 +103,7 @@ def task_scheduled(self, task):
"""
- def before_task_step(self, task):
+ def before_task_step(self, task: Task) -> None:
"""Called immediately before we resume running the given task.
Args:
@@ -106,7 +111,7 @@ def before_task_step(self, task):
"""
- def after_task_step(self, task):
+ def after_task_step(self, task: Task) -> None:
"""Called when we return to the main run loop after a task has yielded.
Args:
@@ -114,7 +119,7 @@ def after_task_step(self, task):
"""
- def task_exited(self, task):
+ def task_exited(self, task: Task) -> None:
"""Called when the given task exits.
Args:
@@ -122,7 +127,7 @@ def task_exited(self, task):
"""
- def before_io_wait(self, timeout):
+ def before_io_wait(self, timeout: float) -> None:
"""Called before blocking to wait for I/O readiness.
Args:
@@ -130,7 +135,7 @@ def before_io_wait(self, timeout):
"""
- def after_io_wait(self, timeout):
+ def after_io_wait(self, timeout: float) -> None:
"""Called after handling pending I/O.
Args:
@@ -152,7 +157,23 @@ class HostnameResolver(metaclass=ABCMeta):
__slots__ = ()
@abstractmethod
- async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0):
+ async def getaddrinfo(
+ self,
+ host: bytes | str | None,
+ port: bytes | str | int | None,
+ family: int = 0,
+ type: int = 0,
+ proto: int = 0,
+ flags: int = 0,
+ ) -> list[
+ tuple[
+ socket.AddressFamily,
+ socket.SocketKind,
+ int,
+ str,
+ tuple[str, int] | tuple[str, int, int, int],
+ ]
+ ]:
"""A custom implementation of :func:`~trio.socket.getaddrinfo`.
Called by :func:`trio.socket.getaddrinfo`.
@@ -169,7 +190,9 @@ async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0):
"""
@abstractmethod
- async def getnameinfo(self, sockaddr, flags):
+ async def getnameinfo(
+ self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int
+ ) -> tuple[str, str]:
"""A custom implementation of :func:`~trio.socket.getnameinfo`.
Called by :func:`trio.socket.getnameinfo`.
@@ -186,7 +209,12 @@ class SocketFactory(metaclass=ABCMeta):
"""
@abstractmethod
- def socket(self, family=None, type=None, proto=None):
+ def socket(
+ self,
+ family: socket.AddressFamily | int = socket.AF_INET,
+ type: socket.SocketKind | int = socket.SOCK_STREAM,
+ proto: int = 0,
+ ) -> _SocketType:
"""Create and return a socket object.
Your socket object must inherit from :class:`trio.socket.SocketType`,
@@ -537,7 +565,7 @@ class Listener(AsyncResource, Generic[T_resource]):
__slots__ = ()
@abstractmethod
- async def accept(self):
+ async def accept(self) -> AsyncResource:
"""Wait until an incoming connection arrives, and then return it.
Returns:
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index a54f424fdf..89ccf93e95 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -1,25 +1,32 @@
+from __future__ import annotations
+
+from typing import Generic, TypeVar
+
# Runvar implementations
import attr
from .._util import Final
from . import _run
+T = TypeVar("T")
+C = TypeVar("C", bound="_RunVarToken")
+
@attr.s(eq=False, hash=False, slots=True)
-class _RunVarToken:
- _no_value = object()
+class _RunVarToken(Generic[T]):
+ _no_value = None
- _var = attr.ib()
- previous_value = attr.ib(default=_no_value)
- redeemed = attr.ib(default=False, init=False)
+ _var: RunVar[T] = attr.ib()
+ previous_value: T | None = attr.ib(default=_no_value)
+ redeemed: bool = attr.ib(default=False, init=False)
@classmethod
- def empty(cls, var):
+ def empty(cls: type[C], var: RunVar[T]) -> C:
return cls(var)
@attr.s(eq=False, hash=False, slots=True)
-class RunVar(metaclass=Final):
+class RunVar(Generic[T], metaclass=Final):
"""The run-local variant of a context variable.
:class:`RunVar` objects are similar to context variable objects,
@@ -28,14 +35,15 @@ class RunVar(metaclass=Final):
"""
- _NO_DEFAULT = object()
- _name = attr.ib()
- _default = attr.ib(default=_NO_DEFAULT)
+ _NO_DEFAULT = None
+ _name: str = attr.ib()
+ _default: T | None = attr.ib(default=_NO_DEFAULT)
- def get(self, default=_NO_DEFAULT):
+ def get(self, default: T | None = _NO_DEFAULT) -> T | None:
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
- return _run.GLOBAL_RUN_CONTEXT.runner._locals[self]
+ # not typed yet
+ return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value, index]
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
@@ -48,7 +56,7 @@ def get(self, default=_NO_DEFAULT):
raise LookupError(self) from None
- def set(self, value):
+ def set(self, value: T) -> _RunVarToken[T]:
"""Sets the value of this :class:`RunVar` for this current run
call.
@@ -56,16 +64,16 @@ def set(self, value):
try:
old_value = self.get()
except LookupError:
- token = _RunVarToken.empty(self)
+ token: _RunVarToken[T] = _RunVarToken.empty(self)
else:
token = _RunVarToken(self, old_value)
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index]
return token
- def reset(self, token):
+ def reset(self, token: _RunVarToken[T]) -> None:
"""Resets the value of this :class:`RunVar` to what it was
previously specified by the token.
@@ -82,13 +90,13 @@ def reset(self, token):
previous = token.previous_value
try:
if previous is _RunVarToken._no_value:
- _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
+ _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type]
else:
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment]
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context")
token.redeemed = True
- def __repr__(self):
+ def __repr__(self) -> str:
return f""
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 723370afd8..804a958714 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -789,10 +789,10 @@ class _TaskStatus:
_called_started = attr.ib(default=False)
_value = attr.ib(default=None)
- def __repr__(self):
+ def __repr__(self) -> str:
return f""
- def started(self, value=None):
+ def started(self, value: Any = None) -> None:
if self._called_started:
raise RuntimeError("called 'started' twice on the same task status")
self._called_started = True
diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py
index 9c747749b4..cbcf10cf89 100644
--- a/trio/_core/_unbounded_queue.py
+++ b/trio/_core/_unbounded_queue.py
@@ -1,17 +1,34 @@
+from __future__ import annotations
+
+from typing import Generic, TypeVar
+
import attr
+from typing_extensions import Self
from .. import _core
from .._deprecate import deprecated
from .._util import Final
+T = TypeVar("T")
+
@attr.s(frozen=True)
-class _UnboundedQueueStats:
- qsize = attr.ib()
- tasks_waiting = attr.ib()
+class UnboundedQueueStats:
+ """An object containing debugging information.
+
+ Currently the following fields are defined:
+
+ * ``qsize``: The number of items currently in the queue.
+ * ``tasks_waiting``: The number of tasks blocked on this queue's
+ :meth:`get_batch` method.
+
+ """
+
+ qsize: int = attr.ib()
+ tasks_waiting: int = attr.ib()
-class UnboundedQueue(metaclass=Final):
+class UnboundedQueue(Generic[T], metaclass=Final):
"""An unbounded queue suitable for certain unusual forms of inter-task
communication.
@@ -47,20 +64,20 @@ class UnboundedQueue(metaclass=Final):
thing="trio.lowlevel.UnboundedQueue",
instead="trio.open_memory_channel(math.inf)",
)
- def __init__(self):
+ def __init__(self) -> None:
self._lot = _core.ParkingLot()
- self._data = []
+ self._data: list[T] = []
# used to allow handoff from put to the first task in the lot
self._can_get = False
- def __repr__(self):
+ def __repr__(self) -> str:
return f""
- def qsize(self):
+ def qsize(self) -> int:
"""Returns the number of items currently in the queue."""
return len(self._data)
- def empty(self):
+ def empty(self) -> bool:
"""Returns True if the queue is empty, False otherwise.
There is some subtlety to interpreting this method's return value: see
@@ -70,7 +87,7 @@ def empty(self):
return not self._data
@_core.enable_ki_protection
- def put_nowait(self, obj):
+ def put_nowait(self, obj: T) -> None:
"""Put an object into the queue, without blocking.
This always succeeds, because the queue is unbounded. We don't provide
@@ -88,13 +105,13 @@ def put_nowait(self, obj):
self._can_get = True
self._data.append(obj)
- def _get_batch_protected(self):
+ def _get_batch_protected(self) -> list[T]:
data = self._data.copy()
self._data.clear()
self._can_get = False
return data
- def get_batch_nowait(self):
+ def get_batch_nowait(self) -> list[T]:
"""Attempt to get the next batch from the queue, without blocking.
Returns:
@@ -110,7 +127,7 @@ def get_batch_nowait(self):
raise _core.WouldBlock
return self._get_batch_protected()
- async def get_batch(self):
+ async def get_batch(self) -> list[T]:
"""Get the next batch from the queue, blocking as necessary.
Returns:
@@ -128,7 +145,7 @@ async def get_batch(self):
finally:
await _core.cancel_shielded_checkpoint()
- def statistics(self):
+ def statistics(self) -> UnboundedQueueStats:
"""Return an object containing debugging information.
Currently the following fields are defined:
@@ -138,12 +155,12 @@ def statistics(self):
:meth:`get_batch` method.
"""
- return _UnboundedQueueStats(
+ return UnboundedQueueStats(
qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting
)
- def __aiter__(self):
+ def __aiter__(self) -> Self:
return self
- async def __anext__(self):
+ async def __anext__(self) -> list[T]:
return await self.get_batch()
diff --git a/trio/_deprecate.py b/trio/_deprecate.py
index fe00192583..aeebe80722 100644
--- a/trio/_deprecate.py
+++ b/trio/_deprecate.py
@@ -1,10 +1,15 @@
+from __future__ import annotations
+
import sys
import warnings
from functools import wraps
from types import ModuleType
+from typing import Callable, TypeVar
import attr
+T = TypeVar("T")
+
# We want our warnings to be visible by default (at least for now), but we
# also want it to be possible to override that using the -W switch. AFAICT
@@ -53,7 +58,9 @@ def warn_deprecated(thing, version, *, issue, instead, stacklevel=2):
# @deprecated("0.2.0", issue=..., instead=...)
# def ...
-def deprecated(version, *, thing=None, issue, instead):
+def deprecated(
+ version: str, *, thing: str | None = None, issue: int, instead: str
+) -> Callable[[T], T]:
def do_wrap(fn):
nonlocal thing
@@ -124,10 +131,10 @@ def __getattr__(self, name):
raise AttributeError(msg.format(self.__name__, name))
-def enable_attribute_deprecations(module_name):
+def enable_attribute_deprecations(module_name: str) -> None:
module = sys.modules[module_name]
module.__class__ = _ModuleWithDeprecations
# Make sure that this is always defined so that
# _ModuleWithDeprecations.__getattr__ can access it without jumping
# through hoops or risking infinite recursion.
- module.__deprecated_attributes__ = {}
+ module.__deprecated_attributes__ = {} # type: ignore[attr-defined]
diff --git a/trio/_dtls.py b/trio/_dtls.py
index a1551fda0e..aea15be735 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -16,9 +16,10 @@
import warnings
import weakref
from itertools import count
-from typing import TYPE_CHECKING, Awaitable, Callable
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Iterator, cast
import attr
+from OpenSSL import SSL
import trio
from trio._util import Final, NoPublicConstructor
@@ -31,24 +32,26 @@
from trio._socket import _SocketType
+ from ._core._run import _TaskStatus
+
MAX_UDP_PACKET_SIZE = 65527
-def packet_header_overhead(sock):
+def packet_header_overhead(sock: _SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 28
else:
return 48
-def worst_case_mtu(sock):
+def worst_case_mtu(sock: _SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 576 - packet_header_overhead(sock)
else:
return 1280 - packet_header_overhead(sock)
-def best_guess_mtu(sock):
+def best_guess_mtu(sock: _SocketType) -> int:
return 1500 - packet_header_overhead(sock)
@@ -110,14 +113,14 @@ class BadPacket(Exception):
# ChangeCipherSpec is used during the handshake but has its own ContentType.
#
# Cannot fail.
-def part_of_handshake_untrusted(packet):
+def part_of_handshake_untrusted(packet: bytes) -> bool:
# If the packet is too short, then slicing will successfully return a
# short string, which will necessarily fail to match.
return packet[3:5] == b"\x00\x00"
# Cannot fail
-def is_client_hello_untrusted(packet):
+def is_client_hello_untrusted(packet: bytes) -> bool:
try:
return (
packet[0] == ContentType.handshake
@@ -152,7 +155,7 @@ class Record:
payload: bytes = attr.ib(repr=to_hex)
-def records_untrusted(packet):
+def records_untrusted(packet: bytes) -> Iterator[Record]:
i = 0
while i < len(packet):
try:
@@ -170,7 +173,7 @@ def records_untrusted(packet):
yield Record(ct, version, epoch_seqno, payload)
-def encode_record(record):
+def encode_record(record: Record) -> bytes:
header = RECORD_HEADER.pack(
record.content_type,
record.version,
@@ -199,7 +202,7 @@ class HandshakeFragment:
frag: bytes = attr.ib(repr=to_hex)
-def decode_handshake_fragment_untrusted(payload):
+def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment:
# Raises BadPacket if decoding fails
try:
(
@@ -229,7 +232,7 @@ def decode_handshake_fragment_untrusted(payload):
)
-def encode_handshake_fragment(hsf):
+def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes:
hs_header = HANDSHAKE_MESSAGE_HEADER.pack(
hsf.msg_type,
hsf.msg_len.to_bytes(3, "big"),
@@ -240,7 +243,7 @@ def encode_handshake_fragment(hsf):
return hs_header + hsf.frag
-def decode_client_hello_untrusted(packet):
+def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]:
# Raises BadPacket if parsing fails
# Returns (record epoch_seqno, cookie from the packet, data that should be
# hashed into cookie)
@@ -340,8 +343,12 @@ class OpaqueHandshakeMessage:
# reconstructs the handshake messages inside it, so that we can repack them
# into records while retransmitting. So the data ought to be well-behaved --
# it's not coming from the network.
-def decode_volley_trusted(volley):
- messages = []
+def decode_volley_trusted(
+ volley: bytes,
+) -> list[HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage]:
+ messages: list[
+ HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage
+ ] = []
messages_by_seq = {}
for record in records_untrusted(volley):
# ChangeCipherSpec isn't a handshake message, so it can't be fragmented.
@@ -388,10 +395,16 @@ class RecordEncoder:
def __init__(self):
self._record_seq = count()
- def set_first_record_number(self, n):
+ def set_first_record_number(self, n: int) -> None:
self._record_seq = count(n)
- def encode_volley(self, messages, mtu):
+ def encode_volley(
+ self,
+ messages: Iterable[
+ HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage
+ ],
+ mtu: int,
+ ) -> list[bytearray]:
packets = []
packet = bytearray()
for message in messages:
@@ -523,13 +536,13 @@ def encode_volley(self, messages, mtu):
COOKIE_LENGTH = 32
-def _current_cookie_tick():
+def _current_cookie_tick() -> int:
return int(trio.current_time() / COOKIE_REFRESH_INTERVAL)
# Simple deterministic and invertible serializer -- i.e., a useful tool for converting
# structured data into something we can cryptographically sign.
-def _signable(*fields):
+def _signable(*fields: bytes) -> bytes:
out = []
for field in fields:
out.append(struct.pack("!Q", len(field)))
@@ -618,7 +631,7 @@ def __init__(self, incoming_packets_buffer):
self.s, self.r = trio.open_memory_channel(incoming_packets_buffer)
-def _read_loop(read_fn):
+def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
chunks = []
while True:
try:
@@ -778,7 +791,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
"""
- def __init__(self, endpoint, peer_address, ctx):
+ def __init__(self, endpoint: DTLSEndpoint, peer_address: str, ctx: Context):
self.endpoint = endpoint
self.peer_address = peer_address
self._packets_dropped_in_trio = 0
@@ -789,9 +802,9 @@ def __init__(self, endpoint, peer_address, ctx):
# OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to
# support and isn't useful anyway -- especially for DTLS where it's equivalent
# to just performing a new handshake.
- ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION)
+ ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined]
self._ssl = SSL.Connection(ctx)
- self._handshake_mtu = None
+ self._handshake_mtu: int | None = None
# This calls self._ssl.set_ciphertext_mtu, which is important, because if you
# don't call it then openssl doesn't work.
self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket))
@@ -841,7 +854,7 @@ def close(self) -> None:
# ClosedResourceError
self._q.r.close()
- def __enter__(self):
+ def __enter__(self) -> Self:
return self
def __exit__(
@@ -852,7 +865,7 @@ def __exit__(
) -> None:
return self.close()
- async def aclose(self):
+ async def aclose(self) -> None:
"""Close this connection, but asynchronously.
This is included to satisfy the `trio.abc.Channel` contract. It's
@@ -873,7 +886,7 @@ async def _send_volley(self, volley_messages):
async def _resend_final_volley(self):
await self._send_volley(self._final_volley)
- async def do_handshake(self, *, initial_retransmit_timeout=1.0):
+ async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None:
"""Perform the handshake.
Calling this is optional – if you don't, then it will be automatically called
@@ -906,17 +919,23 @@ async def do_handshake(self, *, initial_retransmit_timeout=1.0):
return
timeout = initial_retransmit_timeout
- volley_messages = []
+ volley_messages: list[
+ HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage
+ ] = []
volley_failed_sends = 0
- def read_volley():
+ def read_volley() -> (
+ list[HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage]
+ ):
volley_bytes = _read_loop(self._ssl.bio_read)
new_volley_messages = decode_volley_trusted(volley_bytes)
if (
new_volley_messages
and volley_messages
and isinstance(new_volley_messages[0], HandshakeMessage)
- and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq
+ # TODO: add isinstance or do a cast?
+ and new_volley_messages[0].msg_seq
+ == cast(HandshakeMessage, volley_messages[0]).msg_seq
):
# openssl decided to retransmit; discard because we handle
# retransmits ourselves
@@ -1000,10 +1019,13 @@ def read_volley():
# PMTU estimate is wrong? Let's try dropping it to the minimum
# and hope that helps.
self._handshake_mtu = min(
- self._handshake_mtu, worst_case_mtu(self.endpoint.socket)
+ self._handshake_mtu or 0,
+ worst_case_mtu(self.endpoint.socket),
)
- async def send(self, data):
+ async def send(
+ self, data: bytes
+ ) -> None: # or str? SendChannel defines it as bytes
"""Send a packet of data, securely."""
if self._closed:
@@ -1019,7 +1041,7 @@ async def send(self, data):
_read_loop(self._ssl.bio_read), self.peer_address
)
- async def receive(self):
+ async def receive(self) -> bytes: # or str?
"""Fetch the next packet of data from this connection's peer, waiting if
necessary.
@@ -1045,7 +1067,7 @@ async def receive(self):
if cleartext:
return cleartext
- def set_ciphertext_mtu(self, new_mtu):
+ def set_ciphertext_mtu(self, new_mtu: int) -> None:
"""Tells Trio the `largest amount of data that can be sent in a single packet to
this peer `__.
@@ -1080,7 +1102,7 @@ def set_ciphertext_mtu(self, new_mtu):
self._handshake_mtu = new_mtu
self._ssl.set_ciphertext_mtu(new_mtu)
- def get_cleartext_mtu(self):
+ def get_cleartext_mtu(self) -> int:
"""Returns the largest number of bytes that you can pass in a single call to
`send` while still fitting within the network-level MTU.
@@ -1089,9 +1111,9 @@ def get_cleartext_mtu(self):
"""
if not self._did_handshake:
raise trio.NeedHandshakeError
- return self._ssl.get_cleartext_mtu()
+ return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return]
- def statistics(self):
+ def statistics(self) -> DTLSChannelStatistics:
"""Returns an object with statistics about this connection.
Currently this has only one attribute:
@@ -1142,7 +1164,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
if socket.type != trio.socket.SOCK_DGRAM:
raise ValueError("DTLS requires a SOCK_DGRAM socket")
self._initialized = True
- self.socket = socket
+ self.socket: _SocketType = socket
self.incoming_packets_buffer = incoming_packets_buffer
self._token = trio.lowlevel.current_trio_token()
@@ -1212,12 +1234,12 @@ def _check_closed(self) -> None:
if self._closed:
raise trio.ClosedResourceError
- async def serve( # type: ignore[no-untyped-def]
+ async def serve(
self,
ssl_context: Context,
async_fn: Callable[..., Awaitable],
- *args,
- task_status=trio.TASK_STATUS_IGNORED, # type: ignore[has-type] # ???
+ *args: Any,
+ task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] # ???
) -> None:
"""Listen for incoming connections, and spawn a handler for each using an
internal nursery.
@@ -1272,7 +1294,7 @@ async def handler_wrapper(stream):
finally:
self._listening_context = None
- def connect(self, address, ssl_context):
+ def connect(self, address: tuple[str, int], ssl_context: Context) -> DTLSChannel:
"""Initiate an outgoing DTLS connection.
Notice that this is a synchronous method. That's because it doesn't actually
diff --git a/trio/_socket.py b/trio/_socket.py
index e1a8c5562a..0f5aa75fd2 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -5,7 +5,7 @@
import socket as _stdlib_socket
import sys
from functools import wraps as _wraps
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, SupportsIndex
import idna as _idna
@@ -19,6 +19,8 @@
from typing_extensions import Self
+ from ._abc import HostnameResolver, SocketFactory
+
# Usage:
#
@@ -73,11 +75,13 @@ async def __aexit__(
# Overrides
################################################################
-_resolver = _core.RunVar("hostname_resolver")
-_socket_factory = _core.RunVar("socket_factory")
+_resolver: _core.RunVar[HostnameResolver | None] = _core.RunVar("hostname_resolver")
+_socket_factory: _core.RunVar[SocketFactory] = _core.RunVar("socket_factory")
-def set_custom_hostname_resolver(hostname_resolver):
+def set_custom_hostname_resolver(
+ hostname_resolver: HostnameResolver | None,
+) -> HostnameResolver | None:
"""Set a custom hostname resolver.
By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions
@@ -143,7 +147,22 @@ def set_custom_socket_factory(socket_factory):
_NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV
-async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
+async def getaddrinfo(
+ host: bytes | str | None,
+ port: bytes | str | int | None,
+ family: int = 0,
+ type: int = 0,
+ proto: int = 0,
+ flags: int = 0,
+) -> list[
+ tuple[
+ _stdlib_socket.AddressFamily,
+ _stdlib_socket.SocketKind,
+ int,
+ str,
+ tuple[str, int] | tuple[str, int, int, int],
+ ]
+]:
"""Look up a numeric address given a name.
Arguments and return values are identical to :func:`socket.getaddrinfo`,
@@ -190,7 +209,7 @@ def numeric_only_failure(exc):
# idna.encode will error out if the hostname has Capital Letters
# in it; with uts46=True it will lowercase them instead.
host = _idna.encode(host, uts46=True)
- hr = _resolver.get(None)
+ hr: HostnameResolver | None = _resolver.get(None)
if hr is not None:
return await hr.getaddrinfo(host, port, family, type, proto, flags)
else:
@@ -206,7 +225,9 @@ def numeric_only_failure(exc):
)
-async def getnameinfo(sockaddr, flags):
+async def getnameinfo(
+ sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int
+) -> tuple[str, str]:
"""Look up a name given a numeric address.
Arguments and return values are identical to :func:`socket.getnameinfo`,
@@ -244,7 +265,7 @@ async def getprotobyname(name):
################################################################
-def from_stdlib_socket(sock):
+def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType:
"""Convert a standard library :class:`socket.socket` object into a Trio
socket object.
@@ -253,7 +274,12 @@ def from_stdlib_socket(sock):
@_wraps(_stdlib_socket.fromfd, assigned=(), updated=())
-def fromfd(fd, family, type, proto=0):
+def fromfd(
+ fd: SupportsIndex,
+ family: _stdlib_socket.AddressFamily | int = _stdlib_socket.AF_INET,
+ type: _stdlib_socket.SocketKind | int = _stdlib_socket.SOCK_STREAM,
+ proto: int = 0,
+) -> _SocketType:
"""Like :func:`socket.fromfd`, but returns a Trio socket object."""
family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd)
return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto))
@@ -280,11 +306,11 @@ def socketpair(*args, **kwargs):
@_wraps(_stdlib_socket.socket, assigned=(), updated=())
def socket(
- family=_stdlib_socket.AF_INET,
- type=_stdlib_socket.SOCK_STREAM,
- proto=0,
- fileno=None,
-):
+ family: _stdlib_socket.AddressFamily | int = _stdlib_socket.AF_INET,
+ type: _stdlib_socket.SocketKind | int = _stdlib_socket.SOCK_STREAM,
+ proto: int = 0,
+ fileno: int | None = None,
+) -> _SocketType:
"""Create a new Trio socket, like :class:`socket.socket`.
This function's behavior can be customized using
@@ -483,7 +509,7 @@ def __init__(self, sock: _stdlib_socket.socket):
"share",
}
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
if name in self._forward:
return getattr(self._sock, name)
raise AttributeError(name)
@@ -619,9 +645,11 @@ async def _nonblocking_helper(self, fn, args, kwargs, wait_fn):
# accept
################################################################
- _accept = _make_simple_sock_method_wrapper("accept", _core.wait_readable)
+ _accept: Callable[
+ [], Awaitable[tuple[_stdlib_socket.socket, object]]
+ ] = _make_simple_sock_method_wrapper("accept", _core.wait_readable)
- async def accept(self):
+ async def accept(self) -> tuple[_SocketType, object]:
"""Like :meth:`socket.socket.accept`, but async."""
sock, addr = await self._accept()
return from_stdlib_socket(sock), addr
@@ -630,7 +658,8 @@ async def accept(self):
# connect
################################################################
- async def connect(self, address):
+ # TODO: typing addresses is ... a pain
+ async def connect(self, address: str) -> None:
# nonblocking connect is weird -- you call it to start things
# off, then the socket becomes writable as a completion
# notification. This means it isn't really cancellable... we close the
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 147fa6253a..00a964787a 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.8812199036918138,
+ "completenessScore": 0.9020866773675762,
"exportedSymbolCounts": {
- "withAmbiguousType": 1,
- "withKnownType": 549,
- "withUnknownType": 73
+ "withAmbiguousType": 0,
+ "withKnownType": 562,
+ "withUnknownType": 61
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
@@ -46,53 +46,17 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 438,
- "withUnknownType": 130
+ "withKnownType": 523,
+ "withUnknownType": 89
},
"packageName": "trio",
"symbols": [
- "trio.__deprecated_attributes__",
- "trio._abc.HostnameResolver.getaddrinfo",
- "trio._abc.HostnameResolver.getnameinfo",
- "trio._abc.Instrument.after_io_wait",
- "trio._abc.Instrument.after_run",
- "trio._abc.Instrument.after_task_step",
- "trio._abc.Instrument.before_io_wait",
- "trio._abc.Instrument.before_run",
- "trio._abc.Instrument.before_task_step",
- "trio._abc.Instrument.task_exited",
- "trio._abc.Instrument.task_scheduled",
- "trio._abc.Instrument.task_spawned",
- "trio._abc.Listener.accept",
"trio._abc.SocketFactory.socket",
"trio._core._entry_queue.TrioToken.run_sync_soon",
- "trio._core._local.RunVar.__repr__",
- "trio._core._local.RunVar.get",
- "trio._core._local.RunVar.reset",
- "trio._core._local.RunVar.set",
"trio._core._run.Nursery.start",
"trio._core._run.Nursery.start_soon",
- "trio._core._unbounded_queue.UnboundedQueue.__aiter__",
- "trio._core._unbounded_queue.UnboundedQueue.__anext__",
- "trio._core._unbounded_queue.UnboundedQueue.__repr__",
- "trio._core._unbounded_queue.UnboundedQueue.empty",
- "trio._core._unbounded_queue.UnboundedQueue.get_batch",
- "trio._core._unbounded_queue.UnboundedQueue.get_batch_nowait",
- "trio._core._unbounded_queue.UnboundedQueue.qsize",
- "trio._core._unbounded_queue.UnboundedQueue.statistics",
- "trio._dtls.DTLSChannel.__enter__",
"trio._dtls.DTLSChannel.__init__",
- "trio._dtls.DTLSChannel.aclose",
- "trio._dtls.DTLSChannel.do_handshake",
- "trio._dtls.DTLSChannel.get_cleartext_mtu",
- "trio._dtls.DTLSChannel.receive",
- "trio._dtls.DTLSChannel.send",
- "trio._dtls.DTLSChannel.set_ciphertext_mtu",
- "trio._dtls.DTLSChannel.statistics",
"trio._dtls.DTLSEndpoint.__init__",
- "trio._dtls.DTLSEndpoint.connect",
- "trio._dtls.DTLSEndpoint.serve",
- "trio._highlevel_socket.SocketListener",
"trio._highlevel_socket.SocketListener.__init__",
"trio._highlevel_socket.SocketStream.__init__",
"trio._highlevel_socket.SocketStream.getsockopt",
@@ -112,9 +76,6 @@
"trio._path.Path.__rtruediv__",
"trio._path.Path.__truediv__",
"trio._path.Path.open",
- "trio._socket._SocketType.__getattr__",
- "trio._socket._SocketType.accept",
- "trio._socket._SocketType.connect",
"trio._socket._SocketType.recv_into",
"trio._socket._SocketType.recvfrom",
"trio._socket._SocketType.recvfrom_into",
@@ -123,7 +84,6 @@
"trio._socket._SocketType.send",
"trio._socket._SocketType.sendmsg",
"trio._socket._SocketType.sendto",
- "trio._ssl.SSLListener",
"trio._ssl.SSLListener.__init__",
"trio._ssl.SSLListener.accept",
"trio._ssl.SSLListener.aclose",
@@ -155,7 +115,6 @@
"trio.current_time",
"trio.from_thread.run",
"trio.from_thread.run_sync",
- "trio.lowlevel.add_instrument",
"trio.lowlevel.cancel_shielded_checkpoint",
"trio.lowlevel.current_clock",
"trio.lowlevel.current_root_task",
@@ -165,7 +124,6 @@
"trio.lowlevel.notify_closing",
"trio.lowlevel.permanently_detach_coroutine_object",
"trio.lowlevel.reattach_detached_coroutine_object",
- "trio.lowlevel.remove_instrument",
"trio.lowlevel.reschedule",
"trio.lowlevel.spawn_system_task",
"trio.lowlevel.start_guest_run",
@@ -184,13 +142,8 @@
"trio.serve_ssl_over_tcp",
"trio.serve_tcp",
"trio.socket.from_stdlib_socket",
- "trio.socket.fromfd",
- "trio.socket.getaddrinfo",
- "trio.socket.getnameinfo",
"trio.socket.getprotobyname",
- "trio.socket.set_custom_hostname_resolver",
"trio.socket.set_custom_socket_factory",
- "trio.socket.socket",
"trio.socket.socketpair",
"trio.testing._memory_streams.MemoryReceiveStream.__init__",
"trio.testing._memory_streams.MemoryReceiveStream.aclose",
diff --git a/trio/_threads.py b/trio/_threads.py
index 52c742d588..45a416249e 100644
--- a/trio/_threads.py
+++ b/trio/_threads.py
@@ -27,7 +27,7 @@
# Global due to Threading API, thread local storage for trio token
TOKEN_LOCAL = threading.local()
-_limiter_local = RunVar("limiter")
+_limiter_local: RunVar[CapacityLimiter] = RunVar("limiter")
# I pulled this number out of the air; it isn't based on anything. Probably we
# should make some kind of measurements to pick a good value.
DEFAULT_LIMIT = 40
diff --git a/trio/_util.py b/trio/_util.py
index 0a0795fc15..0f73ff19e9 100644
--- a/trio/_util.py
+++ b/trio/_util.py
@@ -216,7 +216,7 @@ def decorator(func):
return decorator
-def fixup_module_metadata(module_name, namespace):
+def fixup_module_metadata(module_name: str, namespace: dict[str, object]) -> None:
seen_ids = set()
def fix_one(qualname, name, obj):
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index b3bdfd85c0..fdb4d45102 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -19,6 +19,7 @@
from trio._util import Final, NoPublicConstructor
if TYPE_CHECKING:
+ import socket
from types import TracebackType
IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@@ -113,11 +114,27 @@ class FakeHostnameResolver(trio.abc.HostnameResolver):
fake_net: "FakeNet"
async def getaddrinfo(
- self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0
- ):
+ self,
+ host: bytes | str | None,
+ port: bytes | str | int | None,
+ family: int = 0,
+ type: int = 0,
+ proto: int = 0,
+ flags: int = 0,
+ ) -> list[
+ tuple[
+ socket.AddressFamily,
+ socket.SocketKind,
+ int,
+ str,
+ tuple[str, int] | tuple[str, int, int, int],
+ ]
+ ]:
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
- async def getnameinfo(self, sockaddr, flags: int):
+ async def getnameinfo(
+ self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int
+ ) -> tuple[str, str]:
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
From 4678d44ede954b8e3431b56d0c973bdebacd683a Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 13 Jul 2023 16:12:23 +0200
Subject: [PATCH 03/49] typecheck trio/_dtls.py
---
pyproject.toml | 7 ++
trio/_channel.py | 10 +-
trio/_dtls.py | 195 +++++++++++++++++++++-------------
trio/_tests/verify_types.json | 21 +---
4 files changed, 140 insertions(+), 93 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index cfb4060ee7..ee0c019af7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,6 +44,13 @@ disallow_untyped_defs = false
# DO NOT use `ignore_errors`; it doesn't apply
# downstream and users have to deal with them.
+[[tool.mypy.overrides]]
+module = [
+ "trio._dtls"
+]
+disallow_incomplete_defs = true
+disallow_untyped_defs = true
+
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
faulthandler_timeout = 60
diff --git a/trio/_channel.py b/trio/_channel.py
index 7c8ff4660d..df596adddd 100644
--- a/trio/_channel.py
+++ b/trio/_channel.py
@@ -20,7 +20,7 @@
def _open_memory_channel(
- max_buffer_size: int,
+ max_buffer_size: int | float,
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
"""Open a channel for passing objects between tasks within a process.
@@ -92,11 +92,11 @@ def _open_memory_channel(
# Need to use Tuple instead of tuple due to CI check running on 3.8
class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]):
def __new__( # type: ignore[misc] # "must return a subtype"
- cls, max_buffer_size: int
+ cls, max_buffer_size: int | float
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)
- def __init__(self, max_buffer_size: int):
+ def __init__(self, max_buffer_size: int | float):
...
else:
@@ -108,7 +108,7 @@ def __init__(self, max_buffer_size: int):
@attr.s(frozen=True, slots=True)
class MemoryChannelStats:
current_buffer_used: int = attr.ib()
- max_buffer_size: int = attr.ib()
+ max_buffer_size: int | float = attr.ib()
open_send_channels: int = attr.ib()
open_receive_channels: int = attr.ib()
tasks_waiting_send: int = attr.ib()
@@ -117,7 +117,7 @@ class MemoryChannelStats:
@attr.s(slots=True)
class MemoryChannelState(Generic[T]):
- max_buffer_size: int = attr.ib()
+ max_buffer_size: int | float = attr.ib()
data: deque[T] = attr.ib(factory=deque)
# Counts of open endpoints using this state
open_send_channels: int = attr.ib(default=0)
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 722a9499f8..acc5f950eb 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -16,34 +16,53 @@
import warnings
import weakref
from itertools import count
-from typing import TYPE_CHECKING
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Generic,
+ Iterable,
+ Iterator,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
+from OpenSSL import SSL
import trio
-from trio._util import Final, NoPublicConstructor
+
+from ._util import Final, NoPublicConstructor
if TYPE_CHECKING:
from types import TracebackType
+ from OpenSSL.SSL import Context
+ from typing_extensions import Self, TypeAlias
+
+ from ._core._run import _TaskStatus
+ from ._socket import _SocketType
+
MAX_UDP_PACKET_SIZE = 65527
-def packet_header_overhead(sock):
+def packet_header_overhead(sock: _SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 28
else:
return 48
-def worst_case_mtu(sock):
+def worst_case_mtu(sock: _SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 576 - packet_header_overhead(sock)
else:
return 1280 - packet_header_overhead(sock)
-def best_guess_mtu(sock):
+def best_guess_mtu(sock: _SocketType) -> int:
return 1500 - packet_header_overhead(sock)
@@ -105,14 +124,14 @@ class BadPacket(Exception):
# ChangeCipherSpec is used during the handshake but has its own ContentType.
#
# Cannot fail.
-def part_of_handshake_untrusted(packet):
+def part_of_handshake_untrusted(packet: bytes) -> bool:
# If the packet is too short, then slicing will successfully return a
# short string, which will necessarily fail to match.
return packet[3:5] == b"\x00\x00"
# Cannot fail
-def is_client_hello_untrusted(packet):
+def is_client_hello_untrusted(packet: bytes) -> bool:
try:
return (
packet[0] == ContentType.handshake
@@ -147,7 +166,7 @@ class Record:
payload: bytes = attr.ib(repr=to_hex)
-def records_untrusted(packet):
+def records_untrusted(packet: bytes) -> Iterator[Record]:
i = 0
while i < len(packet):
try:
@@ -165,7 +184,7 @@ def records_untrusted(packet):
yield Record(ct, version, epoch_seqno, payload)
-def encode_record(record):
+def encode_record(record: Record) -> bytes:
header = RECORD_HEADER.pack(
record.content_type,
record.version,
@@ -194,7 +213,7 @@ class HandshakeFragment:
frag: bytes = attr.ib(repr=to_hex)
-def decode_handshake_fragment_untrusted(payload):
+def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment:
# Raises BadPacket if decoding fails
try:
(
@@ -224,7 +243,7 @@ def decode_handshake_fragment_untrusted(payload):
)
-def encode_handshake_fragment(hsf):
+def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes:
hs_header = HANDSHAKE_MESSAGE_HEADER.pack(
hsf.msg_type,
hsf.msg_len.to_bytes(3, "big"),
@@ -235,7 +254,7 @@ def encode_handshake_fragment(hsf):
return hs_header + hsf.frag
-def decode_client_hello_untrusted(packet):
+def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]:
# Raises BadPacket if parsing fails
# Returns (record epoch_seqno, cookie from the packet, data that should be
# hashed into cookie)
@@ -331,12 +350,20 @@ class OpaqueHandshakeMessage:
record: Record
+# for some reason doesn't work with |
+_AnyHandshakeMessage: TypeAlias = Union[
+ HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage
+]
+
+
# This takes a raw outgoing handshake volley that openssl generated, and
# reconstructs the handshake messages inside it, so that we can repack them
# into records while retransmitting. So the data ought to be well-behaved --
# it's not coming from the network.
-def decode_volley_trusted(volley):
- messages = []
+def decode_volley_trusted(
+ volley: bytes,
+) -> list[_AnyHandshakeMessage]:
+ messages: list[_AnyHandshakeMessage] = []
messages_by_seq = {}
for record in records_untrusted(volley):
# ChangeCipherSpec isn't a handshake message, so it can't be fragmented.
@@ -380,13 +407,17 @@ def decode_volley_trusted(volley):
class RecordEncoder:
- def __init__(self):
+ def __init__(self) -> None:
self._record_seq = count()
- def set_first_record_number(self, n):
+ def set_first_record_number(self, n: int) -> None:
self._record_seq = count(n)
- def encode_volley(self, messages, mtu):
+ def encode_volley(
+ self,
+ messages: Iterable[_AnyHandshakeMessage],
+ mtu: int,
+ ) -> list[bytearray]:
packets = []
packet = bytearray()
for message in messages:
@@ -518,13 +549,13 @@ def encode_volley(self, messages, mtu):
COOKIE_LENGTH = 32
-def _current_cookie_tick():
+def _current_cookie_tick() -> int:
return int(trio.current_time() / COOKIE_REFRESH_INTERVAL)
# Simple deterministic and invertible serializer -- i.e., a useful tool for converting
# structured data into something we can cryptographically sign.
-def _signable(*fields):
+def _signable(*fields: bytes) -> bytes:
out = []
for field in fields:
out.append(struct.pack("!Q", len(field)))
@@ -532,7 +563,9 @@ def _signable(*fields):
return b"".join(out)
-def _make_cookie(key, salt, tick, address, client_hello_bits):
+def _make_cookie(
+ key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes
+) -> bytes:
assert len(salt) == SALT_BYTES
assert len(key) == KEY_BYTES
@@ -548,7 +581,9 @@ def _make_cookie(key, salt, tick, address, client_hello_bits):
return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH]
-def valid_cookie(key, cookie, address, client_hello_bits):
+def valid_cookie(
+ key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes
+) -> bool:
if len(cookie) > SALT_BYTES:
salt = cookie[:SALT_BYTES]
@@ -568,7 +603,9 @@ def valid_cookie(key, cookie, address, client_hello_bits):
return False
-def challenge_for(key, address, epoch_seqno, client_hello_bits):
+def challenge_for(
+ key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes
+) -> bytes:
salt = os.urandom(SALT_BYTES)
tick = _current_cookie_tick()
cookie = _make_cookie(key, salt, tick, address, client_hello_bits)
@@ -608,12 +645,15 @@ def challenge_for(key, address, epoch_seqno, client_hello_bits):
return packet
-class _Queue:
- def __init__(self, incoming_packets_buffer):
- self.s, self.r = trio.open_memory_channel(incoming_packets_buffer)
+T = TypeVar("T")
+
+class _Queue(Generic[T]):
+ def __init__(self, incoming_packets_buffer: int | float):
+ self.s, self.r = trio.open_memory_channel[T](incoming_packets_buffer)
-def _read_loop(read_fn):
+
+def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
chunks = []
while True:
try:
@@ -624,7 +664,9 @@ def _read_loop(read_fn):
return b"".join(chunks)
-async def handle_client_hello_untrusted(endpoint, address, packet):
+async def handle_client_hello_untrusted(
+ endpoint: DTLSEndpoint, address: Any, packet: bytes
+) -> None:
if endpoint._listening_context is None:
return
@@ -697,7 +739,9 @@ async def handle_client_hello_untrusted(endpoint, address, packet):
endpoint._incoming_connections_q.s.send_nowait(stream)
-async def dtls_receive_loop(endpoint_ref, sock):
+async def dtls_receive_loop(
+ endpoint_ref: weakref.ReferenceType[DTLSEndpoint], sock: _SocketType
+) -> None:
try:
while True:
try:
@@ -773,7 +817,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
"""
- def __init__(self, endpoint, peer_address, ctx):
+ def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context):
self.endpoint = endpoint
self.peer_address = peer_address
self._packets_dropped_in_trio = 0
@@ -784,25 +828,27 @@ def __init__(self, endpoint, peer_address, ctx):
# OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to
# support and isn't useful anyway -- especially for DTLS where it's equivalent
# to just performing a new handshake.
- ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION)
+ ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined]
self._ssl = SSL.Connection(ctx)
- self._handshake_mtu = None
+ self._handshake_mtu = 0
# This calls self._ssl.set_ciphertext_mtu, which is important, because if you
# don't call it then openssl doesn't work.
self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket))
self._replaced = False
self._closed = False
- self._q = _Queue(endpoint.incoming_packets_buffer)
+ self._q = _Queue[bytes](endpoint.incoming_packets_buffer)
self._handshake_lock = trio.Lock()
- self._record_encoder = RecordEncoder()
+ self._record_encoder: RecordEncoder = RecordEncoder()
+
+ self._final_volley: list[_AnyHandshakeMessage] = []
- def _set_replaced(self):
+ def _set_replaced(self) -> None:
self._replaced = True
# Any packets we already received could maybe possibly still be processed, but
# there are no more coming. So we close this on the sender side.
self._q.s.close()
- def _check_replaced(self):
+ def _check_replaced(self) -> None:
if self._replaced:
raise trio.BrokenResourceError(
"peer tore down this connection to start a new one"
@@ -836,7 +882,7 @@ def close(self) -> None:
# ClosedResourceError
self._q.r.close()
- def __enter__(self):
+ def __enter__(self) -> Self:
return self
def __exit__(
@@ -847,7 +893,7 @@ def __exit__(
) -> None:
return self.close()
- async def aclose(self):
+ async def aclose(self) -> None:
"""Close this connection, but asynchronously.
This is included to satisfy the `trio.abc.Channel` contract. It's
@@ -857,7 +903,7 @@ async def aclose(self):
self.close()
await trio.lowlevel.checkpoint()
- async def _send_volley(self, volley_messages):
+ async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None:
packets = self._record_encoder.encode_volley(
volley_messages, self._handshake_mtu
)
@@ -865,10 +911,10 @@ async def _send_volley(self, volley_messages):
async with self.endpoint._send_lock:
await self.endpoint.socket.sendto(packet, self.peer_address)
- async def _resend_final_volley(self):
+ async def _resend_final_volley(self) -> None:
await self._send_volley(self._final_volley)
- async def do_handshake(self, *, initial_retransmit_timeout=1.0):
+ async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None:
"""Perform the handshake.
Calling this is optional – if you don't, then it will be automatically called
@@ -901,17 +947,19 @@ async def do_handshake(self, *, initial_retransmit_timeout=1.0):
return
timeout = initial_retransmit_timeout
- volley_messages = []
+ volley_messages: list[_AnyHandshakeMessage] = []
volley_failed_sends = 0
- def read_volley():
+ def read_volley() -> list[_AnyHandshakeMessage]:
volley_bytes = _read_loop(self._ssl.bio_read)
new_volley_messages = decode_volley_trusted(volley_bytes)
if (
new_volley_messages
and volley_messages
and isinstance(new_volley_messages[0], HandshakeMessage)
- and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq
+ # TODO: add isinstance or do a cast?
+ and new_volley_messages[0].msg_seq
+ == cast(HandshakeMessage, volley_messages[0]).msg_seq
):
# openssl decided to retransmit; discard because we handle
# retransmits ourselves
@@ -995,10 +1043,13 @@ def read_volley():
# PMTU estimate is wrong? Let's try dropping it to the minimum
# and hope that helps.
self._handshake_mtu = min(
- self._handshake_mtu, worst_case_mtu(self.endpoint.socket)
+ self._handshake_mtu,
+ worst_case_mtu(self.endpoint.socket),
)
- async def send(self, data):
+ async def send(
+ self, data: bytes
+ ) -> None: # or str? SendChannel defines it as bytes
"""Send a packet of data, securely."""
if self._closed:
@@ -1014,7 +1065,7 @@ async def send(self, data):
_read_loop(self._ssl.bio_read), self.peer_address
)
- async def receive(self):
+ async def receive(self) -> bytes: # or str?
"""Fetch the next packet of data from this connection's peer, waiting if
necessary.
@@ -1040,7 +1091,7 @@ async def receive(self):
if cleartext:
return cleartext
- def set_ciphertext_mtu(self, new_mtu):
+ def set_ciphertext_mtu(self, new_mtu: int) -> None:
"""Tells Trio the `largest amount of data that can be sent in a single packet to
this peer `__.
@@ -1075,7 +1126,7 @@ def set_ciphertext_mtu(self, new_mtu):
self._handshake_mtu = new_mtu
self._ssl.set_ciphertext_mtu(new_mtu)
- def get_cleartext_mtu(self):
+ def get_cleartext_mtu(self) -> int:
"""Returns the largest number of bytes that you can pass in a single call to
`send` while still fitting within the network-level MTU.
@@ -1084,9 +1135,9 @@ def get_cleartext_mtu(self):
"""
if not self._did_handshake:
raise trio.NeedHandshakeError
- return self._ssl.get_cleartext_mtu()
+ return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return]
- def statistics(self):
+ def statistics(self) -> DTLSChannelStatistics:
"""Returns an object with statistics about this connection.
Currently this has only one attribute:
@@ -1126,18 +1177,18 @@ class DTLSEndpoint(metaclass=Final):
"""
- def __init__(self, socket, *, incoming_packets_buffer=10):
+ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
# We do this lazily on first construction, so only people who actually use DTLS
# have to install PyOpenSSL.
global SSL
from OpenSSL import SSL
- # TODO: create a `self._initialized` for `__del__`, so self.socket can be typed
- # as trio.socket.SocketType and `is not None` checks can be removed.
- self.socket = None # for __del__, in case the next line raises
+ # for __del__, in case the next line raises
+ self._initialized: bool = False
if socket.type != trio.socket.SOCK_DGRAM:
raise ValueError("DTLS requires a SOCK_DGRAM socket")
- self.socket = socket
+ self._initialized = True
+ self.socket: _SocketType = socket
self.incoming_packets_buffer = incoming_packets_buffer
self._token = trio.lowlevel.current_trio_token()
@@ -1146,15 +1197,15 @@ def __init__(self, socket, *, incoming_packets_buffer=10):
# as a peer provides a valid cookie, we can immediately tear down the
# old connection.
# {remote address: DTLSChannel}
- self._streams = weakref.WeakValueDictionary()
- self._listening_context = None
- self._listening_key = None
- self._incoming_connections_q = _Queue(float("inf"))
+ self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
+ self._listening_context: Context | None = None
+ self._listening_key: bytes | None = None
+ self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
self._send_lock = trio.Lock()
self._closed = False
self._receive_loop_spawned = False
- def _ensure_receive_loop(self):
+ def _ensure_receive_loop(self) -> None:
# We have to spawn this lazily, because on Windows it will immediately error out
# if the socket isn't already bound -- which for clients might not happen until
# after we send our first packet.
@@ -1164,9 +1215,9 @@ def _ensure_receive_loop(self):
)
self._receive_loop_spawned = True
- def __del__(self):
+ def __del__(self) -> None:
# Do nothing if this object was never fully constructed
- if self.socket is None:
+ if not self._initialized:
return
# Close the socket in Trio context (if our Trio context still exists), so that
# the background task gets notified about the closure and can exit.
@@ -1186,17 +1237,13 @@ def close(self) -> None:
This object can also be used as a context manager.
"""
- # Do nothing if this object was never fully constructed
- if self.socket is None: # pragma: no cover
- return
-
self._closed = True
self.socket.close()
for stream in list(self._streams.values()):
stream.close()
self._incoming_connections_q.s.close()
- def __enter__(self):
+ def __enter__(self) -> Self:
return self
def __exit__(
@@ -1207,13 +1254,17 @@ def __exit__(
) -> None:
return self.close()
- def _check_closed(self):
+ def _check_closed(self) -> None:
if self._closed:
raise trio.ClosedResourceError
async def serve(
- self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED
- ):
+ self,
+ ssl_context: Context,
+ async_fn: Callable[[DTLSChannel], Awaitable],
+ *args: Any,
+ task_status: _TaskStatus = trio.TASK_STATUS_IGNORED,
+ ) -> None:
"""Listen for incoming connections, and spawn a handler for each using an
internal nursery.
@@ -1257,7 +1308,7 @@ async def handler(dtls_channel):
self._listening_context = ssl_context
task_status.started()
- async def handler_wrapper(stream):
+ async def handler_wrapper(stream: DTLSChannel) -> None:
with stream:
await async_fn(stream, *args)
@@ -1267,7 +1318,7 @@ async def handler_wrapper(stream):
finally:
self._listening_context = None
- def connect(self, address, ssl_context):
+ def connect(self, address: tuple[str, int], ssl_context: Context) -> DTLSChannel:
"""Initiate an outgoing DTLS connection.
Notice that this is a synchronous method. That's because it doesn't actually
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 9d7d7aa912..cf1e7eccfb 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -45,9 +45,9 @@
}
],
"otherSymbolCounts": {
- "withAmbiguousType": 8,
- "withKnownType": 433,
- "withUnknownType": 135
+ "withAmbiguousType": 6,
+ "withKnownType": 460,
+ "withUnknownType": 129
},
"packageName": "trio",
"symbols": [
@@ -73,6 +73,8 @@
"trio._core._mock_clock.MockClock.jump",
"trio._core._run.Nursery.start",
"trio._core._run.Nursery.start_soon",
+ "trio._core._run._TaskStatus.__repr__",
+ "trio._core._run._TaskStatus.started",
"trio._core._unbounded_queue.UnboundedQueue.__aiter__",
"trio._core._unbounded_queue.UnboundedQueue.__anext__",
"trio._core._unbounded_queue.UnboundedQueue.__repr__",
@@ -81,22 +83,9 @@
"trio._core._unbounded_queue.UnboundedQueue.get_batch_nowait",
"trio._core._unbounded_queue.UnboundedQueue.qsize",
"trio._core._unbounded_queue.UnboundedQueue.statistics",
- "trio._dtls.DTLSChannel.__enter__",
"trio._dtls.DTLSChannel.__init__",
- "trio._dtls.DTLSChannel.aclose",
- "trio._dtls.DTLSChannel.do_handshake",
- "trio._dtls.DTLSChannel.get_cleartext_mtu",
- "trio._dtls.DTLSChannel.receive",
- "trio._dtls.DTLSChannel.send",
- "trio._dtls.DTLSChannel.set_ciphertext_mtu",
- "trio._dtls.DTLSChannel.statistics",
- "trio._dtls.DTLSEndpoint.__del__",
- "trio._dtls.DTLSEndpoint.__enter__",
"trio._dtls.DTLSEndpoint.__init__",
- "trio._dtls.DTLSEndpoint.connect",
- "trio._dtls.DTLSEndpoint.incoming_packets_buffer",
"trio._dtls.DTLSEndpoint.serve",
- "trio._dtls.DTLSEndpoint.socket",
"trio._highlevel_socket.SocketListener",
"trio._highlevel_socket.SocketListener.__init__",
"trio._highlevel_socket.SocketStream.__init__",
From 2522bc529ca95602da1cb9b8aa1862dab3de69b2 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 13 Jul 2023 16:38:09 +0200
Subject: [PATCH 04/49] incorporate _abc
---
trio/_abc.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index 0bb49e207d..402bb78b27 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -1,19 +1,19 @@
from __future__ import annotations
+import socket
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Generic, TypeVar
import trio
if TYPE_CHECKING:
- import socket
from types import TracebackType
from typing_extensions import Self
- from trio.lowlevel import Task
-
+ # both of these introduce circular imports if outside a TYPE_CHECKING guard
from ._socket import _SocketType
+ from .lowlevel import Task
# We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a
From 909ac67fba675f036fcea18b9ba7b1eacc4d4ccb Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 13 Jul 2023 16:42:30 +0200
Subject: [PATCH 05/49] incorporate _dtls
---
pyproject.toml | 9 +++-
trio/_channel.py | 10 ++---
trio/_dtls.py | 108 +++++++++++++++++++++++++++++------------------
3 files changed, 79 insertions(+), 48 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 121a398234..fa46d76bc7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,7 +33,7 @@ warn_return_any = true
# Avoid subtle backsliding
#disallow_any_decorated = true
-#disallow_incomplete_defs = true
+disallow_incomplete_defs = true
#disallow_subclassing_any = true
# Enable gradually / for new modules
@@ -48,6 +48,13 @@ disallow_untyped_defs = false
module = "trio._core._run"
disallow_incomplete_defs = false
+[[tool.mypy.overrides]]
+module = [
+ "trio._abc",
+ "trio._dtls"
+]
+disallow_untyped_defs = true
+
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
faulthandler_timeout = 60
diff --git a/trio/_channel.py b/trio/_channel.py
index 7c8ff4660d..df596adddd 100644
--- a/trio/_channel.py
+++ b/trio/_channel.py
@@ -20,7 +20,7 @@
def _open_memory_channel(
- max_buffer_size: int,
+ max_buffer_size: int | float,
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
"""Open a channel for passing objects between tasks within a process.
@@ -92,11 +92,11 @@ def _open_memory_channel(
# Need to use Tuple instead of tuple due to CI check running on 3.8
class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]):
def __new__( # type: ignore[misc] # "must return a subtype"
- cls, max_buffer_size: int
+ cls, max_buffer_size: int | float
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)
- def __init__(self, max_buffer_size: int):
+ def __init__(self, max_buffer_size: int | float):
...
else:
@@ -108,7 +108,7 @@ def __init__(self, max_buffer_size: int):
@attr.s(frozen=True, slots=True)
class MemoryChannelStats:
current_buffer_used: int = attr.ib()
- max_buffer_size: int = attr.ib()
+ max_buffer_size: int | float = attr.ib()
open_send_channels: int = attr.ib()
open_receive_channels: int = attr.ib()
tasks_waiting_send: int = attr.ib()
@@ -117,7 +117,7 @@ class MemoryChannelStats:
@attr.s(slots=True)
class MemoryChannelState(Generic[T]):
- max_buffer_size: int = attr.ib()
+ max_buffer_size: int | float = attr.ib()
data: deque[T] = attr.ib(factory=deque)
# Counts of open endpoints using this state
open_send_channels: int = attr.ib(default=0)
diff --git a/trio/_dtls.py b/trio/_dtls.py
index aea15be735..8aede92bc8 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -16,23 +16,34 @@
import warnings
import weakref
from itertools import count
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Iterator, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Generic,
+ Iterable,
+ Iterator,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
from OpenSSL import SSL
import trio
-from trio._util import Final, NoPublicConstructor
+
+from ._util import Final, NoPublicConstructor
if TYPE_CHECKING:
from types import TracebackType
from OpenSSL.SSL import Context
- from typing_extensions import Self
-
- from trio._socket import _SocketType
+ from typing_extensions import Self, TypeAlias
from ._core._run import _TaskStatus
+ from ._socket import _SocketType
MAX_UDP_PACKET_SIZE = 65527
@@ -339,16 +350,20 @@ class OpaqueHandshakeMessage:
record: Record
+# for some reason doesn't work with |
+_AnyHandshakeMessage: TypeAlias = Union[
+ HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage
+]
+
+
# This takes a raw outgoing handshake volley that openssl generated, and
# reconstructs the handshake messages inside it, so that we can repack them
# into records while retransmitting. So the data ought to be well-behaved --
# it's not coming from the network.
def decode_volley_trusted(
volley: bytes,
-) -> list[HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage]:
- messages: list[
- HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage
- ] = []
+) -> list[_AnyHandshakeMessage]:
+ messages: list[_AnyHandshakeMessage] = []
messages_by_seq = {}
for record in records_untrusted(volley):
# ChangeCipherSpec isn't a handshake message, so it can't be fragmented.
@@ -392,7 +407,7 @@ def decode_volley_trusted(
class RecordEncoder:
- def __init__(self):
+ def __init__(self) -> None:
self._record_seq = count()
def set_first_record_number(self, n: int) -> None:
@@ -400,9 +415,7 @@ def set_first_record_number(self, n: int) -> None:
def encode_volley(
self,
- messages: Iterable[
- HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage
- ],
+ messages: Iterable[_AnyHandshakeMessage],
mtu: int,
) -> list[bytearray]:
packets = []
@@ -550,7 +563,9 @@ def _signable(*fields: bytes) -> bytes:
return b"".join(out)
-def _make_cookie(key, salt, tick, address, client_hello_bits):
+def _make_cookie(
+ key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes
+) -> bytes:
assert len(salt) == SALT_BYTES
assert len(key) == KEY_BYTES
@@ -566,7 +581,9 @@ def _make_cookie(key, salt, tick, address, client_hello_bits):
return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH]
-def valid_cookie(key, cookie, address, client_hello_bits):
+def valid_cookie(
+ key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes
+) -> bool:
if len(cookie) > SALT_BYTES:
salt = cookie[:SALT_BYTES]
@@ -586,7 +603,9 @@ def valid_cookie(key, cookie, address, client_hello_bits):
return False
-def challenge_for(key, address, epoch_seqno, client_hello_bits):
+def challenge_for(
+ key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes
+) -> bytes:
salt = os.urandom(SALT_BYTES)
tick = _current_cookie_tick()
cookie = _make_cookie(key, salt, tick, address, client_hello_bits)
@@ -626,9 +645,12 @@ def challenge_for(key, address, epoch_seqno, client_hello_bits):
return packet
-class _Queue:
- def __init__(self, incoming_packets_buffer):
- self.s, self.r = trio.open_memory_channel(incoming_packets_buffer)
+T = TypeVar("T")
+
+
+class _Queue(Generic[T]):
+ def __init__(self, incoming_packets_buffer: int | float):
+ self.s, self.r = trio.open_memory_channel[T](incoming_packets_buffer)
def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
@@ -642,7 +664,9 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
return b"".join(chunks)
-async def handle_client_hello_untrusted(endpoint, address, packet):
+async def handle_client_hello_untrusted(
+ endpoint: DTLSEndpoint, address: Any, packet: bytes
+) -> None:
if endpoint._listening_context is None:
return
@@ -715,7 +739,9 @@ async def handle_client_hello_untrusted(endpoint, address, packet):
endpoint._incoming_connections_q.s.send_nowait(stream)
-async def dtls_receive_loop(endpoint_ref, sock):
+async def dtls_receive_loop(
+ endpoint_ref: weakref.ReferenceType[DTLSEndpoint], sock: _SocketType
+) -> None:
try:
while True:
try:
@@ -791,7 +817,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
"""
- def __init__(self, endpoint: DTLSEndpoint, peer_address: str, ctx: Context):
+ def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context):
self.endpoint = endpoint
self.peer_address = peer_address
self._packets_dropped_in_trio = 0
@@ -804,23 +830,25 @@ def __init__(self, endpoint: DTLSEndpoint, peer_address: str, ctx: Context):
# to just performing a new handshake.
ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined]
self._ssl = SSL.Connection(ctx)
- self._handshake_mtu: int | None = None
+ self._handshake_mtu = 0
# This calls self._ssl.set_ciphertext_mtu, which is important, because if you
# don't call it then openssl doesn't work.
self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket))
self._replaced = False
self._closed = False
- self._q = _Queue(endpoint.incoming_packets_buffer)
+ self._q = _Queue[bytes](endpoint.incoming_packets_buffer)
self._handshake_lock = trio.Lock()
- self._record_encoder = RecordEncoder()
+ self._record_encoder: RecordEncoder = RecordEncoder()
+
+ self._final_volley: list[_AnyHandshakeMessage] = []
- def _set_replaced(self):
+ def _set_replaced(self) -> None:
self._replaced = True
# Any packets we already received could maybe possibly still be processed, but
# there are no more coming. So we close this on the sender side.
self._q.s.close()
- def _check_replaced(self):
+ def _check_replaced(self) -> None:
if self._replaced:
raise trio.BrokenResourceError(
"peer tore down this connection to start a new one"
@@ -875,7 +903,7 @@ async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()
- async def _send_volley(self, volley_messages):
+ async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None:
packets = self._record_encoder.encode_volley(
volley_messages, self._handshake_mtu
)
@@ -883,7 +911,7 @@ async def _send_volley(self, volley_messages):
async with self.endpoint._send_lock:
await self.endpoint.socket.sendto(packet, self.peer_address)
- async def _resend_final_volley(self):
+ async def _resend_final_volley(self) -> None:
await self._send_volley(self._final_volley)
async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None:
@@ -919,14 +947,10 @@ async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None
return
timeout = initial_retransmit_timeout
- volley_messages: list[
- HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage
- ] = []
+ volley_messages: list[_AnyHandshakeMessage] = []
volley_failed_sends = 0
- def read_volley() -> (
- list[HandshakeMessage | OpaqueHandshakeMessage | PseudoHandshakeMessage]
- ):
+ def read_volley() -> list[_AnyHandshakeMessage]:
volley_bytes = _read_loop(self._ssl.bio_read)
new_volley_messages = decode_volley_trusted(volley_bytes)
if (
@@ -1019,7 +1043,7 @@ def read_volley() -> (
# PMTU estimate is wrong? Let's try dropping it to the minimum
# and hope that helps.
self._handshake_mtu = min(
- self._handshake_mtu or 0,
+ self._handshake_mtu,
worst_case_mtu(self.endpoint.socket),
)
@@ -1175,13 +1199,13 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
# {remote address: DTLSChannel}
self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self._listening_context: Context | None = None
- self._listening_key = None
- self._incoming_connections_q = _Queue(float("inf"))
+ self._listening_key: bytes | None = None
+ self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
self._send_lock = trio.Lock()
self._closed = False
self._receive_loop_spawned = False
- def _ensure_receive_loop(self):
+ def _ensure_receive_loop(self) -> None:
# We have to spawn this lazily, because on Windows it will immediately error out
# if the socket isn't already bound -- which for clients might not happen until
# after we send our first packet.
@@ -1237,9 +1261,9 @@ def _check_closed(self) -> None:
async def serve(
self,
ssl_context: Context,
- async_fn: Callable[..., Awaitable],
+ async_fn: Callable[[DTLSChannel], Awaitable],
*args: Any,
- task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] # ???
+ task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type]
) -> None:
"""Listen for incoming connections, and spawn a handler for each using an
internal nursery.
@@ -1284,7 +1308,7 @@ async def handler(dtls_channel):
self._listening_context = ssl_context
task_status.started()
- async def handler_wrapper(stream):
+ async def handler_wrapper(stream: DTLSChannel) -> None:
with stream:
await async_fn(stream, *args)
From 40ff89f65c22397c4ecae52be49fe5b7d5974b34 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 13 Jul 2023 16:51:33 +0200
Subject: [PATCH 06/49] ignore weird error with TASK_STATUS_IGNORED, add
pyOpenSSL to docs-requirements
---
docs-requirements.in | 3 +++
docs-requirements.txt | 8 ++++++++
trio/_dtls.py | 15 ++++++++++++++-
3 files changed, 25 insertions(+), 1 deletion(-)
diff --git a/docs-requirements.in b/docs-requirements.in
index 98d5030bc5..d6214ec1d0 100644
--- a/docs-requirements.in
+++ b/docs-requirements.in
@@ -19,3 +19,6 @@ exceptiongroup >= 1.0.0rc9
# See note in test-requirements.in
immutables >= 0.6
+
+# types used in annotations
+pyOpenSSL
diff --git a/docs-requirements.txt b/docs-requirements.txt
index 06136fd765..c607f4f186 100644
--- a/docs-requirements.txt
+++ b/docs-requirements.txt
@@ -16,6 +16,8 @@ babel==2.12.1
# via sphinx
certifi==2023.5.7
# via requests
+cffi==1.15.1
+ # via cryptography
charset-normalizer==3.1.0
# via requests
click==8.1.3
@@ -24,6 +26,8 @@ click==8.1.3
# towncrier
click-default-group==1.2.2
# via towncrier
+cryptography==41.0.2
+ # via pyopenssl
docutils==0.18.1
# via
# sphinx
@@ -55,8 +59,12 @@ outcome==1.2.0
# via -r docs-requirements.in
packaging==23.1
# via sphinx
+pycparser==2.21
+ # via cffi
pygments==2.15.1
# via sphinx
+pyopenssl==23.2.0
+ # via -r docs-requirements.in
pytz==2023.3
# via babel
requests==2.31.0
diff --git a/trio/_dtls.py b/trio/_dtls.py
index acc5f950eb..0af8340732 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -798,6 +798,19 @@ async def dtls_receive_loop(
@attr.frozen
class DTLSChannelStatistics:
+ """An object with statistics about this connection.
+
+ Currently this has only one attribute:
+
+ - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of
+ incoming packets from this peer that Trio successfully received from the
+ network, but then got dropped because the internal channel buffer was full. If
+ this is non-zero, then you might want to call ``receive`` more often, or use a
+ larger ``incoming_packets_buffer``, or just not worry about it because your
+ UDP-based protocol should be able to handle the occasional lost packet, right?
+
+ """
+
incoming_packets_dropped_in_trio: int
@@ -1263,7 +1276,7 @@ async def serve(
ssl_context: Context,
async_fn: Callable[[DTLSChannel], Awaitable],
*args: Any,
- task_status: _TaskStatus = trio.TASK_STATUS_IGNORED,
+ task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type]
) -> None:
"""Listen for incoming connections, and spawn a handler for each using an
internal nursery.
From f18abdd9c92759528fcb1ae746ba16034c926c2d Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 14 Jul 2023 16:35:03 +0200
Subject: [PATCH 07/49] socket is done - other than getting rid of the
_SocketType <-> SocketType distinction
---
pyproject.toml | 51 ++++-
trio/_core/_io_epoll.py | 4 +-
trio/_core/_local.py | 24 ++-
trio/_core/_run.py | 2 +-
trio/_core/_thread_cache.py | 2 +
trio/_socket.py | 299 ++++++++++++++++++++-------
trio/_subprocess_platform/windows.py | 2 +-
trio/_tests/verify_types.json | 27 +--
8 files changed, 307 insertions(+), 104 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index fa46d76bc7..aa1644b443 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,7 +39,7 @@ disallow_incomplete_defs = true
# Enable gradually / for new modules
check_untyped_defs = false
disallow_untyped_calls = false
-disallow_untyped_defs = false
+disallow_untyped_defs = true
# DO NOT use `ignore_errors`; it doesn't apply
# downstream and users have to deal with them.
@@ -47,14 +47,61 @@ disallow_untyped_defs = false
[[tool.mypy.overrides]]
module = "trio._core._run"
disallow_incomplete_defs = false
+disallow_untyped_defs = false
[[tool.mypy.overrides]]
module = [
"trio._abc",
- "trio._dtls"
+ "trio._dtls",
+ "trio._socket",
]
disallow_untyped_defs = true
+[[tool.mypy.overrides]]
+module = [
+"trio/_core/_asyncgens",
+"trio/_core/_entry_queue",
+"trio/_core/_generated_io_epoll",
+"trio/_core/_generated_io_kqueue",
+"trio/_core/_generated_io_windows",
+"trio/_core/_generated_run",
+"trio/_core/_io_common",
+"trio/_core/_io_epoll",
+"trio/_core/_io_kqueue",
+"trio/_core/_io_windows",
+"trio/_core/_ki",
+"trio/_core/_multierror",
+"trio/_core/_parking_lot",
+"trio/_core/_thread_cache",
+"trio/_core/_traps",
+"trio/_core/_wakeup_socketpair",
+"trio/_core/_windows_cffi",
+"trio/_deprecate",
+"trio/_file_io",
+"trio/_highlevel_open_tcp_listeners",
+"trio/_highlevel_open_tcp_stream",
+"trio/_highlevel_open_unix_stream",
+"trio/_highlevel_serve_listeners",
+"trio/_highlevel_socket",
+"trio/_highlevel_ssl_helpers",
+"trio/_path",
+"trio/_signals",
+"trio/_ssl",
+"trio/_subprocess",
+"trio/_subprocess_platform/kqueue",
+"trio/_subprocess_platform/waitid",
+"trio/_sync",
+"trio/_threads",
+"trio/_util",
+"trio/_wait_for_object",
+"trio/testing/_check_streams",
+"trio/testing/_checkpoints",
+"trio/testing/_memory_streams",
+"trio/testing/_network",
+"trio/testing/_trio_test",
+]
+disallow_untyped_defs = false
+
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
faulthandler_timeout = 60
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index fbeb454c7d..9d7b250785 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -3,7 +3,7 @@
import select
import sys
from collections import defaultdict
-from typing import TYPE_CHECKING, DefaultDict, Dict
+from typing import TYPE_CHECKING, DefaultDict
import attr
@@ -192,7 +192,7 @@ class EpollIOManager:
_epoll: select.epoll = attr.ib(factory=select.epoll)
# {fd: EpollWaiters}
_registered: DefaultDict[int, EpollWaiters] = attr.ib(
- factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters]
+ factory=lambda: defaultdict(EpollWaiters)
)
_force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd: int | None = attr.ib(default=None)
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index 89ccf93e95..fe509ca7ad 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import Generic, TypeVar
+from typing import Generic, TypeVar, overload
# Runvar implementations
import attr
@@ -12,12 +12,16 @@
C = TypeVar("C", bound="_RunVarToken")
+class NoValue(object):
+ ...
+
+
@attr.s(eq=False, hash=False, slots=True)
class _RunVarToken(Generic[T]):
- _no_value = None
+ _no_value = NoValue()
_var: RunVar[T] = attr.ib()
- previous_value: T | None = attr.ib(default=_no_value)
+ previous_value: T | NoValue = attr.ib(default=_no_value)
redeemed: bool = attr.ib(default=False, init=False)
@classmethod
@@ -35,11 +39,19 @@ class RunVar(Generic[T], metaclass=Final):
"""
- _NO_DEFAULT = None
+ _NO_DEFAULT = NoValue()
_name: str = attr.ib()
- _default: T | None = attr.ib(default=_NO_DEFAULT)
+ _default: T | NoValue = attr.ib(default=_NO_DEFAULT)
+
+ @overload
+ def get(self, default: T) -> T:
+ ...
+
+ @overload
+ def get(self, default: NoValue = _NO_DEFAULT) -> T | NoValue:
+ ...
- def get(self, default: T | None = _NO_DEFAULT) -> T | None:
+ def get(self, default: T | NoValue = _NO_DEFAULT) -> T | NoValue:
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
# not typed yet
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 804a958714..ecc9138b23 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -1392,7 +1392,7 @@ class GuestState:
done_callback: Callable = attr.ib()
unrolled_run_gen = attr.ib()
_value_factory: Callable[[], Value] = lambda: Value(None)
- unrolled_run_next_send: Outcome = attr.ib(factory=_value_factory, type=Outcome)
+ unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome)
def guest_tick(self):
try:
diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py
index 3e27ce6a32..157f14c5a1 100644
--- a/trio/_core/_thread_cache.py
+++ b/trio/_core/_thread_cache.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import ctypes
import ctypes.util
import sys
diff --git a/trio/_socket.py b/trio/_socket.py
index 0f5aa75fd2..d492bbc41f 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -5,9 +5,22 @@
import socket as _stdlib_socket
import sys
from functools import wraps as _wraps
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, SupportsIndex
+from socket import AddressFamily, SocketKind
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ NoReturn,
+ SupportsIndex,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
import idna as _idna
+from typing_extensions import Concatenate, ParamSpec
import trio
@@ -17,11 +30,20 @@
from collections.abc import Iterable
from types import TracebackType
- from typing_extensions import Self
+ from typing_extensions import Buffer, Self, TypeAlias
from ._abc import HostnameResolver, SocketFactory
+T = TypeVar("T")
+P = ParamSpec("P")
+
+# must use old-style typing for TypeAlias
+Address: TypeAlias = Union[
+ str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int]
+]
+
+
# Usage:
#
# async with _try_sync():
@@ -31,16 +53,18 @@
# return await do_it_properly_with_a_check_point()
#
class _try_sync:
- def __init__(self, blocking_exc_override=None):
+ def __init__(
+ self, blocking_exc_override: Callable[[BaseException], bool] | None = None
+ ):
self._blocking_exc_override = blocking_exc_override
- def _is_blocking_io_error(self, exc):
+ def _is_blocking_io_error(self, exc: BaseException) -> bool:
if self._blocking_exc_override is None:
return isinstance(exc, BlockingIOError)
else:
return self._blocking_exc_override(exc)
- async def __aenter__(self):
+ async def __aenter__(self) -> None:
await trio.lowlevel.checkpoint_if_cancelled()
async def __aexit__(
@@ -76,7 +100,7 @@ async def __aexit__(
################################################################
_resolver: _core.RunVar[HostnameResolver | None] = _core.RunVar("hostname_resolver")
-_socket_factory: _core.RunVar[SocketFactory] = _core.RunVar("socket_factory")
+_socket_factory: _core.RunVar[SocketFactory | None] = _core.RunVar("socket_factory")
def set_custom_hostname_resolver(
@@ -113,7 +137,9 @@ def set_custom_hostname_resolver(
return old
-def set_custom_socket_factory(socket_factory):
+def set_custom_socket_factory(
+ socket_factory: SocketFactory | None,
+) -> SocketFactory | None:
"""Set a custom socket object factory.
This function allows you to replace Trio's normal socket class with a
@@ -147,6 +173,7 @@ def set_custom_socket_factory(socket_factory):
_NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV
+# It would be possible to @overload the return value depending on Literal[AddressFamily.INET/6], but should probably be added in typeshed first
async def getaddrinfo(
host: bytes | str | None,
port: bytes | str | int | None,
@@ -156,8 +183,8 @@ async def getaddrinfo(
flags: int = 0,
) -> list[
tuple[
- _stdlib_socket.AddressFamily,
- _stdlib_socket.SocketKind,
+ AddressFamily,
+ SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
@@ -183,7 +210,7 @@ async def getaddrinfo(
# skip the whole thread thing, which seems worthwhile. So we try first
# with the _NUMERIC_ONLY flags set, and then only spawn a thread if that
# fails with EAI_NONAME:
- def numeric_only_failure(exc):
+ def numeric_only_failure(exc: BaseException) -> bool:
return (
isinstance(exc, _stdlib_socket.gaierror)
and exc.errno == _stdlib_socket.EAI_NONAME
@@ -246,7 +273,7 @@ async def getnameinfo(
)
-async def getprotobyname(name):
+async def getprotobyname(name: str) -> int:
"""Look up a protocol number by name. (Rarely used.)
Like :func:`socket.getprotobyname`, but async.
@@ -276,12 +303,12 @@ def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType:
@_wraps(_stdlib_socket.fromfd, assigned=(), updated=())
def fromfd(
fd: SupportsIndex,
- family: _stdlib_socket.AddressFamily | int = _stdlib_socket.AF_INET,
- type: _stdlib_socket.SocketKind | int = _stdlib_socket.SOCK_STREAM,
+ family: AddressFamily | int = _stdlib_socket.AF_INET,
+ type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
) -> _SocketType:
"""Like :func:`socket.fromfd`, but returns a Trio socket object."""
- family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd)
+ family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, int(fd))
return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto))
@@ -290,24 +317,38 @@ def fromfd(
):
@_wraps(_stdlib_socket.fromshare, assigned=(), updated=())
- def fromshare(*args, **kwargs):
- return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs))
+ def fromshare(info: bytes) -> _SocketType:
+ return from_stdlib_socket(_stdlib_socket.fromshare(info))
+
+
+if sys.platform == "win32":
+ FamilyT = int
+ TypeT = int
+ FamilyDefault = _stdlib_socket.AF_INET
+else:
+ FamilyDefault = None
+ FamilyT = Union[int, AddressFamily, None]
+ TypeT = Union[_stdlib_socket.socket, int]
@_wraps(_stdlib_socket.socketpair, assigned=(), updated=())
-def socketpair(*args, **kwargs):
+def socketpair(
+ family: FamilyT = FamilyDefault,
+ type: TypeT = SocketKind.SOCK_STREAM,
+ proto: int = 0,
+) -> tuple[_SocketType, _SocketType]:
"""Like :func:`socket.socketpair`, but returns a pair of Trio socket
objects.
"""
- left, right = _stdlib_socket.socketpair(*args, **kwargs)
+ left, right = _stdlib_socket.socketpair(family, type, proto)
return (from_stdlib_socket(left), from_stdlib_socket(right))
@_wraps(_stdlib_socket.socket, assigned=(), updated=())
def socket(
- family: _stdlib_socket.AddressFamily | int = _stdlib_socket.AF_INET,
- type: _stdlib_socket.SocketKind | int = _stdlib_socket.SOCK_STREAM,
+ family: AddressFamily | int = _stdlib_socket.AF_INET,
+ type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
fileno: int | None = None,
) -> _SocketType:
@@ -327,14 +368,24 @@ def socket(
return from_stdlib_socket(stdlib_socket)
-def _sniff_sockopts_for_fileno(family, type, proto, fileno):
+def _sniff_sockopts_for_fileno(
+ family: AddressFamily | int,
+ type: SocketKind | int,
+ proto: int,
+ fileno: int | None,
+) -> tuple[AddressFamily | int, SocketKind | int, int]:
"""Correct SOCKOPTS for given fileno, falling back to provided values."""
# Wrap the raw fileno into a Python socket object
# This object might have the wrong metadata, but it lets us easily call getsockopt
# and then we'll throw it away and construct a new one with the correct metadata.
if sys.platform != "linux":
return family, type, proto
- from socket import SO_DOMAIN, SO_PROTOCOL, SO_TYPE, SOL_SOCKET
+ from socket import ( # type: ignore[attr-defined]
+ SO_DOMAIN,
+ SO_PROTOCOL,
+ SO_TYPE,
+ SOL_SOCKET,
+ )
sockobj = _stdlib_socket.socket(family, type, proto, fileno=fileno)
try:
@@ -364,19 +415,21 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno):
)
-def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False):
- fn = getattr(_stdlib_socket.socket, methname)
-
+def _make_simple_sock_method_wrapper(
+ fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
+ wait_fn: Callable,
+ maybe_avail: bool = False,
+) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]:
@_wraps(fn, assigned=("__name__",), updated=())
- async def wrapper(self, *args, **kwargs):
- return await self._nonblocking_helper(fn, args, kwargs, wait_fn)
+ async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T:
+ return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs)
- wrapper.__doc__ = f"""Like :meth:`socket.socket.{methname}`, but async.
+ wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async.
"""
if maybe_avail:
wrapper.__doc__ += (
- f"Only available on platforms where :meth:`socket.socket.{methname}` is "
+ f"Only available on platforms where :meth:`socket.socket.{fn.__name__}` is "
"available."
)
return wrapper
@@ -395,8 +448,21 @@ async def wrapper(self, *args, **kwargs):
# local=False means that the address is being used with connect() or sendto() or
# similar.
#
+
+
+# Using a TypeVar to indicate we return the same type of address appears to give errors
+# when passed a union of address types.
+# @overload likely works, but is extremely verbose.
# NOTE: this function does not always checkpoint
-async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, local):
+async def _resolve_address_nocp(
+ type: int,
+ family: AddressFamily,
+ proto: int,
+ *,
+ ipv6_v6only: bool | int,
+ address: Address,
+ local: bool,
+) -> Address:
# Do some pre-checking (or exit early for non-IP sockets)
if family == _stdlib_socket.AF_INET:
if not isinstance(address, tuple) or not len(address) == 2:
@@ -406,13 +472,15 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo
raise ValueError(
"address should be a (host, port, [flowinfo, [scopeid]]) tuple"
)
- elif family == _stdlib_socket.AF_UNIX:
+ elif family == getattr(_stdlib_socket, "AF_UNIX"):
# unwrap path-likes
+ assert isinstance(address, (str, bytes))
return os.fspath(address)
else:
return address
# -- From here on we know we have IPv4 or IPV6 --
+ host: str | None
host, port, *_ = address
# Fast path for the simple case: already-resolved IP address,
# already-resolved port. This is particularly important for UDP, since
@@ -450,20 +518,20 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo
# The above ignored any flowid and scopeid in the passed-in address,
# so restore them if present:
if family == _stdlib_socket.AF_INET6:
- normed = list(normed)
+ list_normed = list(normed)
assert len(normed) == 4
if len(address) >= 3:
- normed[2] = address[2]
+ list_normed[2] = address[2] # type: ignore
if len(address) >= 4:
- normed[3] = address[3]
- normed = tuple(normed)
+ list_normed[3] = address[3] # type: ignore
+ return tuple(list_normed) # type: ignore
return normed
# TODO: stopping users from initializing this type should be done in a different way,
# so SocketType can be used as a type.
class SocketType:
- def __init__(self):
+ def __init__(self) -> NoReturn:
raise TypeError(
"SocketType is an abstract class; use trio.socket.socket if you "
"want to construct a socket object"
@@ -529,11 +597,11 @@ def __exit__(
return self._sock.__exit__(exc_type, exc_value, traceback)
@property
- def family(self) -> _stdlib_socket.AddressFamily:
+ def family(self) -> AddressFamily:
return self._sock.family
@property
- def type(self) -> _stdlib_socket.SocketKind:
+ def type(self) -> SocketKind:
return self._sock.type
@property
@@ -556,7 +624,7 @@ def close(self) -> None:
trio.lowlevel.notify_closing(self._sock)
self._sock.close()
- async def bind(self, address: tuple[object, ...] | str | bytes) -> None:
+ async def bind(self, address: Address) -> None:
address = await self._resolve_address_nocp(address, local=True)
if (
hasattr(_stdlib_socket, "AF_UNIX")
@@ -593,7 +661,12 @@ def is_readable(self) -> bool:
async def wait_writable(self) -> None:
await _core.wait_writable(self._sock)
- async def _resolve_address_nocp(self, address, *, local):
+ async def _resolve_address_nocp(
+ self,
+ address: Address,
+ *,
+ local: bool,
+ ) -> Address:
if self.family == _stdlib_socket.AF_INET6:
ipv6_v6only = self._sock.getsockopt(
IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY
@@ -609,7 +682,19 @@ async def _resolve_address_nocp(self, address, *, local):
local=local,
)
- async def _nonblocking_helper(self, fn, args, kwargs, wait_fn):
+ # args and kwargs must be starred, otherwise pyright complains:
+ # '"args" member of ParamSpec is valid only when used with *args parameter'
+ # '"kwargs" member of ParamSpec is valid only when used with **kwargs parameter'
+ # wait_fn and fn must also be first in the signature
+ # 'Keyword parameter cannot appear in signature after ParamSpec args parameter'
+
+ async def _nonblocking_helper(
+ self,
+ wait_fn: Callable[[_stdlib_socket.socket], Awaitable],
+ fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> T:
# We have to reconcile two conflicting goals:
# - We want to make it look like we always blocked in doing these
# operations. The obvious way is to always do an IO wait before
@@ -645,9 +730,9 @@ async def _nonblocking_helper(self, fn, args, kwargs, wait_fn):
# accept
################################################################
- _accept: Callable[
- [], Awaitable[tuple[_stdlib_socket.socket, object]]
- ] = _make_simple_sock_method_wrapper("accept", _core.wait_readable)
+ _accept = _make_simple_sock_method_wrapper(
+ _stdlib_socket.socket.accept, _core.wait_readable
+ )
async def accept(self) -> tuple[_SocketType, object]:
"""Like :meth:`socket.socket.accept`, but async."""
@@ -658,8 +743,7 @@ async def accept(self) -> tuple[_SocketType, object]:
# connect
################################################################
- # TODO: typing addresses is ... a pain
- async def connect(self, address: str) -> None:
+ async def connect(self, address: Address) -> None:
# nonblocking connect is weird -- you call it to start things
# off, then the socket becomes writable as a completion
# notification. This means it isn't really cancellable... we close the
@@ -727,38 +811,69 @@ async def connect(self, address: str) -> None:
# Okay, the connect finished, but it might have failed:
err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR)
if err != 0:
- raise OSError(err, f"Error connecting to {address}: {os.strerror(err)}")
+ raise OSError(err, f"Error connecting to {address!r}: {os.strerror(err)}")
################################################################
# recv
################################################################
+ # Not possible to typecheck with a Callable (due to DefaultArg), nor with a
+ # callback Protocol (https://github.com/python/typing/discussions/1040)
+ # but this seems to work. If not explicitly defined then pyright --verifytypes will
+ # complain about AmbiguousType
if TYPE_CHECKING:
- async def recv(self, buffersize: int, flags: int = 0) -> bytes:
+ def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
...
- else:
- recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable)
+ # _make_simple_sock_method_wrapper is typed, so this check that the above is correct
+ recv = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recv, _core.wait_readable
+ )
################################################################
# recv_into
################################################################
- recv_into = _make_simple_sock_method_wrapper("recv_into", _core.wait_readable)
+ if TYPE_CHECKING:
+
+ def recv_into(
+ __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[int]:
+ ...
+
+ recv_into = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recv_into, _core.wait_readable
+ )
################################################################
# recvfrom
################################################################
- recvfrom = _make_simple_sock_method_wrapper("recvfrom", _core.wait_readable)
+ if TYPE_CHECKING:
+ # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any]
+ def recvfrom(
+ __self, __bufsize: int, __flags: int = 0
+ ) -> Awaitable[tuple[bytes, Address]]:
+ ...
+
+ recvfrom = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recvfrom, _core.wait_readable
+ )
################################################################
# recvfrom_into
################################################################
- recvfrom_into = _make_simple_sock_method_wrapper(
- "recvfrom_into", _core.wait_readable
+ if TYPE_CHECKING:
+ # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any]
+ def recvfrom_into(
+ __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[tuple[int, Address]]:
+ ...
+
+ recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recvfrom_into, _core.wait_readable
)
################################################################
@@ -766,8 +881,15 @@ async def recv(self, buffersize: int, flags: int = 0) -> bytes:
################################################################
if hasattr(_stdlib_socket.socket, "recvmsg"):
- recvmsg = _make_simple_sock_method_wrapper(
- "recvmsg", _core.wait_readable, maybe_avail=True
+ if TYPE_CHECKING:
+
+ def recvmsg(
+ __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0
+ ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]:
+ ...
+
+ recvmsg = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recvmsg, _core.wait_readable, maybe_avail=True
)
################################################################
@@ -775,29 +897,58 @@ async def recv(self, buffersize: int, flags: int = 0) -> bytes:
################################################################
if hasattr(_stdlib_socket.socket, "recvmsg_into"):
- recvmsg_into = _make_simple_sock_method_wrapper(
- "recvmsg_into", _core.wait_readable, maybe_avail=True
+ if TYPE_CHECKING:
+
+ def recvmsg_into(
+ __self,
+ __buffers: Iterable[Buffer],
+ __ancbufsize: int = 0,
+ __flags: int = 0,
+ ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]:
+ ...
+
+ recvmsg_into = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recvmsg_into, _core.wait_readable, maybe_avail=True
)
################################################################
# send
################################################################
- send = _make_simple_sock_method_wrapper("send", _core.wait_writable)
+ if TYPE_CHECKING:
+
+ def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]:
+ ...
+
+ send = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.send, _core.wait_writable
+ )
################################################################
# sendto
################################################################
+ @overload
+ async def sendto(
+ self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer
+ ) -> int:
+ ...
+
+ @overload
+ async def sendto(
+ self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer
+ ) -> int:
+ ...
+
@_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=())
- async def sendto(self, *args):
+ async def sendto(self, *args: Any) -> int:
"""Similar to :meth:`socket.socket.sendto`, but async."""
# args is: data[, flags], address)
# and kwargs are not accepted
- args = list(args)
- args[-1] = await self._resolve_address_nocp(args[-1], local=False)
+ args_list = list(args)
+ args_list[-1] = await self._resolve_address_nocp(args[-1], local=False)
return await self._nonblocking_helper(
- _stdlib_socket.socket.sendto, args, {}, _core.wait_writable
+ _core.wait_writable, _stdlib_socket.socket.sendto, *args_list
)
################################################################
@@ -809,20 +960,28 @@ async def sendto(self, *args):
):
@_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=())
- async def sendmsg(self, *args):
+ async def sendmsg(
+ self,
+ __buffers: Iterable[Buffer],
+ __ancdata: Iterable[tuple[int, int, Buffer]] = (),
+ __flags: int = 0,
+ __address: Address | None = None,
+ ) -> int:
"""Similar to :meth:`socket.socket.sendmsg`, but async.
Only available on platforms where :meth:`socket.socket.sendmsg` is
available.
"""
- # args is: buffers[, ancdata[, flags[, address]]]
- # and kwargs are not accepted
- if len(args) == 4 and args[-1] is not None:
- args = list(args)
- args[-1] = await self._resolve_address_nocp(args[-1], local=False)
+ if __address is not None:
+ __address = await self._resolve_address_nocp(__address, local=False)
return await self._nonblocking_helper(
- _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable
+ _core.wait_writable,
+ _stdlib_socket.socket.sendmsg,
+ __buffers,
+ __ancdata,
+ __flags,
+ __address,
)
################################################################
diff --git a/trio/_subprocess_platform/windows.py b/trio/_subprocess_platform/windows.py
index 958be8675c..816da4b203 100644
--- a/trio/_subprocess_platform/windows.py
+++ b/trio/_subprocess_platform/windows.py
@@ -3,4 +3,4 @@
async def wait_child_exiting(process: "_subprocess.Process") -> None:
- await WaitForSingleObject(int(process._proc._handle))
+ await WaitForSingleObject(int(process._proc._handle)) # type: ignore[attr-defined]
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 00a964787a..5d19fe6729 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9020866773675762,
+ "completenessScore": 0.9149277688603531,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 562,
- "withUnknownType": 61
+ "withKnownType": 570,
+ "withUnknownType": 53
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
@@ -46,19 +46,14 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 523,
- "withUnknownType": 89
+ "withKnownType": 546,
+ "withUnknownType": 67
},
"packageName": "trio",
"symbols": [
- "trio._abc.SocketFactory.socket",
"trio._core._entry_queue.TrioToken.run_sync_soon",
"trio._core._run.Nursery.start",
"trio._core._run.Nursery.start_soon",
- "trio._dtls.DTLSChannel.__init__",
- "trio._dtls.DTLSEndpoint.__init__",
- "trio._highlevel_socket.SocketListener.__init__",
- "trio._highlevel_socket.SocketStream.__init__",
"trio._highlevel_socket.SocketStream.getsockopt",
"trio._highlevel_socket.SocketStream.send_all",
"trio._highlevel_socket.SocketStream.setsockopt",
@@ -76,14 +71,6 @@
"trio._path.Path.__rtruediv__",
"trio._path.Path.__truediv__",
"trio._path.Path.open",
- "trio._socket._SocketType.recv_into",
- "trio._socket._SocketType.recvfrom",
- "trio._socket._SocketType.recvfrom_into",
- "trio._socket._SocketType.recvmsg",
- "trio._socket._SocketType.recvmsg_into",
- "trio._socket._SocketType.send",
- "trio._socket._SocketType.sendmsg",
- "trio._socket._SocketType.sendto",
"trio._ssl.SSLListener.__init__",
"trio._ssl.SSLListener.accept",
"trio._ssl.SSLListener.aclose",
@@ -141,10 +128,6 @@
"trio.serve_listeners",
"trio.serve_ssl_over_tcp",
"trio.serve_tcp",
- "trio.socket.from_stdlib_socket",
- "trio.socket.getprotobyname",
- "trio.socket.set_custom_socket_factory",
- "trio.socket.socketpair",
"trio.testing._memory_streams.MemoryReceiveStream.__init__",
"trio.testing._memory_streams.MemoryReceiveStream.aclose",
"trio.testing._memory_streams.MemoryReceiveStream.close",
From 128a8d26945ee24ac6780b336c13a06c7846991e Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 20 Jul 2023 13:48:44 +0200
Subject: [PATCH 08/49] fix RTD, and export DTLSChannelStatistics and
TaskStatus
---
docs/source/conf.py | 4 ++++
docs/source/reference-core.rst | 1 +
docs/source/reference-io.rst | 2 ++
trio/__init__.py | 7 ++++++-
trio/_core/__init__.py | 1 +
trio/_core/_run.py | 8 ++++----
trio/_dtls.py | 4 ++--
trio/_tests/verify_types.json | 10 +++++-----
8 files changed, 25 insertions(+), 12 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 68a5a22a81..8efc702f02 100755
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -63,6 +63,10 @@
("py:obj", "trio._abc.T"),
("py:obj", "trio._abc.T_resource"),
("py:class", "types.FrameType"),
+ # TODO: figure out if you can link this to SSL
+ ("py:class", "Context"),
+ # TODO: temporary type
+ ("py:class", "_SocketType"),
]
autodoc_inherit_docstrings = False
default_role = "obj"
diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst
index f571d23294..434b8c8b5b 100644
--- a/docs/source/reference-core.rst
+++ b/docs/source/reference-core.rst
@@ -922,6 +922,7 @@ The nursery API
See :meth:`~Nursery.start`.
+.. autoclass:: TaskStatus
.. _task-local-storage:
diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst
index a3291ef2ae..3ae66699e1 100644
--- a/docs/source/reference-io.rst
+++ b/docs/source/reference-io.rst
@@ -304,6 +304,8 @@ unfortunately that's not yet possible.
.. automethod:: statistics
+.. autoclass:: DTLSChannelStatistics
+
.. module:: trio.socket
Low-level networking with :mod:`trio.socket`
diff --git a/trio/__init__.py b/trio/__init__.py
index 2b8810504b..ac0687f529 100644
--- a/trio/__init__.py
+++ b/trio/__init__.py
@@ -34,6 +34,7 @@
EndOfChannel as EndOfChannel,
Nursery as Nursery,
RunFinishedError as RunFinishedError,
+ TaskStatus as TaskStatus,
TrioInternalError as TrioInternalError,
WouldBlock as WouldBlock,
current_effective_deadline as current_effective_deadline,
@@ -46,7 +47,11 @@
NonBaseMultiError as _NonBaseMultiError,
)
from ._deprecate import TrioDeprecationWarning as TrioDeprecationWarning
-from ._dtls import DTLSChannel as DTLSChannel, DTLSEndpoint as DTLSEndpoint
+from ._dtls import (
+ DTLSChannel as DTLSChannel,
+ DTLSChannelStatistics as DTLSChannelStatistics,
+ DTLSEndpoint as DTLSEndpoint,
+)
from ._file_io import open_file as open_file, wrap_file as wrap_file
from ._highlevel_generic import (
StapledStream as StapledStream,
diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py
index abd58245e3..aa898fffe0 100644
--- a/trio/_core/__init__.py
+++ b/trio/_core/__init__.py
@@ -28,6 +28,7 @@
CancelScope,
Nursery,
Task,
+ TaskStatus,
add_instrument,
checkpoint,
checkpoint_if_cancelled,
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 4f90889c5f..279061f872 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -783,7 +783,7 @@ def cancel_called(self) -> bool:
# This code needs to be read alongside the code from Nursery.start to make
# sense.
@attr.s(eq=False, hash=False, repr=False)
-class _TaskStatus:
+class TaskStatus:
_old_nursery = attr.ib()
_new_nursery = attr.ib()
_called_started = attr.ib(default=False)
@@ -1137,16 +1137,16 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED):
try:
self._pending_starts += 1
async with open_nursery() as old_nursery:
- task_status = _TaskStatus(old_nursery, self)
+ task_status = TaskStatus(old_nursery, self)
thunk = functools.partial(async_fn, task_status=task_status)
task = GLOBAL_RUN_CONTEXT.runner.spawn_impl(
thunk, args, old_nursery, name
)
task._eventual_parent_nursery = self
- # Wait for either _TaskStatus.started or an exception to
+ # Wait for either TaskStatus.started or an exception to
# cancel this nursery:
# If we get here, then the child either got reparented or exited
- # normally. The complicated logic is all in _TaskStatus.started().
+ # normally. The complicated logic is all in TaskStatus.started().
# (Any exceptions propagate directly out of the above.)
if not task_status._called_started:
raise RuntimeError("child exited without calling task_status.started()")
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 0af8340732..3b15c83b7e 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -42,7 +42,7 @@
from OpenSSL.SSL import Context
from typing_extensions import Self, TypeAlias
- from ._core._run import _TaskStatus
+ from ._core._run import TaskStatus
from ._socket import _SocketType
MAX_UDP_PACKET_SIZE = 65527
@@ -1276,7 +1276,7 @@ async def serve(
ssl_context: Context,
async_fn: Callable[[DTLSChannel], Awaitable],
*args: Any,
- task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type]
+ task_status: TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type]
) -> None:
"""Listen for incoming connections, and spawn a handler for each using an
internal nursery.
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index cf1e7eccfb..65bfd6a301 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.8764044943820225,
+ "completenessScore": 0.8752,
"exportedSymbolCounts": {
"withAmbiguousType": 1,
- "withKnownType": 546,
- "withUnknownType": 76
+ "withKnownType": 547,
+ "withUnknownType": 77
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
@@ -73,8 +73,8 @@
"trio._core._mock_clock.MockClock.jump",
"trio._core._run.Nursery.start",
"trio._core._run.Nursery.start_soon",
- "trio._core._run._TaskStatus.__repr__",
- "trio._core._run._TaskStatus.started",
+ "trio._core._run.TaskStatus.__repr__",
+ "trio._core._run.TaskStatus.started",
"trio._core._unbounded_queue.UnboundedQueue.__aiter__",
"trio._core._unbounded_queue.UnboundedQueue.__anext__",
"trio._core._unbounded_queue.UnboundedQueue.__repr__",
From 0e9aedd9de39bd3f413096ff416f6ec13e568c8e Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 21 Jul 2023 12:38:56 +0200
Subject: [PATCH 09/49] update .gitattributes
---
.gitattributes | 1 +
1 file changed, 1 insertion(+)
diff --git a/.gitattributes b/.gitattributes
index 991065e069..3fd55705b6 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -2,3 +2,4 @@
trio/_core/_generated* linguist-generated=true
# Treat generated files as binary in git diff
trio/_core/_generated* -diff
+trio/_tests/verify_types.json merge=binary
From c5b43a0d7d02b45782b007aecb967777a43903a1 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sat, 22 Jul 2023 00:26:06 +0200
Subject: [PATCH 10/49] fixes after review from ZacHD
---
docs/source/reference-core.rst | 1 +
docs/source/reference-io.rst | 1 +
trio/_dtls.py | 28 ++++++++--------------------
3 files changed, 10 insertions(+), 20 deletions(-)
diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst
index 5f9381cbfc..980a3106e5 100644
--- a/docs/source/reference-core.rst
+++ b/docs/source/reference-core.rst
@@ -923,6 +923,7 @@ The nursery API
See :meth:`~Nursery.start`.
.. autoclass:: TaskStatus
+ :members:
.. _task-local-storage:
diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst
index 3ae66699e1..9ad11b2c5a 100644
--- a/docs/source/reference-io.rst
+++ b/docs/source/reference-io.rst
@@ -305,6 +305,7 @@ unfortunately that's not yet possible.
.. automethod:: statistics
.. autoclass:: DTLSChannelStatistics
+ :members:
.. module:: trio.socket
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 3b15c83b7e..7885b1ff21 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -798,9 +798,7 @@ async def dtls_receive_loop(
@attr.frozen
class DTLSChannelStatistics:
- """An object with statistics about this connection.
-
- Currently this has only one attribute:
+ """Currently this has only one attribute:
- ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of
incoming packets from this peer that Trio successfully received from the
@@ -1060,9 +1058,7 @@ def read_volley() -> list[_AnyHandshakeMessage]:
worst_case_mtu(self.endpoint.socket),
)
- async def send(
- self, data: bytes
- ) -> None: # or str? SendChannel defines it as bytes
+ async def send(self, data: bytes) -> None:
"""Send a packet of data, securely."""
if self._closed:
@@ -1078,7 +1074,7 @@ async def send(
_read_loop(self._ssl.bio_read), self.peer_address
)
- async def receive(self) -> bytes: # or str?
+ async def receive(self) -> bytes:
"""Fetch the next packet of data from this connection's peer, waiting if
necessary.
@@ -1151,18 +1147,7 @@ def get_cleartext_mtu(self) -> int:
return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return]
def statistics(self) -> DTLSChannelStatistics:
- """Returns an object with statistics about this connection.
-
- Currently this has only one attribute:
-
- - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of
- incoming packets from this peer that Trio successfully received from the
- network, but then got dropped because the internal channel buffer was full. If
- this is non-zero, then you might want to call ``receive`` more often, or use a
- larger ``incoming_packets_buffer``, or just not worry about it because your
- UDP-based protocol should be able to handle the occasional lost packet, right?
-
- """
+ """Returns a `DTLSChannelStatistics` object with statistics about this connection."""
return DTLSChannelStatistics(self._packets_dropped_in_trio)
@@ -1271,10 +1256,13 @@ def _check_closed(self) -> None:
if self._closed:
raise trio.ClosedResourceError
+ # async_fn cannot be typed with ParamSpec, since we don't accept
+ # kwargs. Can be typed with TypeVarTuple once it's fully supported
+ # in mypy.
async def serve(
self,
ssl_context: Context,
- async_fn: Callable[[DTLSChannel], Awaitable],
+ async_fn: Callable[[...], Awaitable],
*args: Any,
task_status: TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type]
) -> None:
From 58e4412e9c7f4e4b56dddc3f2d61b3d122e3b83a Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 23 Jul 2023 23:26:06 +0200
Subject: [PATCH 11/49] fixes after review by a5rocks
---
pyproject.toml | 3 +++
trio/_dtls.py | 37 ++++++++++++++++++++-----------------
trio/_socket.py | 9 +++++++--
3 files changed, 30 insertions(+), 19 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 4fe96b06b6..3b14a075da 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -51,6 +51,9 @@ module = [
]
disallow_incomplete_defs = true
disallow_untyped_defs = true
+disallow_any_generics = true
+disallow_any_decorated = true
+disallow_subclassing_any = true
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 7885b1ff21..6ba5f7931e 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -26,7 +26,6 @@
Iterator,
TypeVar,
Union,
- cast,
)
import attr
@@ -43,7 +42,7 @@
from typing_extensions import Self, TypeAlias
from ._core._run import TaskStatus
- from ._socket import _SocketType
+ from ._socket import Address, _SocketType
MAX_UDP_PACKET_SIZE = 65527
@@ -350,7 +349,7 @@ class OpaqueHandshakeMessage:
record: Record
-# for some reason doesn't work with |
+# Needs Union until <3.10 is dropped
_AnyHandshakeMessage: TypeAlias = Union[
HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage
]
@@ -564,7 +563,7 @@ def _signable(*fields: bytes) -> bytes:
def _make_cookie(
- key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes
+ key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes
) -> bytes:
assert len(salt) == SALT_BYTES
assert len(key) == KEY_BYTES
@@ -582,7 +581,7 @@ def _make_cookie(
def valid_cookie(
- key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes
+ key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes
) -> bool:
if len(cookie) > SALT_BYTES:
salt = cookie[:SALT_BYTES]
@@ -604,7 +603,7 @@ def valid_cookie(
def challenge_for(
- key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes
+ key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes
) -> bytes:
salt = os.urandom(SALT_BYTES)
tick = _current_cookie_tick()
@@ -665,7 +664,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
async def handle_client_hello_untrusted(
- endpoint: DTLSEndpoint, address: Any, packet: bytes
+ endpoint: DTLSEndpoint, address: Address, packet: bytes
) -> None:
if endpoint._listening_context is None:
return
@@ -776,7 +775,8 @@ async def dtls_receive_loop(
await stream._resend_final_volley()
else:
try:
- stream._q.s.send_nowait(packet)
+ # mypy for some reason cannot determine type of _q
+ stream._q.s.send_nowait(packet) # type:ignore[has-type]
except trio.WouldBlock:
stream._packets_dropped_in_trio += 1
else:
@@ -828,7 +828,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
"""
- def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context):
+ def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context):
self.endpoint = endpoint
self.peer_address = peer_address
self._packets_dropped_in_trio = 0
@@ -839,7 +839,12 @@ def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context):
# OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to
# support and isn't useful anyway -- especially for DTLS where it's equivalent
# to just performing a new handshake.
- ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined]
+ ctx.set_options(
+ (
+ SSL.OP_NO_QUERY_MTU
+ | SSL.OP_NO_RENEGOTIATION # type: ignore[attr-defined]
+ )
+ )
self._ssl = SSL.Connection(ctx)
self._handshake_mtu = 0
# This calls self._ssl.set_ciphertext_mtu, which is important, because if you
@@ -968,9 +973,8 @@ def read_volley() -> list[_AnyHandshakeMessage]:
new_volley_messages
and volley_messages
and isinstance(new_volley_messages[0], HandshakeMessage)
- # TODO: add isinstance or do a cast?
- and new_volley_messages[0].msg_seq
- == cast(HandshakeMessage, volley_messages[0]).msg_seq
+ and isinstance(volley_messages[0], HandshakeMessage)
+ and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq
):
# openssl decided to retransmit; discard because we handle
# retransmits ourselves
@@ -1054,8 +1058,7 @@ def read_volley() -> list[_AnyHandshakeMessage]:
# PMTU estimate is wrong? Let's try dropping it to the minimum
# and hope that helps.
self._handshake_mtu = min(
- self._handshake_mtu,
- worst_case_mtu(self.endpoint.socket),
+ self._handshake_mtu, worst_case_mtu(self.endpoint.socket)
)
async def send(self, data: bytes) -> None:
@@ -1195,7 +1198,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
# as a peer provides a valid cookie, we can immediately tear down the
# old connection.
# {remote address: DTLSChannel}
- self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
+ self._streams = weakref.WeakValueDictionary[Address, DTLSChannel]()
self._listening_context: Context | None = None
self._listening_key: bytes | None = None
self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
@@ -1262,7 +1265,7 @@ def _check_closed(self) -> None:
async def serve(
self,
ssl_context: Context,
- async_fn: Callable[[...], Awaitable],
+ async_fn: Callable[..., Awaitable[object]],
*args: Any,
task_status: TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type]
) -> None:
diff --git a/trio/_socket.py b/trio/_socket.py
index 659f844078..26b03fc3e0 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -5,7 +5,7 @@
import socket as _stdlib_socket
import sys
from functools import wraps as _wraps
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Tuple, Union
import idna as _idna
@@ -17,7 +17,12 @@
from collections.abc import Iterable
from types import TracebackType
- from typing_extensions import Self
+ from typing_extensions import Self, TypeAlias
+
+# must use old-style typing because it's evaluated at runtime
+Address: TypeAlias = Union[
+ str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int]
+]
# Usage:
From eda238b1feb919c01740d701831a7568e6a50173 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 23 Jul 2023 23:40:33 +0200
Subject: [PATCH 12/49] oopsies
---
trio/_dtls.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 6ba5f7931e..e8888d7871 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -27,6 +27,7 @@
TypeVar,
Union,
)
+from weakref import ReferenceType, WeakValueDictionary
import attr
from OpenSSL import SSL
@@ -739,7 +740,7 @@ async def handle_client_hello_untrusted(
async def dtls_receive_loop(
- endpoint_ref: weakref.ReferenceType[DTLSEndpoint], sock: _SocketType
+ endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType
) -> None:
try:
while True:
@@ -1198,7 +1199,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
# as a peer provides a valid cookie, we can immediately tear down the
# old connection.
# {remote address: DTLSChannel}
- self._streams = weakref.WeakValueDictionary[Address, DTLSChannel]()
+ self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary()
self._listening_context: Context | None = None
self._listening_key: bytes | None = None
self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
From c437821d74b54ccad30982fa45b65c26826bb104 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 24 Jul 2023 00:20:19 +0200
Subject: [PATCH 13/49] aoeu
---
pyproject.toml | 75 +++++++++++++++++++++++++-------------------------
1 file changed, 38 insertions(+), 37 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 8e44b75e9f..a67da04e11 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,6 +40,7 @@ disallow_incomplete_defs = true
check_untyped_defs = false
disallow_untyped_calls = false
disallow_untyped_defs = true
+disallow_any_generics = true
# DO NOT use `ignore_errors`; it doesn't apply
# downstream and users have to deal with them.
@@ -48,6 +49,7 @@ disallow_untyped_defs = true
module = "trio._core._run"
disallow_incomplete_defs = false
disallow_untyped_defs = false
+disallow_any_generics = false
[[tool.mypy.overrides]]
module = [
@@ -57,52 +59,51 @@ module = [
]
disallow_untyped_defs = true
disallow_incomplete_defs = true
-disallow_generic_any = true
+disallow_any_generics = true
[[tool.mypy.overrides]]
module = [
-"trio/_core/_asyncgens",
-"trio/_core/_entry_queue",
-"trio/_core/_generated_io_epoll",
-"trio/_core/_generated_io_kqueue",
+"trio/_core/_asyncgens", # 10
+"trio/_core/_entry_queue", # 16
+"trio/_core/_generated_io_epoll", # 3
"trio/_core/_generated_io_windows",
-"trio/_core/_generated_run",
-"trio/_core/_io_common",
-"trio/_core/_io_epoll",
-"trio/_core/_io_kqueue",
+"trio/_core/_generated_run", # 8
+"trio/_core/_io_common", # 1
+"trio/_core/_io_epoll", # 21
+"trio/_core/_io_kqueue", # 16
"trio/_core/_io_windows",
-"trio/_core/_ki",
-"trio/_core/_multierror",
-"trio/_core/_parking_lot",
-"trio/_core/_thread_cache",
-"trio/_core/_traps",
-"trio/_core/_wakeup_socketpair",
+"trio/_core/_ki", # 14
+"trio/_core/_multierror", # 19
+"trio/_core/_parking_lot", # 1
+"trio/_core/_thread_cache", # 6
+"trio/_core/_traps", # 7
+"trio/_core/_wakeup_socketpair", # 12
"trio/_core/_windows_cffi",
-"trio/_deprecate",
-"trio/_file_io",
-"trio/_highlevel_open_tcp_listeners",
-"trio/_highlevel_open_tcp_stream",
-"trio/_highlevel_open_unix_stream",
-"trio/_highlevel_serve_listeners",
-"trio/_highlevel_socket",
-"trio/_highlevel_ssl_helpers",
-"trio/_path",
-"trio/_signals",
-"trio/_ssl",
-"trio/_subprocess",
-"trio/_subprocess_platform/kqueue",
-"trio/_subprocess_platform/waitid",
-"trio/_sync",
-"trio/_threads",
-"trio/_util",
+"trio/_deprecate", # 12
+"trio/_file_io", # 13
+"trio/_highlevel_open_tcp_listeners", # 3
+"trio/_highlevel_open_tcp_stream", # 5
+"trio/_highlevel_open_unix_stream", # 2
+"trio/_highlevel_serve_listeners", # 3
+"trio/_highlevel_socket", # 4
+"trio/_highlevel_ssl_helpers", # 3
+"trio/_path", # 21
+"trio/_signals", # 13
+"trio/_ssl", # 26
+"trio/_subprocess", # 21
+"trio/_subprocess_platform/waitid", # 2
+"trio/_sync", # 1
+"trio/_threads", # 15
+"trio/_util", # 13
"trio/_wait_for_object",
-"trio/testing/_check_streams",
-"trio/testing/_checkpoints",
-"trio/testing/_memory_streams",
-"trio/testing/_network",
-"trio/testing/_trio_test",
+"trio/testing/_check_streams", # 27
+"trio/testing/_checkpoints", # 3
+"trio/testing/_memory_streams", # 66
+"trio/testing/_network", # 1
+"trio/testing/_trio_test", # 2
]
disallow_untyped_defs = false
+disallow_any_generics = false
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
From 0fbb87f87846c61ceeaac8a990cf4d79f5e87d49 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 24 Jul 2023 00:22:01 +0200
Subject: [PATCH 14/49] merge _dtls
---
trio/_dtls.py | 77 +++++++++++++++++++----------------
trio/_tests/verify_types.json | 2 +-
2 files changed, 42 insertions(+), 37 deletions(-)
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 8aede92bc8..e8888d7871 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -26,8 +26,8 @@
Iterator,
TypeVar,
Union,
- cast,
)
+from weakref import ReferenceType, WeakValueDictionary
import attr
from OpenSSL import SSL
@@ -42,8 +42,8 @@
from OpenSSL.SSL import Context
from typing_extensions import Self, TypeAlias
- from ._core._run import _TaskStatus
- from ._socket import _SocketType
+ from ._core._run import TaskStatus
+ from ._socket import Address, _SocketType
MAX_UDP_PACKET_SIZE = 65527
@@ -350,7 +350,7 @@ class OpaqueHandshakeMessage:
record: Record
-# for some reason doesn't work with |
+# Needs Union until <3.10 is dropped
_AnyHandshakeMessage: TypeAlias = Union[
HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage
]
@@ -564,7 +564,7 @@ def _signable(*fields: bytes) -> bytes:
def _make_cookie(
- key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes
+ key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes
) -> bytes:
assert len(salt) == SALT_BYTES
assert len(key) == KEY_BYTES
@@ -582,7 +582,7 @@ def _make_cookie(
def valid_cookie(
- key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes
+ key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes
) -> bool:
if len(cookie) > SALT_BYTES:
salt = cookie[:SALT_BYTES]
@@ -604,7 +604,7 @@ def valid_cookie(
def challenge_for(
- key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes
+ key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes
) -> bytes:
salt = os.urandom(SALT_BYTES)
tick = _current_cookie_tick()
@@ -665,7 +665,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
async def handle_client_hello_untrusted(
- endpoint: DTLSEndpoint, address: Any, packet: bytes
+ endpoint: DTLSEndpoint, address: Address, packet: bytes
) -> None:
if endpoint._listening_context is None:
return
@@ -740,7 +740,7 @@ async def handle_client_hello_untrusted(
async def dtls_receive_loop(
- endpoint_ref: weakref.ReferenceType[DTLSEndpoint], sock: _SocketType
+ endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType
) -> None:
try:
while True:
@@ -776,7 +776,8 @@ async def dtls_receive_loop(
await stream._resend_final_volley()
else:
try:
- stream._q.s.send_nowait(packet)
+ # mypy for some reason cannot determine type of _q
+ stream._q.s.send_nowait(packet) # type:ignore[has-type]
except trio.WouldBlock:
stream._packets_dropped_in_trio += 1
else:
@@ -798,6 +799,17 @@ async def dtls_receive_loop(
@attr.frozen
class DTLSChannelStatistics:
+ """Currently this has only one attribute:
+
+ - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of
+ incoming packets from this peer that Trio successfully received from the
+ network, but then got dropped because the internal channel buffer was full. If
+ this is non-zero, then you might want to call ``receive`` more often, or use a
+ larger ``incoming_packets_buffer``, or just not worry about it because your
+ UDP-based protocol should be able to handle the occasional lost packet, right?
+
+ """
+
incoming_packets_dropped_in_trio: int
@@ -817,7 +829,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
"""
- def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context):
+ def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context):
self.endpoint = endpoint
self.peer_address = peer_address
self._packets_dropped_in_trio = 0
@@ -828,7 +840,12 @@ def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context):
# OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to
# support and isn't useful anyway -- especially for DTLS where it's equivalent
# to just performing a new handshake.
- ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) # type: ignore[attr-defined]
+ ctx.set_options(
+ (
+ SSL.OP_NO_QUERY_MTU
+ | SSL.OP_NO_RENEGOTIATION # type: ignore[attr-defined]
+ )
+ )
self._ssl = SSL.Connection(ctx)
self._handshake_mtu = 0
# This calls self._ssl.set_ciphertext_mtu, which is important, because if you
@@ -957,9 +974,8 @@ def read_volley() -> list[_AnyHandshakeMessage]:
new_volley_messages
and volley_messages
and isinstance(new_volley_messages[0], HandshakeMessage)
- # TODO: add isinstance or do a cast?
- and new_volley_messages[0].msg_seq
- == cast(HandshakeMessage, volley_messages[0]).msg_seq
+ and isinstance(volley_messages[0], HandshakeMessage)
+ and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq
):
# openssl decided to retransmit; discard because we handle
# retransmits ourselves
@@ -1043,13 +1059,10 @@ def read_volley() -> list[_AnyHandshakeMessage]:
# PMTU estimate is wrong? Let's try dropping it to the minimum
# and hope that helps.
self._handshake_mtu = min(
- self._handshake_mtu,
- worst_case_mtu(self.endpoint.socket),
+ self._handshake_mtu, worst_case_mtu(self.endpoint.socket)
)
- async def send(
- self, data: bytes
- ) -> None: # or str? SendChannel defines it as bytes
+ async def send(self, data: bytes) -> None:
"""Send a packet of data, securely."""
if self._closed:
@@ -1065,7 +1078,7 @@ async def send(
_read_loop(self._ssl.bio_read), self.peer_address
)
- async def receive(self) -> bytes: # or str?
+ async def receive(self) -> bytes:
"""Fetch the next packet of data from this connection's peer, waiting if
necessary.
@@ -1138,18 +1151,7 @@ def get_cleartext_mtu(self) -> int:
return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return]
def statistics(self) -> DTLSChannelStatistics:
- """Returns an object with statistics about this connection.
-
- Currently this has only one attribute:
-
- - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of
- incoming packets from this peer that Trio successfully received from the
- network, but then got dropped because the internal channel buffer was full. If
- this is non-zero, then you might want to call ``receive`` more often, or use a
- larger ``incoming_packets_buffer``, or just not worry about it because your
- UDP-based protocol should be able to handle the occasional lost packet, right?
-
- """
+ """Returns a `DTLSChannelStatistics` object with statistics about this connection."""
return DTLSChannelStatistics(self._packets_dropped_in_trio)
@@ -1197,7 +1199,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
# as a peer provides a valid cookie, we can immediately tear down the
# old connection.
# {remote address: DTLSChannel}
- self._streams: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
+ self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary()
self._listening_context: Context | None = None
self._listening_key: bytes | None = None
self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
@@ -1258,12 +1260,15 @@ def _check_closed(self) -> None:
if self._closed:
raise trio.ClosedResourceError
+ # async_fn cannot be typed with ParamSpec, since we don't accept
+ # kwargs. Can be typed with TypeVarTuple once it's fully supported
+ # in mypy.
async def serve(
self,
ssl_context: Context,
- async_fn: Callable[[DTLSChannel], Awaitable],
+ async_fn: Callable[..., Awaitable[object]],
*args: Any,
- task_status: _TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type]
+ task_status: TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type]
) -> None:
"""Listen for incoming connections, and spawn a handler for each using an
internal nursery.
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 5d19fe6729..be02c5203b 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -46,7 +46,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 546,
+ "withKnownType": 535,
"withUnknownType": 67
},
"packageName": "trio",
From 3d92de4c0d6bd87affb37cd0323935e0a2a66b96 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 24 Jul 2023 00:31:48 +0200
Subject: [PATCH 15/49] merge-ish _socket
---
pyproject.toml | 11 ---
trio/_core/_local.py | 53 ++++++--------
trio/_core/_parking_lot.py | 2 +-
trio/_socket.py | 127 ++++++++++++++++++++++------------
trio/_tests/verify_types.json | 2 +-
5 files changed, 108 insertions(+), 87 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index a67da04e11..e6f61a698e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -51,16 +51,6 @@ disallow_incomplete_defs = false
disallow_untyped_defs = false
disallow_any_generics = false
-[[tool.mypy.overrides]]
-module = [
- "trio._abc",
- "trio._dtls",
- "trio._socket",
-]
-disallow_untyped_defs = true
-disallow_incomplete_defs = true
-disallow_any_generics = true
-
[[tool.mypy.overrides]]
module = [
"trio/_core/_asyncgens", # 10
@@ -74,7 +64,6 @@ module = [
"trio/_core/_io_windows",
"trio/_core/_ki", # 14
"trio/_core/_multierror", # 19
-"trio/_core/_parking_lot", # 1
"trio/_core/_thread_cache", # 6
"trio/_core/_traps", # 7
"trio/_core/_wakeup_socketpair", # 12
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index fe509ca7ad..b9dada64fe 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -1,32 +1,32 @@
from __future__ import annotations
-from typing import Generic, TypeVar, overload
+from typing import Generic, TypeVar, final
# Runvar implementations
import attr
-from .._util import Final
+from .._util import Final, NoPublicConstructor
from . import _run
+# `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released
+
T = TypeVar("T")
-C = TypeVar("C", bound="_RunVarToken")
-class NoValue(object):
+@final
+class _NoValue:
...
-@attr.s(eq=False, hash=False, slots=True)
-class _RunVarToken(Generic[T]):
- _no_value = NoValue()
-
+@attr.s(eq=False, hash=False, slots=False)
+class RunVarToken(Generic[T], metaclass=NoPublicConstructor):
_var: RunVar[T] = attr.ib()
- previous_value: T | NoValue = attr.ib(default=_no_value)
+ previous_value: T | type[_NoValue] = attr.ib(default=_NoValue)
redeemed: bool = attr.ib(default=False, init=False)
@classmethod
- def empty(cls: type[C], var: RunVar[T]) -> C:
- return cls(var)
+ def _empty(cls, var: RunVar[T]) -> RunVarToken[T]:
+ return cls._create(var)
@attr.s(eq=False, hash=False, slots=True)
@@ -39,19 +39,10 @@ class RunVar(Generic[T], metaclass=Final):
"""
- _NO_DEFAULT = NoValue()
_name: str = attr.ib()
- _default: T | NoValue = attr.ib(default=_NO_DEFAULT)
-
- @overload
- def get(self, default: T) -> T:
- ...
-
- @overload
- def get(self, default: NoValue = _NO_DEFAULT) -> T | NoValue:
- ...
+ _default: T | type[_NoValue] = attr.ib(default=_NoValue)
- def get(self, default: T | NoValue = _NO_DEFAULT) -> T | NoValue:
+ def get(self, default: T | type[_NoValue] = _NoValue) -> T:
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
# not typed yet
@@ -60,15 +51,15 @@ def get(self, default: T | NoValue = _NO_DEFAULT) -> T | NoValue:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
# contextvars consistency
- if default is not self._NO_DEFAULT:
- return default
+ if default is not _NoValue:
+ return default # type: ignore[return-value]
- if self._default is not self._NO_DEFAULT:
- return self._default
+ if self._default is not _NoValue:
+ return self._default # type: ignore[return-value]
raise LookupError(self) from None
- def set(self, value: T) -> _RunVarToken[T]:
+ def set(self, value: T) -> RunVarToken[T]:
"""Sets the value of this :class:`RunVar` for this current run
call.
@@ -76,16 +67,16 @@ def set(self, value: T) -> _RunVarToken[T]:
try:
old_value = self.get()
except LookupError:
- token: _RunVarToken[T] = _RunVarToken.empty(self)
+ token = RunVarToken._empty(self)
else:
- token = _RunVarToken(self, old_value)
+ token = RunVarToken[T]._create(self, old_value)
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index]
return token
- def reset(self, token: _RunVarToken[T]) -> None:
+ def reset(self, token: RunVarToken[T]) -> None:
"""Resets the value of this :class:`RunVar` to what it was
previously specified by the token.
@@ -101,7 +92,7 @@ def reset(self, token: _RunVarToken[T]) -> None:
previous = token.previous_value
try:
- if previous is _RunVarToken._no_value:
+ if previous is _NoValue:
_run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type]
else:
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment]
diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py
index 74708433da..6510745e5b 100644
--- a/trio/_core/_parking_lot.py
+++ b/trio/_core/_parking_lot.py
@@ -139,7 +139,7 @@ async def park(self) -> None:
self._parked[task] = None
task.custom_sleep_data = self
- def abort_fn(_):
+ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort:
del task.custom_sleep_data._parked[task]
return _core.Abort.SUCCEEDED
diff --git a/trio/_socket.py b/trio/_socket.py
index 72498a5482..e9fa8f3537 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -5,6 +5,7 @@
import socket as _stdlib_socket
import sys
from functools import wraps as _wraps
+from operator import index
from socket import AddressFamily, SocketKind
from typing import (
TYPE_CHECKING,
@@ -16,11 +17,11 @@
Tuple,
TypeVar,
Union,
+ cast,
overload,
)
import idna as _idna
-from typing_extensions import Concatenate, ParamSpec
import trio
@@ -30,13 +31,14 @@
from collections.abc import Iterable
from types import TracebackType
- from typing_extensions import Buffer, Self, TypeAlias
+ from typing_extensions import Buffer, Concatenate, ParamSpec, Self, TypeAlias
from ._abc import HostnameResolver, SocketFactory
+ P = ParamSpec("P")
+
T = TypeVar("T")
-P = ParamSpec("P")
# must use old-style typing because it's evaluated at runtime
Address: TypeAlias = Union[
@@ -224,7 +226,7 @@ def numeric_only_failure(exc: BaseException) -> bool:
# idna.encode will error out if the hostname has Capital Letters
# in it; with uts46=True it will lowercase them instead.
host = _idna.encode(host, uts46=True)
- hr: HostnameResolver | None = _resolver.get(None)
+ hr = _resolver.get(None)
if hr is not None:
return await hr.getaddrinfo(host, port, family, type, proto, flags)
else:
@@ -296,7 +298,7 @@ def fromfd(
proto: int = 0,
) -> _SocketType:
"""Like :func:`socket.fromfd`, but returns a Trio socket object."""
- family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, int(fd))
+ family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd))
return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto))
@@ -310,13 +312,13 @@ def fromshare(info: bytes) -> _SocketType:
if sys.platform == "win32":
- FamilyT = int
- TypeT = int
+ FamilyT: TypeAlias = int
+ TypeT: TypeAlias = int
FamilyDefault = _stdlib_socket.AF_INET
else:
FamilyDefault = None
- FamilyT = Union[int, AddressFamily, None]
- TypeT = Union[_stdlib_socket.socket, int]
+ FamilyT: TypeAlias = Union[int, AddressFamily, None]
+ TypeT: TypeAlias = Union[_stdlib_socket.socket, int]
@_wraps(_stdlib_socket.socketpair, assigned=(), updated=())
@@ -405,7 +407,7 @@ def _sniff_sockopts_for_fileno(
def _make_simple_sock_method_wrapper(
fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
- wait_fn: Callable,
+ wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]],
maybe_avail: bool = False,
) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]:
@_wraps(fn, assigned=("__name__",), updated=())
@@ -508,6 +510,8 @@ async def _resolve_address_nocp(
if family == _stdlib_socket.AF_INET6:
list_normed = list(normed)
assert len(normed) == 4
+ # typechecking certainly doesn't like this logic, but given just how broad
+ # Address is, it's quite cumbersome to write the below without type: ignore
if len(address) >= 3:
list_normed[2] = address[2] # type: ignore
if len(address) >= 4:
@@ -517,7 +521,9 @@ async def _resolve_address_nocp(
# TODO: stopping users from initializing this type should be done in a different way,
-# so SocketType can be used as a type.
+# so SocketType can be used as a type. Note that this is *far* from trivial without
+# breaking subclasses of SocketType. Should maybe just add abstract methods to SocketType,
+# or rename _SocketType.
class SocketType:
def __init__(self) -> NoReturn:
raise TypeError(
@@ -542,36 +548,69 @@ def __init__(self, sock: _stdlib_socket.socket):
# Simple + portable methods and attributes
################################################################
- # NB this doesn't work because for loops don't create a scope
- # for _name in [
- # ]:
- # _meth = getattr(_stdlib_socket.socket, _name)
- # @_wraps(_meth, assigned=("__name__", "__doc__"), updated=())
- # def _wrapped(self, *args, **kwargs):
- # return getattr(self._sock, _meth)(*args, **kwargs)
- # locals()[_meth] = _wrapped
- # del _name, _meth, _wrapped
-
- _forward = {
- "detach",
- "get_inheritable",
- "set_inheritable",
- "fileno",
- "getpeername",
- "getsockname",
- "getsockopt",
- "setsockopt",
- "listen",
- "share",
- }
-
- def __getattr__(self, name: str) -> Any:
- if name in self._forward:
- return getattr(self._sock, name)
- raise AttributeError(name)
-
- def __dir__(self) -> Iterable[str]:
- return [*super().__dir__(), *self._forward]
+ # forwarded methods
+ def detach(self) -> int:
+ return self._sock.detach()
+
+ def fileno(self) -> int:
+ return self._sock.fileno()
+
+ def getpeername(self) -> Any:
+ return self._sock.getpeername()
+
+ def getsockname(self) -> Any:
+ return self._sock.getsockname()
+
+ @overload
+ def getsockopt(self, /, level: int, optname: int) -> int:
+ ...
+
+ @overload
+ def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
+ ...
+
+ def getsockopt(
+ self, /, level: int, optname: int, buflen: int | None = None
+ ) -> int | bytes:
+ if buflen is None:
+ return self._sock.getsockopt(level, optname)
+ return self._sock.getsockopt(level, optname, buflen)
+
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
+ ...
+
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None:
+ ...
+
+ def setsockopt(
+ self,
+ /,
+ level: int,
+ optname: int,
+ value: int | Buffer | None,
+ optlen: int | None = None,
+ ) -> None:
+ if optlen is None:
+ return self._sock.setsockopt(level, optname, cast("int|Buffer", value))
+ return self._sock.setsockopt(level, optname, cast(None, value), optlen)
+
+ def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
+ return self._sock.listen(backlog)
+
+ def get_inheritable(self) -> bool:
+ return self._sock.get_inheritable()
+
+ def set_inheritable(self, inheritable: bool) -> None:
+ return self._sock.set_inheritable(inheritable)
+
+ if sys.platform == "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share")
+ ):
+
+ def share(self, /, process_id: int) -> bytes:
+ return self._sock.share(process_id)
def __enter__(self) -> Self:
return self
@@ -678,7 +717,7 @@ async def _resolve_address_nocp(
async def _nonblocking_helper(
self,
- wait_fn: Callable[[_stdlib_socket.socket], Awaitable],
+ wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]],
fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
*args: P.args,
**kwargs: P.kwargs,
@@ -814,7 +853,9 @@ async def connect(self, address: Address) -> None:
def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
...
- # _make_simple_sock_method_wrapper is typed, so this check that the above is correct
+ # _make_simple_sock_method_wrapper is typed, so this checks that the above is correct
+ # this requires that we refrain from using `/` to specify pos-only
+ # args, or mypy thinks the signature differs from typeshed.
recv = _make_simple_sock_method_wrapper( # noqa: F811
_stdlib_socket.socket.recv, _core.wait_readable
)
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 56ca3ace4d..ad044eb65f 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -46,7 +46,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 546,
+ "withKnownType": 552,
"withUnknownType": 67
},
"packageName": "trio",
From 2a51953d22919bffdc81972d56c18809b0565777 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 24 Jul 2023 00:34:16 +0200
Subject: [PATCH 16/49] _sync
---
pyproject.toml | 1 -
trio/_sync.py | 16 ++++++++--------
2 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index e6f61a698e..73131c90ad 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -81,7 +81,6 @@ module = [
"trio/_ssl", # 26
"trio/_subprocess", # 21
"trio/_subprocess_platform/waitid", # 2
-"trio/_sync", # 1
"trio/_threads", # 15
"trio/_util", # 13
"trio/_wait_for_object",
diff --git a/trio/_sync.py b/trio/_sync.py
index 0f05dd458c..9764ddce2d 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -8,7 +8,7 @@
import trio
from . import _core
-from ._core import ParkingLot, enable_ki_protection
+from ._core import Abort, ParkingLot, RaiseCancelT, enable_ki_protection
from ._util import Final
if TYPE_CHECKING:
@@ -87,7 +87,7 @@ async def wait(self) -> None:
task = _core.current_task()
self._tasks.add(task)
- def abort_fn(_):
+ def abort_fn(_: RaiseCancelT) -> Abort:
self._tasks.remove(task)
return _core.Abort.SUCCEEDED
@@ -143,7 +143,7 @@ class CapacityLimiterStatistics:
borrowed_tokens: int = attr.ib()
total_tokens: int | float = attr.ib()
- borrowers: list[object] = attr.ib()
+ borrowers: list[Task | object] = attr.ib()
tasks_waiting: int = attr.ib()
@@ -204,9 +204,9 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final):
# total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
def __init__(self, total_tokens: int | float):
self._lot = ParkingLot()
- self._borrowers: set[object] = set()
+ self._borrowers: set[Task | object] = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
- self._pending_borrowers: dict[Task, object] = {}
+ self._pending_borrowers: dict[Task, Task | object] = {}
# invoke the property setter for validation
self.total_tokens: int | float = total_tokens
assert self._total_tokens == total_tokens
@@ -268,7 +268,7 @@ def acquire_nowait(self) -> None:
self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task())
@enable_ki_protection
- def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
+ def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, without
blocking.
@@ -307,7 +307,7 @@ async def acquire(self) -> None:
await self.acquire_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- async def acquire_on_behalf_of(self, borrower: object) -> None:
+ async def acquire_on_behalf_of(self, borrower: Task | object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, blocking if
necessary.
@@ -347,7 +347,7 @@ def release(self) -> None:
self.release_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- def release_on_behalf_of(self, borrower: object) -> None:
+ def release_on_behalf_of(self, borrower: Task | object) -> None:
"""Put a token back into the sack on behalf of ``borrower``.
Raises:
From 4b513bdab79050e154643838712cab1d15fb6f60 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 24 Jul 2023 11:16:08 +0200
Subject: [PATCH 17/49] type _io_epoll and stuff
---
pyproject.toml | 77 ++++++++++++++++-------------
trio/_core/_generated_io_epoll.py | 6 +--
trio/_core/_io_common.py | 8 ++-
trio/_core/_io_epoll.py | 39 ++++++++-------
trio/_core/_io_kqueue.py | 8 +--
trio/_core/_run.py | 7 +--
trio/_core/_tests/test_ki.py | 10 +++-
trio/_ssl.py | 9 ++--
trio/_subprocess_platform/kqueue.py | 7 ++-
trio/_tests/verify_types.json | 9 ++--
trio/tests.py | 2 +
11 files changed, 108 insertions(+), 74 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 73131c90ad..dd3bafd078 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -53,42 +53,49 @@ disallow_any_generics = false
[[tool.mypy.overrides]]
module = [
-"trio/_core/_asyncgens", # 10
-"trio/_core/_entry_queue", # 16
-"trio/_core/_generated_io_epoll", # 3
-"trio/_core/_generated_io_windows",
-"trio/_core/_generated_run", # 8
-"trio/_core/_io_common", # 1
-"trio/_core/_io_epoll", # 21
-"trio/_core/_io_kqueue", # 16
-"trio/_core/_io_windows",
-"trio/_core/_ki", # 14
-"trio/_core/_multierror", # 19
-"trio/_core/_thread_cache", # 6
-"trio/_core/_traps", # 7
+#"trio/_core/_io_common", # 1, 24
+"trio/_core/_windows_cffi", # 2, 324
+#"trio/_core/_generated_io_epoll", # 3, 36
+"trio/_core/_thread_cache", # 6, 273
+"trio/_core/_generated_io_kqueue", # 6 (darwin), 60
+"trio/_core/_traps", # 7, 276
+"trio/_core/_generated_run", # 8, 242
+"trio/_core/_generated_io_windows", # 9 (win32), 84
+"trio/_core/_asyncgens", # 10, 194
+
"trio/_core/_wakeup_socketpair", # 12
-"trio/_core/_windows_cffi",
-"trio/_deprecate", # 12
-"trio/_file_io", # 13
-"trio/_highlevel_open_tcp_listeners", # 3
-"trio/_highlevel_open_tcp_stream", # 5
-"trio/_highlevel_open_unix_stream", # 2
-"trio/_highlevel_serve_listeners", # 3
-"trio/_highlevel_socket", # 4
-"trio/_highlevel_ssl_helpers", # 3
-"trio/_path", # 21
-"trio/_signals", # 13
-"trio/_ssl", # 26
-"trio/_subprocess", # 21
-"trio/_subprocess_platform/waitid", # 2
-"trio/_threads", # 15
-"trio/_util", # 13
-"trio/_wait_for_object",
-"trio/testing/_check_streams", # 27
-"trio/testing/_checkpoints", # 3
-"trio/testing/_memory_streams", # 66
-"trio/testing/_network", # 1
-"trio/testing/_trio_test", # 2
+"trio/_core/_ki", # 14, 210
+"trio/_core/_entry_queue", # 16, 195
+"trio/_core/_io_kqueue", # 16, 198
+"trio/_core/_multierror", # 19, 469
+
+#"trio/_core/_io_epoll", # 21, 323
+"trio/_core/_io_windows", # 47 (win32), 867
+
+
+"trio/testing/_network", # 1, 34
+"trio/testing/_trio_test", # 2, 29
+"trio/testing/_checkpoints", # 3, 62
+"trio/testing/_check_streams", # 27, 522
+"trio/testing/_memory_streams", # 66, 590
+
+"trio/_highlevel_open_unix_stream", # 2, 49 lines
+"trio/_highlevel_open_tcp_listeners", # 3, 227 lines
+"trio/_highlevel_serve_listeners", # 3, 121 lines
+"trio/_highlevel_ssl_helpers", # 3, 155 lines
+"trio/_highlevel_socket", # 4, 386 lines
+"trio/_highlevel_open_tcp_stream", # 5, 379 lines
+
+"trio/_subprocess_platform/waitid", # 2, 107 lines
+"trio/_wait_for_object", # 2 (windows)
+"trio/_deprecate", # 12, 140lines
+"trio/_util", # 13, 348 lines
+"trio/_file_io", # 13, 191 lines
+"trio/_signals", # 13, 168 lines
+"trio/_threads", # 15, 398 lines
+"trio/_path", # 21, 295 lines
+"trio/_subprocess", # 21, 759 lines
+"trio/_ssl", # 26, 929 lines
]
disallow_untyped_defs = false
disallow_any_generics = false
diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py
index 02fb3bc348..1de66b0a8a 100644
--- a/trio/_core/_generated_io_epoll.py
+++ b/trio/_core/_generated_io_epoll.py
@@ -9,7 +9,7 @@
# fmt: off
-async def wait_readable(fd):
+async def wait_readable(fd: int) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
@@ -17,7 +17,7 @@ async def wait_readable(fd):
raise RuntimeError("must be called from async context")
-async def wait_writable(fd):
+async def wait_writable(fd: int) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
@@ -25,7 +25,7 @@ async def wait_writable(fd):
raise RuntimeError("must be called from async context")
-def notify_closing(fd):
+def notify_closing(fd: int) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
diff --git a/trio/_core/_io_common.py b/trio/_core/_io_common.py
index b141474fda..c1af293278 100644
--- a/trio/_core/_io_common.py
+++ b/trio/_core/_io_common.py
@@ -1,12 +1,18 @@
+from __future__ import annotations
+
import copy
+from typing import TYPE_CHECKING
import outcome
from .. import _core
+if TYPE_CHECKING:
+ from ._io_epoll import EpollWaiters
+
# Utility function shared between _io_epoll and _io_windows
-def wake_all(waiters, exc):
+def wake_all(waiters: EpollWaiters, exc: BaseException) -> None:
try:
current_task = _core.current_task()
except RuntimeError:
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index 9d7b250785..31c49ca230 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -12,14 +12,17 @@
from ._run import _public
from ._wakeup_socketpair import WakeupSocketpair
+if TYPE_CHECKING:
+ from .._core import Abort, RaiseCancelT
+
assert not TYPE_CHECKING or sys.platform == "linux"
@attr.s(slots=True, eq=False, frozen=True)
class _EpollStatistics:
- tasks_waiting_read = attr.ib()
- tasks_waiting_write = attr.ib()
- backend = attr.ib(default="epoll")
+ tasks_waiting_read: int = attr.ib()
+ tasks_waiting_write: int = attr.ib()
+ backend: str = attr.ib(default="epoll")
# Some facts about epoll
@@ -182,9 +185,9 @@ class _EpollStatistics:
@attr.s(slots=True, eq=False)
class EpollWaiters:
- read_task = attr.ib(default=None)
- write_task = attr.ib(default=None)
- current_flags = attr.ib(default=0)
+ read_task: None = attr.ib(default=None)
+ write_task: None = attr.ib(default=None)
+ current_flags: int = attr.ib(default=0)
@attr.s(slots=True, eq=False, hash=False)
@@ -197,11 +200,11 @@ class EpollIOManager:
_force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd: int | None = attr.ib(default=None)
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
- def statistics(self):
+ def statistics(self) -> _EpollStatistics:
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._registered.values():
@@ -214,24 +217,24 @@ def statistics(self):
tasks_waiting_write=tasks_waiting_write,
)
- def close(self):
+ def close(self) -> None:
self._epoll.close()
self._force_wakeup.close()
- def force_wakeup(self):
+ def force_wakeup(self) -> None:
self._force_wakeup.wakeup_thread_and_signal_safe()
# Return value must be False-y IFF the timeout expired, NOT if any I/O
# happened or force_wakeup was called. Otherwise it can be anything; gets
# passed straight through to process_events.
- def get_events(self, timeout):
+ def get_events(self, timeout: float) -> list[tuple[int, int]]:
# max_events must be > 0 or epoll gets cranky
# accessing self._registered from a thread looks dangerous, but it's
# OK because it doesn't matter if our value is a little bit off.
max_events = max(1, len(self._registered))
return self._epoll.poll(timeout, max_events)
- def process_events(self, events):
+ def process_events(self, events: list[tuple[int, int]]) -> None:
for fd, flags in events:
if fd == self._force_wakeup_fd:
self._force_wakeup.drain()
@@ -250,7 +253,7 @@ def process_events(self, events):
waiters.read_task = None
self._update_registrations(fd)
- def _update_registrations(self, fd):
+ def _update_registrations(self, fd: int) -> None:
waiters = self._registered[fd]
wanted_flags = 0
if waiters.read_task is not None:
@@ -279,7 +282,7 @@ def _update_registrations(self, fd):
if not wanted_flags:
del self._registered[fd]
- async def _epoll_wait(self, fd, attr_name):
+ async def _epoll_wait(self, fd: int, attr_name: str) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
@@ -290,7 +293,7 @@ async def _epoll_wait(self, fd, attr_name):
setattr(waiters, attr_name, _core.current_task())
self._update_registrations(fd)
- def abort(_):
+ def abort(_: RaiseCancelT) -> Abort:
setattr(waiters, attr_name, None)
self._update_registrations(fd)
return _core.Abort.SUCCEEDED
@@ -298,15 +301,15 @@ def abort(_):
await _core.wait_task_rescheduled(abort)
@_public
- async def wait_readable(self, fd):
+ async def wait_readable(self, fd: int) -> None:
await self._epoll_wait(fd, "read_task")
@_public
- async def wait_writable(self, fd):
+ async def wait_writable(self, fd: int) -> None:
await self._epoll_wait(fd, "write_task")
@_public
- def notify_closing(self, fd):
+ def notify_closing(self, fd: int) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py
index d1151843e8..5ce5a609ed 100644
--- a/trio/_core/_io_kqueue.py
+++ b/trio/_core/_io_kqueue.py
@@ -11,6 +11,8 @@
from ._run import _public
from ._wakeup_socketpair import WakeupSocketpair
+if TYPE_CHECKING:
+ from .._core import Abort, RaiseCancelT
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
@@ -123,11 +125,11 @@ async def wait_kevent(self, ident, filter, abort_func):
)
self._registered[key] = _core.current_task()
- def abort(raise_cancel):
+ def abort(raise_cancel: RaiseCancelT) -> Abort:
r = abort_func(raise_cancel)
if r is _core.Abort.SUCCEEDED:
del self._registered[key]
- return r
+ return r # type: ignore[no-any-return]
return await _core.wait_task_rescheduled(abort)
@@ -138,7 +140,7 @@ async def _wait_common(self, fd, filter):
event = select.kevent(fd, filter, flags)
self._kqueue.control([event], 0)
- def abort(_):
+ def abort(_: RaiseCancelT) -> Abort:
event = select.kevent(fd, filter, select.KQ_EV_DELETE)
try:
self._kqueue.control([event], 0)
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 5daf08f462..49c0bcef67 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -37,6 +37,7 @@
Abort,
CancelShieldedCheckpoint,
PermanentlyDetachCoroutineObject,
+ RaiseCancelT,
WaitTaskRescheduled,
cancel_shielded_checkpoint,
wait_task_rescheduled,
@@ -1022,7 +1023,7 @@ async def _nested_child_finished(self, nested_child_exc):
# If we get cancelled (or have an exception injected, like
# KeyboardInterrupt), then save that, but still wait until our
# children finish.
- def aborted(raise_cancel):
+ def abort(raise_cancel: RaiseCancelT) -> Abort:
self._add_exc(capture(raise_cancel).error)
return Abort.FAILED
@@ -1433,7 +1434,7 @@ def in_main_thread():
class Runner:
clock = attr.ib()
instruments: Instruments = attr.ib()
- io_manager = attr.ib()
+ io_manager: TheIOManager = attr.ib()
ki_manager = attr.ib()
strict_exception_groups = attr.ib()
@@ -1905,7 +1906,7 @@ async def test_lock_fairness():
key = (cushion, id(task))
self.waiting_for_idle[key] = task
- def abort(_):
+ def abort(_: RaiseCancelT) -> Abort:
del self.waiting_for_idle[key]
return Abort.SUCCEEDED
diff --git a/trio/_core/_tests/test_ki.py b/trio/_core/_tests/test_ki.py
index fdbada4624..b6eef68e22 100644
--- a/trio/_core/_tests/test_ki.py
+++ b/trio/_core/_tests/test_ki.py
@@ -1,7 +1,10 @@
+from __future__ import annotations
+
import contextlib
import inspect
import signal
import threading
+from typing import TYPE_CHECKING
import outcome
import pytest
@@ -16,6 +19,9 @@
from ..._util import signal_raise
from ...testing import wait_all_tasks_blocked
+if TYPE_CHECKING:
+ from ..._core import Abort, RaiseCancelT
+
def ki_self():
signal_raise(signal.SIGINT)
@@ -375,7 +381,7 @@ async def main():
ki_self()
task = _core.current_task()
- def abort(_):
+ def abort(_: RaiseCancelT) -> Abort:
_core.reschedule(task, outcome.Value(1))
return _core.Abort.FAILED
@@ -394,7 +400,7 @@ async def main():
ki_self()
task = _core.current_task()
- def abort(raise_cancel):
+ def abort(raise_cancel: RaiseCancelT) -> Abort:
result = outcome.capture(raise_cancel)
_core.reschedule(task, result)
return _core.Abort.FAILED
diff --git a/trio/_ssl.py b/trio/_ssl.py
index bd8b3b06b6..352f95edaf 100644
--- a/trio/_ssl.py
+++ b/trio/_ssl.py
@@ -148,10 +148,12 @@
# stream)
# docs will need to make very clear that this is different from all the other
# cancellations in core Trio
+from __future__ import annotations
import operator as _operator
import ssl as _stdlib_ssl
from enum import Enum as _Enum
+from typing import Any, Awaitable, Callable
import trio
@@ -209,13 +211,14 @@ class NeedHandshakeError(Exception):
class _Once:
- def __init__(self, afn, *args):
+ # needs TypeVarTuple
+ def __init__(self, afn: Callable[..., Awaitable[object]], *args: Any):
self._afn = afn
self._args = args
self.started = False
self._done = _sync.Event()
- async def ensure(self, *, checkpoint):
+ async def ensure(self, *, checkpoint: bool) -> None:
if not self.started:
self.started = True
await self._afn(*self._args)
@@ -226,7 +229,7 @@ async def ensure(self, *, checkpoint):
await self._done.wait()
@property
- def done(self):
+ def done(self) -> bool:
return self._done.is_set()
diff --git a/trio/_subprocess_platform/kqueue.py b/trio/_subprocess_platform/kqueue.py
index 9839fd046b..b40db75953 100644
--- a/trio/_subprocess_platform/kqueue.py
+++ b/trio/_subprocess_platform/kqueue.py
@@ -1,9 +1,14 @@
+from __future__ import annotations
+
import select
import sys
from typing import TYPE_CHECKING
from .. import _core, _subprocess
+if TYPE_CHECKING:
+ from .._core import Abort, RaiseCancelT
+
assert (sys.platform != "win32" and sys.platform != "linux") or not TYPE_CHECKING
@@ -35,7 +40,7 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None:
# in Chromium it seems we should still keep the check.
return
- def abort(_):
+ def abort(_: RaiseCancelT) -> Abort:
kqueue.control([make_event(select.KQ_EV_DELETE)], 0)
return _core.Abort.SUCCEEDED
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index ad044eb65f..c454c4d3b9 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,14 +7,14 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9152,
+ "completenessScore": 0.92,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 572,
- "withUnknownType": 53
+ "withKnownType": 575,
+ "withUnknownType": 50
},
"ignoreUnknownTypesFromImports": true,
- "missingClassDocStringCount": 1,
+ "missingClassDocStringCount": 0,
"missingDefaultParamCount": 0,
"missingFunctionDocStringCount": 4,
"moduleName": "trio",
@@ -159,7 +159,6 @@
"trio.testing.open_stream_to_socket_listener",
"trio.testing.trio_test",
"trio.testing.wait_all_tasks_blocked",
- "trio.tests.TestsDeprecationWrapper",
"trio.to_thread.current_default_thread_limiter",
"trio.wrap_file"
]
diff --git a/trio/tests.py b/trio/tests.py
index 573a076da8..472befb1ce 100644
--- a/trio/tests.py
+++ b/trio/tests.py
@@ -16,6 +16,8 @@
# This won't give deprecation warning on import, but will give a warning on use of any
# attribute in tests, and static analysis tools will also not see any content inside.
class TestsDeprecationWrapper:
+ """trio.tests is deprecated, use trio._tests"""
+
__name__ = "trio.tests"
def __getattr__(self, attr: str) -> Any:
From 1831eaef19ede65303add9ab287bd46e436b5ec8 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 26 Jul 2023 13:22:03 +0200
Subject: [PATCH 18/49] aborted -> abort
---
trio/_core/_run.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 49c0bcef67..4fb1c78048 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -1028,7 +1028,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
return Abort.FAILED
self._parent_waiting_in_aexit = True
- await wait_task_rescheduled(aborted)
+ await wait_task_rescheduled(abort)
else:
# Nothing to wait for, so just execute a checkpoint -- but we
# still need to mix any exception (e.g. from an external
From 091063b2a08a5e001681785d096f5cdb016d0965 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 24 Jul 2023 11:59:03 +0200
Subject: [PATCH 19/49] merge _socket, more stuff
---
.coveragerc | 1 +
docs/source/conf.py | 2 ++
docs/source/reference-io.rst | 8 ++++++++
pyproject.toml | 8 ++++++++
trio/_core/_generated_instrumentation.py | 7 +++++++
trio/_core/_generated_io_epoll.py | 13 ++++++++++---
trio/_core/_generated_io_kqueue.py | 7 +++++++
trio/_core/_generated_io_windows.py | 7 +++++++
trio/_core/_generated_run.py | 7 +++++++
trio/_core/_io_epoll.py | 10 ++++++----
trio/_threads.py | 12 ++++--------
trio/_tools/gen_exports.py | 5 +++++
trio/socket.py | 1 +
13 files changed, 73 insertions(+), 15 deletions(-)
diff --git a/.coveragerc b/.coveragerc
index 98f923bd8e..d577aa8adf 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -21,6 +21,7 @@ exclude_lines =
abc.abstractmethod
if TYPE_CHECKING:
if _t.TYPE_CHECKING:
+ @overload
partial_branches =
pragma: no branch
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 91ce7d884c..7e8626c20d 100755
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -62,6 +62,8 @@
("py:obj", "trio._abc.SendType"),
("py:obj", "trio._abc.T"),
("py:obj", "trio._abc.T_resource"),
+ ("py:class", "trio._threads.T"),
+ # why aren't these found in stdlib?
("py:class", "types.FrameType"),
# TODO: figure out if you can link this to SSL
("py:class", "Context"),
diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst
index 9ad11b2c5a..0669eb5323 100644
--- a/docs/source/reference-io.rst
+++ b/docs/source/reference-io.rst
@@ -504,6 +504,14 @@ Socket objects
* :meth:`~socket.socket.set_inheritable`
* :meth:`~socket.socket.get_inheritable`
+The internal SocketType
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: _SocketType
+..
+ TODO: adding `:members:` here gives error due to overload+_wraps on `sendto`
+ TODO: rewrite ... all of the above when fixing _SocketType vs SocketType
+
+
.. currentmodule:: trio
diff --git a/pyproject.toml b/pyproject.toml
index dd3bafd078..2b2678a366 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,6 +44,14 @@ disallow_any_generics = true
# DO NOT use `ignore_errors`; it doesn't apply
# downstream and users have to deal with them.
+[[tool.mypy.overrides]]
+module = [
+ "trio._socket",
+ "trio._core._local",
+ "trio._sync",
+]
+disallow_untyped_defs = true
+disallow_any_generics = true
[[tool.mypy.overrides]]
module = "trio._core._run"
diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py
index 30c2f26b4e..e38df6c1ad 100644
--- a/trio/_core/_generated_instrumentation.py
+++ b/trio/_core/_generated_instrumentation.py
@@ -2,10 +2,17 @@
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
# isort: skip
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ from socket import socket
+
# fmt: off
diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py
index 1de66b0a8a..a6ea291d91 100644
--- a/trio/_core/_generated_io_epoll.py
+++ b/trio/_core/_generated_io_epoll.py
@@ -2,14 +2,21 @@
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
# isort: skip
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ from socket import socket
+
# fmt: off
-async def wait_readable(fd: int) ->None:
+async def wait_readable(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
@@ -17,7 +24,7 @@ async def wait_readable(fd: int) ->None:
raise RuntimeError("must be called from async context")
-async def wait_writable(fd: int) ->None:
+async def wait_writable(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
@@ -25,7 +32,7 @@ async def wait_writable(fd: int) ->None:
raise RuntimeError("must be called from async context")
-def notify_closing(fd: int) ->None:
+def notify_closing(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py
index 94e819769c..5179f150c6 100644
--- a/trio/_core/_generated_io_kqueue.py
+++ b/trio/_core/_generated_io_kqueue.py
@@ -2,10 +2,17 @@
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
# isort: skip
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ from socket import socket
+
# fmt: off
diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py
index 26b4da697d..71172ef4df 100644
--- a/trio/_core/_generated_io_windows.py
+++ b/trio/_core/_generated_io_windows.py
@@ -2,10 +2,17 @@
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
# isort: skip
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ from socket import socket
+
# fmt: off
diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py
index d1e74a93f4..e3f08a49e3 100644
--- a/trio/_core/_generated_run.py
+++ b/trio/_core/_generated_run.py
@@ -2,10 +2,17 @@
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
# isort: skip
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ from socket import socket
+
# fmt: off
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index 31c49ca230..750f85fabb 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -13,6 +13,8 @@
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
+ from socket import socket
+
from .._core import Abort, RaiseCancelT
assert not TYPE_CHECKING or sys.platform == "linux"
@@ -282,7 +284,7 @@ def _update_registrations(self, fd: int) -> None:
if not wanted_flags:
del self._registered[fd]
- async def _epoll_wait(self, fd: int, attr_name: str) -> None:
+ async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
@@ -301,15 +303,15 @@ def abort(_: RaiseCancelT) -> Abort:
await _core.wait_task_rescheduled(abort)
@_public
- async def wait_readable(self, fd: int) -> None:
+ async def wait_readable(self, fd: int | socket) -> None:
await self._epoll_wait(fd, "read_task")
@_public
- async def wait_writable(self, fd: int) -> None:
+ async def wait_writable(self, fd: int | socket) -> None:
await self._epoll_wait(fd, "write_task")
@_public
- def notify_closing(self, fd: int) -> None:
+ def notify_closing(self, fd: int | socket) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
diff --git a/trio/_threads.py b/trio/_threads.py
index 45a416249e..3fbab05750 100644
--- a/trio/_threads.py
+++ b/trio/_threads.py
@@ -6,13 +6,14 @@
import queue as stdlib_queue
import threading
from itertools import count
-from typing import Optional
+from typing import Any, Callable, Optional, TypeVar
import attr
import outcome
from sniffio import current_async_library_cvar
import trio
+from trio._core._traps import RaiseCancelT
from ._core import (
RunVar,
@@ -24,6 +25,8 @@
from ._sync import CapacityLimiter
from ._util import coroutine_or_error
+T = TypeVar("T")
+
# Global due to Threading API, thread local storage for trio token
TOKEN_LOCAL = threading.local()
@@ -59,11 +62,6 @@ class ThreadPlaceholder:
name = attr.ib()
-from typing import Any, Callable, TypeVar
-
-T = TypeVar("T")
-
-
@enable_ki_protection
async def to_thread_run_sync(
sync_fn: Callable[..., T],
@@ -228,8 +226,6 @@ def deliver_worker_fn_result(result):
limiter.release_on_behalf_of(placeholder)
raise
- from trio._core._traps import RaiseCancelT
-
def abort(_: RaiseCancelT) -> trio.lowlevel.Abort:
if cancellable:
task_register[0] = None
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index a5d8529b53..9c9d91f413 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -18,9 +18,14 @@
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
# isort: skip
+from __future__ import annotations
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from socket import socket
# fmt: off
"""
diff --git a/trio/socket.py b/trio/socket.py
index a9e276c782..f6aebb6a6e 100644
--- a/trio/socket.py
+++ b/trio/socket.py
@@ -35,6 +35,7 @@
# import the overwrites
from ._socket import (
SocketType as SocketType,
+ _SocketType as _SocketType,
from_stdlib_socket as from_stdlib_socket,
fromfd as fromfd,
getaddrinfo as getaddrinfo,
From 46d9e9596d4528e86fb70770c48da9a903804e28 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 24 Jul 2023 12:35:25 +0200
Subject: [PATCH 20/49] regen
---
trio/_tools/gen_exports.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 9c9d91f413..95b92cedae 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -18,11 +18,13 @@
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
# isort: skip
+
+from typing import TYPE_CHECKING
+
from __future__ import annotations
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
-from typing import TYPE_CHECKING
if TYPE_CHECKING:
from socket import socket
From 156c94f99652760bf8b67ffa8878008640a13eb9 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 24 Jul 2023 13:56:21 +0200
Subject: [PATCH 21/49] progress
---
pyproject.toml | 4 +-
trio/_core/_generated_instrumentation.py | 10 +++-
trio/_core/_generated_io_epoll.py | 10 +++-
trio/_core/_generated_io_kqueue.py | 24 +++++---
trio/_core/_generated_io_windows.py | 10 +++-
trio/_core/_generated_run.py | 10 +++-
trio/_core/_io_epoll.py | 17 +++---
trio/_core/_io_kqueue.py | 75 +++++++++++++++---------
trio/_tools/gen_exports.py | 12 +++-
9 files changed, 114 insertions(+), 58 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 2b2678a366..f2422f8eea 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -65,7 +65,7 @@ module = [
"trio/_core/_windows_cffi", # 2, 324
#"trio/_core/_generated_io_epoll", # 3, 36
"trio/_core/_thread_cache", # 6, 273
-"trio/_core/_generated_io_kqueue", # 6 (darwin), 60
+#"trio/_core/_generated_io_kqueue", # 6 (darwin), 60
"trio/_core/_traps", # 7, 276
"trio/_core/_generated_run", # 8, 242
"trio/_core/_generated_io_windows", # 9 (win32), 84
@@ -74,7 +74,7 @@ module = [
"trio/_core/_wakeup_socketpair", # 12
"trio/_core/_ki", # 14, 210
"trio/_core/_entry_queue", # 16, 195
-"trio/_core/_io_kqueue", # 16, 198
+#"trio/_core/_io_kqueue", # 16, 198
"trio/_core/_multierror", # 19, 469
#"trio/_core/_io_epoll", # 21, 323
diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py
index e38df6c1ad..c783452bfc 100644
--- a/trio/_core/_generated_instrumentation.py
+++ b/trio/_core/_generated_instrumentation.py
@@ -1,18 +1,24 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
+ import select
from socket import socket
+ from _contextlib import _GeneratorContextManager
+ from _core import Abort, RaiseCancelT
+
+ from .. import _core
+
# fmt: off
diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py
index a6ea291d91..f35b927737 100644
--- a/trio/_core/_generated_io_epoll.py
+++ b/trio/_core/_generated_io_epoll.py
@@ -1,18 +1,24 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
+ import select
from socket import socket
+ from _contextlib import _GeneratorContextManager
+ from _core import Abort, RaiseCancelT
+
+ from .. import _core
+
# fmt: off
diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py
index 5179f150c6..f4b14cd500 100644
--- a/trio/_core/_generated_io_kqueue.py
+++ b/trio/_core/_generated_io_kqueue.py
@@ -1,22 +1,28 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
+ import select
from socket import socket
+ from _contextlib import _GeneratorContextManager
+ from _core import Abort, RaiseCancelT
+
+ from .. import _core
+
# fmt: off
-def current_kqueue():
+def current_kqueue() ->select.kqueue:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue()
@@ -24,7 +30,8 @@ def current_kqueue():
raise RuntimeError("must be called from async context")
-def monitor_kevent(ident, filter):
+def monitor_kevent(ident: int, filter: int) ->_GeneratorContextManager[_core
+ .UnboundedQueue[select.kevent]]:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter)
@@ -32,7 +39,8 @@ def monitor_kevent(ident, filter):
raise RuntimeError("must be called from async context")
-async def wait_kevent(ident, filter, abort_func):
+async def wait_kevent(ident: int, filter: int, abort_func: Callable[[
+ RaiseCancelT], Abort]) ->Abort:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func)
@@ -40,7 +48,7 @@ async def wait_kevent(ident, filter, abort_func):
raise RuntimeError("must be called from async context")
-async def wait_readable(fd):
+async def wait_readable(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
@@ -48,7 +56,7 @@ async def wait_readable(fd):
raise RuntimeError("must be called from async context")
-async def wait_writable(fd):
+async def wait_writable(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
@@ -56,7 +64,7 @@ async def wait_writable(fd):
raise RuntimeError("must be called from async context")
-def notify_closing(fd):
+def notify_closing(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py
index 71172ef4df..90d7ce0d70 100644
--- a/trio/_core/_generated_io_windows.py
+++ b/trio/_core/_generated_io_windows.py
@@ -1,18 +1,24 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
+ import select
from socket import socket
+ from _contextlib import _GeneratorContextManager
+ from _core import Abort, RaiseCancelT
+
+ from .. import _core
+
# fmt: off
diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py
index e3f08a49e3..e644de78fb 100644
--- a/trio/_core/_generated_run.py
+++ b/trio/_core/_generated_run.py
@@ -1,18 +1,24 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
+ import select
from socket import socket
+ from _contextlib import _GeneratorContextManager
+ from _core import Abort, RaiseCancelT
+
+ from .. import _core
+
# fmt: off
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index 750f85fabb..cc38d9d537 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -17,7 +17,15 @@
from .._core import Abort, RaiseCancelT
-assert not TYPE_CHECKING or sys.platform == "linux"
+
+@attr.s(slots=True, eq=False)
+class EpollWaiters:
+ read_task: None = attr.ib(default=None)
+ write_task: None = attr.ib(default=None)
+ current_flags: int = attr.ib(default=0)
+
+
+assert not TYPE_CHECKING or sys.platform == "linux" or sys.platform == "darwin"
@attr.s(slots=True, eq=False, frozen=True)
@@ -185,13 +193,6 @@ class _EpollStatistics:
# wanted to about how epoll works.
-@attr.s(slots=True, eq=False)
-class EpollWaiters:
- read_task: None = attr.ib(default=None)
- write_task: None = attr.ib(default=None)
- current_flags: int = attr.ib(default=0)
-
-
@attr.s(slots=True, eq=False, hash=False)
class EpollIOManager:
_epoll: select.epoll = attr.ib(factory=select.epoll)
diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py
index 5ce5a609ed..ee25a748f7 100644
--- a/trio/_core/_io_kqueue.py
+++ b/trio/_core/_io_kqueue.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import errno
import select
import sys
-from contextlib import contextmanager
-from typing import TYPE_CHECKING
+from contextlib import _GeneratorContextManager, contextmanager
+from typing import TYPE_CHECKING, Callable, Iterator
import attr
import outcome
@@ -12,33 +14,38 @@
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
- from .._core import Abort, RaiseCancelT
+ from socket import socket
+
+ from .._core import Abort, RaiseCancelT, Task, UnboundedQueue
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
@attr.s(slots=True, eq=False, frozen=True)
class _KqueueStatistics:
- tasks_waiting = attr.ib()
- monitors = attr.ib()
- backend = attr.ib(default="kqueue")
+ tasks_waiting: int = attr.ib()
+ monitors: int = attr.ib()
+ backend: str = attr.ib(default="kqueue")
@attr.s(slots=True, eq=False)
class KqueueIOManager:
- _kqueue = attr.ib(factory=select.kqueue)
+ _kqueue: select.kqueue = attr.ib(factory=select.kqueue)
# {(ident, filter): Task or UnboundedQueue}
- _registered = attr.ib(factory=dict)
- _force_wakeup = attr.ib(factory=WakeupSocketpair)
- _force_wakeup_fd = attr.ib(default=None)
-
- def __attrs_post_init__(self):
+ # TODO: int, int?
+ _registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = attr.ib(
+ factory=dict
+ )
+ _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
+ _force_wakeup_fd: None = attr.ib(default=None)
+
+ def __attrs_post_init__(self) -> None:
force_wakeup_event = select.kevent(
self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD
)
self._kqueue.control([force_wakeup_event], 0)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
- def statistics(self):
+ def statistics(self) -> _KqueueStatistics:
tasks_waiting = 0
monitors = 0
for receiver in self._registered.values():
@@ -48,14 +55,14 @@ def statistics(self):
monitors += 1
return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors)
- def close(self):
+ def close(self) -> None:
self._kqueue.close()
self._force_wakeup.close()
- def force_wakeup(self):
+ def force_wakeup(self) -> None:
self._force_wakeup.wakeup_thread_and_signal_safe()
- def get_events(self, timeout):
+ def get_events(self, timeout: float) -> list[select.kevent]:
# max_events must be > 0 or kqueue gets cranky
# and we generally want this to be strictly larger than the actual
# number of events we get, so that we can tell that we've gotten
@@ -72,7 +79,7 @@ def get_events(self, timeout):
# and loop back to the start
return events
- def process_events(self, events):
+ def process_events(self, events: list[select.kevent]) -> None:
for event in events:
key = (event.ident, event.filter)
if event.ident == self._force_wakeup_fd:
@@ -81,7 +88,7 @@ def process_events(self, events):
receiver = self._registered[key]
if event.flags & select.KQ_EV_ONESHOT:
del self._registered[key]
- if type(receiver) is _core.Task:
+ if isinstance(receiver, _core.Task):
_core.reschedule(receiver, outcome.Value(event))
else:
receiver.put_nowait(event)
@@ -98,18 +105,25 @@ def process_events(self, events):
# be more ergonomic...
@_public
- def current_kqueue(self):
+ def current_kqueue(self) -> select.kqueue:
return self._kqueue
- @contextmanager
@_public
- def monitor_kevent(self, ident, filter):
+ def monitor_kevent(
+ self, ident: int, filter: int
+ ) -> _GeneratorContextManager[_core.UnboundedQueue[select.kevent]]:
+ return self._monitor_kevent(ident, filter)
+
+ @contextmanager
+ def _monitor_kevent(
+ self, ident: int, filter: int
+ ) -> Iterator[_core.UnboundedQueue[select.kevent]]:
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair"
)
- q = _core.UnboundedQueue()
+ q = _core.UnboundedQueue[select.kevent]()
self._registered[key] = q
try:
yield q
@@ -117,7 +131,9 @@ def monitor_kevent(self, ident, filter):
del self._registered[key]
@_public
- async def wait_kevent(self, ident, filter, abort_func):
+ async def wait_kevent(
+ self, ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort]
+ ) -> Abort:
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
@@ -129,11 +145,12 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
r = abort_func(raise_cancel)
if r is _core.Abort.SUCCEEDED:
del self._registered[key]
- return r # type: ignore[no-any-return]
+ return r
- return await _core.wait_task_rescheduled(abort)
+ # wait_task_rescheduled does not have its return type typed
+ return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return]
- async def _wait_common(self, fd, filter):
+ async def _wait_common(self, fd: int | socket, filter: int) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
@@ -165,15 +182,15 @@ def abort(_: RaiseCancelT) -> Abort:
await self.wait_kevent(fd, filter, abort)
@_public
- async def wait_readable(self, fd):
+ async def wait_readable(self, fd: int | socket) -> None:
await self._wait_common(fd, select.KQ_FILTER_READ)
@_public
- async def wait_writable(self, fd):
+ async def wait_writable(self, fd: int | socket) -> None:
await self._wait_common(fd, select.KQ_FILTER_WRITE)
@_public
- def notify_closing(self, fd):
+ def notify_closing(self, fd: int | socket) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 95b92cedae..2996cfaaad 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -17,18 +17,24 @@
HEADER = """# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
-
-from typing import TYPE_CHECKING
from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable, Iterator
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
+ import select
from socket import socket
+ from _contextlib import _GeneratorContextManager
+ from _core import Abort, RaiseCancelT
+
+ from .. import _core
+
# fmt: off
"""
From 12b5af701737cf40f9bc4386edf6fd6285d43fb2 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 25 Jul 2023 15:45:30 +0200
Subject: [PATCH 22/49] mypy can now be run on trio/
---
pyproject.toml | 22 ++++++-
trio/_core/__init__.py | 1 +
trio/_core/_generated_instrumentation.py | 8 ++-
trio/_core/_generated_io_epoll.py | 8 ++-
trio/_core/_generated_io_kqueue.py | 8 ++-
trio/_core/_generated_io_windows.py | 8 ++-
trio/_core/_generated_run.py | 25 ++++----
trio/_core/_io_epoll.py | 2 +-
trio/_core/_io_kqueue.py | 1 +
trio/_core/_run.py | 59 ++++++++++++-------
trio/_core/_tests/test_io.py | 13 ++--
trio/_core/_tests/test_multierror.py | 2 +-
.../apport_excepthook.py | 2 +-
.../ipython_custom_exc.py | 2 +-
.../simple_excepthook.py | 2 +-
trio/_deprecate.py | 2 +-
trio/_subprocess_platform/waitid.py | 4 +-
trio/_tests/check_type_completeness.py | 2 +
trio/_tests/test_contextvars.py | 4 +-
trio/_tests/test_dtls.py | 4 +-
trio/_tests/test_exports.py | 10 +++-
trio/_tests/test_highlevel_serve_listeners.py | 2 +-
trio/_tests/test_subprocess.py | 14 ++++-
trio/_tests/test_threads.py | 12 ++--
trio/_tests/test_tracing.py | 10 ++--
trio/_tests/test_unix_pipes.py | 8 ++-
trio/_tests/verify_types.json | 14 +----
trio/_tools/gen_exports.py | 30 +++++++---
trio/_unix_pipes.py | 3 +
trio/testing/_fake_net.py | 18 +++---
30 files changed, 199 insertions(+), 101 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index f2422f8eea..74d9108644 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -59,15 +59,31 @@ disallow_incomplete_defs = false
disallow_untyped_defs = false
disallow_any_generics = false
+# TODO: gen_exports add platform checks to specific files
+[[tool.mypy.overrides]]
+module = "trio/_core/_generated_run"
+disable_error_code = ['has-type']
+[[tool.mypy.overrides]]
+module = "trio/_core/_generated_io_kqueue"
+disable_error_code = ['name-defined', 'attr-defined', 'no-any-return']
+[[tool.mypy.overrides]]
+module = "trio/_core/_generated_io_epoll"
+disable_error_code = ['no-any-return']
+
[[tool.mypy.overrides]]
module = [
+"trio/_core/_tests/*",
+"trio/_tests/*",
+
+
+"trio/_windows_pipes",
#"trio/_core/_io_common", # 1, 24
"trio/_core/_windows_cffi", # 2, 324
#"trio/_core/_generated_io_epoll", # 3, 36
"trio/_core/_thread_cache", # 6, 273
-#"trio/_core/_generated_io_kqueue", # 6 (darwin), 60
+"trio/_core/_generated_io_kqueue", # 6 (darwin), 60
"trio/_core/_traps", # 7, 276
-"trio/_core/_generated_run", # 8, 242
+#"trio/_core/_generated_run", # 8, 242
"trio/_core/_generated_io_windows", # 9 (win32), 84
"trio/_core/_asyncgens", # 10, 194
@@ -85,6 +101,7 @@ module = [
"trio/testing/_trio_test", # 2, 29
"trio/testing/_checkpoints", # 3, 62
"trio/testing/_check_streams", # 27, 522
+"trio/testing/_fake_net", # 30
"trio/testing/_memory_streams", # 66, 590
"trio/_highlevel_open_unix_stream", # 2, 49 lines
@@ -107,6 +124,7 @@ module = [
]
disallow_untyped_defs = false
disallow_any_generics = false
+disallow_incomplete_defs = false
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py
index aa898fffe0..26f6f04e7c 100644
--- a/trio/_core/__init__.py
+++ b/trio/_core/__init__.py
@@ -27,6 +27,7 @@
TASK_STATUS_IGNORED,
CancelScope,
Nursery,
+ RunStatistics,
Task,
TaskStatus,
add_instrument,
diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py
index c783452bfc..a1d38519d9 100644
--- a/trio/_core/_generated_instrumentation.py
+++ b/trio/_core/_generated_instrumentation.py
@@ -4,7 +4,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Callable, Iterator
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
@@ -12,12 +12,16 @@
if TYPE_CHECKING:
import select
+ import sys
+ from contextvars import Context
from socket import socket
from _contextlib import _GeneratorContextManager
- from _core import Abort, RaiseCancelT
+ from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken
+ from outcome import Outcome
from .. import _core
+ from .._abc import Clock
# fmt: off
diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py
index f35b927737..ea30ddf1fc 100644
--- a/trio/_core/_generated_io_epoll.py
+++ b/trio/_core/_generated_io_epoll.py
@@ -4,7 +4,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Callable, Iterator
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
@@ -12,12 +12,16 @@
if TYPE_CHECKING:
import select
+ import sys
+ from contextvars import Context
from socket import socket
from _contextlib import _GeneratorContextManager
- from _core import Abort, RaiseCancelT
+ from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken
+ from outcome import Outcome
from .. import _core
+ from .._abc import Clock
# fmt: off
diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py
index f4b14cd500..57fbd6f423 100644
--- a/trio/_core/_generated_io_kqueue.py
+++ b/trio/_core/_generated_io_kqueue.py
@@ -4,7 +4,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Callable, Iterator
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
@@ -12,12 +12,16 @@
if TYPE_CHECKING:
import select
+ import sys
+ from contextvars import Context
from socket import socket
from _contextlib import _GeneratorContextManager
- from _core import Abort, RaiseCancelT
+ from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken
+ from outcome import Outcome
from .. import _core
+ from .._abc import Clock
# fmt: off
diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py
index 90d7ce0d70..4c92661c50 100644
--- a/trio/_core/_generated_io_windows.py
+++ b/trio/_core/_generated_io_windows.py
@@ -4,7 +4,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Callable, Iterator
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
@@ -12,12 +12,16 @@
if TYPE_CHECKING:
import select
+ import sys
+ from contextvars import Context
from socket import socket
from _contextlib import _GeneratorContextManager
- from _core import Abort, RaiseCancelT
+ from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken
+ from outcome import Outcome
from .. import _core
+ from .._abc import Clock
# fmt: off
diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py
index e644de78fb..62d6eaff68 100644
--- a/trio/_core/_generated_run.py
+++ b/trio/_core/_generated_run.py
@@ -4,7 +4,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Callable, Iterator
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
@@ -12,17 +12,21 @@
if TYPE_CHECKING:
import select
+ import sys
+ from contextvars import Context
from socket import socket
from _contextlib import _GeneratorContextManager
- from _core import Abort, RaiseCancelT
+ from _core import Abort, RaiseCancelT, RunStatistics, SystemClock, Task, TrioToken
+ from outcome import Outcome
from .. import _core
+ from .._abc import Clock
# fmt: off
-def current_statistics():
+def current_statistics() ->RunStatistics:
"""Returns an object containing run-loop-level debugging information.
Currently the following fields are defined:
@@ -52,7 +56,7 @@ def current_statistics():
raise RuntimeError("must be called from async context")
-def current_time():
+def current_time() ->float:
"""Returns the current time according to Trio's internal clock.
Returns:
@@ -69,7 +73,7 @@ def current_time():
raise RuntimeError("must be called from async context")
-def current_clock():
+def current_clock() ->(SystemClock | Clock):
"""Returns the current :class:`~trio.abc.Clock`."""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
@@ -78,7 +82,7 @@ def current_clock():
raise RuntimeError("must be called from async context")
-def current_root_task():
+def current_root_task() ->(Task | None):
"""Returns the current root :class:`Task`.
This is the task that is the ultimate parent of all other tasks.
@@ -91,7 +95,7 @@ def current_root_task():
raise RuntimeError("must be called from async context")
-def reschedule(task, next_send=_NO_SEND):
+def reschedule(task: Task, next_send: Outcome=_NO_SEND) ->None:
"""Reschedule the given task with the given
:class:`outcome.Outcome`.
@@ -116,7 +120,8 @@ def reschedule(task, next_send=_NO_SEND):
raise RuntimeError("must be called from async context")
-def spawn_system_task(async_fn, *args, name=None, context=None):
+def spawn_system_task(async_fn: Callable[..., Awaitable[object]], *args:
+ Any, name: (str | None)=None, context: (Context | None)=None) ->Task:
"""Spawn a "system" task.
System tasks have a few differences from regular tasks:
@@ -175,7 +180,7 @@ def spawn_system_task(async_fn, *args, name=None, context=None):
raise RuntimeError("must be called from async context")
-def current_trio_token():
+def current_trio_token() ->TrioToken:
"""Retrieve the :class:`TrioToken` for the current call to
:func:`trio.run`.
@@ -187,7 +192,7 @@ def current_trio_token():
raise RuntimeError("must be called from async context")
-async def wait_all_tasks_blocked(cushion=0.0):
+async def wait_all_tasks_blocked(cushion: float=0.0) ->None:
"""Block until there are no runnable tasks.
This is useful in testing code when you want to give other tasks a
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index cc38d9d537..130403df73 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -25,7 +25,7 @@ class EpollWaiters:
current_flags: int = attr.ib(default=0)
-assert not TYPE_CHECKING or sys.platform == "linux" or sys.platform == "darwin"
+assert not TYPE_CHECKING or sys.platform == "linux"
@attr.s(slots=True, eq=False, frozen=True)
diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py
index ee25a748f7..0014ed9ac7 100644
--- a/trio/_core/_io_kqueue.py
+++ b/trio/_core/_io_kqueue.py
@@ -17,6 +17,7 @@
from socket import socket
from .._core import Abort, RaiseCancelT, Task, UnboundedQueue
+
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 4fb1c78048..79ed7ba8e3 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -17,7 +17,7 @@
from math import inf
from time import perf_counter
from types import TracebackType
-from typing import TYPE_CHECKING, Any, NoReturn, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Iterable, NoReturn, TypeVar
import attr
from outcome import Error, Outcome, Value, capture
@@ -49,11 +49,14 @@
from types import FrameType
if TYPE_CHECKING:
- import contextvars
+ from contextvars import Context
# An unfortunate name collision here with trio._util.Final
from typing import Final as FinalT
+ from .._abc import Clock
+ from ._mock_clock import MockClock
+
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000
_NO_SEND: FinalT = object()
@@ -120,6 +123,7 @@ def function_with_unique_name_xyzzy() -> NoReturn:
CONTEXT_RUN_TB_FRAMES: FinalT = _count_context_run_tb_frames()
+# Why doesn't this inherit from abc.Clock?
@attr.s(frozen=True, slots=True)
class SystemClock:
# Add a large random offset to our clock to ensure that if people
@@ -1171,7 +1175,7 @@ class Task(metaclass=NoPublicConstructor):
coro: Coroutine[Any, Outcome[object], Any] = attr.ib()
_runner: Runner = attr.ib()
name: str = attr.ib()
- context: contextvars.Context = attr.ib()
+ context: Context = attr.ib()
_counter: int = attr.ib(init=False, factory=itertools.count().__next__)
# Invariant:
@@ -1358,7 +1362,7 @@ class RunContext(threading.local):
@attr.s(frozen=True)
-class _RunStatistics:
+class RunStatistics:
tasks_living = attr.ib()
tasks_runnable = attr.ib()
seconds_to_next_deadline = attr.ib()
@@ -1432,7 +1436,7 @@ def in_main_thread():
@attr.s(eq=False, hash=False, slots=True)
class Runner:
- clock = attr.ib()
+ clock: SystemClock | Clock | MockClock = attr.ib()
instruments: Instruments = attr.ib()
io_manager: TheIOManager = attr.ib()
ki_manager = attr.ib()
@@ -1442,18 +1446,18 @@ class Runner:
_locals = attr.ib(factory=dict)
runq: deque[Task] = attr.ib(factory=deque)
- tasks = attr.ib(factory=set)
+ tasks: set[Task] = attr.ib(factory=set)
deadlines = attr.ib(factory=Deadlines)
- init_task = attr.ib(default=None)
+ init_task: Task | None = attr.ib(default=None)
system_nursery = attr.ib(default=None)
system_context = attr.ib(default=None)
main_task = attr.ib(default=None)
main_task_outcome = attr.ib(default=None)
entry_queue = attr.ib(factory=EntryQueue)
- trio_token = attr.ib(default=None)
+ trio_token: TrioToken | None = attr.ib(default=None)
asyncgens = attr.ib(factory=AsyncGenerators)
# If everything goes idle for this long, we call clock._autojump()
@@ -1479,7 +1483,7 @@ def close(self):
self.ki_manager.close()
@_public
- def current_statistics(self):
+ def current_statistics(self) -> RunStatistics:
"""Returns an object containing run-loop-level debugging information.
Currently the following fields are defined:
@@ -1503,7 +1507,7 @@ def current_statistics(self):
"""
seconds_to_next_deadline = self.deadlines.next_deadline() - self.current_time()
- return _RunStatistics(
+ return RunStatistics(
tasks_living=len(self.tasks),
tasks_runnable=len(self.runq),
seconds_to_next_deadline=seconds_to_next_deadline,
@@ -1512,7 +1516,7 @@ def current_statistics(self):
)
@_public
- def current_time(self):
+ def current_time(self) -> float:
"""Returns the current time according to Trio's internal clock.
Returns:
@@ -1524,13 +1528,15 @@ def current_time(self):
"""
return self.clock.current_time()
+ # TODO: abc.Clock or SystemClock? (the latter which doesn't inherit
+ # from abc.Clock)
@_public
- def current_clock(self):
+ def current_clock(self) -> SystemClock | Clock:
"""Returns the current :class:`~trio.abc.Clock`."""
return self.clock
@_public
- def current_root_task(self):
+ def current_root_task(self) -> Task | None:
"""Returns the current root :class:`Task`.
This is the task that is the ultimate parent of all other tasks.
@@ -1543,7 +1549,7 @@ def current_root_task(self):
################
@_public
- def reschedule(self, task, next_send=_NO_SEND):
+ def reschedule(self, task: Task, next_send: Outcome = _NO_SEND) -> None:
"""Reschedule the given task with the given
:class:`outcome.Outcome`.
@@ -1577,8 +1583,15 @@ def reschedule(self, task, next_send=_NO_SEND):
self.instruments.call("task_scheduled", task)
def spawn_impl(
- self, async_fn, args, nursery, name, *, system_task=False, context=None
- ):
+ self,
+ async_fn: Callable[..., Awaitable[object]],
+ args: Iterable[Any],
+ nursery: Nursery | None,
+ name: str | functools.partial | Callable[..., Awaitable[object]] | None,
+ *,
+ system_task: bool = False,
+ context: Context | None = None,
+ ) -> Task:
######
# Make sure the nursery is in working order
######
@@ -1696,7 +1709,13 @@ def task_exited(self, task, outcome):
################
@_public
- def spawn_system_task(self, async_fn, *args, name=None, context=None):
+ def spawn_system_task(
+ self,
+ async_fn: Callable[..., Awaitable[object]],
+ *args: Any,
+ name: str | None = None,
+ context: Context | None = None,
+ ) -> Task:
"""Spawn a "system" task.
System tasks have a few differences from regular tasks:
@@ -1795,7 +1814,7 @@ async def init(self, async_fn, args):
################
@_public
- def current_trio_token(self):
+ def current_trio_token(self) -> TrioToken:
"""Retrieve the :class:`TrioToken` for the current call to
:func:`trio.run`.
@@ -1844,7 +1863,7 @@ def _deliver_ki_cb(self):
waiting_for_idle = attr.ib(factory=SortedDict)
@_public
- async def wait_all_tasks_blocked(self, cushion=0.0):
+ async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None:
"""Block until there are no runnable tasks.
This is useful in testing code when you want to give other tasks a
@@ -2311,7 +2330,7 @@ def unrolled_run(
break
else:
assert idle_primed is IdlePrimedTypes.AUTOJUMP_CLOCK
- runner.clock._autojump()
+ runner.clock._autojump() # type: ignore[union-attr]
# Process all runnable tasks, but only the ones that are already
# runnable now. Anything that becomes runnable during this cycle
diff --git a/trio/_core/_tests/test_io.py b/trio/_core/_tests/test_io.py
index 21a954941c..2205c83976 100644
--- a/trio/_core/_tests/test_io.py
+++ b/trio/_core/_tests/test_io.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
import random
import socket as stdlib_socket
from contextlib import suppress
+from typing import Callable
import pytest
@@ -47,15 +50,15 @@ def fileno_wrapper(fileobj):
return fileno_wrapper
-wait_readable_options = [trio.lowlevel.wait_readable]
-wait_writable_options = [trio.lowlevel.wait_writable]
-notify_closing_options = [trio.lowlevel.notify_closing]
+wait_readable_options: list[Callable] = [trio.lowlevel.wait_readable]
+wait_writable_options: list[Callable] = [trio.lowlevel.wait_writable]
+notify_closing_options: list[Callable] = [trio.lowlevel.notify_closing]
-for options_list in [
+for options_list in (
wait_readable_options,
wait_writable_options,
notify_closing_options,
-]:
+):
options_list += [using_fileno(f) for f in options_list]
# Decorators that feed in different settings for wait_readable / wait_writable
diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py
index 7a8bd2f9a8..52e5e39d1b 100644
--- a/trio/_core/_tests/test_multierror.py
+++ b/trio/_core/_tests/test_multierror.py
@@ -555,7 +555,7 @@ def test_apport_excepthook_monkeypatch_interaction():
@pytest.mark.parametrize("protocol", range(0, pickle.HIGHEST_PROTOCOL + 1))
-def test_pickle_multierror(protocol) -> None:
+def test_pickle_multierror(protocol: int) -> None:
# use trio.MultiError to make sure that pickle works through the deprecation layer
import trio
diff --git a/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py b/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py
index 3e1d23ca8e..e51b8cdca0 100644
--- a/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py
+++ b/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py
@@ -12,4 +12,4 @@
import trio
-raise trio.MultiError([KeyError("key_error"), ValueError("value_error")])
+raise trio.MultiError([KeyError("key_error"), ValueError("value_error")]) # type: ignore[attr-defined]
diff --git a/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py b/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py
index 80e42b6a2c..c8086d3a0e 100644
--- a/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py
+++ b/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py
@@ -33,4 +33,4 @@ def custom_exc_hook(etype, value, tb, tb_offset=None):
# The custom excepthook should run, because Trio was polite and didn't
# override it
-raise trio.MultiError([ValueError(), KeyError()])
+raise trio.MultiError([ValueError(), KeyError()]) # type: ignore[attr-defined]
diff --git a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py
index 94004525db..c2297df400 100644
--- a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py
+++ b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py
@@ -18,4 +18,4 @@ def exc2_fn():
# This should be printed nicely, because Trio overrode sys.excepthook
-raise trio.MultiError([exc1_fn(), exc2_fn()])
+raise trio.MultiError([exc1_fn(), exc2_fn()]) # type: ignore[attr-defined]
diff --git a/trio/_deprecate.py b/trio/_deprecate.py
index aeebe80722..7deecb7042 100644
--- a/trio/_deprecate.py
+++ b/trio/_deprecate.py
@@ -59,7 +59,7 @@ def warn_deprecated(thing, version, *, issue, instead, stacklevel=2):
# @deprecated("0.2.0", issue=..., instead=...)
# def ...
def deprecated(
- version: str, *, thing: str | None = None, issue: int, instead: str
+ version: str, *, thing: str | None = None, issue: int | None, instead: object
) -> Callable[[T], T]:
def do_wrap(fn):
nonlocal thing
diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py
index ad69017219..f90a3f5b65 100644
--- a/trio/_subprocess_platform/waitid.py
+++ b/trio/_subprocess_platform/waitid.py
@@ -2,13 +2,15 @@
import math
import os
import sys
+from typing import TYPE_CHECKING
from .. import _core, _subprocess
from .._sync import CapacityLimiter, Event
from .._threads import to_thread_run_sync
try:
- from os import waitid
+ if not TYPE_CHECKING or sys.platform == "unix":
+ from os import waitid
def sync_wait_reapable(pid):
waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT)
diff --git a/trio/_tests/check_type_completeness.py b/trio/_tests/check_type_completeness.py
index 7a65a4249e..abaabcf785 100755
--- a/trio/_tests/check_type_completeness.py
+++ b/trio/_tests/check_type_completeness.py
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
+from __future__ import annotations
+
# this file is not run as part of the tests, instead it's run standalone from check.sh
import argparse
import json
diff --git a/trio/_tests/test_contextvars.py b/trio/_tests/test_contextvars.py
index 63853f5171..0ff13435cf 100644
--- a/trio/_tests/test_contextvars.py
+++ b/trio/_tests/test_contextvars.py
@@ -2,7 +2,9 @@
from .. import _core
-trio_testing_contextvar = contextvars.ContextVar("trio_testing_contextvar")
+trio_testing_contextvar: contextvars.ContextVar = contextvars.ContextVar(
+ "trio_testing_contextvar"
+)
async def test_contextvars_default():
diff --git a/trio/_tests/test_dtls.py b/trio/_tests/test_dtls.py
index b8c32c6d5f..8cb06ccb3d 100644
--- a/trio/_tests/test_dtls.py
+++ b/trio/_tests/test_dtls.py
@@ -17,10 +17,10 @@
ca = trustme.CA()
server_cert = ca.issue_cert("example.com")
-server_ctx = SSL.Context(SSL.DTLS_METHOD)
+server_ctx = SSL.Context(SSL.DTLS_METHOD) # type: ignore[attr-defined]
server_cert.configure_cert(server_ctx)
-client_ctx = SSL.Context(SSL.DTLS_METHOD)
+client_ctx = SSL.Context(SSL.DTLS_METHOD) # type: ignore[attr-defined]
ca.configure_trust(client_ctx)
diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py
index 20635b0022..6e65a39316 100644
--- a/trio/_tests/test_exports.py
+++ b/trio/_tests/test_exports.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import enum
import functools
import importlib
@@ -175,7 +177,7 @@ def no_underscores(symbols):
if modname == "trio":
static_names.add("testing")
- # these are hidden behind `if sys.plaftorm != "win32" or not TYPE_CHECKING`
+ # these are hidden behind `if sys.platform != "win32" or not TYPE_CHECKING`
# so presumably pyright is parsing that if statement, in which case we don't
# care about them being missing.
if modname == "trio.socket" and sys.platform == "win32":
@@ -226,7 +228,9 @@ def no_underscores(symbols):
)
@pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES)
@pytest.mark.parametrize("tool", ["jedi", "mypy"])
-def test_static_tool_sees_class_members(tool, module_name, tmpdir) -> None:
+def test_static_tool_sees_class_members(
+ tool: str, module_name: str, tmpdir: Path
+) -> None:
module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)]
# ignore hidden, but not dunder, symbols
@@ -483,7 +487,7 @@ def test_classes_are_final():
continue
# These are classes that are conceptually abstract, but
# inspect.isabstract returns False for boring reasons.
- if class_ in {trio.abc.Instrument, trio.socket.SocketType}:
+ if class_ in (trio.abc.Instrument, trio.socket.SocketType):
continue
# Enums have their own metaclass, so we can't use our metaclasses.
# And I don't think there's a lot of risk from people subclassing
diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py
index 4385263899..67e2eecbc8 100644
--- a/trio/_tests/test_highlevel_serve_listeners.py
+++ b/trio/_tests/test_highlevel_serve_listeners.py
@@ -12,7 +12,7 @@
class MemoryListener(trio.abc.Listener):
closed = attr.ib(default=False)
accepted_streams = attr.ib(factory=list)
- queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1)))
+ queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel[object](1)))
accept_hook = attr.ib(default=None)
async def connect(self):
diff --git a/trio/_tests/test_subprocess.py b/trio/_tests/test_subprocess.py
index 4dfaef4c7f..7986dfd71e 100644
--- a/trio/_tests/test_subprocess.py
+++ b/trio/_tests/test_subprocess.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import random
import signal
@@ -6,6 +8,7 @@
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path as SyncPath
+from typing import TYPE_CHECKING
import pytest
@@ -24,8 +27,15 @@
from ..lowlevel import open_process
from ..testing import assert_no_checkpoints, wait_all_tasks_blocked
+if TYPE_CHECKING:
+ ...
+ from signal import Signals
+
posix = os.name == "posix"
-if posix:
+SIGKILL: Signals | None
+SIGTERM: Signals | None
+SIGUSR1: Signals | None
+if (not TYPE_CHECKING and posix) or sys.platform != "win32":
from signal import SIGKILL, SIGTERM, SIGUSR1
else:
SIGKILL, SIGTERM, SIGUSR1 = None, None, None
@@ -574,7 +584,7 @@ async def test_for_leaking_fds():
async def test_subprocess_pidfd_unnotified():
noticed_exit = None
- async def wait_and_tell(proc) -> None:
+ async def wait_and_tell(proc: Process) -> None:
nonlocal noticed_exit
noticed_exit = Event()
await proc.wait()
diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py
index 21eb7b12e8..9149f43037 100644
--- a/trio/_tests/test_threads.py
+++ b/trio/_tests/test_threads.py
@@ -170,7 +170,7 @@ async def main():
async def test_named_thread():
ending = " from trio._tests.test_threads.test_named_thread"
- def inner(name="inner" + ending) -> threading.Thread:
+ def inner(name: str = "inner" + ending) -> threading.Thread:
assert threading.current_thread().name == name
return threading.current_thread()
@@ -185,7 +185,7 @@ def f(name: str) -> Callable[[None], threading.Thread]:
await to_thread_run_sync(f("None" + ending))
# test that you can set a custom name, and that it's reset afterwards
- async def test_thread_name(name: str):
+ async def test_thread_name(name: str) -> None:
thread = await to_thread_run_sync(f(name), thread_name=name)
assert re.match("Trio thread [0-9]*", thread.name)
@@ -235,7 +235,7 @@ def _get_thread_name(ident: Optional[int] = None) -> Optional[str]:
# and most mac machines. So unless the platform is linux it will just skip
# in case it fails to fetch the os thread name.
async def test_named_thread_os():
- def inner(name) -> threading.Thread:
+ def inner(name: str) -> threading.Thread:
os_thread_name = _get_thread_name()
if os_thread_name is None and sys.platform != "linux":
pytest.skip(f"no pthread OS support on {sys.platform}")
@@ -253,7 +253,7 @@ def f(name: str) -> Callable[[None], threading.Thread]:
await to_thread_run_sync(f(default), thread_name=None)
# test that you can set a custom name, and that it's reset afterwards
- async def test_thread_name(name: str, expected: Optional[str] = None):
+ async def test_thread_name(name: str, expected: Optional[str] = None) -> None:
if expected is None:
expected = name
thread = await to_thread_run_sync(f(expected), thread_name=name)
@@ -584,7 +584,9 @@ async def async_fn(): # pragma: no cover
await to_thread_run_sync(async_fn)
-trio_test_contextvar = contextvars.ContextVar("trio_test_contextvar")
+trio_test_contextvar: contextvars.ContextVar = contextvars.ContextVar(
+ "trio_test_contextvar"
+)
async def test_trio_to_thread_run_sync_contextvars():
diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py
index 07d1ff7609..e5110eaff3 100644
--- a/trio/_tests/test_tracing.py
+++ b/trio/_tests/test_tracing.py
@@ -1,26 +1,26 @@
import trio
-async def coro1(event: trio.Event):
+async def coro1(event: trio.Event) -> None:
event.set()
await trio.sleep_forever()
-async def coro2(event: trio.Event):
+async def coro2(event: trio.Event) -> None:
await coro1(event)
-async def coro3(event: trio.Event):
+async def coro3(event: trio.Event) -> None:
await coro2(event)
-async def coro2_async_gen(event: trio.Event):
+async def coro2_async_gen(event):
yield await trio.lowlevel.checkpoint()
yield await coro1(event)
yield await trio.lowlevel.checkpoint()
-async def coro3_async_gen(event: trio.Event):
+async def coro3_async_gen(event: trio.Event) -> None:
async for x in coro2_async_gen(event):
pass
diff --git a/trio/_tests/test_unix_pipes.py b/trio/_tests/test_unix_pipes.py
index acee75aafb..0b0d2ceb23 100644
--- a/trio/_tests/test_unix_pipes.py
+++ b/trio/_tests/test_unix_pipes.py
@@ -1,7 +1,10 @@
+from __future__ import annotations
+
import errno
import os
import select
import sys
+from typing import TYPE_CHECKING
import pytest
@@ -11,6 +14,9 @@
posix = os.name == "posix"
pytestmark = pytest.mark.skipif(not posix, reason="posix only")
+
+assert not TYPE_CHECKING or sys.platform == "unix"
+
if posix:
from .._unix_pipes import FdStream
else:
@@ -19,7 +25,7 @@
# Have to use quoted types so import doesn't crash on windows
-async def make_pipe() -> "Tuple[FdStream, FdStream]":
+async def make_pipe() -> "tuple[FdStream, FdStream]":
"""Makes a new pair of pipes."""
(r, w) = os.pipe()
return FdStream(w), FdStream(r)
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index c454c4d3b9..1824df59e1 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.92,
+ "completenessScore": 0.9328,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 575,
- "withUnknownType": 50
+ "withKnownType": 583,
+ "withUnknownType": 42
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 0,
@@ -99,20 +99,13 @@
"trio._subprocess.Process.send_signal",
"trio._subprocess.Process.terminate",
"trio._subprocess.Process.wait",
- "trio.current_time",
"trio.from_thread.run",
"trio.from_thread.run_sync",
"trio.lowlevel.cancel_shielded_checkpoint",
- "trio.lowlevel.current_clock",
- "trio.lowlevel.current_root_task",
- "trio.lowlevel.current_statistics",
- "trio.lowlevel.current_trio_token",
"trio.lowlevel.currently_ki_protected",
"trio.lowlevel.notify_closing",
"trio.lowlevel.permanently_detach_coroutine_object",
"trio.lowlevel.reattach_detached_coroutine_object",
- "trio.lowlevel.reschedule",
- "trio.lowlevel.spawn_system_task",
"trio.lowlevel.start_guest_run",
"trio.lowlevel.temporarily_detach_coroutine_object",
"trio.lowlevel.wait_readable",
@@ -158,7 +151,6 @@
"trio.testing.memory_stream_pump",
"trio.testing.open_stream_to_socket_listener",
"trio.testing.trio_test",
- "trio.testing.wait_all_tasks_blocked",
"trio.to_thread.current_default_thread_limiter",
"trio.wrap_file"
]
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 2996cfaaad..0e2dfb7b42 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -3,15 +3,21 @@
Code generation script for class methods
to be exported as public API
"""
+from __future__ import annotations
+
import argparse
import ast
import os
import sys
from pathlib import Path
from textwrap import indent
+from typing import TYPE_CHECKING, Iterator, Union
import astor
+if TYPE_CHECKING:
+ from typing_extensions import TypeAlias, TypeGuard
+
PREFIX = "_generated"
HEADER = """# ***********************************************************
@@ -20,18 +26,22 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Callable, Iterator
+from typing import TYPE_CHECKING, Callable, Iterator, Awaitable, Any
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
+ import sys
import select
from socket import socket
from _contextlib import _GeneratorContextManager
- from _core import Abort, RaiseCancelT
+ from contextvars import Context
+ from _core import Abort, RaiseCancelT, RunStatistics, Task, SystemClock, TrioToken
+ from outcome import Outcome
+ from .._abc import Clock
from .. import _core
@@ -48,8 +58,10 @@
raise RuntimeError("must be called from async context")
"""
+AstFun: TypeAlias = Union[ast.FunctionDef, ast.AsyncFunctionDef]
+
-def is_function(node):
+def is_function(node: ast.AST) -> TypeGuard[AstFun]:
"""Check if the AST node is either a function
or an async function
"""
@@ -58,7 +70,7 @@ def is_function(node):
return False
-def is_public(node):
+def is_public(node: ast.AST) -> TypeGuard[AstFun]:
"""Check if the AST node has a _public decorator"""
if not is_function(node):
return False
@@ -68,7 +80,7 @@ def is_public(node):
return False
-def get_public_methods(tree):
+def get_public_methods(tree: ast.AST) -> Iterator[AstFun]:
"""Return a list of methods marked as public.
The function walks the given tree and extracts
all objects that are functions which are marked
@@ -79,7 +91,7 @@ def get_public_methods(tree):
yield node
-def create_passthrough_args(funcdef):
+def create_passthrough_args(funcdef: AstFun) -> str:
"""Given a function definition, create a string that represents taking all
the arguments from the function, and passing them through to another
invocation of the same function.
@@ -143,7 +155,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
return "\n\n".join(generated)
-def matches_disk_files(new_files):
+def matches_disk_files(new_files: dict[str, str]) -> bool:
for new_path, new_source in new_files.items():
if not os.path.exists(new_path):
return False
@@ -154,7 +166,7 @@ def matches_disk_files(new_files):
return True
-def process(sources_and_lookups, *, do_test):
+def process(sources_and_lookups: list[tuple[Path, str]], *, do_test: bool) -> None:
new_files = {}
for source_path, lookup_path in sources_and_lookups:
print("Scanning:", source_path)
@@ -177,7 +189,7 @@ def process(sources_and_lookups, *, do_test):
# This is in fact run in CI, but only in the formatting check job, which
# doesn't collect coverage.
-def main(): # pragma: no cover
+def main() -> None: # pragma: no cover
parser = argparse.ArgumentParser(
description="Generate python code for public api wrappers"
)
diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py
index 716550790e..1a389e12dd 100644
--- a/trio/_unix_pipes.py
+++ b/trio/_unix_pipes.py
@@ -2,6 +2,7 @@
import errno
import os
+import sys
from typing import TYPE_CHECKING
import trio
@@ -12,6 +13,8 @@
if TYPE_CHECKING:
from typing import Final as FinalType
+assert not TYPE_CHECKING or sys.platform != "win32"
+
if os.name != "posix":
# We raise an error here rather than gating the import in lowlevel.py
# in order to keep jedi static analysis happy.
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index fdb4d45102..9befedf21b 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -19,7 +19,7 @@
from trio._util import Final, NoPublicConstructor
if TYPE_CHECKING:
- import socket
+ from socket import AddressFamily, SocketKind
from types import TracebackType
IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@@ -105,7 +105,7 @@ def reply(self, payload):
class FakeSocketFactory(trio.abc.SocketFactory):
fake_net: "FakeNet"
- def socket(self, family: int, type: int, proto: int) -> "FakeSocket":
+ def socket(self, family: int, type: int, proto: int) -> FakeSocket: # type: ignore[override]
return FakeSocket._create(self.fake_net, family, type, proto)
@@ -123,8 +123,8 @@ async def getaddrinfo(
flags: int = 0,
) -> list[
tuple[
- socket.AddressFamily,
- socket.SocketKind,
+ AddressFamily,
+ SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
@@ -139,13 +139,13 @@ async def getnameinfo(
class FakeNet(metaclass=Final):
- def __init__(self):
+ def __init__(self) -> None:
# When we need to pick an arbitrary unique ip address/port, use these:
self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts()
- self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts()
+ self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() # type: ignore[assignment]
self._auto_port_iter = iter(range(50000, 65535))
- self._bound: Dict[UDPBinding, FakeSocket] = {}
+ self._bound: dict[UDPBinding, FakeSocket] = {}
self.route_packet = None
@@ -193,7 +193,7 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int):
self._closed = False
- self._packet_sender, self._packet_receiver = trio.open_memory_channel(
+ self._packet_sender, self._packet_receiver = trio.open_memory_channel[object](
float("inf")
)
@@ -223,7 +223,7 @@ async def _resolve_address_nocp(self, address, *, local):
local=local,
)
- def _deliver_packet(self, packet: UDPPacket):
+ def _deliver_packet(self, packet: UDPPacket) -> None:
try:
self._packet_sender.send_nowait(packet)
except trio.BrokenResourceError:
From 2490873beee68557babcc87f10ae7a3232dd6c1b Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 25 Jul 2023 16:13:59 +0200
Subject: [PATCH 23/49] started playing around with windows but mostly gave up.
Some types added though
---
pyproject.toml | 24 ++++---
trio/_core/_generated_instrumentation.py | 1 +
trio/_core/_generated_io_epoll.py | 1 +
trio/_core/_generated_io_kqueue.py | 1 +
trio/_core/_generated_io_windows.py | 14 ++--
trio/_core/_generated_run.py | 1 +
trio/_core/_io_epoll.py | 1 +
trio/_core/_io_windows.py | 89 ++++++++++++++----------
trio/_core/_windows_cffi.py | 3 +-
trio/_tools/gen_exports.py | 1 +
10 files changed, 83 insertions(+), 53 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 74d9108644..51a1bd8178 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -64,7 +64,11 @@ disallow_any_generics = false
module = "trio/_core/_generated_run"
disable_error_code = ['has-type']
[[tool.mypy.overrides]]
-module = "trio/_core/_generated_io_kqueue"
+module = [
+ "trio/_core/_generated_io_kqueue",
+ "trio/_core/_generated_io_windows",
+ ]
+
disable_error_code = ['name-defined', 'attr-defined', 'no-any-return']
[[tool.mypy.overrides]]
module = "trio/_core/_generated_io_epoll"
@@ -75,26 +79,28 @@ module = [
"trio/_core/_tests/*",
"trio/_tests/*",
-
+# windows
"trio/_windows_pipes",
-#"trio/_core/_io_common", # 1, 24
"trio/_core/_windows_cffi", # 2, 324
+"trio/_core/_generated_io_windows", # 9 (win32), 84
+"trio/_core/_io_windows", # 47 (win32), 867
+
+#"trio/_core/_io_common", # 1, 24
#"trio/_core/_generated_io_epoll", # 3, 36
+#"trio/_core/_generated_io_kqueue", # 6 (darwin), 60
+#"trio/_core/_generated_run", # 8, 242
+#"trio/_core/_io_kqueue", # 16, 198
+#"trio/_core/_io_epoll", # 21, 323
+
"trio/_core/_thread_cache", # 6, 273
-"trio/_core/_generated_io_kqueue", # 6 (darwin), 60
"trio/_core/_traps", # 7, 276
-#"trio/_core/_generated_run", # 8, 242
-"trio/_core/_generated_io_windows", # 9 (win32), 84
"trio/_core/_asyncgens", # 10, 194
"trio/_core/_wakeup_socketpair", # 12
"trio/_core/_ki", # 14, 210
"trio/_core/_entry_queue", # 16, 195
-#"trio/_core/_io_kqueue", # 16, 198
"trio/_core/_multierror", # 19, 469
-#"trio/_core/_io_epoll", # 21, 323
-"trio/_core/_io_windows", # 47 (win32), 867
"trio/testing/_network", # 1, 34
diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py
index a1d38519d9..bb19c2fbe5 100644
--- a/trio/_core/_generated_instrumentation.py
+++ b/trio/_core/_generated_instrumentation.py
@@ -22,6 +22,7 @@
from .. import _core
from .._abc import Clock
+ from ._unbounded_queue import UnboundedQueue
# fmt: off
diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py
index ea30ddf1fc..17d712726b 100644
--- a/trio/_core/_generated_io_epoll.py
+++ b/trio/_core/_generated_io_epoll.py
@@ -22,6 +22,7 @@
from .. import _core
from .._abc import Clock
+ from ._unbounded_queue import UnboundedQueue
# fmt: off
diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py
index 57fbd6f423..bfba052792 100644
--- a/trio/_core/_generated_io_kqueue.py
+++ b/trio/_core/_generated_io_kqueue.py
@@ -22,6 +22,7 @@
from .. import _core
from .._abc import Clock
+ from ._unbounded_queue import UnboundedQueue
# fmt: off
diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py
index 4c92661c50..7072f12aaf 100644
--- a/trio/_core/_generated_io_windows.py
+++ b/trio/_core/_generated_io_windows.py
@@ -22,11 +22,12 @@
from .. import _core
from .._abc import Clock
+ from ._unbounded_queue import UnboundedQueue
# fmt: off
-async def wait_readable(sock):
+async def wait_readable(sock) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock)
@@ -34,7 +35,7 @@ async def wait_readable(sock):
raise RuntimeError("must be called from async context")
-async def wait_writable(sock):
+async def wait_writable(sock) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock)
@@ -42,7 +43,7 @@ async def wait_writable(sock):
raise RuntimeError("must be called from async context")
-def notify_closing(handle):
+def notify_closing(handle) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle)
@@ -50,7 +51,7 @@ def notify_closing(handle):
raise RuntimeError("must be called from async context")
-def register_with_iocp(handle):
+def register_with_iocp(handle) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle)
@@ -82,7 +83,7 @@ async def readinto_overlapped(handle, buffer, file_offset=0):
raise RuntimeError("must be called from async context")
-def current_iocp():
+def current_iocp() ->int:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp()
@@ -90,7 +91,8 @@ def current_iocp():
raise RuntimeError("must be called from async context")
-def monitor_completion_key():
+def monitor_completion_key() ->_GeneratorContextManager[tuple[int,
+ UnboundedQueue[object]]]:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key()
diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py
index 62d6eaff68..999ce9d1e5 100644
--- a/trio/_core/_generated_run.py
+++ b/trio/_core/_generated_run.py
@@ -22,6 +22,7 @@
from .. import _core
from .._abc import Clock
+ from ._unbounded_queue import UnboundedQueue
# fmt: off
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index 130403df73..dfa979ac82 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -20,6 +20,7 @@
@attr.s(slots=True, eq=False)
class EpollWaiters:
+ # TODO: why is nobody complaining about this?
read_task: None = attr.ib(default=None)
write_task: None = attr.ib(default=None)
current_flags: int = attr.ib(default=0)
diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py
index 4084f72b6e..66bf6a9349 100644
--- a/trio/_core/_io_windows.py
+++ b/trio/_core/_io_windows.py
@@ -1,9 +1,11 @@
+from __future__ import annotations
+
import enum
import itertools
import socket
import sys
from contextlib import contextmanager
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Iterator, Literal
import attr
from outcome import Value
@@ -29,6 +31,12 @@
assert not TYPE_CHECKING or sys.platform == "win32"
+if TYPE_CHECKING:
+ from _contextlib import _GeneratorContextManager
+
+ from ._traps import Abort, RaiseCancelT
+ from ._unbouded_queue import UnboundedQueue
+
# There's a lot to be said about the overall design of a Windows event
# loop. See
#
@@ -179,13 +187,15 @@ class CKeys(enum.IntEnum):
USER_DEFINED = 4 # and above
-def _check(success):
+def _check(success: bool) -> Literal[True]:
if not success:
raise_winerror()
return success
-def _get_underlying_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE):
+def _get_underlying_socket(
+ sock: socket.socket | int, *, which=WSAIoctls.SIO_BASE_HANDLE
+):
if hasattr(sock, "fileno"):
sock = sock.fileno()
base_ptr = ffi.new("HANDLE *")
@@ -330,9 +340,9 @@ def _afd_helper_handle():
# operation and start a new one.
@attr.s(slots=True, eq=False)
class AFDWaiters:
- read_task = attr.ib(default=None)
- write_task = attr.ib(default=None)
- current_op = attr.ib(default=None)
+ read_task: None = attr.ib(default=None)
+ write_task: None = attr.ib(default=None)
+ current_op: None = attr.ib(default=None)
# We also need to bundle up all the info for a single op into a standalone
@@ -340,10 +350,10 @@ class AFDWaiters:
# finishes, even if we're throwing it away.
@attr.s(slots=True, eq=False, frozen=True)
class AFDPollOp:
- lpOverlapped = attr.ib()
- poll_info = attr.ib()
- waiters = attr.ib()
- afd_group = attr.ib()
+ lpOverlapped: None = attr.ib()
+ poll_info: None = attr.ib()
+ waiters: None = attr.ib()
+ afd_group: None = attr.ib()
# The Windows kernel has a weird issue when using AFD handles. If you have N
@@ -359,17 +369,17 @@ class AFDPollOp:
@attr.s(slots=True, eq=False)
class AFDGroup:
- size = attr.ib()
- handle = attr.ib()
+ size: int = attr.ib()
+ handle: None = attr.ib()
@attr.s(slots=True, eq=False, frozen=True)
class _WindowsStatistics:
- tasks_waiting_read = attr.ib()
- tasks_waiting_write = attr.ib()
- tasks_waiting_overlapped = attr.ib()
- completion_key_monitors = attr.ib()
- backend = attr.ib(default="windows")
+ tasks_waiting_read: int = attr.ib()
+ tasks_waiting_write: int = attr.ib()
+ tasks_waiting_overlapped: int = attr.ib()
+ completion_key_monitors: int = attr.ib()
+ backend: str = attr.ib(default="windows")
# Maximum number of events to dequeue from the completion port on each pass
@@ -381,8 +391,8 @@ class _WindowsStatistics:
@attr.s(frozen=True)
class CompletionKeyEventInfo:
- lpOverlapped = attr.ib()
- dwNumberOfBytesTransferred = attr.ib()
+ lpOverlapped: None = attr.ib()
+ dwNumberOfBytesTransferred: int = attr.ib()
class WindowsIOManager:
@@ -449,7 +459,7 @@ def __init__(self):
"netsh winsock show catalog"
)
- def close(self):
+ def close(self) -> None:
try:
if self._iocp is not None:
iocp = self._iocp
@@ -460,10 +470,10 @@ def close(self):
afd_handle = self._all_afd_handles.pop()
_check(kernel32.CloseHandle(afd_handle))
- def __del__(self):
+ def __del__(self) -> None:
self.close()
- def statistics(self):
+ def statistics(self) -> _WindowsStatistics:
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._afd_waiters.values():
@@ -478,7 +488,7 @@ def statistics(self):
completion_key_monitors=len(self._completion_key_queues),
)
- def force_wakeup(self):
+ def force_wakeup(self) -> None:
_check(
kernel32.PostQueuedCompletionStatus(
self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL
@@ -502,7 +512,7 @@ def get_events(self, timeout):
return 0
return received[0]
- def process_events(self, received):
+ def process_events(self, received: int) -> None:
for i in range(received):
entry = self._events[i]
if entry.lpCompletionKey == CKeys.AFD_POLL:
@@ -582,7 +592,7 @@ def process_events(self, received):
)
queue.put_nowait(info)
- def _register_with_iocp(self, handle, completion_key):
+ def _register_with_iocp(self, handle, completion_key) -> None:
handle = _handle(handle)
_check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0))
# Supposedly this makes things slightly faster, by disabling the
@@ -599,7 +609,7 @@ def _register_with_iocp(self, handle, completion_key):
# AFD stuff
################################################################
- def _refresh_afd(self, base_handle):
+ def _refresh_afd(self, base_handle) -> None:
waiters = self._afd_waiters[base_handle]
if waiters.current_op is not None:
afd_group = waiters.current_op.afd_group
@@ -675,7 +685,7 @@ def _refresh_afd(self, base_handle):
if afd_group.size >= MAX_AFD_GROUP_SIZE:
self._vacant_afd_groups.remove(afd_group)
- async def _afd_poll(self, sock, mode):
+ async def _afd_poll(self, sock, mode) -> None:
base_handle = _get_base_socket(sock)
waiters = self._afd_waiters.get(base_handle)
if waiters is None:
@@ -688,7 +698,7 @@ async def _afd_poll(self, sock, mode):
# we let it escape.
self._refresh_afd(base_handle)
- def abort_fn(_):
+ def abort_fn(_: RaiseCancelT) -> Abort:
setattr(waiters, mode, None)
self._refresh_afd(base_handle)
return _core.Abort.SUCCEEDED
@@ -696,15 +706,15 @@ def abort_fn(_):
await _core.wait_task_rescheduled(abort_fn)
@_public
- async def wait_readable(self, sock):
+ async def wait_readable(self, sock) -> None:
await self._afd_poll(sock, "read_task")
@_public
- async def wait_writable(self, sock):
+ async def wait_writable(self, sock) -> None:
await self._afd_poll(sock, "write_task")
@_public
- def notify_closing(self, handle):
+ def notify_closing(self, handle) -> None:
handle = _get_base_socket(handle)
waiters = self._afd_waiters.get(handle)
if waiters is not None:
@@ -716,7 +726,7 @@ def notify_closing(self, handle):
################################################################
@_public
- def register_with_iocp(self, handle):
+ def register_with_iocp(self, handle) -> None:
self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED)
@_public
@@ -732,7 +742,7 @@ async def wait_overlapped(self, handle, lpOverlapped):
self._overlapped_waiters[lpOverlapped] = task
raise_cancel = None
- def abort(raise_cancel_):
+ def abort(raise_cancel_: RaiseCancelT) -> Abort:
nonlocal raise_cancel
raise_cancel = raise_cancel_
try:
@@ -852,14 +862,19 @@ def submit_read(lpOverlapped):
################################################################
@_public
- def current_iocp(self):
+ def current_iocp(self) -> int:
return int(ffi.cast("uintptr_t", self._iocp))
- @contextmanager
@_public
- def monitor_completion_key(self):
+ def monitor_completion_key(
+ self,
+ ) -> _GeneratorContextManager[tuple[int, UnboundedQueue[object]]]:
+ return self._monitor_completion_key()
+
+ @contextmanager
+ def _monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]:
key = next(self._completion_key_counter)
- queue = _core.UnboundedQueue()
+ queue = _core.UnboundedQueue[object]()
self._completion_key_queues[key] = queue
try:
yield (key, queue)
diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py
index 639e75b50e..50d598c2be 100644
--- a/trio/_core/_windows_cffi.py
+++ b/trio/_core/_windows_cffi.py
@@ -1,5 +1,6 @@
import enum
import re
+from typing import NoReturn
import cffi
@@ -315,7 +316,7 @@ def _handle(obj):
return obj
-def raise_winerror(winerror=None, *, filename=None, filename2=None):
+def raise_winerror(winerror=None, *, filename=None, filename2=None) -> NoReturn:
if winerror is None:
winerror, msg = ffi.getwinerror()
else:
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 0e2dfb7b42..8dd2d4fae3 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -37,6 +37,7 @@
import select
from socket import socket
+ from ._unbounded_queue import UnboundedQueue
from _contextlib import _GeneratorContextManager
from contextvars import Context
from _core import Abort, RaiseCancelT, RunStatistics, Task, SystemClock, TrioToken
From a349d998dec5846ee56e3655c9c3dff34e71baf4 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 25 Jul 2023 16:28:52 +0200
Subject: [PATCH 24/49] _core/_thread_cache
---
pyproject.toml | 2 +-
trio/_core/_thread_cache.py | 27 ++++++++++++++++++---------
2 files changed, 19 insertions(+), 10 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 51a1bd8178..f839b5b78e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -92,7 +92,7 @@ module = [
#"trio/_core/_io_kqueue", # 16, 198
#"trio/_core/_io_epoll", # 21, 323
-"trio/_core/_thread_cache", # 6, 273
+#"trio/_core/_thread_cache", # 6, 273
"trio/_core/_traps", # 7, 276
"trio/_core/_asyncgens", # 10, 194
diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py
index 157f14c5a1..823d22a10a 100644
--- a/trio/_core/_thread_cache.py
+++ b/trio/_core/_thread_cache.py
@@ -7,10 +7,13 @@
from functools import partial
from itertools import count
from threading import Lock, Thread
-from typing import Callable, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple
import outcome
+if TYPE_CHECKING:
+ from outcome import Value
+
def _to_os_thread_name(name: str) -> bytes:
# ctypes handles the trailing \00
@@ -116,9 +119,10 @@ def darwin_namefunc(
class WorkerThread:
def __init__(self, thread_cache: ThreadCache):
- # deliver (the second value) can probably be Callable[[outcome.Value], None] ?
# should generate stubs for outcome
- self._job: Optional[Tuple[Callable, Callable, str]] = None
+ self._job: Optional[
+ Tuple[Callable[[None], None], Callable[[Value], None], str | None]
+ ] = None
self._thread_cache = thread_cache
# This Lock is used in an unconventional way.
#
@@ -136,7 +140,7 @@ def __init__(self, thread_cache: ThreadCache):
set_os_thread_name(self._thread.ident, self._default_name)
self._thread.start()
- def _handle_job(self):
+ def _handle_job(self) -> None:
# Handle job in a separate method to ensure user-created
# objects are cleaned up in a consistent manner.
assert self._job is not None
@@ -167,7 +171,7 @@ def _handle_job(self):
print("Exception while delivering result of thread", file=sys.stderr)
traceback.print_exception(type(e), e, e.__traceback__)
- def _work(self):
+ def _work(self) -> None:
while True:
if self._worker_lock.acquire(timeout=IDLE_TIMEOUT):
# We got a job
@@ -191,11 +195,14 @@ def _work(self):
class ThreadCache:
- def __init__(self):
- self._idle_workers = {}
+ def __init__(self) -> None:
+ self._idle_workers: dict[WorkerThread, None] = {}
def start_thread_soon(
- self, fn: Callable, deliver: Callable, name: Optional[str] = None
+ self,
+ fn: Callable[[None], Any] | partial[Any],
+ deliver: Callable[[Value], None],
+ name: Optional[str] = None,
) -> None:
try:
worker, _ = self._idle_workers.popitem()
@@ -209,7 +216,9 @@ def start_thread_soon(
def start_thread_soon(
- fn: Callable, deliver: Callable, name: Optional[str] = None
+ fn: Callable[[None], None] | partial[Any],
+ deliver: Callable[[Value], None],
+ name: Optional[str] = None,
) -> None:
"""Runs ``deliver(outcome.capture(fn))`` in a worker thread.
From 7ed3a7b7d16286d92e247b67c18f99f71390c558 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 25 Jul 2023 16:39:20 +0200
Subject: [PATCH 25/49] type _traps as far as it's possible ... but tests are
failing
---
pyproject.toml | 2 +-
trio/_core/_traps.py | 33 +++++++++++++++++++++++----------
trio/_tests/verify_types.json | 10 +++-------
3 files changed, 27 insertions(+), 18 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index f839b5b78e..9e156299cc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -93,7 +93,7 @@ module = [
#"trio/_core/_io_epoll", # 21, 323
#"trio/_core/_thread_cache", # 6, 273
-"trio/_core/_traps", # 7, 276
+#"trio/_core/_traps", # 7, 276
"trio/_core/_asyncgens", # 10, 194
"trio/_core/_wakeup_socketpair", # 12
diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py
index 08a8ceac01..77fc9966fe 100644
--- a/trio/_core/_traps.py
+++ b/trio/_core/_traps.py
@@ -1,14 +1,25 @@
-# These are the only functions that ever yield back to the task runner.
+from __future__ import annotations
import enum
import types
-from typing import Any, Callable, NoReturn
+from typing import TYPE_CHECKING, Any, Callable, Iterator, NoReturn, TypeVar
import attr
import outcome
from . import _run
+# These are the only functions that ever yield back to the task runner.
+
+
+if TYPE_CHECKING:
+ from outcome import Outcome
+ from typing_extensions import TypeAlias
+
+ from ._run import Task
+
+T = TypeVar("T")
+
# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
# function, but you can inside a generator, and if you decorate your generator
@@ -18,7 +29,7 @@
# tracking machinery. Since our traps are public APIs, we make them real async
# functions, and then this helper takes care of the actual yield:
@types.coroutine
-def _async_yield(obj):
+def _async_yield(obj: T) -> Iterator[T]:
return (yield obj)
@@ -28,7 +39,7 @@ class CancelShieldedCheckpoint:
pass
-async def cancel_shielded_checkpoint():
+async def cancel_shielded_checkpoint() -> Any:
"""Introduce a schedule point, but not a cancel point.
This is *not* a :ref:`checkpoint `, but it is half of a
@@ -62,10 +73,10 @@ class Abort(enum.Enum):
# Not exported in the trio._core namespace, but imported directly by _run.
@attr.s(frozen=True)
class WaitTaskRescheduled:
- abort_func = attr.ib()
+ abort_func: Callable[[RaiseCancelT], Abort] = attr.ib()
-RaiseCancelT = Callable[[], NoReturn] # TypeAlias
+RaiseCancelT: TypeAlias = Callable[[], NoReturn]
# Should always return the type a Task "expects", unless you willfully reschedule it
@@ -175,10 +186,10 @@ def abort(inner_raise_cancel):
# Not exported in the trio._core namespace, but imported directly by _run.
@attr.s(frozen=True)
class PermanentlyDetachCoroutineObject:
- final_outcome = attr.ib()
+ final_outcome: Outcome = attr.ib()
-async def permanently_detach_coroutine_object(final_outcome):
+async def permanently_detach_coroutine_object(final_outcome: Outcome) -> Any:
"""Permanently detach the current task from the Trio scheduler.
Normally, a Trio task doesn't exit until its coroutine object exits. When
@@ -209,7 +220,9 @@ async def permanently_detach_coroutine_object(final_outcome):
return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome))
-async def temporarily_detach_coroutine_object(abort_func):
+async def temporarily_detach_coroutine_object(
+ abort_func: Callable[[RaiseCancelT], Abort]
+) -> Any:
"""Temporarily detach the current coroutine object from the Trio
scheduler.
@@ -245,7 +258,7 @@ async def temporarily_detach_coroutine_object(abort_func):
return await _async_yield(WaitTaskRescheduled(abort_func))
-async def reattach_detached_coroutine_object(task, yield_value):
+async def reattach_detached_coroutine_object(task: Task, yield_value: object) -> None:
"""Reattach a coroutine object that was detached using
:func:`temporarily_detach_coroutine_object`.
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 1824df59e1..a61c417781 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9328,
+ "completenessScore": 0.9392,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 583,
- "withUnknownType": 42
+ "withKnownType": 587,
+ "withUnknownType": 38
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 0,
@@ -101,13 +101,9 @@
"trio._subprocess.Process.wait",
"trio.from_thread.run",
"trio.from_thread.run_sync",
- "trio.lowlevel.cancel_shielded_checkpoint",
"trio.lowlevel.currently_ki_protected",
"trio.lowlevel.notify_closing",
- "trio.lowlevel.permanently_detach_coroutine_object",
- "trio.lowlevel.reattach_detached_coroutine_object",
"trio.lowlevel.start_guest_run",
- "trio.lowlevel.temporarily_detach_coroutine_object",
"trio.lowlevel.wait_readable",
"trio.lowlevel.wait_writable",
"trio.open_file",
From 1974a5cdc67028a5a9ab6e0317d20a4b9a4ade39 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 26 Jul 2023 13:17:50 +0200
Subject: [PATCH 26/49] fix after nit from a5
---
trio/_dtls.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_dtls.py b/trio/_dtls.py
index e8888d7871..9795357f94 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -645,12 +645,12 @@ def challenge_for(
return packet
-T = TypeVar("T")
+_T = TypeVar("_T")
-class _Queue(Generic[T]):
+class _Queue(Generic[_T]):
def __init__(self, incoming_packets_buffer: int | float):
- self.s, self.r = trio.open_memory_channel[T](incoming_packets_buffer)
+ self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer)
def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
From d0fae93eaf09f41a45987cfcd3846b14dde8220e Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 26 Jul 2023 13:32:36 +0200
Subject: [PATCH 27/49] fix test errors
---
trio/_core/__init__.py | 1 -
trio/_core/_run.py | 6 +++---
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py
index 26f6f04e7c..aa898fffe0 100644
--- a/trio/_core/__init__.py
+++ b/trio/_core/__init__.py
@@ -27,7 +27,6 @@
TASK_STATUS_IGNORED,
CancelScope,
Nursery,
- RunStatistics,
Task,
TaskStatus,
add_instrument,
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 79ed7ba8e3..463e6a7a1d 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -1362,7 +1362,7 @@ class RunContext(threading.local):
@attr.s(frozen=True)
-class RunStatistics:
+class _RunStatistics:
tasks_living = attr.ib()
tasks_runnable = attr.ib()
seconds_to_next_deadline = attr.ib()
@@ -1483,7 +1483,7 @@ def close(self):
self.ki_manager.close()
@_public
- def current_statistics(self) -> RunStatistics:
+ def current_statistics(self) -> _RunStatistics:
"""Returns an object containing run-loop-level debugging information.
Currently the following fields are defined:
@@ -1507,7 +1507,7 @@ def current_statistics(self) -> RunStatistics:
"""
seconds_to_next_deadline = self.deadlines.next_deadline() - self.current_time()
- return RunStatistics(
+ return _RunStatistics(
tasks_living=len(self.tasks),
tasks_runnable=len(self.runq),
seconds_to_next_deadline=seconds_to_next_deadline,
From f353058e7364d4c6554f54700089b9dd0205fbb3 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 26 Jul 2023 14:29:49 +0200
Subject: [PATCH 28/49] _asyncgens
---
pyproject.toml | 4 ++--
trio/_core/_asyncgens.py | 36 ++++++++++++++++++++++++------------
2 files changed, 26 insertions(+), 14 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 9e156299cc..543b8587d9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -84,6 +84,7 @@ module = [
"trio/_core/_windows_cffi", # 2, 324
"trio/_core/_generated_io_windows", # 9 (win32), 84
"trio/_core/_io_windows", # 47 (win32), 867
+"trio/_wait_for_object", # 2 (windows)
#"trio/_core/_io_common", # 1, 24
#"trio/_core/_generated_io_epoll", # 3, 36
@@ -94,7 +95,7 @@ module = [
#"trio/_core/_thread_cache", # 6, 273
#"trio/_core/_traps", # 7, 276
-"trio/_core/_asyncgens", # 10, 194
+#"trio/_core/_asyncgens", # 10, 194
"trio/_core/_wakeup_socketpair", # 12
"trio/_core/_ki", # 14, 210
@@ -118,7 +119,6 @@ module = [
"trio/_highlevel_open_tcp_stream", # 5, 379 lines
"trio/_subprocess_platform/waitid", # 2, 107 lines
-"trio/_wait_for_object", # 2 (windows)
"trio/_deprecate", # 12, 140lines
"trio/_util", # 13, 348 lines
"trio/_file_io", # 13, 191 lines
diff --git a/trio/_core/_asyncgens.py b/trio/_core/_asyncgens.py
index 5f02ebe76d..eacdbb4923 100644
--- a/trio/_core/_asyncgens.py
+++ b/trio/_core/_asyncgens.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import logging
import sys
import warnings
@@ -12,6 +14,16 @@
# Used to log exceptions in async generator finalizers
ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors")
+from typing import TYPE_CHECKING, AsyncGenerator
+
+if TYPE_CHECKING:
+ from typing_extensions import TypeAlias
+
+ from ._run import Runner
+
+# can this be typed more strictly in any way?
+AGenT: TypeAlias = AsyncGenerator[object, object]
+
@attr.s(eq=False, slots=True)
class AsyncGenerators:
@@ -22,17 +34,17 @@ class AsyncGenerators:
# asyncgens after the system nursery has been closed, it's a
# regular set so we don't have to deal with GC firing at
# unexpected times.
- alive = attr.ib(factory=weakref.WeakSet)
+ alive: weakref.WeakSet[AGenT] = attr.ib(factory=weakref.WeakSet)
# This collects async generators that get garbage collected during
# the one-tick window between the system nursery closing and the
# init task starting end-of-run asyncgen finalization.
- trailing_needs_finalize = attr.ib(factory=set)
+ trailing_needs_finalize: set[AGenT] = attr.ib(factory=set)
prev_hooks = attr.ib(init=False)
- def install_hooks(self, runner):
- def firstiter(agen):
+ def install_hooks(self, runner: Runner) -> None:
+ def firstiter(agen: AGenT) -> None:
if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"):
self.alive.add(agen)
else:
@@ -46,7 +58,7 @@ def firstiter(agen):
if self.prev_hooks.firstiter is not None:
self.prev_hooks.firstiter(agen)
- def finalize_in_trio_context(agen, agen_name):
+ def finalize_in_trio_context(agen: AGenT, agen_name: str) -> None:
try:
runner.spawn_system_task(
self._finalize_one,
@@ -61,7 +73,7 @@ def finalize_in_trio_context(agen, agen_name):
# have hit it.
self.trailing_needs_finalize.add(agen)
- def finalizer(agen):
+ def finalizer(agen: AGenT) -> None:
agen_name = name_asyncgen(agen)
try:
is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen")
@@ -99,7 +111,7 @@ def finalizer(agen):
try:
# If the next thing is a yield, this will raise RuntimeError
# which we allow to propagate
- closer.send(None)
+ closer.send(None) # type: ignore[attr-defined]
except StopIteration:
pass
else:
@@ -114,7 +126,7 @@ def finalizer(agen):
self.prev_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer)
- async def finalize_remaining(self, runner):
+ async def finalize_remaining(self, runner: Runner) -> None:
# This is called from init after shutting down the system nursery.
# The only tasks running at this point are init and
# the run_sync_soon task, and since the system nursery is closed,
@@ -125,7 +137,7 @@ async def finalize_remaining(self, runner):
# To make async generator finalization easier to reason
# about, we'll shut down asyncgen garbage collection by turning
# the alive WeakSet into a regular set.
- self.alive = set(self.alive)
+ self.alive = set(self.alive) # type: ignore
# Process all pending run_sync_soon callbacks, in case one of
# them was an asyncgen finalizer that snuck in under the wire.
@@ -170,14 +182,14 @@ async def finalize_remaining(self, runner):
# all are gone.
while self.alive:
batch = self.alive
- self.alive = set()
+ self.alive = set() # type: ignore
for agen in batch:
await self._finalize_one(agen, name_asyncgen(agen))
- def close(self):
+ def close(self) -> None:
sys.set_asyncgen_hooks(*self.prev_hooks)
- async def _finalize_one(self, agen, name):
+ async def _finalize_one(self, agen: AGenT, name: str) -> None:
try:
# This shield ensures that finalize_asyncgen never exits
# with an exception, not even a Cancelled. The inside
From 766efdcf9302505232dc1454372edf7409ca3ecd Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 26 Jul 2023 14:32:56 +0200
Subject: [PATCH 29/49] _wakeup_socketpair
---
pyproject.toml | 2 +-
trio/_core/_wakeup_socketpair.py | 16 +++++++++-------
2 files changed, 10 insertions(+), 8 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 543b8587d9..b866125bc8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -97,7 +97,7 @@ module = [
#"trio/_core/_traps", # 7, 276
#"trio/_core/_asyncgens", # 10, 194
-"trio/_core/_wakeup_socketpair", # 12
+#"trio/_core/_wakeup_socketpair", # 12
"trio/_core/_ki", # 14, 210
"trio/_core/_entry_queue", # 16, 195
"trio/_core/_multierror", # 19, 469
diff --git a/trio/_core/_wakeup_socketpair.py b/trio/_core/_wakeup_socketpair.py
index 51a80ef024..2ad1a023fe 100644
--- a/trio/_core/_wakeup_socketpair.py
+++ b/trio/_core/_wakeup_socketpair.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import signal
import socket
import warnings
@@ -7,7 +9,7 @@
class WakeupSocketpair:
- def __init__(self):
+ def __init__(self) -> None:
self.wakeup_sock, self.write_sock = socket.socketpair()
self.wakeup_sock.setblocking(False)
self.write_sock.setblocking(False)
@@ -27,26 +29,26 @@ def __init__(self):
self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except OSError:
pass
- self.old_wakeup_fd = None
+ self.old_wakeup_fd: int | None = None
- def wakeup_thread_and_signal_safe(self):
+ def wakeup_thread_and_signal_safe(self) -> None:
try:
self.write_sock.send(b"\x00")
except BlockingIOError:
pass
- async def wait_woken(self):
+ async def wait_woken(self) -> None:
await _core.wait_readable(self.wakeup_sock)
self.drain()
- def drain(self):
+ def drain(self) -> None:
try:
while True:
self.wakeup_sock.recv(2**16)
except BlockingIOError:
pass
- def wakeup_on_signals(self):
+ def wakeup_on_signals(self) -> None:
assert self.old_wakeup_fd is None
if not is_main_thread():
return
@@ -64,7 +66,7 @@ def wakeup_on_signals(self):
)
)
- def close(self):
+ def close(self) -> None:
self.wakeup_sock.close()
self.write_sock.close()
if self.old_wakeup_fd is not None:
From 6a13de7a657839da7255735a3bfd3b0278328f1a Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 26 Jul 2023 15:51:20 +0200
Subject: [PATCH 30/49] trio/_core/_ki
---
pyproject.toml | 2 +-
trio/_core/_ki.py | 33 +++++++++++++++++++--------------
trio/_tests/verify_types.json | 7 +++----
3 files changed, 23 insertions(+), 19 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index b866125bc8..6f49c6c1ae 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -98,7 +98,7 @@ module = [
#"trio/_core/_asyncgens", # 10, 194
#"trio/_core/_wakeup_socketpair", # 12
-"trio/_core/_ki", # 14, 210
+#"trio/_core/_ki", # 14, 210
"trio/_core/_entry_queue", # 16, 195
"trio/_core/_multierror", # 19, 469
diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py
index cc05ef9177..aca321bce9 100644
--- a/trio/_core/_ki.py
+++ b/trio/_core/_ki.py
@@ -11,6 +11,7 @@
from .._util import is_main_thread
if TYPE_CHECKING:
+ from types import FrameType
from typing import Any, Callable, TypeVar
F = TypeVar("F", bound=Callable[..., Any])
@@ -85,17 +86,17 @@
# NB: according to the signal.signal docs, 'frame' can be None on entry to
# this function:
-def ki_protection_enabled(frame):
+def ki_protection_enabled(frame: FrameType | None) -> bool:
while frame is not None:
if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals:
- return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED]
+ return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] # type: ignore[no-any-return]
if frame.f_code.co_name == "__del__":
return True
frame = frame.f_back
return True
-def currently_ki_protected():
+def currently_ki_protected() -> bool:
r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection
enabled.
@@ -115,19 +116,19 @@ def currently_ki_protected():
# functions decorated @async_generator are given this magic property that's a
# reference to the object itself
# see python-trio/async_generator/async_generator/_impl.py
-def legacy_isasyncgenfunction(obj):
+def legacy_isasyncgenfunction(obj: object) -> bool:
return getattr(obj, "_async_gen_function", None) == id(obj)
-def _ki_protection_decorator(enabled):
- def decorator(fn):
+def _ki_protection_decorator(enabled: bool) -> Callable[[F], F]:
+ def decorator(fn): # type: ignore[no-untyped-def]
# In some version of Python, isgeneratorfunction returns true for
# coroutine functions, so we have to check for coroutine functions
# first.
if inspect.iscoroutinefunction(fn):
@wraps(fn)
- def wrapper(*args, **kwargs):
+ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
# See the comment for regular generators below
coro = fn(*args, **kwargs)
coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
@@ -137,7 +138,7 @@ def wrapper(*args, **kwargs):
elif inspect.isgeneratorfunction(fn):
@wraps(fn)
- def wrapper(*args, **kwargs):
+ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
# It's important that we inject this directly into the
# generator's locals, as opposed to setting it here and then
# doing 'yield from'. The reason is, if a generator is
@@ -154,7 +155,7 @@ def wrapper(*args, **kwargs):
elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn):
@wraps(fn)
- def wrapper(*args, **kwargs):
+ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
# See the comment for regular generators above
agen = fn(*args, **kwargs)
agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
@@ -164,7 +165,7 @@ def wrapper(*args, **kwargs):
else:
@wraps(fn)
- def wrapper(*args, **kwargs):
+ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return fn(*args, **kwargs)
@@ -182,9 +183,13 @@ def wrapper(*args, **kwargs):
@attr.s
class KIManager:
- handler = attr.ib(default=None)
+ handler: Callable[[int, FrameType | None], None] | None = attr.ib(default=None)
- def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints):
+ def install(
+ self,
+ deliver_cb: Callable[[], None],
+ restrict_keyboard_interrupt_to_checkpoints: bool,
+ ) -> None:
assert self.handler is None
if (
not is_main_thread()
@@ -192,7 +197,7 @@ def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints):
):
return
- def handler(signum, frame):
+ def handler(signum: int, frame: FrameType | None) -> None:
assert signum == signal.SIGINT
protection_enabled = ki_protection_enabled(frame)
if protection_enabled or restrict_keyboard_interrupt_to_checkpoints:
@@ -203,7 +208,7 @@ def handler(signum, frame):
self.handler = handler
signal.signal(signal.SIGINT, handler)
- def close(self):
+ def close(self) -> None:
if self.handler is not None:
if signal.getsignal(signal.SIGINT) is self.handler:
signal.signal(signal.SIGINT, signal.default_int_handler)
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index a61c417781..b8eb3d5dcd 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9392,
+ "completenessScore": 0.9408,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 587,
- "withUnknownType": 38
+ "withKnownType": 588,
+ "withUnknownType": 37
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 0,
@@ -101,7 +101,6 @@
"trio._subprocess.Process.wait",
"trio.from_thread.run",
"trio.from_thread.run_sync",
- "trio.lowlevel.currently_ki_protected",
"trio.lowlevel.notify_closing",
"trio.lowlevel.start_guest_run",
"trio.lowlevel.wait_readable",
From 420e98b0aab3e34ed0516b102c4ae45487f87f48 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 26 Jul 2023 16:01:51 +0200
Subject: [PATCH 31/49] trio/_core/_entry_queue
---
pyproject.toml | 2 +-
trio/_core/_entry_queue.py | 41 ++++++++++++++++++++++-------------
trio/_tests/verify_types.json | 11 +++++-----
3 files changed, 32 insertions(+), 22 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 6f49c6c1ae..867d4df1a2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -99,7 +99,7 @@ module = [
#"trio/_core/_wakeup_socketpair", # 12
#"trio/_core/_ki", # 14, 210
-"trio/_core/_entry_queue", # 16, 195
+#"trio/_core/_entry_queue", # 16, 195
"trio/_core/_multierror", # 19, 469
diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py
index 878506bb2b..553143ceb1 100644
--- a/trio/_core/_entry_queue.py
+++ b/trio/_core/_entry_queue.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import threading
from collections import deque
+from typing import Any, Callable, Iterable, Literal, NoReturn
import attr
@@ -17,11 +20,13 @@ class EntryQueue:
# atomic WRT signal delivery (signal handlers can run on either side, but
# not *during* a deque operation). dict makes similar guarantees - and
# it's even ordered!
- queue = attr.ib(factory=deque)
- idempotent_queue = attr.ib(factory=dict)
+ queue: deque[tuple[Callable[..., Any], Iterable[Any]]] = attr.ib(factory=deque)
+ idempotent_queue: dict[tuple[Callable[..., Any], Iterable[Any]], None] = attr.ib(
+ factory=dict
+ )
- wakeup = attr.ib(factory=WakeupSocketpair)
- done = attr.ib(default=False)
+ wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
+ done: bool = attr.ib(default=False)
# Must be a reentrant lock, because it's acquired from signal handlers.
# RLock is signal-safe as of cpython 3.2. NB that this does mean that the
# lock is effectively *disabled* when we enter from signal context. The
@@ -30,9 +35,9 @@ class EntryQueue:
# main thread -- it just might happen at some inconvenient place. But if
# you look at the one place where the main thread holds the lock, it's
# just to make 1 assignment, so that's atomic WRT a signal anyway.
- lock = attr.ib(factory=threading.RLock)
+ lock: threading.RLock = attr.ib(factory=threading.RLock)
- async def task(self):
+ async def task(self) -> None:
assert _core.currently_ki_protected()
# RLock has two implementations: a signal-safe version in _thread, and
# and signal-UNsafe version in threading. We need the signal safe
@@ -43,7 +48,7 @@ async def task(self):
# https://bugs.python.org/issue13697#msg237140
assert self.lock.__class__.__module__ == "_thread"
- def run_cb(job):
+ def run_cb(job: tuple[Callable[..., object], Iterable[Any]]) -> Literal[True]:
# We run this with KI protection enabled; it's the callback's
# job to disable it if it wants it disabled. Exceptions are
# treated like system task exceptions (i.e., converted into
@@ -53,7 +58,7 @@ def run_cb(job):
sync_fn(*args)
except BaseException as exc:
- async def kill_everything(exc):
+ async def kill_everything(exc: BaseException) -> NoReturn:
raise exc
try:
@@ -63,14 +68,16 @@ async def kill_everything(exc):
# system nursery is already closed.
# TODO(2020-06): this is a gross hack and should
# be fixed soon when we address #1607.
- _core.current_task().parent_nursery.start_soon(kill_everything, exc)
+ parent_nursery = _core.current_task().parent_nursery
+ assert parent_nursery is not None
+ parent_nursery.start_soon(kill_everything, exc)
return True
# This has to be carefully written to be safe in the face of new items
# being queued while we iterate, and to do a bounded amount of work on
# each pass:
- def run_all_bounded():
+ def run_all_bounded() -> None:
for _ in range(len(self.queue)):
run_cb(self.queue.popleft())
for job in list(self.idempotent_queue):
@@ -104,13 +111,15 @@ def run_all_bounded():
assert not self.queue
assert not self.idempotent_queue
- def close(self):
+ def close(self) -> None:
self.wakeup.close()
- def size(self):
+ def size(self) -> int:
return len(self.queue) + len(self.idempotent_queue)
- def run_sync_soon(self, sync_fn, *args, idempotent=False):
+ def run_sync_soon(
+ self, sync_fn: Callable[..., object], *args: object, idempotent: bool = False
+ ) -> None:
with self.lock:
if self.done:
raise _core.RunFinishedError("run() has exited")
@@ -146,9 +155,11 @@ class TrioToken(metaclass=NoPublicConstructor):
"""
- _reentry_queue = attr.ib()
+ _reentry_queue: EntryQueue = attr.ib()
- def run_sync_soon(self, sync_fn, *args, idempotent=False):
+ def run_sync_soon(
+ self, sync_fn: Callable[..., object], *args: object, idempotent: bool = False
+ ) -> None:
"""Schedule a call to ``sync_fn(*args)`` to occur in the context of a
Trio task.
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index b8eb3d5dcd..e6e074802c 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9408,
+ "completenessScore": 0.9424,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 588,
- "withUnknownType": 37
+ "withKnownType": 589,
+ "withUnknownType": 36
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 0,
@@ -46,12 +46,11 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 552,
- "withUnknownType": 67
+ "withKnownType": 554,
+ "withUnknownType": 65
},
"packageName": "trio",
"symbols": [
- "trio._core._entry_queue.TrioToken.run_sync_soon",
"trio._core._run.Nursery.start",
"trio._core._run.Nursery.start_soon",
"trio._highlevel_socket.SocketStream.getsockopt",
From 2370c3fb0a45ca58d2fcd574e7ce5aa9369aae02 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 27 Jul 2023 15:08:59 +0200
Subject: [PATCH 32/49] type trio/_core/_run
---
pyproject.toml | 10 +-
trio/_core/_local.py | 9 +-
trio/_core/_run.py | 245 +++++++++++++++++++++-------------
trio/_core/_thread_cache.py | 8 +-
trio/_tests/verify_types.json | 14 +-
5 files changed, 172 insertions(+), 114 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 867d4df1a2..975d1f5fda 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -53,11 +53,11 @@ module = [
disallow_untyped_defs = true
disallow_any_generics = true
-[[tool.mypy.overrides]]
-module = "trio._core._run"
-disallow_incomplete_defs = false
-disallow_untyped_defs = false
-disallow_any_generics = false
+#[[tool.mypy.overrides]]
+#module = "trio._core._run"
+#disallow_incomplete_defs = false
+#disallow_untyped_defs = false
+#disallow_any_generics = false
# TODO: gen_exports add platform checks to specific files
[[tool.mypy.overrides]]
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index b9dada64fe..1965d44eb1 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -45,8 +45,7 @@ class RunVar(Generic[T], metaclass=Final):
def get(self, default: T | type[_NoValue] = _NoValue) -> T:
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
- # not typed yet
- return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value, index]
+ return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[no-any-return]
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
@@ -73,7 +72,7 @@ def set(self, value: T) -> RunVarToken[T]:
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index]
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
return token
def reset(self, token: RunVarToken[T]) -> None:
@@ -93,9 +92,9 @@ def reset(self, token: RunVarToken[T]) -> None:
previous = token.previous_value
try:
if previous is _NoValue:
- _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type]
+ _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
else:
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment]
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context")
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 463e6a7a1d..8c7b89ad2a 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -17,7 +17,16 @@
from math import inf
from time import perf_counter
from types import TracebackType
-from typing import TYPE_CHECKING, Any, Awaitable, Iterable, NoReturn, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Generator,
+ Iterable,
+ NoReturn,
+ Sequence,
+ TypeVar,
+)
import attr
from outcome import Error, Outcome, Value, capture
@@ -54,15 +63,27 @@
# An unfortunate name collision here with trio._util.Final
from typing import Final as FinalT
+ from typing_extensions import Self, TypeAlias
+
from .._abc import Clock
+ from ._local import RunVar
from ._mock_clock import MockClock
+ if sys.platform == "win32":
+ from ._io_windows import _WindowsStatistics
+ elif sys.platform == "darwin":
+ from ._io_kqueue import _KqueueStatistics
+ elif sys.platform == "linux":
+ from ._io_epoll import _EpollStatistics
+
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000
_NO_SEND: FinalT = object()
FnT = TypeVar("FnT", bound="Callable[..., Any]")
+T = TypeVar("T")
+
# Decorator to mark methods public. This does nothing by itself, but
# trio/_tools/gen_exports.py looks for it.
@@ -154,7 +175,9 @@ class IdlePrimedTypes(enum.Enum):
################################################################
-def collapse_exception_group(excgroup):
+def collapse_exception_group(
+ excgroup: BaseExceptionGroup[BaseException],
+) -> BaseExceptionGroup[BaseException] | BaseException:
"""Recursively collapse any single-exception groups into that single contained
exception.
@@ -174,7 +197,7 @@ def collapse_exception_group(excgroup):
)
return exceptions[0]
elif modified:
- return excgroup.derive(exceptions)
+ return excgroup.derive(exceptions) # type: ignore[no-any-return]
else:
return excgroup
@@ -189,18 +212,18 @@ class Deadlines:
"""
# Heap of (deadline, id(CancelScope), CancelScope)
- _heap = attr.ib(factory=list)
+ _heap: list[tuple[float, int, CancelScope]] = attr.ib(factory=list)
# Count of active deadlines (those that haven't been changed)
- _active = attr.ib(default=0)
+ _active: int = attr.ib(default=0)
- def add(self, deadline, cancel_scope):
+ def add(self, deadline: float, cancel_scope: CancelScope) -> None:
heappush(self._heap, (deadline, id(cancel_scope), cancel_scope))
self._active += 1
- def remove(self, deadline, cancel_scope):
+ def remove(self, deadline: float, cancel_scope: CancelScope) -> None:
self._active -= 1
- def next_deadline(self):
+ def next_deadline(self) -> float:
while self._heap:
deadline, _, cancel_scope = self._heap[0]
if deadline == cancel_scope._registered_deadline:
@@ -210,7 +233,7 @@ def next_deadline(self):
heappop(self._heap)
return inf
- def _prune(self):
+ def _prune(self) -> None:
# In principle, it's possible for a cancel scope to toggle back and
# forth repeatedly between the same two deadlines, and end up with
# lots of stale entries that *look* like they're still active, because
@@ -231,7 +254,7 @@ def _prune(self):
heapify(pruned_heap)
self._heap = pruned_heap
- def expire(self, now):
+ def expire(self, now: float) -> bool:
did_something = False
while self._heap and self._heap[0][0] <= now:
deadline, _, cancel_scope = heappop(self._heap)
@@ -382,14 +405,14 @@ def close(self) -> None:
child.recalculate()
@property
- def parent_cancellation_is_visible_to_us(self):
+ def parent_cancellation_is_visible_to_us(self) -> bool:
return (
self._parent is not None
and not self._scope.shield
and self._parent.effectively_cancelled
)
- def recalculate(self):
+ def recalculate(self) -> None:
# This does a depth-first traversal over this and descendent cancel
# statuses, to ensure their state is up-to-date. It's basically a
# recursive algorithm, but we use an explicit stack to avoid any
@@ -408,7 +431,7 @@ def recalculate(self):
task._attempt_delivery_of_any_pending_cancel()
todo.extend(current._children)
- def _mark_abandoned(self):
+ def _mark_abandoned(self) -> None:
self.abandoned_by_misnesting = True
for child in self._children:
child._mark_abandoned()
@@ -496,7 +519,7 @@ class CancelScope(metaclass=Final):
_shield: bool = attr.ib(default=False, kw_only=True)
@enable_ki_protection
- def __enter__(self):
+ def __enter__(self) -> Self:
task = _core.current_task()
if self._has_been_entered:
raise RuntimeError(
@@ -510,7 +533,7 @@ def __enter__(self):
task._activate_cancel_status(self._cancel_status)
return self
- def _close(self, exc):
+ def _close(self, exc: BaseException | None) -> BaseException | None:
if self._cancel_status is None:
new_exc = RuntimeError(
"Cancel scope stack corrupted: attempted to exit {!r} "
@@ -548,6 +571,7 @@ def _close(self, exc):
# CancelStatus.close() will take care of the plumbing;
# we just need to make sure we don't let the error
# pass silently.
+ assert scope_task._cancel_status is not None
new_exc = RuntimeError(
"Cancel scope stack corrupted: attempted to exit {!r} "
"in {!r} that's still within its child {!r}\n{}".format(
@@ -789,10 +813,10 @@ def cancel_called(self) -> bool:
# sense.
@attr.s(eq=False, hash=False, repr=False)
class TaskStatus(metaclass=Final):
- _old_nursery = attr.ib()
- _new_nursery = attr.ib()
- _called_started = attr.ib(default=False)
- _value = attr.ib(default=None)
+ _old_nursery: Nursery = attr.ib()
+ _new_nursery: Nursery = attr.ib()
+ _called_started: bool = attr.ib(default=False)
+ _value: Any = attr.ib(default=None)
def __repr__(self) -> str:
return f""
@@ -807,6 +831,7 @@ def started(self, value: Any = None) -> None:
# will eventually exit on its own, and we don't want to risk moving
# children that might have propagating Cancelled exceptions into
# a place with no cancelled cancel scopes to catch them.
+ assert self._old_nursery._cancel_status is not None
if self._old_nursery._cancel_status.effectively_cancelled:
return
@@ -830,6 +855,7 @@ def started(self, value: Any = None) -> None:
# do something evil like cancel the old nursery. We thus break
# everything off from the old nursery before we start attaching
# anything to the new.
+ assert self._old_nursery._cancel_status is not None
cancel_status_children = self._old_nursery._cancel_status.children
cancel_status_tasks = set(self._old_nursery._cancel_status.tasks)
cancel_status_tasks.discard(self._old_nursery._parent_task)
@@ -1002,20 +1028,22 @@ def _add_exc(self, exc: BaseException) -> None:
self._pending_excs.append(exc)
self.cancel_scope.cancel()
- def _check_nursery_closed(self):
+ def _check_nursery_closed(self) -> None:
if not any([self._nested_child_running, self._children, self._pending_starts]):
self._closed = True
if self._parent_waiting_in_aexit:
self._parent_waiting_in_aexit = False
GLOBAL_RUN_CONTEXT.runner.reschedule(self._parent_task)
- def _child_finished(self, task, outcome):
+ def _child_finished(self, task: Task, outcome: Outcome) -> None:
self._children.remove(task)
if isinstance(outcome, Error):
self._add_exc(outcome.error)
self._check_nursery_closed()
- async def _nested_child_finished(self, nested_child_exc):
+ async def _nested_child_finished(
+ self, nested_child_exc: BaseException | None
+ ) -> MultiError | None:
# Returns MultiError instance (or any exception if the nursery is in loose mode
# and there is just one contained exception) if there are pending exceptions
if nested_child_exc is not None:
@@ -1053,8 +1081,14 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
# avoid a garbage cycle
# (see test_nursery_cancel_doesnt_create_cyclic_garbage)
del self._pending_excs
+ return None
- def start_soon(self, async_fn, *args, name=None):
+ def start_soon(
+ self,
+ async_fn: Callable[..., Awaitable[object]],
+ *args: object,
+ name: str | None = None,
+ ) -> None:
"""Creates a child task, scheduling ``await async_fn(*args)``.
If you want to run a function and immediately wait for its result,
@@ -1096,7 +1130,12 @@ def start_soon(self, async_fn, *args, name=None):
"""
GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name)
- async def start(self, async_fn, *args, name=None):
+ async def start(
+ self,
+ async_fn: Callable[..., Awaitable[object]],
+ *args: object,
+ name: str | None = None,
+ ) -> Value:
r"""Creates and initializes a child task.
Like :meth:`start_soon`, but blocks until the new task has
@@ -1295,9 +1334,9 @@ def print_stack_for_task(task):
# The CancelStatus object that is currently active for this task.
# Don't change this directly; instead, use _activate_cancel_status().
- _cancel_status: CancelStatus = attr.ib(default=None, repr=False)
+ _cancel_status: CancelStatus | None = attr.ib(default=None, repr=False)
- def _activate_cancel_status(self, cancel_status: CancelStatus) -> None:
+ def _activate_cancel_status(self, cancel_status: CancelStatus | None) -> None:
if self._cancel_status is not None:
self._cancel_status._tasks.remove(self)
self._cancel_status = cancel_status
@@ -1328,10 +1367,11 @@ def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None:
def _attempt_delivery_of_any_pending_cancel(self) -> None:
if self._abort_func is None:
return
+ assert self._cancel_status is not None
if not self._cancel_status.effectively_cancelled:
return
- def raise_cancel():
+ def raise_cancel() -> NoReturn:
raise Cancelled._create()
self._attempt_abort(raise_cancel)
@@ -1360,14 +1400,24 @@ class RunContext(threading.local):
GLOBAL_RUN_CONTEXT: FinalT = RunContext()
+if TYPE_CHECKING:
+ if sys.platform == "win32":
+ IO_STATISTICS_TYPE: TypeAlias = _WindowsStatistics
+ elif sys.platform == "darwin":
+ IO_STATISTICS_TYPE: TypeAlias = _KqueueStatistics
+ elif sys.platform == "linux":
+ IO_STATISTICS_TYPE: TypeAlias = _EpollStatistics
+else:
+ IO_STATISTICS_TYPE = None
+
@attr.s(frozen=True)
class _RunStatistics:
- tasks_living = attr.ib()
- tasks_runnable = attr.ib()
- seconds_to_next_deadline = attr.ib()
- io_statistics = attr.ib()
- run_sync_soon_queue_size = attr.ib()
+ tasks_living: int = attr.ib()
+ tasks_runnable: int = attr.ib()
+ seconds_to_next_deadline: float = attr.ib()
+ io_statistics: IO_STATISTICS_TYPE = attr.ib()
+ run_sync_soon_queue_size: int = attr.ib()
# This holds all the state that gets trampolined back and forth between
@@ -1392,14 +1442,14 @@ class _RunStatistics:
@attr.s(eq=False, hash=False, slots=True)
class GuestState:
runner: Runner = attr.ib()
- run_sync_soon_threadsafe: Callable = attr.ib()
- run_sync_soon_not_threadsafe: Callable = attr.ib()
- done_callback: Callable = attr.ib()
- unrolled_run_gen = attr.ib()
+ run_sync_soon_threadsafe: Callable[[Callable[[], None]], None] = attr.ib()
+ run_sync_soon_not_threadsafe: Callable[[Callable[[], None]], None] = attr.ib()
+ done_callback: Callable[[Outcome], None] = attr.ib()
+ unrolled_run_gen: Generator[float, list[tuple[int, int]], None] = attr.ib()
_value_factory: Callable[[], Value] = lambda: Value(None)
unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome)
- def guest_tick(self):
+ def guest_tick(self) -> None:
try:
timeout = self.unrolled_run_next_send.send(self.unrolled_run_gen)
except StopIteration:
@@ -1420,11 +1470,11 @@ def guest_tick(self):
# Need to go into the thread and call get_events() there
self.runner.guest_tick_scheduled = False
- def get_events():
+ def get_events() -> list[tuple[int, int]]:
return self.runner.io_manager.get_events(timeout)
- def deliver(events_outcome):
- def in_main_thread():
+ def deliver(events_outcome: Outcome) -> None:
+ def in_main_thread() -> None:
self.unrolled_run_next_send = events_outcome
self.runner.guest_tick_scheduled = True
self.guest_tick()
@@ -1439,41 +1489,41 @@ class Runner:
clock: SystemClock | Clock | MockClock = attr.ib()
instruments: Instruments = attr.ib()
io_manager: TheIOManager = attr.ib()
- ki_manager = attr.ib()
- strict_exception_groups = attr.ib()
+ ki_manager: KIManager = attr.ib()
+ strict_exception_groups: bool = attr.ib()
# Run-local values, see _local.py
- _locals = attr.ib(factory=dict)
+ _locals: dict[RunVar[Any], Any] = attr.ib(factory=dict)
runq: deque[Task] = attr.ib(factory=deque)
tasks: set[Task] = attr.ib(factory=set)
- deadlines = attr.ib(factory=Deadlines)
+ deadlines: Deadlines = attr.ib(factory=Deadlines)
init_task: Task | None = attr.ib(default=None)
- system_nursery = attr.ib(default=None)
- system_context = attr.ib(default=None)
- main_task = attr.ib(default=None)
- main_task_outcome = attr.ib(default=None)
+ system_nursery: Nursery | None = attr.ib(default=None)
+ system_context: Context | None = attr.ib(default=None)
+ main_task: Task | None = attr.ib(default=None)
+ main_task_outcome: Outcome | None = attr.ib(default=None)
- entry_queue = attr.ib(factory=EntryQueue)
+ entry_queue: EntryQueue = attr.ib(factory=EntryQueue)
trio_token: TrioToken | None = attr.ib(default=None)
- asyncgens = attr.ib(factory=AsyncGenerators)
+ asyncgens: AsyncGenerators = attr.ib(factory=AsyncGenerators)
# If everything goes idle for this long, we call clock._autojump()
- clock_autojump_threshold = attr.ib(default=inf)
+ clock_autojump_threshold: float = attr.ib(default=inf)
# Guest mode stuff
- is_guest = attr.ib(default=False)
- guest_tick_scheduled = attr.ib(default=False)
+ is_guest: bool = attr.ib(default=False)
+ guest_tick_scheduled: bool = attr.ib(default=False)
- def force_guest_tick_asap(self):
+ def force_guest_tick_asap(self) -> None:
if self.guest_tick_scheduled:
return
self.guest_tick_scheduled = True
self.io_manager.force_wakeup()
- def close(self):
+ def close(self) -> None:
self.io_manager.close()
self.entry_queue.close()
self.asyncgens.close()
@@ -1587,7 +1637,7 @@ def spawn_impl(
async_fn: Callable[..., Awaitable[object]],
args: Iterable[Any],
nursery: Nursery | None,
- name: str | functools.partial | Callable[..., Awaitable[object]] | None,
+ name: str | functools.partial[Any] | Callable[..., Awaitable[object]] | None,
*,
system_task: bool = False,
context: Context | None = None,
@@ -1609,6 +1659,7 @@ def spawn_impl(
######
if context is None:
if system_task:
+ assert self.system_context is not None
context = self.system_context.copy()
else:
context = copy_context()
@@ -1635,7 +1686,7 @@ def spawn_impl(
if not hasattr(coro, "cr_frame"):
# This async function is implemented in C or Cython
- async def python_wrapper(orig_coro):
+ async def python_wrapper(orig_coro: Awaitable[T]) -> T:
return await orig_coro
coro = python_wrapper(coro)
@@ -1660,7 +1711,7 @@ async def python_wrapper(orig_coro):
self.reschedule(task, None)
return task
- def task_exited(self, task, outcome):
+ def task_exited(self, task: Task, outcome: Outcome) -> None:
if (
task._cancel_status is not None
and task._cancel_status.abandoned_by_misnesting
@@ -1699,6 +1750,7 @@ def task_exited(self, task, outcome):
if task is self.main_task:
self.main_task_outcome = outcome
outcome = Value(None)
+ assert task._parent_nursery is not None
task._parent_nursery._child_finished(task, outcome)
if "task_exited" in self.instruments:
@@ -1776,7 +1828,9 @@ def spawn_system_task(
context=context,
)
- async def init(self, async_fn, args):
+ async def init(
+ self, async_fn: Callable[..., Awaitable[object]], args: Iterable[object]
+ ) -> None:
# run_sync_soon task runs here:
async with open_nursery() as run_sync_soon_nursery:
# All other system tasks run here:
@@ -1827,7 +1881,7 @@ def current_trio_token(self) -> TrioToken:
# KI handling
################
- ki_pending = attr.ib(default=False)
+ ki_pending: bool = attr.ib(default=False)
# deliver_ki is broke. Maybe move all the actual logic and state into
# RunToken, and we'll only have one instance per runner? But then we can't
@@ -1836,14 +1890,14 @@ def current_trio_token(self) -> TrioToken:
# keep the class public so people can isinstance() it if they want.
# This gets called from signal context
- def deliver_ki(self):
+ def deliver_ki(self) -> None:
self.ki_pending = True
try:
self.entry_queue.run_sync_soon(self._deliver_ki_cb)
except RunFinishedError:
pass
- def _deliver_ki_cb(self):
+ def _deliver_ki_cb(self) -> None:
if not self.ki_pending:
return
# Can't happen because main_task and run_sync_soon_task are created at
@@ -1860,7 +1914,7 @@ def _deliver_ki_cb(self):
# Quiescing
################
- waiting_for_idle = attr.ib(factory=SortedDict)
+ waiting_for_idle: SortedDict = attr.ib(factory=SortedDict)
@_public
async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None:
@@ -2000,11 +2054,11 @@ def abort(_: RaiseCancelT) -> Abort:
def setup_runner(
- clock,
- instruments,
- restrict_keyboard_interrupt_to_checkpoints,
- strict_exception_groups,
-):
+ clock: Clock | None,
+ instruments: Sequence[Instrument],
+ restrict_keyboard_interrupt_to_checkpoints: bool,
+ strict_exception_groups: bool,
+) -> Runner:
"""Create a Runner object and install it as the GLOBAL_RUN_CONTEXT."""
# It wouldn't be *hard* to support nested calls to run(), but I can't
# think of a single good reason for it, so let's be conservative for
@@ -2013,15 +2067,17 @@ def setup_runner(
raise RuntimeError("Attempted to call run() from inside a run()")
if clock is None:
- clock = SystemClock()
- instruments = Instruments(instruments)
+ _clock: Clock | SystemClock = SystemClock()
+ else:
+ _clock = clock
+ _instruments = Instruments(instruments)
io_manager = TheIOManager()
system_context = copy_context()
ki_manager = KIManager()
runner = Runner(
- clock=clock,
- instruments=instruments,
+ clock=_clock,
+ instruments=_instruments,
io_manager=io_manager,
system_context=system_context,
ki_manager=ki_manager,
@@ -2038,13 +2094,13 @@ def setup_runner(
def run(
- async_fn,
- *args,
- clock=None,
- instruments=(),
+ async_fn: Callable[..., Awaitable[T]],
+ *args: object,
+ clock: Clock | None = None,
+ instruments: Sequence[Instrument] = (),
restrict_keyboard_interrupt_to_checkpoints: bool = False,
strict_exception_groups: bool = False,
-):
+) -> T:
"""Run a Trio-flavored async function, and return the result.
Calling::
@@ -2130,30 +2186,32 @@ def run(
next_send = None
while True:
try:
- timeout = gen.send(next_send)
+ # sending next_send==None here ... should not work??
+ timeout = gen.send(next_send) # type: ignore[arg-type]
except StopIteration:
break
next_send = runner.io_manager.get_events(timeout)
# Inlined copy of runner.main_task_outcome.unwrap() to avoid
# cluttering every single Trio traceback with an extra frame.
if isinstance(runner.main_task_outcome, Value):
- return runner.main_task_outcome.value
+ return runner.main_task_outcome.value # type: ignore[no-any-return]
else:
+ assert runner.main_task_outcome is not None
raise runner.main_task_outcome.error
def start_guest_run(
- async_fn,
- *args,
- run_sync_soon_threadsafe,
- done_callback,
- run_sync_soon_not_threadsafe=None,
+ async_fn: Callable[..., Awaitable[object]],
+ *args: object,
+ run_sync_soon_threadsafe: Callable[[Callable[..., None]], None],
+ done_callback: Callable[[Outcome], None],
+ run_sync_soon_not_threadsafe: Callable[[Callable[..., None]], None] | None = None,
host_uses_signal_set_wakeup_fd: bool = False,
- clock=None,
- instruments=(),
+ clock: Clock | None = None,
+ instruments: tuple[Instrument, ...] = (),
restrict_keyboard_interrupt_to_checkpoints: bool = False,
strict_exception_groups: bool = False,
-):
+) -> None:
"""Start a "guest" run of Trio on top of some other "host" event loop.
Each host loop can only have one guest run at a time.
@@ -2241,10 +2299,10 @@ def my_done_callback(run_outcome):
# straight through.
def unrolled_run(
runner: Runner,
- async_fn,
- args,
+ async_fn: Callable[..., object],
+ args: Iterable[object],
host_uses_signal_set_wakeup_fd: bool = False,
-):
+) -> Generator[float, list[tuple[int, int]], None]:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
__tracebackhide__ = True
@@ -2530,7 +2588,10 @@ def current_effective_deadline() -> float:
float: the effective deadline, as an absolute time.
"""
- return current_task()._cancel_status.effective_deadline()
+ curr_cancel_status = current_task()._cancel_status
+ assert curr_cancel_status is not None
+ return curr_cancel_status.effective_deadline()
+ # return current_task()._cancel_status.effective_deadline()
async def checkpoint() -> None:
@@ -2553,6 +2614,7 @@ async def checkpoint() -> None:
await cancel_shielded_checkpoint()
task = current_task()
task._cancel_points += 1
+ assert task._cancel_status is not None
if task._cancel_status.effectively_cancelled or (
task is task._runner.main_task and task._runner.ki_pending
):
@@ -2576,6 +2638,7 @@ async def checkpoint_if_cancelled() -> None:
"""
task = current_task()
+ assert task._cancel_status is not None
if task._cancel_status.effectively_cancelled or (
task is task._runner.main_task and task._runner.ki_pending
):
diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py
index 823d22a10a..d66bf7c05d 100644
--- a/trio/_core/_thread_cache.py
+++ b/trio/_core/_thread_cache.py
@@ -7,7 +7,7 @@
from functools import partial
from itertools import count
from threading import Lock, Thread
-from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple
+from typing import TYPE_CHECKING, Callable, Optional, Tuple
import outcome
@@ -121,7 +121,7 @@ class WorkerThread:
def __init__(self, thread_cache: ThreadCache):
# should generate stubs for outcome
self._job: Optional[
- Tuple[Callable[[None], None], Callable[[Value], None], str | None]
+ Tuple[Callable[[], object], Callable[[Value], None], str | None]
] = None
self._thread_cache = thread_cache
# This Lock is used in an unconventional way.
@@ -200,7 +200,7 @@ def __init__(self) -> None:
def start_thread_soon(
self,
- fn: Callable[[None], Any] | partial[Any],
+ fn: Callable[[], object] | partial[object],
deliver: Callable[[Value], None],
name: Optional[str] = None,
) -> None:
@@ -216,7 +216,7 @@ def start_thread_soon(
def start_thread_soon(
- fn: Callable[[None], None] | partial[Any],
+ fn: Callable[[], object] | partial[object],
deliver: Callable[[Value], None],
name: Optional[str] = None,
) -> None:
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index e6e074802c..8f75bb13b2 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9424,
+ "completenessScore": 0.9472,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 589,
- "withUnknownType": 36
+ "withKnownType": 592,
+ "withUnknownType": 33
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 0,
@@ -46,13 +46,11 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 554,
- "withUnknownType": 65
+ "withKnownType": 557,
+ "withUnknownType": 62
},
"packageName": "trio",
"symbols": [
- "trio._core._run.Nursery.start",
- "trio._core._run.Nursery.start_soon",
"trio._highlevel_socket.SocketStream.getsockopt",
"trio._highlevel_socket.SocketStream.send_all",
"trio._highlevel_socket.SocketStream.setsockopt",
@@ -101,7 +99,6 @@
"trio.from_thread.run",
"trio.from_thread.run_sync",
"trio.lowlevel.notify_closing",
- "trio.lowlevel.start_guest_run",
"trio.lowlevel.wait_readable",
"trio.lowlevel.wait_writable",
"trio.open_file",
@@ -110,7 +107,6 @@
"trio.open_tcp_listeners",
"trio.open_tcp_stream",
"trio.open_unix_socket",
- "trio.run",
"trio.run_process",
"trio.serve_listeners",
"trio.serve_ssl_over_tcp",
From ae32ea903249f20ff473edb1b810fddcd3a7f919 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 28 Jul 2023 14:25:02 +0200
Subject: [PATCH 33/49] typecheck trio/_core/_unbounded_queue
---
pyproject.toml | 3 +-
trio/_core/__init__.py | 2 +-
trio/_core/_unbounded_queue.py | 62 ++++++++++++++++++++--------------
trio/_tests/verify_types.json | 18 +++-------
4 files changed, 44 insertions(+), 41 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index d479442c7a..1a95434ec5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,7 +56,8 @@ disallow_untyped_defs = true
[[tool.mypy.overrides]]
module = [
"trio._dtls",
- "trio._abc"
+ "trio._abc",
+ "trio._core._unbounded_queue",
]
disallow_incomplete_defs = true
disallow_untyped_defs = true
diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py
index aa898fffe0..0325572376 100644
--- a/trio/_core/__init__.py
+++ b/trio/_core/__init__.py
@@ -62,7 +62,7 @@
temporarily_detach_coroutine_object,
wait_task_rescheduled,
)
-from ._unbounded_queue import UnboundedQueue
+from ._unbounded_queue import UnboundedQueue, UnboundedQueueStats
# Windows imports
if sys.platform == "win32":
diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py
index 9c747749b4..27d36c965d 100644
--- a/trio/_core/_unbounded_queue.py
+++ b/trio/_core/_unbounded_queue.py
@@ -1,17 +1,34 @@
+from __future__ import annotations
+
+from typing import Generic, TypeVar
+
import attr
+from typing_extensions import Self
from .. import _core
from .._deprecate import deprecated
from .._util import Final
+T = TypeVar("T")
+
@attr.s(frozen=True)
-class _UnboundedQueueStats:
- qsize = attr.ib()
- tasks_waiting = attr.ib()
+class UnboundedQueueStats:
+ """An object containing debugging information.
+
+ Currently the following fields are defined:
+
+ * ``qsize``: The number of items currently in the queue.
+ * ``tasks_waiting``: The number of tasks blocked on this queue's
+ :meth:`get_batch` method.
+
+ """
+
+ qsize: int = attr.ib()
+ tasks_waiting: int = attr.ib()
-class UnboundedQueue(metaclass=Final):
+class UnboundedQueue(Generic[T], metaclass=Final):
"""An unbounded queue suitable for certain unusual forms of inter-task
communication.
@@ -41,26 +58,27 @@ class UnboundedQueue(metaclass=Final):
"""
+ # deprecated is not typed
@deprecated(
"0.9.0",
issue=497,
thing="trio.lowlevel.UnboundedQueue",
instead="trio.open_memory_channel(math.inf)",
)
- def __init__(self):
+ def __init__(self) -> None: # type: ignore[misc]
self._lot = _core.ParkingLot()
- self._data = []
+ self._data: list[T] = []
# used to allow handoff from put to the first task in the lot
self._can_get = False
- def __repr__(self):
+ def __repr__(self) -> str:
return f""
- def qsize(self):
+ def qsize(self) -> int:
"""Returns the number of items currently in the queue."""
return len(self._data)
- def empty(self):
+ def empty(self) -> bool:
"""Returns True if the queue is empty, False otherwise.
There is some subtlety to interpreting this method's return value: see
@@ -70,7 +88,7 @@ def empty(self):
return not self._data
@_core.enable_ki_protection
- def put_nowait(self, obj):
+ def put_nowait(self, obj: T) -> None:
"""Put an object into the queue, without blocking.
This always succeeds, because the queue is unbounded. We don't provide
@@ -88,13 +106,13 @@ def put_nowait(self, obj):
self._can_get = True
self._data.append(obj)
- def _get_batch_protected(self):
+ def _get_batch_protected(self) -> list[T]:
data = self._data.copy()
self._data.clear()
self._can_get = False
return data
- def get_batch_nowait(self):
+ def get_batch_nowait(self) -> list[T]:
"""Attempt to get the next batch from the queue, without blocking.
Returns:
@@ -110,7 +128,7 @@ def get_batch_nowait(self):
raise _core.WouldBlock
return self._get_batch_protected()
- async def get_batch(self):
+ async def get_batch(self) -> list[T]:
"""Get the next batch from the queue, blocking as necessary.
Returns:
@@ -128,22 +146,14 @@ async def get_batch(self):
finally:
await _core.cancel_shielded_checkpoint()
- def statistics(self):
- """Return an object containing debugging information.
-
- Currently the following fields are defined:
-
- * ``qsize``: The number of items currently in the queue.
- * ``tasks_waiting``: The number of tasks blocked on this queue's
- :meth:`get_batch` method.
-
- """
- return _UnboundedQueueStats(
+ def statistics(self) -> UnboundedQueueStats:
+ """Return an UnboundedQueueStats object containing debugging information."""
+ return UnboundedQueueStats(
qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting
)
- def __aiter__(self):
+ def __aiter__(self) -> Self:
return self
- async def __anext__(self):
+ async def __anext__(self) -> list[T]:
return await self.get_batch()
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index d08c03060c..f1276037a2 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.888,
+ "completenessScore": 0.8896,
"exportedSymbolCounts": {
"withAmbiguousType": 1,
- "withKnownType": 555,
- "withUnknownType": 69
+ "withKnownType": 556,
+ "withUnknownType": 68
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
@@ -46,8 +46,8 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 3,
- "withKnownType": 529,
- "withUnknownType": 102
+ "withKnownType": 551,
+ "withUnknownType": 93
},
"packageName": "trio",
"symbols": [
@@ -63,14 +63,6 @@
"trio._core._run.Nursery.start_soon",
"trio._core._run.TaskStatus.__repr__",
"trio._core._run.TaskStatus.started",
- "trio._core._unbounded_queue.UnboundedQueue.__aiter__",
- "trio._core._unbounded_queue.UnboundedQueue.__anext__",
- "trio._core._unbounded_queue.UnboundedQueue.__repr__",
- "trio._core._unbounded_queue.UnboundedQueue.empty",
- "trio._core._unbounded_queue.UnboundedQueue.get_batch",
- "trio._core._unbounded_queue.UnboundedQueue.get_batch_nowait",
- "trio._core._unbounded_queue.UnboundedQueue.qsize",
- "trio._core._unbounded_queue.UnboundedQueue.statistics",
"trio._dtls.DTLSChannel.__init__",
"trio._dtls.DTLSEndpoint.__init__",
"trio._dtls.DTLSEndpoint.serve",
From ea1998894ab26fa4658515b2a2ba414bd7031d41 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 28 Jul 2023 14:40:06 +0200
Subject: [PATCH 34/49] fix CI
---
trio/_core/__init__.py | 2 +-
trio/_core/_unbounded_queue.py | 14 ++++++++------
2 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py
index 0325572376..8e42d2743b 100644
--- a/trio/_core/__init__.py
+++ b/trio/_core/__init__.py
@@ -62,7 +62,7 @@
temporarily_detach_coroutine_object,
wait_task_rescheduled,
)
-from ._unbounded_queue import UnboundedQueue, UnboundedQueueStats
+from ._unbounded_queue import UnboundedQueue, UnboundedQueueStatistics
# Windows imports
if sys.platform == "win32":
diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py
index 27d36c965d..7659845a0f 100644
--- a/trio/_core/_unbounded_queue.py
+++ b/trio/_core/_unbounded_queue.py
@@ -1,9 +1,8 @@
from __future__ import annotations
-from typing import Generic, TypeVar
+from typing import TYPE_CHECKING, Generic, TypeVar
import attr
-from typing_extensions import Self
from .. import _core
from .._deprecate import deprecated
@@ -11,9 +10,12 @@
T = TypeVar("T")
+if TYPE_CHECKING:
+ from typing_extensions import Self
+
@attr.s(frozen=True)
-class UnboundedQueueStats:
+class UnboundedQueueStatistics:
"""An object containing debugging information.
Currently the following fields are defined:
@@ -146,9 +148,9 @@ async def get_batch(self) -> list[T]:
finally:
await _core.cancel_shielded_checkpoint()
- def statistics(self) -> UnboundedQueueStats:
- """Return an UnboundedQueueStats object containing debugging information."""
- return UnboundedQueueStats(
+ def statistics(self) -> UnboundedQueueStatistics:
+ """Return an UnboundedQueueStatistics object containing debugging information."""
+ return UnboundedQueueStatistics(
qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting
)
From b4621e550d81e37c1f1509273d7537cb537bda2b Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 28 Jul 2023 14:46:11 +0200
Subject: [PATCH 35/49] fix test
---
trio/_tests/verify_types.json | 4 ++--
trio/lowlevel.py | 1 +
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index f1276037a2..af3283c2ef 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,10 +7,10 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.8896,
+ "completenessScore": 0.889776357827476,
"exportedSymbolCounts": {
"withAmbiguousType": 1,
- "withKnownType": 556,
+ "withKnownType": 557,
"withUnknownType": 68
},
"ignoreUnknownTypesFromImports": true,
diff --git a/trio/lowlevel.py b/trio/lowlevel.py
index 54f4ef3141..36d23d5955 100644
--- a/trio/lowlevel.py
+++ b/trio/lowlevel.py
@@ -17,6 +17,7 @@
Task as Task,
TrioToken as TrioToken,
UnboundedQueue as UnboundedQueue,
+ UnboundedQueueStatistics as UnboundedQueueStatistics,
add_instrument as add_instrument,
cancel_shielded_checkpoint as cancel_shielded_checkpoint,
checkpoint as checkpoint,
From d2511967de12feae0868940b8eec3b932a8b571d Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 28 Jul 2023 15:44:37 +0200
Subject: [PATCH 36/49] fix Statistics not defining slots breaking tests
---
trio/_core/_unbounded_queue.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py
index 7659845a0f..94348bfc26 100644
--- a/trio/_core/_unbounded_queue.py
+++ b/trio/_core/_unbounded_queue.py
@@ -14,7 +14,7 @@
from typing_extensions import Self
-@attr.s(frozen=True)
+@attr.s(slots=True, frozen=True)
class UnboundedQueueStatistics:
"""An object containing debugging information.
From 34341c2c3592e4d109aaf0ecad6fabce5a54f6a2 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 28 Jul 2023 15:49:13 +0200
Subject: [PATCH 37/49] Any -> object in _entry_queue
---
trio/_core/_entry_queue.py | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py
index 553143ceb1..68e1a89180 100644
--- a/trio/_core/_entry_queue.py
+++ b/trio/_core/_entry_queue.py
@@ -2,7 +2,7 @@
import threading
from collections import deque
-from typing import Any, Callable, Iterable, Literal, NoReturn
+from typing import Callable, Iterable, Literal, NoReturn
import attr
@@ -20,10 +20,12 @@ class EntryQueue:
# atomic WRT signal delivery (signal handlers can run on either side, but
# not *during* a deque operation). dict makes similar guarantees - and
# it's even ordered!
- queue: deque[tuple[Callable[..., Any], Iterable[Any]]] = attr.ib(factory=deque)
- idempotent_queue: dict[tuple[Callable[..., Any], Iterable[Any]], None] = attr.ib(
- factory=dict
+ queue: deque[tuple[Callable[..., object], Iterable[object]]] = attr.ib(
+ factory=deque
)
+ idempotent_queue: dict[
+ tuple[Callable[..., object], Iterable[object]], None
+ ] = attr.ib(factory=dict)
wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
done: bool = attr.ib(default=False)
@@ -48,7 +50,9 @@ async def task(self) -> None:
# https://bugs.python.org/issue13697#msg237140
assert self.lock.__class__.__module__ == "_thread"
- def run_cb(job: tuple[Callable[..., object], Iterable[Any]]) -> Literal[True]:
+ def run_cb(
+ job: tuple[Callable[..., object], Iterable[object]]
+ ) -> Literal[True]:
# We run this with KI protection enabled; it's the callback's
# job to disable it if it wants it disabled. Exceptions are
# treated like system task exceptions (i.e., converted into
From cf20b631a3a6bd9ea0ab20e7d094a256e4a541d2 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 28 Jul 2023 16:34:42 +0200
Subject: [PATCH 38/49] adding py.typed, for the fun of it
---
trio/py.typed | 0
1 file changed, 0 insertions(+), 0 deletions(-)
create mode 100644 trio/py.typed
diff --git a/trio/py.typed b/trio/py.typed
new file mode 100644
index 0000000000..e69de29bb2
From dc8c18cc1bb02f6a9e454a9adbec15178df1d75c Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 6 Aug 2023 11:50:22 +0200
Subject: [PATCH 39/49] .
---
pyproject.toml | 1 -
trio/__init__.py | 2 --
2 files changed, 3 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index dbdf1bff5a..c0619d58e6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,7 +52,6 @@ disallow_any_unimported = false
# DO NOT use `ignore_errors`; it doesn't apply
# downstream and users have to deal with them.
[[tool.mypy.overrides]]
-# Fully typed, enable stricter checks
module = [
"trio/_core/_tests/*",
"trio/_tests/*",
diff --git a/trio/__init__.py b/trio/__init__.py
index c193fe58e3..8db5439d70 100644
--- a/trio/__init__.py
+++ b/trio/__init__.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
"""Trio - A friendly Python library for async concurrency and I/O
"""
from __future__ import annotations
From 9bde0a1b2f1351972c0b94a2fc3bbfbe258ad241 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 10 Aug 2023 16:46:06 +0200
Subject: [PATCH 40/49] reset _core/_run in expectation of #2733
---
pyproject.toml | 46 +++--
trio/_core/_generated_run.py | 17 +-
trio/_core/_local.py | 8 +-
trio/_core/_run.py | 330 ++++++++++++----------------------
trio/_tests/verify_types.json | 21 ++-
5 files changed, 174 insertions(+), 248 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 93e5e71ea9..037e0d4403 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,7 +44,7 @@ disallow_untyped_defs = true
# Enable gradually / for new modules
check_untyped_defs = false
disallow_untyped_calls = false
-disallow_any_unimported = false
+disallow_any_unimported = true
@@ -55,6 +55,24 @@ module = [
"trio/_core/_tests/*",
"trio/_tests/*",
+# 2749
+"trio/_threads", # 15, 398 lines
+# 2747
+"trio/testing/_network", # 1, 34
+"trio/testing/_trio_test", # 2, 29
+"trio/testing/_checkpoints", # 3, 62
+"trio/testing/_check_streams", # 27, 522
+"trio/testing/_memory_streams", # 66, 590
+# 2745
+"trio/_ssl", # 26, 929 lines
+# 2742
+"trio/_core/_multierror", # 19, 469
+# 2735 trio/_core/_asyncgens
+
+# 2733
+"trio/_core/_run",
+"trio/_core/_generated_run",
+
# windows
"trio/_windows_pipes",
"trio/_core/_windows_cffi", # 2, 324
@@ -62,15 +80,8 @@ module = [
"trio/_core/_io_windows", # 47 (win32), 867
"trio/_wait_for_object", # 2 (windows)
-"trio/_core/_multierror", # 19, 469
-
-"trio/testing/_network", # 1, 34
-"trio/testing/_trio_test", # 2, 29
-"trio/testing/_checkpoints", # 3, 62
-"trio/testing/_check_streams", # 27, 522
"trio/testing/_fake_net", # 30
-"trio/testing/_memory_streams", # 66, 590
"trio/_highlevel_open_unix_stream", # 2, 49 lines
"trio/_highlevel_open_tcp_listeners", # 3, 227 lines
@@ -80,26 +91,27 @@ module = [
"trio/_subprocess_platform/waitid", # 2, 107 lines
"trio/_signals", # 13, 168 lines
-"trio/_threads", # 15, 398 lines
"trio/_subprocess", # 21, 759 lines
-"trio/_ssl", # 26, 929 lines
]
-disallow_untyped_defs = false
+disallow_any_decorated = false
disallow_any_generics = false
+disallow_any_unimported = false
disallow_incomplete_defs = false
-disallow_any_decorated = false
+disallow_untyped_defs = false
+
+[[tool.mypy.overrides]]
+# awaiting typing of OutCome
+module = [
+ "trio._core._traps",
+]
+disallow_any_unimported = false
[[tool.mypy.overrides]]
# Needs to use Any due to some complex introspection.
module = [
"trio._path",
]
-disallow_incomplete_defs = true
-disallow_untyped_defs = true
disallow_any_generics = false
-disallow_any_decorated = true
-disallow_any_unimported = true
-disallow_subclassing_any = true
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py
index 038ef0e5e2..674c86aaec 100644
--- a/trio/_core/_generated_run.py
+++ b/trio/_core/_generated_run.py
@@ -11,7 +11,7 @@
from ._run import _NO_SEND
-def current_statistics() ->_RunStatistics:
+def current_statistics():
"""Returns an object containing run-loop-level debugging information.
Currently the following fields are defined:
@@ -41,7 +41,7 @@ def current_statistics() ->_RunStatistics:
raise RuntimeError("must be called from async context")
-def current_time() ->float:
+def current_time():
"""Returns the current time according to Trio's internal clock.
Returns:
@@ -58,7 +58,7 @@ def current_time() ->float:
raise RuntimeError("must be called from async context")
-def current_clock() ->(SystemClock | Clock):
+def current_clock():
"""Returns the current :class:`~trio.abc.Clock`."""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
@@ -67,7 +67,7 @@ def current_clock() ->(SystemClock | Clock):
raise RuntimeError("must be called from async context")
-def current_root_task() ->(Task | None):
+def current_root_task():
"""Returns the current root :class:`Task`.
This is the task that is the ultimate parent of all other tasks.
@@ -80,7 +80,7 @@ def current_root_task() ->(Task | None):
raise RuntimeError("must be called from async context")
-def reschedule(task: Task, next_send: Outcome=_NO_SEND) ->None:
+def reschedule(task, next_send=_NO_SEND):
"""Reschedule the given task with the given
:class:`outcome.Outcome`.
@@ -105,8 +105,7 @@ def reschedule(task: Task, next_send: Outcome=_NO_SEND) ->None:
raise RuntimeError("must be called from async context")
-def spawn_system_task(async_fn: Callable[..., Awaitable[object]], *args:
- Any, name: (str | None)=None, context: (Context | None)=None) ->Task:
+def spawn_system_task(async_fn, *args, name=None, context=None):
"""Spawn a "system" task.
System tasks have a few differences from regular tasks:
@@ -165,7 +164,7 @@ def spawn_system_task(async_fn: Callable[..., Awaitable[object]], *args:
raise RuntimeError("must be called from async context")
-def current_trio_token() ->TrioToken:
+def current_trio_token():
"""Retrieve the :class:`TrioToken` for the current call to
:func:`trio.run`.
@@ -177,7 +176,7 @@ def current_trio_token() ->TrioToken:
raise RuntimeError("must be called from async context")
-async def wait_all_tasks_blocked(cushion: float=0.0) ->None:
+async def wait_all_tasks_blocked(cushion=0.0):
"""Block until there are no runnable tasks.
This is useful in testing code when you want to give other tasks a
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index ee74f36c49..4f267ba006 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -44,7 +44,7 @@ def get(self, default: T | type[_NoValue] = _NoValue) -> T:
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
# not typed yet
- return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[no-any-return]
+ return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value, index]
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
@@ -72,7 +72,7 @@ def set(self, value: T) -> RunVarToken[T]:
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[index,assignment]
return token
def reset(self, token: RunVarToken[T]) -> None:
@@ -92,9 +92,9 @@ def reset(self, token: RunVarToken[T]) -> None:
previous = token.previous_value
try:
if previous is _NoValue:
- _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
+ _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type]
else:
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index,assignment]
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context")
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 3361dbe375..7d247a2738 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -17,16 +17,7 @@
from math import inf
from time import perf_counter
from types import TracebackType
-from typing import (
- TYPE_CHECKING,
- Any,
- Awaitable,
- Generator,
- Iterable,
- NoReturn,
- Sequence,
- TypeVar,
-)
+from typing import TYPE_CHECKING, Any, NoReturn, TypeVar
import attr
from outcome import Error, Outcome, Value, capture
@@ -46,7 +37,6 @@
Abort,
CancelShieldedCheckpoint,
PermanentlyDetachCoroutineObject,
- RaiseCancelT,
WaitTaskRescheduled,
cancel_shielded_checkpoint,
wait_task_rescheduled,
@@ -58,34 +48,17 @@
from types import FrameType
if TYPE_CHECKING:
- from contextvars import Context
+ import contextvars
# An unfortunate name collision here with trio._util.Final
from typing import Final as FinalT
- from typing_extensions import Self, TypeAlias
-
- from .._abc import Clock
- from ._local import RunVar
- from ._mock_clock import MockClock
-
- if sys.platform == "win32":
- from ._io_windows import _WindowsStatistics
- elif sys.platform == "darwin":
- from select import kevent
-
- from ._io_kqueue import _KqueueStatistics
- elif sys.platform == "linux":
- from ._io_epoll import _EpollStatistics
-
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000
_NO_SEND: FinalT = object()
FnT = TypeVar("FnT", bound="Callable[..., Any]")
-T = TypeVar("T")
-
# Decorator to mark methods public. This does nothing by itself, but
# trio/_tools/gen_exports.py looks for it.
@@ -146,7 +119,6 @@ def function_with_unique_name_xyzzy() -> NoReturn:
CONTEXT_RUN_TB_FRAMES: FinalT = _count_context_run_tb_frames()
-# Why doesn't this inherit from abc.Clock?
@attr.s(frozen=True, slots=True)
class SystemClock:
# Add a large random offset to our clock to ensure that if people
@@ -177,9 +149,7 @@ class IdlePrimedTypes(enum.Enum):
################################################################
-def collapse_exception_group(
- excgroup: BaseExceptionGroup[BaseException],
-) -> BaseExceptionGroup[BaseException] | BaseException:
+def collapse_exception_group(excgroup):
"""Recursively collapse any single-exception groups into that single contained
exception.
@@ -199,7 +169,7 @@ def collapse_exception_group(
)
return exceptions[0]
elif modified:
- return excgroup.derive(exceptions) # type: ignore[no-any-return]
+ return excgroup.derive(exceptions)
else:
return excgroup
@@ -214,18 +184,18 @@ class Deadlines:
"""
# Heap of (deadline, id(CancelScope), CancelScope)
- _heap: list[tuple[float, int, CancelScope]] = attr.ib(factory=list)
+ _heap = attr.ib(factory=list)
# Count of active deadlines (those that haven't been changed)
- _active: int = attr.ib(default=0)
+ _active = attr.ib(default=0)
- def add(self, deadline: float, cancel_scope: CancelScope) -> None:
+ def add(self, deadline, cancel_scope):
heappush(self._heap, (deadline, id(cancel_scope), cancel_scope))
self._active += 1
- def remove(self, deadline: float, cancel_scope: CancelScope) -> None:
+ def remove(self, deadline, cancel_scope):
self._active -= 1
- def next_deadline(self) -> float:
+ def next_deadline(self):
while self._heap:
deadline, _, cancel_scope = self._heap[0]
if deadline == cancel_scope._registered_deadline:
@@ -235,7 +205,7 @@ def next_deadline(self) -> float:
heappop(self._heap)
return inf
- def _prune(self) -> None:
+ def _prune(self):
# In principle, it's possible for a cancel scope to toggle back and
# forth repeatedly between the same two deadlines, and end up with
# lots of stale entries that *look* like they're still active, because
@@ -256,7 +226,7 @@ def _prune(self) -> None:
heapify(pruned_heap)
self._heap = pruned_heap
- def expire(self, now: float) -> bool:
+ def expire(self, now):
did_something = False
while self._heap and self._heap[0][0] <= now:
deadline, _, cancel_scope = heappop(self._heap)
@@ -407,14 +377,14 @@ def close(self) -> None:
child.recalculate()
@property
- def parent_cancellation_is_visible_to_us(self) -> bool:
+ def parent_cancellation_is_visible_to_us(self):
return (
self._parent is not None
and not self._scope.shield
and self._parent.effectively_cancelled
)
- def recalculate(self) -> None:
+ def recalculate(self):
# This does a depth-first traversal over this and descendent cancel
# statuses, to ensure their state is up-to-date. It's basically a
# recursive algorithm, but we use an explicit stack to avoid any
@@ -433,7 +403,7 @@ def recalculate(self) -> None:
task._attempt_delivery_of_any_pending_cancel()
todo.extend(current._children)
- def _mark_abandoned(self) -> None:
+ def _mark_abandoned(self):
self.abandoned_by_misnesting = True
for child in self._children:
child._mark_abandoned()
@@ -521,7 +491,7 @@ class CancelScope(metaclass=Final):
_shield: bool = attr.ib(default=False, kw_only=True)
@enable_ki_protection
- def __enter__(self) -> Self:
+ def __enter__(self):
task = _core.current_task()
if self._has_been_entered:
raise RuntimeError(
@@ -535,7 +505,7 @@ def __enter__(self) -> Self:
task._activate_cancel_status(self._cancel_status)
return self
- def _close(self, exc: BaseException | None) -> BaseException | None:
+ def _close(self, exc):
if self._cancel_status is None:
new_exc = RuntimeError(
"Cancel scope stack corrupted: attempted to exit {!r} "
@@ -573,7 +543,6 @@ def _close(self, exc: BaseException | None) -> BaseException | None:
# CancelStatus.close() will take care of the plumbing;
# we just need to make sure we don't let the error
# pass silently.
- assert scope_task._cancel_status is not None
new_exc = RuntimeError(
"Cancel scope stack corrupted: attempted to exit {!r} "
"in {!r} that's still within its child {!r}\n{}".format(
@@ -815,15 +784,15 @@ def cancel_called(self) -> bool:
# sense.
@attr.s(eq=False, hash=False, repr=False)
class TaskStatus(metaclass=Final):
- _old_nursery: Nursery = attr.ib()
- _new_nursery: Nursery = attr.ib()
- _called_started: bool = attr.ib(default=False)
- _value: Any = attr.ib(default=None)
+ _old_nursery = attr.ib()
+ _new_nursery = attr.ib()
+ _called_started = attr.ib(default=False)
+ _value = attr.ib(default=None)
- def __repr__(self) -> str:
+ def __repr__(self):
return f""
- def started(self, value: Any = None) -> None:
+ def started(self, value=None):
if self._called_started:
raise RuntimeError("called 'started' twice on the same task status")
self._called_started = True
@@ -833,7 +802,6 @@ def started(self, value: Any = None) -> None:
# will eventually exit on its own, and we don't want to risk moving
# children that might have propagating Cancelled exceptions into
# a place with no cancelled cancel scopes to catch them.
- assert self._old_nursery._cancel_status is not None
if self._old_nursery._cancel_status.effectively_cancelled:
return
@@ -857,7 +825,6 @@ def started(self, value: Any = None) -> None:
# do something evil like cancel the old nursery. We thus break
# everything off from the old nursery before we start attaching
# anything to the new.
- assert self._old_nursery._cancel_status is not None
cancel_status_children = self._old_nursery._cancel_status.children
cancel_status_tasks = set(self._old_nursery._cancel_status.tasks)
cancel_status_tasks.discard(self._old_nursery._parent_task)
@@ -1030,22 +997,20 @@ def _add_exc(self, exc: BaseException) -> None:
self._pending_excs.append(exc)
self.cancel_scope.cancel()
- def _check_nursery_closed(self) -> None:
+ def _check_nursery_closed(self):
if not any([self._nested_child_running, self._children, self._pending_starts]):
self._closed = True
if self._parent_waiting_in_aexit:
self._parent_waiting_in_aexit = False
GLOBAL_RUN_CONTEXT.runner.reschedule(self._parent_task)
- def _child_finished(self, task: Task, outcome: Outcome) -> None:
+ def _child_finished(self, task, outcome):
self._children.remove(task)
if isinstance(outcome, Error):
self._add_exc(outcome.error)
self._check_nursery_closed()
- async def _nested_child_finished(
- self, nested_child_exc: BaseException | None
- ) -> MultiError | None:
+ async def _nested_child_finished(self, nested_child_exc):
# Returns MultiError instance (or any exception if the nursery is in loose mode
# and there is just one contained exception) if there are pending exceptions
if nested_child_exc is not None:
@@ -1057,12 +1022,12 @@ async def _nested_child_finished(
# If we get cancelled (or have an exception injected, like
# KeyboardInterrupt), then save that, but still wait until our
# children finish.
- def abort(raise_cancel: RaiseCancelT) -> Abort:
+ def aborted(raise_cancel):
self._add_exc(capture(raise_cancel).error)
return Abort.FAILED
self._parent_waiting_in_aexit = True
- await wait_task_rescheduled(abort)
+ await wait_task_rescheduled(aborted)
else:
# Nothing to wait for, so just execute a checkpoint -- but we
# still need to mix any exception (e.g. from an external
@@ -1083,14 +1048,8 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
# avoid a garbage cycle
# (see test_nursery_cancel_doesnt_create_cyclic_garbage)
del self._pending_excs
- return None
- def start_soon(
- self,
- async_fn: Callable[..., Awaitable[object]],
- *args: object,
- name: str | None = None,
- ) -> None:
+ def start_soon(self, async_fn, *args, name=None):
"""Creates a child task, scheduling ``await async_fn(*args)``.
If you want to run a function and immediately wait for its result,
@@ -1132,12 +1091,7 @@ def start_soon(
"""
GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name)
- async def start(
- self,
- async_fn: Callable[..., Awaitable[object]],
- *args: object,
- name: str | None = None,
- ) -> Value:
+ async def start(self, async_fn, *args, name=None):
r"""Creates and initializes a child task.
Like :meth:`start_soon`, but blocks until the new task has
@@ -1214,9 +1168,9 @@ def __del__(self) -> None:
class Task(metaclass=NoPublicConstructor):
_parent_nursery: Nursery | None = attr.ib()
coro: Coroutine[Any, Outcome[object], Any] = attr.ib()
- _runner: Runner = attr.ib()
+ _runner = attr.ib()
name: str = attr.ib()
- context: Context = attr.ib()
+ context: contextvars.Context = attr.ib()
_counter: int = attr.ib(init=False, factory=itertools.count().__next__)
# Invariant:
@@ -1230,8 +1184,8 @@ class Task(metaclass=NoPublicConstructor):
# tracebacks with extraneous frames.
# - for scheduled tasks, custom_sleep_data is None
# Tasks start out unscheduled.
- _next_send_fn: Callable[[Outcome | None], None] = attr.ib(default=None)
- _next_send: Outcome | None = attr.ib(default=None)
+ _next_send_fn = attr.ib(default=None)
+ _next_send = attr.ib(default=None)
_abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib(
default=None
)
@@ -1336,9 +1290,9 @@ def print_stack_for_task(task):
# The CancelStatus object that is currently active for this task.
# Don't change this directly; instead, use _activate_cancel_status().
- _cancel_status: CancelStatus | None = attr.ib(default=None, repr=False)
+ _cancel_status: CancelStatus = attr.ib(default=None, repr=False)
- def _activate_cancel_status(self, cancel_status: CancelStatus | None) -> None:
+ def _activate_cancel_status(self, cancel_status: CancelStatus) -> None:
if self._cancel_status is not None:
self._cancel_status._tasks.remove(self)
self._cancel_status = cancel_status
@@ -1369,11 +1323,10 @@ def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None:
def _attempt_delivery_of_any_pending_cancel(self) -> None:
if self._abort_func is None:
return
- assert self._cancel_status is not None
if not self._cancel_status.effectively_cancelled:
return
- def raise_cancel() -> NoReturn:
+ def raise_cancel():
raise Cancelled._create()
self._attempt_abort(raise_cancel)
@@ -1402,32 +1355,14 @@ class RunContext(threading.local):
GLOBAL_RUN_CONTEXT: FinalT = RunContext()
-if TYPE_CHECKING:
- if sys.platform == "win32":
- IO_STATISTICS_TYPE: TypeAlias = _WindowsStatistics
- elif sys.platform == "darwin":
- IO_STATISTICS_TYPE: TypeAlias = _KqueueStatistics
- elif sys.platform == "linux":
- IO_STATISTICS_TYPE: TypeAlias = _EpollStatistics
-else:
- IO_STATISTICS_TYPE = None
-
@attr.s(frozen=True)
class _RunStatistics:
- tasks_living: int = attr.ib()
- tasks_runnable: int = attr.ib()
- seconds_to_next_deadline: float = attr.ib()
- io_statistics: IO_STATISTICS_TYPE = attr.ib()
- run_sync_soon_queue_size: int = attr.ib()
-
-
-if sys.platform == "linux":
- GetEventsT: TypeAlias = "list[tuple[int, int]]"
-elif sys.platform == "darwin":
- GetEventsT: TypeAlias = "list[kevent]"
-else:
- GetEventsT: TypeAlias = int
+ tasks_living = attr.ib()
+ tasks_runnable = attr.ib()
+ seconds_to_next_deadline = attr.ib()
+ io_statistics = attr.ib()
+ run_sync_soon_queue_size = attr.ib()
# This holds all the state that gets trampolined back and forth between
@@ -1451,15 +1386,15 @@ class _RunStatistics:
# worker thread.
@attr.s(eq=False, hash=False, slots=True)
class GuestState:
- runner: Runner = attr.ib()
- run_sync_soon_threadsafe: Callable[[Callable[[], None]], None] = attr.ib()
- run_sync_soon_not_threadsafe: Callable[[Callable[[], None]], None] = attr.ib()
- done_callback: Callable[[Outcome], None] = attr.ib()
- unrolled_run_gen: Generator[float, GetEventsT, None] = attr.ib()
+ runner = attr.ib()
+ run_sync_soon_threadsafe = attr.ib()
+ run_sync_soon_not_threadsafe = attr.ib()
+ done_callback = attr.ib()
+ unrolled_run_gen = attr.ib()
_value_factory: Callable[[], Value] = lambda: Value(None)
unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome)
- def guest_tick(self) -> None:
+ def guest_tick(self):
try:
timeout = self.unrolled_run_next_send.send(self.unrolled_run_gen)
except StopIteration:
@@ -1480,11 +1415,11 @@ def guest_tick(self) -> None:
# Need to go into the thread and call get_events() there
self.runner.guest_tick_scheduled = False
- def get_events() -> GetEventsT:
+ def get_events():
return self.runner.io_manager.get_events(timeout)
- def deliver(events_outcome: Outcome) -> None:
- def in_main_thread() -> None:
+ def deliver(events_outcome):
+ def in_main_thread():
self.unrolled_run_next_send = events_outcome
self.runner.guest_tick_scheduled = True
self.guest_tick()
@@ -1496,44 +1431,44 @@ def in_main_thread() -> None:
@attr.s(eq=False, hash=False, slots=True)
class Runner:
- clock: SystemClock | Clock | MockClock = attr.ib()
+ clock = attr.ib()
instruments: Instruments = attr.ib()
io_manager: TheIOManager = attr.ib()
- ki_manager: KIManager = attr.ib()
- strict_exception_groups: bool = attr.ib()
+ ki_manager = attr.ib()
+ strict_exception_groups = attr.ib()
# Run-local values, see _local.py
- _locals: dict[RunVar[Any], Any] = attr.ib(factory=dict)
+ _locals = attr.ib(factory=dict)
runq: deque[Task] = attr.ib(factory=deque)
- tasks: set[Task] = attr.ib(factory=set)
+ tasks = attr.ib(factory=set)
- deadlines: Deadlines = attr.ib(factory=Deadlines)
+ deadlines = attr.ib(factory=Deadlines)
- init_task: Task | None = attr.ib(default=None)
- system_nursery: Nursery | None = attr.ib(default=None)
- system_context: Context | None = attr.ib(default=None)
- main_task: Task | None = attr.ib(default=None)
- main_task_outcome: Outcome | None = attr.ib(default=None)
+ init_task = attr.ib(default=None)
+ system_nursery = attr.ib(default=None)
+ system_context = attr.ib(default=None)
+ main_task = attr.ib(default=None)
+ main_task_outcome = attr.ib(default=None)
- entry_queue: EntryQueue = attr.ib(factory=EntryQueue)
- trio_token: TrioToken | None = attr.ib(default=None)
- asyncgens: AsyncGenerators = attr.ib(factory=AsyncGenerators)
+ entry_queue = attr.ib(factory=EntryQueue)
+ trio_token = attr.ib(default=None)
+ asyncgens = attr.ib(factory=AsyncGenerators)
# If everything goes idle for this long, we call clock._autojump()
- clock_autojump_threshold: float = attr.ib(default=inf)
+ clock_autojump_threshold = attr.ib(default=inf)
# Guest mode stuff
- is_guest: bool = attr.ib(default=False)
- guest_tick_scheduled: bool = attr.ib(default=False)
+ is_guest = attr.ib(default=False)
+ guest_tick_scheduled = attr.ib(default=False)
- def force_guest_tick_asap(self) -> None:
+ def force_guest_tick_asap(self):
if self.guest_tick_scheduled:
return
self.guest_tick_scheduled = True
self.io_manager.force_wakeup()
- def close(self) -> None:
+ def close(self):
self.io_manager.close()
self.entry_queue.close()
self.asyncgens.close()
@@ -1543,7 +1478,7 @@ def close(self) -> None:
self.ki_manager.close()
@_public
- def current_statistics(self) -> _RunStatistics:
+ def current_statistics(self):
"""Returns an object containing run-loop-level debugging information.
Currently the following fields are defined:
@@ -1576,7 +1511,7 @@ def current_statistics(self) -> _RunStatistics:
)
@_public
- def current_time(self) -> float:
+ def current_time(self):
"""Returns the current time according to Trio's internal clock.
Returns:
@@ -1588,15 +1523,13 @@ def current_time(self) -> float:
"""
return self.clock.current_time()
- # TODO: abc.Clock or SystemClock? (the latter which doesn't inherit
- # from abc.Clock)
@_public
- def current_clock(self) -> SystemClock | Clock:
+ def current_clock(self):
"""Returns the current :class:`~trio.abc.Clock`."""
return self.clock
@_public
- def current_root_task(self) -> Task | None:
+ def current_root_task(self):
"""Returns the current root :class:`Task`.
This is the task that is the ultimate parent of all other tasks.
@@ -1608,9 +1541,8 @@ def current_root_task(self) -> Task | None:
# Core task handling primitives
################
- # Outcome is not typed
@_public
- def reschedule(self, task: Task, next_send: Outcome = _NO_SEND) -> None: # type: ignore[misc]
+ def reschedule(self, task, next_send=_NO_SEND):
"""Reschedule the given task with the given
:class:`outcome.Outcome`.
@@ -1644,15 +1576,8 @@ def reschedule(self, task: Task, next_send: Outcome = _NO_SEND) -> None: # type
self.instruments.call("task_scheduled", task)
def spawn_impl(
- self,
- async_fn: Callable[..., Awaitable[object]],
- args: Iterable[Any],
- nursery: Nursery | None,
- name: str | functools.partial[Any] | Callable[..., Awaitable[object]] | None,
- *,
- system_task: bool = False,
- context: Context | None = None,
- ) -> Task:
+ self, async_fn, args, nursery, name, *, system_task=False, context=None
+ ):
######
# Make sure the nursery is in working order
######
@@ -1670,7 +1595,6 @@ def spawn_impl(
######
if context is None:
if system_task:
- assert self.system_context is not None
context = self.system_context.copy()
else:
context = copy_context()
@@ -1683,8 +1607,7 @@ def spawn_impl(
# Call the function and get the coroutine object, while giving helpful
# errors for common mistakes.
######
- # TODO: ??
- coro = context.run(coroutine_or_error, async_fn, *args) # type: ignore[arg-type]
+ coro = context.run(coroutine_or_error, async_fn, *args)
if name is None:
name = async_fn
@@ -1698,7 +1621,7 @@ def spawn_impl(
if not hasattr(coro, "cr_frame"):
# This async function is implemented in C or Cython
- async def python_wrapper(orig_coro: Awaitable[T]) -> T:
+ async def python_wrapper(orig_coro):
return await orig_coro
coro = python_wrapper(coro)
@@ -1723,7 +1646,7 @@ async def python_wrapper(orig_coro: Awaitable[T]) -> T:
self.reschedule(task, None)
return task
- def task_exited(self, task: Task, outcome: Outcome) -> None:
+ def task_exited(self, task, outcome):
if (
task._cancel_status is not None
and task._cancel_status.abandoned_by_misnesting
@@ -1762,7 +1685,6 @@ def task_exited(self, task: Task, outcome: Outcome) -> None:
if task is self.main_task:
self.main_task_outcome = outcome
outcome = Value(None)
- assert task._parent_nursery is not None
task._parent_nursery._child_finished(task, outcome)
if "task_exited" in self.instruments:
@@ -1772,15 +1694,8 @@ def task_exited(self, task: Task, outcome: Outcome) -> None:
# System tasks and init
################
- # TODO: [misc]typed with Any
@_public
- def spawn_system_task( # type: ignore[misc]
- self,
- async_fn: Callable[..., Awaitable[object]],
- *args: Any,
- name: str | None = None,
- context: Context | None = None,
- ) -> Task:
+ def spawn_system_task(self, async_fn, *args, name=None, context=None):
"""Spawn a "system" task.
System tasks have a few differences from regular tasks:
@@ -1841,9 +1756,7 @@ def spawn_system_task( # type: ignore[misc]
context=context,
)
- async def init(
- self, async_fn: Callable[..., Awaitable[object]], args: Iterable[object]
- ) -> None:
+ async def init(self, async_fn, args):
# run_sync_soon task runs here:
async with open_nursery() as run_sync_soon_nursery:
# All other system tasks run here:
@@ -1881,7 +1794,7 @@ async def init(
################
@_public
- def current_trio_token(self) -> TrioToken:
+ def current_trio_token(self):
"""Retrieve the :class:`TrioToken` for the current call to
:func:`trio.run`.
@@ -1894,7 +1807,7 @@ def current_trio_token(self) -> TrioToken:
# KI handling
################
- ki_pending: bool = attr.ib(default=False)
+ ki_pending = attr.ib(default=False)
# deliver_ki is broke. Maybe move all the actual logic and state into
# RunToken, and we'll only have one instance per runner? But then we can't
@@ -1903,14 +1816,14 @@ def current_trio_token(self) -> TrioToken:
# keep the class public so people can isinstance() it if they want.
# This gets called from signal context
- def deliver_ki(self) -> None:
+ def deliver_ki(self):
self.ki_pending = True
try:
self.entry_queue.run_sync_soon(self._deliver_ki_cb)
except RunFinishedError:
pass
- def _deliver_ki_cb(self) -> None:
+ def _deliver_ki_cb(self):
if not self.ki_pending:
return
# Can't happen because main_task and run_sync_soon_task are created at
@@ -1927,10 +1840,10 @@ def _deliver_ki_cb(self) -> None:
# Quiescing
################
- waiting_for_idle: SortedDict = attr.ib(factory=SortedDict)
+ waiting_for_idle = attr.ib(factory=SortedDict)
@_public
- async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None:
+ async def wait_all_tasks_blocked(self, cushion=0.0):
"""Block until there are no runnable tasks.
This is useful in testing code when you want to give other tasks a
@@ -1992,7 +1905,7 @@ async def test_lock_fairness():
key = (cushion, id(task))
self.waiting_for_idle[key] = task
- def abort(_: RaiseCancelT) -> Abort:
+ def abort(_):
del self.waiting_for_idle[key]
return Abort.SUCCEEDED
@@ -2067,11 +1980,11 @@ def abort(_: RaiseCancelT) -> Abort:
def setup_runner(
- clock: Clock | None,
- instruments: Sequence[Instrument],
- restrict_keyboard_interrupt_to_checkpoints: bool,
- strict_exception_groups: bool,
-) -> Runner:
+ clock,
+ instruments,
+ restrict_keyboard_interrupt_to_checkpoints,
+ strict_exception_groups,
+):
"""Create a Runner object and install it as the GLOBAL_RUN_CONTEXT."""
# It wouldn't be *hard* to support nested calls to run(), but I can't
# think of a single good reason for it, so let's be conservative for
@@ -2080,17 +1993,15 @@ def setup_runner(
raise RuntimeError("Attempted to call run() from inside a run()")
if clock is None:
- _clock: Clock | SystemClock = SystemClock()
- else:
- _clock = clock
- _instruments = Instruments(instruments)
+ clock = SystemClock()
+ instruments = Instruments(instruments)
io_manager = TheIOManager()
system_context = copy_context()
ki_manager = KIManager()
runner = Runner(
- clock=_clock,
- instruments=_instruments,
+ clock=clock,
+ instruments=instruments,
io_manager=io_manager,
system_context=system_context,
ki_manager=ki_manager,
@@ -2107,13 +2018,13 @@ def setup_runner(
def run(
- async_fn: Callable[..., Awaitable[T]],
- *args: object,
- clock: Clock | None = None,
- instruments: Sequence[Instrument] = (),
+ async_fn,
+ *args,
+ clock=None,
+ instruments=(),
restrict_keyboard_interrupt_to_checkpoints: bool = False,
strict_exception_groups: bool = False,
-) -> T:
+):
"""Run a Trio-flavored async function, and return the result.
Calling::
@@ -2199,32 +2110,30 @@ def run(
next_send = None
while True:
try:
- # sending next_send==None here ... should not work??
- timeout = gen.send(next_send) # type: ignore[arg-type]
+ timeout = gen.send(next_send)
except StopIteration:
break
next_send = runner.io_manager.get_events(timeout)
# Inlined copy of runner.main_task_outcome.unwrap() to avoid
# cluttering every single Trio traceback with an extra frame.
if isinstance(runner.main_task_outcome, Value):
- return runner.main_task_outcome.value # type: ignore[no-any-return]
+ return runner.main_task_outcome.value
else:
- assert runner.main_task_outcome is not None
raise runner.main_task_outcome.error
def start_guest_run(
- async_fn: Callable[..., Awaitable[object]],
- *args: object,
- run_sync_soon_threadsafe: Callable[[Callable[..., None]], None],
- done_callback: Callable[[Outcome], None],
- run_sync_soon_not_threadsafe: Callable[[Callable[..., None]], None] | None = None,
+ async_fn,
+ *args,
+ run_sync_soon_threadsafe,
+ done_callback,
+ run_sync_soon_not_threadsafe=None,
host_uses_signal_set_wakeup_fd: bool = False,
- clock: Clock | None = None,
- instruments: tuple[Instrument, ...] = (),
+ clock=None,
+ instruments=(),
restrict_keyboard_interrupt_to_checkpoints: bool = False,
strict_exception_groups: bool = False,
-) -> None:
+):
"""Start a "guest" run of Trio on top of some other "host" event loop.
Each host loop can only have one guest run at a time.
@@ -2312,10 +2221,10 @@ def my_done_callback(run_outcome):
# straight through.
def unrolled_run(
runner: Runner,
- async_fn: Callable[..., object],
- args: Iterable[object],
+ async_fn,
+ args,
host_uses_signal_set_wakeup_fd: bool = False,
-) -> Generator[float, GetEventsT, None]:
+):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
__tracebackhide__ = True
@@ -2401,7 +2310,7 @@ def unrolled_run(
break
else:
assert idle_primed is IdlePrimedTypes.AUTOJUMP_CLOCK
- runner.clock._autojump() # type: ignore[union-attr]
+ runner.clock._autojump()
# Process all runnable tasks, but only the ones that are already
# runnable now. Anything that becomes runnable during this cycle
@@ -2601,10 +2510,7 @@ def current_effective_deadline() -> float:
float: the effective deadline, as an absolute time.
"""
- curr_cancel_status = current_task()._cancel_status
- assert curr_cancel_status is not None
- return curr_cancel_status.effective_deadline()
- # return current_task()._cancel_status.effective_deadline()
+ return current_task()._cancel_status.effective_deadline()
async def checkpoint() -> None:
@@ -2627,7 +2533,6 @@ async def checkpoint() -> None:
await cancel_shielded_checkpoint()
task = current_task()
task._cancel_points += 1
- assert task._cancel_status is not None
if task._cancel_status.effectively_cancelled or (
task is task._runner.main_task and task._runner.ki_pending
):
@@ -2651,7 +2556,6 @@ async def checkpoint_if_cancelled() -> None:
"""
task = current_task()
- assert task._cancel_status is not None
if task._cancel_status.effectively_cancelled or (
task is task._runner.main_task and task._runner.ki_pending
):
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index ce82e1ec65..cf108c4bff 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9473684210526315,
+ "completenessScore": 0.9330143540669856,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 594,
- "withUnknownType": 33
+ "withKnownType": 585,
+ "withUnknownType": 42
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 0,
@@ -46,11 +46,17 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 5,
- "withKnownType": 613,
- "withUnknownType": 48
+ "withKnownType": 602,
+ "withUnknownType": 59
},
"packageName": "trio",
"symbols": [
+ "trio._core._run.Nursery.start",
+ "trio._core._run.Nursery.start_soon",
+ "trio._core._run.TaskStatus.__repr__",
+ "trio._core._run.TaskStatus.started",
+ "trio._dtls.DTLSChannel.__init__",
+ "trio._dtls.DTLSEndpoint.serve",
"trio._highlevel_socket.SocketStream.getsockopt",
"trio._highlevel_socket.SocketStream.send_all",
"trio._highlevel_socket.SocketStream.setsockopt",
@@ -80,6 +86,7 @@
"trio._subprocess.Process.send_signal",
"trio._subprocess.Process.terminate",
"trio._subprocess.Process.wait",
+ "trio.current_time",
"trio.from_thread.run",
"trio.from_thread.run_sync",
"trio.lowlevel.current_clock",
@@ -89,14 +96,17 @@
"trio.lowlevel.notify_closing",
"trio.lowlevel.reschedule",
"trio.lowlevel.spawn_system_task",
+ "trio.lowlevel.start_guest_run",
"trio.lowlevel.wait_readable",
"trio.lowlevel.wait_writable",
"trio.open_ssl_over_tcp_listeners",
"trio.open_ssl_over_tcp_stream",
"trio.open_unix_socket",
+ "trio.run",
"trio.run_process",
"trio.serve_listeners",
"trio.serve_ssl_over_tcp",
+ "trio.serve_tcp",
"trio.testing._memory_streams.MemoryReceiveStream.__init__",
"trio.testing._memory_streams.MemoryReceiveStream.aclose",
"trio.testing._memory_streams.MemoryReceiveStream.close",
@@ -127,6 +137,7 @@
"trio.testing.memory_stream_pump",
"trio.testing.open_stream_to_socket_listener",
"trio.testing.trio_test",
+ "trio.testing.wait_all_tasks_blocked",
"trio.to_thread.current_default_thread_limiter"
]
}
From 048b48a4db86323189654c1ca2d907d8f9ab241c Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 10 Aug 2023 17:13:46 +0200
Subject: [PATCH 41/49] rearrange files in toml
---
pyproject.toml | 35 +++++++++++++++++++----------------
1 file changed, 19 insertions(+), 16 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 037e0d4403..ebbe88f5ca 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,8 +52,6 @@ disallow_any_unimported = true
# downstream and users have to deal with them.
[[tool.mypy.overrides]]
module = [
-"trio/_core/_tests/*",
-"trio/_tests/*",
# 2749
"trio/_threads", # 15, 398 lines
@@ -67,31 +65,36 @@ module = [
"trio/_ssl", # 26, 929 lines
# 2742
"trio/_core/_multierror", # 19, 469
-# 2735 trio/_core/_asyncgens
-
# 2733
"trio/_core/_run",
"trio/_core/_generated_run",
+# 2724
+"trio/_highlevel_open_tcp_listeners", # 3, 227 lines
+# 2735 trio/_core/_asyncgens
-# windows
-"trio/_windows_pipes",
-"trio/_core/_windows_cffi", # 2, 324
-"trio/_core/_generated_io_windows", # 9 (win32), 84
-"trio/_core/_io_windows", # 47 (win32), 867
-"trio/_wait_for_object", # 2 (windows)
-
-
-"trio/testing/_fake_net", # 30
-
+# exported API
"trio/_highlevel_open_unix_stream", # 2, 49 lines
-"trio/_highlevel_open_tcp_listeners", # 3, 227 lines
"trio/_highlevel_serve_listeners", # 3, 121 lines
"trio/_highlevel_ssl_helpers", # 3, 155 lines
"trio/_highlevel_socket", # 4, 386 lines
-"trio/_subprocess_platform/waitid", # 2, 107 lines
"trio/_signals", # 13, 168 lines
"trio/_subprocess", # 21, 759 lines
+
+# windows API
+"trio/_core/_generated_io_windows", # 9 (win32), 84
+"trio/_core/_io_windows", # 47 (win32), 867
+"trio/_wait_for_object", # 2 (windows)
+
+# internal
+"trio/_windows_pipes",
+"trio/_core/_windows_cffi", # 2, 324
+
+# tests
+"trio/_subprocess_platform/waitid", # 2, 107 lines
+"trio/_core/_tests/*",
+"trio/_tests/*",
+"trio/testing/_fake_net", # 30
]
disallow_any_decorated = false
disallow_any_generics = false
From 5c0a25cbe651b9fdd2cdb0d91db380cc4352425a Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 10 Aug 2023 17:21:39 +0200
Subject: [PATCH 42/49] remove unnecessary/incorrect diffs relative to master
---
.coveragerc | 1 -
trio/__init__.py | 1 -
trio/_core/_asyncgens.py | 8 +++++---
trio/_socket.py | 4 ++--
4 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/.coveragerc b/.coveragerc
index a7e309c11b..431a02971b 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -10,7 +10,6 @@ omit=
*/trio/_core/_generated_*
# Script used to check type completeness that isn't run in tests
*/trio/_tests/check_type_completeness.py
-
# The test suite spawns subprocesses to test some stuff, so make sure
# this doesn't corrupt the coverage files
parallel=True
diff --git a/trio/__init__.py b/trio/__init__.py
index d147012b0a..277baa5339 100644
--- a/trio/__init__.py
+++ b/trio/__init__.py
@@ -16,7 +16,6 @@
#
# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625)
-
# must be imported early to avoid circular import
from ._core import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED # isort: split
diff --git a/trio/_core/_asyncgens.py b/trio/_core/_asyncgens.py
index 975db63555..4261328278 100644
--- a/trio/_core/_asyncgens.py
+++ b/trio/_core/_asyncgens.py
@@ -140,7 +140,7 @@ async def finalize_remaining(self, runner: _run.Runner) -> None:
# To make async generator finalization easier to reason
# about, we'll shut down asyncgen garbage collection by turning
# the alive WeakSet into a regular set.
- self.alive = set(self.alive) # type: ignore
+ self.alive = set(self.alive)
# Process all pending run_sync_soon callbacks, in case one of
# them was an asyncgen finalizer that snuck in under the wire.
@@ -185,14 +185,16 @@ async def finalize_remaining(self, runner: _run.Runner) -> None:
# all are gone.
while self.alive:
batch = self.alive
- self.alive = set() # type: ignore
+ self.alive = _ASYNC_GEN_SET()
for agen in batch:
await self._finalize_one(agen, name_asyncgen(agen))
def close(self) -> None:
sys.set_asyncgen_hooks(*self.prev_hooks)
- async def _finalize_one(self, agen: AGenT, name: str) -> None:
+ async def _finalize_one(
+ self, agen: AsyncGeneratorType[object, NoReturn], name: object
+ ) -> None:
try:
# This shield ensures that finalize_asyncgen never exits
# with an exception, not even a Cancelled. The inside
diff --git a/trio/_socket.py b/trio/_socket.py
index cb739bb903..b0ec1d480d 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -979,8 +979,8 @@ async def sendto(
) -> int:
...
- @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=())
- async def sendto(self, *args: Any) -> int: # type: ignore[misc] # Any
+ @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc]
+ async def sendto(self, *args: Any) -> int:
"""Similar to :meth:`socket.socket.sendto`, but async."""
# args is: data[, flags], address)
# and kwargs are not accepted
From 9f5370564cf78fa216b048d1515d1c3516e84d7a Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 10 Aug 2023 17:34:16 +0200
Subject: [PATCH 43/49] _traps cleanup
---
trio/_core/_traps.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py
index 77fc9966fe..e0f40f8ad6 100644
--- a/trio/_core/_traps.py
+++ b/trio/_core/_traps.py
@@ -1,3 +1,4 @@
+# These are the only functions that ever yield back to the task runner.
from __future__ import annotations
import enum
@@ -9,9 +10,6 @@
from . import _run
-# These are the only functions that ever yield back to the task runner.
-
-
if TYPE_CHECKING:
from outcome import Outcome
from typing_extensions import TypeAlias
From 63a8de7ef38285cddae1cf5653a1cda4cdc46781 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 10 Aug 2023 17:54:06 +0200
Subject: [PATCH 44/49] cleanup
---
trio/_core/_generated_io_windows.py | 8 +++++---
trio/_core/_io_windows.py | 15 ++++-----------
trio/_core/_windows_cffi.py | 3 +--
trio/_ssl.py | 9 +++------
trio/_tools/gen_exports.py | 14 +++++++++++++-
5 files changed, 26 insertions(+), 23 deletions(-)
diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py
index 29f4eb56db..301573c6ee 100644
--- a/trio/_core/_generated_io_windows.py
+++ b/trio/_core/_generated_io_windows.py
@@ -8,7 +8,10 @@
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, ContextManager
+
+if TYPE_CHECKING:
+ from ._unbounded_queue import UnboundedQueue
import sys
assert not TYPE_CHECKING or sys.platform=="win32"
@@ -78,8 +81,7 @@ def current_iocp() ->int:
raise RuntimeError("must be called from async context")
-def monitor_completion_key() ->_GeneratorContextManager[tuple[int,
- UnboundedQueue[object]]]:
+def monitor_completion_key() ->ContextManager[tuple[int, UnboundedQueue[object]]]:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key()
diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py
index 7a9f827b0c..74c5ff1552 100644
--- a/trio/_core/_io_windows.py
+++ b/trio/_core/_io_windows.py
@@ -32,10 +32,8 @@
assert not TYPE_CHECKING or sys.platform == "win32"
if TYPE_CHECKING:
- from _contextlib import _GeneratorContextManager
-
from ._traps import Abort, RaiseCancelT
- from ._unbouded_queue import UnboundedQueue
+ from ._unbounded_queue import UnboundedQueue
# There's a lot to be said about the overall design of a Windows event
# loop. See
@@ -190,7 +188,7 @@ class CKeys(enum.IntEnum):
def _check(success: bool) -> Literal[True]:
if not success:
raise_winerror()
- return success
+ return True
def _get_underlying_socket(
@@ -865,14 +863,9 @@ def submit_read(lpOverlapped):
def current_iocp(self) -> int:
return int(ffi.cast("uintptr_t", self._iocp))
- @_public
- def monitor_completion_key(
- self,
- ) -> _GeneratorContextManager[tuple[int, UnboundedQueue[object]]]:
- return self._monitor_completion_key()
-
@contextmanager
- def _monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]:
+ @_public
+ def monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]:
key = next(self._completion_key_counter)
queue = _core.UnboundedQueue[object]()
self._completion_key_queues[key] = queue
diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py
index 50d598c2be..639e75b50e 100644
--- a/trio/_core/_windows_cffi.py
+++ b/trio/_core/_windows_cffi.py
@@ -1,6 +1,5 @@
import enum
import re
-from typing import NoReturn
import cffi
@@ -316,7 +315,7 @@ def _handle(obj):
return obj
-def raise_winerror(winerror=None, *, filename=None, filename2=None) -> NoReturn:
+def raise_winerror(winerror=None, *, filename=None, filename2=None):
if winerror is None:
winerror, msg = ffi.getwinerror()
else:
diff --git a/trio/_ssl.py b/trio/_ssl.py
index 352f95edaf..bd8b3b06b6 100644
--- a/trio/_ssl.py
+++ b/trio/_ssl.py
@@ -148,12 +148,10 @@
# stream)
# docs will need to make very clear that this is different from all the other
# cancellations in core Trio
-from __future__ import annotations
import operator as _operator
import ssl as _stdlib_ssl
from enum import Enum as _Enum
-from typing import Any, Awaitable, Callable
import trio
@@ -211,14 +209,13 @@ class NeedHandshakeError(Exception):
class _Once:
- # needs TypeVarTuple
- def __init__(self, afn: Callable[..., Awaitable[object]], *args: Any):
+ def __init__(self, afn, *args):
self._afn = afn
self._args = args
self.started = False
self._done = _sync.Event()
- async def ensure(self, *, checkpoint: bool) -> None:
+ async def ensure(self, *, checkpoint):
if not self.started:
self.started = True
await self._afn(*self._args)
@@ -229,7 +226,7 @@ async def ensure(self, *, checkpoint: bool) -> None:
await self._done.wait()
@property
- def done(self) -> bool:
+ def done(self):
return self._done.is_set()
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index f3ed2e26e7..4517eb7bf9 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -228,7 +228,12 @@ def main() -> None: # pragma: no cover
"runner.instruments",
imports=IMPORTS_INSTRUMENT,
),
- File(core / "_io_windows.py", "runner.io_manager", platform="win32"),
+ File(
+ core / "_io_windows.py",
+ "runner.io_manager",
+ platform="win32",
+ imports=IMPORTS_WINDOWS,
+ ),
File(
core / "_io_epoll.py",
"runner.io_manager",
@@ -270,6 +275,13 @@ def main() -> None: # pragma: no cover
"""
+IMPORTS_WINDOWS = """\
+from typing import TYPE_CHECKING, ContextManager
+
+if TYPE_CHECKING:
+ from ._unbounded_queue import UnboundedQueue
+"""
+
if __name__ == "__main__": # pragma: no cover
main()
From 8a4bc3349fbc7dc359c30f1c6d2fe3e99a99e247 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 17 Aug 2023 13:38:28 +0200
Subject: [PATCH 45/49] revert changes to pyproject.toml, properly merge stuff
so Success: no issues found in 129 source files works again
---
pyproject.toml | 49 +++++------------------------------
trio/_core/_generated_run.py | 2 +-
trio/_core/_local.py | 4 +--
trio/_core/_tests/test_run.py | 4 +--
trio/_tests/test_exports.py | 4 +--
trio/_tests/verify_types.json | 18 ++++---------
trio/_tools/gen_exports.py | 7 -----
7 files changed, 18 insertions(+), 70 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index eec4402e73..d4418ba053 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,7 +44,7 @@ disallow_untyped_defs = true
# Enable gradually / for new modules
check_untyped_defs = false
disallow_untyped_calls = false
-disallow_any_unimported = true
+disallow_any_unimported = false # awaiting Outcome
@@ -63,11 +63,6 @@ module = [
"trio/testing/_memory_streams", # 66, 590
# 2745
"trio/_ssl", # 26, 929 lines
-# 2742
-"trio/_core/_multierror", # 19, 469
-# 2733
-"trio/_core/_run",
-"trio/_core/_generated_run",
# 2724
"trio/_highlevel_open_tcp_listeners", # 3, 227 lines
@@ -94,37 +89,12 @@ module = [
"trio/_core/_tests/*",
"trio/_tests/*",
"trio/testing/_fake_net", # 30
-
- "trio._abc",
- "trio._core._asyncgens",
- "trio._core._entry_queue",
- "trio._core._generated_run",
- "trio._core._generated_io_epoll",
- "trio._core._generated_io_kqueue",
- "trio._core._io_epoll",
- "trio._core._io_kqueue",
- "trio._core._local",
- "trio._core._multierror",
- "trio._core._thread_cache",
- "trio._core._unbounded_queue",
- "trio._core._run",
- "trio._deprecate",
- "trio._dtls",
- "trio._file_io",
- "trio._highlevel_open_tcp_stream",
- "trio._ki",
- "trio._socket",
- "trio._sync",
- "trio._tools.gen_exports",
- "trio._util",
]
-disallow_incomplete_defs = true
-disallow_untyped_defs = true
-disallow_untyped_decorators = true
-disallow_any_generics = true
-disallow_any_decorated = true
-disallow_any_unimported = false # Enable once outcome has stubs.
-disallow_subclassing_any = true
+disallow_any_decorated = false
+disallow_any_generics = false
+disallow_any_unimported = false
+disallow_incomplete_defs = false
+disallow_untyped_defs = false
[[tool.mypy.overrides]]
# Needs to use Any due to some complex introspection.
@@ -133,13 +103,6 @@ module = [
]
disallow_any_generics = false
-[[tool.mypy.overrides]]
-# awaiting typing of OutCome
-module = [
- "trio._core._traps",
-]
-disallow_any_unimported = false
-
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
faulthandler_timeout = 60
diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py
index 35ecd45a1b..bd5abbd639 100644
--- a/trio/_core/_generated_run.py
+++ b/trio/_core/_generated_run.py
@@ -88,7 +88,7 @@ def current_root_task() ->(Task | None):
raise RuntimeError("must be called from async context")
-def reschedule(task: Task, next_send: Outcome[Any]=_NO_SEND) ->None: # type: ignore[has-type]
+def reschedule(task: Task, next_send: Outcome[Any]=_NO_SEND) ->None:
"""Reschedule the given task with the given
:class:`outcome.Outcome`.
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index 83826fc63f..8286a5578f 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -71,7 +71,7 @@ def set(self, value: T) -> RunVarToken[T]:
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[index,assignment]
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
return token
def reset(self, token: RunVarToken[T]) -> None:
@@ -93,7 +93,7 @@ def reset(self, token: RunVarToken[T]) -> None:
if previous is _NoValue:
_run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
else:
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index,assignment]
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context")
diff --git a/trio/_core/_tests/test_run.py b/trio/_core/_tests/test_run.py
index 81c3b73cc4..6d34d8f223 100644
--- a/trio/_core/_tests/test_run.py
+++ b/trio/_core/_tests/test_run.py
@@ -1954,7 +1954,7 @@ async def test_Nursery_private_init():
def test_Nursery_subclass():
with pytest.raises(TypeError):
- class Subclass(_core._run.Nursery):
+ class Subclass(_core._run.Nursery): # type: ignore[misc]
pass
@@ -1984,7 +1984,7 @@ class Subclass(_core.Cancelled):
def test_CancelScope_subclass():
with pytest.raises(TypeError):
- class Subclass(_core.CancelScope):
+ class Subclass(_core.CancelScope): # type: ignore[misc]
pass
diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py
index 1553474c7a..1a145f397e 100644
--- a/trio/_tests/test_exports.py
+++ b/trio/_tests/test_exports.py
@@ -9,7 +9,7 @@
import sys
from pathlib import Path
from types import ModuleType
-from typing import Protocol
+from typing import Dict, Protocol
import attrs
import pytest
@@ -27,7 +27,7 @@
try: # If installed, check both versions of this class.
from typing_extensions import Protocol as Protocol_ext
except ImportError: # pragma: no cover
- Protocol_ext = Protocol
+ Protocol_ext = Protocol # type: ignore[assignment]
def _ensure_mypy_cache_updated():
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index ac5ab46812..28288652c9 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9330143540669856,
+ "completenessScore": 0.9570063694267515,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 585,
- "withUnknownType": 42
+ "withKnownType": 601,
+ "withUnknownType": 27
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 0,
@@ -46,8 +46,8 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 5,
- "withKnownType": 602,
- "withUnknownType": 59
+ "withKnownType": 627,
+ "withUnknownType": 48
},
"packageName": "trio",
"symbols": [
@@ -82,14 +82,7 @@
"trio._subprocess.Process.wait",
"trio.from_thread.run",
"trio.from_thread.run_sync",
- "trio.lowlevel.current_clock",
- "trio.lowlevel.current_root_task",
- "trio.lowlevel.current_statistics",
- "trio.lowlevel.current_trio_token",
"trio.lowlevel.notify_closing",
- "trio.lowlevel.reschedule",
- "trio.lowlevel.spawn_system_task",
- "trio.lowlevel.start_guest_run",
"trio.lowlevel.wait_readable",
"trio.lowlevel.wait_writable",
"trio.open_ssl_over_tcp_listeners",
@@ -128,7 +121,6 @@
"trio.testing.memory_stream_pump",
"trio.testing.open_stream_to_socket_listener",
"trio.testing.trio_test",
- "trio.testing.wait_all_tasks_blocked",
"trio.to_thread.current_default_thread_limiter"
]
}
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 43ed8b8bb8..0730c16684 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -158,13 +158,6 @@ def gen_public_wrappers_source(file: File) -> str:
if is_cm: # pragma: no cover
func = func.replace("->Iterator", "->ContextManager")
- # TODO: hacky workaround until we run mypy without `-m`, which breaks imports
- # enough that it cannot figure out the type of _NO_SEND
- if file.path.stem == "_run" and func.startswith(
- "def reschedule"
- ): # pragma: no cover
- func = func.replace("None:\n", "None: # type: ignore[has-type]\n")
-
# Create export function body
template = TEMPLATE.format(
" await " if isinstance(method, ast.AsyncFunctionDef) else " ",
From f0e18ad490c40e5094e0ac5a7f26ceca7292bc1c Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 17 Aug 2023 14:04:39 +0200
Subject: [PATCH 46/49] _io_windows changes moved to #2761
---
trio/_core/_generated_io_windows.py | 17 +++-----
trio/_core/_io_windows.py | 66 ++++++++++++++---------------
trio/_tools/gen_exports.py | 14 +-----
3 files changed, 39 insertions(+), 58 deletions(-)
diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py
index 301573c6ee..7fa6fd5126 100644
--- a/trio/_core/_generated_io_windows.py
+++ b/trio/_core/_generated_io_windows.py
@@ -8,16 +8,13 @@
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
-from typing import TYPE_CHECKING, ContextManager
-
-if TYPE_CHECKING:
- from ._unbounded_queue import UnboundedQueue
+from typing import TYPE_CHECKING
import sys
assert not TYPE_CHECKING or sys.platform=="win32"
-async def wait_readable(sock) ->None:
+async def wait_readable(sock):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock)
@@ -25,7 +22,7 @@ async def wait_readable(sock) ->None:
raise RuntimeError("must be called from async context")
-async def wait_writable(sock) ->None:
+async def wait_writable(sock):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock)
@@ -33,7 +30,7 @@ async def wait_writable(sock) ->None:
raise RuntimeError("must be called from async context")
-def notify_closing(handle) ->None:
+def notify_closing(handle):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle)
@@ -41,7 +38,7 @@ def notify_closing(handle) ->None:
raise RuntimeError("must be called from async context")
-def register_with_iocp(handle) ->None:
+def register_with_iocp(handle):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle)
@@ -73,7 +70,7 @@ async def readinto_overlapped(handle, buffer, file_offset=0):
raise RuntimeError("must be called from async context")
-def current_iocp() ->int:
+def current_iocp():
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp()
@@ -81,7 +78,7 @@ def current_iocp() ->int:
raise RuntimeError("must be called from async context")
-def monitor_completion_key() ->ContextManager[tuple[int, UnboundedQueue[object]]]:
+def monitor_completion_key():
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key()
diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py
index ba2f78cfe2..9757d25b5f 100644
--- a/trio/_core/_io_windows.py
+++ b/trio/_core/_io_windows.py
@@ -5,7 +5,7 @@
import socket
import sys
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Iterator, Literal
+from typing import TYPE_CHECKING, Literal
import attr
from outcome import Value
@@ -32,8 +32,6 @@
assert not TYPE_CHECKING or sys.platform == "win32"
if TYPE_CHECKING:
- from ._traps import Abort, RaiseCancelT
- from ._unbounded_queue import UnboundedQueue
from typing_extensions import TypeAlias
EventResult: TypeAlias = int
@@ -187,15 +185,13 @@ class CKeys(enum.IntEnum):
USER_DEFINED = 4 # and above
-def _check(success: bool) -> Literal[True]:
+def _check(success):
if not success:
raise_winerror()
- return True
+ return success
-def _get_underlying_socket(
- sock: socket.socket | int, *, which=WSAIoctls.SIO_BASE_HANDLE
-):
+def _get_underlying_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE):
if hasattr(sock, "fileno"):
sock = sock.fileno()
base_ptr = ffi.new("HANDLE *")
@@ -340,9 +336,9 @@ def _afd_helper_handle():
# operation and start a new one.
@attr.s(slots=True, eq=False)
class AFDWaiters:
- read_task: None = attr.ib(default=None)
- write_task: None = attr.ib(default=None)
- current_op: None = attr.ib(default=None)
+ read_task = attr.ib(default=None)
+ write_task = attr.ib(default=None)
+ current_op = attr.ib(default=None)
# We also need to bundle up all the info for a single op into a standalone
@@ -350,10 +346,10 @@ class AFDWaiters:
# finishes, even if we're throwing it away.
@attr.s(slots=True, eq=False, frozen=True)
class AFDPollOp:
- lpOverlapped: None = attr.ib()
- poll_info: None = attr.ib()
- waiters: None = attr.ib()
- afd_group: None = attr.ib()
+ lpOverlapped = attr.ib()
+ poll_info = attr.ib()
+ waiters = attr.ib()
+ afd_group = attr.ib()
# The Windows kernel has a weird issue when using AFD handles. If you have N
@@ -369,8 +365,8 @@ class AFDPollOp:
@attr.s(slots=True, eq=False)
class AFDGroup:
- size: int = attr.ib()
- handle: None = attr.ib()
+ size = attr.ib()
+ handle = attr.ib()
@attr.s(slots=True, eq=False, frozen=True)
@@ -391,8 +387,8 @@ class _WindowsStatistics:
@attr.s(frozen=True)
class CompletionKeyEventInfo:
- lpOverlapped: None = attr.ib()
- dwNumberOfBytesTransferred: int = attr.ib()
+ lpOverlapped = attr.ib()
+ dwNumberOfBytesTransferred = attr.ib()
class WindowsIOManager:
@@ -459,7 +455,7 @@ def __init__(self):
"netsh winsock show catalog"
)
- def close(self) -> None:
+ def close(self):
try:
if self._iocp is not None:
iocp = self._iocp
@@ -470,10 +466,10 @@ def close(self) -> None:
afd_handle = self._all_afd_handles.pop()
_check(kernel32.CloseHandle(afd_handle))
- def __del__(self) -> None:
+ def __del__(self):
self.close()
- def statistics(self) -> _WindowsStatistics:
+ def statistics(self):
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._afd_waiters.values():
@@ -488,7 +484,7 @@ def statistics(self) -> _WindowsStatistics:
completion_key_monitors=len(self._completion_key_queues),
)
- def force_wakeup(self) -> None:
+ def force_wakeup(self):
_check(
kernel32.PostQueuedCompletionStatus(
self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL
@@ -594,7 +590,7 @@ def process_events(self, received: EventResult) -> None:
)
queue.put_nowait(info)
- def _register_with_iocp(self, handle, completion_key) -> None:
+ def _register_with_iocp(self, handle, completion_key):
handle = _handle(handle)
_check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0))
# Supposedly this makes things slightly faster, by disabling the
@@ -611,7 +607,7 @@ def _register_with_iocp(self, handle, completion_key) -> None:
# AFD stuff
################################################################
- def _refresh_afd(self, base_handle) -> None:
+ def _refresh_afd(self, base_handle):
waiters = self._afd_waiters[base_handle]
if waiters.current_op is not None:
afd_group = waiters.current_op.afd_group
@@ -687,7 +683,7 @@ def _refresh_afd(self, base_handle) -> None:
if afd_group.size >= MAX_AFD_GROUP_SIZE:
self._vacant_afd_groups.remove(afd_group)
- async def _afd_poll(self, sock, mode) -> None:
+ async def _afd_poll(self, sock, mode):
base_handle = _get_base_socket(sock)
waiters = self._afd_waiters.get(base_handle)
if waiters is None:
@@ -700,7 +696,7 @@ async def _afd_poll(self, sock, mode) -> None:
# we let it escape.
self._refresh_afd(base_handle)
- def abort_fn(_: RaiseCancelT) -> Abort:
+ def abort_fn(_):
setattr(waiters, mode, None)
self._refresh_afd(base_handle)
return _core.Abort.SUCCEEDED
@@ -708,15 +704,15 @@ def abort_fn(_: RaiseCancelT) -> Abort:
await _core.wait_task_rescheduled(abort_fn)
@_public
- async def wait_readable(self, sock) -> None:
+ async def wait_readable(self, sock):
await self._afd_poll(sock, "read_task")
@_public
- async def wait_writable(self, sock) -> None:
+ async def wait_writable(self, sock):
await self._afd_poll(sock, "write_task")
@_public
- def notify_closing(self, handle) -> None:
+ def notify_closing(self, handle):
handle = _get_base_socket(handle)
waiters = self._afd_waiters.get(handle)
if waiters is not None:
@@ -728,7 +724,7 @@ def notify_closing(self, handle) -> None:
################################################################
@_public
- def register_with_iocp(self, handle) -> None:
+ def register_with_iocp(self, handle):
self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED)
@_public
@@ -744,7 +740,7 @@ async def wait_overlapped(self, handle, lpOverlapped):
self._overlapped_waiters[lpOverlapped] = task
raise_cancel = None
- def abort(raise_cancel_: RaiseCancelT) -> Abort:
+ def abort(raise_cancel_):
nonlocal raise_cancel
raise_cancel = raise_cancel_
try:
@@ -864,14 +860,14 @@ def submit_read(lpOverlapped):
################################################################
@_public
- def current_iocp(self) -> int:
+ def current_iocp(self):
return int(ffi.cast("uintptr_t", self._iocp))
@contextmanager
@_public
- def monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]:
+ def monitor_completion_key(self):
key = next(self._completion_key_counter)
- queue = _core.UnboundedQueue[object]()
+ queue = _core.UnboundedQueue()
self._completion_key_queues[key] = queue
try:
yield (key, queue)
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 0730c16684..3c598e8eae 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -228,12 +228,7 @@ def main() -> None: # pragma: no cover
"runner.instruments",
imports=IMPORTS_INSTRUMENT,
),
- File(
- core / "_io_windows.py",
- "runner.io_manager",
- platform="win32",
- imports=IMPORTS_WINDOWS,
- ),
+ File(core / "_io_windows.py", "runner.io_manager", platform="win32"),
File(
core / "_io_epoll.py",
"runner.io_manager",
@@ -283,13 +278,6 @@ def main() -> None: # pragma: no cover
"""
-IMPORTS_WINDOWS = """\
-from typing import TYPE_CHECKING, ContextManager
-
-if TYPE_CHECKING:
- from ._unbounded_queue import UnboundedQueue
-"""
-
if __name__ == "__main__": # pragma: no cover
main()
From cd4e907fe45ec6ee95d87604c480cc922b2e4687 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 17 Aug 2023 14:11:22 +0200
Subject: [PATCH 47/49] move around in pyproject.toml
---
pyproject.toml | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 51ed378b49..4526571fc3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,15 +69,15 @@ module = [
# 2755
"trio/_core/_windows_cffi", # 2, 324
"trio/_wait_for_object", # 2 (windows)
+# 2761
+"trio/_core/_generated_io_windows", # 9 (win32), 84
+"trio/_core/_io_windows", # 47 (win32), 867
"trio/_signals", # 13, 168 lines
-# windows API
-"trio/_core/_generated_io_windows", # 9 (win32), 84
-"trio/_core/_io_windows", # 47 (win32), 867
# internal
"trio/_windows_pipes",
From 69b7e77d5a1d493198283fdd63fa0fbe715f48e5 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 17 Aug 2023 14:58:06 +0200
Subject: [PATCH 48/49] cleanup
---
pyproject.toml | 49 +++++++++++++----------------------
trio/_path.py | 6 ++---
trio/_tests/verify_types.json | 5 ++--
trio/tests.py | 2 --
4 files changed, 24 insertions(+), 38 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 4526571fc3..a212393452 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,48 +36,42 @@ warn_return_any = true
# Avoid subtle backsliding
disallow_any_decorated = true
disallow_any_generics = true
+disallow_any_unimported = false # Enable once Outcome has stubs.
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
-# Enable gradually / for new modules
+# Enable once other problems are dealt with
check_untyped_defs = false
disallow_untyped_calls = false
-disallow_any_unimported = false # awaiting Outcome
-
-# DO NOT use `ignore_errors`; it doesn't apply
-# downstream and users have to deal with them.
+# files not yet fully typed
[[tool.mypy.overrides]]
module = [
-
# 2747
-"trio/testing/_network", # 1, 34
-"trio/testing/_trio_test", # 2, 29
-"trio/testing/_checkpoints", # 3, 62
-"trio/testing/_check_streams", # 27, 522
-"trio/testing/_memory_streams", # 66, 590
+"trio/testing/_network",
+"trio/testing/_trio_test",
+"trio/testing/_checkpoints",
+"trio/testing/_check_streams",
+"trio/testing/_memory_streams",
# 2745
-"trio/_ssl", # 26, 929 lines
+"trio/_ssl",
# 2756
-"trio/_highlevel_open_unix_stream", # 2, 49 lines
-"trio/_highlevel_serve_listeners", # 3, 121 lines
-"trio/_highlevel_ssl_helpers", # 3, 155 lines
-"trio/_highlevel_socket", # 4, 386 lines
+"trio/_highlevel_open_unix_stream",
+"trio/_highlevel_serve_listeners",
+"trio/_highlevel_ssl_helpers",
+"trio/_highlevel_socket",
# 2755
-"trio/_core/_windows_cffi", # 2, 324
-"trio/_wait_for_object", # 2 (windows)
+"trio/_core/_windows_cffi",
+"trio/_wait_for_object",
# 2761
-"trio/_core/_generated_io_windows", # 9 (win32), 84
-"trio/_core/_io_windows", # 47 (win32), 867
-
-
-
+"trio/_core/_generated_io_windows",
+"trio/_core/_io_windows",
-"trio/_signals", # 13, 168 lines
+"trio/_signals",
# internal
"trio/_windows_pipes",
@@ -93,13 +87,6 @@ disallow_any_unimported = false
disallow_incomplete_defs = false
disallow_untyped_defs = false
-[[tool.mypy.overrides]]
-# Needs to use Any due to some complex introspection.
-module = [
- "trio._path",
-]
-disallow_any_generics = false
-
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
faulthandler_timeout = 60
diff --git a/trio/_path.py b/trio/_path.py
index cad83e0e6a..c2763e03af 100644
--- a/trio/_path.py
+++ b/trio/_path.py
@@ -116,7 +116,7 @@ async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Path:
def classmethod_wrapper_factory(
cls: AsyncAutoWrapperType, meth_name: str
-) -> classmethod:
+) -> classmethod: # type: ignore[type-arg]
@async_wraps(cls, cls._wraps, meth_name)
async def wrapper(cls: type[Path], *args: Any, **kwargs: Any) -> Path: # type: ignore[misc] # contains Any
meth = getattr(cls._wraps, meth_name)
@@ -163,7 +163,7 @@ def generate_forwards(cls, attrs: dict[str, object]) -> None:
def generate_wraps(cls, attrs: dict[str, object]) -> None:
# generate wrappers for functions of _wraps
- wrapper: classmethod | Callable
+ wrapper: classmethod | Callable[..., object] # type: ignore[type-arg]
for attr_name, attr in cls._wraps.__dict__.items():
# .z. exclude cls._wrap_iter
if attr_name.startswith("_") or attr_name in attrs:
@@ -188,7 +188,7 @@ def generate_magic(cls, attrs: dict[str, object]) -> None:
def generate_iter(cls, attrs: dict[str, object]) -> None:
# generate wrappers for methods that return iterators
- wrapper: Callable
+ wrapper: Callable[..., object]
for attr_name, attr in cls._wraps.__dict__.items():
if attr_name in cls._wrap_iter:
wrapper = iter_wrapper_factory(cls, attr_name)
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 9d2f5b3a55..c5e9c4dc66 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -14,7 +14,7 @@
"withUnknownType": 22
},
"ignoreUnknownTypesFromImports": true,
- "missingClassDocStringCount": 0,
+ "missingClassDocStringCount": 1,
"missingDefaultParamCount": 0,
"missingFunctionDocStringCount": 4,
"moduleName": "trio",
@@ -105,7 +105,8 @@
"trio.testing.memory_stream_pair",
"trio.testing.memory_stream_pump",
"trio.testing.open_stream_to_socket_listener",
- "trio.testing.trio_test"
+ "trio.testing.trio_test",
+ "trio.tests.TestsDeprecationWrapper"
]
}
}
diff --git a/trio/tests.py b/trio/tests.py
index 1c5f039f0f..4ffb583a3a 100644
--- a/trio/tests.py
+++ b/trio/tests.py
@@ -16,8 +16,6 @@
# This won't give deprecation warning on import, but will give a warning on use of any
# attribute in tests, and static analysis tools will also not see any content inside.
class TestsDeprecationWrapper:
- """trio.tests is deprecated, use trio._tests"""
-
__name__ = "trio.tests"
def __getattr__(self, attr: str) -> Any:
From 47a7228b3bc97bbfd44708801d1b2db73176e597 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 17 Aug 2023 14:58:57 +0200
Subject: [PATCH 49/49] make CI run without -m
---
check.sh | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/check.sh b/check.sh
index a0efa531b6..ace193a62a 100755
--- a/check.sh
+++ b/check.sh
@@ -27,9 +27,9 @@ fi
flake8 trio/ || EXIT_STATUS=$?
# Run mypy on all supported platforms
-mypy -m trio -m trio.testing --platform linux || EXIT_STATUS=$?
-mypy -m trio -m trio.testing --platform darwin || EXIT_STATUS=$? # tests FreeBSD too
-mypy -m trio -m trio.testing --platform win32 || EXIT_STATUS=$?
+mypy trio --platform linux || EXIT_STATUS=$?
+mypy trio --platform darwin || EXIT_STATUS=$? # tests FreeBSD too
+mypy trio --platform win32 || EXIT_STATUS=$?
# Check pip compile is consistent
pip-compile test-requirements.in