diff --git a/pyproject.toml b/pyproject.toml index a23f7f5db9..fc7d37dbf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,6 @@ extend-ignore = ['D', 'E', 'W', 'F403', 'F405', 'F821', 'F822'] per-file-ignores = [ 'trio/__init__.py: F401', 'trio/_core/__init__.py: F401', - 'trio/_core/_generated*.py: F401', 'trio/_core/_tests/test_multierror_scripts/*: F401', 'trio/abc.py: F401', 'trio/lowlevel.py: F401', @@ -51,6 +50,10 @@ disallow_untyped_defs = false module = [ "trio._abc", "trio._core._entry_queue", + "trio._core._generated_io_epoll", + "trio._core._generated_io_kqueue", + "trio._core._io_epoll", + "trio._core._io_kqueue", "trio._core._local", "trio._core._unbounded_queue", "trio._core._thread_cache", diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index 30c2f26b4e..605a6372f2 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -1,12 +1,14 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip -from ._instrumentation import Instrument -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT - +# Don't lint this file, generation will not format this too nicely. +# isort: skip_file # fmt: off +from __future__ import annotations + +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import GLOBAL_RUN_CONTEXT +from ._instrumentation import Instrument def add_instrument(instrument: Instrument) ->None: diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index 02fb3bc348..abe49ed3ff 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -1,15 +1,21 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip -from ._instrumentation import Instrument +# Don't lint this file, generation will not format this too nicely. +# isort: skip_file +# fmt: off +from __future__ import annotations + from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT +from ._run import GLOBAL_RUN_CONTEXT +from socket import socket +from typing import TYPE_CHECKING +import sys -# fmt: off +assert not TYPE_CHECKING or sys.platform=="linux" -async def wait_readable(fd): +async def wait_readable(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -17,7 +23,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -25,7 +31,7 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 94e819769c..cfcf6354c7 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -1,15 +1,29 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip -from ._instrumentation import Instrument +# Don't lint this file, generation will not format this too nicely. +# isort: skip_file +# fmt: off +from __future__ import annotations + from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT +from ._run import GLOBAL_RUN_CONTEXT +from typing import Callable, ContextManager, TYPE_CHECKING -# fmt: off +if TYPE_CHECKING: + import select + from socket import socket + + from ._traps import Abort, RaiseCancelT + + from .. import _core + +import sys + +assert not TYPE_CHECKING or sys.platform=="darwin" -def current_kqueue(): +def current_kqueue() ->select.kqueue: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() @@ -17,7 +31,8 @@ def current_kqueue(): raise RuntimeError("must be called from async context") -def monitor_kevent(ident, filter): +def monitor_kevent(ident: int, filter: int) ->ContextManager[_core.UnboundedQueue + [select.kevent]]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) @@ -25,7 +40,8 @@ def monitor_kevent(ident, filter): raise RuntimeError("must be called from async context") -async def wait_kevent(ident, filter, abort_func): +async def wait_kevent(ident: int, filter: int, abort_func: Callable[[ + RaiseCancelT], Abort]) ->Abort: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func) @@ -33,7 +49,7 @@ async def wait_kevent(ident, filter, abort_func): raise RuntimeError("must be called from async context") -async def wait_readable(fd): +async def wait_readable(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -41,7 +57,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -49,7 +65,7 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: (int | socket)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 26b4da697d..7fa6fd5126 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -1,12 +1,17 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip -from ._instrumentation import Instrument +# Don't lint this file, generation will not format this too nicely. +# isort: skip_file +# fmt: off +from __future__ import annotations + from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT +from ._run import GLOBAL_RUN_CONTEXT +from typing import TYPE_CHECKING +import sys -# fmt: off +assert not TYPE_CHECKING or sys.platform=="win32" async def wait_readable(sock): diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index d1e74a93f4..674c86aaec 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -1,12 +1,14 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip -from ._instrumentation import Instrument -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT - +# Don't lint this file, generation will not format this too nicely. +# isort: skip_file # fmt: off +from __future__ import annotations + +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import GLOBAL_RUN_CONTEXT +from ._run import _NO_SEND def current_statistics(): diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 376dd18a4e..c4a31f3722 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -1,23 +1,38 @@ +from __future__ import annotations + import select import sys from collections import defaultdict -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, DefaultDict, Literal import attr from .. import _core from ._io_common import wake_all -from ._run import _public +from ._run import Task, _public from ._wakeup_socketpair import WakeupSocketpair +if TYPE_CHECKING: + from socket import socket + + from .._core import Abort, RaiseCancelT + + +@attr.s(slots=True, eq=False) +class EpollWaiters: + read_task: Task | None = attr.ib(default=None) + write_task: Task | None = attr.ib(default=None) + current_flags: int = attr.ib(default=0) + + assert not TYPE_CHECKING or sys.platform == "linux" @attr.s(slots=True, eq=False, frozen=True) class _EpollStatistics: - tasks_waiting_read = attr.ib() - tasks_waiting_write = attr.ib() - backend = attr.ib(default="epoll") + tasks_waiting_read: int = attr.ib() + tasks_waiting_write: int = attr.ib() + backend: Literal["epoll"] = attr.ib(init=False, default="epoll") # Some facts about epoll @@ -178,28 +193,21 @@ class _EpollStatistics: # wanted to about how epoll works. -@attr.s(slots=True, eq=False) -class EpollWaiters: - read_task = attr.ib(default=None) - write_task = attr.ib(default=None) - current_flags = attr.ib(default=0) - - @attr.s(slots=True, eq=False, hash=False) class EpollIOManager: - _epoll = attr.ib(factory=select.epoll) + _epoll: select.epoll = attr.ib(factory=select.epoll) # {fd: EpollWaiters} - _registered = attr.ib( - factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters] + _registered: DefaultDict[int, EpollWaiters] = attr.ib( + factory=lambda: defaultdict(EpollWaiters) ) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: int | None = attr.ib(default=None) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() - def statistics(self): + def statistics(self) -> _EpollStatistics: tasks_waiting_read = 0 tasks_waiting_write = 0 for waiter in self._registered.values(): @@ -212,24 +220,24 @@ def statistics(self): tasks_waiting_write=tasks_waiting_write, ) - def close(self): + def close(self) -> None: self._epoll.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() # Return value must be False-y IFF the timeout expired, NOT if any I/O # happened or force_wakeup was called. Otherwise it can be anything; gets # passed straight through to process_events. - def get_events(self, timeout): + def get_events(self, timeout: float) -> list[tuple[int, int]]: # max_events must be > 0 or epoll gets cranky # accessing self._registered from a thread looks dangerous, but it's # OK because it doesn't matter if our value is a little bit off. max_events = max(1, len(self._registered)) return self._epoll.poll(timeout, max_events) - def process_events(self, events): + def process_events(self, events: list[tuple[int, int]]) -> None: for fd, flags in events: if fd == self._force_wakeup_fd: self._force_wakeup.drain() @@ -248,7 +256,7 @@ def process_events(self, events): waiters.read_task = None self._update_registrations(fd) - def _update_registrations(self, fd): + def _update_registrations(self, fd: int) -> None: waiters = self._registered[fd] wanted_flags = 0 if waiters.read_task is not None: @@ -277,7 +285,7 @@ def _update_registrations(self, fd): if not wanted_flags: del self._registered[fd] - async def _epoll_wait(self, fd, attr_name): + async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None: if not isinstance(fd, int): fd = fd.fileno() waiters = self._registered[fd] @@ -288,7 +296,7 @@ async def _epoll_wait(self, fd, attr_name): setattr(waiters, attr_name, _core.current_task()) self._update_registrations(fd) - def abort(_): + def abort(_: RaiseCancelT) -> Abort: setattr(waiters, attr_name, None) self._update_registrations(fd) return _core.Abort.SUCCEEDED @@ -296,15 +304,15 @@ def abort(_): await _core.wait_task_rescheduled(abort) @_public - async def wait_readable(self, fd): + async def wait_readable(self, fd: int | socket) -> None: await self._epoll_wait(fd, "read_task") @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: int | socket) -> None: await self._epoll_wait(fd, "write_task") @_public - def notify_closing(self, fd): + def notify_closing(self, fd: int | socket) -> None: if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index d1151843e8..0b0f8ee557 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import errno import select import sys from contextlib import contextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Iterator, Literal import attr import outcome @@ -11,32 +13,39 @@ from ._run import _public from ._wakeup_socketpair import WakeupSocketpair +if TYPE_CHECKING: + from socket import socket + + from .._core import Abort, RaiseCancelT, Task, UnboundedQueue + assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") @attr.s(slots=True, eq=False, frozen=True) class _KqueueStatistics: - tasks_waiting = attr.ib() - monitors = attr.ib() - backend = attr.ib(default="kqueue") + tasks_waiting: int = attr.ib() + monitors: int = attr.ib() + backend: Literal["kqueue"] = attr.ib(init=False, default="kqueue") @attr.s(slots=True, eq=False) class KqueueIOManager: - _kqueue = attr.ib(factory=select.kqueue) + _kqueue: select.kqueue = attr.ib(factory=select.kqueue) # {(ident, filter): Task or UnboundedQueue} - _registered = attr.ib(factory=dict) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) + _registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = attr.ib( + factory=dict + ) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: int | None = attr.ib(default=None) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: force_wakeup_event = select.kevent( self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD ) self._kqueue.control([force_wakeup_event], 0) self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() - def statistics(self): + def statistics(self) -> _KqueueStatistics: tasks_waiting = 0 monitors = 0 for receiver in self._registered.values(): @@ -46,14 +55,14 @@ def statistics(self): monitors += 1 return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors) - def close(self): + def close(self) -> None: self._kqueue.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() - def get_events(self, timeout): + def get_events(self, timeout: float) -> list[select.kevent]: # max_events must be > 0 or kqueue gets cranky # and we generally want this to be strictly larger than the actual # number of events we get, so that we can tell that we've gotten @@ -70,7 +79,7 @@ def get_events(self, timeout): # and loop back to the start return events - def process_events(self, events): + def process_events(self, events: list[select.kevent]) -> None: for event in events: key = (event.ident, event.filter) if event.ident == self._force_wakeup_fd: @@ -79,7 +88,7 @@ def process_events(self, events): receiver = self._registered[key] if event.flags & select.KQ_EV_ONESHOT: del self._registered[key] - if type(receiver) is _core.Task: + if isinstance(receiver, _core.Task): _core.reschedule(receiver, outcome.Value(event)) else: receiver.put_nowait(event) @@ -96,18 +105,20 @@ def process_events(self, events): # be more ergonomic... @_public - def current_kqueue(self): + def current_kqueue(self) -> select.kqueue: return self._kqueue @contextmanager @_public - def monitor_kevent(self, ident, filter): + def monitor_kevent( + self, ident: int, filter: int + ) -> Iterator[_core.UnboundedQueue[select.kevent]]: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( "attempt to register multiple listeners for same ident/filter pair" ) - q = _core.UnboundedQueue() + q = _core.UnboundedQueue[select.kevent]() self._registered[key] = q try: yield q @@ -115,7 +126,9 @@ def monitor_kevent(self, ident, filter): del self._registered[key] @_public - async def wait_kevent(self, ident, filter, abort_func): + async def wait_kevent( + self, ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort] + ) -> Abort: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( @@ -123,22 +136,23 @@ async def wait_kevent(self, ident, filter, abort_func): ) self._registered[key] = _core.current_task() - def abort(raise_cancel): + def abort(raise_cancel: RaiseCancelT) -> Abort: r = abort_func(raise_cancel) if r is _core.Abort.SUCCEEDED: del self._registered[key] return r - return await _core.wait_task_rescheduled(abort) + # wait_task_rescheduled does not have its return type typed + return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return] - async def _wait_common(self, fd, filter): + async def _wait_common(self, fd: int | socket, filter: int) -> None: if not isinstance(fd, int): fd = fd.fileno() flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT event = select.kevent(fd, filter, flags) self._kqueue.control([event], 0) - def abort(_): + def abort(_: RaiseCancelT) -> Abort: event = select.kevent(fd, filter, select.KQ_EV_DELETE) try: self._kqueue.control([event], 0) @@ -163,15 +177,15 @@ def abort(_): await self.wait_kevent(fd, filter, abort) @_public - async def wait_readable(self, fd): + async def wait_readable(self, fd: int | socket) -> None: await self._wait_common(fd, select.KQ_FILTER_READ) @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: int | socket) -> None: await self._wait_common(fd, select.KQ_FILTER_WRITE) @_public - def notify_closing(self, fd): + def notify_closing(self, fd: int | socket) -> None: if not isinstance(fd, int): fd = fd.fileno() diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 4084f72b6e..0130170af3 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -3,7 +3,7 @@ import socket import sys from contextlib import contextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import attr from outcome import Value @@ -369,7 +369,7 @@ class _WindowsStatistics: tasks_waiting_write = attr.ib() tasks_waiting_overlapped = attr.ib() completion_key_monitors = attr.ib() - backend = attr.ib(default="windows") + backend: Literal["windows"] = attr.ib(init=False, default="windows") # Maximum number of events to dequeue from the completion port on each pass diff --git a/trio/_core/_run.py b/trio/_core/_run.py index ce8feb2827..7d247a2738 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1433,7 +1433,7 @@ def in_main_thread(): class Runner: clock = attr.ib() instruments: Instruments = attr.ib() - io_manager = attr.ib() + io_manager: TheIOManager = attr.ib() ki_manager = attr.ib() strict_exception_groups = attr.ib() diff --git a/trio/_tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py index 9436105fa4..7d2d6e99a1 100644 --- a/trio/_tests/tools/test_gen_exports.py +++ b/trio/_tests/tools/test_gen_exports.py @@ -2,9 +2,15 @@ import pytest -from trio._tools.gen_exports import create_passthrough_args, get_public_methods, process +from trio._tools.gen_exports import ( + File, + create_passthrough_args, + get_public_methods, + process, +) SOURCE = '''from _run import _public +from somewhere import Thing class Test: @_public @@ -14,7 +20,7 @@ def public_func(self): @ignore_this @_public @another_decorator - async def public_async_func(self): + async def public_async_func(self) -> Thing: pass # no doc string def not_public(self): @@ -24,6 +30,21 @@ async def not_public_async(self): pass ''' +IMPORT_1 = """\ +from somewhere import Thing +""" + +IMPORT_2 = """\ +from somewhere import Thing +import os +""" + +IMPORT_3 = """\ +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from somewhere import Thing +""" + def test_get_public_methods(): methods = list(get_public_methods(ast.parse(SOURCE))) @@ -48,18 +69,27 @@ def test_create_pass_through_args(): assert create_passthrough_args(func_node) == expected -def test_process(tmp_path): +@pytest.mark.parametrize("imports", ["", IMPORT_1, IMPORT_2, IMPORT_3]) +def test_process(tmp_path, imports): modpath = tmp_path / "_module.py" genpath = tmp_path / "_generated_module.py" modpath.write_text(SOURCE, encoding="utf-8") + file = File(modpath, "runner", platform="linux", imports=imports) assert not genpath.exists() with pytest.raises(SystemExit) as excinfo: - process([(str(modpath), "runner")], do_test=True) + process([file], do_test=True) assert excinfo.value.code == 1 - process([(str(modpath), "runner")], do_test=False) + process([file], do_test=False) assert genpath.exists() - process([(str(modpath), "runner")], do_test=True) + process([file], do_test=True) # But if we change the lookup path it notices with pytest.raises(SystemExit) as excinfo: - process([(str(modpath), "runner.io_manager")], do_test=True) + process( + [File(modpath, "runner.io_manager", platform="linux", imports=imports)], + do_test=True, + ) + assert excinfo.value.code == 1 + # Also if the platform is changed. + with pytest.raises(SystemExit) as excinfo: + process([File(modpath, "runner", imports=imports)], do_test=True) assert excinfo.value.code == 1 diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 5a0c59e33e..ec345facf3 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9202551834130781, + "completenessScore": 0.9250398724082934, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 577, - "withUnknownType": 50 + "withKnownType": 580, + "withUnknownType": 47 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index bae7e4f69d..f3ed2e26e7 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -18,18 +18,20 @@ from typing_extensions import TypeGuard import astor +import attr PREFIX = "_generated" HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -# isort: skip -from ._instrumentation import Instrument -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT - +# Don't lint this file, generation will not format this too nicely. +# isort: skip_file # fmt: off +from __future__ import annotations + +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import GLOBAL_RUN_CONTEXT """ FOOTER = """# fmt: on @@ -43,6 +45,14 @@ """ +@attr.define +class File: + path: Path + modname: str + platform: str = attr.field(default="", kw_only=True) + imports: str = attr.field(default="", kw_only=True) + + def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node is either a function or an async function @@ -94,18 +104,41 @@ def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> return "({})".format(", ".join(call_args)) -def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: +def gen_public_wrappers_source(file: File) -> str: """Scan the given .py file for @_public decorators, and generate wrapper functions. """ - generated = [HEADER] - source = astor.code_to_ast.parse_file(source_path) + header = [HEADER] + + if file.imports: + header.append(file.imports) + if file.platform: + # Simple checks to avoid repeating imports. If this messes up, type checkers/tests will + # just give errors. + if "TYPE_CHECKING" not in file.imports: + header.append("from typing import TYPE_CHECKING\n") + if "import sys" not in file.imports: # pragma: no cover + header.append("import sys\n") + header.append( + f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n' + ) + + generated = ["".join(header)] + + source = astor.code_to_ast.parse_file(file.path) for method in get_public_methods(source): # Remove self from arguments assert method.args.args[0].arg == "self" del method.args.args[0] + for dec in method.decorator_list: # pragma: no cover + if isinstance(dec, ast.Name) and dec.id == "contextmanager": + is_cm = True + break + else: + is_cm = False + # Remove decorators method.decorator_list = [] @@ -122,10 +155,13 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: # Create the function definition including the body func = astor.to_source(method, indent_with=" " * 4) + if is_cm: # pragma: no cover + func = func.replace("->Iterator", "->ContextManager") + # Create export function body template = TEMPLATE.format( " await " if isinstance(method, ast.AsyncFunctionDef) else " ", - lookup_path, + file.modname, method.name + new_args, ) @@ -149,12 +185,12 @@ def matches_disk_files(new_files: dict[str, str]) -> bool: return True -def process(sources_and_lookups: Iterable[tuple[Path, str]], *, do_test: bool) -> None: +def process(files: Iterable[File], *, do_test: bool) -> None: new_files = {} - for source_path, lookup_path in sources_and_lookups: - print("Scanning:", source_path) - new_source = gen_public_wrappers_source(source_path, lookup_path) - dirname, basename = os.path.split(source_path) + for file in files: + print("Scanning:", file.path) + new_source = gen_public_wrappers_source(file) + dirname, basename = os.path.split(file.path) new_path = os.path.join(dirname, PREFIX + basename) new_files[new_path] = new_source if do_test: @@ -186,15 +222,54 @@ def main() -> None: # pragma: no cover assert (source_root / "LICENSE").exists() core = source_root / "trio/_core" to_wrap = [ - (core / "_run.py", "runner"), - (core / "_instrumentation.py", "runner.instruments"), - (core / "_io_windows.py", "runner.io_manager"), - (core / "_io_epoll.py", "runner.io_manager"), - (core / "_io_kqueue.py", "runner.io_manager"), + File(core / "_run.py", "runner", imports=IMPORTS_RUN), + File( + core / "_instrumentation.py", + "runner.instruments", + imports=IMPORTS_INSTRUMENT, + ), + File(core / "_io_windows.py", "runner.io_manager", platform="win32"), + File( + core / "_io_epoll.py", + "runner.io_manager", + platform="linux", + imports=IMPORTS_EPOLL, + ), + File( + core / "_io_kqueue.py", + "runner.io_manager", + platform="darwin", + imports=IMPORTS_KQUEUE, + ), ] process(to_wrap, do_test=parsed_args.test) +IMPORTS_RUN = """\ +from ._run import _NO_SEND +""" +IMPORTS_INSTRUMENT = """\ +from ._instrumentation import Instrument +""" + +IMPORTS_EPOLL = """\ +from socket import socket +""" + +IMPORTS_KQUEUE = """\ +from typing import Callable, ContextManager, TYPE_CHECKING + +if TYPE_CHECKING: + import select + from socket import socket + + from ._traps import Abort, RaiseCancelT + + from .. import _core + +""" + + if __name__ == "__main__": # pragma: no cover main()