From 8eec6eabed5e5fe30940da5d3ab4f406aeb4d78d Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 6 Aug 2023 12:58:54 +0200
Subject: [PATCH 01/10] type _io_epoll and _io_kqueue
---
pyproject.toml | 4 ++
trio/_core/_generated_instrumentation.py | 16 +++++-
trio/_core/_generated_io_epoll.py | 24 +++++++--
trio/_core/_generated_io_kqueue.py | 32 ++++++++---
trio/_core/_generated_io_windows.py | 18 ++++++-
trio/_core/_generated_run.py | 16 +++++-
trio/_core/_io_epoll.py | 66 +++++++++++++----------
trio/_core/_io_kqueue.py | 68 ++++++++++++++----------
trio/_core/_run.py | 2 +-
trio/_tests/verify_types.json | 6 +--
trio/_tools/gen_exports.py | 42 ++++++++++++++-
11 files changed, 219 insertions(+), 75 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 0900f3a7d1..ba10d4c532 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -51,6 +51,10 @@ disallow_untyped_defs = false
module = [
"trio._abc",
"trio._core._entry_queue",
+ "trio._core._generated_io_epoll",
+ "trio._core._generated_io_kqueue",
+ "trio._core._io_epoll",
+ "trio._core._io_kqueue",
"trio._core._local",
"trio._core._unbounded_queue",
"trio._deprecate",
diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py
index 30c2f26b4e..b3f73f6616 100644
--- a/trio/_core/_generated_instrumentation.py
+++ b/trio/_core/_generated_instrumentation.py
@@ -1,11 +1,25 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+# isort: skip_file
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING, Callable, ContextManager
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ import select
+ from socket import socket
+
+ from ._traps import Abort, RaiseCancelT
+
+ from .. import _core
+ from ._unbounded_queue import UnboundedQueue
+
# fmt: off
diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py
index 02fb3bc348..5ac1b99efb 100644
--- a/trio/_core/_generated_io_epoll.py
+++ b/trio/_core/_generated_io_epoll.py
@@ -1,15 +1,31 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+# isort: skip_file
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING, Callable, ContextManager
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ import select
+ from socket import socket
+
+ from ._traps import Abort, RaiseCancelT
+
+ from .. import _core
+ from ._unbounded_queue import UnboundedQueue
+
# fmt: off
-async def wait_readable(fd):
+assert not TYPE_CHECKING or sys.platform=="linux"
+
+async def wait_readable(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
@@ -17,7 +33,7 @@ async def wait_readable(fd):
raise RuntimeError("must be called from async context")
-async def wait_writable(fd):
+async def wait_writable(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
@@ -25,7 +41,7 @@ async def wait_writable(fd):
raise RuntimeError("must be called from async context")
-def notify_closing(fd):
+def notify_closing(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py
index 94e819769c..97199e4577 100644
--- a/trio/_core/_generated_io_kqueue.py
+++ b/trio/_core/_generated_io_kqueue.py
@@ -1,15 +1,31 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+# isort: skip_file
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING, Callable, ContextManager
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ import select
+ from socket import socket
+
+ from ._traps import Abort, RaiseCancelT
+
+ from .. import _core
+ from ._unbounded_queue import UnboundedQueue
+
# fmt: off
-def current_kqueue():
+assert not TYPE_CHECKING or sys.platform=="darwin"
+
+def current_kqueue() ->select.kqueue:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue()
@@ -17,7 +33,8 @@ def current_kqueue():
raise RuntimeError("must be called from async context")
-def monitor_kevent(ident, filter):
+def monitor_kevent(ident: int, filter: int) ->ContextManager[_core.UnboundedQueue
+ [select.kevent]]:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter)
@@ -25,7 +42,8 @@ def monitor_kevent(ident, filter):
raise RuntimeError("must be called from async context")
-async def wait_kevent(ident, filter, abort_func):
+async def wait_kevent(ident: int, filter: int, abort_func: Callable[[
+ RaiseCancelT], Abort]) ->Abort:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func)
@@ -33,7 +51,7 @@ async def wait_kevent(ident, filter, abort_func):
raise RuntimeError("must be called from async context")
-async def wait_readable(fd):
+async def wait_readable(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
@@ -41,7 +59,7 @@ async def wait_readable(fd):
raise RuntimeError("must be called from async context")
-async def wait_writable(fd):
+async def wait_writable(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
@@ -49,7 +67,7 @@ async def wait_writable(fd):
raise RuntimeError("must be called from async context")
-def notify_closing(fd):
+def notify_closing(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py
index 26b4da697d..59e14b185d 100644
--- a/trio/_core/_generated_io_windows.py
+++ b/trio/_core/_generated_io_windows.py
@@ -1,14 +1,30 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+# isort: skip_file
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING, Callable, ContextManager
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ import select
+ from socket import socket
+
+ from ._traps import Abort, RaiseCancelT
+
+ from .. import _core
+ from ._unbounded_queue import UnboundedQueue
+
# fmt: off
+assert not TYPE_CHECKING or sys.platform=="win32"
+
async def wait_readable(sock):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py
index d1e74a93f4..c5b78beadd 100644
--- a/trio/_core/_generated_run.py
+++ b/trio/_core/_generated_run.py
@@ -1,11 +1,25 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+# isort: skip_file
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING, Callable, ContextManager
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ import select
+ from socket import socket
+
+ from ._traps import Abort, RaiseCancelT
+
+ from .. import _core
+ from ._unbounded_queue import UnboundedQueue
+
# fmt: off
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index 376dd18a4e..ff873ee966 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -1,23 +1,38 @@
+from __future__ import annotations
+
import select
import sys
from collections import defaultdict
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, DefaultDict
import attr
from .. import _core
from ._io_common import wake_all
-from ._run import _public
+from ._run import Task, _public
from ._wakeup_socketpair import WakeupSocketpair
+if TYPE_CHECKING:
+ from socket import socket
+
+ from .._core import Abort, RaiseCancelT
+
+
+@attr.s(slots=True, eq=False)
+class EpollWaiters:
+ read_task: Task | None = attr.ib(default=None)
+ write_task: Task | None = attr.ib(default=None)
+ current_flags: int = attr.ib(default=0)
+
+
assert not TYPE_CHECKING or sys.platform == "linux"
@attr.s(slots=True, eq=False, frozen=True)
class _EpollStatistics:
- tasks_waiting_read = attr.ib()
- tasks_waiting_write = attr.ib()
- backend = attr.ib(default="epoll")
+ tasks_waiting_read: int = attr.ib()
+ tasks_waiting_write: int = attr.ib()
+ backend: str = attr.ib(default="epoll")
# Some facts about epoll
@@ -178,28 +193,21 @@ class _EpollStatistics:
# wanted to about how epoll works.
-@attr.s(slots=True, eq=False)
-class EpollWaiters:
- read_task = attr.ib(default=None)
- write_task = attr.ib(default=None)
- current_flags = attr.ib(default=0)
-
-
@attr.s(slots=True, eq=False, hash=False)
class EpollIOManager:
- _epoll = attr.ib(factory=select.epoll)
+ _epoll: select.epoll = attr.ib(factory=select.epoll)
# {fd: EpollWaiters}
- _registered = attr.ib(
- factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters]
+ _registered: DefaultDict[int, EpollWaiters] = attr.ib(
+ factory=lambda: defaultdict(EpollWaiters)
)
- _force_wakeup = attr.ib(factory=WakeupSocketpair)
- _force_wakeup_fd = attr.ib(default=None)
+ _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
+ _force_wakeup_fd: int | None = attr.ib(default=None)
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
- def statistics(self):
+ def statistics(self) -> _EpollStatistics:
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._registered.values():
@@ -212,24 +220,24 @@ def statistics(self):
tasks_waiting_write=tasks_waiting_write,
)
- def close(self):
+ def close(self) -> None:
self._epoll.close()
self._force_wakeup.close()
- def force_wakeup(self):
+ def force_wakeup(self) -> None:
self._force_wakeup.wakeup_thread_and_signal_safe()
# Return value must be False-y IFF the timeout expired, NOT if any I/O
# happened or force_wakeup was called. Otherwise it can be anything; gets
# passed straight through to process_events.
- def get_events(self, timeout):
+ def get_events(self, timeout: float) -> list[tuple[int, int]]:
# max_events must be > 0 or epoll gets cranky
# accessing self._registered from a thread looks dangerous, but it's
# OK because it doesn't matter if our value is a little bit off.
max_events = max(1, len(self._registered))
return self._epoll.poll(timeout, max_events)
- def process_events(self, events):
+ def process_events(self, events: list[tuple[int, int]]) -> None:
for fd, flags in events:
if fd == self._force_wakeup_fd:
self._force_wakeup.drain()
@@ -248,7 +256,7 @@ def process_events(self, events):
waiters.read_task = None
self._update_registrations(fd)
- def _update_registrations(self, fd):
+ def _update_registrations(self, fd: int) -> None:
waiters = self._registered[fd]
wanted_flags = 0
if waiters.read_task is not None:
@@ -277,7 +285,7 @@ def _update_registrations(self, fd):
if not wanted_flags:
del self._registered[fd]
- async def _epoll_wait(self, fd, attr_name):
+ async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
@@ -288,7 +296,7 @@ async def _epoll_wait(self, fd, attr_name):
setattr(waiters, attr_name, _core.current_task())
self._update_registrations(fd)
- def abort(_):
+ def abort(_: RaiseCancelT) -> Abort:
setattr(waiters, attr_name, None)
self._update_registrations(fd)
return _core.Abort.SUCCEEDED
@@ -296,15 +304,15 @@ def abort(_):
await _core.wait_task_rescheduled(abort)
@_public
- async def wait_readable(self, fd):
+ async def wait_readable(self, fd: int | socket) -> None:
await self._epoll_wait(fd, "read_task")
@_public
- async def wait_writable(self, fd):
+ async def wait_writable(self, fd: int | socket) -> None:
await self._epoll_wait(fd, "write_task")
@_public
- def notify_closing(self, fd):
+ def notify_closing(self, fd: int | socket) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py
index d1151843e8..f47d6f1c7c 100644
--- a/trio/_core/_io_kqueue.py
+++ b/trio/_core/_io_kqueue.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import errno
import select
import sys
from contextlib import contextmanager
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, Iterator
import attr
import outcome
@@ -11,32 +13,39 @@
from ._run import _public
from ._wakeup_socketpair import WakeupSocketpair
+if TYPE_CHECKING:
+ from socket import socket
+
+ from .._core import Abort, RaiseCancelT, Task, UnboundedQueue
+
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
@attr.s(slots=True, eq=False, frozen=True)
class _KqueueStatistics:
- tasks_waiting = attr.ib()
- monitors = attr.ib()
- backend = attr.ib(default="kqueue")
+ tasks_waiting: int = attr.ib()
+ monitors: int = attr.ib()
+ backend: str = attr.ib(default="kqueue")
@attr.s(slots=True, eq=False)
class KqueueIOManager:
- _kqueue = attr.ib(factory=select.kqueue)
+ _kqueue: select.kqueue = attr.ib(factory=select.kqueue)
# {(ident, filter): Task or UnboundedQueue}
- _registered = attr.ib(factory=dict)
- _force_wakeup = attr.ib(factory=WakeupSocketpair)
- _force_wakeup_fd = attr.ib(default=None)
+ _registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = attr.ib(
+ factory=dict
+ )
+ _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
+ _force_wakeup_fd: int | None = attr.ib(default=None)
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
force_wakeup_event = select.kevent(
self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD
)
self._kqueue.control([force_wakeup_event], 0)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
- def statistics(self):
+ def statistics(self) -> _KqueueStatistics:
tasks_waiting = 0
monitors = 0
for receiver in self._registered.values():
@@ -46,14 +55,14 @@ def statistics(self):
monitors += 1
return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors)
- def close(self):
+ def close(self) -> None:
self._kqueue.close()
self._force_wakeup.close()
- def force_wakeup(self):
+ def force_wakeup(self) -> None:
self._force_wakeup.wakeup_thread_and_signal_safe()
- def get_events(self, timeout):
+ def get_events(self, timeout: float) -> list[select.kevent]:
# max_events must be > 0 or kqueue gets cranky
# and we generally want this to be strictly larger than the actual
# number of events we get, so that we can tell that we've gotten
@@ -70,7 +79,7 @@ def get_events(self, timeout):
# and loop back to the start
return events
- def process_events(self, events):
+ def process_events(self, events: list[select.kevent]) -> None:
for event in events:
key = (event.ident, event.filter)
if event.ident == self._force_wakeup_fd:
@@ -79,7 +88,7 @@ def process_events(self, events):
receiver = self._registered[key]
if event.flags & select.KQ_EV_ONESHOT:
del self._registered[key]
- if type(receiver) is _core.Task:
+ if isinstance(receiver, _core.Task):
_core.reschedule(receiver, outcome.Value(event))
else:
receiver.put_nowait(event)
@@ -96,18 +105,20 @@ def process_events(self, events):
# be more ergonomic...
@_public
- def current_kqueue(self):
+ def current_kqueue(self) -> select.kqueue:
return self._kqueue
- @contextmanager
@_public
- def monitor_kevent(self, ident, filter):
+ @contextmanager
+ def monitor_kevent(
+ self, ident: int, filter: int
+ ) -> Iterator[_core.UnboundedQueue[select.kevent]]:
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair"
)
- q = _core.UnboundedQueue()
+ q = _core.UnboundedQueue[select.kevent]()
self._registered[key] = q
try:
yield q
@@ -115,7 +126,9 @@ def monitor_kevent(self, ident, filter):
del self._registered[key]
@_public
- async def wait_kevent(self, ident, filter, abort_func):
+ async def wait_kevent(
+ self, ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort]
+ ) -> Abort:
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
@@ -123,22 +136,23 @@ async def wait_kevent(self, ident, filter, abort_func):
)
self._registered[key] = _core.current_task()
- def abort(raise_cancel):
+ def abort(raise_cancel: RaiseCancelT) -> Abort:
r = abort_func(raise_cancel)
if r is _core.Abort.SUCCEEDED:
del self._registered[key]
return r
- return await _core.wait_task_rescheduled(abort)
+ # wait_task_rescheduled does not have its return type typed
+ return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return]
- async def _wait_common(self, fd, filter):
+ async def _wait_common(self, fd: int | socket, filter: int) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
event = select.kevent(fd, filter, flags)
self._kqueue.control([event], 0)
- def abort(_):
+ def abort(_: RaiseCancelT) -> Abort:
event = select.kevent(fd, filter, select.KQ_EV_DELETE)
try:
self._kqueue.control([event], 0)
@@ -163,15 +177,15 @@ def abort(_):
await self.wait_kevent(fd, filter, abort)
@_public
- async def wait_readable(self, fd):
+ async def wait_readable(self, fd: int | socket) -> None:
await self._wait_common(fd, select.KQ_FILTER_READ)
@_public
- async def wait_writable(self, fd):
+ async def wait_writable(self, fd: int | socket) -> None:
await self._wait_common(fd, select.KQ_FILTER_WRITE)
@_public
- def notify_closing(self, fd):
+ def notify_closing(self, fd: int | socket) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index ce8feb2827..7d247a2738 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -1433,7 +1433,7 @@ def in_main_thread():
class Runner:
clock = attr.ib()
instruments: Instruments = attr.ib()
- io_manager = attr.ib()
+ io_manager: TheIOManager = attr.ib()
ki_manager = attr.ib()
strict_exception_groups = attr.ib()
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 55ba3b32d7..38014eac07 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9186602870813397,
+ "completenessScore": 0.9234449760765551,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 576,
- "withUnknownType": 51
+ "withKnownType": 579,
+ "withUnknownType": 48
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index bae7e4f69d..e887a7bbc0 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -24,11 +24,25 @@
HEADER = """# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
-# isort: skip
+# isort: skip_file
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING, Callable, ContextManager
+
from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+if TYPE_CHECKING:
+ import select
+ from socket import socket
+
+ from ._traps import Abort, RaiseCancelT
+
+ from .. import _core
+ from ._unbounded_queue import UnboundedQueue
+
# fmt: off
"""
@@ -100,12 +114,35 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
"""
generated = [HEADER]
+
+ # This is only triggered by check.sh, not test_gen_exports.py
+ if lookup_path == "runner.io_manager": # pragma: no coverage
+ for file_indicator, platform in (
+ ("windows", "win32"),
+ ("kqueue", "darwin"),
+ ("epoll", "linux"),
+ ):
+ if file_indicator in source_path.stem:
+ generated.append(
+ f'assert not TYPE_CHECKING or sys.platform=="{platform}"'
+ )
+ break
+ else:
+ assert False
+
source = astor.code_to_ast.parse_file(source_path)
for method in get_public_methods(source):
# Remove self from arguments
assert method.args.args[0].arg == "self"
del method.args.args[0]
+ for dec in method.decorator_list:
+ if isinstance(dec, ast.Name) and dec.id == "contextmanager":
+ is_cm = True
+ break
+ else:
+ is_cm = False
+
# Remove decorators
method.decorator_list = []
@@ -122,6 +159,9 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
# Create the function definition including the body
func = astor.to_source(method, indent_with=" " * 4)
+ if is_cm:
+ func = func.replace("->Iterator", "->ContextManager")
+
# Create export function body
template = TEMPLATE.format(
" await " if isinstance(method, ast.AsyncFunctionDef) else " ",
From 326b21727d090a544284b9d7c720cab6ae58d718 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 6 Aug 2023 14:13:01 +0200
Subject: [PATCH 02/10] fix CI tests
---
trio/_tools/gen_exports.py | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index e887a7bbc0..3ec617e5ea 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -108,7 +108,7 @@ def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) ->
return "({})".format(", ".join(call_args))
-def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
+def gen_public_wrappers_source(source_path: Path | str, lookup_path: str) -> str:
"""Scan the given .py file for @_public decorators, and generate wrapper
functions.
@@ -116,7 +116,9 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
generated = [HEADER]
# This is only triggered by check.sh, not test_gen_exports.py
- if lookup_path == "runner.io_manager": # pragma: no coverage
+ if lookup_path == "runner.io_manager" and isinstance(
+ source_path, Path
+ ): # pragma: no coverage
for file_indicator, platform in (
("windows", "win32"),
("kqueue", "darwin"),
@@ -127,8 +129,6 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
f'assert not TYPE_CHECKING or sys.platform=="{platform}"'
)
break
- else:
- assert False
source = astor.code_to_ast.parse_file(source_path)
for method in get_public_methods(source):
@@ -189,7 +189,9 @@ def matches_disk_files(new_files: dict[str, str]) -> bool:
return True
-def process(sources_and_lookups: Iterable[tuple[Path, str]], *, do_test: bool) -> None:
+def process(
+ sources_and_lookups: Iterable[tuple[Path | str, str]], *, do_test: bool
+) -> None:
new_files = {}
for source_path, lookup_path in sources_and_lookups:
print("Scanning:", source_path)
From 6334bba99c62b0caf03af9ca3748f703273bda5e Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 6 Aug 2023 14:18:21 +0200
Subject: [PATCH 03/10] ignore coverage, undo unnecessary swap of decorator
order
---
trio/_core/_io_kqueue.py | 2 +-
trio/_tools/gen_exports.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py
index f47d6f1c7c..63528a40b2 100644
--- a/trio/_core/_io_kqueue.py
+++ b/trio/_core/_io_kqueue.py
@@ -108,8 +108,8 @@ def process_events(self, events: list[select.kevent]) -> None:
def current_kqueue(self) -> select.kqueue:
return self._kqueue
- @_public
@contextmanager
+ @_public
def monitor_kevent(
self, ident: int, filter: int
) -> Iterator[_core.UnboundedQueue[select.kevent]]:
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 3ec617e5ea..c3e0e3be08 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -136,7 +136,7 @@ def gen_public_wrappers_source(source_path: Path | str, lookup_path: str) -> str
assert method.args.args[0].arg == "self"
del method.args.args[0]
- for dec in method.decorator_list:
+ for dec in method.decorator_list: # pragma: no cover
if isinstance(dec, ast.Name) and dec.id == "contextmanager":
is_cm = True
break
@@ -159,7 +159,7 @@ def gen_public_wrappers_source(source_path: Path | str, lookup_path: str) -> str
# Create the function definition including the body
func = astor.to_source(method, indent_with=" " * 4)
- if is_cm:
+ if is_cm: # pragma: no cover
func = func.replace("->Iterator", "->ContextManager")
# Create export function body
From b9a6acb99858ad2f9920baea0e89b781ce2eb5de Mon Sep 17 00:00:00 2001
From: Spencer Brown
Date: Mon, 7 Aug 2023 17:32:21 +1000
Subject: [PATCH 04/10] Pass the platform in to gen_exports.process(), instead
of special casing
---
trio/_tests/tools/test_gen_exports.py | 11 +++---
trio/_tools/gen_exports.py | 52 +++++++++++++--------------
2 files changed, 31 insertions(+), 32 deletions(-)
diff --git a/trio/_tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py
index 9436105fa4..a95f9634e7 100644
--- a/trio/_tests/tools/test_gen_exports.py
+++ b/trio/_tests/tools/test_gen_exports.py
@@ -2,7 +2,7 @@
import pytest
-from trio._tools.gen_exports import create_passthrough_args, get_public_methods, process
+from trio._tools.gen_exports import File, create_passthrough_args, get_public_methods, process
SOURCE = '''from _run import _public
@@ -52,14 +52,15 @@ def test_process(tmp_path):
modpath = tmp_path / "_module.py"
genpath = tmp_path / "_generated_module.py"
modpath.write_text(SOURCE, encoding="utf-8")
+ file = File(modpath, "runner")
assert not genpath.exists()
with pytest.raises(SystemExit) as excinfo:
- process([(str(modpath), "runner")], do_test=True)
+ process([file], do_test=True)
assert excinfo.value.code == 1
- process([(str(modpath), "runner")], do_test=False)
+ process([file], do_test=False)
assert genpath.exists()
- process([(str(modpath), "runner")], do_test=True)
+ process([file], do_test=True)
# But if we change the lookup path it notices
with pytest.raises(SystemExit) as excinfo:
- process([(str(modpath), "runner.io_manager")], do_test=True)
+ process([File(modpath, "runner.io_manager")], do_test=True)
assert excinfo.value.code == 1
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index c3e0e3be08..506c922767 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -18,6 +18,7 @@
from typing_extensions import TypeGuard
import astor
+import attr
PREFIX = "_generated"
@@ -57,6 +58,13 @@
"""
+@attr.define
+class File:
+ path: Path
+ modname: str
+ platform: str = attr.field(default='', kw_only=True)
+
+
def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]:
"""Check if the AST node is either a function
or an async function
@@ -108,29 +116,19 @@ def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) ->
return "({})".format(", ".join(call_args))
-def gen_public_wrappers_source(source_path: Path | str, lookup_path: str) -> str:
+def gen_public_wrappers_source(file: File) -> str:
"""Scan the given .py file for @_public decorators, and generate wrapper
functions.
"""
generated = [HEADER]
- # This is only triggered by check.sh, not test_gen_exports.py
- if lookup_path == "runner.io_manager" and isinstance(
- source_path, Path
- ): # pragma: no coverage
- for file_indicator, platform in (
- ("windows", "win32"),
- ("kqueue", "darwin"),
- ("epoll", "linux"),
- ):
- if file_indicator in source_path.stem:
- generated.append(
- f'assert not TYPE_CHECKING or sys.platform=="{platform}"'
- )
- break
+ if file.platform:
+ generated.append(
+ f'assert not TYPE_CHECKING or sys.platform=="{file.platform}"'
+ )
- source = astor.code_to_ast.parse_file(source_path)
+ source = astor.code_to_ast.parse_file(file.path)
for method in get_public_methods(source):
# Remove self from arguments
assert method.args.args[0].arg == "self"
@@ -165,7 +163,7 @@ def gen_public_wrappers_source(source_path: Path | str, lookup_path: str) -> str
# Create export function body
template = TEMPLATE.format(
" await " if isinstance(method, ast.AsyncFunctionDef) else " ",
- lookup_path,
+ file.modname,
method.name + new_args,
)
@@ -190,13 +188,13 @@ def matches_disk_files(new_files: dict[str, str]) -> bool:
def process(
- sources_and_lookups: Iterable[tuple[Path | str, str]], *, do_test: bool
+ files: Iterable[File], *, do_test: bool
) -> None:
new_files = {}
- for source_path, lookup_path in sources_and_lookups:
- print("Scanning:", source_path)
- new_source = gen_public_wrappers_source(source_path, lookup_path)
- dirname, basename = os.path.split(source_path)
+ for file in files:
+ print("Scanning:", file.path)
+ new_source = gen_public_wrappers_source(file)
+ dirname, basename = os.path.split(file.path)
new_path = os.path.join(dirname, PREFIX + basename)
new_files[new_path] = new_source
if do_test:
@@ -228,11 +226,11 @@ def main() -> None: # pragma: no cover
assert (source_root / "LICENSE").exists()
core = source_root / "trio/_core"
to_wrap = [
- (core / "_run.py", "runner"),
- (core / "_instrumentation.py", "runner.instruments"),
- (core / "_io_windows.py", "runner.io_manager"),
- (core / "_io_epoll.py", "runner.io_manager"),
- (core / "_io_kqueue.py", "runner.io_manager"),
+ File(core / "_run.py", "runner"),
+ File(core / "_instrumentation.py", "runner.instruments"),
+ File(core / "_io_windows.py", "runner.io_manager", platform="win32"),
+ File(core / "_io_epoll.py", "runner.io_manager", platform="linux"),
+ File(core / "_io_kqueue.py", "runner.io_manager", platform="darwin"),
]
process(to_wrap, do_test=parsed_args.test)
From b5ca5dc6a5d316cf0143f132965d2ebcb496a170 Mon Sep 17 00:00:00 2001
From: Spencer Brown
Date: Mon, 7 Aug 2023 18:49:55 +1000
Subject: [PATCH 05/10] Specify individual sets of imports for each generated
module
This way they only import the things they actually use.
---
trio/_core/_generated_instrumentation.py | 20 ++-----
trio/_core/_generated_io_epoll.py | 24 +++-----
trio/_core/_generated_io_kqueue.py | 14 ++---
trio/_core/_generated_io_windows.py | 23 ++------
trio/_core/_generated_run.py | 20 ++-----
trio/_tests/tools/test_gen_exports.py | 7 ++-
trio/_tools/gen_exports.py | 70 ++++++++++++++++--------
7 files changed, 80 insertions(+), 98 deletions(-)
diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py
index b3f73f6616..605a6372f2 100644
--- a/trio/_core/_generated_instrumentation.py
+++ b/trio/_core/_generated_instrumentation.py
@@ -1,26 +1,14 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
+# Don't lint this file, generation will not format this too nicely.
# isort: skip_file
+# fmt: off
from __future__ import annotations
-import sys
-from typing import TYPE_CHECKING, Callable, ContextManager
-
-from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
-from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
-
-if TYPE_CHECKING:
- import select
- from socket import socket
-
- from ._traps import Abort, RaiseCancelT
-
- from .. import _core
- from ._unbounded_queue import UnboundedQueue
-
-# fmt: off
+from ._run import GLOBAL_RUN_CONTEXT
+from ._instrumentation import Instrument
def add_instrument(instrument: Instrument) ->None:
diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py
index 5ac1b99efb..abe49ed3ff 100644
--- a/trio/_core/_generated_io_epoll.py
+++ b/trio/_core/_generated_io_epoll.py
@@ -1,30 +1,20 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
+# Don't lint this file, generation will not format this too nicely.
# isort: skip_file
+# fmt: off
from __future__ import annotations
-import sys
-from typing import TYPE_CHECKING, Callable, ContextManager
-
-from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
-from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
-
-if TYPE_CHECKING:
- import select
- from socket import socket
-
- from ._traps import Abort, RaiseCancelT
-
- from .. import _core
- from ._unbounded_queue import UnboundedQueue
-
-# fmt: off
-
+from ._run import GLOBAL_RUN_CONTEXT
+from socket import socket
+from typing import TYPE_CHECKING
+import sys
assert not TYPE_CHECKING or sys.platform=="linux"
+
async def wait_readable(fd: (int | socket)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py
index 97199e4577..df28a3ae23 100644
--- a/trio/_core/_generated_io_kqueue.py
+++ b/trio/_core/_generated_io_kqueue.py
@@ -1,15 +1,14 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
+# Don't lint this file, generation will not format this too nicely.
# isort: skip_file
+# fmt: off
from __future__ import annotations
-import sys
-from typing import TYPE_CHECKING, Callable, ContextManager
-
-from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
-from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
+from ._run import GLOBAL_RUN_CONTEXT
+from typing import Callable, ContextManager, TYPE_CHECKING
if TYPE_CHECKING:
import select
@@ -19,12 +18,11 @@
from .. import _core
from ._unbounded_queue import UnboundedQueue
-
-# fmt: off
-
+import sys
assert not TYPE_CHECKING or sys.platform=="darwin"
+
def current_kqueue() ->select.kqueue:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py
index 59e14b185d..7fa6fd5126 100644
--- a/trio/_core/_generated_io_windows.py
+++ b/trio/_core/_generated_io_windows.py
@@ -1,30 +1,19 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
+# Don't lint this file, generation will not format this too nicely.
# isort: skip_file
+# fmt: off
from __future__ import annotations
-import sys
-from typing import TYPE_CHECKING, Callable, ContextManager
-
-from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
-from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
-
-if TYPE_CHECKING:
- import select
- from socket import socket
-
- from ._traps import Abort, RaiseCancelT
-
- from .. import _core
- from ._unbounded_queue import UnboundedQueue
-
-# fmt: off
-
+from ._run import GLOBAL_RUN_CONTEXT
+from typing import TYPE_CHECKING
+import sys
assert not TYPE_CHECKING or sys.platform=="win32"
+
async def wait_readable(sock):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py
index c5b78beadd..674c86aaec 100644
--- a/trio/_core/_generated_run.py
+++ b/trio/_core/_generated_run.py
@@ -1,26 +1,14 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
+# Don't lint this file, generation will not format this too nicely.
# isort: skip_file
+# fmt: off
from __future__ import annotations
-import sys
-from typing import TYPE_CHECKING, Callable, ContextManager
-
-from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
-from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
-
-if TYPE_CHECKING:
- import select
- from socket import socket
-
- from ._traps import Abort, RaiseCancelT
-
- from .. import _core
- from ._unbounded_queue import UnboundedQueue
-
-# fmt: off
+from ._run import GLOBAL_RUN_CONTEXT
+from ._run import _NO_SEND
def current_statistics():
diff --git a/trio/_tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py
index a95f9634e7..58e63ff9bc 100644
--- a/trio/_tests/tools/test_gen_exports.py
+++ b/trio/_tests/tools/test_gen_exports.py
@@ -5,6 +5,7 @@
from trio._tools.gen_exports import File, create_passthrough_args, get_public_methods, process
SOURCE = '''from _run import _public
+from somewhere import Thing
class Test:
@_public
@@ -14,7 +15,7 @@ def public_func(self):
@ignore_this
@_public
@another_decorator
- async def public_async_func(self):
+ async def public_async_func(self) -> Thing:
pass # no doc string
def not_public(self):
@@ -24,6 +25,10 @@ async def not_public_async(self):
pass
'''
+IMPORTS = '''\
+from somewhere import Thing
+'''
+
def test_get_public_methods():
methods = list(get_public_methods(ast.parse(SOURCE)))
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 506c922767..121bfc0a9f 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -25,26 +25,13 @@
HEADER = """# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
+# Don't lint this file, generation will not format this too nicely.
# isort: skip_file
+# fmt: off
from __future__ import annotations
-import sys
-from typing import TYPE_CHECKING, Callable, ContextManager
-
-from ._instrumentation import Instrument
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
-from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT
-
-if TYPE_CHECKING:
- import select
- from socket import socket
-
- from ._traps import Abort, RaiseCancelT
-
- from .. import _core
- from ._unbounded_queue import UnboundedQueue
-
-# fmt: off
+from ._run import GLOBAL_RUN_CONTEXT
"""
FOOTER = """# fmt: on
@@ -63,6 +50,7 @@ class File:
path: Path
modname: str
platform: str = attr.field(default='', kw_only=True)
+ imports: str = attr.field(default='', kw_only=True)
def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]:
@@ -121,13 +109,23 @@ def gen_public_wrappers_source(file: File) -> str:
functions.
"""
- generated = [HEADER]
+ header = [HEADER]
+ if file.imports:
+ header.append(file.imports)
if file.platform:
- generated.append(
- f'assert not TYPE_CHECKING or sys.platform=="{file.platform}"'
+ # Simple checks to avoid repeating imports. If this messes up, type checkers/tests will
+ # just give errors.
+ if 'TYPE_CHECKING' not in file.imports:
+ header.append('from typing import TYPE_CHECKING\n')
+ if 'import sys' not in file.imports:
+ header.append('import sys\n')
+ header.append(
+ f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n'
)
+ generated = [''.join(header)]
+
source = astor.code_to_ast.parse_file(file.path)
for method in get_public_methods(source):
# Remove self from arguments
@@ -226,15 +224,41 @@ def main() -> None: # pragma: no cover
assert (source_root / "LICENSE").exists()
core = source_root / "trio/_core"
to_wrap = [
- File(core / "_run.py", "runner"),
- File(core / "_instrumentation.py", "runner.instruments"),
+ File(core / "_run.py", "runner", imports=IMPORTS_RUN),
+ File(core / "_instrumentation.py", "runner.instruments", imports=IMPORTS_INSTRUMENT),
File(core / "_io_windows.py", "runner.io_manager", platform="win32"),
- File(core / "_io_epoll.py", "runner.io_manager", platform="linux"),
- File(core / "_io_kqueue.py", "runner.io_manager", platform="darwin"),
+ File(core / "_io_epoll.py", "runner.io_manager", platform="linux", imports=IMPORTS_EPOLL),
+ File(core / "_io_kqueue.py", "runner.io_manager", platform="darwin", imports=IMPORTS_KQUEUE),
]
process(to_wrap, do_test=parsed_args.test)
+IMPORTS_RUN = '''\
+from ._run import _NO_SEND
+'''
+IMPORTS_INSTRUMENT = '''\
+from ._instrumentation import Instrument
+'''
+
+IMPORTS_EPOLL = '''\
+from socket import socket
+'''
+
+IMPORTS_KQUEUE = '''\
+from typing import Callable, ContextManager, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import select
+ from socket import socket
+
+ from ._traps import Abort, RaiseCancelT
+
+ from .. import _core
+ from ._unbounded_queue import UnboundedQueue
+
+'''
+
+
if __name__ == "__main__": # pragma: no cover
main()
From 46ed04a497a3ab8fb775c8fd96c7564958c9149a Mon Sep 17 00:00:00 2001
From: Spencer Brown
Date: Mon, 7 Aug 2023 19:14:09 +1000
Subject: [PATCH 06/10] Expand the tests
---
trio/_tests/tools/test_gen_exports.py | 36 +++++++++++++++---
trio/_tools/gen_exports.py | 54 ++++++++++++++++-----------
2 files changed, 63 insertions(+), 27 deletions(-)
diff --git a/trio/_tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py
index 58e63ff9bc..7fdce51daf 100644
--- a/trio/_tests/tools/test_gen_exports.py
+++ b/trio/_tests/tools/test_gen_exports.py
@@ -2,7 +2,12 @@
import pytest
-from trio._tools.gen_exports import File, create_passthrough_args, get_public_methods, process
+from trio._tools.gen_exports import (
+ File,
+ create_passthrough_args,
+ get_public_methods,
+ process,
+)
SOURCE = '''from _run import _public
from somewhere import Thing
@@ -25,9 +30,20 @@ async def not_public_async(self):
pass
'''
-IMPORTS = '''\
+IMPORT_1 = """\
from somewhere import Thing
-'''
+"""
+
+IMPORT_2 = """\
+from somewhere import Thing
+import os
+"""
+
+IMPORT_3 = """\
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from somewhere import Thing
+"""
def test_get_public_methods():
@@ -53,11 +69,12 @@ def test_create_pass_through_args():
assert create_passthrough_args(func_node) == expected
-def test_process(tmp_path):
+@pytest.mark.parametrize("imports", ['', IMPORT_1, IMPORT_2, IMPORT_3])
+def test_process(tmp_path, imports):
modpath = tmp_path / "_module.py"
genpath = tmp_path / "_generated_module.py"
modpath.write_text(SOURCE, encoding="utf-8")
- file = File(modpath, "runner")
+ file = File(modpath, "runner", platform="linux", imports=imports)
assert not genpath.exists()
with pytest.raises(SystemExit) as excinfo:
process([file], do_test=True)
@@ -67,5 +84,12 @@ def test_process(tmp_path):
process([file], do_test=True)
# But if we change the lookup path it notices
with pytest.raises(SystemExit) as excinfo:
- process([File(modpath, "runner.io_manager")], do_test=True)
+ process(
+ [File(modpath, "runner.io_manager", platform="linux", imports=imports)],
+ do_test=True,
+ )
+ assert excinfo.value.code == 1
+ # Also if the platform is changed.
+ with pytest.raises(SystemExit) as excinfo:
+ process([File(modpath, "runner", imports=imports)], do_test=True)
assert excinfo.value.code == 1
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 121bfc0a9f..ad41223609 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -49,8 +49,8 @@
class File:
path: Path
modname: str
- platform: str = attr.field(default='', kw_only=True)
- imports: str = attr.field(default='', kw_only=True)
+ platform: str = attr.field(default="", kw_only=True)
+ imports: str = attr.field(default="", kw_only=True)
def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]:
@@ -116,15 +116,15 @@ def gen_public_wrappers_source(file: File) -> str:
if file.platform:
# Simple checks to avoid repeating imports. If this messes up, type checkers/tests will
# just give errors.
- if 'TYPE_CHECKING' not in file.imports:
- header.append('from typing import TYPE_CHECKING\n')
- if 'import sys' not in file.imports:
- header.append('import sys\n')
+ if "TYPE_CHECKING" not in file.imports:
+ header.append("from typing import TYPE_CHECKING\n")
+ if "import sys" not in file.imports:
+ header.append("import sys\n")
header.append(
f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n'
)
- generated = [''.join(header)]
+ generated = ["".join(header)]
source = astor.code_to_ast.parse_file(file.path)
for method in get_public_methods(source):
@@ -185,9 +185,7 @@ def matches_disk_files(new_files: dict[str, str]) -> bool:
return True
-def process(
- files: Iterable[File], *, do_test: bool
-) -> None:
+def process(files: Iterable[File], *, do_test: bool) -> None:
new_files = {}
for file in files:
print("Scanning:", file.path)
@@ -225,27 +223,41 @@ def main() -> None: # pragma: no cover
core = source_root / "trio/_core"
to_wrap = [
File(core / "_run.py", "runner", imports=IMPORTS_RUN),
- File(core / "_instrumentation.py", "runner.instruments", imports=IMPORTS_INSTRUMENT),
+ File(
+ core / "_instrumentation.py",
+ "runner.instruments",
+ imports=IMPORTS_INSTRUMENT,
+ ),
File(core / "_io_windows.py", "runner.io_manager", platform="win32"),
- File(core / "_io_epoll.py", "runner.io_manager", platform="linux", imports=IMPORTS_EPOLL),
- File(core / "_io_kqueue.py", "runner.io_manager", platform="darwin", imports=IMPORTS_KQUEUE),
+ File(
+ core / "_io_epoll.py",
+ "runner.io_manager",
+ platform="linux",
+ imports=IMPORTS_EPOLL,
+ ),
+ File(
+ core / "_io_kqueue.py",
+ "runner.io_manager",
+ platform="darwin",
+ imports=IMPORTS_KQUEUE,
+ ),
]
process(to_wrap, do_test=parsed_args.test)
-IMPORTS_RUN = '''\
+IMPORTS_RUN = """\
from ._run import _NO_SEND
-'''
-IMPORTS_INSTRUMENT = '''\
+"""
+IMPORTS_INSTRUMENT = """\
from ._instrumentation import Instrument
-'''
+"""
-IMPORTS_EPOLL = '''\
+IMPORTS_EPOLL = """\
from socket import socket
-'''
+"""
-IMPORTS_KQUEUE = '''\
+IMPORTS_KQUEUE = """\
from typing import Callable, ContextManager, TYPE_CHECKING
if TYPE_CHECKING:
@@ -257,7 +269,7 @@ def main() -> None: # pragma: no cover
from .. import _core
from ._unbounded_queue import UnboundedQueue
-'''
+"""
if __name__ == "__main__": # pragma: no cover
From 7a0a66b24c9fa4a95d82ea3cacad7b32c4f5101f Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 8 Aug 2023 16:43:03 +0200
Subject: [PATCH 07/10] remove F401 exception for generated files and remove
unnecessary import
---
pyproject.toml | 1 -
trio/_core/_generated_io_kqueue.py | 2 +-
trio/_tools/gen_exports.py | 3 +--
3 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index ba10d4c532..fcc09a3b8b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,7 +9,6 @@ extend-ignore = ['D', 'E', 'W', 'F403', 'F405', 'F821', 'F822']
per-file-ignores = [
'trio/__init__.py: F401',
'trio/_core/__init__.py: F401',
- 'trio/_core/_generated*.py: F401',
'trio/_core/_tests/test_multierror_scripts/*: F401',
'trio/abc.py: F401',
'trio/lowlevel.py: F401',
diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py
index df28a3ae23..cfcf6354c7 100644
--- a/trio/_core/_generated_io_kqueue.py
+++ b/trio/_core/_generated_io_kqueue.py
@@ -17,7 +17,7 @@
from ._traps import Abort, RaiseCancelT
from .. import _core
- from ._unbounded_queue import UnboundedQueue
+
import sys
assert not TYPE_CHECKING or sys.platform=="darwin"
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index ad41223609..f3ed2e26e7 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -118,7 +118,7 @@ def gen_public_wrappers_source(file: File) -> str:
# just give errors.
if "TYPE_CHECKING" not in file.imports:
header.append("from typing import TYPE_CHECKING\n")
- if "import sys" not in file.imports:
+ if "import sys" not in file.imports: # pragma: no cover
header.append("import sys\n")
header.append(
f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n'
@@ -267,7 +267,6 @@ def main() -> None: # pragma: no cover
from ._traps import Abort, RaiseCancelT
from .. import _core
- from ._unbounded_queue import UnboundedQueue
"""
From c192655173efa67723c8fe757bba7ebdaca5e6fd Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 8 Aug 2023 17:06:25 +0200
Subject: [PATCH 08/10] change backend to a Literal
---
trio/_core/_io_epoll.py | 4 ++--
trio/_core/_io_kqueue.py | 4 ++--
trio/_core/_io_windows.py | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index ff873ee966..44c405f0c0 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -3,7 +3,7 @@
import select
import sys
from collections import defaultdict
-from typing import TYPE_CHECKING, DefaultDict
+from typing import TYPE_CHECKING, DefaultDict, Literal
import attr
@@ -32,7 +32,7 @@ class EpollWaiters:
class _EpollStatistics:
tasks_waiting_read: int = attr.ib()
tasks_waiting_write: int = attr.ib()
- backend: str = attr.ib(default="epoll")
+ backend: Literal["epoll", "kqueue", "windows"] = attr.ib(default="epoll")
# Some facts about epoll
diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py
index 63528a40b2..bdbbc90766 100644
--- a/trio/_core/_io_kqueue.py
+++ b/trio/_core/_io_kqueue.py
@@ -4,7 +4,7 @@
import select
import sys
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Callable, Iterator
+from typing import TYPE_CHECKING, Callable, Iterator, Literal
import attr
import outcome
@@ -25,7 +25,7 @@
class _KqueueStatistics:
tasks_waiting: int = attr.ib()
monitors: int = attr.ib()
- backend: str = attr.ib(default="kqueue")
+ backend: Literal["epoll", "kqueue", "windows"] = attr.ib(default="kqueue")
@attr.s(slots=True, eq=False)
diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py
index 4084f72b6e..866e7e5305 100644
--- a/trio/_core/_io_windows.py
+++ b/trio/_core/_io_windows.py
@@ -3,7 +3,7 @@
import socket
import sys
from contextlib import contextmanager
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
import attr
from outcome import Value
@@ -369,7 +369,7 @@ class _WindowsStatistics:
tasks_waiting_write = attr.ib()
tasks_waiting_overlapped = attr.ib()
completion_key_monitors = attr.ib()
- backend = attr.ib(default="windows")
+ backend: Literal["epoll", "kqueue", "windows"] = attr.ib(default="windows")
# Maximum number of events to dequeue from the completion port on each pass
From d74db2e3b512b5c436de6f524493932343b9b5d3 Mon Sep 17 00:00:00 2001
From: Spencer Brown
Date: Wed, 9 Aug 2023 15:10:10 +1000
Subject: [PATCH 09/10] Reformat
(My fault)
---
trio/_tests/tools/test_gen_exports.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trio/_tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py
index 7fdce51daf..7d2d6e99a1 100644
--- a/trio/_tests/tools/test_gen_exports.py
+++ b/trio/_tests/tools/test_gen_exports.py
@@ -69,7 +69,7 @@ def test_create_pass_through_args():
assert create_passthrough_args(func_node) == expected
-@pytest.mark.parametrize("imports", ['', IMPORT_1, IMPORT_2, IMPORT_3])
+@pytest.mark.parametrize("imports", ["", IMPORT_1, IMPORT_2, IMPORT_3])
def test_process(tmp_path, imports):
modpath = tmp_path / "_module.py"
genpath = tmp_path / "_generated_module.py"
From 70fcae91385f7ceb34202230dbb078101483243c Mon Sep 17 00:00:00 2001
From: Spencer Brown
Date: Wed, 9 Aug 2023 15:27:58 +1000
Subject: [PATCH 10/10] We know statically what backend is being specified
here.
---
trio/_core/_io_epoll.py | 2 +-
trio/_core/_io_kqueue.py | 2 +-
trio/_core/_io_windows.py | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index 44c405f0c0..c4a31f3722 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -32,7 +32,7 @@ class EpollWaiters:
class _EpollStatistics:
tasks_waiting_read: int = attr.ib()
tasks_waiting_write: int = attr.ib()
- backend: Literal["epoll", "kqueue", "windows"] = attr.ib(default="epoll")
+ backend: Literal["epoll"] = attr.ib(init=False, default="epoll")
# Some facts about epoll
diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py
index bdbbc90766..0b0f8ee557 100644
--- a/trio/_core/_io_kqueue.py
+++ b/trio/_core/_io_kqueue.py
@@ -25,7 +25,7 @@
class _KqueueStatistics:
tasks_waiting: int = attr.ib()
monitors: int = attr.ib()
- backend: Literal["epoll", "kqueue", "windows"] = attr.ib(default="kqueue")
+ backend: Literal["kqueue"] = attr.ib(init=False, default="kqueue")
@attr.s(slots=True, eq=False)
diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py
index 866e7e5305..0130170af3 100644
--- a/trio/_core/_io_windows.py
+++ b/trio/_core/_io_windows.py
@@ -369,7 +369,7 @@ class _WindowsStatistics:
tasks_waiting_write = attr.ib()
tasks_waiting_overlapped = attr.ib()
completion_key_monitors = attr.ib()
- backend: Literal["epoll", "kqueue", "windows"] = attr.ib(default="windows")
+ backend: Literal["windows"] = attr.ib(init=False, default="windows")
# Maximum number of events to dequeue from the completion port on each pass