From 6d6af4dd67cef7da42e28771f6ea2dae601e4c22 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 22 Jan 2021 12:10:30 -0500 Subject: [PATCH 01/50] mypy the entire trio package --- check.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/check.sh b/check.sh index 57f1e2db40..de185cab67 100755 --- a/check.sh +++ b/check.sh @@ -24,9 +24,9 @@ flake8 trio/ \ || EXIT_STATUS=$? # Run mypy on all supported platforms -mypy -m trio -m trio.testing --platform linux || EXIT_STATUS=$? -mypy -m trio -m trio.testing --platform darwin || EXIT_STATUS=$? # tests FreeBSD too -mypy -m trio -m trio.testing --platform win32 || EXIT_STATUS=$? +mypy -p trio --platform linux || EXIT_STATUS=$? +mypy -p trio --platform darwin || EXIT_STATUS=$? # tests FreeBSD too +mypy -p trio --platform win32 || EXIT_STATUS=$? # Finally, leave a really clear warning of any issues and exit if [ $EXIT_STATUS -ne 0 ]; then From 082f6565bdfb53da648e8b7449ee421312888f21 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 22 Jan 2021 12:22:47 -0500 Subject: [PATCH 02/50] update mypy to 0.800 --- test-requirements.txt | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test-requirements.txt b/test-requirements.txt index 00994509a4..457e3b4f27 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --output-file test-requirements.txt test-requirements.in +# pip-compile --output-file=test-requirements.txt test-requirements.in # appdirs==1.4.4 # via black @@ -64,7 +64,7 @@ mypy-extensions==0.4.3 ; implementation_name == "cpython" # -r test-requirements.in # black # mypy -mypy==0.790 ; implementation_name == "cpython" +mypy==0.800 ; implementation_name == "cpython" # via -r test-requirements.in outcome==1.1.0 # via -r test-requirements.in @@ -140,3 +140,6 @@ wcwidth==0.2.5 # via prompt-toolkit wrapt==1.12.1 # via astroid + +# The following packages are considered to be unsafe in a requirements file: +# setuptools From 1c00364f454fe9dfa74ec84967ede8b3e546a845 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 22 Jan 2021 21:09:35 -0500 Subject: [PATCH 03/50] more specific ignores, etc --- trio/_channel.py | 5 ++++- trio/_core/_multierror.py | 6 +++--- trio/_core/_run.py | 3 ++- trio/_core/tests/test_io.py | 12 ++++++++--- trio/_core/tests/test_multierror.py | 2 +- trio/_path.py | 22 +++++++++++++++++--- trio/_path.pyi | 1 - trio/_subprocess_platform/waitid.py | 3 ++- trio/_util.py | 18 ++++++++++------ trio/socket.py | 4 ++-- trio/tests/module_with_deprecations.py | 2 +- trio/tests/test_highlevel_serve_listeners.py | 8 ++++++- trio/tests/test_subprocess.py | 6 ++++++ trio/tests/test_sync.py | 8 ++++++- trio/tests/test_unix_pipes.py | 3 ++- trio/tests/test_windows_pipes.py | 9 ++++---- 16 files changed, 82 insertions(+), 30 deletions(-) delete mode 100644 trio/_path.pyi diff --git a/trio/_channel.py b/trio/_channel.py index 1cecc55621..e2a184950b 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,5 +1,6 @@ from collections import deque, OrderedDict from math import inf +from typing import cast, Callable, Tuple, TypeVar import attr from outcome import Error, Value @@ -12,7 +13,9 @@ @generic_function -def open_memory_channel(max_buffer_size): +def open_memory_channel( + max_buffer_size, +) -> Tuple["MemorySendChannel", "MemoryReceiveChannel"]: """Open a channel for passing objects between tasks within a process. Memory channels are lightweight, cheap to allocate, and entirely diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 13fc3f3d0f..337b9e64ac 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -419,7 +419,7 @@ def traceback_exception_init( self.embedded = [] -traceback.TracebackException.__init__ = traceback_exception_init # type: ignore +traceback.TracebackException.__init__ = traceback_exception_init # type: ignore[assignment] traceback_exception_original_format = traceback.TracebackException.format @@ -431,7 +431,7 @@ def traceback_exception_format(self, *, chain=True): yield from (textwrap.indent(line, " " * 2) for line in exc.format(chain=chain)) -traceback.TracebackException.format = traceback_exception_format # type: ignore +traceback.TracebackException.format = traceback_exception_format # type: ignore[assignment] def trio_excepthook(etype, value, tb): @@ -493,7 +493,7 @@ class TrioFakeSysModuleForApport: fake_sys = TrioFakeSysModuleForApport() fake_sys.__dict__.update(sys.__dict__) - fake_sys.__excepthook__ = trio_excepthook # type: ignore + fake_sys.__excepthook__ = trio_excepthook # type: ignore[attr-defined] apport_python_hook.sys = fake_sys monkeypatched_or_warned = True diff --git a/trio/_core/_run.py b/trio/_core/_run.py index e56977f386..3f62f68f66 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -657,7 +657,8 @@ def shield(self): """ return self._shield - @shield.setter # type: ignore # "decorated property not supported" + # ignore for "decorated property not supported" + @shield.setter # type: ignore[misc] @enable_ki_protection def shield(self, new_value): if not isinstance(new_value, bool): diff --git a/trio/_core/tests/test_io.py b/trio/_core/tests/test_io.py index 397375503d..a35cb09e7a 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/tests/test_io.py @@ -5,6 +5,7 @@ import random import errno from contextlib import suppress +from typing import Callable from ... import _core from ...testing import wait_all_tasks_blocked, Sequencer, assert_checkpoints @@ -59,18 +60,23 @@ def fileno_wrapper(fileobj): ]: options_list += [using_fileno(f) for f in options_list] + +def get__name__(fn: Callable) -> str: + return fn.__name__ + + # Decorators that feed in different settings for wait_readable / wait_writable # / notify_closing. # Note that if you use all three decorators on the same test, it will run all # N**3 *combinations* read_socket_test = pytest.mark.parametrize( - "wait_readable", wait_readable_options, ids=lambda fn: fn.__name__ + "wait_readable", wait_readable_options, ids=get__name__ ) write_socket_test = pytest.mark.parametrize( - "wait_writable", wait_writable_options, ids=lambda fn: fn.__name__ + "wait_writable", wait_writable_options, ids=get__name__ ) notify_closing_test = pytest.mark.parametrize( - "notify_closing", notify_closing_options, ids=lambda fn: fn.__name__ + "notify_closing", notify_closing_options, ids=get__name__ ) diff --git a/trio/_core/tests/test_multierror.py b/trio/_core/tests/test_multierror.py index 14eab22df7..7976c46ae2 100644 --- a/trio/_core/tests/test_multierror.py +++ b/trio/_core/tests/test_multierror.py @@ -6,7 +6,7 @@ print_exception, format_exception, ) -from traceback import _cause_message # type: ignore +from traceback import _cause_message # type: ignore[attr-defined] import sys import os import re diff --git a/trio/_path.py b/trio/_path.py index 4077c449d7..a4162693db 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -1,14 +1,16 @@ -# type: ignore - from functools import wraps, partial import os import types import pathlib +from typing import Iterator, TypeVar, Union import trio from trio._util import async_wraps, Final +_P = TypeVar("_P", bound="Path") + + # re-wrap return value from methods that return new instances of pathlib.Path def rewrap_path(value): if isinstance(value, pathlib.Path): @@ -153,6 +155,20 @@ class Path(metaclass=AsyncAutoWrapperType): ] _wrap_iter = ["glob", "rglob", "iterdir"] + # TODO: fill out the rest. Just copy from typeshed? Maybe this design won't pan + # out cleaner than a stub .pyi in the long run. + # https://github.com/python/typeshed/blob/58032a701811093d7bd24f9f75ad5e5de07e7723/stdlib/3/pathlib.pyi#L17-L53 + + # NOTE: These are effectively type hints compensating for Mypy not being able to + # see through AsyncAutoWrapperType. They are inline here such that the rest + # of the file can be hinted regularly rather than in a separate stub .pyi. + + def joinpath(self: _P, *other: Union[str, os.PathLike[str]]) -> _P: + ... + + def iterdir(self: _P) -> Iterator[_P]: + ... + def __init__(self, *args): self._wrapped = pathlib.Path(*args) @@ -201,6 +217,6 @@ async def open(self, *args, **kwargs): # The value of Path.absolute.__doc__ makes a reference to # :meth:~pathlib.Path.absolute, which does not exist. Removing this makes more # sense than inventing our own special docstring for this. -del Path.absolute.__doc__ +del Path.absolute.__doc__ # type: ignore[attr-defined] os.PathLike.register(Path) diff --git a/trio/_path.pyi b/trio/_path.pyi deleted file mode 100644 index 85a8e1f960..0000000000 --- a/trio/_path.pyi +++ /dev/null @@ -1 +0,0 @@ -class Path: ... diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index 91ba224546..486c37acb9 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -102,7 +102,8 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: # process. if process._wait_for_exit_data is None: - process._wait_for_exit_data = event = Event() # type: ignore + event = Event() + process._wait_for_exit_data = event # type: ignore[assignment] _core.spawn_system_task(_waitid_system_task, process.pid, event) assert isinstance(process._wait_for_exit_data, Event) await process._wait_for_exit_data.wait() diff --git a/trio/_util.py b/trio/_util.py index ec0350b305..17bdf1a673 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -65,6 +65,9 @@ def signal_raise(signum): signal.pthread_kill(threading.get_ident(), signum) +_T = t.TypeVar("_T") + + # See: #461 as to why this is needed. # The gist is that threading.main_thread() has the capability to lie to us # if somebody else edits the threading ident cache to replace the main @@ -241,7 +244,13 @@ def fix_one(qualname, name, obj): fix_one(objname, objname, obj) -class generic_function: +# TODO: This does not account for the generic parametrization via __getitem__. +# Presumably this will not work out in the long run even though it helped +# with trio._channel.open_memory_channel right now. +generic_function: t.Callable[[_T], _T] + + +class generic_function: # type: ignore[no-redef] """Decorator that makes a function indexable, to communicate non-inferrable generic type parameters to a static type checker. @@ -307,9 +316,6 @@ def __new__(cls, name, bases, cls_namespace): return super().__new__(cls, name, bases, cls_namespace) -T = t.TypeVar("T") - - class NoPublicConstructor(Final): """Metaclass that enforces a class to be final (i.e., subclass not allowed) and ensures a private constructor. @@ -334,8 +340,8 @@ def __call__(cls, *args, **kwargs): f"{cls.__module__}.{cls.__qualname__} has no public constructor" ) - def _create(cls: t.Type[T], *args: t.Any, **kwargs: t.Any) -> T: - return super().__call__(*args, **kwargs) # type: ignore + def _create(cls: t.Type[_T], *args: t.Any, **kwargs: t.Any) -> _T: + return super().__call__(*args, **kwargs) # type: ignore[no-any-return,misc] def name_asyncgen(agen): diff --git a/trio/socket.py b/trio/socket.py index 5402f5bc73..b4c9649502 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -22,7 +22,7 @@ # kept up to date. try: # fmt: off - from socket import ( # type: ignore + from socket import ( # type: ignore[attr-defined] CMSG_LEN, CMSG_SPACE, CAPI, AF_UNSPEC, AF_INET, AF_UNIX, AF_IPX, AF_APPLETALK, AF_INET6, AF_ROUTE, AF_LINK, AF_SNA, PF_SYSTEM, AF_SYSTEM, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, SOCK_SEQPACKET, SOCK_RDM, @@ -137,7 +137,7 @@ globals().update( { _name: getattr(_stdlib_socket, _name) - for _name in _stdlib_socket.__all__ # type: ignore + for _name in _stdlib_socket.__all__ # type: ignore[attr-defined] if _name.isupper() and _name not in _bad_symbols } ) diff --git a/trio/tests/module_with_deprecations.py b/trio/tests/module_with_deprecations.py index 73184d11e8..ed51f150c3 100644 --- a/trio/tests/module_with_deprecations.py +++ b/trio/tests/module_with_deprecations.py @@ -10,7 +10,7 @@ import sys this_mod = sys.modules[__name__] -assert this_mod.regular == "hi" +assert this_mod.regular == "hi" # type: ignore[attr-defined] assert not hasattr(this_mod, "dep1") __deprecated_attributes__ = { diff --git a/trio/tests/test_highlevel_serve_listeners.py b/trio/tests/test_highlevel_serve_listeners.py index b028092eb9..f488b5e7ff 100644 --- a/trio/tests/test_highlevel_serve_listeners.py +++ b/trio/tests/test_highlevel_serve_listeners.py @@ -2,18 +2,24 @@ from functools import partial import errno +from typing import Tuple import attr import trio from trio.testing import memory_stream_pair, wait_all_tasks_blocked +from trio._channel import MemorySendChannel, MemoryReceiveChannel + + +def _open_memory_channel_1() -> Tuple[MemorySendChannel, MemoryReceiveChannel]: + return trio.open_memory_channel(1) @attr.s(hash=False, eq=False) class MemoryListener(trio.abc.Listener): closed = attr.ib(default=False) accepted_streams = attr.ib(factory=list) - queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1))) + queued_streams = attr.ib(factory=_open_memory_channel_1) accept_hook = attr.ib(default=None) async def connect(self): diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index 7ba794a428..43c5bf5dd0 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -4,6 +4,8 @@ import sys import pytest import random +import signal +from typing import Optional from functools import partial from .. import ( @@ -20,6 +22,10 @@ from .._core.tests.tutil import slow, skip_if_fbsd_pipes_broken from ..testing import wait_all_tasks_blocked +SIGKILL: Optional[signal.Signals] +SIGTERM: Optional[signal.Signals] +SIGUSR1: Optional[signal.Signals] + posix = os.name == "posix" if posix: from signal import SIGKILL, SIGTERM, SIGUSR1 diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index 229dea301c..944adec2cf 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -1,3 +1,5 @@ +from typing import cast, Callable + import pytest import weakref @@ -241,7 +243,11 @@ async def test_Semaphore_bounded(): assert bs.value == 1 -@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__) +def get__name__(fn: Callable) -> str: + return fn.__name__ + + +@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=get__name__) async def test_Lock_and_StrictFIFOLock(lockcls): l = lockcls() # noqa assert not l.locked() diff --git a/trio/tests/test_unix_pipes.py b/trio/tests/test_unix_pipes.py index 55dd4e3734..da5e696bf6 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/tests/test_unix_pipes.py @@ -3,6 +3,7 @@ import os import tempfile import sys +from typing import Tuple import pytest @@ -20,7 +21,7 @@ # Have to use quoted types so import doesn't crash on windows -async def make_pipe() -> "Tuple[FdStream, FdStream]": +async def make_pipe() -> Tuple["FdStream", "FdStream"]: """Makes a new pair of pipes.""" (r, w) = os.pipe() return FdStream(w), FdStream(r) diff --git a/trio/tests/test_windows_pipes.py b/trio/tests/test_windows_pipes.py index 361cd64ce2..919ca75427 100644 --- a/trio/tests/test_windows_pipes.py +++ b/trio/tests/test_windows_pipes.py @@ -3,6 +3,7 @@ import os import sys +from typing import Any, Tuple import pytest from .._core.tests.tutil import gc_collect_harder @@ -15,12 +16,12 @@ from asyncio.windows_utils import pipe else: pytestmark = pytest.mark.skip(reason="windows only") - pipe = None # type: Any - PipeSendStream = None # type: Any - PipeReceiveStream = None # type: Any + pipe: Any = None + PipeSendStream: Any = None + PipeReceiveStream: Any = None -async def make_pipe() -> "Tuple[PipeSendStream, PipeReceiveStream]": +async def make_pipe() -> Tuple[PipeSendStream, PipeReceiveStream]: """Makes a new pair of pipes.""" (r, w) = pipe() return PipeSendStream(w), PipeReceiveStream(r) From 9f94c1e022af7f9e9885669c93d5e877798d09dd Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 14:23:21 -0500 Subject: [PATCH 04/50] disallow_incomplete_defs = True --- mypy.ini | 2 +- trio/_channel.py | 5 +- trio/_core/_io_epoll.py | 12 +- trio/_core/_run.py | 75 ++-- trio/_subprocess.py | 874 ++++++++++++++++++++++--------------- trio/_unix_pipes.py | 5 +- trio/testing/_sequencer.py | 6 +- 7 files changed, 596 insertions(+), 383 deletions(-) diff --git a/mypy.ini b/mypy.ini index 31eeef1cd0..d308880c33 100644 --- a/mypy.ini +++ b/mypy.ini @@ -13,7 +13,7 @@ warn_return_any = True # Avoid subtle backsliding #disallow_any_decorated = True -#disallow_incomplete_defs = True +disallow_incomplete_defs = True #disallow_subclassing_any = True # Enable gradually / for new modules diff --git a/trio/_channel.py b/trio/_channel.py index e2a184950b..b4f09e01ed 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,6 +1,6 @@ from collections import deque, OrderedDict from math import inf -from typing import cast, Callable, Tuple, TypeVar +from typing import cast, Callable, Tuple, TypeVar, Union import attr from outcome import Error, Value @@ -14,7 +14,8 @@ @generic_function def open_memory_channel( - max_buffer_size, + # TODO: should restrict the float bit to just the inf value + max_buffer_size: Union[int, float], ) -> Tuple["MemorySendChannel", "MemoryReceiveChannel"]: """Open a channel for passing objects between tasks within a process. diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index c1537cf53e..802f425409 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -2,7 +2,7 @@ import sys import attr from collections import defaultdict -from typing import Dict, TYPE_CHECKING +from typing import DefaultDict, Dict, TYPE_CHECKING from .. import _core from ._run import _public @@ -186,13 +186,13 @@ class EpollWaiters: @attr.s(slots=True, eq=False, hash=False) class EpollIOManager: - _epoll = attr.ib(factory=select.epoll) + _epoll: select.epoll = attr.ib(factory=select.epoll) # {fd: EpollWaiters} - _registered = attr.ib( - factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters] + _registered: DefaultDict[int, EpollWaiters] = attr.ib( + factory=lambda: defaultdict(EpollWaiters) ) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: int = attr.ib(default=None) def __attrs_post_init__(self): self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 3f62f68f66..3865f4f755 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1,5 +1,6 @@ # coding: utf-8 +from contextvars import Context import functools import itertools import logging @@ -18,7 +19,18 @@ from contextvars import copy_context from math import inf from time import perf_counter -from typing import Callable, TYPE_CHECKING +from typing import ( + Callable, + Deque, + Dict, + Generator, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) from sniffio import current_async_library_cvar @@ -44,9 +56,12 @@ WaitTaskRescheduled, ) from ._asyncgens import AsyncGenerators +from ._entry_queue import TrioToken from ._thread_cache import start_thread_soon from ._instrumentation import Instruments +from ._local import RunVar from .. import _core +from ..abc import Clock from .._deprecate import warn_deprecated from .._util import Final, NoPublicConstructor, coroutine_or_error @@ -1229,13 +1244,15 @@ class _RunStatistics: # worker thread. @attr.s(eq=False, hash=False, slots=True) class GuestState: - runner = attr.ib() - run_sync_soon_threadsafe = attr.ib() - run_sync_soon_not_threadsafe = attr.ib() - done_callback = attr.ib() - unrolled_run_gen = attr.ib() + runner: "Runner" = attr.ib() + run_sync_soon_threadsafe: Callable[[Callable[[], object]], object] = attr.ib() + run_sync_soon_not_threadsafe: Optional[ + Callable[[Callable[[], object]], object] + ] = attr.ib() + done_callback: Callable[[Outcome], object] = attr.ib() + unrolled_run_gen: Generator[int, List[Tuple[int, int]], None] = attr.ib() _value_factory: Callable[[], Value] = lambda: Value(None) - unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome) + unrolled_run_next_send: Outcome = attr.ib(factory=_value_factory) def guest_tick(self): try: @@ -1274,35 +1291,38 @@ def in_main_thread(): @attr.s(eq=False, hash=False, slots=True) class Runner: - clock = attr.ib() + clock: Clock = attr.ib() instruments: Instruments = attr.ib() - io_manager = attr.ib() - ki_manager = attr.ib() + # TODO: It seems that down at the bottom kqueue is the IO manager chosen for for + # type checking. Seems like there ought to be a protocol or union here. + # io_manager: Union["KqueueIOManager", "EpollIOManager", "WindowsIOManager"] = attr.ib() + io_manager: "TheIOManager" = attr.ib() + ki_manager: KIManager = attr.ib() # Run-local values, see _local.py - _locals = attr.ib(factory=dict) + _locals: Dict[RunVar, object] = attr.ib(factory=dict) - runq = attr.ib(factory=deque) - tasks = attr.ib(factory=set) + runq: Deque[Task] = attr.ib(factory=deque) + tasks: Set[Task] = attr.ib(factory=set) - deadlines = attr.ib(factory=Deadlines) + deadlines: Deadlines = attr.ib(factory=Deadlines) - init_task = attr.ib(default=None) - system_nursery = attr.ib(default=None) - system_context = attr.ib(default=None) - main_task = attr.ib(default=None) - main_task_outcome = attr.ib(default=None) + init_task: Optional[Task] = attr.ib(default=None) + system_nursery: Optional[Nursery] = attr.ib(default=None) + system_context: Optional[Context] = attr.ib(default=None) + main_task: Optional[Task] = attr.ib(default=None) + main_task_outcome: Optional[Outcome] = attr.ib(default=None) - entry_queue = attr.ib(factory=EntryQueue) - trio_token = attr.ib(default=None) - asyncgens = attr.ib(factory=AsyncGenerators) + entry_queue: EntryQueue = attr.ib(factory=EntryQueue) + trio_token: TrioToken = attr.ib(default=None) + asyncgens: AsyncGenerators = attr.ib(factory=AsyncGenerators) # If everything goes idle for this long, we call clock._autojump() - clock_autojump_threshold = attr.ib(default=inf) + clock_autojump_threshold: float = attr.ib(default=inf) # Guest mode stuff - is_guest = attr.ib(default=False) - guest_tick_scheduled = attr.ib(default=False) + is_guest: bool = attr.ib(default=False) + guest_tick_scheduled: bool = attr.ib(default=False) def force_guest_tick_asap(self): if self.guest_tick_scheduled: @@ -1631,7 +1651,7 @@ def current_trio_token(self): # KI handling ################ - ki_pending = attr.ib(default=False) + ki_pending: bool = attr.ib(default=False) # deliver_ki is broke. Maybe move all the actual logic and state into # RunToken, and we'll only have one instance per runner? But then we can't @@ -1664,7 +1684,8 @@ def _deliver_ki_cb(self): # Quiescing ################ - waiting_for_idle = attr.ib(factory=SortedDict) + # TODO: how to hint a SortedDict with it's content type as well? + waiting_for_idle: SortedDict = attr.ib(factory=SortedDict) @_public async def wait_all_tasks_blocked(self, cushion=0.0): diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 876cc0d7c9..48e242378d 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -6,7 +6,20 @@ from typing import Optional from functools import partial import warnings -from typing import TYPE_CHECKING +from typing import ( + Any, + Awaitable, + Callable, + Literal, + Mapping, + Optional, + overload, + Union, + Sequence, + TYPE_CHECKING, +) + +from typing_extensions import Protocol from ._abc import AsyncResource, SendStream, ReceiveStream from ._highlevel_generic import StapledStream @@ -276,119 +289,532 @@ def kill(self): self._proc.kill() -async def open_process( - command, *, stdin=None, stdout=None, stderr=None, **options -) -> Process: - r"""Execute a child program in a new process. - - After construction, you can interact with the child process by writing - data to its `~Process.stdin` stream (a `~trio.abc.SendStream`), reading - data from its `~Process.stdout` and/or `~Process.stderr` streams (both - `~trio.abc.ReceiveStream`\s), sending it signals using - `~Process.terminate`, `~Process.kill`, or `~Process.send_signal`, and - waiting for it to exit using `~Process.wait`. See `Process` for details. - - Each standard stream is only available if you specify that a pipe should - be created for it. For example, if you pass ``stdin=subprocess.PIPE``, you - can write to the `~Process.stdin` stream, else `~Process.stdin` will be - ``None``. - - Args: - command (list or str): The command to run. Typically this is a - sequence of strings such as ``['ls', '-l', 'directory with spaces']``, - where the first element names the executable to invoke and the other - elements specify its arguments. With ``shell=True`` in the - ``**options``, or on Windows, ``command`` may alternatively - be a string, which will be parsed following platform-dependent - :ref:`quoting rules `. - stdin: Specifies what the child process's standard input - stream should connect to: output written by the parent - (``subprocess.PIPE``), nothing (``subprocess.DEVNULL``), - or an open file (pass a file descriptor or something whose - ``fileno`` method returns one). If ``stdin`` is unspecified, - the child process will have the same standard input stream - as its parent. - stdout: Like ``stdin``, but for the child process's standard output - stream. - stderr: Like ``stdin``, but for the child process's standard error - stream. An additional value ``subprocess.STDOUT`` is supported, - which causes the child's standard output and standard error - messages to be intermixed on a single standard output stream, - attached to whatever the ``stdout`` option says to attach it to. - **options: Other :ref:`general subprocess options ` - are also accepted. - - Returns: - A new `Process` object. - - Raises: - OSError: if the process spawning fails, for example because the - specified command could not be found. +class _HasFileno(Protocol): + def fileno(self) -> int: + ... - """ - for key in ("universal_newlines", "text", "encoding", "errors", "bufsize"): - if options.get(key): - raise TypeError( - "trio.Process only supports communicating over " - "unbuffered byte streams; the '{}' option is not supported".format(key) - ) - if os.name == "posix": - if isinstance(command, str) and not options.get("shell"): - raise TypeError( - "command must be a sequence (not a string) if shell=False " - "on UNIX systems" - ) - if not isinstance(command, str) and options.get("shell"): - raise TypeError( - "command must be a string (not a sequence) if shell=True " - "on UNIX systems" +_Redirect = Union[int, _HasFileno, None] + +# There's a lot of duplication here because mypy doesn't +# have a good way to represent overloads that differ only +# slightly. A cheat sheet: +# - on Windows, command is Union[str, Sequence[str]]; +# on Unix, command is str if shell=True and Sequence[str] otherwise +# - on Windows, there are startupinfo and creationflags options; +# on Unix, there are preexec_fn, restore_signals, start_new_session, and pass_fds +# - run_process() has the signature of open_process() plus arguments +# capture_stdout, capture_stderr, check, deliver_cancel, and the ability to pass +# bytes as stdin + +if TYPE_CHECKING: + if sys.platform == "win32": + + async def open_process( + command: Union[str, Sequence[str]], + *, + stdin: _Redirect = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: bool = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + startupinfo: subprocess.STARTUPINFO = ..., + creationflags: int = ..., + ) -> Process: + ... + + async def run_process( + command: Union[str, Sequence[str]], + *, + stdin: Union[bytes, _Redirect] = ..., + capture_stdout: bool = ..., + capture_stderr: bool = ..., + check: bool = ..., + deliver_cancel: Callable[[Process], Awaitable[None]] = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: bool = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + startupinfo: subprocess.STARTUPINFO = ..., + creationflags: int = ..., + ) -> subprocess.CompletedProcess[bytes]: + ... + + else: + + @overload + async def open_process( + command: str, + *, + stdin: _Redirect = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: Literal[True], + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> Process: + ... + + @overload + async def open_process( + command: Sequence[str], + *, + stdin: _Redirect = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: bool = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> Process: + ... + + async def open_process( # type: ignore[no-untyped-def] + command, + *, + stdin=..., + stdout=..., + stderr=..., + close_fds=..., + shell=..., + cwd=..., + env=..., + preexec_fn=..., + restore_signals=..., + start_new_session=..., + pass_fds=..., + ) -> Process: + ... + + @overload + async def run_process( + command: str, + *, + stdin: Union[bytes, _Redirect] = ..., + capture_stdout: bool = ..., + capture_stderr: bool = ..., + check: bool = ..., + deliver_cancel: Callable[[Process], Awaitable[None]] = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: Literal[True], + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> subprocess.CompletedProcess[bytes]: + ... + + @overload + async def run_process( + command: Sequence[str], + *, + stdin: Union[bytes, _Redirect] = ..., + capture_stdout: bool = ..., + capture_stderr: bool = ..., + check: bool = ..., + deliver_cancel: Callable[[Process], Awaitable[None]] = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: bool = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> subprocess.CompletedProcess[bytes]: + ... + + async def run_process( # type: ignore[no-untyped-def] + command, + *, + stdin=..., + capture_stdout=..., + capture_stderr=..., + check=..., + deliver_cancel=..., + stdout=..., + stderr=..., + close_fds=..., + shell=..., + cwd=..., + env=..., + preexec_fn=..., + restore_signals=..., + start_new_session=..., + pass_fds=..., + ) -> subprocess.CompletedProcess[bytes]: + ... + + +else: + + async def open_process(command, *, stdin=None, stdout=None, stderr=None, **options): + r"""Execute a child program in a new process. + + After construction, you can interact with the child process by writing + data to its `~Process.stdin` stream (a `~trio.abc.SendStream`), reading + data from its `~Process.stdout` and/or `~Process.stderr` streams (both + `~trio.abc.ReceiveStream`\s), sending it signals using + `~Process.terminate`, `~Process.kill`, or `~Process.send_signal`, and + waiting for it to exit using `~Process.wait`. See `Process` for details. + + Each standard stream is only available if you specify that a pipe should + be created for it. For example, if you pass ``stdin=subprocess.PIPE``, you + can write to the `~Process.stdin` stream, else `~Process.stdin` will be + ``None``. + + Args: + command (list or str): The command to run. Typically this is a + sequence of strings such as ``['ls', '-l', 'directory with spaces']``, + where the first element names the executable to invoke and the other + elements specify its arguments. With ``shell=True`` in the + ``**options``, or on Windows, ``command`` may alternatively + be a string, which will be parsed following platform-dependent + :ref:`quoting rules `. + stdin: Specifies what the child process's standard input + stream should connect to: output written by the parent + (``subprocess.PIPE``), nothing (``subprocess.DEVNULL``), + or an open file (pass a file descriptor or something whose + ``fileno`` method returns one). If ``stdin`` is unspecified, + the child process will have the same standard input stream + as its parent. + stdout: Like ``stdin``, but for the child process's standard output + stream. + stderr: Like ``stdin``, but for the child process's standard error + stream. An additional value ``subprocess.STDOUT`` is supported, + which causes the child's standard output and standard error + messages to be intermixed on a single standard output stream, + attached to whatever the ``stdout`` option says to attach it to. + **options: Other :ref:`general subprocess options ` + are also accepted. + + Returns: + A new `Process` object. + + Raises: + OSError: if the process spawning fails, for example because the + specified command could not be found. + + """ + for key in ("universal_newlines", "text", "encoding", "errors", "bufsize"): + if options.get(key): + raise TypeError( + "trio.Process only supports communicating over " + "unbuffered byte streams; the '{}' option is not supported".format( + key + ) + ) + + if os.name == "posix": + if isinstance(command, str) and not options.get("shell"): + raise TypeError( + "command must be a sequence (not a string) if shell=False " + "on UNIX systems" + ) + if not isinstance(command, str) and options.get("shell"): + raise TypeError( + "command must be a string (not a sequence) if shell=True " + "on UNIX systems" + ) + + trio_stdin = None # type: Optional[SendStream] + trio_stdout = None # type: Optional[ReceiveStream] + trio_stderr = None # type: Optional[ReceiveStream] + + if stdin == subprocess.PIPE: + trio_stdin, stdin = create_pipe_to_child_stdin() + if stdout == subprocess.PIPE: + trio_stdout, stdout = create_pipe_from_child_output() + if stderr == subprocess.STDOUT: + # If we created a pipe for stdout, pass the same pipe for + # stderr. If stdout was some non-pipe thing (DEVNULL or a + # given FD), pass the same thing. If stdout was passed as + # None, keep stderr as STDOUT to allow subprocess to dup + # our stdout. Regardless of which of these is applicable, + # don't create a new Trio stream for stderr -- if stdout + # is piped, stderr will be intermixed on the stdout stream. + if stdout is not None: + stderr = stdout + elif stderr == subprocess.PIPE: + trio_stderr, stderr = create_pipe_from_child_output() + + try: + popen = await trio.to_thread.run_sync( + partial( + subprocess.Popen, + command, + stdin=stdin, + stdout=stdout, + stderr=stderr, + **options, + ) ) + finally: + # Close the parent's handle for each child side of a pipe; + # we want the child to have the only copy, so that when + # it exits we can read EOF on our side. + if trio_stdin is not None: + os.close(stdin) + if trio_stdout is not None: + os.close(stdout) + if trio_stderr is not None: + os.close(stderr) + + return Process._create(popen, trio_stdin, trio_stdout, trio_stderr) + + async def run_process( + command, + *, + stdin=b"", + capture_stdout=False, + capture_stderr=False, + check=True, + deliver_cancel=None, + **options, + ): + """Run ``command`` in a subprocess, wait for it to complete, and + return a :class:`subprocess.CompletedProcess` instance describing + the results. + + If cancelled, :func:`run_process` terminates the subprocess and + waits for it to exit before propagating the cancellation, like + :meth:`Process.aclose`. + + **Input:** The subprocess's standard input stream is set up to + receive the bytes provided as ``stdin``. Once the given input has + been fully delivered, or if none is provided, the subprocess will + receive end-of-file when reading from its standard input. + Alternatively, if you want the subprocess to read its + standard input from the same place as the parent Trio process, you + can pass ``stdin=None``. + + **Output:** By default, any output produced by the subprocess is + passed through to the standard output and error streams of the + parent Trio process. If you would like to capture this output and + do something with it, you can pass ``capture_stdout=True`` to + capture the subprocess's standard output, and/or + ``capture_stderr=True`` to capture its standard error. Captured + data is provided as the + :attr:`~subprocess.CompletedProcess.stdout` and/or + :attr:`~subprocess.CompletedProcess.stderr` attributes of the + returned :class:`~subprocess.CompletedProcess` object. The value + for any stream that was not captured will be ``None``. + + If you want to capture both stdout and stderr while keeping them + separate, pass ``capture_stdout=True, capture_stderr=True``. + + If you want to capture both stdout and stderr but mixed together + in the order they were printed, use: ``capture_stdout=True, stderr=subprocess.STDOUT``. + This directs the child's stderr into its stdout, so the combined + output will be available in the `~subprocess.CompletedProcess.stdout` + attribute. + + **Error checking:** If the subprocess exits with a nonzero status + code, indicating failure, :func:`run_process` raises a + :exc:`subprocess.CalledProcessError` exception rather than + returning normally. The captured outputs are still available as + the :attr:`~subprocess.CalledProcessError.stdout` and + :attr:`~subprocess.CalledProcessError.stderr` attributes of that + exception. To disable this behavior, so that :func:`run_process` + returns normally even if the subprocess exits abnormally, pass + ``check=False``. + + Args: + command (list or str): The command to run. Typically this is a + sequence of strings such as ``['ls', '-l', 'directory with spaces']``, + where the first element names the executable to invoke and the other + elements specify its arguments. With ``shell=True`` in the + ``**options``, or on Windows, ``command`` may alternatively + be a string, which will be parsed following platform-dependent + :ref:`quoting rules `. + + stdin (:obj:`bytes`, file descriptor, or None): The bytes to provide to + the subprocess on its standard input stream, or ``None`` if the + subprocess's standard input should come from the same place as + the parent Trio process's standard input. As is the case with + the :mod:`subprocess` module, you can also pass a + file descriptor or an object with a ``fileno()`` method, + in which case the subprocess's standard input will come from + that file. + + capture_stdout (bool): If true, capture the bytes that the subprocess + writes to its standard output stream and return them in the + :attr:`~subprocess.CompletedProcess.stdout` attribute + of the returned :class:`~subprocess.CompletedProcess` object. + + capture_stderr (bool): If true, capture the bytes that the subprocess + writes to its standard error stream and return them in the + :attr:`~subprocess.CompletedProcess.stderr` attribute + of the returned :class:`~subprocess.CompletedProcess` object. + + check (bool): If false, don't validate that the subprocess exits + successfully. You should be sure to check the + ``returncode`` attribute of the returned object if you pass + ``check=False``, so that errors don't pass silently. + + deliver_cancel (async function or None): If `run_process` is cancelled, + then it needs to kill the child process. There are multiple ways to + do this, so we let you customize it. + + If you pass None (the default), then the behavior depends on the + platform: + + - On Windows, Trio calls ``TerminateProcess``, which should kill the + process immediately. + + - On Unix-likes, the default behavior is to send a ``SIGTERM``, wait + 5 seconds, and send a ``SIGKILL``. + + Alternatively, you can customize this behavior by passing in an + arbitrary async function, which will be called with the `Process` + object as an argument. For example, the default Unix behavior could + be implemented like this:: + + async def my_deliver_cancel(process): + process.send_signal(signal.SIGTERM) + await trio.sleep(5) + process.send_signal(signal.SIGKILL) + + When the process actually exits, the ``deliver_cancel`` function + will automatically be cancelled – so if the process exits after + ``SIGTERM``, then we'll never reach the ``SIGKILL``. + + In any case, `run_process` will always wait for the child process to + exit before raising `Cancelled`. + + **options: :func:`run_process` also accepts any :ref:`general subprocess + options ` and passes them on to the + :class:`~trio.Process` constructor. This includes the + ``stdout`` and ``stderr`` options, which provide additional + redirection possibilities such as ``stderr=subprocess.STDOUT``, + ``stdout=subprocess.DEVNULL``, or file descriptors. - trio_stdin = None # type: Optional[SendStream] - trio_stdout = None # type: Optional[ReceiveStream] - trio_stderr = None # type: Optional[ReceiveStream] - - if stdin == subprocess.PIPE: - trio_stdin, stdin = create_pipe_to_child_stdin() - if stdout == subprocess.PIPE: - trio_stdout, stdout = create_pipe_from_child_output() - if stderr == subprocess.STDOUT: - # If we created a pipe for stdout, pass the same pipe for - # stderr. If stdout was some non-pipe thing (DEVNULL or a - # given FD), pass the same thing. If stdout was passed as - # None, keep stderr as STDOUT to allow subprocess to dup - # our stdout. Regardless of which of these is applicable, - # don't create a new Trio stream for stderr -- if stdout - # is piped, stderr will be intermixed on the stdout stream. - if stdout is not None: - stderr = stdout - elif stderr == subprocess.PIPE: - trio_stderr, stderr = create_pipe_from_child_output() + Returns: + A :class:`subprocess.CompletedProcess` instance describing the + return code and outputs. + + Raises: + UnicodeError: if ``stdin`` is specified as a Unicode string, rather + than bytes + ValueError: if multiple redirections are specified for the same + stream, e.g., both ``capture_stdout=True`` and + ``stdout=subprocess.DEVNULL`` + subprocess.CalledProcessError: if ``check=False`` is not passed + and the process exits with a nonzero exit status + OSError: if an error is encountered starting or communicating with + the process + + .. note:: The child process runs in the same process group as the parent + Trio process, so a Ctrl+C will be delivered simultaneously to both + parent and child. If you don't want this behavior, consult your + platform's documentation for starting child processes in a different + process group. - try: - popen = await trio.to_thread.run_sync( - partial( - subprocess.Popen, - command, - stdin=stdin, - stdout=stdout, - stderr=stderr, - **options, + """ + + if isinstance(stdin, str): + raise UnicodeError("process stdin must be bytes, not str") + if stdin == subprocess.PIPE: + raise ValueError( + "stdin=subprocess.PIPE doesn't make sense since the pipe " + "is internal to run_process(); pass the actual data you " + "want to send over that pipe instead" ) - ) - finally: - # Close the parent's handle for each child side of a pipe; - # we want the child to have the only copy, so that when - # it exits we can read EOF on our side. - if trio_stdin is not None: - os.close(stdin) - if trio_stdout is not None: - os.close(stdout) - if trio_stderr is not None: - os.close(stderr) + if isinstance(stdin, (bytes, bytearray, memoryview)): + input = stdin + options["stdin"] = subprocess.PIPE + else: + # stdin should be something acceptable to Process + # (None, DEVNULL, a file descriptor, etc) and Process + # will raise if it's not + input = None + options["stdin"] = stdin + + if capture_stdout: + if "stdout" in options: + raise ValueError("can't specify both stdout and capture_stdout") + options["stdout"] = subprocess.PIPE + if capture_stderr: + if "stderr" in options: + raise ValueError("can't specify both stderr and capture_stderr") + options["stderr"] = subprocess.PIPE + + if deliver_cancel is None: + if os.name == "nt": + deliver_cancel = _windows_deliver_cancel + else: + assert os.name == "posix" + deliver_cancel = _posix_deliver_cancel + + stdout_chunks = [] + stderr_chunks = [] + + async with await open_process(command, **options) as proc: + + async def feed_input(): + async with proc.stdin: + try: + await proc.stdin.send_all(input) + except trio.BrokenResourceError: + pass + + async def read_output(stream, chunks): + async with stream: + async for chunk in stream: + chunks.append(chunk) + + async with trio.open_nursery() as nursery: + if proc.stdin is not None: + nursery.start_soon(feed_input) + if proc.stdout is not None: + nursery.start_soon(read_output, proc.stdout, stdout_chunks) + if proc.stderr is not None: + nursery.start_soon(read_output, proc.stderr, stderr_chunks) + try: + await proc.wait() + except trio.Cancelled: + with trio.CancelScope(shield=True): + killer_cscope = trio.CancelScope(shield=True) + + async def killer(): + with killer_cscope: + await deliver_cancel(proc) - return Process._create(popen, trio_stdin, trio_stdout, trio_stderr) + nursery.start_soon(killer) + await proc.wait() + killer_cscope.cancel() + raise + + stdout = b"".join(stdout_chunks) if proc.stdout is not None else None + stderr = b"".join(stderr_chunks) if proc.stderr is not None else None + + if proc.returncode and check: + raise subprocess.CalledProcessError( + proc.returncode, proc.args, output=stdout, stderr=stderr + ) + else: + return subprocess.CompletedProcess( + proc.args, proc.returncode, stdout, stderr + ) async def _windows_deliver_cancel(p): @@ -414,237 +840,3 @@ async def _posix_deliver_cancel(p): warnings.warn( RuntimeWarning(f"tried to kill process {p!r}, but failed with: {exc!r}") ) - - -async def run_process( - command, - *, - stdin=b"", - capture_stdout=False, - capture_stderr=False, - check=True, - deliver_cancel=None, - **options, -): - """Run ``command`` in a subprocess, wait for it to complete, and - return a :class:`subprocess.CompletedProcess` instance describing - the results. - - If cancelled, :func:`run_process` terminates the subprocess and - waits for it to exit before propagating the cancellation, like - :meth:`Process.aclose`. - - **Input:** The subprocess's standard input stream is set up to - receive the bytes provided as ``stdin``. Once the given input has - been fully delivered, or if none is provided, the subprocess will - receive end-of-file when reading from its standard input. - Alternatively, if you want the subprocess to read its - standard input from the same place as the parent Trio process, you - can pass ``stdin=None``. - - **Output:** By default, any output produced by the subprocess is - passed through to the standard output and error streams of the - parent Trio process. If you would like to capture this output and - do something with it, you can pass ``capture_stdout=True`` to - capture the subprocess's standard output, and/or - ``capture_stderr=True`` to capture its standard error. Captured - data is provided as the - :attr:`~subprocess.CompletedProcess.stdout` and/or - :attr:`~subprocess.CompletedProcess.stderr` attributes of the - returned :class:`~subprocess.CompletedProcess` object. The value - for any stream that was not captured will be ``None``. - - If you want to capture both stdout and stderr while keeping them - separate, pass ``capture_stdout=True, capture_stderr=True``. - - If you want to capture both stdout and stderr but mixed together - in the order they were printed, use: ``capture_stdout=True, stderr=subprocess.STDOUT``. - This directs the child's stderr into its stdout, so the combined - output will be available in the `~subprocess.CompletedProcess.stdout` - attribute. - - **Error checking:** If the subprocess exits with a nonzero status - code, indicating failure, :func:`run_process` raises a - :exc:`subprocess.CalledProcessError` exception rather than - returning normally. The captured outputs are still available as - the :attr:`~subprocess.CalledProcessError.stdout` and - :attr:`~subprocess.CalledProcessError.stderr` attributes of that - exception. To disable this behavior, so that :func:`run_process` - returns normally even if the subprocess exits abnormally, pass - ``check=False``. - - Args: - command (list or str): The command to run. Typically this is a - sequence of strings such as ``['ls', '-l', 'directory with spaces']``, - where the first element names the executable to invoke and the other - elements specify its arguments. With ``shell=True`` in the - ``**options``, or on Windows, ``command`` may alternatively - be a string, which will be parsed following platform-dependent - :ref:`quoting rules `. - - stdin (:obj:`bytes`, file descriptor, or None): The bytes to provide to - the subprocess on its standard input stream, or ``None`` if the - subprocess's standard input should come from the same place as - the parent Trio process's standard input. As is the case with - the :mod:`subprocess` module, you can also pass a - file descriptor or an object with a ``fileno()`` method, - in which case the subprocess's standard input will come from - that file. - - capture_stdout (bool): If true, capture the bytes that the subprocess - writes to its standard output stream and return them in the - :attr:`~subprocess.CompletedProcess.stdout` attribute - of the returned :class:`~subprocess.CompletedProcess` object. - - capture_stderr (bool): If true, capture the bytes that the subprocess - writes to its standard error stream and return them in the - :attr:`~subprocess.CompletedProcess.stderr` attribute - of the returned :class:`~subprocess.CompletedProcess` object. - - check (bool): If false, don't validate that the subprocess exits - successfully. You should be sure to check the - ``returncode`` attribute of the returned object if you pass - ``check=False``, so that errors don't pass silently. - - deliver_cancel (async function or None): If `run_process` is cancelled, - then it needs to kill the child process. There are multiple ways to - do this, so we let you customize it. - - If you pass None (the default), then the behavior depends on the - platform: - - - On Windows, Trio calls ``TerminateProcess``, which should kill the - process immediately. - - - On Unix-likes, the default behavior is to send a ``SIGTERM``, wait - 5 seconds, and send a ``SIGKILL``. - - Alternatively, you can customize this behavior by passing in an - arbitrary async function, which will be called with the `Process` - object as an argument. For example, the default Unix behavior could - be implemented like this:: - - async def my_deliver_cancel(process): - process.send_signal(signal.SIGTERM) - await trio.sleep(5) - process.send_signal(signal.SIGKILL) - - When the process actually exits, the ``deliver_cancel`` function - will automatically be cancelled – so if the process exits after - ``SIGTERM``, then we'll never reach the ``SIGKILL``. - - In any case, `run_process` will always wait for the child process to - exit before raising `Cancelled`. - - **options: :func:`run_process` also accepts any :ref:`general subprocess - options ` and passes them on to the - :class:`~trio.Process` constructor. This includes the - ``stdout`` and ``stderr`` options, which provide additional - redirection possibilities such as ``stderr=subprocess.STDOUT``, - ``stdout=subprocess.DEVNULL``, or file descriptors. - - Returns: - A :class:`subprocess.CompletedProcess` instance describing the - return code and outputs. - - Raises: - UnicodeError: if ``stdin`` is specified as a Unicode string, rather - than bytes - ValueError: if multiple redirections are specified for the same - stream, e.g., both ``capture_stdout=True`` and - ``stdout=subprocess.DEVNULL`` - subprocess.CalledProcessError: if ``check=False`` is not passed - and the process exits with a nonzero exit status - OSError: if an error is encountered starting or communicating with - the process - - .. note:: The child process runs in the same process group as the parent - Trio process, so a Ctrl+C will be delivered simultaneously to both - parent and child. If you don't want this behavior, consult your - platform's documentation for starting child processes in a different - process group. - - """ - - if isinstance(stdin, str): - raise UnicodeError("process stdin must be bytes, not str") - if stdin == subprocess.PIPE: - raise ValueError( - "stdin=subprocess.PIPE doesn't make sense since the pipe " - "is internal to run_process(); pass the actual data you " - "want to send over that pipe instead" - ) - if isinstance(stdin, (bytes, bytearray, memoryview)): - input = stdin - options["stdin"] = subprocess.PIPE - else: - # stdin should be something acceptable to Process - # (None, DEVNULL, a file descriptor, etc) and Process - # will raise if it's not - input = None - options["stdin"] = stdin - - if capture_stdout: - if "stdout" in options: - raise ValueError("can't specify both stdout and capture_stdout") - options["stdout"] = subprocess.PIPE - if capture_stderr: - if "stderr" in options: - raise ValueError("can't specify both stderr and capture_stderr") - options["stderr"] = subprocess.PIPE - - if deliver_cancel is None: - if os.name == "nt": - deliver_cancel = _windows_deliver_cancel - else: - assert os.name == "posix" - deliver_cancel = _posix_deliver_cancel - - stdout_chunks = [] - stderr_chunks = [] - - async with await open_process(command, **options) as proc: - - async def feed_input(): - async with proc.stdin: - try: - await proc.stdin.send_all(input) - except trio.BrokenResourceError: - pass - - async def read_output(stream, chunks): - async with stream: - async for chunk in stream: - chunks.append(chunk) - - async with trio.open_nursery() as nursery: - if proc.stdin is not None: - nursery.start_soon(feed_input) - if proc.stdout is not None: - nursery.start_soon(read_output, proc.stdout, stdout_chunks) - if proc.stderr is not None: - nursery.start_soon(read_output, proc.stderr, stderr_chunks) - try: - await proc.wait() - except trio.Cancelled: - with trio.CancelScope(shield=True): - killer_cscope = trio.CancelScope(shield=True) - - async def killer(): - with killer_cscope: - await deliver_cancel(proc) - - nursery.start_soon(killer) - await proc.wait() - killer_cscope.cancel() - raise - - stdout = b"".join(stdout_chunks) if proc.stdout is not None else None - stderr = b"".join(stderr_chunks) if proc.stderr is not None else None - - if proc.returncode and check: - raise subprocess.CalledProcessError( - proc.returncode, proc.args, output=stdout, stderr=stderr - ) - else: - return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr) diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 0d2d11c53c..6645812f58 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -1,5 +1,6 @@ import os import errno +from typing import Optional from ._abc import Stream from ._util import ConflictDetector, Final @@ -116,7 +117,7 @@ def __init__(self, fd: int): "another task is using this stream for receive" ) - async def send_all(self, data: bytes): + async def send_all(self, data: bytes) -> None: with self._send_conflict_detector: # have to check up front, because send_all(b"") on a closed pipe # should raise @@ -152,7 +153,7 @@ async def wait_send_all_might_not_block(self) -> None: # of sending, which is annoying raise trio.BrokenResourceError from e - async def receive_some(self, max_bytes=None) -> bytes: + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: with self._receive_conflict_detector: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE diff --git a/trio/testing/_sequencer.py b/trio/testing/_sequencer.py index a7e6e50ff0..0922e1c9ad 100644 --- a/trio/testing/_sequencer.py +++ b/trio/testing/_sequencer.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import AsyncIterator, DefaultDict, Set import attr from async_generator import asynccontextmanager @@ -7,9 +8,6 @@ from .. import _util from .. import Event -if False: - from typing import DefaultDict, Set - @attr.s(eq=False, hash=False) class Sequencer(metaclass=_util.Final): @@ -59,7 +57,7 @@ async def main(): _broken = attr.ib(default=False, init=False) @asynccontextmanager - async def __call__(self, position: int): + async def __call__(self, position: int) -> AsyncIterator[None]: if position in self._claimed: raise RuntimeError("Attempted to re-use sequence point {}".format(position)) if self._broken: From c5d01709cbc4acce0eb40609ccf17f644de02667 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 14:41:40 -0500 Subject: [PATCH 05/50] just os.PathLike without [str] for now --- trio/_path.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/trio/_path.py b/trio/_path.py index a4162693db..53a52ddf15 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -2,7 +2,7 @@ import os import types import pathlib -from typing import Iterator, TypeVar, Union +from typing import Iterator, TYPE_CHECKING, TypeVar, Union import trio from trio._util import async_wraps, Final @@ -155,19 +155,21 @@ class Path(metaclass=AsyncAutoWrapperType): ] _wrap_iter = ["glob", "rglob", "iterdir"] - # TODO: fill out the rest. Just copy from typeshed? Maybe this design won't pan - # out cleaner than a stub .pyi in the long run. - # https://github.com/python/typeshed/blob/58032a701811093d7bd24f9f75ad5e5de07e7723/stdlib/3/pathlib.pyi#L17-L53 + if TYPE_CHECKING: + # TODO: fill out the rest. Just copy from typeshed? Maybe this design won't pan + # out cleaner than a stub .pyi in the long run. + # https://github.com/python/typeshed/blob/58032a701811093d7bd24f9f75ad5e5de07e7723/stdlib/3/pathlib.pyi#L17-L53 - # NOTE: These are effectively type hints compensating for Mypy not being able to - # see through AsyncAutoWrapperType. They are inline here such that the rest - # of the file can be hinted regularly rather than in a separate stub .pyi. + # NOTE: These are effectively type hints compensating for Mypy not being able to + # see through AsyncAutoWrapperType. They are inline here such that the rest + # of the file can be hinted regularly rather than in a separate stub .pyi. - def joinpath(self: _P, *other: Union[str, os.PathLike[str]]) -> _P: - ... + # TODO: Can we handle os.PathLike[str] at least for 3.9+? + def joinpath(self: _P, *other: Union[str, os.PathLike]) -> _P: + ... - def iterdir(self: _P) -> Iterator[_P]: - ... + def iterdir(self: _P) -> Iterator[_P]: + ... def __init__(self, *args): self._wrapped = pathlib.Path(*args) From 874eea17495368a289ed6af93ab8d1d3c381737d Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 14:46:31 -0500 Subject: [PATCH 06/50] get Literal from typing_extensions --- trio/_subprocess.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 48e242378d..f3b2a7109c 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -10,7 +10,6 @@ Any, Awaitable, Callable, - Literal, Mapping, Optional, overload, @@ -19,7 +18,7 @@ TYPE_CHECKING, ) -from typing_extensions import Protocol +from typing_extensions import Literal, Protocol from ._abc import AsyncResource, SendStream, ReceiveStream from ._highlevel_generic import StapledStream From 778b4257815c242e3ecbbc8178b5f1b1dcd8b73f Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 14:55:56 -0500 Subject: [PATCH 07/50] cleanup --- trio/_subprocess.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trio/_subprocess.py b/trio/_subprocess.py index f3b2a7109c..8cb8755699 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -3,7 +3,6 @@ import os import subprocess import sys -from typing import Optional from functools import partial import warnings from typing import ( From b0a5e9ea8549b7ef141d2bd761750bc5c7fcfd72 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 15:50:51 -0500 Subject: [PATCH 08/50] work on Windows specific issues --- trio/_subprocess.py | 60 ++++++++++++++--------------- trio/_subprocess_platform/waitid.py | 9 +++-- trio/_windows_pipes.py | 6 +-- trio/tests/test_subprocess.py | 11 ++++-- 4 files changed, 47 insertions(+), 39 deletions(-) diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 8cb8755699..6b29c4465c 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -380,20 +380,20 @@ async def open_process( ) -> Process: ... - async def open_process( # type: ignore[no-untyped-def] - command, + async def open_process( + command: Union[str, Sequence[str]], *, - stdin=..., - stdout=..., - stderr=..., - close_fds=..., - shell=..., - cwd=..., - env=..., - preexec_fn=..., - restore_signals=..., - start_new_session=..., - pass_fds=..., + stdin: _Redirect = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: Union[Literal[True], bool] = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., ) -> Process: ... @@ -441,24 +441,24 @@ async def run_process( ) -> subprocess.CompletedProcess[bytes]: ... - async def run_process( # type: ignore[no-untyped-def] - command, + async def run_process( + command: Union[str, Sequence[str]], *, - stdin=..., - capture_stdout=..., - capture_stderr=..., - check=..., - deliver_cancel=..., - stdout=..., - stderr=..., - close_fds=..., - shell=..., - cwd=..., - env=..., - preexec_fn=..., - restore_signals=..., - start_new_session=..., - pass_fds=..., + stdin: Union[bytes, _Redirect] = ..., + capture_stdout: bool = ..., + capture_stderr: bool = ..., + check: bool = ..., + deliver_cancel: Callable[[Process], Awaitable[None]] = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: Union[Literal[True], bool] = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., ) -> subprocess.CompletedProcess[bytes]: ... diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index 486c37acb9..dad765a0f1 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -7,14 +7,17 @@ from .._sync import CapacityLimiter, Event from .._threads import to_thread_run_sync -try: - from os import waitid + +waitid = getattr(os, "waitid", None) + + +if waitid is not None: def sync_wait_reapable(pid): waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT) -except ImportError: +else: # pypy doesn't define os.waitid so we need to pull it out ourselves # using cffi: https://bitbucket.org/pypy/pypy/issues/2922/ import cffi diff --git a/trio/_windows_pipes.py b/trio/_windows_pipes.py index fb420535f4..025c8742a7 100644 --- a/trio/_windows_pipes.py +++ b/trio/_windows_pipes.py @@ -1,5 +1,5 @@ import sys -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from . import _core from ._abc import SendStream, ReceiveStream from ._util import ConflictDetector, Final @@ -52,7 +52,7 @@ def __init__(self, handle: int) -> None: "another task is currently using this pipe" ) - async def send_all(self, data: bytes): + async def send_all(self, data: bytes) -> None: with self._conflict_detector: if self._handle_holder.closed: raise _core.ClosedResourceError("this pipe is already closed") @@ -91,7 +91,7 @@ def __init__(self, handle: int) -> None: "another task is currently using this pipe" ) - async def receive_some(self, max_bytes=None) -> bytes: + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: with self._conflict_detector: if self._handle_holder.closed: raise _core.ClosedResourceError("this pipe is already closed") diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index 43c5bf5dd0..3bfc367224 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -26,9 +26,14 @@ SIGTERM: Optional[signal.Signals] SIGUSR1: Optional[signal.Signals] -posix = os.name == "posix" -if posix: - from signal import SIGKILL, SIGTERM, SIGUSR1 +# TODO: is this the proper translation from os.name to sys.platform? +# Mypy understands sys.platform but not os.name +posix = sys.platform != "win32" + +if sys.platform != "win32": + import signal + + SIGKILL, SIGTERM, SIGUSR1 = signal.SIGKILL, signal.SIGTERM, signal.SIGUSR1 else: SIGKILL, SIGTERM, SIGUSR1 = None, None, None From 5366d378823bd94ca6510c1238660c63dedef6b4 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 16:32:25 -0500 Subject: [PATCH 09/50] a couple more details --- trio/_unix_pipes.py | 339 +++++++++++++++++----------------- trio/tests/test_unix_pipes.py | 7 +- 2 files changed, 171 insertions(+), 175 deletions(-) diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 6645812f58..19b3bb4348 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -1,5 +1,6 @@ import os import errno +import sys from typing import Optional from ._abc import Stream @@ -7,182 +8,178 @@ import trio -if os.name != "posix": - # We raise an error here rather than gating the import in lowlevel.py - # in order to keep jedi static analysis happy. - raise ImportError - -# XX TODO: is this a good number? who knows... it does match the default Linux -# pipe capacity though. -DEFAULT_RECEIVE_SIZE = 65536 - - -class _FdHolder: - # This class holds onto a raw file descriptor, in non-blocking mode, and - # is responsible for managing its lifecycle. In particular, it's - # responsible for making sure it gets closed, and also for tracking - # whether it's been closed. - # - # The way we track closure is to set the .fd field to -1, discarding the - # original value. You might think that this is a strange idea, since it - # overloads the same field to do two different things. Wouldn't it be more - # natural to have a dedicated .closed field? But that would be more - # error-prone. Fds are represented by small integers, and once an fd is - # closed, its integer value may be reused immediately. If we accidentally - # used the old fd after being closed, we might end up doing something to - # another unrelated fd that happened to get assigned the same integer - # value. By throwing away the integer value immediately, it becomes - # impossible to make this mistake – we'll just get an EBADF. - # - # (This trick was copied from the stdlib socket module.) - def __init__(self, fd: int): - # make sure self.fd is always initialized to *something*, because even - # if we error out here then __del__ will run and access it. - self.fd = -1 - if not isinstance(fd, int): - raise TypeError("file descriptor must be an int") - self.fd = fd - # Store original state, and ensure non-blocking mode is enabled - self._original_is_blocking = os.get_blocking(fd) - os.set_blocking(fd, False) - - @property - def closed(self): - return self.fd == -1 - - def _raw_close(self): - # This doesn't assume it's in a Trio context, so it can be called from - # __del__. You should never call it from Trio context, because it - # skips calling notify_fd_close. But from __del__, skipping that is - # OK, because notify_fd_close just wakes up other tasks that are - # waiting on this fd, and those tasks hold a reference to this object. - # So if __del__ is being called, we know there aren't any tasks that - # need to be woken. - if self.closed: - return - fd = self.fd - self.fd = -1 - os.set_blocking(fd, self._original_is_blocking) - os.close(fd) - - def __del__(self): - self._raw_close() - - async def aclose(self): - if not self.closed: - trio.lowlevel.notify_closing(self.fd) +if sys.platform != "win32": + # XX TODO: is this a good number? who knows... it does match the default Linux + # pipe capacity though. + DEFAULT_RECEIVE_SIZE = 65536 + + + class _FdHolder: + # This class holds onto a raw file descriptor, in non-blocking mode, and + # is responsible for managing its lifecycle. In particular, it's + # responsible for making sure it gets closed, and also for tracking + # whether it's been closed. + # + # The way we track closure is to set the .fd field to -1, discarding the + # original value. You might think that this is a strange idea, since it + # overloads the same field to do two different things. Wouldn't it be more + # natural to have a dedicated .closed field? But that would be more + # error-prone. Fds are represented by small integers, and once an fd is + # closed, its integer value may be reused immediately. If we accidentally + # used the old fd after being closed, we might end up doing something to + # another unrelated fd that happened to get assigned the same integer + # value. By throwing away the integer value immediately, it becomes + # impossible to make this mistake – we'll just get an EBADF. + # + # (This trick was copied from the stdlib socket module.) + def __init__(self, fd: int): + # make sure self.fd is always initialized to *something*, because even + # if we error out here then __del__ will run and access it. + self.fd = -1 + if not isinstance(fd, int): + raise TypeError("file descriptor must be an int") + self.fd = fd + # Store original state, and ensure non-blocking mode is enabled + self._original_is_blocking = os.get_blocking(fd) + os.set_blocking(fd, False) + + @property + def closed(self): + return self.fd == -1 + + def _raw_close(self): + # This doesn't assume it's in a Trio context, so it can be called from + # __del__. You should never call it from Trio context, because it + # skips calling notify_fd_close. But from __del__, skipping that is + # OK, because notify_fd_close just wakes up other tasks that are + # waiting on this fd, and those tasks hold a reference to this object. + # So if __del__ is being called, we know there aren't any tasks that + # need to be woken. + if self.closed: + return + fd = self.fd + self.fd = -1 + os.set_blocking(fd, self._original_is_blocking) + os.close(fd) + + def __del__(self): self._raw_close() - await trio.lowlevel.checkpoint() - - -class FdStream(Stream, metaclass=Final): - """ - Represents a stream given the file descriptor to a pipe, TTY, etc. - - *fd* must refer to a file that is open for reading and/or writing and - supports non-blocking I/O (pipes and TTYs will work, on-disk files probably - not). The returned stream takes ownership of the fd, so closing the stream - will close the fd too. As with `os.fdopen`, you should not directly use - an fd after you have wrapped it in a stream using this function. - - To be used as a Trio stream, an open file must be placed in non-blocking - mode. Unfortunately, this impacts all I/O that goes through the - underlying open file, including I/O that uses a different - file descriptor than the one that was passed to Trio. If other threads - or processes are using file descriptors that are related through `os.dup` - or inheritance across `os.fork` to the one that Trio is using, they are - unlikely to be prepared to have non-blocking I/O semantics suddenly - thrust upon them. For example, you can use - ``FdStream(os.dup(sys.stdin.fileno()))`` to obtain a stream for reading - from standard input, but it is only safe to do so with heavy caveats: your - stdin must not be shared by any other processes and you must not make any - calls to synchronous methods of `sys.stdin` until the stream returned by - `FdStream` is closed. See `issue #174 - `__ for a discussion of the - challenges involved in relaxing this restriction. - - Args: - fd (int): The fd to be wrapped. - - Returns: - A new `FdStream` object. - """ - - def __init__(self, fd: int): - self._fd_holder = _FdHolder(fd) - self._send_conflict_detector = ConflictDetector( - "another task is using this stream for send" - ) - self._receive_conflict_detector = ConflictDetector( - "another task is using this stream for receive" - ) - - async def send_all(self, data: bytes) -> None: - with self._send_conflict_detector: - # have to check up front, because send_all(b"") on a closed pipe - # should raise - if self._fd_holder.closed: - raise trio.ClosedResourceError("file was already closed") - await trio.lowlevel.checkpoint() - length = len(data) - # adapted from the SocketStream code - with memoryview(data) as view: - sent = 0 - while sent < length: - with view[sent:] as remaining: - try: - sent += os.write(self._fd_holder.fd, remaining) - except BlockingIOError: - await trio.lowlevel.wait_writable(self._fd_holder.fd) - except OSError as e: - if e.errno == errno.EBADF: - raise trio.ClosedResourceError( - "file was already closed" - ) from None - else: - raise trio.BrokenResourceError from e - - async def wait_send_all_might_not_block(self) -> None: - with self._send_conflict_detector: - if self._fd_holder.closed: - raise trio.ClosedResourceError("file was already closed") - try: - await trio.lowlevel.wait_writable(self._fd_holder.fd) - except BrokenPipeError as e: - # kqueue: raises EPIPE on wait_writable instead - # of sending, which is annoying - raise trio.BrokenResourceError from e - - async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: - with self._receive_conflict_detector: - if max_bytes is None: - max_bytes = DEFAULT_RECEIVE_SIZE - else: - if not isinstance(max_bytes, int): - raise TypeError("max_bytes must be integer >= 1") - if max_bytes < 1: - raise ValueError("max_bytes must be integer >= 1") + async def aclose(self): + if not self.closed: + trio.lowlevel.notify_closing(self.fd) + self._raw_close() await trio.lowlevel.checkpoint() - while True: + + + class FdStream(Stream, metaclass=Final): + """ + Represents a stream given the file descriptor to a pipe, TTY, etc. + + *fd* must refer to a file that is open for reading and/or writing and + supports non-blocking I/O (pipes and TTYs will work, on-disk files probably + not). The returned stream takes ownership of the fd, so closing the stream + will close the fd too. As with `os.fdopen`, you should not directly use + an fd after you have wrapped it in a stream using this function. + + To be used as a Trio stream, an open file must be placed in non-blocking + mode. Unfortunately, this impacts all I/O that goes through the + underlying open file, including I/O that uses a different + file descriptor than the one that was passed to Trio. If other threads + or processes are using file descriptors that are related through `os.dup` + or inheritance across `os.fork` to the one that Trio is using, they are + unlikely to be prepared to have non-blocking I/O semantics suddenly + thrust upon them. For example, you can use + ``FdStream(os.dup(sys.stdin.fileno()))`` to obtain a stream for reading + from standard input, but it is only safe to do so with heavy caveats: your + stdin must not be shared by any other processes and you must not make any + calls to synchronous methods of `sys.stdin` until the stream returned by + `FdStream` is closed. See `issue #174 + `__ for a discussion of the + challenges involved in relaxing this restriction. + + Args: + fd (int): The fd to be wrapped. + + Returns: + A new `FdStream` object. + """ + + def __init__(self, fd: int): + self._fd_holder = _FdHolder(fd) + self._send_conflict_detector = ConflictDetector( + "another task is using this stream for send" + ) + self._receive_conflict_detector = ConflictDetector( + "another task is using this stream for receive" + ) + + async def send_all(self, data: bytes) -> None: + with self._send_conflict_detector: + # have to check up front, because send_all(b"") on a closed pipe + # should raise + if self._fd_holder.closed: + raise trio.ClosedResourceError("file was already closed") + await trio.lowlevel.checkpoint() + length = len(data) + # adapted from the SocketStream code + with memoryview(data) as view: + sent = 0 + while sent < length: + with view[sent:] as remaining: + try: + sent += os.write(self._fd_holder.fd, remaining) + except BlockingIOError: + await trio.lowlevel.wait_writable(self._fd_holder.fd) + except OSError as e: + if e.errno == errno.EBADF: + raise trio.ClosedResourceError( + "file was already closed" + ) from None + else: + raise trio.BrokenResourceError from e + + async def wait_send_all_might_not_block(self) -> None: + with self._send_conflict_detector: + if self._fd_holder.closed: + raise trio.ClosedResourceError("file was already closed") try: - data = os.read(self._fd_holder.fd, max_bytes) - except BlockingIOError: - await trio.lowlevel.wait_readable(self._fd_holder.fd) - except OSError as e: - if e.errno == errno.EBADF: - raise trio.ClosedResourceError( - "file was already closed" - ) from None - else: - raise trio.BrokenResourceError from e + await trio.lowlevel.wait_writable(self._fd_holder.fd) + except BrokenPipeError as e: + # kqueue: raises EPIPE on wait_writable instead + # of sending, which is annoying + raise trio.BrokenResourceError from e + + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: + with self._receive_conflict_detector: + if max_bytes is None: + max_bytes = DEFAULT_RECEIVE_SIZE else: - break + if not isinstance(max_bytes, int): + raise TypeError("max_bytes must be integer >= 1") + if max_bytes < 1: + raise ValueError("max_bytes must be integer >= 1") + + await trio.lowlevel.checkpoint() + while True: + try: + data = os.read(self._fd_holder.fd, max_bytes) + except BlockingIOError: + await trio.lowlevel.wait_readable(self._fd_holder.fd) + except OSError as e: + if e.errno == errno.EBADF: + raise trio.ClosedResourceError( + "file was already closed" + ) from None + else: + raise trio.BrokenResourceError from e + else: + break - return data + return data - async def aclose(self): - await self._fd_holder.aclose() + async def aclose(self): + await self._fd_holder.aclose() - def fileno(self): - return self._fd_holder.fd + def fileno(self): + return self._fd_holder.fd diff --git a/trio/tests/test_unix_pipes.py b/trio/tests/test_unix_pipes.py index da5e696bf6..f2f38e776a 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/tests/test_unix_pipes.py @@ -11,13 +11,12 @@ from .. import _core, move_on_after from ..testing import wait_all_tasks_blocked, check_one_way_stream -posix = os.name == "posix" -pytestmark = pytest.mark.skipif(not posix, reason="posix only") -if posix: +pytestmark = pytest.mark.skipif(sys.platform == "win32", reason="posix only") +if sys.platform != "win32": from .._unix_pipes import FdStream else: with pytest.raises(ImportError): - from .._unix_pipes import FdStream + from .._unix_pipes import FdStream # type: ignore[attr-defined] # Have to use quoted types so import doesn't crash on windows From 7b380474ed365b13d2ad5dc002f67a15b873fcf4 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 16:36:42 -0500 Subject: [PATCH 10/50] black --- trio/_unix_pipes.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 19b3bb4348..6be860f9f9 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -13,7 +13,6 @@ # pipe capacity though. DEFAULT_RECEIVE_SIZE = 65536 - class _FdHolder: # This class holds onto a raw file descriptor, in non-blocking mode, and # is responsible for managing its lifecycle. In particular, it's @@ -71,7 +70,6 @@ async def aclose(self): self._raw_close() await trio.lowlevel.checkpoint() - class FdStream(Stream, metaclass=Final): """ Represents a stream given the file descriptor to a pipe, TTY, etc. From 8371daf65675968089067020ec48880b6520b471 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 17:45:31 -0500 Subject: [PATCH 11/50] maybe this time --- trio/_unix_pipes.py | 333 ++++++++++++------------ trio/tests/test_unix_pipes.py | 470 +++++++++++++++++----------------- 2 files changed, 401 insertions(+), 402 deletions(-) diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 6be860f9f9..1c28e88d64 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -8,176 +8,179 @@ import trio -if sys.platform != "win32": - # XX TODO: is this a good number? who knows... it does match the default Linux - # pipe capacity though. - DEFAULT_RECEIVE_SIZE = 65536 - - class _FdHolder: - # This class holds onto a raw file descriptor, in non-blocking mode, and - # is responsible for managing its lifecycle. In particular, it's - # responsible for making sure it gets closed, and also for tracking - # whether it's been closed. - # - # The way we track closure is to set the .fd field to -1, discarding the - # original value. You might think that this is a strange idea, since it - # overloads the same field to do two different things. Wouldn't it be more - # natural to have a dedicated .closed field? But that would be more - # error-prone. Fds are represented by small integers, and once an fd is - # closed, its integer value may be reused immediately. If we accidentally - # used the old fd after being closed, we might end up doing something to - # another unrelated fd that happened to get assigned the same integer - # value. By throwing away the integer value immediately, it becomes - # impossible to make this mistake – we'll just get an EBADF. - # - # (This trick was copied from the stdlib socket module.) - def __init__(self, fd: int): - # make sure self.fd is always initialized to *something*, because even - # if we error out here then __del__ will run and access it. - self.fd = -1 - if not isinstance(fd, int): - raise TypeError("file descriptor must be an int") - self.fd = fd - # Store original state, and ensure non-blocking mode is enabled - self._original_is_blocking = os.get_blocking(fd) - os.set_blocking(fd, False) - - @property - def closed(self): - return self.fd == -1 - - def _raw_close(self): - # This doesn't assume it's in a Trio context, so it can be called from - # __del__. You should never call it from Trio context, because it - # skips calling notify_fd_close. But from __del__, skipping that is - # OK, because notify_fd_close just wakes up other tasks that are - # waiting on this fd, and those tasks hold a reference to this object. - # So if __del__ is being called, we know there aren't any tasks that - # need to be woken. - if self.closed: - return - fd = self.fd - self.fd = -1 - os.set_blocking(fd, self._original_is_blocking) - os.close(fd) - - def __del__(self): +assert sys.platform != "win32" + +# XX TODO: is this a good number? who knows... it does match the default Linux +# pipe capacity though. +DEFAULT_RECEIVE_SIZE = 65536 + + +class _FdHolder: + # This class holds onto a raw file descriptor, in non-blocking mode, and + # is responsible for managing its lifecycle. In particular, it's + # responsible for making sure it gets closed, and also for tracking + # whether it's been closed. + # + # The way we track closure is to set the .fd field to -1, discarding the + # original value. You might think that this is a strange idea, since it + # overloads the same field to do two different things. Wouldn't it be more + # natural to have a dedicated .closed field? But that would be more + # error-prone. Fds are represented by small integers, and once an fd is + # closed, its integer value may be reused immediately. If we accidentally + # used the old fd after being closed, we might end up doing something to + # another unrelated fd that happened to get assigned the same integer + # value. By throwing away the integer value immediately, it becomes + # impossible to make this mistake – we'll just get an EBADF. + # + # (This trick was copied from the stdlib socket module.) + def __init__(self, fd: int): + # make sure self.fd is always initialized to *something*, because even + # if we error out here then __del__ will run and access it. + self.fd = -1 + if not isinstance(fd, int): + raise TypeError("file descriptor must be an int") + self.fd = fd + # Store original state, and ensure non-blocking mode is enabled + self._original_is_blocking = os.get_blocking(fd) + os.set_blocking(fd, False) + + @property + def closed(self): + return self.fd == -1 + + def _raw_close(self): + # This doesn't assume it's in a Trio context, so it can be called from + # __del__. You should never call it from Trio context, because it + # skips calling notify_fd_close. But from __del__, skipping that is + # OK, because notify_fd_close just wakes up other tasks that are + # waiting on this fd, and those tasks hold a reference to this object. + # So if __del__ is being called, we know there aren't any tasks that + # need to be woken. + if self.closed: + return + fd = self.fd + self.fd = -1 + os.set_blocking(fd, self._original_is_blocking) + os.close(fd) + + def __del__(self): + self._raw_close() + + async def aclose(self): + if not self.closed: + trio.lowlevel.notify_closing(self.fd) self._raw_close() - - async def aclose(self): - if not self.closed: - trio.lowlevel.notify_closing(self.fd) - self._raw_close() + await trio.lowlevel.checkpoint() + + +class FdStream(Stream, metaclass=Final): + """ + Represents a stream given the file descriptor to a pipe, TTY, etc. + + *fd* must refer to a file that is open for reading and/or writing and + supports non-blocking I/O (pipes and TTYs will work, on-disk files probably + not). The returned stream takes ownership of the fd, so closing the stream + will close the fd too. As with `os.fdopen`, you should not directly use + an fd after you have wrapped it in a stream using this function. + + To be used as a Trio stream, an open file must be placed in non-blocking + mode. Unfortunately, this impacts all I/O that goes through the + underlying open file, including I/O that uses a different + file descriptor than the one that was passed to Trio. If other threads + or processes are using file descriptors that are related through `os.dup` + or inheritance across `os.fork` to the one that Trio is using, they are + unlikely to be prepared to have non-blocking I/O semantics suddenly + thrust upon them. For example, you can use + ``FdStream(os.dup(sys.stdin.fileno()))`` to obtain a stream for reading + from standard input, but it is only safe to do so with heavy caveats: your + stdin must not be shared by any other processes and you must not make any + calls to synchronous methods of `sys.stdin` until the stream returned by + `FdStream` is closed. See `issue #174 + `__ for a discussion of the + challenges involved in relaxing this restriction. + + Args: + fd (int): The fd to be wrapped. + + Returns: + A new `FdStream` object. + """ + + def __init__(self, fd: int): + self._fd_holder = _FdHolder(fd) + self._send_conflict_detector = ConflictDetector( + "another task is using this stream for send" + ) + self._receive_conflict_detector = ConflictDetector( + "another task is using this stream for receive" + ) + + async def send_all(self, data: bytes) -> None: + with self._send_conflict_detector: + # have to check up front, because send_all(b"") on a closed pipe + # should raise + if self._fd_holder.closed: + raise trio.ClosedResourceError("file was already closed") await trio.lowlevel.checkpoint() + length = len(data) + # adapted from the SocketStream code + with memoryview(data) as view: + sent = 0 + while sent < length: + with view[sent:] as remaining: + try: + sent += os.write(self._fd_holder.fd, remaining) + except BlockingIOError: + await trio.lowlevel.wait_writable(self._fd_holder.fd) + except OSError as e: + if e.errno == errno.EBADF: + raise trio.ClosedResourceError( + "file was already closed" + ) from None + else: + raise trio.BrokenResourceError from e + + async def wait_send_all_might_not_block(self) -> None: + with self._send_conflict_detector: + if self._fd_holder.closed: + raise trio.ClosedResourceError("file was already closed") + try: + await trio.lowlevel.wait_writable(self._fd_holder.fd) + except BrokenPipeError as e: + # kqueue: raises EPIPE on wait_writable instead + # of sending, which is annoying + raise trio.BrokenResourceError from e + + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: + with self._receive_conflict_detector: + if max_bytes is None: + max_bytes = DEFAULT_RECEIVE_SIZE + else: + if not isinstance(max_bytes, int): + raise TypeError("max_bytes must be integer >= 1") + if max_bytes < 1: + raise ValueError("max_bytes must be integer >= 1") - class FdStream(Stream, metaclass=Final): - """ - Represents a stream given the file descriptor to a pipe, TTY, etc. - - *fd* must refer to a file that is open for reading and/or writing and - supports non-blocking I/O (pipes and TTYs will work, on-disk files probably - not). The returned stream takes ownership of the fd, so closing the stream - will close the fd too. As with `os.fdopen`, you should not directly use - an fd after you have wrapped it in a stream using this function. - - To be used as a Trio stream, an open file must be placed in non-blocking - mode. Unfortunately, this impacts all I/O that goes through the - underlying open file, including I/O that uses a different - file descriptor than the one that was passed to Trio. If other threads - or processes are using file descriptors that are related through `os.dup` - or inheritance across `os.fork` to the one that Trio is using, they are - unlikely to be prepared to have non-blocking I/O semantics suddenly - thrust upon them. For example, you can use - ``FdStream(os.dup(sys.stdin.fileno()))`` to obtain a stream for reading - from standard input, but it is only safe to do so with heavy caveats: your - stdin must not be shared by any other processes and you must not make any - calls to synchronous methods of `sys.stdin` until the stream returned by - `FdStream` is closed. See `issue #174 - `__ for a discussion of the - challenges involved in relaxing this restriction. - - Args: - fd (int): The fd to be wrapped. - - Returns: - A new `FdStream` object. - """ - - def __init__(self, fd: int): - self._fd_holder = _FdHolder(fd) - self._send_conflict_detector = ConflictDetector( - "another task is using this stream for send" - ) - self._receive_conflict_detector = ConflictDetector( - "another task is using this stream for receive" - ) - - async def send_all(self, data: bytes) -> None: - with self._send_conflict_detector: - # have to check up front, because send_all(b"") on a closed pipe - # should raise - if self._fd_holder.closed: - raise trio.ClosedResourceError("file was already closed") - await trio.lowlevel.checkpoint() - length = len(data) - # adapted from the SocketStream code - with memoryview(data) as view: - sent = 0 - while sent < length: - with view[sent:] as remaining: - try: - sent += os.write(self._fd_holder.fd, remaining) - except BlockingIOError: - await trio.lowlevel.wait_writable(self._fd_holder.fd) - except OSError as e: - if e.errno == errno.EBADF: - raise trio.ClosedResourceError( - "file was already closed" - ) from None - else: - raise trio.BrokenResourceError from e - - async def wait_send_all_might_not_block(self) -> None: - with self._send_conflict_detector: - if self._fd_holder.closed: - raise trio.ClosedResourceError("file was already closed") + await trio.lowlevel.checkpoint() + while True: try: - await trio.lowlevel.wait_writable(self._fd_holder.fd) - except BrokenPipeError as e: - # kqueue: raises EPIPE on wait_writable instead - # of sending, which is annoying - raise trio.BrokenResourceError from e - - async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: - with self._receive_conflict_detector: - if max_bytes is None: - max_bytes = DEFAULT_RECEIVE_SIZE - else: - if not isinstance(max_bytes, int): - raise TypeError("max_bytes must be integer >= 1") - if max_bytes < 1: - raise ValueError("max_bytes must be integer >= 1") - - await trio.lowlevel.checkpoint() - while True: - try: - data = os.read(self._fd_holder.fd, max_bytes) - except BlockingIOError: - await trio.lowlevel.wait_readable(self._fd_holder.fd) - except OSError as e: - if e.errno == errno.EBADF: - raise trio.ClosedResourceError( - "file was already closed" - ) from None - else: - raise trio.BrokenResourceError from e + data = os.read(self._fd_holder.fd, max_bytes) + except BlockingIOError: + await trio.lowlevel.wait_readable(self._fd_holder.fd) + except OSError as e: + if e.errno == errno.EBADF: + raise trio.ClosedResourceError( + "file was already closed" + ) from None else: - break + raise trio.BrokenResourceError from e + else: + break - return data + return data - async def aclose(self): - await self._fd_holder.aclose() + async def aclose(self): + await self._fd_holder.aclose() - def fileno(self): - return self._fd_holder.fd + def fileno(self): + return self._fd_holder.fd diff --git a/trio/tests/test_unix_pipes.py b/trio/tests/test_unix_pipes.py index f2f38e776a..12d65b8dcb 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/tests/test_unix_pipes.py @@ -12,254 +12,250 @@ from ..testing import wait_all_tasks_blocked, check_one_way_stream pytestmark = pytest.mark.skipif(sys.platform == "win32", reason="posix only") -if sys.platform != "win32": - from .._unix_pipes import FdStream -else: - with pytest.raises(ImportError): - from .._unix_pipes import FdStream # type: ignore[attr-defined] - - -# Have to use quoted types so import doesn't crash on windows -async def make_pipe() -> Tuple["FdStream", "FdStream"]: - """Makes a new pair of pipes.""" - (r, w) = os.pipe() - return FdStream(w), FdStream(r) - - -async def make_clogged_pipe(): - s, r = await make_pipe() - try: - while True: - # We want to totally fill up the pipe buffer. - # This requires working around a weird feature that POSIX pipes - # have. - # If you do a write of <= PIPE_BUF bytes, then it's guaranteed - # to either complete entirely, or not at all. So if we tried to - # write PIPE_BUF bytes, and the buffer's free space is only - # PIPE_BUF/2, then the write will raise BlockingIOError... even - # though a smaller write could still succeed! To avoid this, - # make sure to write >PIPE_BUF bytes each time, which disables - # the special behavior. - # For details, search for PIPE_BUF here: - # http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html - - # for the getattr: - # https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3 - buf_size = getattr(select, "PIPE_BUF", 8192) - os.write(s.fileno(), b"x" * buf_size * 2) - except BlockingIOError: - pass - return s, r - - -async def test_send_pipe(): - r, w = os.pipe() - async with FdStream(w) as send: - assert send.fileno() == w - await send.send_all(b"123") - assert (os.read(r, 8)) == b"123" - - os.close(r) - - -async def test_receive_pipe(): - r, w = os.pipe() - async with FdStream(r) as recv: - assert (recv.fileno()) == r - os.write(w, b"123") - assert (await recv.receive_some(8)) == b"123" - - os.close(w) - - -async def test_pipes_combined(): - write, read = await make_pipe() - count = 2 ** 20 - - async def sender(): - big = bytearray(count) - await write.send_all(big) - - async def reader(): - await wait_all_tasks_blocked() - received = 0 - while received < count: - received += len(await read.receive_some(4096)) - - assert received == count - - async with _core.open_nursery() as n: - n.start_soon(sender) - n.start_soon(reader) - - await read.aclose() - await write.aclose() - - -async def test_pipe_errors(): - with pytest.raises(TypeError): - FdStream(None) - - r, w = os.pipe() - os.close(w) - async with FdStream(r) as s: - with pytest.raises(ValueError): - await s.receive_some(0) - - -async def test_del(): - w, r = await make_pipe() - f1, f2 = w.fileno(), r.fileno() - del w, r - gc_collect_harder() - - with pytest.raises(OSError) as excinfo: - os.close(f1) - assert excinfo.value.errno == errno.EBADF - - with pytest.raises(OSError) as excinfo: - os.close(f2) - assert excinfo.value.errno == errno.EBADF - -async def test_async_with(): - w, r = await make_pipe() - async with w, r: - pass - - assert w.fileno() == -1 - assert r.fileno() == -1 - - with pytest.raises(OSError) as excinfo: - os.close(w.fileno()) - assert excinfo.value.errno == errno.EBADF - - with pytest.raises(OSError) as excinfo: - os.close(r.fileno()) - assert excinfo.value.errno == errno.EBADF - - -async def test_misdirected_aclose_regression(): - # https://github.com/python-trio/trio/issues/661#issuecomment-456582356 - w, r = await make_pipe() - old_r_fd = r.fileno() - - # Close the original objects - await w.aclose() - await r.aclose() - - # Do a little dance to get a new pipe whose receive handle matches the old - # receive handle. - r2_fd, w2_fd = os.pipe() - if r2_fd != old_r_fd: # pragma: no cover - os.dup2(r2_fd, old_r_fd) - os.close(r2_fd) - async with FdStream(old_r_fd) as r2: - assert r2.fileno() == old_r_fd - - # And now set up a background task that's working on the new receive - # handle - async def expect_eof(): - assert await r2.receive_some(10) == b"" +if sys.platform == "win32": + with pytest.raises(ImportError): + # Using sys instead of FdStream since sys is created before the assertion that + # terminates the import of _unix_pipes and this makes Mypy happier. The type + # warning can't be ignored since it is not present on all platforms and thus + # triggers a warning about being unneeded on other platforms. + from .._unix_pipes import sys +else: + from .._unix_pipes import FdStream - async with _core.open_nursery() as nursery: - nursery.start_soon(expect_eof) - await wait_all_tasks_blocked() + # Have to use quoted types so import doesn't crash on windows + async def make_pipe() -> Tuple["FdStream", "FdStream"]: + """Makes a new pair of pipes.""" + (r, w) = os.pipe() + return FdStream(w), FdStream(r) - # Here's the key test: does calling aclose() again on the *old* - # handle, cause the task blocked on the *new* handle to raise - # ClosedResourceError? - await r.aclose() + async def make_clogged_pipe(): + s, r = await make_pipe() + try: + while True: + # We want to totally fill up the pipe buffer. + # This requires working around a weird feature that POSIX pipes + # have. + # If you do a write of <= PIPE_BUF bytes, then it's guaranteed + # to either complete entirely, or not at all. So if we tried to + # write PIPE_BUF bytes, and the buffer's free space is only + # PIPE_BUF/2, then the write will raise BlockingIOError... even + # though a smaller write could still succeed! To avoid this, + # make sure to write >PIPE_BUF bytes each time, which disables + # the special behavior. + # For details, search for PIPE_BUF here: + # http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html + + # for the getattr: + # https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3 + buf_size = getattr(select, "PIPE_BUF", 8192) + os.write(s.fileno(), b"x" * buf_size * 2) + except BlockingIOError: + pass + return s, r + + async def test_send_pipe(): + r, w = os.pipe() + async with FdStream(w) as send: + assert send.fileno() == w + await send.send_all(b"123") + assert (os.read(r, 8)) == b"123" + + os.close(r) + + async def test_receive_pipe(): + r, w = os.pipe() + async with FdStream(r) as recv: + assert (recv.fileno()) == r + os.write(w, b"123") + assert (await recv.receive_some(8)) == b"123" + + os.close(w) + + async def test_pipes_combined(): + write, read = await make_pipe() + count = 2 ** 20 + + async def sender(): + big = bytearray(count) + await write.send_all(big) + + async def reader(): await wait_all_tasks_blocked() + received = 0 + while received < count: + received += len(await read.receive_some(4096)) - # Guess we survived! Close the new write handle so that the task - # gets an EOF and can exit cleanly. - os.close(w2_fd) + assert received == count + async with _core.open_nursery() as n: + n.start_soon(sender) + n.start_soon(reader) -async def test_close_at_bad_time_for_receive_some(monkeypatch): - # We used to have race conditions where if one task was using the pipe, - # and another closed it at *just* the wrong moment, it would give an - # unexpected error instead of ClosedResourceError: - # https://github.com/python-trio/trio/issues/661 - # - # This tests what happens if the pipe gets closed in the moment *between* - # when receive_some wakes up, and when it tries to call os.read - async def expect_closedresourceerror(): - with pytest.raises(_core.ClosedResourceError): - await r.receive_some(10) + await read.aclose() + await write.aclose() - orig_wait_readable = _core._run.TheIOManager.wait_readable + async def test_pipe_errors(): + with pytest.raises(TypeError): + FdStream(None) - async def patched_wait_readable(*args, **kwargs): - await orig_wait_readable(*args, **kwargs) + r, w = os.pipe() + os.close(w) + async with FdStream(r) as s: + with pytest.raises(ValueError): + await s.receive_some(0) + + async def test_del(): + w, r = await make_pipe() + f1, f2 = w.fileno(), r.fileno() + del w, r + gc_collect_harder() + + with pytest.raises(OSError) as excinfo: + os.close(f1) + assert excinfo.value.errno == errno.EBADF + + with pytest.raises(OSError) as excinfo: + os.close(f2) + assert excinfo.value.errno == errno.EBADF + + async def test_async_with(): + w, r = await make_pipe() + async with w, r: + pass + + assert w.fileno() == -1 + assert r.fileno() == -1 + + with pytest.raises(OSError) as excinfo: + os.close(w.fileno()) + assert excinfo.value.errno == errno.EBADF + + with pytest.raises(OSError) as excinfo: + os.close(r.fileno()) + assert excinfo.value.errno == errno.EBADF + + async def test_misdirected_aclose_regression(): + # https://github.com/python-trio/trio/issues/661#issuecomment-456582356 + w, r = await make_pipe() + old_r_fd = r.fileno() + + # Close the original objects + await w.aclose() await r.aclose() - monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable) - s, r = await make_pipe() - async with s, r: - async with _core.open_nursery() as nursery: - nursery.start_soon(expect_closedresourceerror) - await wait_all_tasks_blocked() - # Trigger everything by waking up the receiver - await s.send_all(b"x") - - -async def test_close_at_bad_time_for_send_all(monkeypatch): - # We used to have race conditions where if one task was using the pipe, - # and another closed it at *just* the wrong moment, it would give an - # unexpected error instead of ClosedResourceError: - # https://github.com/python-trio/trio/issues/661 - # - # This tests what happens if the pipe gets closed in the moment *between* - # when send_all wakes up, and when it tries to call os.write - async def expect_closedresourceerror(): - with pytest.raises(_core.ClosedResourceError): - await s.send_all(b"x" * 100) - - orig_wait_writable = _core._run.TheIOManager.wait_writable - - async def patched_wait_writable(*args, **kwargs): - await orig_wait_writable(*args, **kwargs) - await s.aclose() - - monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable) - s, r = await make_clogged_pipe() - async with s, r: - async with _core.open_nursery() as nursery: - nursery.start_soon(expect_closedresourceerror) - await wait_all_tasks_blocked() - # Trigger everything by waking up the sender - await r.receive_some(10000) - - -# On FreeBSD, directories are readable, and we haven't found any other trick -# for making an unreadable fd, so there's no way to run this test. Fortunately -# the logic this is testing doesn't depend on the platform, so testing on -# other platforms is probably good enough. -@pytest.mark.skipif( - sys.platform.startswith("freebsd"), - reason="no way to make read() return a bizarro error on FreeBSD", -) -async def test_bizarro_OSError_from_receive(): - # Make sure that if the read syscall returns some bizarro error, then we - # get a BrokenResourceError. This is incredibly unlikely; there's almost - # no way to trigger a failure here intentionally (except for EBADF, but we - # exploit that to detect file closure, so it takes a different path). So - # we set up a strange scenario where the pipe fd somehow transmutes into a - # directory fd, causing os.read to raise IsADirectoryError (yes, that's a - # real built-in exception type). - s, r = await make_pipe() - async with s, r: - dir_fd = os.open("/", os.O_DIRECTORY, 0) - try: - os.dup2(dir_fd, r.fileno()) - with pytest.raises(_core.BrokenResourceError): + # Do a little dance to get a new pipe whose receive handle matches the old + # receive handle. + r2_fd, w2_fd = os.pipe() + if r2_fd != old_r_fd: # pragma: no cover + os.dup2(r2_fd, old_r_fd) + os.close(r2_fd) + async with FdStream(old_r_fd) as r2: + assert r2.fileno() == old_r_fd + + # And now set up a background task that's working on the new receive + # handle + async def expect_eof(): + assert await r2.receive_some(10) == b"" + + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_eof) + await wait_all_tasks_blocked() + + # Here's the key test: does calling aclose() again on the *old* + # handle, cause the task blocked on the *new* handle to raise + # ClosedResourceError? + await r.aclose() + await wait_all_tasks_blocked() + + # Guess we survived! Close the new write handle so that the task + # gets an EOF and can exit cleanly. + os.close(w2_fd) + + async def test_close_at_bad_time_for_receive_some(monkeypatch): + # We used to have race conditions where if one task was using the pipe, + # and another closed it at *just* the wrong moment, it would give an + # unexpected error instead of ClosedResourceError: + # https://github.com/python-trio/trio/issues/661 + # + # This tests what happens if the pipe gets closed in the moment *between* + # when receive_some wakes up, and when it tries to call os.read + async def expect_closedresourceerror(): + with pytest.raises(_core.ClosedResourceError): await r.receive_some(10) - finally: - os.close(dir_fd) + orig_wait_readable = _core._run.TheIOManager.wait_readable + + async def patched_wait_readable(*args, **kwargs): + await orig_wait_readable(*args, **kwargs) + await r.aclose() -@skip_if_fbsd_pipes_broken -async def test_pipe_fully(): - await check_one_way_stream(make_pipe, make_clogged_pipe) + monkeypatch.setattr( + _core._run.TheIOManager, "wait_readable", patched_wait_readable + ) + s, r = await make_pipe() + async with s, r: + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_closedresourceerror) + await wait_all_tasks_blocked() + # Trigger everything by waking up the receiver + await s.send_all(b"x") + + async def test_close_at_bad_time_for_send_all(monkeypatch): + # We used to have race conditions where if one task was using the pipe, + # and another closed it at *just* the wrong moment, it would give an + # unexpected error instead of ClosedResourceError: + # https://github.com/python-trio/trio/issues/661 + # + # This tests what happens if the pipe gets closed in the moment *between* + # when send_all wakes up, and when it tries to call os.write + async def expect_closedresourceerror(): + with pytest.raises(_core.ClosedResourceError): + await s.send_all(b"x" * 100) + + orig_wait_writable = _core._run.TheIOManager.wait_writable + + async def patched_wait_writable(*args, **kwargs): + await orig_wait_writable(*args, **kwargs) + await s.aclose() + + monkeypatch.setattr( + _core._run.TheIOManager, "wait_writable", patched_wait_writable + ) + s, r = await make_clogged_pipe() + async with s, r: + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_closedresourceerror) + await wait_all_tasks_blocked() + # Trigger everything by waking up the sender + await r.receive_some(10000) + + # On FreeBSD, directories are readable, and we haven't found any other trick + # for making an unreadable fd, so there's no way to run this test. Fortunately + # the logic this is testing doesn't depend on the platform, so testing on + # other platforms is probably good enough. + @pytest.mark.skipif( + sys.platform.startswith("freebsd"), + reason="no way to make read() return a bizarro error on FreeBSD", + ) + async def test_bizarro_OSError_from_receive(): + # Make sure that if the read syscall returns some bizarro error, then we + # get a BrokenResourceError. This is incredibly unlikely; there's almost + # no way to trigger a failure here intentionally (except for EBADF, but we + # exploit that to detect file closure, so it takes a different path). So + # we set up a strange scenario where the pipe fd somehow transmutes into a + # directory fd, causing os.read to raise IsADirectoryError (yes, that's a + # real built-in exception type). + s, r = await make_pipe() + async with s, r: + dir_fd = os.open("/", os.O_DIRECTORY, 0) + try: + os.dup2(dir_fd, r.fileno()) + with pytest.raises(_core.BrokenResourceError): + await r.receive_some(10) + finally: + os.close(dir_fd) + + @skip_if_fbsd_pipes_broken + async def test_pipe_fully(): + await check_one_way_stream(make_pipe, make_clogged_pipe) From 0f313e5e03251c1637bf2af1aa89009432a2c29e Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 18:03:46 -0500 Subject: [PATCH 12/50] correct to AssertionError --- trio/tests/test_unix_pipes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/tests/test_unix_pipes.py b/trio/tests/test_unix_pipes.py index 12d65b8dcb..a49de1796a 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/tests/test_unix_pipes.py @@ -14,7 +14,7 @@ pytestmark = pytest.mark.skipif(sys.platform == "win32", reason="posix only") if sys.platform == "win32": - with pytest.raises(ImportError): + with pytest.raises(AssertionError): # Using sys instead of FdStream since sys is created before the assertion that # terminates the import of _unix_pipes and this makes Mypy happier. The type # warning can't be ignored since it is not present on all platforms and thus From 8da751495e8839ea69c71785261588a12b7dcb9f Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 18:07:38 -0500 Subject: [PATCH 13/50] flake8 --- trio/_core/_run.py | 1 - trio/tests/test_subprocess.py | 1 - 2 files changed, 2 deletions(-) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 3865f4f755..6d46809e95 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -56,7 +56,6 @@ WaitTaskRescheduled, ) from ._asyncgens import AsyncGenerators -from ._entry_queue import TrioToken from ._thread_cache import start_thread_soon from ._instrumentation import Instruments from ._local import RunVar diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index 3bfc367224..a9962a0aa2 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -4,7 +4,6 @@ import sys import pytest import random -import signal from typing import Optional from functools import partial From e4fc53cc21d766fd16615ca7f7f08d95d40644af Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 18:34:42 -0500 Subject: [PATCH 14/50] fixup docs requirements --- docs-requirements.in | 1 + docs-requirements.txt | 7 ++++++- setup.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs-requirements.in b/docs-requirements.in index 23a1b0f652..d2eeb71fe9 100644 --- a/docs-requirements.in +++ b/docs-requirements.in @@ -13,6 +13,7 @@ async_generator >= 1.9 idna outcome sniffio +typing-extensions # See note in test-requirements.in immutables >= 0.6 diff --git a/docs-requirements.txt b/docs-requirements.txt index c5e7b66ab9..fca20252c5 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --output-file docs-requirements.txt docs-requirements.in +# pip-compile --output-file=docs-requirements.txt docs-requirements.in # alabaster==0.7.12 # via sphinx @@ -81,5 +81,10 @@ toml==0.10.2 # via towncrier towncrier==19.2.0 # via -r docs-requirements.in +typing-extensions==3.7.4.3 + # via -r docs-requirements.in urllib3==1.26.2 # via requests + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/setup.py b/setup.py index 726da99365..6a6dda78ca 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,7 @@ "idna", "outcome", "sniffio", + "typing-extensions", # cffi 1.12 adds from_buffer(require_writable=True) and ffi.release() # cffi 1.14 fixes memory leak inside ffi.getwinerror() # cffi is required on Windows, except on PyPy where it is built-in From 55fad66a572c74f599c15f68bd016c16e9955d58 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 23 Jan 2021 19:04:50 -0500 Subject: [PATCH 15/50] disallow_subclassing_any = True --- mypy.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index d308880c33..709f3cee86 100644 --- a/mypy.ini +++ b/mypy.ini @@ -14,7 +14,7 @@ warn_return_any = True # Avoid subtle backsliding #disallow_any_decorated = True disallow_incomplete_defs = True -#disallow_subclassing_any = True +disallow_subclassing_any = True # Enable gradually / for new modules check_untyped_defs = False From ce44c041fd166486fac0ff0cb3372330e41c7658 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 11:45:54 -0500 Subject: [PATCH 16/50] disallow_any_decorated = True --- mypy.ini | 2 +- trio/_abc.py | 60 ++- trio/_channel.py | 64 +-- trio/_core/_io_epoll.py | 18 +- trio/_core/_io_kqueue.py | 45 ++- trio/_core/_io_windows.py | 33 +- trio/_core/_ki.py | 24 +- trio/_core/_local.py | 56 ++- trio/_core/_mock_clock.py | 6 +- trio/_core/_multierror.py | 36 +- trio/_core/_parking_lot.py | 10 +- trio/_core/_run.py | 166 +++++--- trio/_core/_traps.py | 7 +- trio/_core/_unbounded_queue.py | 13 +- trio/_core/tests/conftest.py | 8 +- trio/_core/tests/test_asyncgen.py | 15 +- trio/_core/tests/test_guest_mode.py | 2 +- trio/_core/tests/test_instrumentation.py | 2 +- trio/_core/tests/test_io.py | 47 ++- trio/_core/tests/test_ki.py | 93 ++--- trio/_core/tests/test_mock_clock.py | 4 +- trio/_core/tests/test_multierror.py | 10 +- trio/_core/tests/test_run.py | 14 +- trio/_core/tests/test_thread_cache.py | 12 +- trio/_core/tests/test_windows.py | 371 +++++++++--------- trio/_core/tests/tutil.py | 4 +- trio/_deprecate.py | 46 ++- trio/_file_io.py | 165 +++++++- trio/_highlevel_open_tcp_stream.py | 7 +- trio/_highlevel_open_unix_stream.py | 11 +- trio/_highlevel_socket.py | 3 +- trio/_path.py | 74 +++- trio/_signals.py | 13 +- trio/_socket.py | 273 +++++++++++-- trio/_ssl.py | 6 +- trio/_subprocess.py | 16 +- trio/_subprocess_platform/windows.py | 2 +- trio/_sync.py | 69 ++-- trio/_threads.py | 48 ++- trio/_timeouts.py | 5 +- trio/_unix_pipes.py | 2 +- trio/_windows_pipes.py | 2 +- trio/testing/_check_streams.py | 3 +- trio/testing/_checkpoints.py | 3 +- trio/testing/_sequencer.py | 2 +- trio/testing/_trio_test.py | 11 +- trio/tests/conftest.py | 8 +- trio/tests/module_with_deprecations.py | 3 + trio/tests/test_deprecate.py | 68 +++- trio/tests/test_exports.py | 2 +- trio/tests/test_file_io.py | 20 +- .../test_highlevel_open_tcp_listeners.py | 11 +- trio/tests/test_highlevel_open_unix_stream.py | 2 +- trio/tests/test_path.py | 36 +- trio/tests/test_socket.py | 25 +- trio/tests/test_ssl.py | 23 +- trio/tests/test_subprocess.py | 13 +- trio/tests/test_sync.py | 27 +- trio/tests/test_threads.py | 22 +- trio/tests/test_timeouts.py | 8 +- trio/tests/test_unix_pipes.py | 5 +- trio/tests/test_util.py | 16 +- trio/tests/test_wait_for_object.py | 4 +- 63 files changed, 1510 insertions(+), 666 deletions(-) diff --git a/mypy.ini b/mypy.ini index 709f3cee86..b253e4418e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -12,7 +12,7 @@ warn_redundant_casts = True warn_return_any = True # Avoid subtle backsliding -#disallow_any_decorated = True +disallow_any_decorated = True disallow_incomplete_defs = True disallow_subclassing_any = True diff --git a/trio/_abc.py b/trio/_abc.py index b8e341fdaa..060f66cecc 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,9 +1,16 @@ # coding: utf-8 from abc import ABCMeta, abstractmethod -from typing import Generic, TypeVar +from typing import Generic, List, Optional, Text, Tuple, TYPE_CHECKING, TypeVar, Union +import socket import trio +if TYPE_CHECKING: + from ._socket import SocketType + + +_T = TypeVar("_T") + # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a # __dict__ onto subclasses. @@ -13,7 +20,7 @@ class Clock(metaclass=ABCMeta): __slots__ = () @abstractmethod - def start_clock(self): + def start_clock(self) -> None: """Do any setup this clock might need. Called at the beginning of the run. @@ -21,7 +28,7 @@ def start_clock(self): """ @abstractmethod - def current_time(self): + def current_time(self) -> float: """Return the current time, according to this clock. This is used to implement functions like :func:`trio.current_time` and @@ -33,7 +40,7 @@ def current_time(self): """ @abstractmethod - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: """Compute the real time until the given deadline. This is called before we enter a system-specific wait function like @@ -146,7 +153,23 @@ class HostnameResolver(metaclass=ABCMeta): __slots__ = () @abstractmethod - async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): + async def getaddrinfo( + self, + host: Optional[Union[bytearray, bytes, Text]], + port: Union[str, int, None], + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> List[ + Tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] + ]: """A custom implementation of :func:`~trio.socket.getaddrinfo`. Called by :func:`trio.socket.getaddrinfo`. @@ -163,7 +186,9 @@ async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): """ @abstractmethod - async def getnameinfo(self, sockaddr, flags): + async def getnameinfo( + self, sockaddr: Union[Tuple[str, int], Tuple[str, int, int, int]], flags: int + ) -> Tuple[str, Union[int, str]]: """A custom implementation of :func:`~trio.socket.getnameinfo`. Called by :func:`trio.socket.getnameinfo`. @@ -180,7 +205,12 @@ class SocketFactory(metaclass=ABCMeta): """ @abstractmethod - def socket(self, family=None, type=None, proto=None): + def socket( + self, + family: Optional[int] = None, + type: Optional[int] = None, + proto: Optional[int] = None, + ) -> "SocketType": """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, @@ -226,7 +256,7 @@ class AsyncResource(metaclass=ABCMeta): __slots__ = () @abstractmethod - async def aclose(self): + async def aclose(self) -> None: """Close this resource, possibly blocking. IMPORTANT: This method may block in order to perform a "graceful" @@ -254,10 +284,10 @@ async def aclose(self): """ - async def __aenter__(self): + async def __aenter__(self: _T) -> _T: return self - async def __aexit__(self, *args): + async def __aexit__(self, *args: object) -> None: await self.aclose() @@ -280,7 +310,7 @@ class SendStream(AsyncResource): __slots__ = () @abstractmethod - async def send_all(self, data): + async def send_all(self, data: Union[bytes, memoryview]) -> None: """Sends the given data through the stream, blocking if necessary. Args: @@ -306,7 +336,7 @@ async def send_all(self, data): """ @abstractmethod - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Block until it's possible that :meth:`send_all` might not block. This method may return early: it's possible that after it returns, @@ -386,7 +416,7 @@ class ReceiveStream(AsyncResource): __slots__ = () @abstractmethod - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: Optional[int] = ...) -> bytes: """Wait until there is data available on this stream, and then return some of it. @@ -447,7 +477,7 @@ class HalfCloseableStream(Stream): __slots__ = () @abstractmethod - async def send_eof(self): + async def send_eof(self) -> None: """Send an end-of-file indication on this stream, if possible. The difference between :meth:`send_eof` and @@ -526,7 +556,7 @@ class Listener(AsyncResource, Generic[T_resource]): __slots__ = () @abstractmethod - async def accept(self): + async def accept(self) -> T_resource: """Wait until an incoming connection arrives, and then return it. Returns: diff --git a/trio/_channel.py b/trio/_channel.py index b4f09e01ed..85a4ae618f 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,19 +1,25 @@ from collections import deque, OrderedDict from math import inf -from typing import cast, Callable, Tuple, TypeVar, Union +import typing +from typing import cast, Callable, Deque, Generic, Set, Tuple, TypeVar, Union import attr from outcome import Error, Value from .abc import SendChannel, ReceiveChannel, Channel from ._util import generic_function, NoPublicConstructor +from ._core._run import Task import trio from ._core import enable_ki_protection +_T_contra = TypeVar("_T_contra", contravariant=True) +_T_co = TypeVar("_T_co", covariant=True) + + @generic_function -def open_memory_channel( +def open_memory_channel( # type: ignore[misc] # TODO: should restrict the float bit to just the inf value max_buffer_size: Union[int, float], ) -> Tuple["MemorySendChannel", "MemoryReceiveChannel"]: @@ -72,7 +78,7 @@ def open_memory_channel( raise TypeError("max_buffer_size must be an integer or math.inf") if max_buffer_size < 0: raise ValueError("max_buffer_size must be >= 0") - state = MemoryChannelState(max_buffer_size) + state = MemoryChannelState(max_buffer_size) # type: ignore[var-annotated] return ( MemorySendChannel._create(state), MemoryReceiveChannel._create(state), @@ -90,16 +96,16 @@ class MemoryChannelStats: @attr.s(slots=True) -class MemoryChannelState: - max_buffer_size = attr.ib() - data = attr.ib(factory=deque) +class MemoryChannelState(Generic[_T_contra]): + max_buffer_size: float = attr.ib() + data: Deque[_T_contra] = attr.ib(factory=deque) # Counts of open endpoints using this state - open_send_channels = attr.ib(default=0) - open_receive_channels = attr.ib(default=0) + open_send_channels: int = attr.ib(default=0) + open_receive_channels: int = attr.ib(default=0) # {task: value} - send_tasks = attr.ib(factory=OrderedDict) + send_tasks: typing.OrderedDict[Task, _T_contra] = attr.ib(factory=OrderedDict) # {task: None} - receive_tasks = attr.ib(factory=OrderedDict) + receive_tasks: typing.OrderedDict[Task, None] = attr.ib(factory=OrderedDict) def statistics(self): return MemoryChannelStats( @@ -113,13 +119,13 @@ def statistics(self): @attr.s(eq=False, repr=False) -class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor): - _state = attr.ib() - _closed = attr.ib(default=False) +class MemorySendChannel(SendChannel[_T_contra], metaclass=NoPublicConstructor): + _state: MemoryChannelState[_T_contra] = attr.ib() + _closed: bool = attr.ib(default=False) # This is just the tasks waiting on *this* object. As compared to # self._state.send_tasks, which includes tasks from this object and # all clones. - _tasks = attr.ib(factory=set) + _tasks: Set[Task] = attr.ib(factory=set) def __attrs_post_init__(self): self._state.open_send_channels += 1 @@ -134,7 +140,7 @@ def statistics(self): return self._state.statistics() @enable_ki_protection - def send_nowait(self, value): + def send_nowait(self, value: _T_contra) -> None: """Like `~trio.abc.SendChannel.send`, but if the channel's buffer is full, raises `WouldBlock` instead of blocking. @@ -154,7 +160,7 @@ def send_nowait(self, value): raise trio.WouldBlock @enable_ki_protection - async def send(self, value): + async def send(self, value: _T_contra) -> None: """See `SendChannel.send `. Memory channels allow multiple tasks to call `send` at the same time. @@ -182,7 +188,7 @@ def abort_fn(_): await trio.lowlevel.wait_task_rescheduled(abort_fn) @enable_ki_protection - def clone(self): + def clone(self) -> "MemorySendChannel[_T_contra]": """Clone this send channel object. This returns a new `MemorySendChannel` object, which acts as a @@ -217,7 +223,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() @enable_ki_protection - def close(self): + def close(self) -> None: """Close this send channel object synchronously. All channel objects have an asynchronous `~.AsyncResource.aclose` method. @@ -245,16 +251,16 @@ def close(self): self._state.receive_tasks.clear() @enable_ki_protection - async def aclose(self): + async def aclose(self) -> None: self.close() await trio.lowlevel.checkpoint() @attr.s(eq=False, repr=False) -class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor): - _state = attr.ib() - _closed = attr.ib(default=False) - _tasks = attr.ib(factory=set) +class MemoryReceiveChannel(ReceiveChannel[_T_co], metaclass=NoPublicConstructor): + _state: MemoryChannelState[_T_co] = attr.ib() + _closed: bool = attr.ib(default=False) + _tasks: Set[Task] = attr.ib(factory=set) def __attrs_post_init__(self): self._state.open_receive_channels += 1 @@ -268,7 +274,7 @@ def __repr__(self): ) @enable_ki_protection - def receive_nowait(self): + def receive_nowait(self) -> _T_co: """Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing ready to receive, raises `WouldBlock` instead of blocking. @@ -288,7 +294,7 @@ def receive_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def receive(self): + async def receive(self) -> _T_co: """See `ReceiveChannel.receive `. Memory channels allow multiple tasks to call `receive` at the same @@ -315,10 +321,10 @@ def abort_fn(_): del self._state.receive_tasks[task] return trio.lowlevel.Abort.SUCCEEDED - return await trio.lowlevel.wait_task_rescheduled(abort_fn) + return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[return-value] @enable_ki_protection - def clone(self): + def clone(self) -> "MemoryReceiveChannel[_T_co]": """Clone this receive channel object. This returns a new `MemoryReceiveChannel` object, which acts as a @@ -356,7 +362,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() @enable_ki_protection - def close(self): + def close(self) -> None: """Close this receive channel object synchronously. All channel objects have an asynchronous `~.AsyncResource.aclose` method. @@ -385,6 +391,6 @@ def close(self): self._state.data.clear() @enable_ki_protection - async def aclose(self): + async def aclose(self) -> None: self.close() await trio.lowlevel.checkpoint() diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 802f425409..eea857a3d4 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -2,7 +2,9 @@ import sys import attr from collections import defaultdict -from typing import DefaultDict, Dict, TYPE_CHECKING +from typing import DefaultDict, Dict, TYPE_CHECKING, Union + +from typing_extensions import Protocol from .. import _core from ._run import _public @@ -12,6 +14,11 @@ assert not TYPE_CHECKING or sys.platform == "linux" +class _HasFileno(Protocol): + def fileno(self) -> int: + ... + + @attr.s(slots=True, eq=False, frozen=True) class _EpollStatistics: tasks_waiting_read = attr.ib() @@ -295,15 +302,18 @@ def abort(_): await _core.wait_task_rescheduled(abort) @_public - async def wait_readable(self, fd): + async def wait_readable( + self, + fd: Union[int, _HasFileno], + ) -> None: await self._epoll_wait(fd, "read_task") @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: Union[int, _HasFileno]) -> None: await self._epoll_wait(fd, "write_task") @_public - def notify_closing(self, fd): + def notify_closing(self, fd: Union[int, _HasFileno]) -> None: if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 31940d5694..08c6122778 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -1,6 +1,8 @@ import select import sys -from typing import TYPE_CHECKING +from typing import Callable, Dict, Iterator, Optional, Tuple, TYPE_CHECKING, Union + +from typing_extensions import Protocol import outcome from contextlib import contextmanager @@ -8,12 +10,18 @@ import errno from .. import _core -from ._run import _public +from ._run import _public, Task +from ._unbounded_queue import UnboundedQueue from ._wakeup_socketpair import WakeupSocketpair assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") +class _HasFileno(Protocol): + def fileno(self) -> int: + ... + + @attr.s(slots=True, eq=False, frozen=True) class _KqueueStatistics: tasks_waiting = attr.ib() @@ -23,11 +31,13 @@ class _KqueueStatistics: @attr.s(slots=True, eq=False) class KqueueIOManager: - _kqueue = attr.ib(factory=select.kqueue) + _kqueue: select.kqueue = attr.ib(factory=select.kqueue) # {(ident, filter): Task or UnboundedQueue} - _registered = attr.ib(factory=dict) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) + _registered: Dict[Tuple[int, int], Union[Task, UnboundedQueue[Task]]] = attr.ib( + factory=dict + ) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: Optional[int] = attr.ib(default=None) def __attrs_post_init__(self): force_wakeup_event = select.kevent( @@ -96,18 +106,20 @@ def process_events(self, events): # be more ergonomic... @_public - def current_kqueue(self): + def current_kqueue(self) -> select.kqueue: return self._kqueue @contextmanager @_public - def monitor_kevent(self, ident, filter): + def monitor_kevent( + self, ident: int, filter: int + ) -> Iterator[_core.UnboundedQueue[Task]]: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( "attempt to register multiple listeners for same ident/filter pair" ) - q = _core.UnboundedQueue() + q = _core.UnboundedQueue[Task]() self._registered[key] = q try: yield q @@ -115,7 +127,12 @@ def monitor_kevent(self, ident, filter): del self._registered[key] @_public - async def wait_kevent(self, ident, filter, abort_func): + async def wait_kevent( + self, + ident: int, + filter: int, + abort_func: Callable[[Callable[[], None]], _core.Abort], + ) -> object: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( @@ -131,7 +148,7 @@ def abort(raise_cancel): return await _core.wait_task_rescheduled(abort) - async def _wait_common(self, fd, filter): + async def _wait_common(self, fd: Union[int, _HasFileno], filter: int) -> None: if not isinstance(fd, int): fd = fd.fileno() flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT @@ -163,15 +180,15 @@ def abort(_): await self.wait_kevent(fd, filter, abort) @_public - async def wait_readable(self, fd): + async def wait_readable(self, fd: Union[int, _HasFileno]) -> None: await self._wait_common(fd, select.KQ_FILTER_READ) @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: Union[int, _HasFileno]) -> None: await self._wait_common(fd, select.KQ_FILTER_WRITE) @_public - def notify_closing(self, fd): + def notify_closing(self, fd: Union[int, _HasFileno]) -> None: if not isinstance(fd, int): fd = fd.fileno() diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index c318d59629..0b5dd034b7 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -3,7 +3,7 @@ import enum import socket import sys -from typing import TYPE_CHECKING +from typing import Iterator, Tuple, TYPE_CHECKING, Union import attr @@ -27,7 +27,7 @@ IoControlCodes, ) -assert not TYPE_CHECKING or sys.platform == "win32" +# assert not TYPE_CHECKING or sys.platform == "win32" # There's a lot to be said about the overall design of a Windows event # loop. See @@ -691,15 +691,15 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) @_public - async def wait_readable(self, sock): + async def wait_readable(self, sock: socket.socket) -> None: await self._afd_poll(sock, "read_task") @_public - async def wait_writable(self, sock): + async def wait_writable(self, sock: socket.socket) -> None: await self._afd_poll(sock, "write_task") @_public - def notify_closing(self, handle): + def notify_closing(self, handle: socket.socket) -> None: handle = _get_base_socket(handle) waiters = self._afd_waiters.get(handle) if waiters is not None: @@ -711,11 +711,14 @@ def notify_closing(self, handle): ################################################################ @_public - def register_with_iocp(self, handle): + def register_with_iocp(self, handle: socket.socket) -> None: self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED) + # TODO: what else can lpOverlapped be? @_public - async def wait_overlapped(self, handle, lpOverlapped): + async def wait_overlapped( + self, handle: socket.socket, lpOverlapped: Union[int, object] + ) -> None: handle = _handle(handle) if isinstance(lpOverlapped, int): lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped) @@ -798,7 +801,9 @@ async def _perform_overlapped(self, handle, submit_fn): return lpOverlapped @_public - async def write_overlapped(self, handle, data, file_offset=0): + async def write_overlapped( + self, handle: int, data: bytes, file_offset: int = 0 + ) -> int: with ffi.from_buffer(data) as cbuf: def submit_write(lpOverlapped): @@ -821,7 +826,9 @@ def submit_write(lpOverlapped): return lpOverlapped.InternalHigh @_public - async def readinto_overlapped(self, handle, buffer, file_offset=0): + async def readinto_overlapped( + self, handle: int, buffer: memoryview, file_offset: int = 0 + ) -> int: with ffi.from_buffer(buffer, require_writable=True) as cbuf: def submit_read(lpOverlapped): @@ -846,14 +853,16 @@ def submit_read(lpOverlapped): ################################################################ @_public - def current_iocp(self): + def current_iocp(self) -> int: return int(ffi.cast("uintptr_t", self._iocp)) @contextmanager @_public - def monitor_completion_key(self): + def monitor_completion_key( + self, + ) -> Iterator[Tuple[int, _core.UnboundedQueue[CompletionKeyEventInfo]]]: key = next(self._completion_key_counter) - queue = _core.UnboundedQueue() + queue = _core.UnboundedQueue[CompletionKeyEventInfo]() self._completion_key_queues[key] = queue try: yield (key, queue) diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index 36aacecd96..ee0dab4b04 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -2,16 +2,16 @@ import signal import sys from functools import wraps +from typing import Any, TypeVar, Callable import attr import async_generator from .._util import is_main_thread -if False: - from typing import Any, TypeVar, Callable - F = TypeVar("F", bound=Callable[..., Any]) +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) + # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. @@ -109,15 +109,17 @@ def currently_ki_protected(): return ki_protection_enabled(sys._getframe()) -def _ki_protection_decorator(enabled): - def decorator(fn): +def _ki_protection_decorator(enabled: bool) -> Callable[[_Fn], _Fn]: + def decorator(fn: _Fn) -> _Fn: # In some version of Python, isgeneratorfunction returns true for # coroutine functions, so we have to check for coroutine functions # first. + wrapper: _Fn + if inspect.iscoroutinefunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: # See the comment for regular generators below coro = fn(*args, **kwargs) coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled @@ -127,7 +129,7 @@ def wrapper(*args, **kwargs): elif inspect.isgeneratorfunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: # It's important that we inject this directly into the # generator's locals, as opposed to setting it here and then # doing 'yield from'. The reason is, if a generator is @@ -144,7 +146,7 @@ def wrapper(*args, **kwargs): elif async_generator.isasyncgenfunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: # See the comment for regular generators above agen = fn(*args, **kwargs) agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled @@ -154,7 +156,7 @@ def wrapper(*args, **kwargs): else: @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return fn(*args, **kwargs) @@ -163,10 +165,10 @@ def wrapper(*args, **kwargs): return decorator -enable_ki_protection = _ki_protection_decorator(True) # type: Callable[[F], F] +enable_ki_protection: Callable[[_Fn], _Fn] = _ki_protection_decorator(True) enable_ki_protection.__name__ = "enable_ki_protection" -disable_ki_protection = _ki_protection_decorator(False) # type: Callable[[F], F] +disable_ki_protection: Callable[[_Fn], _Fn] = _ki_protection_decorator(False) disable_ki_protection.__name__ = "disable_ki_protection" diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 1f64d4ce85..6dbf190f0d 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,25 +1,38 @@ # Runvar implementations +from typing import Generic, overload, TypeVar, Union + from . import _run from .._util import Final -class _RunVarToken: - _no_value = object() +_T = TypeVar("_T") + + +class _NoValue: + pass + + +class _RunVarToken(Generic[_T]): + _no_value = _NoValue() __slots__ = ("_var", "previous_value", "redeemed") @classmethod - def empty(cls, var): + def empty(cls, var: "RunVar[_T]") -> "_RunVarToken[_T]": return cls(var, value=cls._no_value) - def __init__(self, var, value): + def __init__(self, var: "RunVar", value: Union[_NoValue, _T]): self._var = var self.previous_value = value self.redeemed = False -class RunVar(metaclass=Final): +class _NoDefault: + pass + + +class RunVar(Generic[_T], metaclass=Final): """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, @@ -28,30 +41,43 @@ class RunVar(metaclass=Final): """ - _NO_DEFAULT = object() + _NO_DEFAULT = _NoDefault() __slots__ = ("_name", "_default") - def __init__(self, name, default=_NO_DEFAULT): + @overload + def __init__(self, name: str) -> None: + ... + + @overload + def __init__(self, name: str, default: _T): + ... + + def __init__(self, name: str, default: object = _NO_DEFAULT) -> None: self._name = name self._default = default - def get(self, default=_NO_DEFAULT): + def get(self, default: Union[_NoDefault, _T] = _NO_DEFAULT) -> _T: """Gets the value of this :class:`RunVar` for the current run call.""" + + # Ignoring type hint return complaints since the underlying dict can't really + # be hinted per run local and other options including casting and checking + # instance types would result in runtime overhead. + try: - return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] + return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: # contextvars consistency if default is not self._NO_DEFAULT: - return default + return default # type: ignore[return-value] if self._default is not self._NO_DEFAULT: - return self._default + return self._default # type: ignore[return-value] raise LookupError(self) from None - def set(self, value): + def set(self, value: _T) -> _RunVarToken[_T]: """Sets the value of this :class:`RunVar` for this current run call. @@ -59,16 +85,16 @@ def set(self, value): try: old_value = self.get() except LookupError: - token = _RunVarToken.empty(self) + token = _RunVarToken[_T].empty(self) else: - token = _RunVarToken(self, old_value) + token = _RunVarToken[_T](self, old_value) # This can't fail, because if we weren't in Trio context then the # get() above would have failed. _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value return token - def reset(self, token): + def reset(self, token: _RunVarToken) -> None: """Resets the value of this :class:`RunVar` to what it was previously specified by the token. diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py index 0e95e4e5c5..50a60325fc 100644 --- a/trio/_core/_mock_clock.py +++ b/trio/_core/_mock_clock.py @@ -62,7 +62,7 @@ class MockClock(Clock, metaclass=Final): """ - def __init__(self, rate=0.0, autojump_threshold=inf): + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): # when the real clock said 'real_base', the virtual time was # 'virtual_base', and since then it's advanced at 'rate' virtual # seconds per real second. @@ -83,7 +83,7 @@ def __repr__(self): ) @property - def rate(self): + def rate(self) -> float: return self._rate @rate.setter @@ -98,7 +98,7 @@ def rate(self, new_rate): self._rate = float(new_rate) @property - def autojump_threshold(self): + def autojump_threshold(self) -> float: return self._autojump_threshold @autojump_threshold.setter diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 337b9e64ac..0f59184030 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -1,6 +1,7 @@ import sys import traceback import textwrap +from typing import Callable, Optional, overload, Set, Union import warnings import attr @@ -17,7 +18,24 @@ ################################################################ -def _filter_impl(handler, root_exc): +@overload +def _filter_impl( + handler: Callable[[Exception], Optional[Exception]], root_exc: Exception +) -> Optional[Exception]: + ... + + +@overload +def _filter_impl( + handler: Callable[[Exception], Optional[Exception]], root_exc: "MultiError" +) -> Optional[Union[Exception, "MultiError"]]: + ... + + +def _filter_impl( + handler: Callable[[Exception], Optional[Exception]], + root_exc: Union[Exception, "MultiError"], +) -> Optional[Union[Exception, "MultiError"]]: # We have a tree of MultiError's, like: # # MultiError([ @@ -76,7 +94,9 @@ def _filter_impl(handler, root_exc): # Filters a subtree, ignoring tracebacks, while keeping a record of # which MultiErrors were preserved unchanged - def filter_tree(exc, preserved): + def filter_tree( + exc: Union[Exception, "MultiError"], preserved: Set[int] + ) -> Optional[Union[Exception, "MultiError"]]: if isinstance(exc, MultiError): new_exceptions = [] changed = False @@ -111,7 +131,7 @@ def push_tb_down(tb, exc, preserved): else: exc.__traceback__ = new_tb - preserved = set() + preserved: Set[int] = set() new_root_exc = filter_tree(root_exc, preserved) push_tb_down(None, root_exc, preserved) # Delete the local functions to avoid a reference cycle (see @@ -213,7 +233,11 @@ def __repr__(self): return "".format(self) @classmethod - def filter(cls, handler, root_exc): + def filter( + cls, + handler: Callable[[Exception], Optional[Exception]], + root_exc: Union[Exception, "MultiError"], + ) -> Optional[Union[Exception, "MultiError"]]: """Apply the given ``handler`` to all the exceptions in ``root_exc``. Args: @@ -232,7 +256,9 @@ def filter(cls, handler, root_exc): return _filter_impl(handler, root_exc) @classmethod - def catch(cls, handler): + def catch( + cls, handler: Callable[[Exception], Optional[Exception]] + ) -> MultiErrorCatcher: """Return a context manager that catches and re-throws exceptions after running :meth:`filter` on them. diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index 8b114b5230..237bd5677d 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -73,7 +73,9 @@ from itertools import count import attr from collections import OrderedDict +import typing +from ._run import Task from .. import _core from .._util import Final @@ -101,7 +103,7 @@ class ParkingLot(metaclass=Final): # {task: None}, we just want a deque where we can quickly delete random # items - _parked = attr.ib(factory=OrderedDict, init=False) + _parked: typing.OrderedDict[Task, None] = attr.ib(factory=OrderedDict, init=False) def __len__(self): """Returns the number of parked tasks.""" @@ -116,7 +118,7 @@ def __bool__(self): # line (for false wakeups), then we could have it return a ticket that # abstracts the "place in line" concept. @_core.enable_ki_protection - async def park(self): + async def park(self) -> None: """Park the current task until woken by a call to :meth:`unpark` or :meth:`unpark_all`. @@ -137,7 +139,7 @@ def _pop_several(self, count): yield task @_core.enable_ki_protection - def unpark(self, *, count=1): + def unpark(self, *, count: int = 1) -> typing.List[Task]: """Unpark one or more tasks. This wakes up ``count`` tasks that are blocked in :meth:`park`. If @@ -158,7 +160,7 @@ def unpark_all(self): return self.unpark(count=len(self)) @_core.enable_ki_protection - def repark(self, new_lot, *, count=1): + def repark(self, new_lot: "ParkingLot", *, count: int = 1) -> None: """Move parked tasks from one :class:`ParkingLot` object to another. This dequeues ``count`` tasks from one lot, and requeues them on diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 6d46809e95..9e95816ce1 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -19,19 +19,31 @@ from contextvars import copy_context from math import inf from time import perf_counter +from types import TracebackType from typing import ( + Any, + Awaitable, Callable, + Coroutine, Deque, Dict, + FrozenSet, Generator, + Iterator, List, Optional, + overload, + Sequence, Set, Tuple, + Type, + TypeVar, TYPE_CHECKING, Union, ) +from typing_extensions import Protocol + from sniffio import current_async_library_cvar import attr @@ -68,10 +80,12 @@ _NO_SEND = object() +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) + # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn): +def _public(fn: _Fn) -> _Fn: return fn @@ -211,6 +225,15 @@ def expire(self, now): return did_something +class Scope(Protocol): + deadline: float + shield: bool + + @property + def cancel_called(self) -> bool: + ... + + @attr.s(eq=False, slots=True) class CancelStatus: """Tracks the cancellation status for a contiguous extent @@ -244,7 +267,7 @@ class CancelStatus: # Our associated cancel scope. Can be any object with attributes # `deadline`, `shield`, and `cancel_called`, but in current usage # is always a CancelScope object. Must not be None. - _scope = attr.ib() + _scope: Scope = attr.ib() # True iff the tasks in self._tasks should receive cancellations # when they checkpoint. Always True when scope.cancel_called is True; @@ -254,29 +277,29 @@ class CancelStatus: # effectively cancelled due to the cancel scope two levels out # becoming cancelled, but then the cancel scope one level out # becomes shielded so we're not effectively cancelled anymore. - effectively_cancelled = attr.ib(default=False) + effectively_cancelled: bool = attr.ib(default=False) # The CancelStatus whose cancellations can propagate to us; we # become effectively cancelled when they do, unless scope.shield # is True. May be None (for the outermost CancelStatus in a call # to trio.run(), briefly during TaskStatus.started(), or during # recovery from mis-nesting of cancel scopes). - _parent = attr.ib(default=None, repr=False) + _parent: Optional["CancelStatus"] = attr.ib(default=None, repr=False) # All of the CancelStatuses that have this CancelStatus as their parent. - _children = attr.ib(factory=set, init=False, repr=False) + _children: Set["CancelStatus"] = attr.ib(factory=set, init=False, repr=False) # Tasks whose cancellation state is currently tied directly to # the cancellation state of this CancelStatus object. Don't modify # this directly; instead, use Task._activate_cancel_status(). # Invariant: all(task._cancel_status is self for task in self._tasks) - _tasks = attr.ib(factory=set, init=False, repr=False) + _tasks: Set["Task"] = attr.ib(factory=set, init=False, repr=False) # Set to True on still-active cancel statuses that are children # of a cancel status that's been closed. This is used to permit # recovery from mis-nested cancel scopes (well, at least enough # recovery to show a useful traceback). - abandoned_by_misnesting = attr.ib(default=False, init=False, repr=False) + abandoned_by_misnesting: bool = attr.ib(default=False, init=False, repr=False) def __attrs_post_init__(self): if self._parent is not None: @@ -286,11 +309,11 @@ def __attrs_post_init__(self): # parent/children/tasks accessors are used by TaskStatus.started() @property - def parent(self): + def parent(self) -> Optional["CancelStatus"]: return self._parent @parent.setter - def parent(self, parent): + def parent(self, parent: Optional["CancelStatus"]) -> None: if self._parent is not None: self._parent._children.remove(self) self._parent = parent @@ -299,11 +322,11 @@ def parent(self, parent): self.recalculate() @property - def children(self): + def children(self) -> FrozenSet["CancelStatus"]: return frozenset(self._children) @property - def tasks(self): + def tasks(self) -> FrozenSet["Task"]: return frozenset(self._tasks) def encloses(self, other): @@ -345,7 +368,7 @@ def close(self): child.recalculate() @property - def parent_cancellation_is_visible_to_us(self): + def parent_cancellation_is_visible_to_us(self) -> bool: return ( self._parent is not None and not self._scope.shield @@ -459,7 +482,7 @@ class CancelScope(metaclass=Final): _shield = attr.ib(default=False, kw_only=True) @enable_ki_protection - def __enter__(self): + def __enter__(self) -> "CancelScope": task = _core.current_task() if self._has_been_entered: raise RuntimeError( @@ -543,7 +566,12 @@ def _close(self, exc): return exc @enable_ki_protection - def __exit__(self, etype, exc, tb): + def __exit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: # NB: NurseryManager calls _close() directly rather than __exit__(), # so __exit__() must be just _close() plus this logic for adapting # the exception-filtering result to the context manager API. @@ -593,7 +621,7 @@ def __repr__(self): @contextmanager @enable_ki_protection - def _might_change_registered_deadline(self): + def _might_change_registered_deadline(self) -> Iterator[None]: try: yield finally: @@ -617,7 +645,7 @@ def _might_change_registered_deadline(self): runner.force_guest_tick_asap() @property - def deadline(self): + def deadline(self) -> float: """Read-write, :class:`float`. An absolute time on the current run's clock at which this scope will automatically become cancelled. You can adjust the deadline by modifying this @@ -643,12 +671,12 @@ def deadline(self): return self._deadline @deadline.setter - def deadline(self, new_deadline): + def deadline(self, new_deadline: float) -> None: with self._might_change_registered_deadline(): self._deadline = float(new_deadline) @property - def shield(self): + def shield(self) -> bool: """Read-write, :class:`bool`, default :data:`False`. So long as this is set to :data:`True`, then the code inside this scope will not receive :exc:`~trio.Cancelled` exceptions from scopes @@ -674,7 +702,7 @@ def shield(self): # ignore for "decorated property not supported" @shield.setter # type: ignore[misc] @enable_ki_protection - def shield(self, new_value): + def shield(self, new_value: bool) -> None: if not isinstance(new_value, bool): raise TypeError("shield must be a bool") self._shield = new_value @@ -682,7 +710,7 @@ def shield(self, new_value): self._cancel_status.recalculate() @enable_ki_protection - def cancel(self): + def cancel(self) -> None: """Cancels this scope immediately. This method is idempotent, i.e., if the scope was already @@ -696,7 +724,7 @@ def cancel(self): self._cancel_status.recalculate() @property - def cancel_called(self): + def cancel_called(self) -> bool: """Readonly :class:`bool`. Records whether cancellation has been requested for this scope, either by an explicit call to :meth:`cancel` or by the deadline expiring. @@ -806,14 +834,19 @@ class NurseryManager: """ @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self) -> "Nursery": self._scope = CancelScope() self._scope.__enter__() self._nursery = Nursery._create(current_task(), self._scope) return self._nursery @enable_ki_protection - async def __aexit__(self, etype, exc, tb): + async def __aexit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: new_exc = await self._nursery._nested_child_finished(exc) # Tracebacks show the 'raise' line below out of context, so let's give # this variable a name that makes sense out of context. @@ -875,7 +908,7 @@ class Nursery(metaclass=NoPublicConstructor): in response to some external event. """ - def __init__(self, parent_task, cancel_scope): + def __init__(self, parent_task: "Task", cancel_scope: CancelScope): self._parent_task = parent_task parent_task._child_nurseries.append(self) # the cancel status that children inherit - we take a snapshot, so it @@ -885,8 +918,8 @@ def __init__(self, parent_task, cancel_scope): # children. self.cancel_scope = cancel_scope assert self.cancel_scope._cancel_status is self._cancel_status - self._children = set() - self._pending_excs = [] + self._children: Set["Task"] = set() + self._pending_excs: List[Exception] = [] # The "nested child" is how this code refers to the contents of the # nursery's 'async with' block, which acts like a child Task in all # the ways we can make it. @@ -896,13 +929,13 @@ def __init__(self, parent_task, cancel_scope): self._closed = False @property - def child_tasks(self): + def child_tasks(self) -> FrozenSet["Task"]: """(`frozenset`): Contains all the child :class:`~trio.lowlevel.Task` objects which are still running.""" return frozenset(self._children) @property - def parent_task(self): + def parent_task(self) -> "Task": "(`~trio.lowlevel.Task`): The Task that opened this nursery." return self._parent_task @@ -1077,13 +1110,13 @@ def __del__(self): @attr.s(eq=False, hash=False, repr=False) class Task(metaclass=NoPublicConstructor): - _parent_nursery = attr.ib() - coro = attr.ib() - _runner = attr.ib() - name = attr.ib() + _parent_nursery: Nursery = attr.ib() + coro: Coroutine = attr.ib() + _runner: "Runner" = attr.ib() + name: str = attr.ib() # PEP 567 contextvars context - context = attr.ib() - _counter = attr.ib(init=False, factory=itertools.count().__next__) + context: Context = attr.ib() + _counter: int = attr.ib(init=False, factory=itertools.count().__next__) # Invariant: # - for unscheduled tasks, _next_send_fn and _next_send are both None @@ -1096,26 +1129,26 @@ class Task(metaclass=NoPublicConstructor): # tracebacks with extraneous frames. # - for scheduled tasks, custom_sleep_data is None # Tasks start out unscheduled. - _next_send_fn = attr.ib(default=None) - _next_send = attr.ib(default=None) - _abort_func = attr.ib(default=None) - custom_sleep_data = attr.ib(default=None) + _next_send_fn: Callable = attr.ib(default=None) + _next_send: Optional[Union[Outcome, Exception, MultiError]] = attr.ib(default=None) + _abort_func: Callable = attr.ib(default=None) + custom_sleep_data: Any = attr.ib(default=None) # For introspection and nursery.start() - _child_nurseries = attr.ib(factory=list) - _eventual_parent_nursery = attr.ib(default=None) + _child_nurseries: List[Nursery] = attr.ib(factory=list) + _eventual_parent_nursery: Optional[Nursery] = attr.ib(default=None) # these are counts of how many cancel/schedule points this task has # executed, for assert{_no,}_checkpoints # XX maybe these should be exposed as part of a statistics() method? - _cancel_points = attr.ib(default=0) - _schedule_points = attr.ib(default=0) + _cancel_points: int = attr.ib(default=0) + _schedule_points: int = attr.ib(default=0) def __repr__(self): return "".format(self.name, id(self)) @property - def parent_nursery(self): + def parent_nursery(self) -> Nursery: """The nursery this task is inside (or None if this is the "init" task). @@ -1126,7 +1159,7 @@ def parent_nursery(self): return self._parent_nursery @property - def eventual_parent_nursery(self): + def eventual_parent_nursery(self) -> Optional[Nursery]: """The nursery this task will be inside after it calls ``task_status.started()``. @@ -1138,7 +1171,7 @@ def eventual_parent_nursery(self): return self._eventual_parent_nursery @property - def child_nurseries(self): + def child_nurseries(self) -> List[Nursery]: """The nurseries this task contains. This is a list, with outer nurseries before inner nurseries. @@ -1152,7 +1185,7 @@ def child_nurseries(self): # The CancelStatus object that is currently active for this task. # Don't change this directly; instead, use _activate_cancel_status(). - _cancel_status = attr.ib(default=None, repr=False) + _cancel_status: Optional[CancelStatus] = attr.ib(default=None, repr=False) def _activate_cancel_status(self, cancel_status): if self._cancel_status is not None: @@ -1339,7 +1372,7 @@ def close(self): self.ki_manager.close() @_public - def current_statistics(self): + def current_statistics(self) -> _RunStatistics: """Returns an object containing run-loop-level debugging information. Currently the following fields are defined: @@ -1372,7 +1405,7 @@ def current_statistics(self): ) @_public - def current_time(self): + def current_time(self) -> float: """Returns the current time according to Trio's internal clock. Returns: @@ -1385,12 +1418,12 @@ def current_time(self): return self.clock.current_time() @_public - def current_clock(self): + def current_clock(self) -> Clock: """Returns the current :class:`~trio.abc.Clock`.""" return self.clock @_public - def current_root_task(self): + def current_root_task(self) -> Optional[Task]: """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. @@ -1402,8 +1435,16 @@ def current_root_task(self): # Core task handling primitives ################ + @overload + def reschedule(self, task: Task) -> None: + ... + + @overload + def reschedule(self, task: Task, next_send: Outcome) -> None: + ... + @_public - def reschedule(self, task, next_send=_NO_SEND): + def reschedule(self, task: Task, next_send: object = _NO_SEND) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -1436,7 +1477,15 @@ def reschedule(self, task, next_send=_NO_SEND): if "task_scheduled" in self.instruments: self.instruments.call("task_scheduled", task) - def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False): + def spawn_impl( + self, + async_fn: Callable[..., Awaitable[object]], + args: Sequence[object], + nursery: Optional[Nursery], + name: Optional[Union[str, Callable]], + *, + system_task: bool = False, + ) -> Task: ###### # Make sure the nursery is in working order @@ -1467,7 +1516,7 @@ def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False): name = repr(name) if system_task: - context = self.system_context.copy() + context = self.system_context.copy() # type: ignore[union-attr] else: context = copy_context() @@ -1547,7 +1596,12 @@ def task_exited(self, task, outcome): ################ @_public - def spawn_system_task(self, async_fn, *args, name=None): + def spawn_system_task( # type: ignore[misc] + self, + async_fn: Callable[..., Awaitable[object]], + *args: object, + name: Optional[str] = None, + ) -> Task: """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -1637,7 +1691,7 @@ async def init(self, async_fn, args): ################ @_public - def current_trio_token(self): + def current_trio_token(self) -> TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. @@ -1687,7 +1741,7 @@ def _deliver_ki_cb(self): waiting_for_idle: SortedDict = attr.ib(factory=SortedDict) @_public - async def wait_all_tasks_blocked(self, cushion=0.0): + async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index 95cf46de9b..7d6f160434 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -2,6 +2,7 @@ import types import enum +from typing import Callable import attr import outcome @@ -17,7 +18,7 @@ # tracking machinery. Since our traps are public APIs, we make them real async # functions, and then this helper takes care of the actual yield: @types.coroutine -def _async_yield(obj): +def _async_yield(obj: object) -> object: return (yield obj) @@ -64,7 +65,9 @@ class WaitTaskRescheduled: abort_func = attr.ib() -async def wait_task_rescheduled(abort_func): +async def wait_task_rescheduled( + abort_func: Callable[[Callable[[], None]], Abort] +) -> object: """Put the current task to sleep, with cancellation support. This is the lowest-level API for blocking in Trio. Every time a diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index f877e42a0c..b57cbddf78 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -1,3 +1,5 @@ +from typing import Generic, List, TypeVar + import attr from .. import _core @@ -5,13 +7,16 @@ from .._util import Final +_T = TypeVar("_T") + + @attr.s(frozen=True) class _UnboundedQueueStats: qsize = attr.ib() tasks_waiting = attr.ib() -class UnboundedQueue(metaclass=Final): +class UnboundedQueue(Generic[_T], metaclass=Final): """An unbounded queue suitable for certain unusual forms of inter-task communication. @@ -47,9 +52,9 @@ class UnboundedQueue(metaclass=Final): thing="trio.lowlevel.UnboundedQueue", instead="trio.open_memory_channel(math.inf)", ) - def __init__(self): + def __init__(self) -> None: self._lot = _core.ParkingLot() - self._data = [] + self._data: List[_T] = [] # used to allow handoff from put to the first task in the lot self._can_get = False @@ -70,7 +75,7 @@ def empty(self): return not self._data @_core.enable_ki_protection - def put_nowait(self, obj): + def put_nowait(self, obj: _T) -> None: """Put an object into the queue, without blocking. This always succeeds, because the queue is unbounded. We don't provide diff --git a/trio/_core/tests/conftest.py b/trio/_core/tests/conftest.py index aca1f98a65..bdb9a9a04f 100644 --- a/trio/_core/tests/conftest.py +++ b/trio/_core/tests/conftest.py @@ -1,4 +1,6 @@ import pytest +import _pytest.python + import inspect # XX this should move into a global something @@ -6,12 +8,12 @@ @pytest.fixture -def mock_clock(): +def mock_clock() -> MockClock: return MockClock() @pytest.fixture -def autojump_clock(): +def autojump_clock() -> MockClock: return MockClock(autojump_threshold=0) @@ -20,6 +22,6 @@ def autojump_clock(): # guess it's useful with the class- and file-level marking machinery (where # the raw @trio_test decorator isn't enough). @pytest.hookimpl(tryfirst=True) -def pytest_pyfunc_call(pyfuncitem): +def pytest_pyfunc_call(pyfuncitem: _pytest.python.Function) -> None: # type: ignore[misc] if inspect.iscoroutinefunction(pyfuncitem.obj): pyfuncitem.obj = trio_test(pyfuncitem.obj) diff --git a/trio/_core/tests/test_asyncgen.py b/trio/_core/tests/test_asyncgen.py index 1f886e11ab..3635b4bdbc 100644 --- a/trio/_core/tests/test_asyncgen.py +++ b/trio/_core/tests/test_asyncgen.py @@ -7,6 +7,8 @@ from ... import _core from .tutil import gc_collect_harder, buggy_pypy_asyncgens +import _pytest.capture + def test_asyncgen_basics(): collected = [] @@ -105,7 +107,7 @@ async def agen(): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_firstiter_after_closing(): +def test_firstiter_after_closing() -> None: saved = [] record = [] @@ -132,7 +134,7 @@ async def async_main(): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_interdependent_asyncgen_cleanup_order(): +def test_interdependent_asyncgen_cleanup_order() -> None: saved = [] record = [] @@ -167,7 +169,7 @@ async def async_main(): assert record == [] _core.run(async_main) - assert record == ["innermost"] + list(range(100)) + assert record == ["innermost", *range(100)] def test_last_minute_gc_edge_case(): @@ -252,8 +254,11 @@ def abort_fn(_): nursery.cancel_scope.deadline = _core.current_time() +# can switch to annotating from pytest directly as of 6.2.0 @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -async def test_fallback_when_no_hook_claims_it(capsys): +async def test_fallback_when_no_hook_claims_it( + capsys: _pytest.capture.CaptureFixture[str], +) -> None: async def well_behaved(): yield 42 @@ -281,7 +286,7 @@ async def awaits_after_yield(): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_delegation_to_existing_hooks(): +def test_delegation_to_existing_hooks() -> None: record = [] def my_firstiter(agen): diff --git a/trio/_core/tests/test_guest_mode.py b/trio/_core/tests/test_guest_mode.py index c9701e7cdd..ad0322c794 100644 --- a/trio/_core/tests/test_guest_mode.py +++ b/trio/_core/tests/test_guest_mode.py @@ -506,7 +506,7 @@ async def trio_main(in_host): sys.implementation.name == "pypy" and sys.version_info >= (3, 7), reason="async generator issue under investigation", ) -def test_guest_mode_asyncgens(): +def test_guest_mode_asyncgens() -> None: import sniffio record = set() diff --git a/trio/_core/tests/test_instrumentation.py b/trio/_core/tests/test_instrumentation.py index 57d3461d3b..837a3e53b7 100644 --- a/trio/_core/tests/test_instrumentation.py +++ b/trio/_core/tests/test_instrumentation.py @@ -238,7 +238,7 @@ def task_exited(self, task): assert False # pragma: no cover @property - def after_run(self): + def after_run(self) -> None: raise ValueError("oops") async def main(): diff --git a/trio/_core/tests/test_io.py b/trio/_core/tests/test_io.py index a35cb09e7a..0cdd9c91a6 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/tests/test_io.py @@ -5,7 +5,7 @@ import random import errno from contextlib import suppress -from typing import Callable +from typing import Awaitable, Callable, Iterator, Tuple from ... import _core from ...testing import wait_all_tasks_blocked, Sequencer, assert_checkpoints @@ -30,8 +30,11 @@ def drain_socket(sock): pass +_SocketPair = Tuple[stdlib_socket.socket, stdlib_socket.socket] + + @pytest.fixture -def socketpair(): +def socketpair() -> Iterator[_SocketPair]: pair = stdlib_socket.socketpair() for sock in pair: sock.setblocking(False) @@ -49,10 +52,15 @@ def fileno_wrapper(fileobj): return fileno_wrapper +_WaitReadable = Callable[[stdlib_socket.socket], Awaitable[None]] +_WaitWritable = Callable[[stdlib_socket.socket], Awaitable[None]] +_NotifyClosing = Callable[[stdlib_socket.socket], None] + wait_readable_options = [trio.lowlevel.wait_readable] wait_writable_options = [trio.lowlevel.wait_writable] notify_closing_options = [trio.lowlevel.notify_closing] + for options_list in [ wait_readable_options, wait_writable_options, @@ -85,7 +93,9 @@ def get__name__(fn: Callable) -> str: # momentarily and then immediately resuming. @read_socket_test @write_socket_test -async def test_wait_basic(socketpair, wait_readable, wait_writable): +async def test_wait_basic( + socketpair: _SocketPair, wait_readable: _WaitReadable, wait_writable: _WaitWritable +) -> None: a, b = socketpair # They start out writable() @@ -151,7 +161,9 @@ async def block_on_write(): @read_socket_test -async def test_double_read(socketpair, wait_readable): +async def test_double_read( + socketpair: _SocketPair, wait_readable: _WaitWritable +) -> None: a, b = socketpair # You can't have two tasks trying to read from a socket at the same time @@ -164,7 +176,9 @@ async def test_double_read(socketpair, wait_readable): @write_socket_test -async def test_double_write(socketpair, wait_writable): +async def test_double_write( + socketpair: _SocketPair, wait_writable: _WaitWritable +) -> None: a, b = socketpair # You can't have two tasks trying to write to a socket at the same time @@ -181,8 +195,11 @@ async def test_double_write(socketpair, wait_writable): @write_socket_test @notify_closing_test async def test_interrupted_by_close( - socketpair, wait_readable, wait_writable, notify_closing -): + socketpair: _SocketPair, + wait_readable: _WaitReadable, + wait_writable: _WaitWritable, + notify_closing: _NotifyClosing, +) -> None: a, b = socketpair async def reader(): @@ -204,7 +221,9 @@ async def writer(): @read_socket_test @write_socket_test -async def test_socket_simultaneous_read_write(socketpair, wait_readable, wait_writable): +async def test_socket_simultaneous_read_write( + socketpair: _SocketPair, wait_readable: _WaitReadable, wait_writable: _WaitWritable +) -> None: record = [] async def r_task(sock): @@ -232,7 +251,9 @@ async def w_task(sock): @read_socket_test @write_socket_test -async def test_socket_actual_streaming(socketpair, wait_readable, wait_writable): +async def test_socket_actual_streaming( + socketpair: _SocketPair, wait_readable: _WaitReadable, wait_writable: _WaitWritable +) -> None: a, b = socketpair # Use a small send buffer on one of the sockets to increase the chance of @@ -281,7 +302,7 @@ async def receiver(sock, key): assert results["send_b"] == results["recv_a"] -async def test_notify_closing_on_invalid_object(): +async def test_notify_closing_on_invalid_object() -> None: # It should either be a no-op (generally on Unix, where we don't know # which fds are valid), or an OSError (on Windows, where we currently only # support sockets, so we have to do some validation to figure out whether @@ -297,7 +318,7 @@ async def test_notify_closing_on_invalid_object(): assert got_oserror or got_no_error -async def test_wait_on_invalid_object(): +async def test_wait_on_invalid_object() -> None: # We definitely want to raise an error everywhere if you pass in an # invalid fd to wait_* for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]: @@ -309,7 +330,7 @@ async def test_wait_on_invalid_object(): await wait(fileno) -async def test_io_manager_statistics(): +async def test_io_manager_statistics() -> None: def check(*, expected_readers, expected_writers): statistics = _core.current_statistics() print(statistics) @@ -357,7 +378,7 @@ def check(*, expected_readers, expected_writers): check(expected_readers=1, expected_writers=0) -async def test_can_survive_unnotified_close(): +async def test_can_survive_unnotified_close() -> None: # An "unnotified" close is when the user closes an fd/socket/handle # directly, without calling notify_closing first. This should never happen # -- users should call notify_closing before closing things. But, just in diff --git a/trio/_core/tests/test_ki.py b/trio/_core/tests/test_ki.py index 0e4db4af49..b803c16536 100644 --- a/trio/_core/tests/test_ki.py +++ b/trio/_core/tests/test_ki.py @@ -6,6 +6,7 @@ import threading import contextlib import time +from typing import Any, AsyncIterator, Iterator from async_generator import ( async_generator, @@ -30,13 +31,13 @@ def test_ki_self(): ki_self() -async def test_ki_enabled(): +async def test_ki_enabled() -> None: # Regular tasks aren't KI-protected assert not _core.currently_ki_protected() # Low-level call-soon callbacks are KI-protected token = _core.current_trio_token() - record = [] + record: Any = [] def check(): record.append(_core.currently_ki_protected()) @@ -46,23 +47,23 @@ def check(): assert record == [True] @_core.enable_ki_protection - def protected(): + def protected() -> None: assert _core.currently_ki_protected() unprotected() @_core.disable_ki_protection - def unprotected(): + def unprotected() -> None: assert not _core.currently_ki_protected() protected() @_core.enable_ki_protection - async def aprotected(): + async def aprotected() -> None: assert _core.currently_ki_protected() await aunprotected() @_core.disable_ki_protection - async def aunprotected(): + async def aunprotected() -> None: assert not _core.currently_ki_protected() await aprotected() @@ -74,7 +75,7 @@ async def aunprotected(): nursery.start_soon(aunprotected) @_core.enable_ki_protection - def gen_protected(): + def gen_protected() -> Iterator[None]: assert _core.currently_ki_protected() yield @@ -82,7 +83,7 @@ def gen_protected(): pass @_core.disable_ki_protection - def gen_unprotected(): + def gen_unprotected() -> Iterator[None]: assert not _core.currently_ki_protected() yield @@ -99,16 +100,16 @@ def gen_unprotected(): # .throw(), not the actual caller. So child() here would have a caller deep in # the guts of the run loop, and always be protected, even when it shouldn't # have been. (Solution: we don't use .throw() anymore.) -async def test_ki_enabled_after_yield_briefly(): +async def test_ki_enabled_after_yield_briefly() -> None: @_core.enable_ki_protection - async def protected(): + async def protected() -> None: await child(True) @_core.disable_ki_protection - async def unprotected(): + async def unprotected() -> None: await child(False) - async def child(expected): + async def child(expected: bool) -> None: import traceback traceback.print_stack() @@ -123,10 +124,10 @@ async def child(expected): # This also used to be broken due to # https://bugs.python.org/issue29590 -async def test_generator_based_context_manager_throw(): +async def test_generator_based_context_manager_throw() -> None: @contextlib.contextmanager @_core.enable_ki_protection - def protected_manager(): + def protected_manager() -> Iterator[None]: assert _core.currently_ki_protected() try: yield @@ -142,10 +143,10 @@ def protected_manager(): raise KeyError -async def test_agen_protection(): +async def test_agen_protection() -> None: @_core.enable_ki_protection @async_generator - async def agen_protected1(): + async def agen_protected1(): # type: ignore[misc] assert _core.currently_ki_protected() try: await yield_() @@ -154,7 +155,7 @@ async def agen_protected1(): @_core.disable_ki_protection @async_generator - async def agen_unprotected1(): + async def agen_unprotected1(): # type: ignore[misc] assert not _core.currently_ki_protected() try: await yield_() @@ -164,7 +165,7 @@ async def agen_unprotected1(): # Swap the order of the decorators: @async_generator @_core.enable_ki_protection - async def agen_protected2(): + async def agen_protected2(): # type: ignore[misc] assert _core.currently_ki_protected() try: await yield_() @@ -173,7 +174,7 @@ async def agen_protected2(): @async_generator @_core.disable_ki_protection - async def agen_unprotected2(): + async def agen_unprotected2(): # type: ignore[misc] assert not _core.currently_ki_protected() try: await yield_() @@ -182,7 +183,7 @@ async def agen_unprotected2(): # Native async generators @_core.enable_ki_protection - async def agen_protected3(): + async def agen_protected3() -> AsyncIterator[None]: assert _core.currently_ki_protected() try: yield @@ -190,7 +191,7 @@ async def agen_protected3(): assert _core.currently_ki_protected() @_core.disable_ki_protection - async def agen_unprotected3(): + async def agen_unprotected3() -> AsyncIterator[None]: assert not _core.currently_ki_protected() try: yield @@ -226,7 +227,7 @@ def test_ki_disabled_out_of_context(): assert _core.currently_ki_protected() -def test_ki_disabled_in_del(): +def test_ki_disabled_in_del() -> None: def nestedfunction(): return _core.currently_ki_protected() @@ -235,7 +236,7 @@ def __del__(): assert nestedfunction() @_core.disable_ki_protection - def outerfunction(): + def outerfunction() -> None: assert not _core.currently_ki_protected() assert not nestedfunction() __del__() @@ -245,7 +246,7 @@ def outerfunction(): assert nestedfunction() -def test_ki_protection_works(): +def test_ki_protection_works() -> None: async def sleeper(name, record): try: while True: @@ -276,7 +277,7 @@ async def raiser(name, record): # simulated control-C during raiser, which is *unprotected* print("check 1") - record = set() + record: Any = set() async def check_unprotected_kill(): async with _core.open_nursery() as nursery: @@ -342,19 +343,19 @@ async def main(): print("check 5") @_core.enable_ki_protection - async def main(): + async def main_a() -> None: assert _core.currently_ki_protected() ki_self() with pytest.raises(KeyboardInterrupt): await _core.checkpoint_if_cancelled() - _core.run(main) + _core.run(main_a) # KI arrives while main task is not abortable, b/c already scheduled print("check 6") @_core.enable_ki_protection - async def main(): + async def main_b() -> None: assert _core.currently_ki_protected() ki_self() await _core.cancel_shielded_checkpoint() @@ -363,13 +364,13 @@ async def main(): with pytest.raises(KeyboardInterrupt): await _core.checkpoint() - _core.run(main) + _core.run(main_b) # KI arrives while main task is not abortable, b/c refuses to be aborted print("check 7") @_core.enable_ki_protection - async def main(): + async def main_c() -> None: assert _core.currently_ki_protected() ki_self() task = _core.current_task() @@ -382,13 +383,13 @@ def abort(_): with pytest.raises(KeyboardInterrupt): await _core.checkpoint() - _core.run(main) + _core.run(main_c) # KI delivered via slow abort print("check 8") @_core.enable_ki_protection - async def main(): + async def main_d() -> None: assert _core.currently_ki_protected() ki_self() task = _core.current_task() @@ -402,7 +403,7 @@ def abort(raise_cancel): assert await _core.wait_task_rescheduled(abort) await _core.checkpoint() - _core.run(main) + _core.run(main_d) # KI arrives just before main task exits, so the run_sync_soon machinery # is still functioning and will accept the callback to deliver the KI, but @@ -411,18 +412,18 @@ def abort(raise_cancel): print("check 9") @_core.enable_ki_protection - async def main(): + async def main_e() -> None: ki_self() with pytest.raises(KeyboardInterrupt): - _core.run(main) + _core.run(main_e) print("check 10") # KI in unprotected code, with # restrict_keyboard_interrupt_to_checkpoints=True record = [] - async def main(): + async def main_f(): # We're not KI protected... assert not _core.currently_ki_protected() ki_self() @@ -432,13 +433,13 @@ async def main(): with pytest.raises(KeyboardInterrupt): await sleep(10) - _core.run(main, restrict_keyboard_interrupt_to_checkpoints=True) + _core.run(main_f, restrict_keyboard_interrupt_to_checkpoints=True) assert record == ["ok"] record = [] # Exact same code raises KI early if we leave off the argument, doesn't # even reach the record.append call: with pytest.raises(KeyboardInterrupt): - _core.run(main) + _core.run(main_f) assert record == [] # KI arrives while main task is inside a cancelled cancellation scope @@ -446,7 +447,7 @@ async def main(): print("check 11") @_core.enable_ki_protection - async def main(): + async def main_g() -> None: assert _core.currently_ki_protected() with _core.CancelScope() as cancel_scope: cancel_scope.cancel() @@ -458,7 +459,7 @@ async def main(): with pytest.raises(_core.Cancelled): await _core.checkpoint() - _core.run(main) + _core.run(main_g) def test_ki_is_good_neighbor(): @@ -481,31 +482,31 @@ async def main(): # Regression test for #461 -def test_ki_with_broken_threads(): +def test_ki_with_broken_threads() -> None: thread = threading.main_thread() # scary! - original = threading._active[thread.ident] + original = threading._active[thread.ident] # type: ignore[attr-defined] # put this in a try finally so we don't have a chance of cascading a # breakage down to everything else try: - del threading._active[thread.ident] + del threading._active[thread.ident] # type: ignore[attr-defined] @_core.enable_ki_protection - async def inner(): + async def inner() -> None: assert signal.getsignal(signal.SIGINT) != signal.default_int_handler _core.run(inner) finally: - threading._active[thread.ident] = original + threading._active[thread.ident] = original # type: ignore[attr-defined] # For details on why this test is non-trivial, see: # https://github.com/python-trio/trio/issues/42 # https://github.com/python-trio/trio/issues/109 @slow -def test_ki_wakes_us_up(): +def test_ki_wakes_us_up() -> None: assert is_main_thread() # This test is flaky due to a race condition on Windows; see: diff --git a/trio/_core/tests/test_mock_clock.py b/trio/_core/tests/test_mock_clock.py index bea9509686..35944760d4 100644 --- a/trio/_core/tests/test_mock_clock.py +++ b/trio/_core/tests/test_mock_clock.py @@ -147,7 +147,9 @@ async def waiter(): @slow -async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clock): +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero( + mock_clock: MockClock, +) -> None: # Checks that autojump_threshold=0 doesn't interfere with # calling wait_all_tasks_blocked with a non-zero cushion. diff --git a/trio/_core/tests/test_multierror.py b/trio/_core/tests/test_multierror.py index 7976c46ae2..c733c0ddd1 100644 --- a/trio/_core/tests/test_multierror.py +++ b/trio/_core/tests/test_multierror.py @@ -682,20 +682,20 @@ def test_custom_excepthook(): @slow @need_ipython -def test_ipython_exc_handler(): +def test_ipython_exc_handler() -> None: completed = run_script("simple_excepthook.py", use_ipython=True) check_simple_excepthook(completed) @slow @need_ipython -def test_ipython_imported_but_unused(): +def test_ipython_imported_but_unused() -> None: completed = run_script("simple_excepthook_IPython.py") check_simple_excepthook(completed) @slow -def test_partial_imported_but_unused(): +def test_partial_imported_but_unused() -> None: # Check that a functools.partial as sys.excepthook doesn't cause an exception when # importing trio. This was a problem due to the lack of a .__name__ attribute and # happens when inside a pytest-qt test case for example. @@ -705,7 +705,7 @@ def test_partial_imported_but_unused(): @slow @need_ipython -def test_ipython_custom_exc_handler(): +def test_ipython_custom_exc_handler() -> None: # Check we get a nice warning (but only one!) if the user is using IPython # and already has some other set_custom_exc handler installed. completed = run_script("ipython_custom_exc.py", use_ipython=True) @@ -731,7 +731,7 @@ def test_ipython_custom_exc_handler(): not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(), reason="need Ubuntu with python3-apport installed", ) -def test_apport_excepthook_monkeypatch_interaction(): +def test_apport_excepthook_monkeypatch_interaction() -> None: completed = run_script("apport_excepthook.py") stdout = completed.stdout.decode("utf-8") diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index 4c4e12b5df..f5b731c17b 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -10,6 +10,7 @@ from math import inf from textwrap import dedent import gc +from typing import Iterator, TypeVar import attr import outcome @@ -35,6 +36,9 @@ ) +_T = TypeVar("_T") + + # slightly different from _timeouts.sleep_forever because it returns the value # its rescheduled with, which is really only useful for tests of # rescheduling... @@ -792,7 +796,7 @@ async def task3(task_status): @slow -async def test_timekeeping(): +async def test_timekeeping() -> None: # probably a good idea to use a real clock for *one* test anyway... TARGET = 1.0 # give it a few tries in case of random CI server flakiness @@ -1334,7 +1338,7 @@ def cb(i): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy") -def test_TrioToken_run_sync_soon_late_crash(): +def test_TrioToken_run_sync_soon_late_crash() -> None: # Crash after system nursery is closed -- easiest way to do that is # from an async generator finalizer. record = [] @@ -2027,7 +2031,7 @@ async def test_Task_custom_sleep_data(): @types.coroutine -def async_yield(value): +def async_yield(value: _T) -> Iterator[_T]: yield value @@ -2201,7 +2205,7 @@ async def test_cancel_scope_deadline_duplicates(): @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage(): +async def test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/issues/1770 gc.collect() @@ -2231,7 +2235,7 @@ async def do_a_cancel(): @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_nursery_cancel_doesnt_create_cyclic_garbage(): +async def test_nursery_cancel_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/issues/1770#issuecomment-730229423 gc.collect() diff --git a/trio/_core/tests/test_thread_cache.py b/trio/_core/tests/test_thread_cache.py index 0f6e0a0715..18bfd2a8c9 100644 --- a/trio/_core/tests/test_thread_cache.py +++ b/trio/_core/tests/test_thread_cache.py @@ -1,9 +1,12 @@ import pytest +import _pytest.monkeypatch import threading from queue import Queue import time import sys +from outcome import Outcome + from .tutil import slow, gc_collect_harder from .. import _thread_cache from .._thread_cache import start_thread_soon, ThreadCache @@ -49,7 +52,7 @@ def deliver(outcome): @slow -def test_spawning_new_thread_from_deliver_reuses_starting_thread(): +def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None: # We know that no-one else is using the thread cache, so if we keep # submitting new jobs the instant the previous one is finished, we should # keep getting the same thread over and over. This tests both that the @@ -58,7 +61,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread(): # Make sure there are a few threads running, so if we weren't LIFO then we # could grab the wrong one. - q = Queue() + q = Queue[Outcome]() COUNT = 5 for _ in range(COUNT): start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result)) @@ -83,14 +86,15 @@ def deliver(n, _): assert len(seen_threads) == 1 +# can switch to annotating from pytest directly as of 6.2.0 @slow -def test_idle_threads_exit(monkeypatch): +def test_idle_threads_exit(monkeypatch: _pytest.monkeypatch.MonkeyPatch) -> None: # Temporarily set the idle timeout to something tiny, to speed up the # test. (But non-zero, so that the worker loop will at least yield the # CPU.) monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) - q = Queue() + q = Queue[threading.Thread]() start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread())) seen_thread = q.get() # Since the idle timeout is 0, after sleeping for 1 second, the thread diff --git a/trio/_core/tests/test_windows.py b/trio/_core/tests/test_windows.py index e6bab82204..930408dba3 100644 --- a/trio/_core/tests/test_windows.py +++ b/trio/_core/tests/test_windows.py @@ -1,6 +1,9 @@ +from io import BufferedWriter import os import tempfile from contextlib import contextmanager +import sys +from typing import Iterator, Tuple import pytest @@ -22,196 +25,196 @@ ) -# The undocumented API that this is testing should be changed to stop using -# UnboundedQueue (or just removed until we have time to redo it), but until -# then we filter out the warning. -@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning") -async def test_completion_key_listen(): - async def post(key): - iocp = ffi.cast("HANDLE", _core.current_iocp()) - for i in range(10): - print("post", i) - if i % 3 == 0: - await _core.checkpoint() - success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL) - assert success - - with _core.monitor_completion_key() as (key, queue): - async with _core.open_nursery() as nursery: - nursery.start_soon(post, key) - i = 0 - print("loop") - async for batch in queue: # pragma: no branch - print("got some", batch) - for info in batch: - assert info.lpOverlapped == 0 - assert info.dwNumberOfBytesTransferred == i - i += 1 - if i == 10: - break - print("end loop") - - -async def test_readinto_overlapped(): - data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024 - buffer = bytearray(len(data)) - - with tempfile.TemporaryDirectory() as tdir: - tfile = os.path.join(tdir, "numbers.txt") - with open(tfile, "wb") as fp: - fp.write(data) - fp.flush() - - rawname = tfile.encode("utf-16le") + b"\0\0" - rawname_buf = ffi.from_buffer(rawname) - handle = kernel32.CreateFileW( - ffi.cast("LPCWSTR", rawname_buf), - FileFlags.GENERIC_READ, - FileFlags.FILE_SHARE_READ, - ffi.NULL, # no security attributes - FileFlags.OPEN_EXISTING, - FileFlags.FILE_FLAG_OVERLAPPED, - ffi.NULL, # no template file - ) - if handle == INVALID_HANDLE_VALUE: # pragma: no cover - raise_winerror() - - try: - with memoryview(buffer) as buffer_view: - - async def read_region(start, end): - await _core.readinto_overlapped( - handle, buffer_view[start:end], start - ) +# mypy recognizes this. an assert would break the pytest skipif +if sys.platform == "win32": + # The undocumented API that this is testing should be changed to stop using + # UnboundedQueue (or just removed until we have time to redo it), but until + # then we filter out the warning. + @pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning") + async def test_completion_key_listen() -> None: + async def post(key): + iocp = ffi.cast("HANDLE", _core.current_iocp()) + for i in range(10): + print("post", i) + if i % 3 == 0: + await _core.checkpoint() + success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL) + assert success + + with _core.monitor_completion_key() as (key, queue): + async with _core.open_nursery() as nursery: + nursery.start_soon(post, key) + i = 0 + print("loop") + async for batch in queue: # pragma: no branch + print("got some", batch) + for info in batch: + assert info.lpOverlapped == 0 + assert info.dwNumberOfBytesTransferred == i + i += 1 + if i == 10: + break + print("end loop") + + async def test_readinto_overlapped(): + data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024 + buffer = bytearray(len(data)) + + with tempfile.TemporaryDirectory() as tdir: + tfile = os.path.join(tdir, "numbers.txt") + with open(tfile, "wb") as fp: + fp.write(data) + fp.flush() + + rawname = tfile.encode("utf-16le") + b"\0\0" + rawname_buf = ffi.from_buffer(rawname) + handle = kernel32.CreateFileW( + ffi.cast("LPCWSTR", rawname_buf), + FileFlags.GENERIC_READ, + FileFlags.FILE_SHARE_READ, + ffi.NULL, # no security attributes + FileFlags.OPEN_EXISTING, + FileFlags.FILE_FLAG_OVERLAPPED, + ffi.NULL, # no template file + ) + if handle == INVALID_HANDLE_VALUE: # pragma: no cover + raise_winerror() - _core.register_with_iocp(handle) - async with _core.open_nursery() as nursery: - for start in range(0, 4096, 512): - nursery.start_soon(read_region, start, start + 512) - - assert buffer == data - - with pytest.raises(BufferError): - await _core.readinto_overlapped(handle, b"immutable") - finally: - kernel32.CloseHandle(handle) + try: + with memoryview(buffer) as buffer_view: + async def read_region(start, end): + await _core.readinto_overlapped( + handle, buffer_view[start:end], start + ) -@contextmanager -def pipe_with_overlapped_read(): - from asyncio.windows_utils import pipe - import msvcrt + _core.register_with_iocp(handle) + async with _core.open_nursery() as nursery: + for start in range(0, 4096, 512): + nursery.start_soon(read_region, start, start + 512) - read_handle, write_handle = pipe(overlapped=(True, False)) - try: - write_fd = msvcrt.open_osfhandle(write_handle, 0) - yield os.fdopen(write_fd, "wb", closefd=False), read_handle - finally: - kernel32.CloseHandle(ffi.cast("HANDLE", read_handle)) - kernel32.CloseHandle(ffi.cast("HANDLE", write_handle)) + assert buffer == data + with pytest.raises(BufferError): + await _core.readinto_overlapped(handle, b"immutable") + finally: + kernel32.CloseHandle(handle) -def test_forgot_to_register_with_iocp(): - with pipe_with_overlapped_read() as (write_fp, read_handle): - with write_fp: - write_fp.write(b"test\n") + @contextmanager + def pipe_with_overlapped_read() -> Iterator[Tuple[BufferedWriter, int]]: + from asyncio.windows_utils import pipe + import msvcrt - left_run_yet = False + read_handle, write_handle = pipe(overlapped=(True, False)) + try: + write_fd = msvcrt.open_osfhandle(write_handle, 0) + yield os.fdopen(write_fd, "wb", closefd=False), read_handle + finally: + kernel32.CloseHandle(ffi.cast("HANDLE", read_handle)) + kernel32.CloseHandle(ffi.cast("HANDLE", write_handle)) - async def main(): - target = bytearray(1) - try: - async with _core.open_nursery() as nursery: - nursery.start_soon( - _core.readinto_overlapped, read_handle, target, name="xyz" - ) - await wait_all_tasks_blocked() - nursery.cancel_scope.cancel() - finally: - # Run loop is exited without unwinding running tasks, so - # we don't get here until the main() coroutine is GC'ed - assert left_run_yet - - with pytest.raises(_core.TrioInternalError) as exc_info: - _core.run(main) - left_run_yet = True - assert "Failed to cancel overlapped I/O in xyz " in str(exc_info.value) - assert "forget to call register_with_iocp()?" in str(exc_info.value) - - # Make sure the Nursery.__del__ assertion about dangling children - # gets put with the correct test - del exc_info - gc_collect_harder() - - -@slow -async def test_too_late_to_cancel(): - import time - - with pipe_with_overlapped_read() as (write_fp, read_handle): - _core.register_with_iocp(read_handle) - target = bytearray(6) - async with _core.open_nursery() as nursery: - # Start an async read in the background - nursery.start_soon(_core.readinto_overlapped, read_handle, target) - await wait_all_tasks_blocked() - - # Synchronous write to the other end of the pipe + def test_forgot_to_register_with_iocp(): + with pipe_with_overlapped_read() as (write_fp, read_handle): with write_fp: - write_fp.write(b"test1\ntest2\n") - - # Note: not trio.sleep! We're making sure the OS level - # ReadFile completes, before Trio has a chance to execute - # another checkpoint and notice it completed. - time.sleep(1) - nursery.cancel_scope.cancel() - assert target[:6] == b"test1\n" - - # Do another I/O to make sure we've actually processed the - # fallback completion that was posted when CancelIoEx failed. - assert await _core.readinto_overlapped(read_handle, target) == 6 - assert target[:6] == b"test2\n" - - -def test_lsp_that_hooks_select_gives_good_error(monkeypatch): - from .._windows_cffi import WSAIoctls, _handle - from .. import _io_windows - - def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): - if hasattr(sock, "fileno"): # pragma: no branch - sock = sock.fileno() - if which == WSAIoctls.SIO_BSP_HANDLE_SELECT: - return _handle(sock + 1) - else: - return _handle(sock) - - monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying) - with pytest.raises( - RuntimeError, match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ" - ): - _core.run(sleep, 0) - - -def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch): - # This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns - # self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns - # self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to - # make sure we get an error rather than an infinite loop. - - from .._windows_cffi import WSAIoctls, _handle - from .. import _io_windows - - def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): - if hasattr(sock, "fileno"): # pragma: no branch - sock = sock.fileno() - if which == WSAIoctls.SIO_BASE_HANDLE: - raise OSError("nope") - else: - return _handle(sock) - - monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying) - with pytest.raises( - RuntimeError, - match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff", - ): - _core.run(sleep, 0) + write_fp.write(b"test\n") + + left_run_yet = False + + async def main(): + target = bytearray(1) + try: + async with _core.open_nursery() as nursery: + nursery.start_soon( + _core.readinto_overlapped, read_handle, target, name="xyz" + ) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + finally: + # Run loop is exited without unwinding running tasks, so + # we don't get here until the main() coroutine is GC'ed + assert left_run_yet + + with pytest.raises(_core.TrioInternalError) as exc_info: + _core.run(main) + left_run_yet = True + assert "Failed to cancel overlapped I/O in xyz " in str(exc_info.value) + assert "forget to call register_with_iocp()?" in str(exc_info.value) + + # Make sure the Nursery.__del__ assertion about dangling children + # gets put with the correct test + del exc_info + gc_collect_harder() + + @slow + async def test_too_late_to_cancel() -> None: + import time + + with pipe_with_overlapped_read() as (write_fp, read_handle): + _core.register_with_iocp(read_handle) + target = bytearray(6) + async with _core.open_nursery() as nursery: + # Start an async read in the background + nursery.start_soon(_core.readinto_overlapped, read_handle, target) + await wait_all_tasks_blocked() + + # Synchronous write to the other end of the pipe + with write_fp: + write_fp.write(b"test1\ntest2\n") + + # Note: not trio.sleep! We're making sure the OS level + # ReadFile completes, before Trio has a chance to execute + # another checkpoint and notice it completed. + time.sleep(1) + nursery.cancel_scope.cancel() + assert target[:6] == b"test1\n" + + # Do another I/O to make sure we've actually processed the + # fallback completion that was posted when CancelIoEx failed. + assert await _core.readinto_overlapped(read_handle, target) == 6 + assert target[:6] == b"test2\n" + + def test_lsp_that_hooks_select_gives_good_error(monkeypatch): + from .._windows_cffi import WSAIoctls, _handle + from .. import _io_windows + + def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): + if hasattr(sock, "fileno"): # pragma: no branch + sock = sock.fileno() + if which == WSAIoctls.SIO_BSP_HANDLE_SELECT: + return _handle(sock + 1) + else: + return _handle(sock) + + monkeypatch.setattr( + _io_windows, "_get_underlying_socket", patched_get_underlying + ) + with pytest.raises( + RuntimeError, match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ" + ): + _core.run(sleep, 0) + + def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch): + # This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns + # self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns + # self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to + # make sure we get an error rather than an infinite loop. + + from .._windows_cffi import WSAIoctls, _handle + from .. import _io_windows + + def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): + if hasattr(sock, "fileno"): # pragma: no branch + sock = sock.fileno() + if which == WSAIoctls.SIO_BASE_HANDLE: + raise OSError("nope") + else: + return _handle(sock) + + monkeypatch.setattr( + _io_windows, "_get_underlying_socket", patched_get_underlying + ) + with pytest.raises( + RuntimeError, + match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff", + ): + _core.run(sleep, 0) diff --git a/trio/_core/tests/tutil.py b/trio/_core/tests/tutil.py index 00669e883e..498c2cf6b0 100644 --- a/trio/_core/tests/tutil.py +++ b/trio/_core/tests/tutil.py @@ -2,7 +2,7 @@ import socket as stdlib_socket import os import sys -from typing import TYPE_CHECKING +from typing import Iterator, TYPE_CHECKING import pytest import warnings @@ -69,7 +69,7 @@ def gc_collect_harder(): # manager should be used anywhere this happens to hide those messages, because # when expected they're clutter. @contextmanager -def ignore_coroutine_never_awaited_warnings(): +def ignore_coroutine_never_awaited_warnings() -> Iterator[None]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited") try: diff --git a/trio/_deprecate.py b/trio/_deprecate.py index 4f9f15ec35..9010d6953e 100644 --- a/trio/_deprecate.py +++ b/trio/_deprecate.py @@ -1,11 +1,15 @@ import sys from functools import wraps from types import ModuleType +from typing import Any, Callable, Optional, TypeVar, Union import warnings import attr +_T = TypeVar("_T", bound=Callable[..., Any]) + + # We want our warnings to be visible by default (at least for now), but we # also want it to be possible to override that using the -W switch. AFAICT # this means we cannot inherit from DeprecationWarning, because the only way @@ -39,7 +43,14 @@ def _stringify(thing): return str(thing) -def warn_deprecated(thing, version, *, issue, instead, stacklevel=2): +def warn_deprecated( + thing: object, + version: str, + *, + issue: Optional[int], + instead: Optional[object], + stacklevel: int = 2, +) -> None: stacklevel += 1 msg = "{} is deprecated since Trio {}".format(_stringify(thing), version) if instead is None: @@ -53,20 +64,29 @@ def warn_deprecated(thing, version, *, issue, instead, stacklevel=2): # @deprecated("0.2.0", issue=..., instead=...) # def ... -def deprecated(version, *, thing=None, issue, instead): - def do_wrap(fn): - nonlocal thing - - @wraps(fn) - def wrapper(*args, **kwargs): - warn_deprecated(thing, version, instead=instead, issue=issue) +def deprecated( + version: str, + *, + thing: Optional[str] = None, + issue: Optional[int], + instead: object, +) -> Callable[[_T], _T]: + def do_wrap(fn: _T) -> _T: + wrapper: _T + + @wraps(fn) # type: ignore[no-redef] + def wrapper(*args: object, **kwargs: object) -> object: + warn_deprecated(final_thing, version, instead=instead, issue=issue) return fn(*args, **kwargs) # If our __module__ or __qualname__ get modified, we want to pick up # on that, so we read them off the wrapper object instead of the (now # hidden) fn object + final_thing: Union[str, _T] if thing is None: - thing = wrapper + final_thing = wrapper + else: + final_thing = thing if wrapper.__doc__ is not None: doc = wrapper.__doc__ @@ -87,10 +107,12 @@ def wrapper(*args, **kwargs): return do_wrap -def deprecated_alias(old_qualname, new_fn, version, *, issue): - @deprecated(version, issue=issue, instead=new_fn) +def deprecated_alias(old_qualname: str, new_fn: _T, version: str, *, issue: int) -> _T: + wrapper: _T + + @deprecated(version, issue=issue, instead=new_fn) # type: ignore[no-redef] @wraps(new_fn, assigned=("__module__", "__annotations__")) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: "Deprecated alias." return new_fn(*args, **kwargs) diff --git a/trio/_file_io.py b/trio/_file_io.py index 8c8425c775..71cebffa6e 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -1,5 +1,30 @@ from functools import partial import io +from typing import ( + Any, + # AnyStr, + # AsyncContextManager, + AsyncIterator, + # Awaitable, + # Callable, + # ContextManager, + # FrozenSet, + # Iterator, + # Mapping, + # NoReturn, + Optional, + # Sequence, + Union, + # Sequence, + # TypeVar, + Tuple, + List, + Iterable, + TextIO, + BinaryIO, + IO, + overload, +) from .abc import AsyncResource from ._util import async_wraps @@ -58,11 +83,11 @@ class AsyncIOWrapper(AsyncResource): """ - def __init__(self, file): + def __init__(self, file: io.IOBase): self._wrapped = file @property - def wrapped(self): + def wrapped(self) -> io.IOBase: """object: A reference to the wrapped file object""" return self._wrapped @@ -74,7 +99,7 @@ def __getattr__(self, name): meth = getattr(self._wrapped, name) @async_wraps(self.__class__, self._wrapped.__class__, name) - async def wrapper(*args, **kwargs): + async def wrapper(*args, **kwargs): # type: ignore[misc] func = partial(meth, *args, **kwargs) return await trio.to_thread.run_sync(func) @@ -126,6 +151,120 @@ async def aclose(self): await trio.lowlevel.checkpoint_if_cancelled() +# _file_io +class _AsyncIOBase(trio.abc.AsyncResource): + closed: bool + + def __aiter__(self) -> AsyncIterator[bytes]: + ... + + async def __anext__(self) -> bytes: + ... + + async def aclose(self) -> None: + ... + + def fileno(self) -> int: + ... + + async def flush(self) -> None: + ... + + def isatty(self) -> bool: + ... + + def readable(self) -> bool: + ... + + async def readlines(self, hint: int = ...) -> List[bytes]: + ... + + async def seek(self, offset: int, whence: int = ...) -> int: + ... + + def seekable(self) -> bool: + ... + + async def tell(self) -> int: + ... + + async def truncate(self, size: Optional[int] = ...) -> int: + ... + + def writable(self) -> bool: + ... + + async def writelines(self, lines: Iterable[bytes]) -> None: + ... + + async def readline(self, size: int = ...) -> bytes: + ... + + +class _AsyncRawIOBase(_AsyncIOBase): + async def readall(self) -> bytes: + ... + + async def readinto(self, b: bytearray) -> Optional[int]: + ... + + async def write(self, b: bytes) -> Optional[int]: + ... + + async def read(self, size: int = ...) -> Optional[bytes]: + ... + + +class _AsyncBufferedIOBase(_AsyncIOBase): + async def detach(self) -> _AsyncRawIOBase: + ... + + async def readinto(self, b: bytearray) -> int: + ... + + async def write(self, b: bytes) -> int: + ... + + async def readinto1(self, b: bytearray) -> int: + ... + + async def read(self, size: Optional[int] = ...) -> bytes: + ... + + async def read1(self, size: int = ...) -> bytes: + ... + + +class _AsyncTextIOBase(_AsyncIOBase): + encoding: str + errors: Optional[str] + newlines: Union[str, Tuple[str, ...], None] + + def __aiter__(self) -> AsyncIterator[str]: # type: ignore + ... + + async def __anext__(self) -> str: # type: ignore + ... + + async def detach(self) -> _AsyncRawIOBase: + ... + + async def write(self, s: str) -> int: + ... + + async def readline(self, size: int = ...) -> str: # type: ignore + ... + + async def read(self, size: Optional[int] = ...) -> str: + ... + + async def seek(self, offset: int, whence: int = ...) -> int: + ... + + async def tell(self) -> int: + ... + + async def open_file( file, mode="r", @@ -161,6 +300,26 @@ async def open_file( return _file +@overload +def wrap_file(obj: Union[TextIO, io.TextIOBase]) -> _AsyncTextIOBase: + ... + + +@overload +def wrap_file(obj: Union[BinaryIO, io.BufferedIOBase]) -> _AsyncBufferedIOBase: + ... + + +@overload +def wrap_file(obj: io.RawIOBase) -> _AsyncRawIOBase: + ... + + +@overload +def wrap_file(obj: Union[IO[Any], io.IOBase]) -> _AsyncIOBase: + ... + + def wrap_file(file): """This wraps any file object in a wrapper that provides an asynchronous file object interface. diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 545fac8641..172468b4de 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -1,7 +1,8 @@ from contextlib import contextmanager +from typing import Iterator, Set import trio -from trio.socket import getaddrinfo, SOCK_STREAM, socket +from trio.socket import getaddrinfo, SOCK_STREAM, socket, SocketType # Implementation of RFC 6555 "Happy eyeballs" # https://tools.ietf.org/html/rfc6555 @@ -103,8 +104,8 @@ @contextmanager -def close_all(): - sockets_to_close = set() +def close_all() -> Iterator[Set[SocketType]]: + sockets_to_close: Set[SocketType] = set() try: yield sockets_to_close finally: diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py index e5aba4695f..3e294200f0 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/trio/_highlevel_open_unix_stream.py @@ -1,5 +1,6 @@ import os from contextlib import contextmanager +from typing import Iterator, Protocol, TypeVar import trio from trio.socket import socket, SOCK_STREAM @@ -12,8 +13,16 @@ has_unix = False +class Closable(Protocol): + def close(self): + ... + + +_CL = TypeVar("_CL", bound=Closable) + + @contextmanager -def close_on_error(obj): +def close_on_error(obj: _CL) -> Iterator[_CL]: try: yield obj except: diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 0d9dbc0e92..5ab52f7c21 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -2,6 +2,7 @@ import errno from contextlib import contextmanager +from typing import Iterator import trio from . import socket as tsocket @@ -23,7 +24,7 @@ @contextmanager -def _translate_socket_errors_to_stream_errors(): +def _translate_socket_errors_to_stream_errors() -> Iterator[None]: try: yield except OSError as exc: diff --git a/trio/_path.py b/trio/_path.py index 53a52ddf15..2012a7c87b 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -2,12 +2,26 @@ import os import types import pathlib -from typing import Iterator, TYPE_CHECKING, TypeVar, Union +from typing import ( + Any, + Callable, + Iterator, + Optional, + overload, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) + +from typing_extensions import Protocol import trio from trio._util import async_wraps, Final +from ._file_io import _AsyncIOBase +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) _P = TypeVar("_P", bound="Path") @@ -18,9 +32,15 @@ def rewrap_path(value): return value -def _forward_factory(cls, attr_name, attr): - @wraps(attr) - def wrapper(self, *args, **kwargs): +class _Wrapper(Protocol): + _wrapped: object + + +def _forward_factory(cls: object, attr_name: str, attr: _Fn) -> _Fn: + wrapper: _Fn + + @wraps(attr) # type: ignore[no-redef] + def wrapper(self: _Wrapper, *args: object, **kwargs: object) -> object: attr = getattr(self._wrapped, attr_name) value = attr(*args, **kwargs) return rewrap_path(value) @@ -28,11 +48,13 @@ def wrapper(self, *args, **kwargs): return wrapper -def _forward_magic(cls, attr): +def _forward_magic(cls: Type, attr: _Fn) -> _Fn: sentinel = object() - @wraps(attr) - def wrapper(self, other=sentinel): + wrapper: _Fn + + @wraps(attr) # type: ignore[no-redef] + def wrapper(self: _Wrapper, other: object = sentinel) -> object: if other is sentinel: return attr(self._wrapped) if isinstance(other, cls): @@ -43,9 +65,9 @@ def wrapper(self, other=sentinel): return wrapper -def iter_wrapper_factory(cls, meth_name): +def iter_wrapper_factory(cls: Type, meth_name: str): # type: ignore[no-untyped-def] @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): + async def wrapper(self, *args, **kwargs): # type: ignore[misc] meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) # Make sure that the full iteration is performed in the thread @@ -56,9 +78,9 @@ async def wrapper(self, *args, **kwargs): return wrapper -def thread_wrapper_factory(cls, meth_name): +def thread_wrapper_factory(cls: Type, meth_name: str): # type: ignore[no-untyped-def] @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): + async def wrapper(self, *args, **kwargs): # type: ignore[misc] meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) @@ -67,10 +89,10 @@ async def wrapper(self, *args, **kwargs): return wrapper -def classmethod_wrapper_factory(cls, meth_name): - @classmethod +def classmethod_wrapper_factory(cls: Type, meth_name: str): # type: ignore[no-untyped-def] + @classmethod # type: ignore[misc] @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(cls, *args, **kwargs): + async def wrapper(cls, *args, **kwargs): # type: ignore[misc] meth = getattr(cls._wraps, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) @@ -165,12 +187,21 @@ class Path(metaclass=AsyncAutoWrapperType): # of the file can be hinted regularly rather than in a separate stub .pyi. # TODO: Can we handle os.PathLike[str] at least for 3.9+? - def joinpath(self: _P, *other: Union[str, os.PathLike]) -> _P: + def joinpath(self: _P, *other: Union[os.PathLike, str]) -> _P: ... def iterdir(self: _P) -> Iterator[_P]: ... + def __gt__(self, other: os.PathLike) -> bool: + ... + + def __lt__(self, other: os.PathLike) -> bool: + ... + + def __truediv__(self: _P, *args: Union[os.PathLike, str]) -> _P: + ... + def __init__(self, *args): self._wrapped = pathlib.Path(*args) @@ -189,8 +220,19 @@ def __repr__(self): def __fspath__(self): return os.fspath(self._wrapped) + @overload # type: ignore[misc] + async def open( + self, + mode: str = ..., + buffering: int = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + ) -> _AsyncIOBase: + ... + @wraps(pathlib.Path.open) - async def open(self, *args, **kwargs): + async def open(self, *args: object, **kwargs: object) -> object: """Open the file pointed to by the path, like the :func:`trio.open_file` function does. diff --git a/trio/_signals.py b/trio/_signals.py index cee3b7db53..252872d7bd 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -1,10 +1,17 @@ import signal from contextlib import contextmanager from collections import OrderedDict +from types import FrameType +from typing import Any, Callable, Iterable, Iterator, Optional, Set, Tuple, Union import trio from ._util import signal_raise, is_main_thread, ConflictDetector +# https://github.com/python/typeshed/blob/master/stdlib/3/signal.pyi#L82-L83 +_SignalNumber = Union[int, signal.Signals] +_Handler = Union[Callable[[signal.Signals, FrameType], Any], int, signal.Handlers, None] + + # Discussion of signal handling strategies: # # - On Windows signals barely exist. There are no options; signal handlers are @@ -42,7 +49,9 @@ @contextmanager -def _signal_handler(signals, handler): +def _signal_handler( + signals: Iterable[_SignalNumber], handler: _Handler +) -> Iterator[None]: original_handlers = {} try: for signum in set(signals): @@ -111,7 +120,7 @@ async def __anext__(self): @contextmanager -def open_signal_receiver(*signals): +def open_signal_receiver(*signals: _SignalNumber) -> Iterator[SignalReceiver]: """A context manager for catching signals. Entering this context manager starts listening for the given signals and diff --git a/trio/_socket.py b/trio/_socket.py index fcf26e072b..6fe70714d8 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -3,12 +3,33 @@ import select import socket as _stdlib_socket from functools import wraps as _wraps -from typing import TYPE_CHECKING +from typing import ( + Awaitable, + Callable, + Mapping, + Union, + Optional, + Iterable, + Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, + List, + Any, + overload, +) import idna as _idna import trio from . import _core +from ._abc import HostnameResolver, SocketFactory + + +_T = TypeVar("_T") + + +_Address = Union[tuple, str] # Usage: @@ -59,8 +80,8 @@ async def __aexit__(self, etype, value, tb): # Overrides ################################################################ -_resolver = _core.RunVar("hostname_resolver") -_socket_factory = _core.RunVar("socket_factory") +_resolver = _core.RunVar[Optional[HostnameResolver]]("hostname_resolver") +_socket_factory = _core.RunVar[Optional[SocketFactory]]("socket_factory") def set_custom_hostname_resolver(hostname_resolver): @@ -95,7 +116,9 @@ def set_custom_hostname_resolver(hostname_resolver): return old -def set_custom_socket_factory(socket_factory): +def set_custom_socket_factory( + socket_factory: Optional[SocketFactory], +) -> Optional[SocketFactory]: """Set a custom socket object factory. This function allows you to replace Trio's normal socket class with a @@ -230,7 +253,7 @@ async def getprotobyname(name): ################################################################ -def from_stdlib_socket(sock): +def from_stdlib_socket(sock: _stdlib_socket.socket) -> "_SocketType": """Convert a standard library :func:`socket.socket` object into a Trio socket object. @@ -239,7 +262,7 @@ def from_stdlib_socket(sock): @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) -def fromfd(fd, family, type, proto=0): +def fromfd(fd: int, family: int, type: int, proto: int = 0) -> "SocketType": """Like :func:`socket.fromfd`, but returns a Trio socket object.""" family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -250,27 +273,55 @@ def fromfd(fd, family, type, proto=0): ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(*args, **kwargs): + def fromshare(*args: object, **kwargs: object) -> "SocketType": return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs)) +# @overload +# def socketpair() -> Tuple["SocketType", "SocketType"]: +# ... +# +# +# @overload +# def socketpair(family: int = ...) -> Tuple["SocketType", "SocketType"]: +# ... +# +# +# @overload +# def socketpair(family: int = ..., type: int = ...) -> Tuple["SocketType", "SocketType"]: +# ... + + +# @overload # type: ignore[misc] +# def socketpair( +# family: int = ..., type: int = ..., proto: int = ... +# ) -> Tuple["SocketType", "SocketType"]: +# ... + +# TODO: uh... stuff... comments... @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) -def socketpair(*args, **kwargs): +# def socketpair( +# family: int = _stdlib_socket.AF_INET, +## type: int = _stdlib_socket.SOCK_STREAM, +# proto: int = 0, +# ) -> Tuple["SocketType", "SocketType"]: +def socketpair(*args: object, **kwargs: object) -> Tuple["SocketType", "SocketType"]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. """ - left, right = _stdlib_socket.socketpair(*args, **kwargs) + # left, right = _stdlib_socket.socketpair(family=family, type=type, proto=proto) + left, right = _stdlib_socket.socketpair(*args, **kwargs) # type: ignore[arg-type] return (from_stdlib_socket(left), from_stdlib_socket(right)) @_wraps(_stdlib_socket.socket, assigned=(), updated=()) def socket( - family=_stdlib_socket.AF_INET, - type=_stdlib_socket.SOCK_STREAM, - proto=0, - fileno=None, -): + family: int = _stdlib_socket.AF_INET, + type: int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, + fileno: Optional[int] = None, +) -> "SocketType": """Create a new Trio socket, like :func:`socket.socket`. This function's behavior can be customized using @@ -318,7 +369,7 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno): # But on other platforms (e.g. Windows) SOCK_NONBLOCK and SOCK_CLOEXEC aren't # even defined. To recover the actual socket type (e.g. SOCK_STREAM) from a # socket.type attribute, mask with this: -_SOCK_TYPE_MASK = ~( +_SOCK_TYPE_MASK: int = ~( getattr(_stdlib_socket, "SOCK_NONBLOCK", 0) | getattr(_stdlib_socket, "SOCK_CLOEXEC", 0) ) @@ -327,7 +378,7 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno): # This function will modify the given socket to match the behavior in python # 3.7. This will become unecessary and can be removed when support for versions # older than 3.7 is dropped. -def real_socket_type(type_num): +def real_socket_type(type_num: int) -> int: return type_num & _SOCK_TYPE_MASK @@ -335,7 +386,7 @@ def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False): fn = getattr(_stdlib_socket.socket, methname) @_wraps(fn, assigned=("__name__",), updated=()) - async def wrapper(self, *args, **kwargs): + async def wrapper(self, *args, **kwargs): # type: ignore[misc] return await self._nonblocking_helper(fn, args, kwargs, wait_fn) wrapper.__doc__ = f"""Like :meth:`socket.socket.{methname}`, but async. @@ -356,9 +407,153 @@ def __init__(self): "want to construct a socket object" ) + if TYPE_CHECKING: + + @property + def family(self) -> int: + ... + + @property + def type(self) -> int: + ... + + @property + def proto(self) -> int: + ... + + @property + def did_shutdown_SHUT_WR(self) -> bool: + ... + + def __enter__(self: _T) -> _T: + ... + + def __exit__(self, *args: Any) -> None: + ... + + def dup(self) -> "SocketType": + ... + + def close(self) -> None: + ... + + async def bind(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: + ... + + def shutdown(self, flag: int) -> None: + ... + + def is_readable(self) -> bool: + ... + + async def wait_writable(self) -> None: + ... + + async def accept(self) -> Tuple[SocketType, Any]: + ... + + async def connect(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: + ... + + async def recv(self, bufsize: int, flags: int = ...) -> bytes: + ... + + async def recv_into( + self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... + ) -> int: + ... + + async def recvfrom(self, bufsize: int, flags: int = ...) -> Tuple[bytes, Any]: + ... + + async def recvfrom_into( + self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... + ) -> Tuple[int, Any]: + ... + + async def recvmsg( + self, bufsize: int, ancbufsize: int = ..., flags: int = ... + ) -> Tuple[bytes, List[Tuple[int, int, bytes]], int, Any]: + ... + + async def recvmsg_into( + self, + buffers: Iterable[Union[bytearray, memoryview]], + ancbufsize: int = ..., + flags: int = ..., + ) -> Tuple[int, List[Tuple[int, int, bytes]], int, Any]: + ... + + async def send(self, data: bytes, flags: int = ...) -> int: + ... + + async def sendmsg( + self, + buffers: Iterable[Union[bytes, memoryview]], + ancdata: Iterable[Tuple[int, int, Union[bytes, memoryview]]] = ..., + flags: int = ..., + address: Union[Tuple[Any, ...], str] = ..., + ) -> int: + ... + + @overload + async def sendto( + self, data: bytes, address: Union[Tuple[Any, ...], str] + ) -> int: + ... + + @overload + async def sendto( + self, data: bytes, flags: int, address: Union[Tuple[Any, ...], str] + ) -> int: + ... + + async def sendto(self, *args: object, **kwargs: object) -> int: + ... + + def detach(self) -> int: + ... + + def get_inheritable(self) -> bool: + ... + + def set_inheritable(self, inheritable: bool) -> None: + ... + + def fileno(self) -> int: + ... + + def getpeername(self) -> Any: + ... + + def getsockname(self) -> Any: + ... + + @overload + def getsockopt(self, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt(self, *args: object, **kwargs: object) -> object: + ... + + def setsockopt( + self, level: int, optname: int, value: Union[int, bytes] + ) -> None: + ... + + def listen(self, backlog: int) -> None: + ... + + def share(self, process_id: int) -> bytes: + ... + class _SocketType(SocketType): - def __init__(self, sock): + def __init__(self, sock: _stdlib_socket.socket): if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we # certainly don't want to blindly wrap one of those. @@ -413,22 +608,22 @@ def __exit__(self, *exc_info): return self._sock.__exit__(*exc_info) @property - def family(self): + def family(self) -> int: return self._sock.family @property - def type(self): + def type(self) -> int: # Modify the socket type do match what is done on python 3.7. When # support for versions older than 3.7 is dropped, this can be updated # to just return self._sock.type return real_socket_type(self._sock.type) @property - def proto(self): + def proto(self) -> int: return self._sock.proto @property - def did_shutdown_SHUT_WR(self): + def did_shutdown_SHUT_WR(self) -> bool: return self._did_shutdown_SHUT_WR def __repr__(self): @@ -564,7 +759,13 @@ async def _resolve_local_address_nocp(self, address): async def _resolve_remote_address_nocp(self, address): return await self._resolve_address_nocp(address, 0) - async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): + async def _nonblocking_helper( + self, + fn: Callable[..., _T], + args: Sequence[object], + kwargs: Mapping[str, object], + wait_fn: Callable[..., Awaitable[object]], + ) -> _T: # We have to reconcile two conflicting goals: # - We want to make it look like we always blocked in doing these # operations. The obvious way is to always do an IO wait before @@ -735,15 +936,23 @@ async def connect(self, address): # sendto ################################################################ - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) - async def sendto(self, *args): + @overload + async def sendto(self, data: bytes, address: _Address) -> int: + ... + + @overload + async def sendto(self, data: bytes, flags: int, address: _Address) -> int: + ... + + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] + async def sendto(self, *args: object) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address) # and kwargs are not accepted - args = list(args) - args[-1] = await self._resolve_remote_address_nocp(args[-1]) + list_args = list(args) + list_args[-1] = await self._resolve_remote_address_nocp(list_args[-1]) return await self._nonblocking_helper( - _stdlib_socket.socket.sendto, args, {}, _core.wait_writable + _stdlib_socket.socket.sendto, list_args, {}, _core.wait_writable ) ################################################################ @@ -755,7 +964,7 @@ async def sendto(self, *args): ): @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) - async def sendmsg(self, *args): + async def sendmsg(self, *args: object) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. Only available on platforms where :meth:`socket.socket.sendmsg` is @@ -765,10 +974,10 @@ async def sendmsg(self, *args): # args is: buffers[, ancdata[, flags[, address]]] # and kwargs are not accepted if len(args) == 4 and args[-1] is not None: - args = list(args) - args[-1] = await self._resolve_remote_address_nocp(args[-1]) + list_args = list(args) + list_args[-1] = await self._resolve_remote_address_nocp(list_args[-1]) return await self._nonblocking_helper( - _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable + _stdlib_socket.socket.sendmsg, list_args, {}, _core.wait_writable ) ################################################################ diff --git a/trio/_ssl.py b/trio/_ssl.py index c4ffa3ddbe..a163f9dee0 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -202,8 +202,8 @@ class _Once: def __init__(self, afn, *args): self._afn = afn self._args = args - self.started = False - self._done = _sync.Event() + self.started: bool = False + self._done: _sync.Event = _sync.Event() async def ensure(self, *, checkpoint): if not self.started: @@ -216,7 +216,7 @@ async def ensure(self, *, checkpoint): await self._done.wait() @property - def done(self): + def done(self) -> bool: return self._done.is_set() diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 6b29c4465c..e3944d9673 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -128,11 +128,17 @@ class Process(AsyncResource, metaclass=NoPublicConstructor): # arbitrarily many threads if wait() keeps getting cancelled. _wait_for_exit_data = None - def __init__(self, popen, stdin, stdout, stderr): + def __init__( + self, + popen: subprocess.Popen, + stdin: Optional[SendStream], + stdout: Optional[ReceiveStream], + stderr: Optional[ReceiveStream], + ): self._proc = popen - self.stdin = stdin # type: Optional[SendStream] - self.stdout = stdout # type: Optional[ReceiveStream] - self.stderr = stderr # type: Optional[ReceiveStream] + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr self.stdio = None # type: Optional[StapledStream] if self.stdin is not None and self.stdout is not None: @@ -170,7 +176,7 @@ def __repr__(self): return "".format(self.args, status) @property - def returncode(self): + def returncode(self) -> Optional[int]: """The exit status of the process (an integer), or ``None`` if it's still running. diff --git a/trio/_subprocess_platform/windows.py b/trio/_subprocess_platform/windows.py index 958be8675c..816da4b203 100644 --- a/trio/_subprocess_platform/windows.py +++ b/trio/_subprocess_platform/windows.py @@ -3,4 +3,4 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: - await WaitForSingleObject(int(process._proc._handle)) + await WaitForSingleObject(int(process._proc._handle)) # type: ignore[attr-defined] diff --git a/trio/_sync.py b/trio/_sync.py index bed339ef6b..526be8e318 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -1,4 +1,7 @@ import math +from typing import Optional, Type, TypeVar, Union + +from typing_extensions import Protocol import attr import outcome @@ -6,6 +9,7 @@ import trio from ._core import enable_ki_protection, ParkingLot +from ._core._run import Task from ._deprecate import deprecated from ._util import Final @@ -40,12 +44,12 @@ class Event(metaclass=Final): _lot = attr.ib(factory=ParkingLot, init=False) _flag = attr.ib(default=False, init=False) - def is_set(self): + def is_set(self) -> bool: """Return the current value of the internal flag.""" return self._flag @enable_ki_protection - def set(self): + def set(self) -> None: """Set the internal flag value to True, and wake any waiting tasks.""" self._flag = True self._lot.unpark_all() @@ -73,20 +77,31 @@ def statistics(self): return self._lot.statistics() -def async_cm(cls): +class _HasAcquire(Protocol): + async def acquire(self) -> object: + ... + + def release(self) -> object: + ... + + +_TA = TypeVar("_TA", bound=_HasAcquire) + + +def async_cm(cls: Type[_TA]) -> Type[_TA]: @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self: _TA) -> None: await self.acquire() __aenter__.__qualname__ = cls.__qualname__ + ".__aenter__" - cls.__aenter__ = __aenter__ + cls.__aenter__ = __aenter__ # type: ignore[attr-defined] @enable_ki_protection - async def __aexit__(self, *args): + async def __aexit__(self: _TA, *args: object) -> None: self.release() __aexit__.__qualname__ = cls.__qualname__ + ".__aexit__" - cls.__aexit__ = __aexit__ + cls.__aexit__ = __aexit__ # type: ignore[attr-defined] return cls @@ -153,6 +168,8 @@ class CapacityLimiter(metaclass=Final): """ + _total_tokens: int + def __init__(self, total_tokens): self._lot = ParkingLot() self._borrowers = set() @@ -168,7 +185,7 @@ def __repr__(self): ) @property - def total_tokens(self): + def total_tokens(self) -> int: """The total capacity available. You can change :attr:`total_tokens` by assigning to this attribute. If @@ -197,17 +214,17 @@ def _wake_waiters(self): self._borrowers.add(self._pending_borrowers.pop(woken)) @property - def borrowed_tokens(self): + def borrowed_tokens(self) -> int: """The amount of capacity that's currently in use.""" return len(self._borrowers) @property - def available_tokens(self): + def available_tokens(self) -> int: """The amount of capacity that's available to use.""" return self.total_tokens - self.borrowed_tokens @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Borrow a token from the sack, without blocking. Raises: @@ -219,7 +236,7 @@ def acquire_nowait(self): self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) @enable_ki_protection - def acquire_on_behalf_of_nowait(self, borrower): + def acquire_on_behalf_of_nowait(self, borrower: Union[object, Task]) -> None: """Borrow a token from the sack on behalf of ``borrower``, without blocking. @@ -248,7 +265,7 @@ def acquire_on_behalf_of_nowait(self, borrower): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Borrow a token from the sack, blocking if necessary. Raises: @@ -259,7 +276,7 @@ async def acquire(self): await self.acquire_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - async def acquire_on_behalf_of(self, borrower): + async def acquire_on_behalf_of(self, borrower: Union[object, Task]) -> None: """Borrow a token from the sack on behalf of ``borrower``, blocking if necessary. @@ -288,7 +305,7 @@ async def acquire_on_behalf_of(self, borrower): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Put a token back into the sack. Raises: @@ -299,7 +316,7 @@ def release(self): self.release_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - def release_on_behalf_of(self, borrower): + def release_on_behalf_of(self, borrower: Union[object, Task]) -> None: """Put a token back into the sack on behalf of ``borrower``. Raises: @@ -369,7 +386,7 @@ class Semaphore(metaclass=Final): """ - def __init__(self, initial_value, *, max_value=None): + def __init__(self, initial_value: int, *, max_value: Optional[int] = None): if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") if initial_value < 0: @@ -397,17 +414,17 @@ def __repr__(self): ) @property - def value(self): + def value(self) -> int: """The current value of the semaphore.""" return self._value @property - def max_value(self): + def max_value(self) -> Optional[int]: """The maximum allowed value. May be None to indicate no limit.""" return self._max_value @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to decrement the semaphore value, without blocking. Raises: @@ -421,7 +438,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Decrement the semaphore value, blocking if necessary to avoid letting it drop below zero. @@ -435,7 +452,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Increment the semaphore value, possibly waking a task blocked in :meth:`acquire`. @@ -498,7 +515,7 @@ def locked(self): return self._owner is not None @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the lock, without blocking. Raises: @@ -516,7 +533,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Acquire the lock, blocking if necessary.""" await trio.lowlevel.checkpoint_if_cancelled() try: @@ -530,7 +547,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Release the lock. Raises: @@ -696,7 +713,7 @@ def release(self): self._lock.release() @enable_ki_protection - async def wait(self): + async def wait(self) -> None: """Wait for another task to call :meth:`notify` or :meth:`notify_all`. diff --git a/trio/_threads.py b/trio/_threads.py index 648b87d801..e3b0f0abe0 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -3,6 +3,7 @@ import threading import queue as stdlib_queue from itertools import count +from typing import Awaitable, Callable, Optional, Sequence, TypeVar import attr import inspect @@ -23,13 +24,16 @@ # Global due to Threading API, thread local storage for trio token TOKEN_LOCAL = threading.local() -_limiter_local = RunVar("limiter") +_limiter_local = RunVar[CapacityLimiter]("limiter") # I pulled this number out of the air; it isn't based on anything. Probably we # should make some kind of measurements to pick a good value. DEFAULT_LIMIT = 40 _thread_counter = count() +_T = TypeVar("_T") + + def current_default_thread_limiter(): """Get the default `~trio.CapacityLimiter` used by `trio.to_thread.run_sync`. @@ -55,8 +59,16 @@ class ThreadPlaceholder: name = attr.ib() +# TODO: maybe we don't want to ban Any in decorated functions? just any? maybe +# then we would just ignore the line with the Any? (it is the callable's +# unspecified parameter list) @enable_ki_protection -async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None): +async def to_thread_run_sync( # type: ignore[misc] + sync_fn: Callable[..., _T], + *args: object, + cancellable: bool = False, + limiter: Optional[CapacityLimiter] = None, +) -> _T: """Convert a blocking operation into an async operation using a thread. These two lines are equivalent:: @@ -204,7 +216,7 @@ def abort(_): else: return trio.lowlevel.Abort.FAILED - return await trio.lowlevel.wait_task_rescheduled(abort) + return await trio.lowlevel.wait_task_rescheduled(abort) # type: ignore[return-value] def _run_fn_as_system_task(cb, fn, *args, trio_token=None): @@ -238,7 +250,11 @@ def _run_fn_as_system_task(cb, fn, *args, trio_token=None): return q.get().unwrap() -def from_thread_run(afn, *args, trio_token=None): +def from_thread_run( + afn: Callable[..., Awaitable[_T]], + *args: object, + trio_token: Optional[TrioToken] = None, +) -> _T: """Run the given async function in the parent Trio thread, blocking until it is complete. @@ -273,11 +289,13 @@ def from_thread_run(afn, *args, trio_token=None): to enter Trio. """ - def callback(q, afn, args): + def callback( + q: stdlib_queue.Queue, afn: Callable[..., Awaitable[_T]], args: Sequence[object] + ) -> None: @disable_ki_protection - async def unprotected_afn(): + async def unprotected_afn() -> _T: coro = coroutine_or_error(afn, *args) - return await coro + return await coro # type: ignore[no-any-return] async def await_in_trio_thread_task(): q.put_nowait(await outcome.acapture(unprotected_afn)) @@ -289,10 +307,12 @@ async def await_in_trio_thread_task(): outcome.Error(trio.RunFinishedError("system nursery is closed")) ) - return _run_fn_as_system_task(callback, afn, *args, trio_token=trio_token) + return _run_fn_as_system_task(callback, afn, *args, trio_token=trio_token) # type: ignore[no-any-return] -def from_thread_run_sync(fn, *args, trio_token=None): +def from_thread_run_sync( + fn: Callable[..., _T], *args: object, trio_token: Optional[TrioToken] = None +) -> _T: """Run the given sync function in the parent Trio thread, blocking until it is complete. @@ -323,14 +343,16 @@ def from_thread_run_sync(fn, *args, trio_token=None): to enter Trio. """ - def callback(q, fn, args): + def callback( + q: stdlib_queue.Queue, fn: Callable[..., _T], args: Sequence[object] + ) -> None: @disable_ki_protection - def unprotected_fn(): + def unprotected_fn() -> _T: ret = fn(*args) if inspect.iscoroutine(ret): # Manually close coroutine to avoid RuntimeWarnings - ret.close() + ret.close() # type: ignore[attr-defined] raise TypeError( "Trio expected a sync function, but {!r} appears to be " "asynchronous".format(getattr(fn, "__qualname__", fn)) @@ -341,4 +363,4 @@ def unprotected_fn(): res = outcome.capture(unprotected_fn) q.put_nowait(res) - return _run_fn_as_system_task(callback, fn, *args, trio_token=trio_token) + return _run_fn_as_system_task(callback, fn, *args, trio_token=trio_token) # type: ignore[no-any-return] diff --git a/trio/_timeouts.py b/trio/_timeouts.py index 1f7878f89e..1683b6157a 100644 --- a/trio/_timeouts.py +++ b/trio/_timeouts.py @@ -1,9 +1,10 @@ from contextlib import contextmanager +from typing import Iterator import trio -def move_on_at(deadline): +def move_on_at(deadline: float) -> trio.CancelScope: """Use as a context manager to create a cancel scope with the given absolute deadline. @@ -84,7 +85,7 @@ class TooSlowError(Exception): @contextmanager -def fail_at(deadline): +def fail_at(deadline: float) -> Iterator[trio.CancelScope]: """Creates a cancel scope with the given deadline, and raises an error if it is actually cancelled. diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 1c28e88d64..4424550020 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -45,7 +45,7 @@ def __init__(self, fd: int): os.set_blocking(fd, False) @property - def closed(self): + def closed(self) -> bool: return self.fd == -1 def _raw_close(self): diff --git a/trio/_windows_pipes.py b/trio/_windows_pipes.py index 025c8742a7..8b3cb93e0f 100644 --- a/trio/_windows_pipes.py +++ b/trio/_windows_pipes.py @@ -22,7 +22,7 @@ def __init__(self, handle: int) -> None: _core.register_with_iocp(self.handle) @property - def closed(self): + def closed(self) -> bool: return self.handle == -1 def _close(self): diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 7a9006ff43..c1760bd62a 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -2,6 +2,7 @@ from contextlib import contextmanager import random +from typing import Iterator, Type from .. import _core from .._highlevel_generic import aclose_forcefully @@ -24,7 +25,7 @@ async def __aexit__(self, *args): @contextmanager -def _assert_raises(exc): +def _assert_raises(exc: Type[Exception]) -> Iterator[None]: __tracebackhide__ = True try: yield diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py index 5804295300..2f4be323a9 100644 --- a/trio/testing/_checkpoints.py +++ b/trio/testing/_checkpoints.py @@ -1,10 +1,11 @@ from contextlib import contextmanager +from typing import Iterator from .. import _core @contextmanager -def _assert_yields_or_not(expected): +def _assert_yields_or_not(expected: bool) -> Iterator[None]: __tracebackhide__ = True task = _core.current_task() orig_cancel = task._cancel_points diff --git a/trio/testing/_sequencer.py b/trio/testing/_sequencer.py index 0922e1c9ad..4a1d42eca8 100644 --- a/trio/testing/_sequencer.py +++ b/trio/testing/_sequencer.py @@ -57,7 +57,7 @@ async def main(): _broken = attr.ib(default=False, init=False) @asynccontextmanager - async def __call__(self, position: int) -> AsyncIterator[None]: + async def __call__(self, position: int) -> AsyncIterator[None]: # type: ignore[misc] if position in self._claimed: raise RuntimeError("Attempted to re-use sequence point {}".format(position)) if self._broken: diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index 4fcaeae372..c372811a65 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -1,9 +1,12 @@ from functools import wraps, partial +from typing import Any, Callable, TypeVar from .. import _core from ..abc import Clock, Instrument +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) + # Use: # # @trio_test @@ -12,9 +15,11 @@ # # Also: if a pytest fixture is passed in that subclasses the Clock abc, then # that clock is passed to trio.run(). -def trio_test(fn): - @wraps(fn) - def wrapper(**kwargs): +def trio_test(fn: _Fn) -> _Fn: + wrapper: _Fn + + @wraps(fn) # type: ignore[no-redef] + def wrapper(**kwargs: object) -> object: __tracebackhide__ = True clocks = [c for c in kwargs.values() if isinstance(c, Clock)] if not clocks: diff --git a/trio/tests/conftest.py b/trio/tests/conftest.py index 772486e1eb..f7e92662d8 100644 --- a/trio/tests/conftest.py +++ b/trio/tests/conftest.py @@ -5,7 +5,9 @@ # this stuff should become a proper pytest plugin import pytest +import _pytest.python import inspect +from typing import Callable from ..testing import trio_test, MockClock @@ -22,12 +24,12 @@ def pytest_configure(config): @pytest.fixture -def mock_clock(): +def mock_clock() -> MockClock: return MockClock() @pytest.fixture -def autojump_clock(): +def autojump_clock() -> MockClock: return MockClock(autojump_threshold=0) @@ -36,6 +38,6 @@ def autojump_clock(): # guess it's useful with the class- and file-level marking machinery (where # the raw @trio_test decorator isn't enough). @pytest.hookimpl(tryfirst=True) -def pytest_pyfunc_call(pyfuncitem): +def pytest_pyfunc_call(pyfuncitem: _pytest.python.Function) -> None: # type: ignore[misc] if inspect.iscoroutinefunction(pyfuncitem.obj): pyfuncitem.obj = trio_test(pyfuncitem.obj) diff --git a/trio/tests/module_with_deprecations.py b/trio/tests/module_with_deprecations.py index ed51f150c3..7db72e16c6 100644 --- a/trio/tests/module_with_deprecations.py +++ b/trio/tests/module_with_deprecations.py @@ -19,3 +19,6 @@ "value2", "1.2", issue=1, instead="instead-string" ), } + +dep1 = None +dep2 = None diff --git a/trio/tests/test_deprecate.py b/trio/tests/test_deprecate.py index e5e1da8c5f..5a53655ac4 100644 --- a/trio/tests/test_deprecate.py +++ b/trio/tests/test_deprecate.py @@ -1,4 +1,5 @@ import pytest +import _pytest.recwarn import inspect import warnings @@ -14,7 +15,9 @@ @pytest.fixture -def recwarn_always(recwarn): +def recwarn_always( + recwarn: _pytest.recwarn.WarningsRecorder, +) -> _pytest.recwarn.WarningsRecorder: warnings.simplefilter("always") # ResourceWarnings about unclosed sockets can occur nondeterministically # (during GC) which throws off the tests in this file @@ -27,7 +30,7 @@ def _here(): return (info.filename, info.lineno) -def test_warn_deprecated(recwarn_always): +def test_warn_deprecated(recwarn_always: _pytest.recwarn.WarningsRecorder) -> None: def deprecated_thing(): warn_deprecated("ice", "1.2", issue=1, instead="water") @@ -35,6 +38,7 @@ def deprecated_thing(): filename, lineno = _here() assert len(recwarn_always) == 1 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "ice is deprecated" in got.message.args[0] assert "Trio 1.2" in got.message.args[0] assert "water instead" in got.message.args[0] @@ -43,17 +47,22 @@ def deprecated_thing(): assert got.lineno == lineno - 1 -def test_warn_deprecated_no_instead_or_issue(recwarn_always): +def test_warn_deprecated_no_instead_or_issue( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: # Explicitly no instead or issue warn_deprecated("water", "1.3", issue=None, instead=None) assert len(recwarn_always) == 1 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "water is deprecated" in got.message.args[0] assert "no replacement" in got.message.args[0] assert "Trio 1.3" in got.message.args[0] -def test_warn_deprecated_stacklevel(recwarn_always): +def test_warn_deprecated_stacklevel( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: def nested1(): nested2() @@ -75,21 +84,25 @@ def new(): # pragma: no cover pass -def test_warn_deprecated_formatting(recwarn_always): +def test_warn_deprecated_formatting( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: warn_deprecated(old, "1.0", issue=1, instead=new) got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.old is deprecated" in got.message.args[0] assert "test_deprecate.new instead" in got.message.args[0] @deprecated("1.5", issue=123, instead=new) -def deprecated_old(): +def deprecated_old() -> int: return 3 -def test_deprecated_decorator(recwarn_always): +def test_deprecated_decorator(recwarn_always: _pytest.recwarn.WarningsRecorder) -> None: assert deprecated_old() == 3 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0] assert "1.5" in got.message.args[0] assert "test_deprecate.new" in got.message.args[0] @@ -98,25 +111,31 @@ def test_deprecated_decorator(recwarn_always): class Foo: @deprecated("1.0", issue=123, instead="crying") - def method(self): + def method(self) -> int: return 7 -def test_deprecated_decorator_method(recwarn_always): +def test_deprecated_decorator_method( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: f = Foo() assert f.method() == 7 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.Foo.method is deprecated" in got.message.args[0] @deprecated("1.2", thing="the thing", issue=None, instead=None) -def deprecated_with_thing(): +def deprecated_with_thing() -> int: return 72 -def test_deprecated_decorator_with_explicit_thing(recwarn_always): +def test_deprecated_decorator_with_explicit_thing( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: assert deprecated_with_thing() == 72 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "the thing is deprecated" in got.message.args[0] @@ -127,14 +146,16 @@ def new_hotness(): old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1) -def test_deprecated_alias(recwarn_always): +def test_deprecated_alias(recwarn_always: _pytest.recwarn.WarningsRecorder) -> None: assert old_hotness() == "new hotness" got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.old_hotness is deprecated" in got.message.args[0] assert "1.23" in got.message.args[0] assert "test_deprecate.new_hotness instead" in got.message.args[0] assert "issues/1" in got.message.args[0] + assert old_hotness.__doc__ is not None assert ".. deprecated:: 1.23" in old_hotness.__doc__ assert "test_deprecate.new_hotness instead" in old_hotness.__doc__ assert "issues/1>`__" in old_hotness.__doc__ @@ -149,36 +170,39 @@ def new_hotness_method(self): ) -def test_deprecated_alias_method(recwarn_always): +def test_deprecated_alias_method( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: obj = Alias() assert obj.old_hotness_method() == "new hotness method" got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) msg = got.message.args[0] assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg assert "test_deprecate.Alias.new_hotness_method instead" in msg @deprecated("2.1", issue=1, instead="hi") -def docstring_test1(): # pragma: no cover +def docstring_test1() -> None: # pragma: no cover """Hello!""" @deprecated("2.1", issue=None, instead="hi") -def docstring_test2(): # pragma: no cover +def docstring_test2() -> None: # pragma: no cover """Hello!""" @deprecated("2.1", issue=1, instead=None) -def docstring_test3(): # pragma: no cover +def docstring_test3() -> None: # pragma: no cover """Hello!""" @deprecated("2.1", issue=None, instead=None) -def docstring_test4(): # pragma: no cover +def docstring_test4() -> None: # pragma: no cover """Hello!""" -def test_deprecated_docstring_munging(): +def test_deprecated_docstring_munging() -> None: assert ( docstring_test1.__doc__ == """Hello! @@ -220,7 +244,9 @@ def test_deprecated_docstring_munging(): ) -def test_module_with_deprecations(recwarn_always): +def test_module_with_deprecations( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: assert module_with_deprecations.regular == "hi" assert len(recwarn_always) == 0 @@ -230,6 +256,7 @@ def test_module_with_deprecations(recwarn_always): assert got.filename == filename assert got.lineno == lineno + 1 + assert isinstance(got.message, Warning) assert "module_with_deprecations.dep1" in got.message.args[0] assert "Trio 1.1" in got.message.args[0] assert "/issues/1" in got.message.args[0] @@ -237,7 +264,8 @@ def test_module_with_deprecations(recwarn_always): assert module_with_deprecations.dep2 == "value2" got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "instead-string instead" in got.message.args[0] with pytest.raises(AttributeError): - module_with_deprecations.asdf + module_with_deprecations.asdf # type: ignore [attr-defined] diff --git a/trio/tests/test_exports.py b/trio/tests/test_exports.py index 374ce8c044..0ad7c7a20f 100644 --- a/trio/tests/test_exports.py +++ b/trio/tests/test_exports.py @@ -69,7 +69,7 @@ def public_modules(module): ) @pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES) @pytest.mark.parametrize("tool", ["pylint", "jedi"]) -def test_static_tool_sees_all_symbols(tool, modname): +def test_static_tool_sees_all_symbols(tool: str, modname: str) -> None: module = importlib.import_module(modname) def no_underscores(symbols): diff --git a/trio/tests/test_file_io.py b/trio/tests/test_file_io.py index b40f7518a9..70266df699 100644 --- a/trio/tests/test_file_io.py +++ b/trio/tests/test_file_io.py @@ -1,27 +1,39 @@ import io import os +from typing import Union +import py.path import pytest from unittest import mock from unittest.mock import sentinel import trio from trio import _core -from trio._file_io import AsyncIOWrapper, _FILE_SYNC_ATTRS, _FILE_ASYNC_METHODS +from trio._file_io import ( + _AsyncTextIOBase, + _AsyncBufferedIOBase, + _AsyncRawIOBase, + _AsyncIOBase, + AsyncIOWrapper, + _FILE_SYNC_ATTRS, + _FILE_ASYNC_METHODS, +) @pytest.fixture -def path(tmpdir): +def path(tmpdir: py.path.local) -> str: return os.fspath(tmpdir.join("test")) @pytest.fixture -def wrapped(): +def wrapped() -> mock.Mock: return mock.Mock(spec_set=io.StringIO) @pytest.fixture -def async_file(wrapped): +def async_file( + wrapped: mock.Mock, +) -> Union[_AsyncTextIOBase, _AsyncBufferedIOBase, _AsyncRawIOBase, _AsyncIOBase]: return trio.wrap_file(wrapped) diff --git a/trio/tests/test_highlevel_open_tcp_listeners.py b/trio/tests/test_highlevel_open_tcp_listeners.py index d5fc576ec5..9c068293db 100644 --- a/trio/tests/test_highlevel_open_tcp_listeners.py +++ b/trio/tests/test_highlevel_open_tcp_listeners.py @@ -2,6 +2,7 @@ import socket as stdlib_socket import errno +from typing import Set import attr @@ -9,6 +10,7 @@ from trio import open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream from trio.testing import open_stream_to_socket_listener from .. import socket as tsocket +from .._abc import SocketFactory from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6 @@ -53,7 +55,7 @@ async def test_open_tcp_listeners_specific_port_specific_host(): @binds_ipv6 -async def test_open_tcp_listeners_ipv6_v6only(): +async def test_open_tcp_listeners_ipv6_v6only() -> None: # Check IPV6_V6ONLY is working properly (ipv6_listener,) = await open_tcp_listeners(0, host="::1") async with ipv6_listener: @@ -141,7 +143,7 @@ def close(self): @attr.s -class FakeSocketFactory: +class FakeSocketFactory(SocketFactory): poison_after = attr.ib() sockets = attr.ib(factory=list) raise_on_family = attr.ib(factory=dict) # family => errno @@ -222,8 +224,9 @@ async def handler(stream): [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}], ) async def test_open_tcp_listeners_some_address_families_unavailable( - try_families, fail_families -): + try_families: Set[stdlib_socket.AddressFamily], + fail_families: Set[stdlib_socket.AddressFamily], +) -> None: fsf = FakeSocketFactory( 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families} ) diff --git a/trio/tests/test_highlevel_open_unix_stream.py b/trio/tests/test_highlevel_open_unix_stream.py index 211aff3e70..58e4cdf5ec 100644 --- a/trio/tests/test_highlevel_open_unix_stream.py +++ b/trio/tests/test_highlevel_open_unix_stream.py @@ -29,7 +29,7 @@ def close(self): @pytest.mark.parametrize("filename", [4, 4.5]) -async def test_open_with_bad_filename_type(filename): +async def test_open_with_bad_filename_type(filename: float) -> None: with pytest.raises(TypeError): await open_unix_socket(filename) diff --git a/trio/tests/test_path.py b/trio/tests/test_path.py index 284bcf82dd..54fb11dc5d 100644 --- a/trio/tests/test_path.py +++ b/trio/tests/test_path.py @@ -1,15 +1,17 @@ import os import pathlib +from typing import Callable, Type, Union +import py.path import pytest import trio -from trio._path import AsyncAutoWrapperType as Type +from trio._path import AsyncAutoWrapperType as WrapperType from trio._file_io import AsyncIOWrapper @pytest.fixture -def path(tmpdir): +def path(tmpdir: py.path.local) -> trio.Path: p = str(tmpdir.join("test")) return trio.Path(p) @@ -42,7 +44,10 @@ async def test_magic(): @pytest.mark.parametrize("cls_a,cls_b", cls_pairs) -async def test_cmp_magic(cls_a, cls_b): +async def test_cmp_magic( + cls_a: Union[Type[pathlib.Path], Type[trio.Path]], + cls_b: Union[Type[pathlib.Path], Type[trio.Path]], +) -> None: a, b = cls_a(""), cls_b("") assert a == b assert not a != b @@ -69,7 +74,10 @@ async def test_cmp_magic(cls_a, cls_b): @pytest.mark.parametrize("cls_a,cls_b", cls_pairs) -async def test_div_magic(cls_a, cls_b): +async def test_div_magic( + cls_a: Union[Type[pathlib.Path], Type[trio.Path]], + cls_b: Union[Type[pathlib.Path], Type[trio.Path]], +) -> None: a, b = cls_a("a"), cls_b("b") result = a / b @@ -81,7 +89,11 @@ async def test_div_magic(cls_a, cls_b): "cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] ) @pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"]) -async def test_hash_magic(cls_a, cls_b, path): +async def test_hash_magic( + cls_a: Union[Type[pathlib.Path], Type[trio.Path]], + cls_b: Union[Type[pathlib.Path], Type[trio.Path]], + path: str, +) -> None: a, b = cls_a(path), cls_b(path) assert hash(a) == hash(b) @@ -103,7 +115,7 @@ async def test_async_method_signature(path): @pytest.mark.parametrize("method_name", ["is_dir", "is_file"]) -async def test_compare_async_stat_methods(method_name): +async def test_compare_async_stat_methods(method_name: str) -> None: method, async_method = method_pair(".", method_name) @@ -119,7 +131,7 @@ async def test_invalid_name_not_wrapped(path): @pytest.mark.parametrize("method_name", ["absolute", "resolve"]) -async def test_async_methods_rewrap(method_name): +async def test_async_methods_rewrap(method_name: str) -> None: method, async_method = method_pair(".", method_name) @@ -168,28 +180,28 @@ class MockWrapper: async def test_type_forwards_unsupported(): with pytest.raises(TypeError): - Type.generate_forwards(MockWrapper, {}) + WrapperType.generate_forwards(MockWrapper, {}) async def test_type_wraps_unsupported(): with pytest.raises(TypeError): - Type.generate_wraps(MockWrapper, {}) + WrapperType.generate_wraps(MockWrapper, {}) async def test_type_forwards_private(): - Type.generate_forwards(MockWrapper, {"unsupported": None}) + WrapperType.generate_forwards(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") async def test_type_wraps_private(): - Type.generate_wraps(MockWrapper, {"unsupported": None}) + WrapperType.generate_wraps(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") @pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath]) -async def test_path_wraps_path(path, meth): +async def test_path_wraps_path(path: trio.Path, meth: Callable[..., trio.Path]) -> None: # type: ignore[misc] wrapped = await path.absolute() result = meth(path, wrapped) if result is None: diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index f8c061ffd3..c307f61284 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -1,6 +1,7 @@ import errno import pytest +import _pytest.monkepatch import attr import os @@ -50,7 +51,7 @@ def getaddrinfo(self, *args, **kwargs): @pytest.fixture -def monkeygai(monkeypatch): +def monkeygai(monkeypatch: _pytest.monkepatch.MonkeyPatch) -> MonkeypatchedGAI: # type: ignore[misc] controller = MonkeypatchedGAI(stdlib_socket.getaddrinfo) monkeypatch.setattr(stdlib_socket, "getaddrinfo", controller.getaddrinfo) return controller @@ -254,12 +255,12 @@ async def child(sock): @pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only") -async def test_fromshare(): +async def test_fromshare() -> None: a, b = tsocket.socketpair() with a, b: # share with ourselves shared = a.share(os.getpid()) - a2 = tsocket.fromshare(shared) + a2 = tsocket.fromshare(shared) # type: ignore[attr-defined] with a2: assert a.fileno() != a2.fileno() await a2.send(b"x") @@ -273,14 +274,14 @@ async def test_socket(): @creates_ipv6 -async def test_socket_v6(): +async def test_socket_v6() -> None: with tsocket.socket(tsocket.AF_INET6, tsocket.SOCK_DGRAM) as s: assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET6 @pytest.mark.skipif(not _sys.platform == "linux", reason="linux only") -async def test_sniff_sockopts(): +async def test_sniff_sockopts() -> None: from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM # generate the combinations of families/types we're testing: @@ -408,7 +409,9 @@ async def test_SocketType_shutdown(): pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6), ], ) -async def test_SocketType_simple_server(address, socket_type): +async def test_SocketType_simple_server( + address: str, socket_type: stdlib_socket.AddressFamily +) -> None: # listen, bind, accept, connect, getpeername, getsockname listener = tsocket.socket(socket_type) client = tsocket.socket(socket_type) @@ -481,7 +484,9 @@ class Addresses: ), ], ) -async def test_SocketType_resolve(socket_type, addrs): +async def test_SocketType_resolve( + socket_type: stdlib_socket.AddressFamily, addrs: Addresses +) -> None: v6 = socket_type == tsocket.AF_INET6 def pad(addr): @@ -498,9 +503,9 @@ def assert_eq(actual, expected): # getaddrinfo They also error out on None, but whatever, None is much # more consistent, so we accept it too. for null in [None, ""]: - got = await sock._resolve_local_address_nocp((null, 80)) + got = await sock._resolve_local_address_nocp((null, 80)) # type: ignore[attr-defined] assert_eq(got, (addrs.bind_all, 80)) - got = await sock._resolve_remote_address_nocp((null, 80)) + got = await sock._resolve_remote_address_nocp((null, 80)) # type: ignore[attr-defined] assert_eq(got, (addrs.localhost, 80)) # AI_PASSIVE only affects the wildcard address, so for everything else @@ -935,7 +940,7 @@ async def test_SocketType_is_abstract(): @pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets") -async def test_unix_domain_socket(): +async def test_unix_domain_socket() -> None: # Bind has a special branch to use a thread, since it has to do filesystem # traversal. Maybe connect should too? Not sure. diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index f160af4999..aa5a6a463f 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -1,10 +1,12 @@ import pytest +import _pytest.fixtures import threading import socket as stdlib_socket import ssl from contextlib import contextmanager from functools import partial +from typing import AsyncIterator, Iterator from OpenSSL import SSL import trustme @@ -26,6 +28,7 @@ assert_checkpoints, Sequencer, memory_stream_pair, + MockClock, lockstep_stream_pair, check_two_way_stream, ) @@ -71,7 +74,7 @@ @pytest.fixture(scope="module", params=client_ctx_params) -def client_ctx(request): +def client_ctx(request: _pytest.fixtures.SubRequest) -> ssl.SSLContext: ctx = ssl.create_default_context() TRIO_TEST_CA.configure_trust(ctx) if request.param in ["default", "tls13"]: @@ -141,7 +144,7 @@ def ssl_echo_serve_sync(sock, *, expect_fail=False): # (running in a thread). Useful for testing making connections with different # SSLContexts. @asynccontextmanager -async def ssl_echo_server_raw(**kwargs): +async def ssl_echo_server_raw(**kwargs: object) -> AsyncIterator[SocketStream]: # type: ignore[misc] a, b = stdlib_socket.socketpair() async with trio.open_nursery() as nursery: # Exiting the 'with a, b' context manager closes the sockets, which @@ -158,7 +161,9 @@ async def ssl_echo_server_raw(**kwargs): # Fixture that gives a properly set up SSLStream connected to a trio-test-1 # echo server (running in a thread) @asynccontextmanager -async def ssl_echo_server(client_ctx, **kwargs): +async def ssl_echo_server( # type: ignore[misc] + client_ctx: ssl.SSLContext, **kwargs: object +) -> AsyncIterator[SSLStream]: async with ssl_echo_server_raw(**kwargs) as sock: yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") @@ -357,7 +362,9 @@ async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): @contextmanager -def virtual_ssl_echo_server(client_ctx, **kwargs): +def virtual_ssl_echo_server( + client_ctx: ssl.SSLContext, **kwargs: object +) -> Iterator[SSLStream]: fakesock = PyOpenSSLEchoStream(**kwargs) yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org") @@ -576,7 +583,9 @@ async def test_renegotiation_simple(client_ctx): @slow -async def test_renegotiation_randomized(mock_clock, client_ctx): +async def test_renegotiation_randomized( + mock_clock: MockClock, client_ctx: ssl.SSLContext +) -> None: # The only blocking things in this function are our random sleeps, so 0 is # a good threshold. mock_clock.autojump_threshold = 0 @@ -793,7 +802,9 @@ async def test_send_all_empty_string(client_ctx): @pytest.mark.parametrize("https_compatible", [False, True]) -async def test_SSLStream_generic(client_ctx, https_compatible): +async def test_SSLStream_generic( + client_ctx: ssl.SSLContext, https_compatible: bool +) -> None: async def stream_maker(): return ssl_memory_stream_pair( client_ctx, diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index a9962a0aa2..5d3f94eb42 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -3,6 +3,7 @@ import subprocess import sys import pytest +import _pytest.monkeypatch import random from typing import Optional from functools import partial @@ -19,7 +20,7 @@ TrioDeprecationWarning, ) from .._core.tests.tutil import slow, skip_if_fbsd_pipes_broken -from ..testing import wait_all_tasks_blocked +from ..testing import MockClock, wait_all_tasks_blocked SIGKILL: Optional[signal.Signals] SIGTERM: Optional[signal.Signals] @@ -280,7 +281,7 @@ async def test_run_check(): @skip_if_fbsd_pipes_broken -async def test_run_with_broken_pipe(): +async def test_run_with_broken_pipe() -> None: result = await run_process( [sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072 ) @@ -391,7 +392,7 @@ async def test_one_signal(send_it, signum): @pytest.mark.skipif(not posix, reason="POSIX specific") -async def test_wait_reapable_fails(): +async def test_wait_reapable_fails() -> None: old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN) try: # With SIGCHLD disabled, the wait() syscall will wait for the @@ -410,7 +411,7 @@ async def test_wait_reapable_fails(): @slow -def test_waitid_eintr(): +def test_waitid_eintr() -> None: # This only matters on PyPy (where we're coding EINTR handling # ourselves) but the test works on all waitid platforms. from .._subprocess_platform import wait_child_exiting @@ -482,7 +483,9 @@ def broken_terminate(self): @pytest.mark.skipif(os.name != "posix", reason="posix only") -async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch): +async def test_warn_on_cancel_SIGKILL_escalation( + autojump_clock: MockClock, monkeypatch: _pytest.monkeypatch.MonkeyPatch +) -> None: monkeypatch.setattr(Process, "terminate", lambda *args: None) with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"): diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index 944adec2cf..b82351af76 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -248,7 +248,9 @@ def get__name__(fn: Callable) -> str: @pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=get__name__) -async def test_Lock_and_StrictFIFOLock(lockcls): +async def test_Lock_and_StrictFIFOLock( + lockcls: Union[Type[Lock], Type[StrictFIFOLock]] +) -> None: l = lockcls() # noqa assert not l.locked() @@ -260,7 +262,8 @@ async def test_Lock_and_StrictFIFOLock(lockcls): # make sure repr uses the right name for subclasses assert lockcls.__name__ in repr(l) with assert_checkpoints(): - async with l: + # TODO: hint async_cm + async with l: # type: ignore[union-attr] assert l.locked() repr(l) # smoke test (repr branches on locked/unlocked) assert not l.locked() @@ -500,11 +503,23 @@ def release(self): "lock_factory", lock_factories, ids=lock_factory_names ) +_LockFactory = Callable[ + [], + Union[ + CapacityLimiter, + Semaphore, + Lock, + StrictFIFOLock, + ChannelLock1, + ChannelLock2, + ChannelLock3, + ], +] # Spawn a bunch of workers that take a lock and then yield; make sure that # only one worker is ever in the critical section at a time. @generic_lock_test -async def test_generic_lock_exclusion(lock_factory): +async def test_generic_lock_exclusion(lock_factory: _LockFactory) -> None: LOOPS = 10 WORKERS = 5 in_critical_section = False @@ -533,7 +548,7 @@ async def worker(lock_like): # Several workers queue on the same lock; make sure they each get it, in # order. @generic_lock_test -async def test_generic_lock_fifo_fairness(lock_factory): +async def test_generic_lock_fifo_fairness(lock_factory: _LockFactory) -> None: initial_order = [] record = [] LOOPS = 5 @@ -557,7 +572,9 @@ async def loopy(name, lock_like): @generic_lock_test -async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory): +async def test_generic_lock_acquire_nowait_blocks_acquire( + lock_factory: _LockFactory, +) -> None: lock_like = lock_factory() record = [] diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index 9da2838cbd..a2b1fce23a 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -2,6 +2,7 @@ import queue as stdlib_queue import time +import attr import pytest from .. import _core @@ -277,7 +278,9 @@ async def child(): @pytest.mark.parametrize("MAX", [3, 5, 10]) @pytest.mark.parametrize("cancel", [False, True]) @pytest.mark.parametrize("use_default_limiter", [False, True]) -async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter): +async def test_run_in_worker_thread_limiter( + MAX: int, cancel: bool, use_default_limiter: bool +) -> None: # This test is a bit tricky. The goal is to make sure that if we set # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever # running at a time, even if there are more concurrent calls to @@ -306,13 +309,16 @@ async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter): # # Mutating them in-place is OK though (as long as you use proper # locking etc.). - class state: - pass - state.ran = 0 - state.high_water = 0 - state.running = 0 - state.parked = 0 + # TODO: does this break the concerns explained above...? + @attr.s() + class State: + ran: int = attr.ib(default=0) + high_water: int = attr.ib(default=0) + running: int = attr.ib(default=0) + parked: int = attr.ib(default=0) + + state = State() token = _core.current_trio_token() @@ -558,7 +564,7 @@ def not_called(): # pragma: no cover @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_from_thread_run_during_shutdown(): +def test_from_thread_run_during_shutdown() -> None: save = [] record = [] diff --git a/trio/tests/test_timeouts.py b/trio/tests/test_timeouts.py index 382c015b1d..97cb728954 100644 --- a/trio/tests/test_timeouts.py +++ b/trio/tests/test_timeouts.py @@ -42,8 +42,8 @@ async def check_takes_about(f, expected_dur): @slow -async def test_sleep(): - async def sleep_1(): +async def test_sleep() -> None: + async def sleep_1() -> None: await sleep_until(_core.current_time() + TARGET) await check_takes_about(sleep_1, TARGET) @@ -65,7 +65,7 @@ async def sleep_2(): @slow -async def test_move_on_after(): +async def test_move_on_after() -> None: with pytest.raises(ValueError): with move_on_after(-1): pass # pragma: no cover @@ -78,7 +78,7 @@ async def sleep_3(): @slow -async def test_fail(): +async def test_fail() -> None: async def sleep_4(): with fail_at(_core.current_time() + TARGET): await sleep(100) diff --git a/trio/tests/test_unix_pipes.py b/trio/tests/test_unix_pipes.py index a49de1796a..4511f08cf4 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/tests/test_unix_pipes.py @@ -13,6 +13,7 @@ pytestmark = pytest.mark.skipif(sys.platform == "win32", reason="posix only") +# mypy recognizes this. an assert would break the pytest skipif if sys.platform == "win32": with pytest.raises(AssertionError): # Using sys instead of FdStream since sys is created before the assertion that @@ -238,7 +239,7 @@ async def patched_wait_writable(*args, **kwargs): sys.platform.startswith("freebsd"), reason="no way to make read() return a bizarro error on FreeBSD", ) - async def test_bizarro_OSError_from_receive(): + async def test_bizarro_OSError_from_receive() -> None: # Make sure that if the read syscall returns some bizarro error, then we # get a BrokenResourceError. This is incredibly unlikely; there's almost # no way to trigger a failure here intentionally (except for EBADF, but we @@ -257,5 +258,5 @@ async def test_bizarro_OSError_from_receive(): os.close(dir_fd) @skip_if_fbsd_pipes_broken - async def test_pipe_fully(): + async def test_pipe_fully() -> None: await check_one_way_stream(make_pipe, make_clogged_pipe) diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index 2ea0a1e287..95c9c5fbb8 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -1,4 +1,5 @@ import signal +from typing import Iterator, TypeVar import pytest import trio @@ -16,6 +17,9 @@ from ..testing import wait_all_tasks_blocked +_T = TypeVar("_T") + + def test_signal_raise(): record = [] @@ -90,7 +94,7 @@ def not_main_thread(): # @coroutine is deprecated since python 3.8, which is fine with us. @pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning") -def test_coroutine_or_error(): +def test_coroutine_or_error() -> None: class Deferred: "Just kidding" @@ -106,7 +110,7 @@ async def f(): # pragma: no cover import asyncio @asyncio.coroutine - def generator_based_coro(): # pragma: no cover + def generator_based_coro() -> Iterator[None]: # pragma: no cover yield from asyncio.sleep(1) with pytest.raises(TypeError) as excinfo: @@ -146,14 +150,14 @@ async def async_gen(arg): # pragma: no cover del excinfo -def test_generic_function(): +def test_generic_function() -> None: @generic_function - def test_func(arg): + def test_func(arg: _T) -> _T: """Look, a docstring!""" return arg - assert test_func is test_func[int] is test_func[int, str] - assert test_func(42) == test_func[int](42) == 42 + assert test_func is test_func[int] is test_func[int, str] # type: ignore[index] + assert test_func(42) == test_func[int](42) == 42 # type: ignore[index] assert test_func.__doc__ == "Look, a docstring!" assert test_func.__qualname__ == "test_generic_function..test_func" assert test_func.__name__ == "test_func" diff --git a/trio/tests/test_wait_for_object.py b/trio/tests/test_wait_for_object.py index 38acfa802d..54af6b873b 100644 --- a/trio/tests/test_wait_for_object.py +++ b/trio/tests/test_wait_for_object.py @@ -73,7 +73,7 @@ async def test_WaitForMultipleObjects_sync(): @slow -async def test_WaitForMultipleObjects_sync_slow(): +async def test_WaitForMultipleObjects_sync_slow() -> None: # This does a series of test in which the main thread sync-waits for # handles, while we spawn a thread to set the handles after a short while. @@ -163,7 +163,7 @@ async def test_WaitForSingleObject(): @slow -async def test_WaitForSingleObject_slow(): +async def test_WaitForSingleObject_slow() -> None: # This does a series of test for setting the handle in another task, # and cancelling the wait task. From f8e83c62dc7f717cd07590d038406de85c4a2ea3 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 15:54:43 -0500 Subject: [PATCH 17/50] fixup tests --- trio/_socket.py | 5 ++--- trio/tests/module_with_deprecations.py | 3 --- trio/tests/test_deprecate.py | 4 ++-- trio/tests/test_socket.py | 4 ++-- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/trio/_socket.py b/trio/_socket.py index 6fe70714d8..69f36c7fc7 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -974,10 +974,9 @@ async def sendmsg(self, *args: object) -> int: # args is: buffers[, ancdata[, flags[, address]]] # and kwargs are not accepted if len(args) == 4 and args[-1] is not None: - list_args = list(args) - list_args[-1] = await self._resolve_remote_address_nocp(list_args[-1]) + args = (*args[:-1], await self._resolve_remote_address_nocp(args[-1])) return await self._nonblocking_helper( - _stdlib_socket.socket.sendmsg, list_args, {}, _core.wait_writable + _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable ) ################################################################ diff --git a/trio/tests/module_with_deprecations.py b/trio/tests/module_with_deprecations.py index 7db72e16c6..ed51f150c3 100644 --- a/trio/tests/module_with_deprecations.py +++ b/trio/tests/module_with_deprecations.py @@ -19,6 +19,3 @@ "value2", "1.2", issue=1, instead="instead-string" ), } - -dep1 = None -dep2 = None diff --git a/trio/tests/test_deprecate.py b/trio/tests/test_deprecate.py index 5a53655ac4..bf3743f395 100644 --- a/trio/tests/test_deprecate.py +++ b/trio/tests/test_deprecate.py @@ -251,7 +251,7 @@ def test_module_with_deprecations( assert len(recwarn_always) == 0 filename, lineno = _here() - assert module_with_deprecations.dep1 == "value1" + assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined] got = recwarn_always.pop(TrioDeprecationWarning) assert got.filename == filename assert got.lineno == lineno + 1 @@ -262,7 +262,7 @@ def test_module_with_deprecations( assert "/issues/1" in got.message.args[0] assert "value1 instead" in got.message.args[0] - assert module_with_deprecations.dep2 == "value2" + assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined] got = recwarn_always.pop(TrioDeprecationWarning) assert isinstance(got.message, Warning) assert "instead-string instead" in got.message.args[0] diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index c307f61284..1c7350f8bd 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -1,7 +1,7 @@ import errno import pytest -import _pytest.monkepatch +import _pytest.monkeypatch import attr import os @@ -51,7 +51,7 @@ def getaddrinfo(self, *args, **kwargs): @pytest.fixture -def monkeygai(monkeypatch: _pytest.monkepatch.MonkeyPatch) -> MonkeypatchedGAI: # type: ignore[misc] +def monkeygai(monkeypatch: _pytest.monkeypatch.MonkeyPatch) -> MonkeypatchedGAI: controller = MonkeypatchedGAI(stdlib_socket.getaddrinfo) monkeypatch.setattr(stdlib_socket, "getaddrinfo", controller.getaddrinfo) return controller From 83600f8b9dcc043ac1c86f24967654e7514f31ab Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 21:47:15 -0500 Subject: [PATCH 18/50] fix import loop --- trio/_core/_io_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 0b5dd034b7..4f4bd12a1a 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -860,7 +860,7 @@ def current_iocp(self) -> int: @_public def monitor_completion_key( self, - ) -> Iterator[Tuple[int, _core.UnboundedQueue[CompletionKeyEventInfo]]]: + ) -> Iterator[Tuple[int, "_core.UnboundedQueue[CompletionKeyEventInfo]"]]: key = next(self._completion_key_counter) queue = _core.UnboundedQueue[CompletionKeyEventInfo]() self._completion_key_queues[key] = queue From e4785033c0878db38359f9b826439b3df4f96f54 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 21:58:38 -0500 Subject: [PATCH 19/50] get Protocol from typing_extensions --- trio/_highlevel_open_unix_stream.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py index 3e294200f0..8922d6d68c 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/trio/_highlevel_open_unix_stream.py @@ -1,6 +1,7 @@ import os from contextlib import contextmanager -from typing import Iterator, Protocol, TypeVar +from typing import Iterator, TypeVar +from typing_extensions import Protocol import trio from trio.socket import socket, SOCK_STREAM From c34029b1173bea2516b8ef20eb2595f5d1f08a95 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 22:05:29 -0500 Subject: [PATCH 20/50] queue.Queue isn't typing.Queue --- trio/_core/tests/test_thread_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trio/_core/tests/test_thread_cache.py b/trio/_core/tests/test_thread_cache.py index 18bfd2a8c9..9f84ec4932 100644 --- a/trio/_core/tests/test_thread_cache.py +++ b/trio/_core/tests/test_thread_cache.py @@ -4,6 +4,7 @@ from queue import Queue import time import sys +import typing from outcome import Outcome @@ -61,7 +62,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None: # Make sure there are a few threads running, so if we weren't LIFO then we # could grab the wrong one. - q = Queue[Outcome]() + q: typing.Queue[Outcome] = Queue() COUNT = 5 for _ in range(COUNT): start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result)) From 0ec01a1df9e5fb22da9cdbc809a8bc60ca5ef984 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 22:08:13 -0500 Subject: [PATCH 21/50] queue.Queue isn't typing.Queue (again) --- trio/_core/tests/test_thread_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_core/tests/test_thread_cache.py b/trio/_core/tests/test_thread_cache.py index 9f84ec4932..3393646290 100644 --- a/trio/_core/tests/test_thread_cache.py +++ b/trio/_core/tests/test_thread_cache.py @@ -95,7 +95,7 @@ def test_idle_threads_exit(monkeypatch: _pytest.monkeypatch.MonkeyPatch) -> None # CPU.) monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) - q = Queue[threading.Thread]() + q: typing.Queue[threading.Thread] = Queue() start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread())) seen_thread = q.get() # Since the idle timeout is 0, after sleeping for 1 second, the thread From b33978a0f9969efac17f437730d4e4b9ae386393 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 22:25:19 -0500 Subject: [PATCH 22/50] try again --- trio/_channel.py | 4 ++-- trio/_core/_parking_lot.py | 2 +- trio/_core/tests/test_thread_cache.py | 5 ++--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/trio/_channel.py b/trio/_channel.py index 85a4ae618f..77cd16660b 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -103,9 +103,9 @@ class MemoryChannelState(Generic[_T_contra]): open_send_channels: int = attr.ib(default=0) open_receive_channels: int = attr.ib(default=0) # {task: value} - send_tasks: typing.OrderedDict[Task, _T_contra] = attr.ib(factory=OrderedDict) + send_tasks: "OrderedDict[Task, _T_contra]" = attr.ib(factory=OrderedDict) # {task: None} - receive_tasks: typing.OrderedDict[Task, None] = attr.ib(factory=OrderedDict) + receive_tasks: "OrderedDict[Task, None]" = attr.ib(factory=OrderedDict) def statistics(self): return MemoryChannelStats( diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index 237bd5677d..b2130417b2 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -103,7 +103,7 @@ class ParkingLot(metaclass=Final): # {task: None}, we just want a deque where we can quickly delete random # items - _parked: typing.OrderedDict[Task, None] = attr.ib(factory=OrderedDict, init=False) + _parked: "OrderedDict[Task, None]" = attr.ib(factory=OrderedDict, init=False) def __len__(self): """Returns the number of parked tasks.""" diff --git a/trio/_core/tests/test_thread_cache.py b/trio/_core/tests/test_thread_cache.py index 3393646290..9e58d77b9f 100644 --- a/trio/_core/tests/test_thread_cache.py +++ b/trio/_core/tests/test_thread_cache.py @@ -4,7 +4,6 @@ from queue import Queue import time import sys -import typing from outcome import Outcome @@ -62,7 +61,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None: # Make sure there are a few threads running, so if we weren't LIFO then we # could grab the wrong one. - q: typing.Queue[Outcome] = Queue() + q: "Queue[Outcome]" = Queue() COUNT = 5 for _ in range(COUNT): start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result)) @@ -95,7 +94,7 @@ def test_idle_threads_exit(monkeypatch: _pytest.monkeypatch.MonkeyPatch) -> None # CPU.) monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) - q: typing.Queue[threading.Thread] = Queue() + q: "Queue[threading.Thread]" = Queue() start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread())) seen_thread = q.get() # Since the idle timeout is 0, after sleeping for 1 second, the thread From 2d1882e4dfc6a15d68008efa69f03c5d54590c51 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 22:51:04 -0500 Subject: [PATCH 23/50] some workarounds for 3.6 --- trio/_channel.py | 6 ++++-- trio/_core/_local.py | 21 ++++++++++++++++++++- trio/tests/test_exports.py | 3 +++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/trio/_channel.py b/trio/_channel.py index 77cd16660b..bd7f314ce1 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,6 +1,6 @@ from collections import deque, OrderedDict from math import inf -import typing +import sys from typing import cast, Callable, Deque, Generic, Set, Tuple, TypeVar, Union import attr @@ -95,7 +95,9 @@ class MemoryChannelStats: tasks_waiting_receive = attr.ib() -@attr.s(slots=True) +# TODO: ick... how to handle 3.6? +# https://github.com/python-attrs/attrs/issues/313 +@attr.s(slots=sys.version_info >= (3, 7)) class MemoryChannelState(Generic[_T_contra]): max_buffer_size: float = attr.ib() data: Deque[_T_contra] = attr.ib(factory=deque) diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 6dbf190f0d..46ef6d683d 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -32,7 +32,26 @@ class _NoDefault: pass -class RunVar(Generic[_T], metaclass=Final): +# TODO: ack! this is... not pleasant. But otherwise we hit the exception below when +# testing in 3.6. Part of cleaning this up is undoing the skip in +# test_classes_are_final(). +# ImportError while loading conftest '/home/altendky/repos/trio/trio/tests/conftest.py'. +# trio/__init__.py:67: in +# from ._highlevel_socket import SocketStream, SocketListener +# trio/_highlevel_socket.py:8: in +# from . import socket as tsocket +# trio/socket.py:9: in +# from . import _socket +# trio/_socket.py:83: in +# _resolver = _core.RunVar[Optional[HostnameResolver]]("hostname_resolver") +# ../../.pyenv/versions/3.6.12/lib/python3.6/typing.py:682: in inner +# return func(*args, **kwds) +# ../../.pyenv/versions/3.6.12/lib/python3.6/typing.py:1143: in __getitem__ +# orig_bases=self.__orig_bases__) +# E TypeError: __new__() got an unexpected keyword argument 'tvars' +import sys +from .._util import BaseMeta +class RunVar(Generic[_T], metaclass=Final if sys.version_info >= (3, 7) else BaseMeta): """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, diff --git a/trio/tests/test_exports.py b/trio/tests/test_exports.py index 0ad7c7a20f..ce3fcec887 100644 --- a/trio/tests/test_exports.py +++ b/trio/tests/test_exports.py @@ -121,6 +121,9 @@ def test_classes_are_final(): # Deprecated classes are exported with a leading underscore if name.startswith("_"): # pragma: no cover continue + # TODO: fix RunVar as a generic to work in 3.6 + if name == "RunVar": + continue # Abstract classes can be subclassed, because that's the whole # point of ABCs From b77b3cb2c3d91a81ccd8ab7a11127082fedad8c4 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 22:59:42 -0500 Subject: [PATCH 24/50] another 3.6 (only windows) workaround --- trio/_core/_io_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 4f4bd12a1a..57d29911f1 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -862,7 +862,7 @@ def monitor_completion_key( self, ) -> Iterator[Tuple[int, "_core.UnboundedQueue[CompletionKeyEventInfo]"]]: key = next(self._completion_key_counter) - queue = _core.UnboundedQueue[CompletionKeyEventInfo]() + queue: "_core.UnboundedQueue[CompletionKeyEventInfo]" = _core.UnboundedQueue() self._completion_key_queues[key] = queue try: yield (key, queue) From b9dc659a89d712304fef29ba16573b9c13f9689a Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 23:09:51 -0500 Subject: [PATCH 25/50] circular import fix for macos --- trio/_core/_io_kqueue.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 08c6122778..d7fbd022b1 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -33,7 +33,7 @@ class _KqueueStatistics: class KqueueIOManager: _kqueue: select.kqueue = attr.ib(factory=select.kqueue) # {(ident, filter): Task or UnboundedQueue} - _registered: Dict[Tuple[int, int], Union[Task, UnboundedQueue[Task]]] = attr.ib( + _registered: Dict[Tuple[int, int], Union[Task, "UnboundedQueue[Task]"]] = attr.ib( factory=dict ) _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) @@ -113,7 +113,7 @@ def current_kqueue(self) -> select.kqueue: @_public def monitor_kevent( self, ident: int, filter: int - ) -> Iterator[_core.UnboundedQueue[Task]]: + ) -> Iterator["_core.UnboundedQueue[Task]"]: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( From 44e60ea5f85866745af693a50fe1a396c32d70e0 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 26 Jan 2021 23:13:18 -0500 Subject: [PATCH 26/50] circular import fix for macos (again) --- trio/_core/_io_kqueue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index d7fbd022b1..c53ce516e3 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -131,7 +131,7 @@ async def wait_kevent( self, ident: int, filter: int, - abort_func: Callable[[Callable[[], None]], _core.Abort], + abort_func: Callable[[Callable[[], None]], "_core.Abort"], ) -> object: key = (ident, filter) if key in self._registered: From 490cccdb2ab031a675269359d2093b3adb94c0d2 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 8 Feb 2021 13:56:35 -0500 Subject: [PATCH 27/50] handle some more win32 stuff --- trio/tests/test_socket.py | 25 ++++++---- trio/tests/test_subprocess.py | 89 +++++++++++++++++++++-------------- 2 files changed, 69 insertions(+), 45 deletions(-) diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index 1c7350f8bd..5b8790772e 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -6,6 +6,7 @@ import os import socket as stdlib_socket +import sys import inspect import tempfile import sys as _sys @@ -256,15 +257,21 @@ async def child(sock): @pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only") async def test_fromshare() -> None: - a, b = tsocket.socketpair() - with a, b: - # share with ourselves - shared = a.share(os.getpid()) - a2 = tsocket.fromshare(shared) # type: ignore[attr-defined] - with a2: - assert a.fileno() != a2.fileno() - await a2.send(b"x") - assert await b.recv(1) == b"x" + if sys.platform != "win32": + # mypy doesn't recognize the pytest.mark.skipif and ignores an assert inside + # this function. + # https://github.com/python/mypy/issues/9025 + assert False # we should have been skipped, if not then fail + else: + a, b = tsocket.socketpair() + with a, b: + # share with ourselves + shared = a.share(os.getpid()) + a2 = tsocket.fromshare(shared) + with a2: + assert a.fileno() != a2.fileno() + await a2.send(b"x") + assert await b.recv(1) == b"x" async def test_socket(): diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index 5d3f94eb42..3f5a69ea12 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -25,6 +25,7 @@ SIGKILL: Optional[signal.Signals] SIGTERM: Optional[signal.Signals] SIGUSR1: Optional[signal.Signals] +SIGCHLD: Optional[signal.Signals] # TODO: is this the proper translation from os.name to sys.platform? # Mypy understands sys.platform but not os.name @@ -33,9 +34,12 @@ if sys.platform != "win32": import signal - SIGKILL, SIGTERM, SIGUSR1 = signal.SIGKILL, signal.SIGTERM, signal.SIGUSR1 + SIGKILL = signal.SIGKILL + SIGTERM = signal.SIGTERM + SIGUSR1 = signal.SIGUSR1 + SIGCHLD = signal.SIGCHLD else: - SIGKILL, SIGTERM, SIGUSR1 = None, None, None + SIGKILL, SIGTERM, SIGUSR1, SIGCHLD = None, None, None, None # Since Windows has very few command-line utilities generally available, @@ -393,21 +397,28 @@ async def test_one_signal(send_it, signum): @pytest.mark.skipif(not posix, reason="POSIX specific") async def test_wait_reapable_fails() -> None: - old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN) - try: - # With SIGCHLD disabled, the wait() syscall will wait for the - # process to exit but then fail with ECHILD. Make sure we - # support this case as the stdlib subprocess module does. - async with await open_process(SLEEP(3600)) as proc: - async with _core.open_nursery() as nursery: - nursery.start_soon(proc.wait) - await wait_all_tasks_blocked() - proc.kill() - nursery.cancel_scope.deadline = _core.current_time() + 1.0 - assert not nursery.cancel_scope.cancelled_caught - assert proc.returncode == 0 # exit status unknowable, so... - finally: - signal.signal(signal.SIGCHLD, old_sigchld) + if sys.platform == "win32": + # mypy doesn't recognize the pytest.mark.skipif and ignores an assert inside + # this function. + # https://github.com/python/mypy/issues/9025 + assert False # we should have been skipped, if not then fail + else: + assert SIGCHLD is not None # for mypy + old_sigchld = signal.signal(SIGCHLD, signal.SIG_IGN) + try: + # With SIGCHLD disabled, the wait() syscall will wait for the + # process to exit but then fail with ECHILD. Make sure we + # support this case as the stdlib subprocess module does. + async with await open_process(SLEEP(3600)) as proc: + async with _core.open_nursery() as nursery: + nursery.start_soon(proc.wait) + await wait_all_tasks_blocked() + proc.kill() + nursery.cancel_scope.deadline = _core.current_time() + 1.0 + assert not nursery.cancel_scope.cancelled_caught + assert proc.returncode == 0 # exit status unknowable, so... + finally: + signal.signal(SIGCHLD, old_sigchld) @slow @@ -420,26 +431,32 @@ def test_waitid_eintr() -> None: pytest.skip("waitid only") from .._subprocess_platform.waitid import sync_wait_reapable - got_alarm = False - sleeper = subprocess.Popen(["sleep", "3600"]) - - def on_alarm(sig, frame): - nonlocal got_alarm - got_alarm = True - sleeper.kill() - - old_sigalrm = signal.signal(signal.SIGALRM, on_alarm) - try: - signal.alarm(1) - sync_wait_reapable(sleeper.pid) - assert sleeper.wait(timeout=1) == -9 - finally: - if sleeper.returncode is None: # pragma: no cover - # We only get here if something fails in the above; - # if the test passes, wait() will reap the process + if sys.platform == "win32": + # mypy doesn't recognize the waitid checks above as representing not-Windows + # and ignores an assert inside this function. + # https://github.com/python/mypy/issues/9025 + assert False # we should have been skipped, if not then fail + else: + got_alarm = False + sleeper = subprocess.Popen(["sleep", "3600"]) + + def on_alarm(sig, frame): + nonlocal got_alarm + got_alarm = True sleeper.kill() - sleeper.wait() - signal.signal(signal.SIGALRM, old_sigalrm) + + old_sigalrm = signal.signal(signal.SIGALRM, on_alarm) + try: + signal.alarm(1) + sync_wait_reapable(sleeper.pid) + assert sleeper.wait(timeout=1) == -9 + finally: + if sleeper.returncode is None: # pragma: no cover + # We only get here if something fails in the above; + # if the test passes, wait() will reap the process + sleeper.kill() + sleeper.wait() + signal.signal(signal.SIGALRM, old_sigalrm) async def test_custom_deliver_cancel(): From c3db271f38fe0ce3525aaa09fcda07e98101575c Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 8 Feb 2021 15:55:52 -0500 Subject: [PATCH 28/50] more --- trio/_core/_io_windows.py | 9 +++++---- trio/_core/_local.py | 2 +- trio/_socket.py | 10 ++++++++-- trio/tests/test_ssl.py | 5 ++++- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 57d29911f1..e0ade8bf2e 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -721,6 +721,7 @@ async def wait_overlapped( ) -> None: handle = _handle(handle) if isinstance(lpOverlapped, int): + # TODO: figure out how to hint this? lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped) if lpOverlapped in self._overlapped_waiters: raise _core.BusyResourceError( @@ -766,11 +767,11 @@ def abort(raise_cancel_): return _core.Abort.FAILED await _core.wait_task_rescheduled(abort) - if lpOverlapped.Internal != 0: + if lpOverlapped.Internal != 0: # type: ignore[attr-defined] # the lpOverlapped reports the error as an NT status code, # which we must convert back to a Win32 error code before # it will produce the right sorts of exceptions - code = ntdll.RtlNtStatusToDosError(lpOverlapped.Internal) + code = ntdll.RtlNtStatusToDosError(lpOverlapped.Internal) # type: ignore[attr-defined] if code == ErrorCodes.ERROR_OPERATION_ABORTED: if raise_cancel is not None: raise_cancel() @@ -823,7 +824,7 @@ def submit_write(lpOverlapped): lpOverlapped = await self._perform_overlapped(handle, submit_write) # this is "number of bytes transferred" - return lpOverlapped.InternalHigh + return lpOverlapped.InternalHigh # type: ignore[no-any-return] @_public async def readinto_overlapped( @@ -846,7 +847,7 @@ def submit_read(lpOverlapped): ) lpOverlapped = await self._perform_overlapped(handle, submit_read) - return lpOverlapped.InternalHigh + return lpOverlapped.InternalHigh # type: ignore[no-any-return] ################################################################ # Raw IOCP operations diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 46ef6d683d..d7ebfa6ac7 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -51,7 +51,7 @@ class _NoDefault: # E TypeError: __new__() got an unexpected keyword argument 'tvars' import sys from .._util import BaseMeta -class RunVar(Generic[_T], metaclass=Final if sys.version_info >= (3, 7) else BaseMeta): +class RunVar(Generic[_T], metaclass=Final if sys.version_info >= (3, 7) else BaseMeta): # type: ignore[misc] """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, diff --git a/trio/_socket.py b/trio/_socket.py index 69f36c7fc7..f09ad6f16c 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -273,8 +273,14 @@ def fromfd(fd: int, family: int, type: int, proto: int = 0) -> "SocketType": ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(*args: object, **kwargs: object) -> "SocketType": - return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs)) + def fromshare(data: bytes) -> "SocketType": + # Not using *args, **kwargs to make mypy happy. + # trio/_socket.py:277: error: Argument 1 to "fromshare" has incompatible type "*Tuple[object, ...]"; expected "bytes" [arg-type] + # trio/_socket.py:277: error: Argument 2 to "fromshare" has incompatible type "**Dict[str, object]"; expected "bytes" [arg-type] + # So, we will just have to keep this in sync with the stdlib function in such + # case as it ever changes in the future. + # https://docs.python.org/3.9/library/socket.html#socket.fromshare + return from_stdlib_socket(_stdlib_socket.fromshare(data)) # @overload diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index aa5a6a463f..5b66eea942 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -80,7 +80,10 @@ def client_ctx(request: _pytest.fixtures.SubRequest) -> ssl.SSLContext: if request.param in ["default", "tls13"]: return ctx elif request.param == "tls12": - ctx.options |= ssl.OP_NO_TLSv1_3 + # https://github.com/python/typeshed/blob/a3f5541830205400cdf3aac04625e8a09f86cace/stdlib/ssl.pyi#L146-L148 + # but it is there in ~=3.6.3 as well... + # https://docs.python.org/3.9/library/ssl.html#ssl.OP_NO_TLSv1_3 + ctx.options |= ssl.OP_NO_TLSv1_3 # type: ignore[attr-defined] return ctx else: # pragma: no cover assert False From 1915661c235460ce4a9f51edd06af07696454166 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 8 Feb 2021 22:40:36 -0500 Subject: [PATCH 29/50] catch up --- trio/_core/_io_windows.py | 4 ++-- trio/_core/tests/test_run.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 29202642d8..2f4eef768e 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -724,7 +724,7 @@ def register_with_iocp(self, handle: socket.socket) -> None: @_public async def wait_overlapped( self, handle: socket.socket, lpOverlapped: Union[int, object] - ) -> None: + ) -> object: handle = _handle(handle) if isinstance(lpOverlapped, int): # TODO: figure out how to hint this? @@ -788,7 +788,7 @@ def abort(raise_cancel_): raise _core.ClosedResourceError("another task closed this resource") else: raise_winerror(code) - return info + return object async def _perform_overlapped(self, handle, submit_fn): # submit_fn(lpOverlapped) submits some I/O diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index 1af746ae6f..6def9c8e97 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -2259,7 +2259,7 @@ async def test_nursery_cancel_doesnt_create_cyclic_garbage() -> None: @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_locals_destroyed_promptly_on_cancel(): +async def test_locals_destroyed_promptly_on_cancel() -> None: destroyed = False def finalizer(): From 2d70fa0500de91c09200c6751f598127b32c3415 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 9 Feb 2021 14:10:44 -0500 Subject: [PATCH 30/50] separate type checking --- .github/workflows/ci.yml | 5 +++++ check.sh | 5 ----- ci.sh | 3 +++ typing.sh | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 5 deletions(-) create mode 100755 typing.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e0959c1797..c9ed826d17 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,12 +61,16 @@ jobs: matrix: python: ['pypy-3.6', 'pypy-3.7', '3.6', '3.7', '3.8', '3.9', '3.6-dev', '3.7-dev', '3.8-dev', '3.9-dev'] check_formatting: ['0'] + check_typing: ['0'] pypy_nightly_branch: [''] extra_name: [''] include: - python: '3.8' check_formatting: '1' extra_name: ', check formatting' + - python: '3.8' + check_typing: '1' + extra_name: ', check typing' - python: '3.7' # <- not actually used pypy_nightly_branch: 'py3.7' extra_name: ', pypy 3.7 nightly' @@ -88,6 +92,7 @@ jobs: env: PYPY_NIGHTLY_BRANCH: '${{ matrix.pypy_nightly_branch }}' CHECK_FORMATTING: '${{ matrix.check_formatting }}' + CHECK_TYPING: '${{ matrix.check_typing }}' # Should match 'name:' up above JOB_NAME: 'Ubuntu (${{ matrix.python }}${{ matrix.extra_name }})' diff --git a/check.sh b/check.sh index de185cab67..c1e81986ae 100755 --- a/check.sh +++ b/check.sh @@ -23,11 +23,6 @@ flake8 trio/ \ --ignore=D,E,W,F401,F403,F405,F821,F822\ || EXIT_STATUS=$? -# Run mypy on all supported platforms -mypy -p trio --platform linux || EXIT_STATUS=$? -mypy -p trio --platform darwin || EXIT_STATUS=$? # tests FreeBSD too -mypy -p trio --platform win32 || EXIT_STATUS=$? - # Finally, leave a really clear warning of any issues and exit if [ $EXIT_STATUS -ne 0 ]; then cat < Date: Tue, 9 Feb 2021 14:20:10 -0500 Subject: [PATCH 31/50] black --- trio/_core/_local.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trio/_core/_local.py b/trio/_core/_local.py index d7ebfa6ac7..3d63edb14b 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -51,6 +51,8 @@ class _NoDefault: # E TypeError: __new__() got an unexpected keyword argument 'tvars' import sys from .._util import BaseMeta + + class RunVar(Generic[_T], metaclass=Final if sys.version_info >= (3, 7) else BaseMeta): # type: ignore[misc] """The run-local variant of a context variable. From c1dd0d9a54cebf49dbf560925ba9ede8ff2fad9a Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 9 Feb 2021 14:28:42 -0500 Subject: [PATCH 32/50] type check 3.6 - 3.9 on each platform --- typing.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/typing.sh b/typing.sh index c6f496a7af..fe156450fd 100755 --- a/typing.sh +++ b/typing.sh @@ -9,9 +9,11 @@ python ./trio/_tools/gen_exports.py --test \ || EXIT_STATUS=$? # Run mypy on all supported platforms -mypy -p trio --platform linux || EXIT_STATUS=$? -mypy -p trio --platform darwin || EXIT_STATUS=$? # tests FreeBSD too -mypy -p trio --platform win32 || EXIT_STATUS=$? +for PLATFORM in linux darwin win32; do + for VERSION in 3.6 3.7 3.8 3.9; do + mypy -p trio --platform $PLATFORM --python-version $VERSION || EXIT_STATUS=$? + done +done # Finally, leave a really clear warning of any issues and exit if [ $EXIT_STATUS -ne 0 ]; then From 2623aff4e09cc2ecc29b31a0ec2abb61ca7cefb3 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 9 Feb 2021 21:31:50 -0500 Subject: [PATCH 33/50] oops --- trio/_core/_io_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 2f4eef768e..40229a775d 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -788,7 +788,7 @@ def abort(raise_cancel_): raise _core.ClosedResourceError("another task closed this resource") else: raise_winerror(code) - return object + return info async def _perform_overlapped(self, handle, submit_fn): # submit_fn(lpOverlapped) submits some I/O From 6dbb49634e4538f98904d58852d18d371a0d66c3 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 9 Feb 2021 21:34:30 -0500 Subject: [PATCH 34/50] misc --- docs/source/conf.py | 4 ++++ trio/_core/_multierror.py | 4 ++-- trio/_socket.py | 2 +- trio/tests/test_ssl.py | 5 +---- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6045ffd828..cfce9d2256 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -54,6 +54,10 @@ # https://github.com/sphinx-doc/sphinx/issues/7722 ("py:class", "SendType"), ("py:class", "ReceiveType"), + ("py:class", "_T_contra"), + ("py:class", "_T_co"), + ("py:class", "_T"), + ("py:class", "T_resource"), ] autodoc_inherit_docstrings = False default_role = "obj" diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 0f59184030..1d535edeab 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -1,7 +1,7 @@ import sys import traceback import textwrap -from typing import Callable, Optional, overload, Set, Union +from typing import Callable, ContextManager, Optional, overload, Set, Union import warnings import attr @@ -258,7 +258,7 @@ def filter( @classmethod def catch( cls, handler: Callable[[Exception], Optional[Exception]] - ) -> MultiErrorCatcher: + ) -> ContextManager[None]: """Return a context manager that catches and re-throws exceptions after running :meth:`filter` on them. diff --git a/trio/_socket.py b/trio/_socket.py index f09ad6f16c..fab8e3f4dc 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -455,7 +455,7 @@ def is_readable(self) -> bool: async def wait_writable(self) -> None: ... - async def accept(self) -> Tuple[SocketType, Any]: + async def accept(self) -> Tuple["SocketType", Any]: ... async def connect(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index 5b66eea942..aa5a6a463f 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -80,10 +80,7 @@ def client_ctx(request: _pytest.fixtures.SubRequest) -> ssl.SSLContext: if request.param in ["default", "tls13"]: return ctx elif request.param == "tls12": - # https://github.com/python/typeshed/blob/a3f5541830205400cdf3aac04625e8a09f86cace/stdlib/ssl.pyi#L146-L148 - # but it is there in ~=3.6.3 as well... - # https://docs.python.org/3.9/library/ssl.html#ssl.OP_NO_TLSv1_3 - ctx.options |= ssl.OP_NO_TLSv1_3 # type: ignore[attr-defined] + ctx.options |= ssl.OP_NO_TLSv1_3 return ctx else: # pragma: no cover assert False From 65e85723de2528ef99ca9b90ee9be2e5bd4f9d2c Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 9 Feb 2021 21:40:04 -0500 Subject: [PATCH 35/50] generate generated files --- trio/_core/_generated_io_epoll.py | 6 +++--- trio/_core/_generated_io_kqueue.py | 14 ++++++++------ trio/_core/_generated_io_windows.py | 21 ++++++++++++--------- trio/_core/_generated_run.py | 17 +++++++++-------- 4 files changed, 32 insertions(+), 26 deletions(-) diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index 9ae54e4f68..a33f6768f4 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -8,7 +8,7 @@ # fmt: off -async def wait_readable(fd): +async def wait_readable(fd: Union[int, _HasFileno]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -16,7 +16,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: Union[int, _HasFileno]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -24,7 +24,7 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: Union[int, _HasFileno]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 7549899dbe..8d798e7112 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -8,7 +8,7 @@ # fmt: off -def current_kqueue(): +def current_kqueue() ->select.kqueue: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() @@ -16,7 +16,8 @@ def current_kqueue(): raise RuntimeError("must be called from async context") -def monitor_kevent(ident, filter): +def monitor_kevent(ident: int, filter: int) ->Iterator[ + '_core.UnboundedQueue[Task]']: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) @@ -24,7 +25,8 @@ def monitor_kevent(ident, filter): raise RuntimeError("must be called from async context") -async def wait_kevent(ident, filter, abort_func): +async def wait_kevent(ident: int, filter: int, abort_func: Callable[[ + Callable[[], None]], '_core.Abort']) ->object: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func) @@ -32,7 +34,7 @@ async def wait_kevent(ident, filter, abort_func): raise RuntimeError("must be called from async context") -async def wait_readable(fd): +async def wait_readable(fd: Union[int, _HasFileno]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -40,7 +42,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: Union[int, _HasFileno]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -48,7 +50,7 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: Union[int, _HasFileno]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index e6337e94b0..49d76baf0f 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -8,7 +8,7 @@ # fmt: off -async def wait_readable(sock): +async def wait_readable(sock: socket.socket) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) @@ -16,7 +16,7 @@ async def wait_readable(sock): raise RuntimeError("must be called from async context") -async def wait_writable(sock): +async def wait_writable(sock: socket.socket) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) @@ -24,7 +24,7 @@ async def wait_writable(sock): raise RuntimeError("must be called from async context") -def notify_closing(handle): +def notify_closing(handle: socket.socket) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) @@ -32,7 +32,7 @@ def notify_closing(handle): raise RuntimeError("must be called from async context") -def register_with_iocp(handle): +def register_with_iocp(handle: socket.socket) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) @@ -40,7 +40,8 @@ def register_with_iocp(handle): raise RuntimeError("must be called from async context") -async def wait_overlapped(handle, lpOverlapped): +async def wait_overlapped(handle: socket.socket, lpOverlapped: Union[int, + object]) ->object: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(handle, lpOverlapped) @@ -48,7 +49,7 @@ async def wait_overlapped(handle, lpOverlapped): raise RuntimeError("must be called from async context") -async def write_overlapped(handle, data, file_offset=0): +async def write_overlapped(handle: int, data: bytes, file_offset: int=0) ->int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(handle, data, file_offset) @@ -56,7 +57,8 @@ async def write_overlapped(handle, data, file_offset=0): raise RuntimeError("must be called from async context") -async def readinto_overlapped(handle, buffer, file_offset=0): +async def readinto_overlapped(handle: int, buffer: memoryview, file_offset: + int=0) ->int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(handle, buffer, file_offset) @@ -64,7 +66,7 @@ async def readinto_overlapped(handle, buffer, file_offset=0): raise RuntimeError("must be called from async context") -def current_iocp(): +def current_iocp() ->int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() @@ -72,7 +74,8 @@ def current_iocp(): raise RuntimeError("must be called from async context") -def monitor_completion_key(): +def monitor_completion_key() ->Iterator[Tuple[int, + '_core.UnboundedQueue[CompletionKeyEventInfo]']]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 1272b4c73c..283417351f 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -8,7 +8,7 @@ # fmt: off -def current_statistics(): +def current_statistics() ->_RunStatistics: """Returns an object containing run-loop-level debugging information. Currently the following fields are defined: @@ -38,7 +38,7 @@ def current_statistics(): raise RuntimeError("must be called from async context") -def current_time(): +def current_time() ->float: """Returns the current time according to Trio's internal clock. Returns: @@ -55,7 +55,7 @@ def current_time(): raise RuntimeError("must be called from async context") -def current_clock(): +def current_clock() ->Clock: """Returns the current :class:`~trio.abc.Clock`.""" locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: @@ -64,7 +64,7 @@ def current_clock(): raise RuntimeError("must be called from async context") -def current_root_task(): +def current_root_task() ->Optional[Task]: """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. @@ -77,7 +77,7 @@ def current_root_task(): raise RuntimeError("must be called from async context") -def reschedule(task, next_send=_NO_SEND): +def reschedule(task: Task, next_send: object=_NO_SEND) ->None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -102,7 +102,8 @@ def reschedule(task, next_send=_NO_SEND): raise RuntimeError("must be called from async context") -def spawn_system_task(async_fn, *args, name=None): +def spawn_system_task(async_fn: Callable[..., Awaitable[object]], *args: + object, name: Optional[str]=None) ->Task: """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -157,7 +158,7 @@ def spawn_system_task(async_fn, *args, name=None): raise RuntimeError("must be called from async context") -def current_trio_token(): +def current_trio_token() ->TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. @@ -169,7 +170,7 @@ def current_trio_token(): raise RuntimeError("must be called from async context") -async def wait_all_tasks_blocked(cushion=0.0): +async def wait_all_tasks_blocked(cushion: float=0.0) ->None: """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a From 74543e01814de17943ab79c31edb9138d709891b Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 9 Feb 2021 22:07:32 -0500 Subject: [PATCH 36/50] update generated imports --- trio/_core/_generated_instrumentation.py | 7 ++++++- trio/_core/_generated_io_epoll.py | 7 ++++++- trio/_core/_generated_io_kqueue.py | 7 ++++++- trio/_core/_generated_io_windows.py | 7 ++++++- trio/_core/_generated_run.py | 7 ++++++- trio/_core/_io_epoll.py | 6 +----- trio/_core/_io_kqueue.py | 8 +------- trio/_subprocess.py | 6 +----- trio/_tools/gen_exports.py | 7 ++++++- trio/_typing.py | 6 ++++++ 10 files changed, 45 insertions(+), 23 deletions(-) create mode 100644 trio/_typing.py diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index 986ab2c7f5..e1ee232a6c 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -1,7 +1,12 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from typing import Awaitable, Callable, Optional, Union + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index a33f6768f4..f4e824c3ae 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -1,7 +1,12 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from typing import Awaitable, Callable, Optional, Union + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 8d798e7112..38aa39b0ac 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -1,7 +1,12 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from typing import Awaitable, Callable, Optional, Union + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 49d76baf0f..2dba97cea7 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -1,7 +1,12 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from typing import Awaitable, Callable, Optional, Union + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 283417351f..ccb13b46fb 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -1,7 +1,12 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from typing import Awaitable, Callable, Optional, Union + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index eea857a3d4..24ecff15c7 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -7,6 +7,7 @@ from typing_extensions import Protocol from .. import _core +from .._typing import _HasFileno from ._run import _public from ._io_common import wake_all from ._wakeup_socketpair import WakeupSocketpair @@ -14,11 +15,6 @@ assert not TYPE_CHECKING or sys.platform == "linux" -class _HasFileno(Protocol): - def fileno(self) -> int: - ... - - @attr.s(slots=True, eq=False, frozen=True) class _EpollStatistics: tasks_waiting_read = attr.ib() diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index c53ce516e3..fe8b765dac 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -2,14 +2,13 @@ import sys from typing import Callable, Dict, Iterator, Optional, Tuple, TYPE_CHECKING, Union -from typing_extensions import Protocol - import outcome from contextlib import contextmanager import attr import errno from .. import _core +from .._typing import _HasFileno from ._run import _public, Task from ._unbounded_queue import UnboundedQueue from ._wakeup_socketpair import WakeupSocketpair @@ -17,11 +16,6 @@ assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") -class _HasFileno(Protocol): - def fileno(self) -> int: - ... - - @attr.s(slots=True, eq=False, frozen=True) class _KqueueStatistics: tasks_waiting = attr.ib() diff --git a/trio/_subprocess.py b/trio/_subprocess.py index e3944d9673..5756bc64b0 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -27,6 +27,7 @@ create_pipe_to_child_stdin, create_pipe_from_child_output, ) +from ._typing import _HasFileno from ._util import NoPublicConstructor import trio @@ -293,11 +294,6 @@ def kill(self): self._proc.kill() -class _HasFileno(Protocol): - def fileno(self) -> int: - ... - - _Redirect = Union[int, _HasFileno, None] # There's a lot of duplication here because mypy doesn't diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index d7e6326ce6..958520c9d7 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -18,7 +18,12 @@ HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +from typing import Awaitable, Callable, Optional, Union + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument diff --git a/trio/_typing.py b/trio/_typing.py new file mode 100644 index 0000000000..fc614f1613 --- /dev/null +++ b/trio/_typing.py @@ -0,0 +1,6 @@ +from typing_extensions import Protocol + + +class _HasFileno(Protocol): + def fileno(self) -> int: + ... From 7e8aed5ef456ea780f689576c162a56b39dff74a Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 9 Feb 2021 22:19:28 -0500 Subject: [PATCH 37/50] more imports for the generated files --- trio/_core/_generated_instrumentation.py | 4 +++- trio/_core/_generated_io_epoll.py | 4 +++- trio/_core/_generated_io_kqueue.py | 4 +++- trio/_core/_generated_io_windows.py | 4 +++- trio/_core/_generated_run.py | 4 +++- trio/_tools/gen_exports.py | 4 +++- 6 files changed, 18 insertions(+), 6 deletions(-) diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index e1ee232a6c..7c49bc10f8 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -1,7 +1,9 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from typing import Awaitable, Callable, Optional, Union +import select +import socket +from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index f4e824c3ae..a4cd0f2f05 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -1,7 +1,9 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from typing import Awaitable, Callable, Optional, Union +import select +import socket +from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 38aa39b0ac..242f9c0369 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -1,7 +1,9 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from typing import Awaitable, Callable, Optional, Union +import select +import socket +from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 2dba97cea7..3fe79460b8 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -1,7 +1,9 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from typing import Awaitable, Callable, Optional, Union +import select +import socket +from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index ccb13b46fb..36ab4a181f 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -1,7 +1,9 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from typing import Awaitable, Callable, Optional, Union +import select +import socket +from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 958520c9d7..4bf7106604 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -18,7 +18,9 @@ HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from typing import Awaitable, Callable, Optional, Union +import select +import socket +from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union from .._abc import Clock from .._typing import _HasFileno From 8a7ce7490b6297f299fd4b6f783e1464a49232f5 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Wed, 10 Feb 2021 22:14:53 -0500 Subject: [PATCH 38/50] always more --- trio/_core/_generated_instrumentation.py | 7 ++++++- trio/_core/_generated_io_epoll.py | 10 +++++++++- trio/_core/_generated_io_kqueue.py | 12 ++++++++++-- trio/_core/_generated_io_windows.py | 22 +++++++++++++++------- trio/_core/_generated_run.py | 9 +++++++-- trio/_core/_io_kqueue.py | 4 +++- trio/_core/_io_windows.py | 12 ++++++------ trio/_core/_run.py | 2 +- trio/_core/tests/test_io.py | 11 +++++++---- trio/_tools/gen_exports.py | 21 ++++++++++++++++++++- 10 files changed, 84 insertions(+), 26 deletions(-) diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index 7c49bc10f8..3b77db7c18 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -3,15 +3,20 @@ # ************************************************************* import select import socket -from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union +import sys +from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union from .._abc import Clock from .._typing import _HasFileno from .._core._entry_queue import TrioToken +from .. import _core from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index a4cd0f2f05..4be0360a5a 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -3,18 +3,26 @@ # ************************************************************* import select import socket -from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union +import sys +from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union from .._abc import Clock from .._typing import _HasFileno from .._core._entry_queue import TrioToken +from .. import _core from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off +assert not TYPE_CHECKING or sys.platform == 'linux' + + async def wait_readable(fd: Union[int, _HasFileno]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 242f9c0369..09d8b24811 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -3,18 +3,26 @@ # ************************************************************* import select import socket -from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union +import sys +from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union from .._abc import Clock from .._typing import _HasFileno from .._core._entry_queue import TrioToken +from .. import _core from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off +assert not TYPE_CHECKING or sys.platform != 'linux' and sys.platform != 'win32' + + def current_kqueue() ->select.kqueue: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: @@ -23,7 +31,7 @@ def current_kqueue() ->select.kqueue: raise RuntimeError("must be called from async context") -def monitor_kevent(ident: int, filter: int) ->Iterator[ +def monitor_kevent(ident: int, filter: int) ->ContextManager[ '_core.UnboundedQueue[Task]']: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 3fe79460b8..66a878ab8a 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -3,19 +3,27 @@ # ************************************************************* import select import socket -from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union +import sys +from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union from .._abc import Clock from .._typing import _HasFileno from .._core._entry_queue import TrioToken +from .. import _core from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off -async def wait_readable(sock: socket.socket) ->None: +assert not TYPE_CHECKING or sys.platform == 'win32' + + +async def wait_readable(sock: int) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) @@ -23,7 +31,7 @@ async def wait_readable(sock: socket.socket) ->None: raise RuntimeError("must be called from async context") -async def wait_writable(sock: socket.socket) ->None: +async def wait_writable(sock: int) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) @@ -31,7 +39,7 @@ async def wait_writable(sock: socket.socket) ->None: raise RuntimeError("must be called from async context") -def notify_closing(handle: socket.socket) ->None: +def notify_closing(handle: int) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) @@ -39,7 +47,7 @@ def notify_closing(handle: socket.socket) ->None: raise RuntimeError("must be called from async context") -def register_with_iocp(handle: socket.socket) ->None: +def register_with_iocp(handle: int) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) @@ -64,7 +72,7 @@ async def write_overlapped(handle: int, data: bytes, file_offset: int=0) ->int: raise RuntimeError("must be called from async context") -async def readinto_overlapped(handle: int, buffer: memoryview, file_offset: +async def readinto_overlapped(handle: int, buffer: bytearray, file_offset: int=0) ->int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: @@ -81,7 +89,7 @@ def current_iocp() ->int: raise RuntimeError("must be called from async context") -def monitor_completion_key() ->Iterator[Tuple[int, +def monitor_completion_key() ->ContextManager[Tuple[int, '_core.UnboundedQueue[CompletionKeyEventInfo]']]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 36ab4a181f..b06dd70589 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -3,15 +3,20 @@ # ************************************************************* import select import socket -from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union +import sys +from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union from .._abc import Clock from .._typing import _HasFileno from .._core._entry_queue import TrioToken +from .. import _core from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off @@ -110,7 +115,7 @@ def reschedule(task: Task, next_send: object=_NO_SEND) ->None: def spawn_system_task(async_fn: Callable[..., Awaitable[object]], *args: - object, name: Optional[str]=None) ->Task: + object, name: Optional[Union[str, Callable]]=None) ->Task: """Spawn a "system" task. System tasks have a few differences from regular tasks: diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index fe8b765dac..d9c0773958 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -193,7 +193,9 @@ def notify_closing(self, fd: Union[int, _HasFileno]) -> None: if receiver is None: continue - if type(receiver) is _core.Task: + # if type(receiver) is _core.Task: + # TODO: is this unacceptably less specific? + if isinstance(receiver, _core.Task): event = select.kevent(fd, filter, select.KQ_EV_DELETE) self._kqueue.control([event], 0) exc = _core.ClosedResourceError("another task closed this fd") diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 40229a775d..c7c40f6701 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -28,7 +28,7 @@ IoControlCodes, ) -# assert not TYPE_CHECKING or sys.platform == "win32" +assert not TYPE_CHECKING or sys.platform == "win32" # There's a lot to be said about the overall design of a Windows event # loop. See @@ -697,15 +697,15 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) @_public - async def wait_readable(self, sock: socket.socket) -> None: + async def wait_readable(self, sock: int) -> None: await self._afd_poll(sock, "read_task") @_public - async def wait_writable(self, sock: socket.socket) -> None: + async def wait_writable(self, sock: int) -> None: await self._afd_poll(sock, "write_task") @_public - def notify_closing(self, handle: socket.socket) -> None: + def notify_closing(self, handle: int) -> None: handle = _get_base_socket(handle) waiters = self._afd_waiters.get(handle) if waiters is not None: @@ -717,7 +717,7 @@ def notify_closing(self, handle: socket.socket) -> None: ################################################################ @_public - def register_with_iocp(self, handle: socket.socket) -> None: + def register_with_iocp(self, handle: int) -> None: self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED) # TODO: what else can lpOverlapped be? @@ -835,7 +835,7 @@ def submit_write(lpOverlapped): @_public async def readinto_overlapped( - self, handle: int, buffer: memoryview, file_offset: int = 0 + self, handle: int, buffer: bytearray, file_offset: int = 0 ) -> int: with ffi.from_buffer(buffer, require_writable=True) as cbuf: diff --git a/trio/_core/_run.py b/trio/_core/_run.py index a1ff64c69f..7ddf83507a 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1600,7 +1600,7 @@ def spawn_system_task( # type: ignore[misc] self, async_fn: Callable[..., Awaitable[object]], *args: object, - name: Optional[str] = None, + name: Optional[Union[str, Callable]] = None, ) -> Task: """Spawn a "system" task. diff --git a/trio/_core/tests/test_io.py b/trio/_core/tests/test_io.py index 0cdd9c91a6..1418df4080 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/tests/test_io.py @@ -5,9 +5,10 @@ import random import errno from contextlib import suppress -from typing import Awaitable, Callable, Iterator, Tuple +from typing import Awaitable, Callable, Iterator, List, Tuple, Union from ... import _core +from ..._typing import _HasFileno from ...testing import wait_all_tasks_blocked, Sequencer, assert_checkpoints import trio @@ -56,9 +57,11 @@ def fileno_wrapper(fileobj): _WaitWritable = Callable[[stdlib_socket.socket], Awaitable[None]] _NotifyClosing = Callable[[stdlib_socket.socket], None] -wait_readable_options = [trio.lowlevel.wait_readable] -wait_writable_options = [trio.lowlevel.wait_writable] -notify_closing_options = [trio.lowlevel.notify_closing] +# OptionsList = List[Callable[[Union[int, _HasFileno]], Union[Awaitable[None], None]]] + +wait_readable_options: List = [trio.lowlevel.wait_readable] +wait_writable_options: List = [trio.lowlevel.wait_writable] +notify_closing_options: List = [trio.lowlevel.notify_closing] for options_list in [ diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 4bf7106604..25ea047d14 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -20,15 +20,20 @@ # ************************************************************* import select import socket -from typing import Awaitable, Callable, Iterator, Optional, Tuple, Union +import sys +from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union from .._abc import Clock from .._typing import _HasFileno from .._core._entry_queue import TrioToken +from .. import _core from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off """ @@ -99,12 +104,24 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: """ generated = [HEADER] + # source_string = source_path.read_text("utf-8") + # source = astor.code_to_ast.parse_string(source_string) source = astor.code_to_ast.parse_file(source_path) + + asserts = [node for node in ast.iter_child_nodes(source) if isinstance(node, ast.Assert)] + if len(asserts) > 0: + the_assert = asserts[0] + generated.append(astor.to_source(the_assert)) + for method in get_public_methods(source): # Remove self from arguments assert method.args.args[0].arg == "self" del method.args.args[0] + contextmanager_decorated = any( + decorator.id in {'contextmanager', 'contextlib.contextmanager'} + for decorator in method.decorator_list + ) # Remove decorators method.decorator_list = [] @@ -120,6 +137,8 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: # Create the function definition including the body func = astor.to_source(method, indent_with=" " * 4) + if contextmanager_decorated: + func = func.replace("->Iterator[", "->ContextManager[") # Create export function body template = TEMPLATE.format( From 80752336b87b47dd58a3f34cc9d5aa42c04e386b Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Wed, 10 Feb 2021 22:18:32 -0500 Subject: [PATCH 39/50] reformat imports for generated code --- trio/_core/_generated_instrumentation.py | 11 ++++++++++- trio/_core/_generated_io_epoll.py | 11 ++++++++++- trio/_core/_generated_io_kqueue.py | 11 ++++++++++- trio/_core/_generated_io_windows.py | 11 ++++++++++- trio/_core/_generated_run.py | 11 ++++++++++- trio/_tools/gen_exports.py | 17 ++++++++++++++--- 6 files changed, 64 insertions(+), 8 deletions(-) diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index 3b77db7c18..dc0495fae7 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -4,7 +4,16 @@ import select import socket import sys -from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index 4be0360a5a..cd76014d5c 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -4,7 +4,16 @@ import select import socket import sys -from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 09d8b24811..200017ffec 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -4,7 +4,16 @@ import select import socket import sys -from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 66a878ab8a..85ed1a3f0d 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -4,7 +4,16 @@ import select import socket import sys -from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index b06dd70589..799b44d0cd 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -4,7 +4,16 @@ import select import socket import sys -from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) from .._abc import Clock from .._typing import _HasFileno diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 25ea047d14..61e2bf0584 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -21,7 +21,16 @@ import select import socket import sys -from typing import Awaitable, Callable, ContextManager, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) from .._abc import Clock from .._typing import _HasFileno @@ -108,7 +117,9 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: # source = astor.code_to_ast.parse_string(source_string) source = astor.code_to_ast.parse_file(source_path) - asserts = [node for node in ast.iter_child_nodes(source) if isinstance(node, ast.Assert)] + asserts = [ + node for node in ast.iter_child_nodes(source) if isinstance(node, ast.Assert) + ] if len(asserts) > 0: the_assert = asserts[0] generated.append(astor.to_source(the_assert)) @@ -119,7 +130,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: del method.args.args[0] contextmanager_decorated = any( - decorator.id in {'contextmanager', 'contextlib.contextmanager'} + decorator.id in {"contextmanager", "contextlib.contextmanager"} for decorator in method.decorator_list ) # Remove decorators From 6e58b9d479f803a02b9dcb29d3dd03707a1834a6 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Thu, 11 Feb 2021 15:57:05 -0500 Subject: [PATCH 40/50] some doc stuff --- trio/_abc.py | 5 +- trio/_socket.py | 227 +++++++++++++++++++++++------------------------- 2 files changed, 112 insertions(+), 120 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index 060f66cecc..1abaac743f 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -5,9 +5,6 @@ import socket import trio -if TYPE_CHECKING: - from ._socket import SocketType - _T = TypeVar("_T") @@ -210,7 +207,7 @@ def socket( family: Optional[int] = None, type: Optional[int] = None, proto: Optional[int] = None, - ) -> "SocketType": + ) -> "trio.socket.SocketType": """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, diff --git a/trio/_socket.py b/trio/_socket.py index fab8e3f4dc..4df7903f01 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1,3 +1,4 @@ +from abc import ABCMeta import os import sys import select @@ -23,7 +24,6 @@ import trio from . import _core -from ._abc import HostnameResolver, SocketFactory _T = TypeVar("_T") @@ -80,8 +80,8 @@ async def __aexit__(self, etype, value, tb): # Overrides ################################################################ -_resolver = _core.RunVar[Optional[HostnameResolver]]("hostname_resolver") -_socket_factory = _core.RunVar[Optional[SocketFactory]]("socket_factory") +_resolver = _core.RunVar[Optional["trio._abc.HostnameResolver"]]("hostname_resolver") +_socket_factory = _core.RunVar[Optional["trio._abc.SocketFactory"]]("socket_factory") def set_custom_hostname_resolver(hostname_resolver): @@ -117,8 +117,8 @@ def set_custom_hostname_resolver(hostname_resolver): def set_custom_socket_factory( - socket_factory: Optional[SocketFactory], -) -> Optional[SocketFactory]: + socket_factory: Optional["trio.abc.SocketFactory"], +) -> Optional["trio.abc.SocketFactory"]: """Set a custom socket object factory. This function allows you to replace Trio's normal socket class with a @@ -253,7 +253,7 @@ async def getprotobyname(name): ################################################################ -def from_stdlib_socket(sock: _stdlib_socket.socket) -> "_SocketType": +def from_stdlib_socket(sock: _stdlib_socket.socket) -> "SocketType": """Convert a standard library :func:`socket.socket` object into a Trio socket object. @@ -406,156 +406,151 @@ async def wrapper(self, *args, **kwargs): # type: ignore[misc] return wrapper -class SocketType: - def __init__(self): +class SocketType(metaclass=ABCMeta): + @property + def family(self) -> int: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" ) - if TYPE_CHECKING: - - @property - def family(self) -> int: - ... - - @property - def type(self) -> int: - ... + @property + def type(self) -> int: + ... - @property - def proto(self) -> int: - ... + @property + def proto(self) -> int: + ... - @property - def did_shutdown_SHUT_WR(self) -> bool: - ... + @property + def did_shutdown_SHUT_WR(self) -> bool: + ... - def __enter__(self: _T) -> _T: - ... + def __enter__(self: _T) -> _T: + ... - def __exit__(self, *args: Any) -> None: - ... + def __exit__(self, *args: Any) -> None: + ... - def dup(self) -> "SocketType": - ... + def dup(self) -> "SocketType": + ... - def close(self) -> None: - ... + def close(self) -> None: + ... - async def bind(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: - ... + async def bind(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: + ... - def shutdown(self, flag: int) -> None: - ... + def shutdown(self, flag: int) -> None: + ... - def is_readable(self) -> bool: - ... + def is_readable(self) -> bool: + ... - async def wait_writable(self) -> None: - ... + async def wait_writable(self) -> None: + ... - async def accept(self) -> Tuple["SocketType", Any]: - ... + async def accept(self) -> Tuple["SocketType", Any]: + ... - async def connect(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: - ... + async def connect(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: + ... - async def recv(self, bufsize: int, flags: int = ...) -> bytes: - ... + async def recv(self, bufsize: int, flags: int = ...) -> bytes: + ... - async def recv_into( - self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... - ) -> int: - ... + async def recv_into( + self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... + ) -> int: + ... - async def recvfrom(self, bufsize: int, flags: int = ...) -> Tuple[bytes, Any]: - ... + async def recvfrom(self, bufsize: int, flags: int = ...) -> Tuple[bytes, Any]: + ... - async def recvfrom_into( - self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... - ) -> Tuple[int, Any]: - ... + async def recvfrom_into( + self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... + ) -> Tuple[int, Any]: + ... - async def recvmsg( - self, bufsize: int, ancbufsize: int = ..., flags: int = ... - ) -> Tuple[bytes, List[Tuple[int, int, bytes]], int, Any]: - ... + async def recvmsg( + self, bufsize: int, ancbufsize: int = ..., flags: int = ... + ) -> Tuple[bytes, List[Tuple[int, int, bytes]], int, Any]: + ... - async def recvmsg_into( - self, - buffers: Iterable[Union[bytearray, memoryview]], - ancbufsize: int = ..., - flags: int = ..., - ) -> Tuple[int, List[Tuple[int, int, bytes]], int, Any]: - ... + async def recvmsg_into( + self, + buffers: Iterable[Union[bytearray, memoryview]], + ancbufsize: int = ..., + flags: int = ..., + ) -> Tuple[int, List[Tuple[int, int, bytes]], int, Any]: + ... - async def send(self, data: bytes, flags: int = ...) -> int: - ... + async def send(self, data: bytes, flags: int = ...) -> int: + ... - async def sendmsg( - self, - buffers: Iterable[Union[bytes, memoryview]], - ancdata: Iterable[Tuple[int, int, Union[bytes, memoryview]]] = ..., - flags: int = ..., - address: Union[Tuple[Any, ...], str] = ..., - ) -> int: - ... + async def sendmsg( + self, + buffers: Iterable[Union[bytes, memoryview]], + ancdata: Iterable[Tuple[int, int, Union[bytes, memoryview]]] = ..., + flags: int = ..., + address: Union[Tuple[Any, ...], str] = ..., + ) -> int: + ... - @overload - async def sendto( - self, data: bytes, address: Union[Tuple[Any, ...], str] - ) -> int: - ... + @overload + async def sendto( + self, data: bytes, address: Union[Tuple[Any, ...], str] + ) -> int: + ... - @overload - async def sendto( - self, data: bytes, flags: int, address: Union[Tuple[Any, ...], str] - ) -> int: - ... + @overload + async def sendto( + self, data: bytes, flags: int, address: Union[Tuple[Any, ...], str] + ) -> int: + ... - async def sendto(self, *args: object, **kwargs: object) -> int: - ... + async def sendto(self, *args: object, **kwargs: object) -> int: + ... - def detach(self) -> int: - ... + def detach(self) -> int: + ... - def get_inheritable(self) -> bool: - ... + def get_inheritable(self) -> bool: + ... - def set_inheritable(self, inheritable: bool) -> None: - ... + def set_inheritable(self, inheritable: bool) -> None: + ... - def fileno(self) -> int: - ... + def fileno(self) -> int: + ... - def getpeername(self) -> Any: - ... + def getpeername(self) -> Any: + ... - def getsockname(self) -> Any: - ... + def getsockname(self) -> Any: + ... - @overload - def getsockopt(self, level: int, optname: int) -> int: - ... + @overload + def getsockopt(self, level: int, optname: int) -> int: + ... - @overload - def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: - ... + @overload + def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: + ... - def getsockopt(self, *args: object, **kwargs: object) -> object: - ... + def getsockopt(self, *args: object, **kwargs: object) -> object: + ... - def setsockopt( - self, level: int, optname: int, value: Union[int, bytes] - ) -> None: - ... + def setsockopt( + self, level: int, optname: int, value: Union[int, bytes] + ) -> None: + ... - def listen(self, backlog: int) -> None: - ... + def listen(self, backlog: int) -> None: + ... - def share(self, process_id: int) -> bytes: - ... + def share(self, process_id: int) -> bytes: + ... class _SocketType(SocketType): From 3ecd3df5ed4562b97914023832c35d75e0cebc48 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 13 Feb 2021 14:56:09 -0500 Subject: [PATCH 41/50] drop the SocketType ABCMeta change for now --- trio/_socket.py | 215 +++++++++++++++++++++++++----------------------- 1 file changed, 110 insertions(+), 105 deletions(-) diff --git a/trio/_socket.py b/trio/_socket.py index 4df7903f01..232efe8f48 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -406,151 +406,156 @@ async def wrapper(self, *args, **kwargs): # type: ignore[misc] return wrapper -class SocketType(metaclass=ABCMeta): - @property - def family(self) -> int: +class SocketType: + def __init__(self): raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" ) - @property - def type(self) -> int: - ... + if TYPE_CHECKING: - @property - def proto(self) -> int: - ... + @property + def family(self) -> int: + ... - @property - def did_shutdown_SHUT_WR(self) -> bool: - ... + @property + def type(self) -> int: + ... - def __enter__(self: _T) -> _T: - ... + @property + def proto(self) -> int: + ... - def __exit__(self, *args: Any) -> None: - ... + @property + def did_shutdown_SHUT_WR(self) -> bool: + ... - def dup(self) -> "SocketType": - ... + def __enter__(self: _T) -> _T: + ... - def close(self) -> None: - ... + def __exit__(self, *args: Any) -> None: + ... - async def bind(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: - ... + def dup(self) -> "SocketType": + ... - def shutdown(self, flag: int) -> None: - ... + def close(self) -> None: + ... - def is_readable(self) -> bool: - ... + async def bind(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: + ... - async def wait_writable(self) -> None: - ... + def shutdown(self, flag: int) -> None: + ... - async def accept(self) -> Tuple["SocketType", Any]: - ... + def is_readable(self) -> bool: + ... - async def connect(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: - ... + async def wait_writable(self) -> None: + ... - async def recv(self, bufsize: int, flags: int = ...) -> bytes: - ... + async def accept(self) -> Tuple["SocketType", Any]: + ... - async def recv_into( - self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... - ) -> int: - ... + async def connect(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: + ... - async def recvfrom(self, bufsize: int, flags: int = ...) -> Tuple[bytes, Any]: - ... + async def recv(self, bufsize: int, flags: int = ...) -> bytes: + ... - async def recvfrom_into( - self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... - ) -> Tuple[int, Any]: - ... + async def recv_into( + self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... + ) -> int: + ... - async def recvmsg( - self, bufsize: int, ancbufsize: int = ..., flags: int = ... - ) -> Tuple[bytes, List[Tuple[int, int, bytes]], int, Any]: - ... + async def recvfrom(self, bufsize: int, flags: int = ...) -> Tuple[bytes, Any]: + ... - async def recvmsg_into( - self, - buffers: Iterable[Union[bytearray, memoryview]], - ancbufsize: int = ..., - flags: int = ..., - ) -> Tuple[int, List[Tuple[int, int, bytes]], int, Any]: - ... + async def recvfrom_into( + self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... + ) -> Tuple[int, Any]: + ... - async def send(self, data: bytes, flags: int = ...) -> int: - ... + async def recvmsg( + self, bufsize: int, ancbufsize: int = ..., flags: int = ... + ) -> Tuple[bytes, List[Tuple[int, int, bytes]], int, Any]: + ... - async def sendmsg( - self, - buffers: Iterable[Union[bytes, memoryview]], - ancdata: Iterable[Tuple[int, int, Union[bytes, memoryview]]] = ..., - flags: int = ..., - address: Union[Tuple[Any, ...], str] = ..., - ) -> int: - ... + async def recvmsg_into( + self, + buffers: Iterable[Union[bytearray, memoryview]], + ancbufsize: int = ..., + flags: int = ..., + ) -> Tuple[int, List[Tuple[int, int, bytes]], int, Any]: + ... - @overload - async def sendto( - self, data: bytes, address: Union[Tuple[Any, ...], str] - ) -> int: - ... + async def send(self, data: bytes, flags: int = ...) -> int: + ... - @overload - async def sendto( - self, data: bytes, flags: int, address: Union[Tuple[Any, ...], str] - ) -> int: - ... + async def sendmsg( + self, + buffers: Iterable[Union[bytes, memoryview]], + ancdata: Iterable[Tuple[int, int, Union[bytes, memoryview]]] = ..., + flags: int = ..., + address: Union[Tuple[Any, ...], str] = ..., + ) -> int: + ... - async def sendto(self, *args: object, **kwargs: object) -> int: - ... + @overload + async def sendto( + self, data: bytes, address: Union[Tuple[Any, ...], str] + ) -> int: + ... - def detach(self) -> int: - ... + @overload + async def sendto( + self, data: bytes, flags: int, address: Union[Tuple[Any, ...], str] + ) -> int: + ... - def get_inheritable(self) -> bool: - ... + async def sendto(self, *args: object, **kwargs: object) -> int: + ... - def set_inheritable(self, inheritable: bool) -> None: - ... + def detach(self) -> int: + ... - def fileno(self) -> int: - ... + def get_inheritable(self) -> bool: + ... - def getpeername(self) -> Any: - ... + def set_inheritable(self, inheritable: bool) -> None: + ... - def getsockname(self) -> Any: - ... + def fileno(self) -> int: + ... - @overload - def getsockopt(self, level: int, optname: int) -> int: - ... + def getpeername(self) -> Any: + ... - @overload - def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: - ... + def getsockname(self) -> Any: + ... - def getsockopt(self, *args: object, **kwargs: object) -> object: - ... + @overload + def getsockopt(self, level: int, optname: int) -> int: + ... - def setsockopt( - self, level: int, optname: int, value: Union[int, bytes] - ) -> None: - ... + @overload + def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: + ... - def listen(self, backlog: int) -> None: - ... + def getsockopt(self, *args: object, **kwargs: object) -> object: + ... - def share(self, process_id: int) -> bytes: - ... + def setsockopt( + self, level: int, optname: int, value: Union[int, bytes] + ) -> None: + ... + + def listen(self, backlog: int) -> None: + ... + + def share(self, process_id: int) -> bytes: + ... class _SocketType(SocketType): From 37b21c11124f5d9c6d287c1bbfab0136cf2464ba Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 13 Feb 2021 15:21:13 -0500 Subject: [PATCH 42/50] nitpick ignore --- docs/source/conf.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index cfce9d2256..7d09b5900e 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -58,6 +58,14 @@ ("py:class", "_T_co"), ("py:class", "_T"), ("py:class", "T_resource"), + ("py:class", "AbstractContextManager"), + ("py:class", "_socket.socket"), + ("py:class", "signal.Signals"), + ("py:class", "trio._signals.SignalReceiver"), + ("py:class", "socket.socket"), + ("py:class", "trio._core._run._RunStatistics"), + ("py:class", "socket.AddressFamily"), + ("py:class", "socket.SocketKind"), ] autodoc_inherit_docstrings = False default_role = "obj" From f3c02e4d520d9bdb19d5b6dd8f23a7b698353e41 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 13 Feb 2021 22:19:25 -0500 Subject: [PATCH 43/50] a start at disallow_untyped_defs = True --- mypy.ini | 2 +- trio/_channel.py | 4 +- trio/_core/_entry_queue.py | 2 +- trio/_core/_exceptions.py | 2 +- trio/_core/_generated_run.py | 2 +- trio/_core/_instrumentation.py | 2 +- trio/_core/_io_windows.py | 2 +- trio/_core/_local.py | 6 +- trio/_core/_mock_clock.py | 4 +- trio/_core/_multierror.py | 109 +++++----- trio/_core/_run.py | 12 +- trio/_core/_thread_cache.py | 24 ++- trio/_core/_unbounded_queue.py | 2 +- trio/_core/_wakeup_socketpair.py | 2 +- trio/_core/_windows_cffi.py | 10 +- trio/_core/tests/test_asyncgen.py | 6 +- trio/_core/tests/test_guest_mode.py | 24 +-- trio/_core/tests/test_instrumentation.py | 16 +- trio/_core/tests/test_ki.py | 6 +- trio/_core/tests/test_local.py | 8 +- trio/_core/tests/test_mock_clock.py | 10 +- trio/_core/tests/test_multierror.py | 28 +-- trio/_core/tests/test_parking_lot.py | 8 +- trio/_core/tests/test_run.py | 188 +++++++++--------- trio/_core/tests/test_thread_cache.py | 8 +- trio/_core/tests/test_tutil.py | 2 +- trio/_core/tests/test_unbounded_queue.py | 10 +- trio/_core/tests/test_windows.py | 8 +- trio/_deprecate.py | 24 ++- trio/_file_io.py | 2 +- trio/_highlevel_socket.py | 4 +- trio/_path.py | 4 +- trio/_signals.py | 2 +- trio/_socket.py | 8 +- trio/_ssl.py | 2 +- trio/_subprocess.py | 2 +- trio/_sync.py | 12 +- trio/_tools/gen_exports.py | 27 ++- trio/_unix_pipes.py | 4 +- trio/_util.py | 56 +++--- trio/_wait_for_object.py | 6 +- trio/testing/_check_streams.py | 2 +- trio/testing/_checkpoints.py | 6 +- trio/testing/_memory_streams.py | 10 +- trio/testing/_trio_test.py | 2 +- trio/tests/test_abc.py | 4 +- trio/tests/test_channel.py | 26 +-- trio/tests/test_exports.py | 4 +- trio/tests/test_file_io.py | 32 +-- trio/tests/test_highlevel_generic.py | 4 +- .../test_highlevel_open_tcp_listeners.py | 16 +- trio/tests/test_highlevel_open_tcp_stream.py | 46 ++--- trio/tests/test_highlevel_open_unix_stream.py | 6 +- trio/tests/test_highlevel_serve_listeners.py | 8 +- trio/tests/test_highlevel_socket.py | 16 +- trio/tests/test_highlevel_ssl_helpers.py | 4 +- trio/tests/test_path.py | 36 ++-- trio/tests/test_scheduler_determinism.py | 4 +- trio/tests/test_signals.py | 16 +- trio/tests/test_socket.py | 50 ++--- trio/tests/test_ssl.py | 62 +++--- trio/tests/test_subprocess.py | 28 +-- trio/tests/test_sync.py | 22 +- trio/tests/test_testing.py | 36 ++-- trio/tests/test_threads.py | 38 ++-- trio/tests/test_unix_pipes.py | 18 +- trio/tests/test_util.py | 12 +- trio/tests/test_wait_for_object.py | 4 +- trio/tests/test_windows_pipes.py | 12 +- trio/tests/tools/test_gen_exports.py | 7 +- 70 files changed, 622 insertions(+), 569 deletions(-) diff --git a/mypy.ini b/mypy.ini index b253e4418e..147dc059bc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -19,7 +19,7 @@ disallow_subclassing_any = True # Enable gradually / for new modules check_untyped_defs = False disallow_untyped_calls = False -disallow_untyped_defs = False +disallow_untyped_defs = True # DO NOT use `ignore_errors`; it doesn't apply # downstream and users have to deal with them. diff --git a/trio/_channel.py b/trio/_channel.py index bd7f314ce1..eb6a211dad 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -132,7 +132,7 @@ class MemorySendChannel(SendChannel[_T_contra], metaclass=NoPublicConstructor): def __attrs_post_init__(self): self._state.open_send_channels += 1 - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), id(self._state) ) @@ -270,7 +270,7 @@ def __attrs_post_init__(self): def statistics(self): return self._state.statistics() - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), id(self._state) ) diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py index a1587a18cd..bf887d8b7f 100644 --- a/trio/_core/_entry_queue.py +++ b/trio/_core/_entry_queue.py @@ -147,7 +147,7 @@ class TrioToken(metaclass=NoPublicConstructor): __slots__ = ("_reentry_queue",) - def __init__(self, reentry_queue): + def __init__(self, reentry_queue) -> None: self._reentry_queue = reentry_queue def run_sync_soon(self, sync_fn, *args, idempotent=False): diff --git a/trio/_core/_exceptions.py b/trio/_core/_exceptions.py index 6189c484b4..3d5c1d831d 100644 --- a/trio/_core/_exceptions.py +++ b/trio/_core/_exceptions.py @@ -61,7 +61,7 @@ class Cancelled(BaseException, metaclass=NoPublicConstructor): """ - def __str__(self): + def __str__(self) -> str: return "Cancelled" diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 799b44d0cd..8789f52b91 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -225,7 +225,7 @@ async def lock_taker(lock): await lock.acquire() lock.release() - async def test_lock_fairness(): + async def test_lock_fairness() -> None: lock = trio.Lock() await lock.acquire() async with trio.open_nursery() as nursery: diff --git a/trio/_core/_instrumentation.py b/trio/_core/_instrumentation.py index e14c1ef1e0..ea990c0b95 100644 --- a/trio/_core/_instrumentation.py +++ b/trio/_core/_instrumentation.py @@ -29,7 +29,7 @@ class Instruments(Dict[str, Dict[Instrument, None]]): __slots__ = () - def __init__(self, incoming: Sequence[Instrument]): + def __init__(self, incoming: Sequence[Instrument]) -> None: self["_all"] = {} for instrument in incoming: self.add_instrument(instrument) diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index c7c40f6701..f67e03b1f7 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -387,7 +387,7 @@ class CompletionKeyEventInfo: class WindowsIOManager: - def __init__(self): + def __init__(self) -> None: # If this method raises an exception, then __del__ could run on a # half-initialized object. So we initialize everything that __del__ # touches to safe values up front, before we do anything that can diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 3d63edb14b..b9a1974b8d 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -22,7 +22,7 @@ class _RunVarToken(Generic[_T]): def empty(cls, var: "RunVar[_T]") -> "_RunVarToken[_T]": return cls(var, value=cls._no_value) - def __init__(self, var: "RunVar", value: Union[_NoValue, _T]): + def __init__(self, var: "RunVar", value: Union[_NoValue, _T]) -> None: self._var = var self.previous_value = value self.redeemed = False @@ -70,7 +70,7 @@ def __init__(self, name: str) -> None: ... @overload - def __init__(self, name: str, default: _T): + def __init__(self, name: str, default: _T) -> None: ... def __init__(self, name: str, default: object = _NO_DEFAULT) -> None: @@ -140,5 +140,5 @@ def reset(self, token: _RunVarToken) -> None: token.redeemed = True - def __repr__(self): + def __repr__(self) -> str: return "".format(self._name) diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py index 50a60325fc..afd8517d4f 100644 --- a/trio/_core/_mock_clock.py +++ b/trio/_core/_mock_clock.py @@ -62,7 +62,7 @@ class MockClock(Clock, metaclass=Final): """ - def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf) -> None: # when the real clock said 'real_base', the virtual time was # 'virtual_base', and since then it's advanced at 'rate' virtual # seconds per real second. @@ -77,7 +77,7 @@ def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): self.rate = rate self.autojump_threshold = autojump_threshold - def __repr__(self): + def __repr__(self) -> str: return "".format( self.current_time(), self._rate, id(self) ) diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 1d535edeab..da008f8f44 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -1,7 +1,18 @@ import sys import traceback import textwrap -from typing import Callable, ContextManager, Optional, overload, Set, Union +from types import TracebackType +from typing import ( + Callable, + ContextManager, + Iterator, + List, + Optional, + overload, + Set, + Type, + Union, +) import warnings import attr @@ -17,25 +28,23 @@ # MultiError ################################################################ +_Handler = Callable[[BaseException], Optional[BaseException]] + @overload -def _filter_impl( - handler: Callable[[Exception], Optional[Exception]], root_exc: Exception -) -> Optional[Exception]: +def _filter_impl(handler: _Handler, root_exc: "MultiError") -> Optional[BaseException]: ... @overload -def _filter_impl( - handler: Callable[[Exception], Optional[Exception]], root_exc: "MultiError" -) -> Optional[Union[Exception, "MultiError"]]: +def _filter_impl(handler: _Handler, root_exc: BaseException) -> Optional[BaseException]: ... def _filter_impl( - handler: Callable[[Exception], Optional[Exception]], - root_exc: Union[Exception, "MultiError"], -) -> Optional[Union[Exception, "MultiError"]]: + handler: _Handler, + root_exc: BaseException, +) -> Optional[BaseException]: # We have a tree of MultiError's, like: # # MultiError([ @@ -94,9 +103,7 @@ def _filter_impl( # Filters a subtree, ignoring tracebacks, while keeping a record of # which MultiErrors were preserved unchanged - def filter_tree( - exc: Union[Exception, "MultiError"], preserved: Set[int] - ) -> Optional[Union[Exception, "MultiError"]]: + def filter_tree(exc: BaseException, preserved: Set[int]) -> Optional[BaseException]: if isinstance(exc, MultiError): new_exceptions = [] changed = False @@ -120,7 +127,11 @@ def filter_tree( new_exc.__context__ = exc return new_exc - def push_tb_down(tb, exc, preserved): + def push_tb_down( + tb: Optional[TracebackType], + exc: BaseException, + preserved: Set[int], + ) -> None: if id(exc) in preserved: return new_tb = concat_tb(tb, exc.__traceback__) @@ -146,13 +157,18 @@ def push_tb_down(tb, exc, preserved): # result: if the exception gets modified, then the 'raise' here makes this # frame show up in the traceback; otherwise, we leave no trace.) @attr.s(frozen=True) -class MultiErrorCatcher: - _handler = attr.ib() +class MultiErrorCatcher(ContextManager[None]): + _handler: _Handler = attr.ib() - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, etype, exc, tb): + def __exit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: if exc is not None: filtered_exc = MultiError.filter(self._handler, exc) @@ -173,6 +189,7 @@ def __exit__(self, etype, exc, tb): _, value, _ = sys.exc_info() assert value is filtered_exc value.__context__ = old_context + return None class MultiError(BaseException): @@ -198,7 +215,7 @@ class MultiError(BaseException): """ - def __init__(self, exceptions): + def __init__(self, exceptions: List[BaseException]) -> None: # Avoid recursion when exceptions[0] returned by __new__() happens # to be a MultiError and subsequently __init__() is called. if hasattr(self, "exceptions"): @@ -207,7 +224,9 @@ def __init__(self, exceptions): return self.exceptions = exceptions - def __new__(cls, exceptions): + def __new__( # type: ignore[misc] + cls, exceptions: List[Union[Exception, "MultiError"]] + ) -> Union[Exception, "MultiError"]: exceptions = list(exceptions) for exc in exceptions: if not isinstance(exc, BaseException): @@ -224,20 +243,20 @@ def __new__(cls, exceptions): # In an earlier version of the code, we didn't define __init__ and # simply set the `exceptions` attribute directly on the new object. # However, linters expect attributes to be initialized in __init__. - return BaseException.__new__(cls, exceptions) + return BaseException.__new__(cls, exceptions) # type: ignore[no-any-return, call-arg] - def __str__(self): + def __str__(self) -> str: return ", ".join(repr(exc) for exc in self.exceptions) - def __repr__(self): + def __repr__(self) -> str: return "".format(self) @classmethod def filter( cls, - handler: Callable[[Exception], Optional[Exception]], - root_exc: Union[Exception, "MultiError"], - ) -> Optional[Union[Exception, "MultiError"]]: + handler: _Handler, + root_exc: BaseException, + ) -> Optional[BaseException]: """Apply the given ``handler`` to all the exceptions in ``root_exc``. Args: @@ -256,9 +275,7 @@ def filter( return _filter_impl(handler, root_exc) @classmethod - def catch( - cls, handler: Callable[[Exception], Optional[Exception]] - ) -> ContextManager[None]: + def catch(cls, handler: _Handler) -> ContextManager[None]: """Return a context manager that catches and re-throws exceptions after running :meth:`filter` on them. @@ -399,21 +416,21 @@ def concat_tb(head, tail): def traceback_exception_init( - self, - exc_type, - exc_value, - exc_traceback, + self: traceback.TracebackException, + exc_type: Type[BaseException], + exc_value: BaseException, + exc_traceback: TracebackType, *, - limit=None, - lookup_lines=True, - capture_locals=False, - _seen=None, -): + limit: Optional[int] = None, + lookup_lines: bool = True, + capture_locals: bool = False, + _seen: Optional[set] = None, +) -> None: if _seen is None: _seen = set() # Capture the original exception and its cause and context as TracebackExceptions - traceback_exception_original_init( + traceback_exception_original_init( # type: ignore[call-arg] self, exc_type, exc_value, @@ -430,7 +447,7 @@ def traceback_exception_init( for exc in exc_value.exceptions: if exc_key(exc) not in _seen: embedded.append( - traceback.TracebackException.from_exception( + traceback.TracebackException.from_exception( # type: ignore[call-arg] exc, limit=limit, lookup_lines=lookup_lines, @@ -440,19 +457,19 @@ def traceback_exception_init( _seen=set(_seen), ) ) - self.embedded = embedded + self.embedded = embedded # type: ignore[attr-defined] else: - self.embedded = [] + self.embedded = [] # type: ignore[attr-defined] traceback.TracebackException.__init__ = traceback_exception_init # type: ignore[assignment] traceback_exception_original_format = traceback.TracebackException.format -def traceback_exception_format(self, *, chain=True): +def traceback_exception_format(self: traceback.TracebackException, *, chain: bool = True) -> Iterator[str]: yield from traceback_exception_original_format(self, chain=chain) - for i, exc in enumerate(self.embedded): + for i, exc in enumerate(self.embedded): # type: ignore[attr-defined] yield "\nDetails of embedded exception {}:\n\n".format(i + 1) yield from (textwrap.indent(line, " " * 2) for line in exc.format(chain=chain)) @@ -460,7 +477,7 @@ def traceback_exception_format(self, *, chain=True): traceback.TracebackException.format = traceback_exception_format # type: ignore[assignment] -def trio_excepthook(etype, value, tb): +def trio_excepthook(etype: Type[BaseException], value: BaseException, tb: TracebackType) -> None: for chunk in traceback.format_exception(etype, value, tb): sys.stderr.write(chunk) @@ -483,7 +500,7 @@ def trio_excepthook(etype, value, tb): monkeypatched_or_warned = True else: - def trio_show_traceback(self, etype, value, tb, tb_offset=None): + def trio_show_traceback(self: object, etype: Type[BaseException], value: BaseException, tb: TracebackType, tb_offset: Optional[int] = None) -> None: # XX it would be better to integrate with IPython's fancy # exception formatting stuff (and not ignore tb_offset) trio_excepthook(etype, value, tb) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 7ddf83507a..ecb64f0ed3 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -594,7 +594,7 @@ def __exit__( assert value is remaining_error_after_cancel_scope value.__context__ = old_context - def __repr__(self): + def __repr__(self) -> str: if self._cancel_status is not None: binding = "active" elif self._has_been_entered: @@ -767,7 +767,7 @@ class _TaskStatus: _called_started = attr.ib(default=False) _value = attr.ib(default=None) - def __repr__(self): + def __repr__(self) -> str: return "".format(id(self)) def started(self, value=None): @@ -908,7 +908,7 @@ class Nursery(metaclass=NoPublicConstructor): in response to some external event. """ - def __init__(self, parent_task: "Task", cancel_scope: CancelScope): + def __init__(self, parent_task: "Task", cancel_scope: CancelScope) -> None: self._parent_task = parent_task parent_task._child_nurseries.append(self) # the cancel status that children inherit - we take a snapshot, so it @@ -1144,7 +1144,7 @@ class Task(metaclass=NoPublicConstructor): _cancel_points: int = attr.ib(default=0) _schedule_points: int = attr.ib(default=0) - def __repr__(self): + def __repr__(self) -> str: return "".format(self.name, id(self)) @property @@ -1775,7 +1775,7 @@ async def lock_taker(lock): await lock.acquire() lock.release() - async def test_lock_fairness(): + async def test_lock_fairness() -> None: lock = trio.Lock() await lock.acquire() async with trio.open_nursery() as nursery: @@ -2336,7 +2336,7 @@ def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False): class _TaskStatusIgnored: - def __repr__(self): + def __repr__(self) -> str: return "TASK_STATUS_IGNORED" def started(self, value=None): diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py index ae5e8450b9..70768878df 100644 --- a/trio/_core/_thread_cache.py +++ b/trio/_core/_thread_cache.py @@ -1,6 +1,7 @@ from threading import Thread, Lock import outcome from itertools import count +from typing import Callable, Dict, Optional, Tuple # The "thread cache" is a simple unbounded thread pool, i.e., it automatically # spawns as many threads as needed to handle all the requests its given. Its @@ -39,10 +40,14 @@ name_counter = count() +_Fn = Callable[..., object] +_Deliver = Callable[[outcome.Outcome], object] +_Job = Tuple[_Fn, _Deliver] + class WorkerThread: - def __init__(self, thread_cache): - self._job = None + def __init__(self, thread_cache: "ThreadCache") -> None: + self._job: Optional[_Job] = None self._thread_cache = thread_cache # This Lock is used in an unconventional way. # @@ -56,11 +61,14 @@ def __init__(self, thread_cache): thread.name = f"Trio worker thread {next(name_counter)}" thread.start() - def _work(self): + def _work(self) -> None: while True: if self._worker_lock.acquire(timeout=IDLE_TIMEOUT): # We got a job - fn, deliver = self._job + fn: _Fn + deliver: _Deliver + # type ignoring to avoid any runtime cost of casting etc + fn, deliver = self._job # type: ignore[misc] self._job = None result = outcome.capture(fn) # Tell the cache that we're available to be assigned a new @@ -90,10 +98,10 @@ def _work(self): class ThreadCache: - def __init__(self): - self._idle_workers = {} + def __init__(self) -> None: + self._idle_workers: Dict[WorkerThread, None] = {} - def start_thread_soon(self, fn, deliver): + def start_thread_soon(self, fn: _Fn, deliver: _Deliver) -> None: try: worker, _ = self._idle_workers.popitem() except KeyError: @@ -105,7 +113,7 @@ def start_thread_soon(self, fn, deliver): THREAD_CACHE = ThreadCache() -def start_thread_soon(fn, deliver): +def start_thread_soon(fn: _Fn, deliver: _Deliver) -> None: """Runs ``deliver(outcome.capture(fn))`` in a worker thread. Generally ``fn`` does some blocking work, and ``deliver`` delivers the diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index b57cbddf78..f6eb7a9c52 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -58,7 +58,7 @@ def __init__(self) -> None: # used to allow handoff from put to the first task in the lot self._can_get = False - def __repr__(self): + def __repr__(self) -> str: return "".format(len(self._data)) def qsize(self): diff --git a/trio/_core/_wakeup_socketpair.py b/trio/_core/_wakeup_socketpair.py index 121cec584e..77abc2f2ed 100644 --- a/trio/_core/_wakeup_socketpair.py +++ b/trio/_core/_wakeup_socketpair.py @@ -26,7 +26,7 @@ def _has_warn_on_full_buffer(): class WakeupSocketpair: - def __init__(self): + def __init__(self) -> None: self.wakeup_sock, self.write_sock = socket.socketpair() self.wakeup_sock.setblocking(False) self.write_sock.setblocking(False) diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py index a1071519e9..7949932dce 100644 --- a/trio/_core/_windows_cffi.py +++ b/trio/_core/_windows_cffi.py @@ -1,6 +1,7 @@ import cffi import re import enum +from typing import Optional, Union ################################################################ # Functions and types @@ -301,7 +302,7 @@ class IoControlCodes(enum.IntEnum): ################################################################ -def _handle(obj): +def _handle(obj: Union[int, object]) -> object: # For now, represent handles as either cffi HANDLEs or as ints. If you # try to pass in a file descriptor instead, it's not going to work # out. (For that msvcrt.get_osfhandle does the trick, but I don't know if @@ -314,7 +315,12 @@ def _handle(obj): return obj -def raise_winerror(winerror=None, *, filename=None, filename2=None): +def raise_winerror( + winerror: Optional[object] = None, + *, + filename: Optional[str] = None, + filename2: Optional[str] = None, +) -> None: if winerror is None: winerror, msg = ffi.getwinerror() else: diff --git a/trio/_core/tests/test_asyncgen.py b/trio/_core/tests/test_asyncgen.py index 3635b4bdbc..fd4e26d392 100644 --- a/trio/_core/tests/test_asyncgen.py +++ b/trio/_core/tests/test_asyncgen.py @@ -10,7 +10,7 @@ import _pytest.capture -def test_asyncgen_basics(): +def test_asyncgen_basics() -> None: collected = [] async def example(cause): @@ -85,7 +85,7 @@ async def async_main(): assert agen.ag_frame is None # all should now be exhausted -async def test_asyncgen_throws_during_finalization(caplog): +async def test_asyncgen_throws_during_finalization(caplog) -> None: record = [] async def agen(): @@ -172,7 +172,7 @@ async def async_main(): assert record == ["innermost", *range(100)] -def test_last_minute_gc_edge_case(): +def test_last_minute_gc_edge_case() -> None: saved = [] record = [] needs_retry = True diff --git a/trio/_core/tests/test_guest_mode.py b/trio/_core/tests/test_guest_mode.py index ad0322c794..707688c739 100644 --- a/trio/_core/tests/test_guest_mode.py +++ b/trio/_core/tests/test_guest_mode.py @@ -71,7 +71,7 @@ def done_callback(outcome): del todo, run_sync_soon_threadsafe, done_callback -def test_guest_trivial(): +def test_guest_trivial() -> None: async def trio_return(in_host): await trio.sleep(0) return "ok" @@ -85,7 +85,7 @@ async def trio_fail(in_host): trivial_guest_run(trio_fail) -def test_guest_can_do_io(): +def test_guest_can_do_io() -> None: async def trio_main(in_host): record = [] a, b = trio.socket.socketpair() @@ -105,7 +105,7 @@ async def do_receive(): trivial_guest_run(trio_main) -def test_host_can_directly_wake_trio_task(): +def test_host_can_directly_wake_trio_task() -> None: async def trio_main(in_host): ev = trio.Event() in_host(ev.set) @@ -115,7 +115,7 @@ async def trio_main(in_host): assert trivial_guest_run(trio_main) == "ok" -def test_host_altering_deadlines_wakes_trio_up(): +def test_host_altering_deadlines_wakes_trio_up() -> None: def set_deadline(cscope, new_deadline): cscope.deadline = new_deadline @@ -138,7 +138,7 @@ async def trio_main(in_host): assert trivial_guest_run(trio_main) == "ok" -def test_warn_set_wakeup_fd_overwrite(): +def test_warn_set_wakeup_fd_overwrite() -> None: assert signal.set_wakeup_fd(-1) == -1 async def trio_main(in_host): @@ -206,7 +206,7 @@ async def trio_check_wakeup_fd_unaltered(in_host): assert signal.set_wakeup_fd(-1) == a.fileno() -def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked(): +def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked() -> None: # This is designed to hit the branch in unrolled_run where: # idle_primed=True # runner.runq is empty @@ -245,7 +245,7 @@ async def get_woken_by_host_deadline(watb_cscope): # actually end. So in after_io_wait we schedule a second host # call to tear things down. class InstrumentHelper: - def __init__(self): + def __init__(self) -> None: self.primed = False def before_io_wait(self, timeout): @@ -277,7 +277,7 @@ def after_io_wait(self, timeout): assert trivial_guest_run(trio_main) == "ok" -def test_guest_warns_if_abandoned(): +def test_guest_warns_if_abandoned() -> None: # This warning is emitted from the garbage collector. So we have to make # sure that our abandoned run is garbage. The easiest way to do this is to # put it into a function, so that we're sure all the local state, @@ -346,7 +346,7 @@ def trio_done_callback(main_outcome): loop.close() -def test_guest_mode_on_asyncio(): +def test_guest_mode_on_asyncio() -> None: async def trio_main(): print("trio_main!") @@ -405,7 +405,7 @@ async def aio_pingpong(from_trio, to_trio): ) -def test_guest_mode_internal_errors(monkeypatch, recwarn): +def test_guest_mode_internal_errors(monkeypatch, recwarn) -> None: with monkeypatch.context() as m: async def crash_in_run_loop(in_host): @@ -446,7 +446,7 @@ def bad_get_events(*args): gc_collect_harder() -def test_guest_mode_ki(): +def test_guest_mode_ki() -> None: assert signal.getsignal(signal.SIGINT) is signal.default_int_handler # Check SIGINT in Trio func and in host func @@ -478,7 +478,7 @@ async def trio_main_raising(in_host): assert signal.getsignal(signal.SIGINT) is signal.default_int_handler -def test_guest_mode_autojump_clock_threshold_changing(): +def test_guest_mode_autojump_clock_threshold_changing() -> None: # This is super obscure and probably no-one will ever notice, but # technically mutating the MockClock.autojump_threshold from the host # should wake up the guest, so let's test it. diff --git a/trio/_core/tests/test_instrumentation.py b/trio/_core/tests/test_instrumentation.py index 837a3e53b7..2666e6bff3 100644 --- a/trio/_core/tests/test_instrumentation.py +++ b/trio/_core/tests/test_instrumentation.py @@ -33,7 +33,7 @@ def filter_tasks(self, tasks): yield item -def test_instruments(recwarn): +def test_instruments(recwarn) -> None: r1 = TaskRecorder() r2 = TaskRecorder() r3 = TaskRecorder() @@ -77,7 +77,7 @@ async def main(): assert list(r1.filter_tasks([task])) == expected -def test_instruments_interleave(): +def test_instruments_interleave() -> None: tasks = {} async def two_step1(): @@ -120,7 +120,7 @@ async def main(): check_sequence_matches(list(r.filter_tasks(tasks.values())), expected) -def test_null_instrument(): +def test_null_instrument() -> None: # undefined instrument methods are skipped class NullInstrument: def something_unrelated(self): @@ -132,7 +132,7 @@ async def main(): _core.run(main, instruments=[NullInstrument()]) -def test_instrument_before_after_run(): +def test_instrument_before_after_run() -> None: record = [] class BeforeAfterRun: @@ -149,7 +149,7 @@ async def main(): assert record == ["before_run", "after_run"] -def test_instrument_task_spawn_exit(): +def test_instrument_task_spawn_exit() -> None: record = [] class SpawnExitRecorder: @@ -169,7 +169,7 @@ async def main(): # This test also tests having a crash before the initial task is even spawned, # which is very difficult to handle. -def test_instruments_crash(caplog): +def test_instruments_crash(caplog) -> None: record = [] class BrokenInstrument: @@ -200,7 +200,7 @@ async def main(): assert "Instrument has been disabled" in caplog.records[0].message -def test_instruments_monkeypatch(): +def test_instruments_monkeypatch() -> None: class NullInstrument(_abc.Instrument): pass @@ -232,7 +232,7 @@ async def main(): _core.run(main, instruments=[instrument]) -def test_instrument_that_raises_on_getattr(): +def test_instrument_that_raises_on_getattr() -> None: class EvilInstrument: def task_exited(self, task): assert False # pragma: no cover diff --git a/trio/_core/tests/test_ki.py b/trio/_core/tests/test_ki.py index b803c16536..f6a5d88f3a 100644 --- a/trio/_core/tests/test_ki.py +++ b/trio/_core/tests/test_ki.py @@ -26,7 +26,7 @@ def ki_self(): signal_raise(signal.SIGINT) -def test_ki_self(): +def test_ki_self() -> None: with pytest.raises(KeyboardInterrupt): ki_self() @@ -223,7 +223,7 @@ async def agen_unprotected3() -> AsyncIterator[None]: # Test the case where there's no magic local anywhere in the call stack -def test_ki_disabled_out_of_context(): +def test_ki_disabled_out_of_context() -> None: assert _core.currently_ki_protected() @@ -462,7 +462,7 @@ async def main_g() -> None: _core.run(main_g) -def test_ki_is_good_neighbor(): +def test_ki_is_good_neighbor() -> None: # in the unlikely event someone overwrites our signal handler, we leave # the overwritten one be try: diff --git a/trio/_core/tests/test_local.py b/trio/_core/tests/test_local.py index 7f403168ea..65a399e298 100644 --- a/trio/_core/tests/test_local.py +++ b/trio/_core/tests/test_local.py @@ -4,7 +4,7 @@ # scary runvar tests -def test_runvar_smoketest(): +def test_runvar_smoketest() -> None: t1 = _core.RunVar("test1") t2 = _core.RunVar("test2", default="catfish") @@ -33,7 +33,7 @@ async def second_check(): _core.run(second_check) -def test_runvar_resetting(): +def test_runvar_resetting() -> None: t1 = _core.RunVar("test1") t2 = _core.RunVar("test2", default="dogfish") t3 = _core.RunVar("test3") @@ -66,7 +66,7 @@ async def reset_check(): _core.run(reset_check) -def test_runvar_sync(): +def test_runvar_sync() -> None: t1 = _core.RunVar("test1") async def sync_check(): @@ -97,7 +97,7 @@ async def task2(tok): _core.run(sync_check) -def test_accessing_runvar_outside_run_call_fails(): +def test_accessing_runvar_outside_run_call_fails() -> None: t1 = _core.RunVar("test1") with pytest.raises(RuntimeError): diff --git a/trio/_core/tests/test_mock_clock.py b/trio/_core/tests/test_mock_clock.py index 35944760d4..7242497b58 100644 --- a/trio/_core/tests/test_mock_clock.py +++ b/trio/_core/tests/test_mock_clock.py @@ -10,7 +10,7 @@ from .tutil import slow -def test_mock_clock(): +def test_mock_clock() -> None: REAL_NOW = 123.0 c = MockClock() c._real_clock = lambda: REAL_NOW @@ -54,7 +54,7 @@ def test_mock_clock(): assert c2.current_time() < 10 -async def test_mock_clock_autojump(mock_clock): +async def test_mock_clock_autojump(mock_clock) -> None: assert mock_clock.autojump_threshold == inf mock_clock.autojump_threshold = 0 @@ -94,7 +94,7 @@ async def test_mock_clock_autojump(mock_clock): await sleep(100000) -async def test_mock_clock_autojump_interference(mock_clock): +async def test_mock_clock_autojump_interference(mock_clock) -> None: mock_clock.autojump_threshold = 0.02 mock_clock2 = MockClock() @@ -111,7 +111,7 @@ async def test_mock_clock_autojump_interference(mock_clock): await sleep(100000) -def test_mock_clock_autojump_preset(): +def test_mock_clock_autojump_preset() -> None: # Check that we can set the autojump_threshold before the clock is # actually in use, and it gets picked up mock_clock = MockClock(autojump_threshold=0.1) @@ -121,7 +121,7 @@ def test_mock_clock_autojump_preset(): assert time.perf_counter() - real_start < 1 -async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock): +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock) -> None: # Checks that autojump_threshold=0 doesn't interfere with # calling wait_all_tasks_blocked with the default cushion=0. diff --git a/trio/_core/tests/test_multierror.py b/trio/_core/tests/test_multierror.py index c733c0ddd1..b2f1e26a0c 100644 --- a/trio/_core/tests/test_multierror.py +++ b/trio/_core/tests/test_multierror.py @@ -22,7 +22,7 @@ class NotHashableException(Exception): code = None - def __init__(self, code): + def __init__(self, code) -> None: super().__init__() self.code = code @@ -75,7 +75,7 @@ def einfo(exc): return (type(exc), exc, exc.__traceback__) -def test_concat_tb(): +def test_concat_tb() -> None: tb1 = get_tb(raiser1) tb2 = get_tb(raiser2) @@ -101,7 +101,7 @@ def test_concat_tb(): assert extract_tb(get_tb(raiser2)) == entries2 -def test_MultiError(): +def test_MultiError() -> None: exc1 = get_exc(raiser1) exc2 = get_exc(raiser2) @@ -117,7 +117,7 @@ def test_MultiError(): MultiError([KeyError(), ValueError]) -def test_MultiErrorOfSingleMultiError(): +def test_MultiErrorOfSingleMultiError() -> None: # For MultiError([MultiError]), ensure there is no bad recursion by the # constructor where __init__ is called if __new__ returns a bare MultiError. exceptions = [KeyError(), ValueError()] @@ -127,7 +127,7 @@ def test_MultiErrorOfSingleMultiError(): assert b.exceptions == exceptions -async def test_MultiErrorNotHashable(): +async def test_MultiErrorNotHashable() -> None: exc1 = NotHashableException(42) exc2 = NotHashableException(4242) exc3 = ValueError() @@ -140,7 +140,7 @@ async def test_MultiErrorNotHashable(): nursery.start_soon(raise_nothashable, 4242) -def test_MultiError_filter_NotHashable(): +def test_MultiError_filter_NotHashable() -> None: excs = MultiError([NotHashableException(42), ValueError()]) def handle_ValueError(exc): @@ -153,7 +153,7 @@ def handle_ValueError(exc): assert isinstance(filtered_excs, NotHashableException) -def test_traceback_recursion(): +def test_traceback_recursion() -> None: exc1 = RuntimeError() exc2 = KeyError() exc3 = NotHashableException(42) @@ -205,7 +205,7 @@ def assert_tree_eq(m1, m2): assert_tree_eq(e1, e2) -def test_MultiError_filter(): +def test_MultiError_filter() -> None: def null_handler(exc): return exc @@ -283,7 +283,7 @@ def filter_all(exc): assert MultiError.filter(filter_all, make_tree()) is None -def test_MultiError_catch(): +def test_MultiError_catch() -> None: # No exception to catch def noop(_): @@ -380,14 +380,14 @@ def assert_match_in_seq(pattern_list, string): offset = match.end() -def test_assert_match_in_seq(): +def test_assert_match_in_seq() -> None: assert_match_in_seq(["a", "b"], "xx a xx b xx") assert_match_in_seq(["b", "a"], "xx b xx a xx") with pytest.raises(AssertionError): assert_match_in_seq(["a", "b"], "xx b xx a xx") -def test_format_exception(): +def test_format_exception() -> None: exc = get_exc(raiser1) formatted = "".join(format_exception(*einfo(exc))) assert "raiser1_string" in formatted @@ -565,7 +565,7 @@ def raise2_raiser1(): ) -def test_logging(caplog): +def test_logging(caplog) -> None: exc1 = get_exc(raiser1) exc2 = get_exc(raiser2) @@ -638,12 +638,12 @@ def check_simple_excepthook(completed): ) -def test_simple_excepthook(): +def test_simple_excepthook() -> None: completed = run_script("simple_excepthook.py") check_simple_excepthook(completed) -def test_custom_excepthook(): +def test_custom_excepthook() -> None: # Check that user-defined excepthooks aren't overridden completed = run_script("custom_excepthook.py") assert_match_in_seq( diff --git a/trio/_core/tests/test_parking_lot.py b/trio/_core/tests/test_parking_lot.py index 13ffe0c066..9dcdfb490a 100644 --- a/trio/_core/tests/test_parking_lot.py +++ b/trio/_core/tests/test_parking_lot.py @@ -6,7 +6,7 @@ from .tutil import check_sequence_matches -async def test_parking_lot_basic(): +async def test_parking_lot_basic() -> None: record = [] async def waiter(i, lot): @@ -85,7 +85,7 @@ async def cancellable_waiter(name, lot, scopes, record): record.append("wake {}".format(name)) -async def test_parking_lot_cancel(): +async def test_parking_lot_cancel() -> None: record = [] scopes = {} @@ -111,7 +111,7 @@ async def test_parking_lot_cancel(): ) -async def test_parking_lot_repark(): +async def test_parking_lot_repark() -> None: record = [] scopes = {} lot1 = ParkingLot() @@ -165,7 +165,7 @@ async def test_parking_lot_repark(): ] -async def test_parking_lot_repark_with_count(): +async def test_parking_lot_repark_with_count() -> None: record = [] scopes = {} lot1 = ParkingLot() diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index 6def9c8e97..c0fcc0dfb1 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -47,7 +47,7 @@ async def sleep_forever(): return await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -def test_basic(): +def test_basic() -> None: async def trivial(x): return x @@ -68,7 +68,7 @@ async def trivial2(x): assert _core.run(trivial2, 1) == 1 -def test_initial_task_error(): +def test_initial_task_error() -> None: async def main(x): raise ValueError(x) @@ -77,7 +77,7 @@ async def main(x): assert excinfo.value.args == (17,) -def test_run_nesting(): +def test_run_nesting() -> None: async def inception(): async def main(): # pragma: no cover pass @@ -89,7 +89,7 @@ async def main(): # pragma: no cover assert "from inside" in str(excinfo.value) -async def test_nursery_warn_use_async_with(): +async def test_nursery_warn_use_async_with() -> None: with pytest.raises(RuntimeError) as excinfo: on = _core.open_nursery() with on: @@ -103,7 +103,7 @@ async def test_nursery_warn_use_async_with(): pass -async def test_nursery_main_block_error_basic(): +async def test_nursery_main_block_error_basic() -> None: exc = ValueError("whoops") with pytest.raises(ValueError) as excinfo: @@ -112,7 +112,7 @@ async def test_nursery_main_block_error_basic(): assert excinfo.value is exc -async def test_child_crash_basic(): +async def test_child_crash_basic() -> None: exc = ValueError("uh oh") async def erroring(): @@ -126,7 +126,7 @@ async def erroring(): assert e is exc -async def test_basic_interleave(): +async def test_basic_interleave() -> None: async def looper(whoami, record): for i in range(3): record.append((whoami, i)) @@ -142,7 +142,7 @@ async def looper(whoami, record): ) -def test_task_crash_propagation(): +def test_task_crash_propagation() -> None: looper_record = [] async def looper(): @@ -168,7 +168,7 @@ async def main(): assert excinfo.value.args == ("argh",) -def test_main_and_task_both_crash(): +def test_main_and_task_both_crash() -> None: # If main crashes and there's also a task crash, then we get both in a # MultiError async def crasher(): @@ -188,7 +188,7 @@ async def main(): } -def test_two_child_crashes(): +def test_two_child_crashes() -> None: async def crasher(etype): raise etype @@ -205,7 +205,7 @@ async def main(): } -async def test_child_crash_wakes_parent(): +async def test_child_crash_wakes_parent() -> None: async def crasher(): raise ValueError @@ -215,7 +215,7 @@ async def crasher(): await sleep_forever() -async def test_reschedule(): +async def test_reschedule() -> None: t1 = None t2 = None @@ -247,7 +247,7 @@ async def child2(): nursery.start_soon(child2) -async def test_current_time(): +async def test_current_time() -> None: t1 = _core.current_time() # Windows clock is pretty low-resolution -- appveyor tests fail unless we # sleep for a bit here. @@ -256,7 +256,7 @@ async def test_current_time(): assert t1 < t2 -async def test_current_time_with_mock_clock(mock_clock): +async def test_current_time_with_mock_clock(mock_clock) -> None: start = mock_clock.current_time() assert mock_clock.current_time() == _core.current_time() assert mock_clock.current_time() == _core.current_time() @@ -264,11 +264,11 @@ async def test_current_time_with_mock_clock(mock_clock): assert start + 3.14 == mock_clock.current_time() == _core.current_time() -async def test_current_clock(mock_clock): +async def test_current_clock(mock_clock) -> None: assert mock_clock is _core.current_clock() -async def test_current_task(): +async def test_current_task() -> None: parent_task = _core.current_task() async def child(): @@ -278,19 +278,19 @@ async def child(): nursery.start_soon(child) -async def test_root_task(): +async def test_root_task() -> None: root = _core.current_root_task() assert root.parent_nursery is root.eventual_parent_nursery is None -def test_out_of_context(): +def test_out_of_context() -> None: with pytest.raises(RuntimeError): _core.current_task() with pytest.raises(RuntimeError): _core.current_time() -async def test_current_statistics(mock_clock): +async def test_current_statistics(mock_clock) -> None: # Make sure all the early startup stuff has settled down await wait_all_tasks_blocked() @@ -342,7 +342,7 @@ async def child(): assert stats.seconds_to_next_deadline == inf -async def test_cancel_scope_repr(mock_clock): +async def test_cancel_scope_repr(mock_clock) -> None: scope = _core.CancelScope() assert "unbound" in repr(scope) with scope: @@ -358,7 +358,7 @@ async def test_cancel_scope_repr(mock_clock): assert "exited" in repr(scope) -def test_cancel_points(): +def test_cancel_points() -> None: async def main1(): with _core.CancelScope() as scope: await _core.checkpoint_if_cancelled() @@ -396,7 +396,7 @@ async def main4(): _core.run(main4) -async def test_cancel_edge_cases(): +async def test_cancel_edge_cases() -> None: with _core.CancelScope() as scope: # Two cancels in a row -- idempotent scope.cancel() @@ -414,7 +414,7 @@ async def test_cancel_edge_cases(): await sleep_forever() -async def test_cancel_scope_multierror_filtering(): +async def test_cancel_scope_multierror_filtering() -> None: async def crasher(): raise KeyError @@ -457,7 +457,7 @@ async def crasher(): assert False -async def test_precancelled_task(): +async def test_precancelled_task() -> None: # a task that gets spawned into an already-cancelled nursery should begin # execution (https://github.com/python-trio/trio/issues/41), but get a # cancelled error at its first blocking call. @@ -473,7 +473,7 @@ async def blocker(): assert record == ["started"] -async def test_cancel_shielding(): +async def test_cancel_shielding() -> None: with _core.CancelScope() as outer: with _core.CancelScope() as inner: await _core.checkpoint() @@ -514,7 +514,7 @@ async def test_cancel_shielding(): # make sure that cancellation propagates immediately to all children -async def test_cancel_inheritance(): +async def test_cancel_inheritance() -> None: record = set() async def leaf(ident): @@ -536,7 +536,7 @@ async def worker(ident): assert record == {"w1-l1", "w1-l2", "w2-l1", "w2-l2"} -async def test_cancel_shield_abort(): +async def test_cancel_shield_abort() -> None: with _core.CancelScope() as outer: async with _core.open_nursery() as nursery: outer.cancel() @@ -567,7 +567,7 @@ async def sleeper(): assert record == ["sleeping", "cancelled"] -async def test_basic_timeout(mock_clock): +async def test_basic_timeout(mock_clock) -> None: start = _core.current_time() with _core.CancelScope() as scope: assert scope.deadline == inf @@ -604,7 +604,7 @@ async def test_basic_timeout(mock_clock): await _core.checkpoint() -async def test_cancel_scope_nesting(): +async def test_cancel_scope_nesting() -> None: # Nested scopes: if two triggering at once, the outer one wins with _core.CancelScope() as scope1: with _core.CancelScope() as scope2: @@ -643,7 +643,7 @@ async def test_cancel_scope_nesting(): # Regression test for https://github.com/python-trio/trio/issues/1175 -async def test_unshield_while_cancel_propagating(): +async def test_unshield_while_cancel_propagating() -> None: with _core.CancelScope() as outer: with _core.CancelScope() as inner: outer.cancel() @@ -654,7 +654,7 @@ async def test_unshield_while_cancel_propagating(): assert outer.cancelled_caught and not inner.cancelled_caught -async def test_cancel_unbound(): +async def test_cancel_unbound() -> None: async def sleep_until_cancelled(scope): with scope, fail_after(1): await sleep_forever() @@ -728,7 +728,7 @@ async def enter_scope(): assert scope.cancel_called # never become un-cancelled -async def test_cancel_scope_misnesting(): +async def test_cancel_scope_misnesting() -> None: outer = _core.CancelScope() inner = _core.CancelScope() with ExitStack() as stack: @@ -817,7 +817,7 @@ async def test_timekeeping() -> None: assert False -async def test_failed_abort(): +async def test_failed_abort() -> None: stubborn_task = [None] stubborn_scope = [None] record = [] @@ -848,7 +848,7 @@ async def stubborn_sleeper(): assert record == ["sleep", "woke", "cancelled"] -def test_broken_abort(): +def test_broken_abort() -> None: async def main(): # These yields are here to work around an annoying warning -- we're # going to crash the main loop, and if we (by chance) do this before @@ -874,7 +874,7 @@ async def main(): gc_collect_harder() -def test_error_in_run_loop(): +def test_error_in_run_loop() -> None: # Blow stuff up real good to check we at least get a TrioInternalError async def main(): task = _core.current_task() @@ -886,7 +886,7 @@ async def main(): _core.run(main) -async def test_spawn_system_task(): +async def test_spawn_system_task() -> None: record = [] async def system_task(x): @@ -900,7 +900,7 @@ async def system_task(x): # intentionally make a system task crash -def test_system_task_crash(): +def test_system_task_crash() -> None: async def crasher(): raise KeyError @@ -912,7 +912,7 @@ async def main(): _core.run(main) -def test_system_task_crash_MultiError(): +def test_system_task_crash_MultiError() -> None: async def crasher1(): raise KeyError @@ -938,7 +938,7 @@ async def main(): assert isinstance(exc, (KeyError, ValueError)) -def test_system_task_crash_plus_Cancelled(): +def test_system_task_crash_plus_Cancelled() -> None: # Set up a situation where a system task crashes with a # MultiError([Cancelled, ValueError]) async def crasher(): @@ -964,7 +964,7 @@ async def main(): assert type(excinfo.value.__cause__) is ValueError -def test_system_task_crash_KeyboardInterrupt(): +def test_system_task_crash_KeyboardInterrupt() -> None: async def ki(): raise KeyboardInterrupt @@ -986,7 +986,7 @@ async def main(): # 4) this task has timed out # 5) ...but it's on the run queue, so the timeout is queued to be delivered # the next time that it's blocked. -async def test_yield_briefly_checks_for_timeout(mock_clock): +async def test_yield_briefly_checks_for_timeout(mock_clock) -> None: with _core.CancelScope(deadline=_core.current_time() + 5): await _core.checkpoint() with pytest.raises(_core.Cancelled): @@ -1000,7 +1000,7 @@ async def test_yield_briefly_checks_for_timeout(mock_clock): # still nice to know that it works :-). # # Update: it turns out I was right to be nervous! see the next test... -async def test_exc_info(): +async def test_exc_info() -> None: record = [] seq = Sequencer() @@ -1061,7 +1061,7 @@ async def child2(): # like re-raising and exception chaining are broken. # # https://bugs.python.org/issue29587 -async def test_exc_info_after_yield_error(): +async def test_exc_info_after_yield_error() -> None: child_task = None async def child(): @@ -1086,7 +1086,7 @@ async def child(): # Similar to previous test -- if the ValueError() gets sent in via 'throw', # then Python's normal implicit chaining stuff is broken. -async def test_exception_chaining_after_yield_error(): +async def test_exception_chaining_after_yield_error() -> None: child_task = None async def child(): @@ -1107,7 +1107,7 @@ async def child(): assert isinstance(excinfo.value.__context__, KeyError) -async def test_nursery_exception_chaining_doesnt_make_context_loops(): +async def test_nursery_exception_chaining_doesnt_make_context_loops() -> None: async def crasher(): raise KeyError @@ -1119,7 +1119,7 @@ async def crasher(): assert excinfo.value.__context__ is None -def test_TrioToken_identity(): +def test_TrioToken_identity() -> None: async def get_and_check_token(): token = _core.current_trio_token() # Two calls in the same run give the same object @@ -1133,7 +1133,7 @@ async def get_and_check_token(): assert hash(t1) != hash(t2) -async def test_TrioToken_run_sync_soon_basic(): +async def test_TrioToken_run_sync_soon_basic() -> None: record = [] def cb(x): @@ -1146,7 +1146,7 @@ def cb(x): assert record == [("cb", 1)] -def test_TrioToken_run_sync_soon_too_late(): +def test_TrioToken_run_sync_soon_too_late() -> None: token = None async def main(): @@ -1159,7 +1159,7 @@ async def main(): token.run_sync_soon(lambda: None) # pragma: no branch -async def test_TrioToken_run_sync_soon_idempotent(): +async def test_TrioToken_run_sync_soon_idempotent() -> None: record = [] def cb(x): @@ -1186,7 +1186,7 @@ def cb(x): assert record == list(range(100)) -def test_TrioToken_run_sync_soon_idempotent_requeue(): +def test_TrioToken_run_sync_soon_idempotent_requeue() -> None: # We guarantee that if a call has finished, queueing it again will call it # again. Due to the lack of synchronization, this effectively means that # we have to guarantee that once a call has *started*, queueing it again @@ -1212,7 +1212,7 @@ async def main(): assert len(record) >= 2 -def test_TrioToken_run_sync_soon_after_main_crash(): +def test_TrioToken_run_sync_soon_after_main_crash() -> None: record = [] async def main(): @@ -1228,7 +1228,7 @@ async def main(): assert record == ["sync-cb"] -def test_TrioToken_run_sync_soon_crashes(): +def test_TrioToken_run_sync_soon_crashes() -> None: record = set() async def main(): @@ -1249,7 +1249,7 @@ async def main(): assert record == {"2nd run_sync_soon ran", "cancelled!"} -async def test_TrioToken_run_sync_soon_FIFO(): +async def test_TrioToken_run_sync_soon_FIFO() -> None: N = 100 record = [] token = _core.current_trio_token() @@ -1259,7 +1259,7 @@ async def test_TrioToken_run_sync_soon_FIFO(): assert record == list(range(N)) -def test_TrioToken_run_sync_soon_starvation_resistance(): +def test_TrioToken_run_sync_soon_starvation_resistance() -> None: # Even if we push callbacks in from callbacks, so that the callback queue # never empties out, then we still can't starve out other tasks from # running. @@ -1288,7 +1288,7 @@ async def main(): assert record[1][1] >= 19 -def test_TrioToken_run_sync_soon_threaded_stress_test(): +def test_TrioToken_run_sync_soon_threaded_stress_test() -> None: cb_counter = 0 def cb(): @@ -1316,7 +1316,7 @@ async def main(): print(cb_counter) -async def test_TrioToken_run_sync_soon_massive_queue(): +async def test_TrioToken_run_sync_soon_massive_queue() -> None: # There are edge cases in the wakeup fd code when the wakeup fd overflows, # so let's try to make that happen. This is also just a good stress test # in general. (With the current-as-of-2017-02-14 code using a socketpair @@ -1365,7 +1365,7 @@ async def main(): assert record == ["main exiting", "2nd ran"] -async def test_slow_abort_basic(): +async def test_slow_abort_basic() -> None: with _core.CancelScope() as scope: scope.cancel() with pytest.raises(_core.Cancelled): @@ -1380,7 +1380,7 @@ def slow_abort(raise_cancel): await _core.wait_task_rescheduled(slow_abort) -async def test_slow_abort_edge_cases(): +async def test_slow_abort_edge_cases() -> None: record = [] async def slow_aborter(): @@ -1423,7 +1423,7 @@ def slow_abort(raise_cancel): assert record == ["sleeping", "abort-called", "cancelled", "done"] -async def test_task_tree_introspection(): +async def test_task_tree_introspection() -> None: tasks = {} nurseries = {} @@ -1489,7 +1489,7 @@ async def child1(task_status=_core.TASK_STATUS_IGNORED): assert task.eventual_parent_nursery is None -async def test_nursery_closure(): +async def test_nursery_closure() -> None: async def child1(nursery): # We can add new tasks to the nursery even after entering __aexit__, # so long as there are still tasks running @@ -1506,7 +1506,7 @@ async def child2(): nursery.start_soon(child2) -async def test_spawn_name(): +async def test_spawn_name() -> None: async def func1(expected): task = _core.current_task() assert expected in task.name @@ -1523,7 +1523,7 @@ async def func2(): # pragma: no cover spawn_fn(func1, "object", name=object()) -async def test_current_effective_deadline(mock_clock): +async def test_current_effective_deadline(mock_clock) -> None: assert _core.current_effective_deadline() == inf with _core.CancelScope(deadline=5) as scope1: @@ -1545,7 +1545,7 @@ async def test_current_effective_deadline(mock_clock): assert _core.current_effective_deadline() == inf -def test_nice_error_on_bad_calls_to_run_or_spawn(): +def test_nice_error_on_bad_calls_to_run_or_spawn() -> None: def bad_call_run(*args): _core.run(*args) @@ -1573,7 +1573,7 @@ async def async_gen(arg): # pragma: no cover bad_call(async_gen, 0) -def test_calling_asyncio_function_gives_nice_error(): +def test_calling_asyncio_function_gives_nice_error() -> None: async def child_xyzzy(): import asyncio @@ -1592,7 +1592,7 @@ async def misguided(): ) -async def test_asyncio_function_inside_nursery_does_not_explode(): +async def test_asyncio_function_inside_nursery_does_not_explode() -> None: # Regression test for https://github.com/python-trio/trio/issues/552 with pytest.raises(TypeError) as excinfo: async with _core.open_nursery() as nursery: @@ -1603,7 +1603,7 @@ async def test_asyncio_function_inside_nursery_does_not_explode(): assert "asyncio" in str(excinfo.value) -async def test_trivial_yields(): +async def test_trivial_yields() -> None: with assert_checkpoints(): await _core.checkpoint() @@ -1627,7 +1627,7 @@ async def test_trivial_yields(): } -async def test_nursery_start(autojump_clock): +async def test_nursery_start(autojump_clock) -> None: async def no_args(): # pragma: no cover pass @@ -1723,7 +1723,7 @@ async def raise_keyerror_after_started(task_status=_core.TASK_STATUS_IGNORED): assert _core.current_time() == t0 -async def test_task_nursery_stack(): +async def test_task_nursery_stack() -> None: task = _core.current_task() assert task._child_nurseries == [] async with _core.open_nursery() as nursery1: @@ -1736,7 +1736,7 @@ async def test_task_nursery_stack(): assert task._child_nurseries == [] -async def test_nursery_start_with_cancelled_nursery(): +async def test_nursery_start_with_cancelled_nursery() -> None: # This function isn't testing task_status, it's using task_status as a # convenient way to get a nursery that we can test spawning stuff into. async def setup_nursery(task_status=_core.TASK_STATUS_IGNORED): @@ -1769,7 +1769,7 @@ async def sleeping_children(fn, *, task_status=_core.TASK_STATUS_IGNORED): target_nursery.cancel_scope.cancel() -async def test_nursery_start_keeps_nursery_open(autojump_clock): +async def test_nursery_start_keeps_nursery_open(autojump_clock) -> None: async def sleep_a_bit(task_status=_core.TASK_STATUS_IGNORED): await sleep(2) task_status.started() @@ -1808,13 +1808,13 @@ async def start_sleep_then_crash(nursery): assert _core.current_time() - t0 == 7 -async def test_nursery_explicit_exception(): +async def test_nursery_explicit_exception() -> None: with pytest.raises(KeyError): async with _core.open_nursery(): raise KeyError() -async def test_nursery_stop_iteration(): +async def test_nursery_stop_iteration() -> None: async def fail(): raise ValueError @@ -1826,9 +1826,9 @@ async def fail(): assert tuple(map(type, e.exceptions)) == (StopIteration, ValueError) -async def test_nursery_stop_async_iteration(): +async def test_nursery_stop_async_iteration() -> None: class it: - def __init__(self, count): + def __init__(self, count) -> None: self.count = count self.val = 0 @@ -1841,7 +1841,7 @@ async def __anext__(self): return val class async_zip: - def __init__(self, *largs): + def __init__(self, *largs) -> None: self.nexts = [obj.__anext__ for obj in largs] async def _accumulate(self, f, items, i): @@ -1878,7 +1878,7 @@ def handle(exc): assert result == [[0, 0], [1, 1]] -async def test_traceback_frame_removal(): +async def test_traceback_frame_removal() -> None: async def my_child_task(): raise KeyError() @@ -1902,7 +1902,7 @@ async def my_child_task(): assert frame.f_code is my_child_task.__code__ -def test_contextvar_support(): +def test_contextvar_support() -> None: var = contextvars.ContextVar("test") var.set("before") @@ -1921,7 +1921,7 @@ async def inner(): assert var.get() == "before" -async def test_contextvar_multitask(): +async def test_contextvar_multitask() -> None: var = contextvars.ContextVar("test", default="hmmm") async def t1(): @@ -1941,7 +1941,7 @@ async def t2(): await wait_all_tasks_blocked() -def test_system_task_contexts(): +def test_system_task_contexts() -> None: cvar = contextvars.ContextVar("qwilfish") cvar.set("water") @@ -1961,25 +1961,25 @@ async def inner(): _core.run(inner) -def test_Nursery_init(): +def test_Nursery_init() -> None: with pytest.raises(TypeError): _core._run.Nursery(None, None) -async def test_Nursery_private_init(): +async def test_Nursery_private_init() -> None: # context manager creation should not raise async with _core.open_nursery() as nursery: assert False == nursery._closed -def test_Nursery_subclass(): +def test_Nursery_subclass() -> None: with pytest.raises(TypeError): class Subclass(_core._run.Nursery): pass -def test_Cancelled_init(): +def test_Cancelled_init() -> None: with pytest.raises(TypeError): raise _core.Cancelled @@ -1990,26 +1990,26 @@ def test_Cancelled_init(): _core.Cancelled._create() -def test_Cancelled_str(): +def test_Cancelled_str() -> None: cancelled = _core.Cancelled._create() assert str(cancelled) == "Cancelled" -def test_Cancelled_subclass(): +def test_Cancelled_subclass() -> None: with pytest.raises(TypeError): class Subclass(_core.Cancelled): pass -def test_CancelScope_subclass(): +def test_CancelScope_subclass() -> None: with pytest.raises(TypeError): class Subclass(_core.CancelScope): pass -def test_sniffio_integration(): +def test_sniffio_integration() -> None: with pytest.raises(sniffio.AsyncLibraryNotFoundError): sniffio.current_async_library() @@ -2022,7 +2022,7 @@ async def check_inside_trio(): sniffio.current_async_library() -async def test_Task_custom_sleep_data(): +async def test_Task_custom_sleep_data() -> None: task = _core.current_task() assert task.custom_sleep_data is None task.custom_sleep_data = 1 @@ -2036,7 +2036,7 @@ def async_yield(value: _T) -> Iterator[_T]: yield value -async def test_permanently_detach_coroutine_object(): +async def test_permanently_detach_coroutine_object() -> None: task = None pdco_outcome = None @@ -2082,7 +2082,7 @@ async def bad_detach(): nursery.start_soon(bad_detach) -async def test_detach_and_reattach_coroutine_object(): +async def test_detach_and_reattach_coroutine_object() -> None: unrelated_task = None task = None @@ -2128,7 +2128,7 @@ def abort_fn(_): # pragma: no cover # Now it's been reattached, and we can leave the nursery -async def test_detached_coroutine_cancellation(): +async def test_detached_coroutine_cancellation() -> None: abort_fn_called = False task = None @@ -2158,7 +2158,7 @@ def abort_fn(_): assert abort_fn_called -def test_async_function_implemented_in_C(): +def test_async_function_implemented_in_C() -> None: # These used to crash because we'd try to mutate the coroutine object's # cr_frame, but C functions don't have Python frames. @@ -2182,7 +2182,7 @@ async def main(): _core.run(main) -async def test_very_deep_cancel_scope_nesting(): +async def test_very_deep_cancel_scope_nesting() -> None: # This used to crash with a RecursionError in CancelStatus.recalculate with ExitStack() as exit_stack: outermost_scope = _core.CancelScope() @@ -2192,7 +2192,7 @@ async def test_very_deep_cancel_scope_nesting(): outermost_scope.cancel() -async def test_cancel_scope_deadline_duplicates(): +async def test_cancel_scope_deadline_duplicates() -> None: # This exercises an assert in Deadlines._prune, by intentionally creating # duplicate entries in the deadline heap. now = _core.current_time() diff --git a/trio/_core/tests/test_thread_cache.py b/trio/_core/tests/test_thread_cache.py index 9e58d77b9f..234fef251c 100644 --- a/trio/_core/tests/test_thread_cache.py +++ b/trio/_core/tests/test_thread_cache.py @@ -12,7 +12,7 @@ from .._thread_cache import start_thread_soon, ThreadCache -def test_thread_cache_basics(): +def test_thread_cache_basics() -> None: q = Queue() def fn(): @@ -28,7 +28,7 @@ def deliver(outcome): outcome.unwrap() -def test_thread_cache_deref(): +def test_thread_cache_deref() -> None: res = [False] class del_me: @@ -103,7 +103,7 @@ def test_idle_threads_exit(monkeypatch: _pytest.monkeypatch.MonkeyPatch) -> None assert not seen_thread.is_alive() -def test_race_between_idle_exit_and_job_assignment(monkeypatch): +def test_race_between_idle_exit_and_job_assignment(monkeypatch) -> None: # This is a lock where the first few times you try to acquire it with a # timeout, it waits until the lock is available and then pretends to time # out. Using this in our thread cache implementation causes the following @@ -122,7 +122,7 @@ def test_race_between_idle_exit_and_job_assignment(monkeypatch): # everything proceeds as normal. class JankyLock: - def __init__(self): + def __init__(self) -> None: self._lock = threading.Lock() self._counter = 3 diff --git a/trio/_core/tests/test_tutil.py b/trio/_core/tests/test_tutil.py index eb16de883f..07bba9407d 100644 --- a/trio/_core/tests/test_tutil.py +++ b/trio/_core/tests/test_tutil.py @@ -3,7 +3,7 @@ from .tutil import check_sequence_matches -def test_check_sequence_matches(): +def test_check_sequence_matches() -> None: check_sequence_matches([1, 2, 3], [1, 2, 3]) with pytest.raises(AssertionError): check_sequence_matches([1, 3, 2], [1, 2, 3]) diff --git a/trio/_core/tests/test_unbounded_queue.py b/trio/_core/tests/test_unbounded_queue.py index 801c34ce46..b7434c3a13 100644 --- a/trio/_core/tests/test_unbounded_queue.py +++ b/trio/_core/tests/test_unbounded_queue.py @@ -10,7 +10,7 @@ ) -async def test_UnboundedQueue_basic(): +async def test_UnboundedQueue_basic() -> None: q = _core.UnboundedQueue() q.put_nowait("hi") assert await q.get_batch() == ["hi"] @@ -35,7 +35,7 @@ async def test_UnboundedQueue_basic(): repr(q) -async def test_UnboundedQueue_blocking(): +async def test_UnboundedQueue_blocking() -> None: record = [] q = _core.UnboundedQueue() @@ -67,7 +67,7 @@ async def aiter_consumer(): nursery.cancel_scope.cancel() -async def test_UnboundedQueue_fairness(): +async def test_UnboundedQueue_fairness() -> None: q = _core.UnboundedQueue() # If there's no-one else around, we can put stuff in and take it out @@ -114,7 +114,7 @@ async def reader(name): assert record == list(zip(itertools.cycle("ab"), [[i] for i in range(20)])) -async def test_UnboundedQueue_trivial_yields(): +async def test_UnboundedQueue_trivial_yields() -> None: q = _core.UnboundedQueue() q.put_nowait(None) @@ -127,7 +127,7 @@ async def test_UnboundedQueue_trivial_yields(): break -async def test_UnboundedQueue_no_spurious_wakeups(): +async def test_UnboundedQueue_no_spurious_wakeups() -> None: # If we have two tasks waiting, and put two items into the queue... then # only one task wakes up record = [] diff --git a/trio/_core/tests/test_windows.py b/trio/_core/tests/test_windows.py index 930408dba3..f7d0896019 100644 --- a/trio/_core/tests/test_windows.py +++ b/trio/_core/tests/test_windows.py @@ -56,7 +56,7 @@ async def post(key): break print("end loop") - async def test_readinto_overlapped(): + async def test_readinto_overlapped() -> None: data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024 buffer = bytearray(len(data)) @@ -113,7 +113,7 @@ def pipe_with_overlapped_read() -> Iterator[Tuple[BufferedWriter, int]]: kernel32.CloseHandle(ffi.cast("HANDLE", read_handle)) kernel32.CloseHandle(ffi.cast("HANDLE", write_handle)) - def test_forgot_to_register_with_iocp(): + def test_forgot_to_register_with_iocp() -> None: with pipe_with_overlapped_read() as (write_fp, read_handle): with write_fp: write_fp.write(b"test\n") @@ -173,7 +173,7 @@ async def test_too_late_to_cancel() -> None: assert await _core.readinto_overlapped(read_handle, target) == 6 assert target[:6] == b"test2\n" - def test_lsp_that_hooks_select_gives_good_error(monkeypatch): + def test_lsp_that_hooks_select_gives_good_error(monkeypatch) -> None: from .._windows_cffi import WSAIoctls, _handle from .. import _io_windows @@ -193,7 +193,7 @@ def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): ): _core.run(sleep, 0) - def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch): + def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch) -> None: # This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns # self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns # self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to diff --git a/trio/_deprecate.py b/trio/_deprecate.py index 9010d6953e..6b9f2d90d5 100644 --- a/trio/_deprecate.py +++ b/trio/_deprecate.py @@ -1,7 +1,7 @@ import sys from functools import wraps from types import ModuleType -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Optional, TypeVar, Union import warnings import attr @@ -33,13 +33,13 @@ class TrioDeprecationWarning(FutureWarning): """ -def _url_for_issue(issue): +def _url_for_issue(issue: int) -> str: return "https://github.com/python-trio/trio/issues/{}".format(issue) -def _stringify(thing): +def _stringify(thing: object) -> str: if hasattr(thing, "__module__") and hasattr(thing, "__qualname__"): - return "{}.{}".format(thing.__module__, thing.__qualname__) + return "{}.{}".format(thing.__module__, thing.__qualname__) # type: ignore[attr-defined] return str(thing) @@ -125,14 +125,16 @@ def wrapper(*args: object, **kwargs: object) -> object: class DeprecatedAttribute: _not_set = object() - value = attr.ib() - version = attr.ib() - issue = attr.ib() - instead = attr.ib(default=_not_set) + value: str = attr.ib() + version: str = attr.ib() + issue: int = attr.ib() + instead: Union[object, str] = attr.ib(default=_not_set) class _ModuleWithDeprecations(ModuleType): - def __getattr__(self, name): + __deprecated_attributes__: Dict[str, DeprecatedAttribute] + + def __getattr__(self, name: str) -> object: if name in self.__deprecated_attributes__: info = self.__deprecated_attributes__[name] instead = info.instead @@ -146,10 +148,10 @@ def __getattr__(self, name): raise AttributeError(msg.format(self.__name__, name)) -def enable_attribute_deprecations(module_name): +def enable_attribute_deprecations(module_name: str) -> None: module = sys.modules[module_name] module.__class__ = _ModuleWithDeprecations # Make sure that this is always defined so that # _ModuleWithDeprecations.__getattr__ can access it without jumping # through hoops or risking infinite recursion. - module.__deprecated_attributes__ = {} + module.__deprecated_attributes__ = {} # type: ignore[attr-defined] diff --git a/trio/_file_io.py b/trio/_file_io.py index 71cebffa6e..b9122305f8 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -83,7 +83,7 @@ class AsyncIOWrapper(AsyncResource): """ - def __init__(self, file: io.IOBase): + def __init__(self, file: io.IOBase) -> None: self._wrapped = file @property diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 5ab52f7c21..0dd05dc7f0 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -60,7 +60,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -333,7 +333,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: diff --git a/trio/_path.py b/trio/_path.py index 2012a7c87b..55d0bed8b6 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -202,7 +202,7 @@ def __lt__(self, other: os.PathLike) -> bool: def __truediv__(self: _P, *args: Union[os.PathLike, str]) -> _P: ... - def __init__(self, *args): + def __init__(self, *args) -> None: self._wrapped = pathlib.Path(*args) def __getattr__(self, name): @@ -214,7 +214,7 @@ def __getattr__(self, name): def __dir__(self): return super().__dir__() + self._forward - def __repr__(self): + def __repr__(self) -> str: return "trio.Path({})".format(repr(str(self))) def __fspath__(self): diff --git a/trio/_signals.py b/trio/_signals.py index 252872d7bd..4cd123b50c 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -63,7 +63,7 @@ def _signal_handler( class SignalReceiver: - def __init__(self): + def __init__(self) -> None: # {signal num: None} self._pending = OrderedDict() self._lot = trio.lowlevel.ParkingLot() diff --git a/trio/_socket.py b/trio/_socket.py index 232efe8f48..ac5d27b063 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -41,7 +41,7 @@ # return await do_it_properly_with_a_check_point() # class _try_sync: - def __init__(self, blocking_exc_override=None): + def __init__(self, blocking_exc_override=None) -> None: self._blocking_exc_override = blocking_exc_override def _is_blocking_io_error(self, exc): @@ -407,7 +407,7 @@ async def wrapper(self, *args, **kwargs): # type: ignore[misc] class SocketType: - def __init__(self): + def __init__(self) -> None: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" @@ -559,7 +559,7 @@ def share(self, process_id: int) -> bytes: class _SocketType(SocketType): - def __init__(self, sock: _stdlib_socket.socket): + def __init__(self, sock: _stdlib_socket.socket) -> None: if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we # certainly don't want to blindly wrap one of those. @@ -632,7 +632,7 @@ def proto(self) -> int: def did_shutdown_SHUT_WR(self) -> bool: return self._did_shutdown_SHUT_WR - def __repr__(self): + def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") def dup(self): diff --git a/trio/_ssl.py b/trio/_ssl.py index a163f9dee0..6dd55f349c 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -199,7 +199,7 @@ class NeedHandshakeError(Exception): class _Once: - def __init__(self, afn, *args): + def __init__(self, afn, *args) -> None: self._afn = afn self._args = args self.started: bool = False diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 5756bc64b0..1f84345663 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -165,7 +165,7 @@ def __init__( self.args = self._proc.args self.pid = self._proc.pid - def __repr__(self): + def __repr__(self) -> str: returncode = self.returncode if returncode is None: status = "running with PID {}".format(self.pid) diff --git a/trio/_sync.py b/trio/_sync.py index 526be8e318..83c7925aff 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -170,7 +170,7 @@ class CapacityLimiter(metaclass=Final): _total_tokens: int - def __init__(self, total_tokens): + def __init__(self, total_tokens) -> None: self._lot = ParkingLot() self._borrowers = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of @@ -179,7 +179,7 @@ def __init__(self, total_tokens): self.total_tokens = total_tokens assert self._total_tokens == total_tokens - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), len(self._borrowers), self._total_tokens, len(self._lot) ) @@ -386,7 +386,7 @@ class Semaphore(metaclass=Final): """ - def __init__(self, initial_value: int, *, max_value: Optional[int] = None): + def __init__(self, initial_value: int, *, max_value: Optional[int] = None) -> None: if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") if initial_value < 0: @@ -404,7 +404,7 @@ def __init__(self, initial_value: int, *, max_value: Optional[int] = None): self._value = initial_value self._max_value = max_value - def __repr__(self): + def __repr__(self) -> str: if self._max_value is None: max_value_str = "" else: @@ -494,7 +494,7 @@ class _LockImpl: _lot = attr.ib(factory=ParkingLot, init=False) _owner = attr.ib(default=None, init=False) - def __repr__(self): + def __repr__(self) -> str: if self.locked(): s1 = "locked" s2 = " with {} waiters".format(len(self._lot)) @@ -678,7 +678,7 @@ class Condition(metaclass=Final): """ - def __init__(self, lock=None): + def __init__(self, lock=None) -> None: if lock is None: lock = Lock() if not type(lock) is Lock: diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 61e2bf0584..b31f5b36ea 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -10,6 +10,7 @@ import os from pathlib import Path import sys +from typing import Dict, Iterator, List, Tuple, Union from textwrap import indent @@ -57,7 +58,7 @@ """ -def is_function(node): +def is_function(node: ast.AST) -> bool: """Check if the AST node is either a function or an async function """ @@ -66,17 +67,21 @@ def is_function(node): return False -def is_public(node): +def is_public(node: ast.AST) -> bool: """Check if the AST node has a _public decorator""" if not is_function(node): return False + + # the `if` above does this but we have to help out Mypy + assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + for decorator in node.decorator_list: if isinstance(decorator, ast.Name) and decorator.id == "_public": return True return False -def get_public_methods(tree): +def get_public_methods(tree: ast.AST) -> Iterator[Union[ast.FunctionDef, ast.AsyncFunctionDef]]: """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked @@ -84,10 +89,13 @@ def get_public_methods(tree): """ for node in ast.walk(tree): if is_public(node): + # the `if` above does this but we have to help out Mypy + assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + yield node -def create_passthrough_args(funcdef): +def create_passthrough_args(funcdef: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> str: """Given a function definition, create a string that represents taking all the arguments from the function, and passing them through to another invocation of the same function. @@ -107,7 +115,7 @@ def create_passthrough_args(funcdef): return "({})".format(", ".join(call_args)) -def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: +def gen_public_wrappers_source(source_path: Union[Path, str], lookup_path: str) -> str: """Scan the given .py file for @_public decorators, and generate wrapper functions. @@ -132,6 +140,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: contextmanager_decorated = any( decorator.id in {"contextmanager", "contextlib.contextmanager"} for decorator in method.decorator_list + if isinstance(decorator, ast.Name) ) # Remove decorators method.decorator_list = [] @@ -167,7 +176,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: return "\n\n".join(generated) -def matches_disk_files(new_files): +def matches_disk_files(new_files: Dict[str, str]) -> bool: for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False @@ -178,7 +187,7 @@ def matches_disk_files(new_files): return True -def process(sources_and_lookups, *, do_test): +def process(sources_and_lookups: List[Tuple[Union[Path, str], str]], *, do_test: bool) -> None: new_files = {} for source_path, lookup_path in sources_and_lookups: print("Scanning:", source_path) @@ -201,7 +210,7 @@ def process(sources_and_lookups, *, do_test): # This is in fact run in CI, but only in the formatting check job, which # doesn't collect coverage. -def main(): # pragma: no cover +def main() -> None: # pragma: no cover parser = argparse.ArgumentParser( description="Generate python code for public api wrappers" ) @@ -214,7 +223,7 @@ def main(): # pragma: no cover # Double-check we found the right directory assert (source_root / "LICENSE").exists() core = source_root / "trio/_core" - to_wrap = [ + to_wrap: List[Tuple[Union[Path, str], str]] = [ (core / "_run.py", "runner"), (core / "_instrumentation.py", "runner.instruments"), (core / "_io_windows.py", "runner.io_manager"), diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 4424550020..4500ce5318 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -33,7 +33,7 @@ class _FdHolder: # impossible to make this mistake – we'll just get an EBADF. # # (This trick was copied from the stdlib socket module.) - def __init__(self, fd: int): + def __init__(self, fd: int) -> None: # make sure self.fd is always initialized to *something*, because even # if we error out here then __del__ will run and access it. self.fd = -1 @@ -106,7 +106,7 @@ class FdStream(Stream, metaclass=Final): A new `FdStream` object. """ - def __init__(self, fd: int): + def __init__(self, fd: int) -> None: self._fd_holder = _FdHolder(fd) self._send_conflict_detector = ConflictDetector( "another task is using this stream for send" diff --git a/trio/_util.py b/trio/_util.py index 17bdf1a673..653cc0a254 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -3,7 +3,6 @@ # Little utilities we use internally from abc import ABCMeta -import os import signal import sys import pathlib @@ -17,7 +16,7 @@ import trio # Equivalent to the C function raise(), which Python doesn't wrap -if os.name == "nt": +if sys.platform == "win32": # On windows, os.kill exists but is really weird. # # If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver @@ -61,7 +60,7 @@ signal_raise = getattr(_lib, "raise") else: - def signal_raise(signum): + def signal_raise(signum: int) -> None: signal.pthread_kill(threading.get_ident(), signum) @@ -76,7 +75,7 @@ def signal_raise(signum): # Trying to use signal out of the main thread will fail, so we can then # reliably check if this is the main thread without relying on a # potentially modified threading. -def is_main_thread(): +def is_main_thread() -> bool: """Attempt to reliably check if we are in the main thread.""" try: signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT)) @@ -89,8 +88,10 @@ def is_main_thread(): # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. Returns coroutine object. ###### -def coroutine_or_error(async_fn, *args): - def _return_value_looks_like_wrong_library(value): +def coroutine_or_error( + async_fn: t.Callable[..., t.Awaitable[object]], *args: object +) -> t.Awaitable[object]: + def _return_value_looks_like_wrong_library(value: object) -> bool: # Returned by legacy @asyncio.coroutine functions, which includes # a surprising proportion of asyncio builtins. if isinstance(value, collections.abc.Generator): @@ -184,24 +185,29 @@ class ConflictDetector: """ - def __init__(self, msg): + def __init__(self, msg: str) -> None: self._msg = msg self._held = False - def __enter__(self): + def __enter__(self) -> None: if self._held: raise trio.BusyResourceError(self._msg) else: self._held = True - def __exit__(self, *args): + def __exit__(self, *args: object) -> None: self._held = False -def async_wraps(cls, wrapped_cls, attr_name): +_Fn = t.TypeVar("_Fn", bound=t.Callable) + + +def async_wraps( + cls: t.Type[object], wrapped_cls: t.Type[object], attr_name: str +) -> t.Callable[[_Fn], _Fn]: """Similar to wraps, but for async wrappers of non-async functions.""" - def decorator(func): + def decorator(func: _Fn) -> _Fn: func.__name__ = attr_name func.__qualname__ = ".".join((cls.__qualname__, attr_name)) @@ -216,10 +222,10 @@ def decorator(func): return decorator -def fixup_module_metadata(module_name, namespace): +def fixup_module_metadata(module_name: str, namespace: t.Dict[str, object]) -> None: seen_ids = set() - def fix_one(qualname, name, obj): + def fix_one(qualname: str, name: str, obj: object) -> None: # avoid infinite recursion (relevant when using # typing.Generic, for example) if id(obj) in seen_ids: @@ -232,9 +238,9 @@ def fix_one(qualname, name, obj): # Modules, unlike everything else in Python, put fully-qualitied # names into their __name__ attribute. We check for "." to avoid # rewriting these. - if hasattr(obj, "__name__") and "." not in obj.__name__: - obj.__name__ = name - obj.__qualname__ = qualname + if hasattr(obj, "__name__") and "." not in obj.__name__: # type: ignore[attr-defined] + obj.__name__ = name # type: ignore[attr-defined] + obj.__qualname__ = qualname # type: ignore[attr-defined] if isinstance(obj, type): for attr_name, attr_value in obj.__dict__.items(): fix_one(objname + "." + attr_name, attr_name, attr_value) @@ -267,14 +273,14 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ but at least it becomes possible to write those. """ - def __init__(self, fn): + def __init__(self, fn: t.Callable[..., object]) -> None: update_wrapper(self, fn) self._fn = fn - def __call__(self, *args, **kwargs): + def __call__(self, *args: object, **kwargs: object) -> object: return self._fn(*args, **kwargs) - def __getitem__(self, _): + def __getitem__(self: _T, _: object) -> _T: return self @@ -307,13 +313,15 @@ class SomeClass(metaclass=Final): - TypeError if a sub class is created """ - def __new__(cls, name, bases, cls_namespace): + def __new__(cls: t.Type[_T], name: str, bases: t.Tuple[type], cls_namespace: t.Dict[str, object]) -> _T: for base in bases: if isinstance(base, Final): raise TypeError( f"{base.__module__}.{base.__qualname__} does not support subclassing" ) - return super().__new__(cls, name, bases, cls_namespace) + + # https://github.com/python/mypy/issues/9282 + return super().__new__(cls, name, bases, cls_namespace) # type: ignore[no-any-return,misc] class NoPublicConstructor(Final): @@ -335,7 +343,7 @@ class SomeClass(metaclass=NoPublicConstructor): - TypeError if a sub class or an instance is created. """ - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: object, **kwargs: object) -> None: raise TypeError( f"{cls.__module__}.{cls.__qualname__} has no public constructor" ) @@ -344,7 +352,7 @@ def _create(cls: t.Type[_T], *args: t.Any, **kwargs: t.Any) -> _T: return super().__call__(*args, **kwargs) # type: ignore[no-any-return,misc] -def name_asyncgen(agen): +def name_asyncgen(agen: t.AsyncGenerator) -> str: """Return the fully-qualified name of the async generator function that produced the async generator iterator *agen*. """ @@ -355,7 +363,7 @@ def name_asyncgen(agen): except (AttributeError, KeyError): module = "<{}>".format(agen.ag_code.co_filename) try: - qualname = agen.__qualname__ + qualname = agen.__qualname__ # type: ignore[attr-defined] except AttributeError: qualname = agen.ag_code.co_name return f"{module}.{qualname}" diff --git a/trio/_wait_for_object.py b/trio/_wait_for_object.py index 07c0461429..bb43e90cf8 100644 --- a/trio/_wait_for_object.py +++ b/trio/_wait_for_object.py @@ -1,4 +1,6 @@ import math +from typing import Union + from . import _timeouts import trio from ._core._windows_cffi import ( @@ -10,7 +12,7 @@ ) -async def WaitForSingleObject(obj): +async def WaitForSingleObject(obj: Union[int, object]) -> None: """Async and cancellable variant of WaitForSingleObject. Windows only. Args: @@ -50,7 +52,7 @@ async def WaitForSingleObject(obj): kernel32.CloseHandle(cancel_handle) -def WaitForMultipleObjects_sync(*handles): +def WaitForMultipleObjects_sync(*handles: object) -> None: """Wait for any of the given Windows handles to be signaled.""" n = len(handles) handle_arr = ffi.new("HANDLE[{}]".format(n)) diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index c1760bd62a..44554aeb3d 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -11,7 +11,7 @@ class _ForceCloseBoth: - def __init__(self, both): + def __init__(self, both) -> None: self._both = list(both) async def __aenter__(self): diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py index 2f4be323a9..e23c8254c5 100644 --- a/trio/testing/_checkpoints.py +++ b/trio/testing/_checkpoints.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Iterator +from typing import ContextManager, Iterator from .. import _core @@ -23,7 +23,7 @@ def _assert_yields_or_not(expected: bool) -> Iterator[None]: raise AssertionError("assert_no_checkpoints block yielded!") -def assert_checkpoints(): +def assert_checkpoints() -> ContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block either exits with an exception or executes at least one :ref:`checkpoint `. @@ -43,7 +43,7 @@ def assert_checkpoints(): return _assert_yields_or_not(True) -def assert_no_checkpoints(): +def assert_no_checkpoints() -> ContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block does not execute any :ref:`checkpoints `. diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 99ad7dfcaf..21b36f044f 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -11,7 +11,7 @@ class _UnboundedByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._closed = False self._lot = _core.ParkingLot() @@ -204,7 +204,7 @@ class MemoryReceiveStream(ReceiveStream, metaclass=_util.Final): """ - def __init__(self, receive_some_hook=None, close_hook=None): + def __init__(self, receive_some_hook=None, close_hook=None) -> None: self._conflict_detector = _util.ConflictDetector( "another task is using this stream" ) @@ -422,7 +422,7 @@ async def receiver(): class _LockstepByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._sender_closed = False self._receiver_closed = False @@ -516,7 +516,7 @@ async def receive_some(self, max_bytes=None): class _LockstepSendStream(SendStream): - def __init__(self, lbq): + def __init__(self, lbq) -> None: self._lbq = lbq def close(self): @@ -534,7 +534,7 @@ async def wait_send_all_might_not_block(self): class _LockstepReceiveStream(ReceiveStream): - def __init__(self, lbq): + def __init__(self, lbq) -> None: self._lbq = lbq def close(self): diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index c372811a65..d0bd8ab9d6 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -10,7 +10,7 @@ # Use: # # @trio_test -# async def test_whatever(): +# async def test_whatever() -> None: # await ... # # Also: if a pytest fixture is passed in that subclasses the Clock abc, then diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py index c445c97103..7e1ae64009 100644 --- a/trio/tests/test_abc.py +++ b/trio/tests/test_abc.py @@ -6,7 +6,7 @@ from .. import abc as tabc -async def test_AsyncResource_defaults(): +async def test_AsyncResource_defaults() -> None: @attr.s class MyAR(tabc.AsyncResource): record = attr.ib(factory=list) @@ -21,7 +21,7 @@ async def aclose(self): assert myar.record == ["ac"] -def test_abc_generics(): +def test_abc_generics() -> None: # Pythons below 3.5.2 had a typing.Generic that would throw # errors when instantiating or subclassing a parameterized # version of a class with any __slots__. This is why RunVar diff --git a/trio/tests/test_channel.py b/trio/tests/test_channel.py index fd990fb3e3..7ab5490d6d 100644 --- a/trio/tests/test_channel.py +++ b/trio/tests/test_channel.py @@ -5,7 +5,7 @@ from trio import open_memory_channel, EndOfChannel -async def test_channel(): +async def test_channel() -> None: with pytest.raises(TypeError): open_memory_channel(1.0) with pytest.raises(ValueError): @@ -48,7 +48,7 @@ async def test_channel(): await r.aclose() -async def test_553(autojump_clock): +async def test_553(autojump_clock) -> None: s, r = open_memory_channel(1) with trio.move_on_after(10) as timeout_scope: await r.receive() @@ -56,7 +56,7 @@ async def test_553(autojump_clock): await s.send("Test for PR #553") -async def test_channel_multiple_producers(): +async def test_channel_multiple_producers() -> None: async def producer(send_channel, i): # We close our handle when we're done with it async with send_channel: @@ -79,7 +79,7 @@ async def producer(send_channel, i): assert got == list(range(30)) -async def test_channel_multiple_consumers(): +async def test_channel_multiple_consumers() -> None: successful_receivers = set() received = [] @@ -102,7 +102,7 @@ async def consumer(receive_channel, i): assert set(received) == set(range(10)) -async def test_close_basics(): +async def test_close_basics() -> None: async def send_block(s, expect): with pytest.raises(expect): await s.send(None) @@ -157,7 +157,7 @@ async def receive_block(r): await r.receive() -async def test_close_sync(): +async def test_close_sync() -> None: async def send_block(s, expect): with pytest.raises(expect): await s.send(None) @@ -212,7 +212,7 @@ async def receive_block(r): await r.receive() -async def test_receive_channel_clone_and_close(): +async def test_receive_channel_clone_and_close() -> None: s, r = open_memory_channel(10) r2 = r.clone() @@ -239,7 +239,7 @@ async def test_receive_channel_clone_and_close(): s.send_nowait(None) -async def test_close_multiple_send_handles(): +async def test_close_multiple_send_handles() -> None: # With multiple send handles, closing one handle only wakes senders on # that handle, but others can continue just fine s1, r = open_memory_channel(0) @@ -260,7 +260,7 @@ async def send_will_succeed(): assert await r.receive() == "ok" -async def test_close_multiple_receive_handles(): +async def test_close_multiple_receive_handles() -> None: # With multiple receive handles, closing one handle only wakes receivers on # that handle, but others can continue just fine s, r1 = open_memory_channel(0) @@ -281,7 +281,7 @@ async def receive_will_succeed(): await s.send("ok") -async def test_inf_capacity(): +async def test_inf_capacity() -> None: s, r = open_memory_channel(float("inf")) # It's accepted, and we can send all day without blocking @@ -295,7 +295,7 @@ async def test_inf_capacity(): assert got == list(range(10)) -async def test_statistics(): +async def test_statistics() -> None: s, r = open_memory_channel(2) assert s.statistics() == r.statistics() @@ -345,7 +345,7 @@ async def test_statistics(): assert s.statistics().tasks_waiting_receive == 0 -async def test_channel_fairness(): +async def test_channel_fairness() -> None: # We can remove an item we just sent, and send an item back in after, if # no-one else is waiting. @@ -388,7 +388,7 @@ async def do_receive(r): assert (await r.receive()) == 2 -async def test_unbuffered(): +async def test_unbuffered() -> None: s, r = open_memory_channel(0) with pytest.raises(trio.WouldBlock): r.receive_nowait() diff --git a/trio/tests/test_exports.py b/trio/tests/test_exports.py index ce3fcec887..ba03f74c9c 100644 --- a/trio/tests/test_exports.py +++ b/trio/tests/test_exports.py @@ -13,7 +13,7 @@ from .. import _util -def test_core_is_properly_reexported(): +def test_core_is_properly_reexported() -> None: # Each export from _core should be re-exported by exactly one of these # three modules: sources = [trio, trio.lowlevel, trio.testing] @@ -113,7 +113,7 @@ def no_underscores(symbols): assert False -def test_classes_are_final(): +def test_classes_are_final() -> None: for module in PUBLIC_MODULES: for name, class_ in module.__dict__.items(): if not isinstance(class_, type): diff --git a/trio/tests/test_file_io.py b/trio/tests/test_file_io.py index 70266df699..91cb4a3b5c 100644 --- a/trio/tests/test_file_io.py +++ b/trio/tests/test_file_io.py @@ -37,12 +37,12 @@ def async_file( return trio.wrap_file(wrapped) -def test_wrap_invalid(): +def test_wrap_invalid() -> None: with pytest.raises(TypeError): trio.wrap_file(str()) -def test_wrap_non_iobase(): +def test_wrap_non_iobase() -> None: class FakeFile: def close(self): # pragma: no cover pass @@ -62,11 +62,11 @@ def write(self): # pragma: no cover trio.wrap_file(FakeFile()) -def test_wrapped_property(async_file, wrapped): +def test_wrapped_property(async_file, wrapped) -> None: assert async_file.wrapped is wrapped -def test_dir_matches_wrapped(async_file, wrapped): +def test_dir_matches_wrapped(async_file, wrapped) -> None: attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS) # all supported attrs in wrapped should be available in async_file @@ -77,7 +77,7 @@ def test_dir_matches_wrapped(async_file, wrapped): ) -def test_unsupported_not_forwarded(): +def test_unsupported_not_forwarded() -> None: class FakeFile(io.RawIOBase): def unsupported_attr(self): # pragma: no cover pass @@ -90,7 +90,7 @@ def unsupported_attr(self): # pragma: no cover getattr(async_file, "unsupported_attr") -def test_sync_attrs_forwarded(async_file, wrapped): +def test_sync_attrs_forwarded(async_file, wrapped) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name not in dir(async_file): continue @@ -98,7 +98,7 @@ def test_sync_attrs_forwarded(async_file, wrapped): assert getattr(async_file, attr_name) is getattr(wrapped, attr_name) -def test_sync_attrs_match_wrapper(async_file, wrapped): +def test_sync_attrs_match_wrapper(async_file, wrapped) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name in dir(async_file): continue @@ -110,7 +110,7 @@ def test_sync_attrs_match_wrapper(async_file, wrapped): getattr(wrapped, attr_name) -def test_async_methods_generated_once(async_file): +def test_async_methods_generated_once(async_file) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name not in dir(async_file): continue @@ -118,7 +118,7 @@ def test_async_methods_generated_once(async_file): assert getattr(async_file, meth_name) is getattr(async_file, meth_name) -def test_async_methods_signature(async_file): +def test_async_methods_signature(async_file) -> None: # use read as a representative of all async methods assert async_file.read.__name__ == "read" assert async_file.read.__qualname__ == "AsyncIOWrapper.read" @@ -126,7 +126,7 @@ def test_async_methods_signature(async_file): assert "io.StringIO.read" in async_file.read.__doc__ -async def test_async_methods_wrap(async_file, wrapped): +async def test_async_methods_wrap(async_file, wrapped) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name not in dir(async_file): continue @@ -144,7 +144,7 @@ async def test_async_methods_wrap(async_file, wrapped): wrapped.reset_mock() -async def test_async_methods_match_wrapper(async_file, wrapped): +async def test_async_methods_match_wrapper(async_file, wrapped) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name in dir(async_file): continue @@ -156,7 +156,7 @@ async def test_async_methods_match_wrapper(async_file, wrapped): getattr(wrapped, meth_name) -async def test_open(path): +async def test_open(path) -> None: f = await trio.open_file(path, "w") assert isinstance(f, AsyncIOWrapper) @@ -164,7 +164,7 @@ async def test_open(path): await f.aclose() -async def test_open_context_manager(path): +async def test_open_context_manager(path) -> None: async with await trio.open_file(path, "w") as f: assert isinstance(f, AsyncIOWrapper) assert not f.closed @@ -172,7 +172,7 @@ async def test_open_context_manager(path): assert f.closed -async def test_async_iter(): +async def test_async_iter() -> None: async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar")) expected = list(async_file.wrapped) result = [] @@ -184,7 +184,7 @@ async def test_async_iter(): assert result == expected -async def test_aclose_cancelled(path): +async def test_aclose_cancelled(path) -> None: with _core.CancelScope() as cscope: f = await trio.open_file(path, "w") cscope.cancel() @@ -198,7 +198,7 @@ async def test_aclose_cancelled(path): assert f.closed -async def test_detach_rewraps_asynciobase(): +async def test_detach_rewraps_asynciobase() -> None: raw = io.BytesIO() buffered = io.BufferedReader(raw) diff --git a/trio/tests/test_highlevel_generic.py b/trio/tests/test_highlevel_generic.py index df2b2cecf7..32dd86682f 100644 --- a/trio/tests/test_highlevel_generic.py +++ b/trio/tests/test_highlevel_generic.py @@ -31,7 +31,7 @@ async def aclose(self): self.record.append("aclose") -async def test_StapledStream(): +async def test_StapledStream() -> None: send_stream = RecordSendStream() receive_stream = RecordReceiveStream() stapled = StapledStream(send_stream, receive_stream) @@ -71,7 +71,7 @@ async def fake_send_eof(): assert send_stream.record == ["aclose"] -async def test_StapledStream_with_erroring_close(): +async def test_StapledStream_with_erroring_close() -> None: # Make sure that if one of the aclose methods errors out, then the other # one still gets called. class BrokenSendStream(RecordSendStream): diff --git a/trio/tests/test_highlevel_open_tcp_listeners.py b/trio/tests/test_highlevel_open_tcp_listeners.py index 9c068293db..1107499218 100644 --- a/trio/tests/test_highlevel_open_tcp_listeners.py +++ b/trio/tests/test_highlevel_open_tcp_listeners.py @@ -14,7 +14,7 @@ from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6 -async def test_open_tcp_listeners_basic(): +async def test_open_tcp_listeners_basic() -> None: listeners = await open_tcp_listeners(0) assert isinstance(listeners, list) for obj in listeners: @@ -42,7 +42,7 @@ async def test_open_tcp_listeners_basic(): await resource.aclose() -async def test_open_tcp_listeners_specific_port_specific_host(): +async def test_open_tcp_listeners_specific_port_specific_host() -> None: # Pick a port sock = tsocket.socket() await sock.bind(("127.0.0.1", 0)) @@ -65,7 +65,7 @@ async def test_open_tcp_listeners_ipv6_v6only() -> None: await open_tcp_stream("127.0.0.1", port) -async def test_open_tcp_listeners_rebind(): +async def test_open_tcp_listeners_rebind() -> None: (l1,) = await open_tcp_listeners(0, host="127.0.0.1") sockaddr1 = l1.socket.getsockname() @@ -170,7 +170,7 @@ async def getaddrinfo(self, host, port, family, type, proto, flags): ] -async def test_open_tcp_listeners_multiple_host_cleanup_on_error(): +async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None: # If we were trying to bind to multiple hosts and one of them failed, they # call get cleaned up before returning fsf = FakeSocketFactory(3) @@ -193,7 +193,7 @@ async def test_open_tcp_listeners_multiple_host_cleanup_on_error(): assert sock.closed -async def test_open_tcp_listeners_port_checking(): +async def test_open_tcp_listeners_port_checking() -> None: for host in ["127.0.0.1", None]: with pytest.raises(TypeError): await open_tcp_listeners(None, host=host) @@ -203,7 +203,7 @@ async def test_open_tcp_listeners_port_checking(): await open_tcp_listeners("http", host=host) -async def test_serve_tcp(): +async def test_serve_tcp() -> None: async def handler(stream): await stream.send_all(b"x") @@ -255,7 +255,7 @@ async def test_open_tcp_listeners_some_address_families_unavailable( assert not should_succeed -async def test_open_tcp_listeners_socket_fails_not_afnosupport(): +async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None: fsf = FakeSocketFactory( 10, raise_on_family={ @@ -283,7 +283,7 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport(): # effectively is no backlog), sometimes the host might not be enough resources # to give us the full requested backlog... it was a mess. So now we just check # that the backlog argument is passed through correctly. -async def test_open_tcp_listeners_backlog(): +async def test_open_tcp_listeners_backlog() -> None: fsf = FakeSocketFactory(99) tsocket.set_custom_socket_factory(fsf) for (given, expected) in [ diff --git a/trio/tests/test_highlevel_open_tcp_stream.py b/trio/tests/test_highlevel_open_tcp_stream.py index bcd3ef7f5a..3b0f7836b4 100644 --- a/trio/tests/test_highlevel_open_tcp_stream.py +++ b/trio/tests/test_highlevel_open_tcp_stream.py @@ -14,7 +14,7 @@ ) -def test_close_all(): +def test_close_all() -> None: class CloseMe: closed = False @@ -45,7 +45,7 @@ def close(self): assert c.closed -def test_reorder_for_rfc_6555_section_5_4(): +def test_reorder_for_rfc_6555_section_5_4() -> None: def fake4(i): return ( AF_INET, @@ -82,7 +82,7 @@ def fake6(i): assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)] -def test_format_host_port(): +def test_format_host_port() -> None: assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port("example.com", 443) == "example.com:443" @@ -92,7 +92,7 @@ def test_format_host_port(): # Make sure we can connect to localhost using real kernel sockets -async def test_open_tcp_stream_real_socket_smoketest(): +async def test_open_tcp_stream_real_socket_smoketest() -> None: listen_sock = trio.socket.socket() await listen_sock.bind(("127.0.0.1", 0)) _, listen_port = listen_sock.getsockname() @@ -107,7 +107,7 @@ async def test_open_tcp_stream_real_socket_smoketest(): listen_sock.close() -async def test_open_tcp_stream_input_validation(): +async def test_open_tcp_stream_input_validation() -> None: with pytest.raises(ValueError): await open_tcp_stream(None, 80) with pytest.raises(TypeError): @@ -123,7 +123,7 @@ def can_bind_127_0_0_2(): return s.getsockname()[0] == "127.0.0.2" -async def test_local_address_real(): +async def test_local_address_real() -> None: with trio.socket.socket() as listener: await listener.bind(("127.0.0.1", 0)) listener.listen() @@ -212,7 +212,7 @@ def setsockopt(self, *args, **kwargs): class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver): - def __init__(self, port, ip_list, supported_families): + def __init__(self, port, ip_list, supported_families) -> None: # ip_list have to be unique ip_order = [ip for (ip, _, _) in ip_list] assert len(set(ip_order)) == len(ip_list) @@ -310,19 +310,19 @@ async def run_scenario( return (exc, scenario) -async def test_one_host_quick_success(autojump_clock): +async def test_one_host_quick_success(autojump_clock) -> None: sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")]) assert sock.ip == "1.2.3.4" assert trio.current_time() == 0.123 -async def test_one_host_slow_success(autojump_clock): +async def test_one_host_slow_success(autojump_clock) -> None: sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")]) assert sock.ip == "1.2.3.4" assert trio.current_time() == 100 -async def test_one_host_quick_fail(autojump_clock): +async def test_one_host_quick_fail(autojump_clock) -> None: exc, scenario = await run_scenario( 82, [("1.2.3.4", 0.123, "error")], expect_error=OSError ) @@ -330,7 +330,7 @@ async def test_one_host_quick_fail(autojump_clock): assert trio.current_time() == 0.123 -async def test_one_host_slow_fail(autojump_clock): +async def test_one_host_slow_fail(autojump_clock) -> None: exc, scenario = await run_scenario( 83, [("1.2.3.4", 100, "error")], expect_error=OSError ) @@ -338,7 +338,7 @@ async def test_one_host_slow_fail(autojump_clock): assert trio.current_time() == 100 -async def test_one_host_failed_after_connect(autojump_clock): +async def test_one_host_failed_after_connect(autojump_clock) -> None: exc, scenario = await run_scenario( 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt ) @@ -346,7 +346,7 @@ async def test_one_host_failed_after_connect(autojump_clock): # With the default 0.250 second delay, the third attempt will win -async def test_basic_fallthrough(autojump_clock): +async def test_basic_fallthrough(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -365,7 +365,7 @@ async def test_basic_fallthrough(autojump_clock): } -async def test_early_success(autojump_clock): +async def test_early_success(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -384,7 +384,7 @@ async def test_early_success(autojump_clock): # With a 0.450 second delay, the first attempt will win -async def test_custom_delay(autojump_clock): +async def test_custom_delay(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -403,7 +403,7 @@ async def test_custom_delay(autojump_clock): } -async def test_custom_errors_expedite(autojump_clock): +async def test_custom_errors_expedite(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -424,7 +424,7 @@ async def test_custom_errors_expedite(autojump_clock): } -async def test_all_fail(autojump_clock): +async def test_all_fail(autojump_clock) -> None: exc, scenario = await run_scenario( 80, [ @@ -447,7 +447,7 @@ async def test_all_fail(autojump_clock): } -async def test_multi_success(autojump_clock): +async def test_multi_success(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -477,7 +477,7 @@ async def test_multi_success(autojump_clock): } -async def test_does_reorder(autojump_clock): +async def test_does_reorder(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -497,7 +497,7 @@ async def test_does_reorder(autojump_clock): } -async def test_handles_no_ipv4(autojump_clock): +async def test_handles_no_ipv4(autojump_clock) -> None: sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect @@ -519,7 +519,7 @@ async def test_handles_no_ipv4(autojump_clock): } -async def test_handles_no_ipv6(autojump_clock): +async def test_handles_no_ipv6(autojump_clock) -> None: sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect @@ -541,12 +541,12 @@ async def test_handles_no_ipv6(autojump_clock): } -async def test_no_hosts(autojump_clock): +async def test_no_hosts(autojump_clock) -> None: exc, scenario = await run_scenario(80, [], expect_error=OSError) assert "no results found" in str(exc) -async def test_cancel(autojump_clock): +async def test_cancel(autojump_clock) -> None: with trio.move_on_after(5) as cancel_scope: exc, scenario = await run_scenario( 80, diff --git a/trio/tests/test_highlevel_open_unix_stream.py b/trio/tests/test_highlevel_open_unix_stream.py index 58e4cdf5ec..a73324cc64 100644 --- a/trio/tests/test_highlevel_open_unix_stream.py +++ b/trio/tests/test_highlevel_open_unix_stream.py @@ -11,7 +11,7 @@ pytestmark = pytest.mark.skip("Needs unix socket support") -def test_close_on_error(): +def test_close_on_error() -> None: class CloseMe: closed = False @@ -34,7 +34,7 @@ async def test_open_with_bad_filename_type(filename: float) -> None: await open_unix_socket(filename) -async def test_open_bad_socket(): +async def test_open_bad_socket() -> None: # mktemp is marked as insecure, but that's okay, we don't want the file to # exist name = tempfile.mktemp() @@ -42,7 +42,7 @@ async def test_open_bad_socket(): await open_unix_socket(name) -async def test_open_unix_socket(): +async def test_open_unix_socket() -> None: for name_type in [Path, str]: name = tempfile.mktemp() serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) diff --git a/trio/tests/test_highlevel_serve_listeners.py b/trio/tests/test_highlevel_serve_listeners.py index f488b5e7ff..707e299e34 100644 --- a/trio/tests/test_highlevel_serve_listeners.py +++ b/trio/tests/test_highlevel_serve_listeners.py @@ -42,7 +42,7 @@ async def aclose(self): await trio.lowlevel.checkpoint() -async def test_serve_listeners_basic(): +async def test_serve_listeners_basic() -> None: listeners = [MemoryListener(), MemoryListener()] record = [] @@ -87,7 +87,7 @@ async def do_tests(parent_nursery): assert listener.closed -async def test_serve_listeners_accept_unrecognized_error(): +async def test_serve_listeners_accept_unrecognized_error() -> None: for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]: listener = MemoryListener() @@ -101,7 +101,7 @@ async def raise_error(): assert excinfo.value is error -async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog): +async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog) -> None: listener = MemoryListener() async def raise_EMFILE(): @@ -120,7 +120,7 @@ async def raise_EMFILE(): assert record.exc_info[1].errno == errno.EMFILE -async def test_serve_listeners_connection_nursery(autojump_clock): +async def test_serve_listeners_connection_nursery(autojump_clock) -> None: listener = MemoryListener() async def handler(stream): diff --git a/trio/tests/test_highlevel_socket.py b/trio/tests/test_highlevel_socket.py index 9dcb834d2c..a67a435e0a 100644 --- a/trio/tests/test_highlevel_socket.py +++ b/trio/tests/test_highlevel_socket.py @@ -14,7 +14,7 @@ from .. import socket as tsocket -async def test_SocketStream_basics(): +async def test_SocketStream_basics() -> None: # stdlib socket bad (even if connected) a, b = stdlib_socket.socketpair() with a, b: @@ -52,7 +52,7 @@ async def test_SocketStream_basics(): assert isinstance(b, bytes) -async def test_SocketStream_send_all(): +async def test_SocketStream_send_all() -> None: BIG = 10000000 a_sock, b_sock = tsocket.socketpair() @@ -121,7 +121,7 @@ async def waiter(nursery): nursery.start_soon(waiter, nursery) -async def test_SocketStream_generic(): +async def test_SocketStream_generic() -> None: async def stream_maker(): left, right = tsocket.socketpair() return SocketStream(left), SocketStream(right) @@ -135,7 +135,7 @@ async def clogged_stream_maker(): await check_half_closeable_stream(stream_maker, clogged_stream_maker) -async def test_SocketListener(): +async def test_SocketListener() -> None: # Not a Trio socket with stdlib_socket.socket() as s: s.bind(("127.0.0.1", 0)) @@ -188,7 +188,7 @@ async def test_SocketListener(): await server_stream.aclose() -async def test_SocketListener_socket_closed_underfoot(): +async def test_SocketListener_socket_closed_underfoot() -> None: listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(10) @@ -203,9 +203,9 @@ async def test_SocketListener_socket_closed_underfoot(): await listener.accept() -async def test_SocketListener_accept_errors(): +async def test_SocketListener_accept_errors() -> None: class FakeSocket(tsocket.SocketType): - def __init__(self, events): + def __init__(self, events) -> None: self._events = iter(events) type = tsocket.SOCK_STREAM @@ -257,7 +257,7 @@ async def accept(self): assert s.socket is fake_server_sock -async def test_socket_stream_works_when_peer_has_already_closed(): +async def test_socket_stream_works_when_peer_has_already_closed() -> None: sock_a, sock_b = tsocket.socketpair() with sock_a, sock_b: await sock_b.send(b"x") diff --git a/trio/tests/test_highlevel_ssl_helpers.py b/trio/tests/test_highlevel_ssl_helpers.py index c00f5dc464..ffd8bca5f4 100644 --- a/trio/tests/test_highlevel_ssl_helpers.py +++ b/trio/tests/test_highlevel_ssl_helpers.py @@ -43,7 +43,7 @@ async def getnameinfo(self, *args): # pragma: no cover # This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... # noqa is needed because flake8 doesn't understand how pytest fixtures work. -async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811 +async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx) -> None: # noqa: F811 async with trio.open_nursery() as nursery: (listener,) = await nursery.start( partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") @@ -96,7 +96,7 @@ async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa nursery.cancel_scope.cancel() -async def test_open_ssl_over_tcp_listeners(): +async def test_open_ssl_over_tcp_listeners() -> None: (listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1") async with listener: assert isinstance(listener, trio.SSLListener) diff --git a/trio/tests/test_path.py b/trio/tests/test_path.py index 54fb11dc5d..a6bb53cecc 100644 --- a/trio/tests/test_path.py +++ b/trio/tests/test_path.py @@ -22,14 +22,14 @@ def method_pair(path, method_name): return getattr(path, method_name), getattr(async_path, method_name) -async def test_open_is_async_context_manager(path): +async def test_open_is_async_context_manager(path) -> None: async with await path.open("w") as f: assert isinstance(f, AsyncIOWrapper) assert f.closed -async def test_magic(): +async def test_magic() -> None: path = trio.Path("test") assert str(path) == "test" @@ -98,14 +98,14 @@ async def test_hash_magic( assert hash(a) == hash(b) -async def test_forwarded_properties(path): +async def test_forwarded_properties(path) -> None: # use `name` as a representative of forwarded properties assert "name" in dir(path) assert path.name == "test" -async def test_async_method_signature(path): +async def test_async_method_signature(path) -> None: # use `resolve` as a representative of wrapped methods assert path.resolve.__name__ == "resolve" @@ -125,7 +125,7 @@ async def test_compare_async_stat_methods(method_name: str) -> None: assert result == async_result -async def test_invalid_name_not_wrapped(path): +async def test_invalid_name_not_wrapped(path) -> None: with pytest.raises(AttributeError): getattr(path, "invalid_fake_attr") @@ -142,7 +142,7 @@ async def test_async_methods_rewrap(method_name: str) -> None: assert str(result) == str(async_result) -async def test_forward_methods_rewrap(path, tmpdir): +async def test_forward_methods_rewrap(path, tmpdir) -> None: with_name = path.with_name("foo") with_suffix = path.with_suffix(".py") @@ -152,17 +152,17 @@ async def test_forward_methods_rewrap(path, tmpdir): assert with_suffix == tmpdir.join("test.py") -async def test_forward_properties_rewrap(path): +async def test_forward_properties_rewrap(path) -> None: assert isinstance(path.parent, trio.Path) -async def test_forward_methods_without_rewrap(path, tmpdir): +async def test_forward_methods_without_rewrap(path, tmpdir) -> None: path = await path.parent.resolve() assert path.as_uri().startswith("file:///") -async def test_repr(): +async def test_repr() -> None: path = trio.Path(".") assert repr(path) == "trio.Path('.')" @@ -178,23 +178,23 @@ class MockWrapper: _wraps = MockWrapped -async def test_type_forwards_unsupported(): +async def test_type_forwards_unsupported() -> None: with pytest.raises(TypeError): WrapperType.generate_forwards(MockWrapper, {}) -async def test_type_wraps_unsupported(): +async def test_type_wraps_unsupported() -> None: with pytest.raises(TypeError): WrapperType.generate_wraps(MockWrapper, {}) -async def test_type_forwards_private(): +async def test_type_forwards_private() -> None: WrapperType.generate_forwards(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") -async def test_type_wraps_private(): +async def test_type_wraps_private() -> None: WrapperType.generate_wraps(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") @@ -210,17 +210,17 @@ async def test_path_wraps_path(path: trio.Path, meth: Callable[..., trio.Path]) assert wrapped == result -async def test_path_nonpath(): +async def test_path_nonpath() -> None: with pytest.raises(TypeError): trio.Path(1) -async def test_open_file_can_open_path(path): +async def test_open_file_can_open_path(path) -> None: async with await trio.open_file(path, "w") as f: assert f.name == os.fspath(path) -async def test_globmethods(path): +async def test_globmethods(path) -> None: # Populate a directory tree await path.mkdir() await (path / "foo").mkdir() @@ -249,7 +249,7 @@ async def test_globmethods(path): assert entries == {"_bar.txt", "bar.txt"} -async def test_iterdir(path): +async def test_iterdir(path) -> None: # Populate a directory await path.mkdir() await (path / "foo").mkdir() @@ -263,7 +263,7 @@ async def test_iterdir(path): assert entries == {"bar.txt", "foo"} -async def test_classmethods(): +async def test_classmethods() -> None: assert isinstance(await trio.Path.home(), trio.Path) # pathlib.Path has only two classmethods diff --git a/trio/tests/test_scheduler_determinism.py b/trio/tests/test_scheduler_determinism.py index e2d3167e45..d52c066de8 100644 --- a/trio/tests/test_scheduler_determinism.py +++ b/trio/tests/test_scheduler_determinism.py @@ -17,7 +17,7 @@ async def tracer(name): return tuple(trace) -def test_the_trio_scheduler_is_not_deterministic(): +def test_the_trio_scheduler_is_not_deterministic() -> None: # At least, not yet. See https://github.com/python-trio/trio/issues/32 traces = [] for _ in range(10): @@ -25,7 +25,7 @@ def test_the_trio_scheduler_is_not_deterministic(): assert len(set(traces)) == len(traces) -def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch): +def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch) -> None: monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True) traces = [] for _ in range(10): diff --git a/trio/tests/test_signals.py b/trio/tests/test_signals.py index 235772f900..bf2bb14d27 100644 --- a/trio/tests/test_signals.py +++ b/trio/tests/test_signals.py @@ -8,7 +8,7 @@ from .._signals import open_signal_receiver, _signal_handler -async def test_open_signal_receiver(): +async def test_open_signal_receiver() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL) as receiver: # Raise it a few times, to exercise signal coalescing, both at the @@ -32,7 +32,7 @@ async def test_open_signal_receiver(): assert signal.getsignal(signal.SIGILL) is orig -async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): +async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None: orig = signal.getsignal(signal.SIGILL) with pytest.raises(ValueError): with open_signal_receiver(signal.SIGILL, 1234567): @@ -41,13 +41,13 @@ async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): assert signal.getsignal(signal.SIGILL) is orig -async def test_open_signal_receiver_empty_fail(): +async def test_open_signal_receiver_empty_fail() -> None: with pytest.raises(TypeError, match="No signals were provided"): with open_signal_receiver(): pass -async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): +async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL, signal.SIGILL): pass @@ -55,7 +55,7 @@ async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): assert signal.getsignal(signal.SIGILL) is orig -async def test_catch_signals_wrong_thread(): +async def test_catch_signals_wrong_thread() -> None: async def naughty(): with open_signal_receiver(signal.SIGINT): pass # pragma: no cover @@ -64,7 +64,7 @@ async def naughty(): await trio.to_thread.run_sync(trio.run, naughty) -async def test_open_signal_receiver_conflict(): +async def test_open_signal_receiver_conflict() -> None: with pytest.raises(trio.BusyResourceError): with open_signal_receiver(signal.SIGILL) as receiver: async with trio.open_nursery() as nursery: @@ -81,7 +81,7 @@ async def wait_run_sync_soon_idempotent_queue_barrier(): await ev.wait() -async def test_open_signal_receiver_no_starvation(): +async def test_open_signal_receiver_no_starvation() -> None: # Set up a situation where there are always 2 pending signals available to # report, and make sure that instead of getting the same signal reported # over and over, it alternates between reporting both of them. @@ -112,7 +112,7 @@ async def test_open_signal_receiver_no_starvation(): traceback.print_exc() -async def test_catch_signals_race_condition_on_exit(): +async def test_catch_signals_race_condition_on_exit() -> None: delivered_directly = set() def direct_handler(signo, frame): diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index 5b8790772e..8c19476db0 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -23,7 +23,7 @@ class MonkeypatchedGAI: - def __init__(self, orig_getaddrinfo): + def __init__(self, orig_getaddrinfo) -> None: self._orig_getaddrinfo = orig_getaddrinfo self._responses = {} self.record = [] @@ -58,7 +58,7 @@ def monkeygai(monkeypatch: _pytest.monkeypatch.MonkeyPatch) -> MonkeypatchedGAI: return controller -async def test__try_sync(): +async def test__try_sync() -> None: with assert_checkpoints(): async with _try_sync(): pass @@ -88,7 +88,7 @@ def _is_ValueError(exc): ################################################################ -def test_socket_has_some_reexports(): +def test_socket_has_some_reexports() -> None: assert tsocket.SOL_SOCKET == stdlib_socket.SOL_SOCKET assert tsocket.TCP_NODELAY == stdlib_socket.TCP_NODELAY assert tsocket.gaierror == stdlib_socket.gaierror @@ -100,7 +100,7 @@ def test_socket_has_some_reexports(): ################################################################ -async def test_getaddrinfo(monkeygai): +async def test_getaddrinfo(monkeygai) -> None: def check(got, expected): # win32 returns 0 for the proto field # musl and glibc have inconsistent handling of the canonical name @@ -176,7 +176,7 @@ def filtered(gai_list): await tsocket.getaddrinfo("asdf", "12345") -async def test_getnameinfo(): +async def test_getnameinfo() -> None: # Trivial test: ni_numeric = stdlib_socket.NI_NUMERICHOST | stdlib_socket.NI_NUMERICSERV with assert_checkpoints(): @@ -211,7 +211,7 @@ async def test_getnameinfo(): ################################################################ -async def test_from_stdlib_socket(): +async def test_from_stdlib_socket() -> None: sa, sb = stdlib_socket.socketpair() assert not isinstance(sa, tsocket.SocketType) with sa, sb: @@ -233,7 +233,7 @@ class MySocket(stdlib_socket.socket): tsocket.from_stdlib_socket(mysock) -async def test_from_fd(): +async def test_from_fd() -> None: sa, sb = stdlib_socket.socketpair() ta = tsocket.fromfd(sa.fileno(), sa.family, sa.type, sa.proto) with sa, sb, ta: @@ -242,7 +242,7 @@ async def test_from_fd(): assert sb.recv(3) == b"x" -async def test_socketpair_simple(): +async def test_socketpair_simple() -> None: async def child(sock): print("sending hello") await sock.send(b"h") @@ -274,7 +274,7 @@ async def test_fromshare() -> None: assert await b.recv(1) == b"x" -async def test_socket(): +async def test_socket() -> None: with tsocket.socket() as s: assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET @@ -319,7 +319,7 @@ async def test_sniff_sockopts() -> None: ################################################################ -async def test_SocketType_basics(): +async def test_SocketType_basics() -> None: sock = tsocket.socket() with sock as cm_enter_value: assert cm_enter_value is sock @@ -370,7 +370,7 @@ async def test_SocketType_basics(): sock.close() -async def test_SocketType_dup(): +async def test_SocketType_dup() -> None: a, b = tsocket.socketpair() with a, b: a2 = a.dup() @@ -382,7 +382,7 @@ async def test_SocketType_dup(): assert await b.recv(1) == b"x" -async def test_SocketType_shutdown(): +async def test_SocketType_shutdown() -> None: a, b = tsocket.socketpair() with a, b: await a.send(b"x") @@ -435,7 +435,7 @@ async def test_SocketType_simple_server( assert await client.recv(1) == b"x" -async def test_SocketType_is_readable(): +async def test_SocketType_is_readable() -> None: a, b = tsocket.socketpair() with a, b: assert not a.is_readable() @@ -586,7 +586,7 @@ async def res(*args): await res(("1.2.3.4", 80, 0, 0)) -async def test_SocketType_unresolved_names(): +async def test_SocketType_unresolved_names() -> None: with tsocket.socket() as sock: await sock.bind(("localhost", 0)) assert sock.getsockname()[0] == "127.0.0.1" @@ -605,7 +605,7 @@ async def test_SocketType_unresolved_names(): # This tests all the complicated paths through _nonblocking_helper, using recv # as a stand-in for all the methods that use _nonblocking_helper. -async def test_SocketType_non_blocking_paths(): +async def test_SocketType_non_blocking_paths() -> None: a, b = stdlib_socket.socketpair() with a, b: ta = tsocket.from_stdlib_socket(a) @@ -680,7 +680,7 @@ async def t2(): # This tests the complicated paths through connect -async def test_SocketType_connect_paths(): +async def test_SocketType_connect_paths() -> None: with tsocket.socket() as sock: with pytest.raises(ValueError): # Should be a tuple @@ -733,7 +733,7 @@ def connect(self, *args, **kwargs): await sock.connect(("127.0.0.1", 2)) -async def test_resolve_remote_address_exception_closes_socket(): +async def test_resolve_remote_address_exception_closes_socket() -> None: # Here we are testing issue 247, any cancellation will leave the socket closed with _core.CancelScope() as cancel_scope: with tsocket.socket() as sock: @@ -749,7 +749,7 @@ async def _resolve_remote_address_nocp(self, *args, **kwargs): assert sock.fileno() == -1 -async def test_send_recv_variants(): +async def test_send_recv_variants() -> None: a, b = tsocket.socketpair() with a, b: # recv, including with flags @@ -845,7 +845,7 @@ async def test_send_recv_variants(): assert await b.recv(10) == b"yyy" -async def test_idna(monkeygai): +async def test_idna(monkeygai) -> None: # This is the encoding for "faß.de", which uses one of the characters that # IDNA 2003 handles incorrectly: monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80) @@ -863,14 +863,14 @@ async def test_idna(monkeygai): assert "ok faß.de" == await tsocket.getaddrinfo(b"xn--fa-hia.de", 80) -async def test_getprotobyname(): +async def test_getprotobyname() -> None: # These are the constants used in IP header fields, so the numeric values # had *better* be stable across systems... assert await tsocket.getprotobyname("udp") == 17 assert await tsocket.getprotobyname("tcp") == 6 -async def test_custom_hostname_resolver(monkeygai): +async def test_custom_hostname_resolver(monkeygai) -> None: class CustomResolver: async def getaddrinfo(self, host, port, family, type, proto, flags): return ("custom_gai", host, port, family, type, proto, flags) @@ -914,7 +914,7 @@ async def getnameinfo(self, sockaddr, flags): assert await tsocket.getaddrinfo("host", "port") == "x" -async def test_custom_socket_factory(): +async def test_custom_socket_factory() -> None: class CustomSocketFactory: def socket(self, family, type, proto): return ("hi", family, type, proto) @@ -941,7 +941,7 @@ def socket(self, family, type, proto): assert tsocket.set_custom_socket_factory(None) is csf -async def test_SocketType_is_abstract(): +async def test_SocketType_is_abstract() -> None: with pytest.raises(TypeError): tsocket.SocketType() @@ -976,7 +976,7 @@ async def check_AF_UNIX(path): pass -async def test_interrupted_by_close(): +async def test_interrupted_by_close() -> None: a_stdlib, b_stdlib = stdlib_socket.socketpair() with a_stdlib, b_stdlib: a_stdlib.setblocking(False) @@ -1006,7 +1006,7 @@ async def receiver(): a.close() -async def test_many_sockets(): +async def test_many_sockets() -> None: total = 5000 # Must be more than MAX_AFD_GROUP_SIZE sockets = [] for x in range(total // 2): diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index aa5a6a463f..30e2f72a3a 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -172,7 +172,7 @@ async def ssl_echo_server( # type: ignore[misc] # Doesn't inherit from Stream because I left out the methods that we don't # actually need. class PyOpenSSLEchoStream: - def __init__(self, sleeper=None): + def __init__(self, sleeper=None) -> None: ctx = SSL.Context(SSL.SSLv23_METHOD) # TLS 1.3 removes renegotiation support. Which is great for them, but # we still have to support versions before that, and that means we @@ -325,7 +325,7 @@ async def receive_some(self, nbytes=None): print(" <-- transport_stream.receive_some finished") -async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): +async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None: # Make sure that PyOpenSSLEchoStream complains if two tasks call send_all # at the same time, or ditto for receive_some. The tricky cases where SSLStream # might accidentally do this are during renegotiation, which we test using @@ -402,7 +402,7 @@ def ssl_lockstep_stream_pair(client_ctx, **kwargs): # Simple smoke test for handshake/send/receive/shutdown talking to a # synchronous server, plus make sure that we do the bare minimum of # certificate checking (even though this is really Python's responsibility) -async def test_ssl_client_basics(client_ctx): +async def test_ssl_client_basics(client_ctx) -> None: # Everything OK async with ssl_echo_server(client_ctx) as s: assert not s.server_side @@ -428,7 +428,7 @@ async def test_ssl_client_basics(client_ctx): assert isinstance(excinfo.value.__cause__, ssl.CertificateError) -async def test_ssl_server_basics(client_ctx): +async def test_ssl_server_basics(client_ctx) -> None: a, b = stdlib_socket.socketpair() with a, b: server_sock = tsocket.from_stdlib_socket(b) @@ -458,7 +458,7 @@ def client(): t.join() -async def test_attributes(client_ctx): +async def test_attributes(client_ctx) -> None: async with ssl_echo_server_raw(expect_fail=True) as sock: good_ctx = client_ctx bad_ctx = ssl.create_default_context() @@ -527,7 +527,7 @@ async def test_attributes(client_ctx): # I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it... -async def test_full_duplex_basics(client_ctx): +async def test_full_duplex_basics(client_ctx) -> None: CHUNKS = 30 CHUNK_SIZE = 32768 EXPECTED = CHUNKS * CHUNK_SIZE @@ -564,7 +564,7 @@ async def receiver(s): assert sent == received -async def test_renegotiation_simple(client_ctx): +async def test_renegotiation_simple(client_ctx) -> None: with virtual_ssl_echo_server(client_ctx) as s: await s.do_handshake() @@ -701,7 +701,7 @@ async def sleeper_with_slow_wait_writable_and_expect(method): await s.aclose() -async def test_resource_busy_errors(client_ctx): +async def test_resource_busy_errors(client_ctx) -> None: async def do_send_all(): with assert_checkpoints(): await s.send_all(b"x") @@ -743,7 +743,7 @@ async def do_wait_send_all_might_not_block(): assert "another task" in str(excinfo.value) -async def test_wait_writable_calls_underlying_wait_writable(): +async def test_wait_writable_calls_underlying_wait_writable() -> None: record = [] class NotAStream: @@ -756,7 +756,7 @@ async def wait_send_all_might_not_block(self): assert record == ["ok"] -async def test_checkpoints(client_ctx): +async def test_checkpoints(client_ctx) -> None: async with ssl_echo_server(client_ctx) as s: with assert_checkpoints(): await s.do_handshake() @@ -785,7 +785,7 @@ async def test_checkpoints(client_ctx): await s.aclose() -async def test_send_all_empty_string(client_ctx): +async def test_send_all_empty_string(client_ctx) -> None: async with ssl_echo_server(client_ctx) as s: await s.do_handshake() @@ -828,7 +828,7 @@ async def clogged_stream_maker(): await check_two_way_stream(stream_maker, clogged_stream_maker) -async def test_unwrap(client_ctx): +async def test_unwrap(client_ctx) -> None: client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) client_transport = client_ssl.transport_stream server_transport = server_ssl.transport_stream @@ -882,7 +882,7 @@ async def server(): nursery.start_soon(server) -async def test_closing_nice_case(client_ctx): +async def test_closing_nice_case(client_ctx) -> None: # the nice case: graceful closes all around client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) @@ -944,7 +944,7 @@ async def expect_eof_server(): nursery.start_soon(expect_eof_server) -async def test_send_all_fails_in_the_middle(client_ctx): +async def test_send_all_fails_in_the_middle(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -975,7 +975,7 @@ def close_hook(): assert closed == 2 -async def test_ssl_over_ssl(client_ctx): +async def test_ssl_over_ssl(client_ctx) -> None: client_0, server_0 = memory_stream_pair() client_1 = SSLStream( @@ -1001,7 +1001,7 @@ async def server(): nursery.start_soon(server) -async def test_ssl_bad_shutdown(client_ctx): +async def test_ssl_bad_shutdown(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -1018,7 +1018,7 @@ async def test_ssl_bad_shutdown(client_ctx): await server.aclose() -async def test_ssl_bad_shutdown_but_its_ok(client_ctx): +async def test_ssl_bad_shutdown_but_its_ok(client_ctx) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, @@ -1038,7 +1038,7 @@ async def test_ssl_bad_shutdown_but_its_ok(client_ctx): await server.aclose() -async def test_ssl_handshake_failure_during_aclose(): +async def test_ssl_handshake_failure_during_aclose() -> None: # Weird scenario: aclose() triggers an automatic handshake, and this # fails. This also exercises a bit of code in aclose() that was otherwise # uncovered, for re-raising exceptions after calling aclose_forcefully on @@ -1057,7 +1057,7 @@ async def test_ssl_handshake_failure_during_aclose(): await s.aclose() -async def test_ssl_only_closes_stream_once(client_ctx): +async def test_ssl_only_closes_stream_once(client_ctx) -> None: # We used to have a bug where if transport_stream.aclose() raised an # error, we would call it again. This checks that that's fixed. client, server = ssl_memory_stream_pair(client_ctx) @@ -1082,7 +1082,7 @@ def close_hook(): assert transport_close_count == 1 -async def test_ssl_https_compatibility_disagreement(client_ctx): +async def test_ssl_https_compatibility_disagreement(client_ctx) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": False}, @@ -1105,7 +1105,7 @@ async def receive_and_expect_error(): nursery.start_soon(receive_and_expect_error) -async def test_https_mode_eof_before_handshake(client_ctx): +async def test_https_mode_eof_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, @@ -1120,7 +1120,7 @@ async def server_expect_clean_eof(): nursery.start_soon(server_expect_clean_eof) -async def test_send_error_during_handshake(client_ctx): +async def test_send_error_during_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async def bad_hook(): @@ -1137,7 +1137,7 @@ async def bad_hook(): await client.do_handshake() -async def test_receive_error_during_handshake(client_ctx): +async def test_receive_error_during_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async def bad_hook(): @@ -1160,7 +1160,7 @@ async def client_side(cancel_scope): await client.do_handshake() -async def test_selected_alpn_protocol_before_handshake(client_ctx): +async def test_selected_alpn_protocol_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1170,7 +1170,7 @@ async def test_selected_alpn_protocol_before_handshake(client_ctx): server.selected_alpn_protocol() -async def test_selected_alpn_protocol_when_not_set(client_ctx): +async def test_selected_alpn_protocol_when_not_set(client_ctx) -> None: # ALPN protocol still returns None when it's not ser, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1185,7 +1185,7 @@ async def test_selected_alpn_protocol_when_not_set(client_ctx): assert client.selected_alpn_protocol() == server.selected_alpn_protocol() -async def test_selected_npn_protocol_before_handshake(client_ctx): +async def test_selected_npn_protocol_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1195,7 +1195,7 @@ async def test_selected_npn_protocol_before_handshake(client_ctx): server.selected_npn_protocol() -async def test_selected_npn_protocol_when_not_set(client_ctx): +async def test_selected_npn_protocol_when_not_set(client_ctx) -> None: # NPN protocol still returns None when it's not ser, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1210,7 +1210,7 @@ async def test_selected_npn_protocol_when_not_set(client_ctx): assert client.selected_npn_protocol() == server.selected_npn_protocol() -async def test_get_channel_binding_before_handshake(client_ctx): +async def test_get_channel_binding_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1220,7 +1220,7 @@ async def test_get_channel_binding_before_handshake(client_ctx): server.get_channel_binding() -async def test_get_channel_binding_after_handshake(client_ctx): +async def test_get_channel_binding_after_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -1233,7 +1233,7 @@ async def test_get_channel_binding_after_handshake(client_ctx): assert client.get_channel_binding() == server.get_channel_binding() -async def test_getpeercert(client_ctx): +async def test_getpeercert(client_ctx) -> None: # Make sure we're not affected by https://bugs.python.org/issue29334 client, server = ssl_memory_stream_pair(client_ctx) @@ -1246,7 +1246,7 @@ async def test_getpeercert(client_ctx): assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"] -async def test_SSLListener(client_ctx): +async def test_SSLListener(client_ctx) -> None: async def setup(**kwargs): listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index 3f5a69ea12..15efb9853e 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -62,7 +62,7 @@ def got_signal(proc, sig): return proc.returncode != 0 -async def test_basic(): +async def test_basic() -> None: async with await open_process(EXIT_TRUE) as proc: pass assert isinstance(proc, Process) @@ -78,7 +78,7 @@ async def test_basic(): ) -async def test_auto_update_returncode(): +async def test_auto_update_returncode() -> None: p = await open_process(SLEEP(9999)) assert p.returncode is None assert "running" in repr(p) @@ -90,7 +90,7 @@ async def test_auto_update_returncode(): assert p.returncode is not None -async def test_multi_wait(): +async def test_multi_wait() -> None: async with await open_process(SLEEP(10)) as proc: # Check that wait (including multi-wait) tolerates being cancelled async with _core.open_nursery() as nursery: @@ -109,7 +109,7 @@ async def test_multi_wait(): proc.kill() -async def test_kill_when_context_cancelled(): +async def test_kill_when_context_cancelled() -> None: with move_on_after(100) as scope: async with await open_process(SLEEP(10)) as proc: assert proc.poll() is None @@ -129,7 +129,7 @@ async def test_kill_when_context_cancelled(): ) -async def test_pipes(): +async def test_pipes() -> None: async with await open_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, @@ -159,7 +159,7 @@ async def check_output(stream, expected): assert 0 == await proc.wait() -async def test_interactive(): +async def test_interactive() -> None: # Test some back-and-forth with a subprocess. This one works like so: # in: 32\n # out: 0000...0000\n (32 zeroes) @@ -227,7 +227,7 @@ async def drain_one(stream, count, digit): assert proc.returncode == 0 -async def test_run(): +async def test_run() -> None: data = bytes(random.randint(0, 255) for _ in range(2 ** 18)) result = await run_process( @@ -266,7 +266,7 @@ async def test_run(): await run_process(CAT, capture_stderr=True, stderr=None) -async def test_run_check(): +async def test_run_check() -> None: cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)") with pytest.raises(subprocess.CalledProcessError) as excinfo: await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True) @@ -293,7 +293,7 @@ async def test_run_with_broken_pipe() -> None: assert result.stdout is result.stderr is None -async def test_stderr_stdout(): +async def test_stderr_stdout() -> None: async with await open_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, @@ -358,7 +358,7 @@ async def test_stderr_stdout(): os.close(r) -async def test_errors(): +async def test_errors() -> None: with pytest.raises(TypeError) as excinfo: await open_process(["ls"], encoding="utf-8") assert "unbuffered byte streams" in str(excinfo.value) @@ -371,8 +371,8 @@ async def test_errors(): await open_process("ls", shell=False) -async def test_signals(): - async def test_one_signal(send_it, signum): +async def test_signals() -> None: + async def test_one_signal(send_it, signum) -> None: with move_on_after(1.0) as scope: async with await open_process(SLEEP(3600)) as proc: send_it(proc) @@ -459,7 +459,7 @@ def on_alarm(sig, frame): signal.signal(signal.SIGALRM, old_sigalrm) -async def test_custom_deliver_cancel(): +async def test_custom_deliver_cancel() -> None: custom_deliver_cancel_called = False async def custom_deliver_cancel(proc): @@ -483,7 +483,7 @@ async def custom_deliver_cancel(proc): assert custom_deliver_cancel_called -async def test_warn_on_failed_cancel_terminate(monkeypatch): +async def test_warn_on_failed_cancel_terminate(monkeypatch) -> None: original_terminate = Process.terminate def broken_terminate(self): diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index b82351af76..560792c240 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -12,7 +12,7 @@ from .._sync import * -async def test_Event(): +async def test_Event() -> None: e = Event() assert not e.is_set() assert e.statistics().tasks_waiting == 0 @@ -42,7 +42,7 @@ async def child(): assert record == ["sleeping", "sleeping", "woken", "woken"] -async def test_CapacityLimiter(): +async def test_CapacityLimiter() -> None: with pytest.raises(TypeError): CapacityLimiter(1.0) with pytest.raises(ValueError): @@ -111,7 +111,7 @@ async def test_CapacityLimiter(): c.release_on_behalf_of("value 1") -async def test_CapacityLimiter_inf(): +async def test_CapacityLimiter_inf() -> None: from math import inf c = CapacityLimiter(inf) @@ -127,7 +127,7 @@ async def test_CapacityLimiter_inf(): assert c.available_tokens == inf -async def test_CapacityLimiter_change_total_tokens(): +async def test_CapacityLimiter_change_total_tokens() -> None: c = CapacityLimiter(2) with pytest.raises(TypeError): @@ -164,7 +164,7 @@ async def test_CapacityLimiter_change_total_tokens(): # regression test for issue #548 -async def test_CapacityLimiter_memleak_548(): +async def test_CapacityLimiter_memleak_548() -> None: limiter = CapacityLimiter(total_tokens=1) await limiter.acquire() @@ -178,7 +178,7 @@ async def test_CapacityLimiter_memleak_548(): assert len(limiter._pending_borrowers) == 0 -async def test_Semaphore(): +async def test_Semaphore() -> None: with pytest.raises(TypeError): Semaphore(1.0) with pytest.raises(ValueError): @@ -226,7 +226,7 @@ async def do_acquire(s): assert record == ["started", "finished"] -async def test_Semaphore_bounded(): +async def test_Semaphore_bounded() -> None: with pytest.raises(TypeError): Semaphore(1, max_value=1.0) with pytest.raises(ValueError): @@ -326,7 +326,7 @@ async def holder(): assert statistics.tasks_waiting == 0 -async def test_Condition(): +async def test_Condition() -> None: with pytest.raises(TypeError): Condition(Semaphore(1)) with pytest.raises(TypeError): @@ -419,7 +419,7 @@ async def waiter(i): @async_cm class ChannelLock1: - def __init__(self, capacity): + def __init__(self, capacity) -> None: self.s, self.r = open_memory_channel(capacity) for _ in range(capacity - 1): self.s.send_nowait(None) @@ -436,7 +436,7 @@ def release(self): @async_cm class ChannelLock2: - def __init__(self): + def __init__(self) -> None: self.s, self.r = open_memory_channel(10) self.s.send_nowait(None) @@ -452,7 +452,7 @@ def release(self): @async_cm class ChannelLock3: - def __init__(self): + def __init__(self) -> None: self.s, self.r = open_memory_channel(0) # self.acquired is true when one task acquires the lock and # only becomes false when it's released and no tasks are diff --git a/trio/tests/test_testing.py b/trio/tests/test_testing.py index 0b10ae71e1..812c108137 100644 --- a/trio/tests/test_testing.py +++ b/trio/tests/test_testing.py @@ -15,7 +15,7 @@ from .._highlevel_socket import SocketListener -async def test_wait_all_tasks_blocked(): +async def test_wait_all_tasks_blocked() -> None: record = [] async def busy_bee(): @@ -47,7 +47,7 @@ async def cancelled_while_waiting(): assert record == ["ok"] -async def test_wait_all_tasks_blocked_with_timeouts(mock_clock): +async def test_wait_all_tasks_blocked_with_timeouts(mock_clock) -> None: record = [] async def timeout_task(): @@ -64,7 +64,7 @@ async def timeout_task(): assert record == ["tt start", "tt finished"] -async def test_wait_all_tasks_blocked_with_cushion(): +async def test_wait_all_tasks_blocked_with_cushion() -> None: record = [] async def blink(): @@ -106,7 +106,7 @@ async def wait_big_cushion(): ################################################################ -async def test_assert_checkpoints(recwarn): +async def test_assert_checkpoints(recwarn) -> None: with assert_checkpoints(): await _core.checkpoint() @@ -132,7 +132,7 @@ async def test_assert_checkpoints(recwarn): await _core.cancel_shielded_checkpoint() -async def test_assert_no_checkpoints(recwarn): +async def test_assert_no_checkpoints(recwarn) -> None: with assert_no_checkpoints(): 1 + 1 @@ -162,7 +162,7 @@ async def test_assert_no_checkpoints(recwarn): ################################################################ -async def test_Sequencer(): +async def test_Sequencer() -> None: record = [] def t(val): @@ -200,7 +200,7 @@ async def f2(seq): pass # pragma: no cover -async def test_Sequencer_cancel(): +async def test_Sequencer_cancel() -> None: # Killing a blocked task makes everything blow up record = [] seq = Sequencer() @@ -232,7 +232,7 @@ async def child(i): ################################################################ -async def test__assert_raises(): +async def test__assert_raises() -> None: with pytest.raises(AssertionError): with _assert_raises(RuntimeError): 1 + 1 @@ -247,7 +247,7 @@ async def test__assert_raises(): # This is a private implementation detail, but it's complex enough to be worth # testing directly -async def test__UnboundeByteQueue(): +async def test__UnboundeByteQueue() -> None: ubq = _UnboundedByteQueue() ubq.put(b"123") @@ -319,7 +319,7 @@ async def closer(): nursery.start_soon(closer) -async def test_MemorySendStream(): +async def test_MemorySendStream() -> None: mss = MemorySendStream() async def do_send_all(data): @@ -409,7 +409,7 @@ def close_hook(): ] -async def test_MemoryReceiveStream(): +async def test_MemoryReceiveStream() -> None: mrs = MemoryReceiveStream() async def do_receive_some(max_bytes): @@ -470,7 +470,7 @@ def close_hook(): await mrs2.receive_some(10) -async def test_MemoryRecvStream_closing(): +async def test_MemoryRecvStream_closing() -> None: mrs = MemoryReceiveStream() # close with no pending data mrs.close() @@ -490,7 +490,7 @@ async def test_MemoryRecvStream_closing(): await mrs2.receive_some(10) -async def test_memory_stream_pump(): +async def test_memory_stream_pump() -> None: mss = MemorySendStream() mrs = MemoryReceiveStream() @@ -514,7 +514,7 @@ async def test_memory_stream_pump(): assert await mrs.receive_some(10) == b"" -async def test_memory_stream_one_way_pair(): +async def test_memory_stream_one_way_pair() -> None: s, r = memory_stream_one_way_pair() assert s.send_all_hook is not None assert s.wait_send_all_might_not_block_hook is None @@ -570,7 +570,7 @@ async def check_for_cancel(): assert await r.receive_some(10) == b"456789" -async def test_memory_stream_pair(): +async def test_memory_stream_pair() -> None: a, b = memory_stream_pair() await a.send_all(b"123") await b.send_all(b"abc") @@ -592,7 +592,7 @@ async def receiver(): nursery.start_soon(sender) -async def test_memory_streams_with_generic_tests(): +async def test_memory_streams_with_generic_tests() -> None: async def one_way_stream_maker(): return memory_stream_one_way_pair() @@ -604,7 +604,7 @@ async def half_closeable_stream_maker(): await check_half_closeable_stream(half_closeable_stream_maker, None) -async def test_lockstep_streams_with_generic_tests(): +async def test_lockstep_streams_with_generic_tests() -> None: async def one_way_stream_maker(): return lockstep_stream_one_way_pair() @@ -616,7 +616,7 @@ async def two_way_stream_maker(): await check_two_way_stream(two_way_stream_maker, two_way_stream_maker) -async def test_open_stream_to_socket_listener(): +async def test_open_stream_to_socket_listener() -> None: async def check(listener): async with listener: client_stream = await open_stream_to_socket_listener(listener) diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index a2b1fce23a..0ca94d71c7 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -19,7 +19,7 @@ from .._core.tests.test_ki import ki_self -async def test_do_in_trio_thread(): +async def test_do_in_trio_thread() -> None: trio_thread = threading.current_thread() async def check_case(do_in_trio_thread, fn, expected, trio_token=None): @@ -74,7 +74,7 @@ async def f(record): await check_case(from_thread_run, f, ("error", KeyError), trio_token=token) -async def test_do_in_trio_thread_from_trio_thread(): +async def test_do_in_trio_thread_from_trio_thread() -> None: with pytest.raises(RuntimeError): from_thread_run_sync(lambda: None) # pragma: no branch @@ -85,7 +85,7 @@ async def foo(): # pragma: no cover from_thread_run(foo) -def test_run_in_trio_thread_ki(): +def test_run_in_trio_thread_ki() -> None: # if we get a control-C during a run_in_trio_thread, then it propagates # back to the caller (slick!) record = set() @@ -133,7 +133,7 @@ def external_thread_fn(): assert record == {"ok1", "ok2"} -def test_await_in_trio_thread_while_main_exits(): +def test_await_in_trio_thread_while_main_exits() -> None: record = [] ev = Event() @@ -161,7 +161,7 @@ async def main(): assert record == ["sleeping", "cancelled"] -async def test_run_in_worker_thread(): +async def test_run_in_worker_thread() -> None: trio_thread = threading.current_thread() def f(x): @@ -180,7 +180,7 @@ def g(): assert excinfo.value.args[0] != trio_thread -async def test_run_in_worker_thread_cancellation(): +async def test_run_in_worker_thread_cancellation() -> None: register = [None] def f(q): @@ -240,7 +240,7 @@ async def child(q, cancellable): # Make sure that if trio.run exits, and then the thread finishes, then that's # handled gracefully. (Requires that the thread result machinery be prepared # for call_soon to raise RunFinishedError.) -def test_run_in_worker_thread_abandoned(capfd, monkeypatch): +def test_run_in_worker_thread_abandoned(capfd, monkeypatch) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) q1 = stdlib_queue.Queue() @@ -387,7 +387,7 @@ async def run_thread(event): c.total_tokens = orig_total_tokens -async def test_run_in_worker_thread_custom_limiter(): +async def test_run_in_worker_thread_custom_limiter() -> None: # Basically just checking that we only call acquire_on_behalf_of and # release_on_behalf_of, since that's part of our documented API. record = [] @@ -405,7 +405,7 @@ def release_on_behalf_of(self, borrower): assert record == ["acquire", "release"] -async def test_run_in_worker_thread_limiter_error(): +async def test_run_in_worker_thread_limiter_error() -> None: record = [] class BadCapacityLimiter: @@ -433,7 +433,7 @@ def release_on_behalf_of(self, borrower): assert record == ["acquire", "release"] -async def test_run_in_worker_thread_fail_to_spawn(monkeypatch): +async def test_run_in_worker_thread_fail_to_spawn(monkeypatch) -> None: # Test the unlikely but possible case where trying to spawn a thread fails def bad_start(self, *args): raise RuntimeError("the engines canna take it captain") @@ -451,7 +451,7 @@ def bad_start(self, *args): assert limiter.borrowed_tokens == 0 -async def test_trio_to_thread_run_sync_token(): +async def test_trio_to_thread_run_sync_token() -> None: # Test that to_thread_run_sync automatically injects the current trio token # into a spawned thread def thread_fn(): @@ -463,7 +463,7 @@ def thread_fn(): assert callee_token == caller_token -async def test_trio_to_thread_run_sync_expected_error(): +async def test_trio_to_thread_run_sync_expected_error() -> None: # Test correct error when passed async function async def async_fn(): # pragma: no cover pass @@ -472,7 +472,7 @@ async def async_fn(): # pragma: no cover await to_thread_run_sync(async_fn) -async def test_trio_from_thread_run_sync(): +async def test_trio_from_thread_run_sync() -> None: # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run_sync() def thread_fn(): @@ -493,7 +493,7 @@ def thread_fn(): await to_thread_run_sync(thread_fn) -async def test_trio_from_thread_run(): +async def test_trio_from_thread_run() -> None: # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run() record = [] @@ -517,7 +517,7 @@ def sync_fn(): # pragma: no cover await to_thread_run_sync(from_thread_run, sync_fn) -async def test_trio_from_thread_token(): +async def test_trio_from_thread_token() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() # share the same Trio token def thread_fn(): @@ -529,7 +529,7 @@ def thread_fn(): assert callee_token == caller_token -async def test_trio_from_thread_token_kwarg(): +async def test_trio_from_thread_token_kwarg() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can # use an explicitly defined token def thread_fn(token): @@ -541,7 +541,7 @@ def thread_fn(token): assert callee_token == caller_token -async def test_from_thread_no_token(): +async def test_from_thread_no_token() -> None: # Test that a "raw call" to trio.from_thread.run() fails because no token # has been provided @@ -549,12 +549,12 @@ async def test_from_thread_no_token(): from_thread_run_sync(_core.current_time) -def test_run_fn_as_system_task_catched_badly_typed_token(): +def test_run_fn_as_system_task_catched_badly_typed_token() -> None: with pytest.raises(RuntimeError): from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype") -async def test_from_thread_inside_trio_thread(): +async def test_from_thread_inside_trio_thread() -> None: def not_called(): # pragma: no cover assert False diff --git a/trio/tests/test_unix_pipes.py b/trio/tests/test_unix_pipes.py index 4511f08cf4..094cfb0cfd 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/tests/test_unix_pipes.py @@ -55,7 +55,7 @@ async def make_clogged_pipe(): pass return s, r - async def test_send_pipe(): + async def test_send_pipe() -> None: r, w = os.pipe() async with FdStream(w) as send: assert send.fileno() == w @@ -64,7 +64,7 @@ async def test_send_pipe(): os.close(r) - async def test_receive_pipe(): + async def test_receive_pipe() -> None: r, w = os.pipe() async with FdStream(r) as recv: assert (recv.fileno()) == r @@ -73,7 +73,7 @@ async def test_receive_pipe(): os.close(w) - async def test_pipes_combined(): + async def test_pipes_combined() -> None: write, read = await make_pipe() count = 2 ** 20 @@ -96,7 +96,7 @@ async def reader(): await read.aclose() await write.aclose() - async def test_pipe_errors(): + async def test_pipe_errors() -> None: with pytest.raises(TypeError): FdStream(None) @@ -106,7 +106,7 @@ async def test_pipe_errors(): with pytest.raises(ValueError): await s.receive_some(0) - async def test_del(): + async def test_del() -> None: w, r = await make_pipe() f1, f2 = w.fileno(), r.fileno() del w, r @@ -120,7 +120,7 @@ async def test_del(): os.close(f2) assert excinfo.value.errno == errno.EBADF - async def test_async_with(): + async def test_async_with() -> None: w, r = await make_pipe() async with w, r: pass @@ -136,7 +136,7 @@ async def test_async_with(): os.close(r.fileno()) assert excinfo.value.errno == errno.EBADF - async def test_misdirected_aclose_regression(): + async def test_misdirected_aclose_regression() -> None: # https://github.com/python-trio/trio/issues/661#issuecomment-456582356 w, r = await make_pipe() old_r_fd = r.fileno() @@ -173,7 +173,7 @@ async def expect_eof(): # gets an EOF and can exit cleanly. os.close(w2_fd) - async def test_close_at_bad_time_for_receive_some(monkeypatch): + async def test_close_at_bad_time_for_receive_some(monkeypatch) -> None: # We used to have race conditions where if one task was using the pipe, # and another closed it at *just* the wrong moment, it would give an # unexpected error instead of ClosedResourceError: @@ -202,7 +202,7 @@ async def patched_wait_readable(*args, **kwargs): # Trigger everything by waking up the receiver await s.send_all(b"x") - async def test_close_at_bad_time_for_send_all(monkeypatch): + async def test_close_at_bad_time_for_send_all(monkeypatch) -> None: # We used to have race conditions where if one task was using the pipe, # and another closed it at *just* the wrong moment, it would give an # unexpected error instead of ClosedResourceError: diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index 95c9c5fbb8..627e0b36a7 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -20,7 +20,7 @@ _T = TypeVar("_T") -def test_signal_raise(): +def test_signal_raise() -> None: record = [] def handler(signum, _): @@ -34,7 +34,7 @@ def handler(signum, _): assert record == [signal.SIGFPE] -async def test_ConflictDetector(): +async def test_ConflictDetector() -> None: ul1 = ConflictDetector("ul1") ul2 = ConflictDetector("ul2") @@ -59,7 +59,7 @@ async def wait_with_ul1(): assert "ul1" in str(excinfo.value) -def test_module_metadata_is_fixed_up(): +def test_module_metadata_is_fixed_up() -> None: import trio import trio.testing @@ -83,7 +83,7 @@ def test_module_metadata_is_fixed_up(): assert trio.to_thread.run_sync.__qualname__ == "run_sync" -async def test_is_main_thread(): +async def test_is_main_thread() -> None: assert is_main_thread() def not_main_thread(): @@ -164,7 +164,7 @@ def test_func(arg: _T) -> _T: assert test_func.__module__ == __name__ -def test_final_metaclass(): +def test_final_metaclass() -> None: class FinalClass(metaclass=Final): pass @@ -174,7 +174,7 @@ class SubClass(FinalClass): pass -def test_no_public_constructor_metaclass(): +def test_no_public_constructor_metaclass() -> None: class SpecialClass(metaclass=NoPublicConstructor): pass diff --git a/trio/tests/test_wait_for_object.py b/trio/tests/test_wait_for_object.py index 54af6b873b..00a0b23015 100644 --- a/trio/tests/test_wait_for_object.py +++ b/trio/tests/test_wait_for_object.py @@ -19,7 +19,7 @@ ) -async def test_WaitForMultipleObjects_sync(): +async def test_WaitForMultipleObjects_sync() -> None: # This does a series of tests where we set/close the handle before # initiating the waiting for it. # @@ -128,7 +128,7 @@ async def test_WaitForMultipleObjects_sync_slow() -> None: print("test_WaitForMultipleObjects_sync_slow thread-set second OK") -async def test_WaitForSingleObject(): +async def test_WaitForSingleObject() -> None: # This does a series of test for setting/closing the handle before # initiating the wait. diff --git a/trio/tests/test_windows_pipes.py b/trio/tests/test_windows_pipes.py index 919ca75427..a8654d8b04 100644 --- a/trio/tests/test_windows_pipes.py +++ b/trio/tests/test_windows_pipes.py @@ -27,14 +27,14 @@ async def make_pipe() -> Tuple[PipeSendStream, PipeReceiveStream]: return PipeSendStream(w), PipeReceiveStream(r) -async def test_pipe_typecheck(): +async def test_pipe_typecheck() -> None: with pytest.raises(TypeError): PipeSendStream(1.0) with pytest.raises(TypeError): PipeReceiveStream(None) -async def test_pipe_error_on_close(): +async def test_pipe_error_on_close() -> None: # Make sure we correctly handle a failure from kernel32.CloseHandle r, w = pipe() @@ -50,7 +50,7 @@ async def test_pipe_error_on_close(): await receive_stream.aclose() -async def test_pipes_combined(): +async def test_pipes_combined() -> None: write, read = await make_pipe() count = 2 ** 20 replicas = 3 @@ -79,7 +79,7 @@ async def reader(): n.start_soon(reader) -async def test_async_with(): +async def test_async_with() -> None: w, r = await make_pipe() async with w, r: pass @@ -90,7 +90,7 @@ async def test_async_with(): await r.receive_some(10) -async def test_close_during_write(): +async def test_close_during_write() -> None: w, r = await make_pipe() async with _core.open_nursery() as nursery: @@ -105,7 +105,7 @@ async def write_forever(): await w.aclose() -async def test_pipe_fully(): +async def test_pipe_fully() -> None: # passing make_clogged_pipe tests wait_send_all_might_not_block, and we # can't implement that on Windows await check_one_way_stream(make_pipe, None) diff --git a/trio/tests/tools/test_gen_exports.py b/trio/tests/tools/test_gen_exports.py index e4e388c226..43cbc3a88a 100644 --- a/trio/tests/tools/test_gen_exports.py +++ b/trio/tests/tools/test_gen_exports.py @@ -1,5 +1,6 @@ import ast import astor +from pathlib import Path import pytest import os import sys @@ -32,12 +33,12 @@ async def not_public_async(self): ''' -def test_get_public_methods(): +def test_get_public_methods() -> None: methods = list(get_public_methods(ast.parse(SOURCE))) assert {m.name for m in methods} == {"public_func", "public_async_func"} -def test_create_pass_through_args(): +def test_create_pass_through_args() -> None: testcases = [ ("def f()", "()"), ("def f(one)", "(one)"), @@ -55,7 +56,7 @@ def test_create_pass_through_args(): assert create_passthrough_args(func_node) == expected -def test_process(tmp_path): +def test_process(tmp_path: Path) -> None: modpath = tmp_path / "_module.py" genpath = tmp_path / "_generated_module.py" modpath.write_text(SOURCE, encoding="utf-8") From 49289ab9de8439a58da427203f678acb47bf3564 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sun, 14 Feb 2021 21:01:18 -0500 Subject: [PATCH 44/50] piles, upon piles --- trio/_abc.py | 6 +- trio/_core/_generated_io_epoll.py | 4 +- trio/_core/_generated_io_kqueue.py | 4 +- trio/_core/_generated_io_windows.py | 4 +- trio/_core/_io_common.py | 12 ++- trio/_core/_io_epoll.py | 5 +- trio/_core/_io_kqueue.py | 5 +- trio/_core/_io_windows.py | 4 +- trio/_core/_multierror.py | 16 +++- trio/_core/_traps.py | 24 ++--- trio/_core/tests/test_windows.py | 4 +- trio/_highlevel_open_tcp_listeners.py | 16 ++-- trio/_highlevel_serve_listeners.py | 9 +- trio/_socket.py | 111 ++++++++++++++++------- trio/_timeouts.py | 12 +-- trio/_tools/gen_exports.py | 12 ++- trio/_typing.py | 7 ++ trio/_util.py | 7 +- trio/tests/test_highlevel_ssl_helpers.py | 4 +- 19 files changed, 186 insertions(+), 80 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index 1abaac743f..ad136ea6e0 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -184,8 +184,10 @@ async def getaddrinfo( @abstractmethod async def getnameinfo( - self, sockaddr: Union[Tuple[str, int], Tuple[str, int, int, int]], flags: int - ) -> Tuple[str, Union[int, str]]: + self, + sockaddr: Union[Tuple[str, int], Tuple[str, int, int, int]], + flags: int, + ) -> Tuple[str, str]: """A custom implementation of :func:`~trio.socket.getnameinfo`. Called by :func:`trio.socket.getnameinfo`. diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index cd76014d5c..a6d09fafbc 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -40,7 +40,7 @@ async def wait_readable(fd: Union[int, _HasFileno]) ->None: raise RuntimeError("must be called from async context") -async def wait_writable(fd: Union[int, _HasFileno]) ->None: +async def wait_writable(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -48,7 +48,7 @@ async def wait_writable(fd: Union[int, _HasFileno]) ->None: raise RuntimeError("must be called from async context") -def notify_closing(fd: Union[int, _HasFileno]) ->None: +def notify_closing(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 200017ffec..21190a47be 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -66,7 +66,7 @@ async def wait_readable(fd: Union[int, _HasFileno]) ->None: raise RuntimeError("must be called from async context") -async def wait_writable(fd: Union[int, _HasFileno]) ->None: +async def wait_writable(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -74,7 +74,7 @@ async def wait_writable(fd: Union[int, _HasFileno]) ->None: raise RuntimeError("must be called from async context") -def notify_closing(fd: Union[int, _HasFileno]) ->None: +def notify_closing(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 85ed1a3f0d..fb4ef94936 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -40,7 +40,7 @@ async def wait_readable(sock: int) ->None: raise RuntimeError("must be called from async context") -async def wait_writable(sock: int) ->None: +async def wait_writable(sock: Union[int, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) @@ -48,7 +48,7 @@ async def wait_writable(sock: int) ->None: raise RuntimeError("must be called from async context") -def notify_closing(handle: int) ->None: +def notify_closing(handle: Union[int, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) diff --git a/trio/_core/_io_common.py b/trio/_core/_io_common.py index 9891849bc9..e9395dbf0b 100644 --- a/trio/_core/_io_common.py +++ b/trio/_core/_io_common.py @@ -1,10 +1,20 @@ import copy +from typing import Optional + import outcome from .. import _core +from . import _run + +from typing_extensions import Protocol + + +class Waiter(Protocol): + read_task: Optional[_run.Task] + write_task: Optional[_run.Task] # Utility function shared between _io_epoll and _io_windows -def wake_all(waiters, exc): +def wake_all(waiters: Waiter, exc: Exception) -> None: try: current_task = _core.current_task() except RuntimeError: diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 24ecff15c7..3f6327347e 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -1,4 +1,5 @@ import select +import socket import sys import attr from collections import defaultdict @@ -305,11 +306,11 @@ async def wait_readable( await self._epoll_wait(fd, "read_task") @_public - async def wait_writable(self, fd: Union[int, _HasFileno]) -> None: + async def wait_writable(self, fd: Union[int, _HasFileno, socket.socket]) -> None: await self._epoll_wait(fd, "write_task") @_public - def notify_closing(self, fd: Union[int, _HasFileno]) -> None: + def notify_closing(self, fd: Union[int, _HasFileno, socket.socket]) -> None: if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index d9c0773958..48848e16dc 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -1,4 +1,5 @@ import select +import socket import sys from typing import Callable, Dict, Iterator, Optional, Tuple, TYPE_CHECKING, Union @@ -178,11 +179,11 @@ async def wait_readable(self, fd: Union[int, _HasFileno]) -> None: await self._wait_common(fd, select.KQ_FILTER_READ) @_public - async def wait_writable(self, fd: Union[int, _HasFileno]) -> None: + async def wait_writable(self, fd: Union[int, _HasFileno, socket.socket]) -> None: await self._wait_common(fd, select.KQ_FILTER_WRITE) @_public - def notify_closing(self, fd: Union[int, _HasFileno]) -> None: + def notify_closing(self, fd: Union[int, _HasFileno, socket.socket]) -> None: if not isinstance(fd, int): fd = fd.fileno() diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index f67e03b1f7..19f06546d4 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -701,11 +701,11 @@ async def wait_readable(self, sock: int) -> None: await self._afd_poll(sock, "read_task") @_public - async def wait_writable(self, sock: int) -> None: + async def wait_writable(self, sock: Union[int, socket.socket]) -> None: await self._afd_poll(sock, "write_task") @_public - def notify_closing(self, handle: int) -> None: + def notify_closing(self, handle: Union[int, socket.socket]) -> None: handle = _get_base_socket(handle) waiters = self._afd_waiters.get(handle) if waiters is not None: diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index da008f8f44..ecbb773962 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -466,7 +466,9 @@ def traceback_exception_init( traceback_exception_original_format = traceback.TracebackException.format -def traceback_exception_format(self: traceback.TracebackException, *, chain: bool = True) -> Iterator[str]: +def traceback_exception_format( + self: traceback.TracebackException, *, chain: bool = True +) -> Iterator[str]: yield from traceback_exception_original_format(self, chain=chain) for i, exc in enumerate(self.embedded): # type: ignore[attr-defined] @@ -477,7 +479,9 @@ def traceback_exception_format(self: traceback.TracebackException, *, chain: boo traceback.TracebackException.format = traceback_exception_format # type: ignore[assignment] -def trio_excepthook(etype: Type[BaseException], value: BaseException, tb: TracebackType) -> None: +def trio_excepthook( + etype: Type[BaseException], value: BaseException, tb: TracebackType +) -> None: for chunk in traceback.format_exception(etype, value, tb): sys.stderr.write(chunk) @@ -500,7 +504,13 @@ def trio_excepthook(etype: Type[BaseException], value: BaseException, tb: Traceb monkeypatched_or_warned = True else: - def trio_show_traceback(self: object, etype: Type[BaseException], value: BaseException, tb: TracebackType, tb_offset: Optional[int] = None) -> None: + def trio_show_traceback( + self: object, + etype: Type[BaseException], + value: BaseException, + tb: TracebackType, + tb_offset: Optional[int] = None, + ) -> None: # XX it would be better to integrate with IPython's fancy # exception formatting stuff (and not ignore tb_offset) trio_excepthook(etype, value, tb) diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index 7d6f160434..163a98da25 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -9,6 +9,8 @@ from . import _run +AbortFunc = Callable[[Callable[[], None]], "Abort"] + # Helper for the bottommost 'yield'. You can't use 'yield' inside an async # function, but you can inside a generator, and if you decorate your generator @@ -28,7 +30,7 @@ class CancelShieldedCheckpoint: pass -async def cancel_shielded_checkpoint(): +async def cancel_shielded_checkpoint() -> None: """Introduce a schedule point, but not a cancel point. This is *not* a :ref:`checkpoint `, but it is half of a @@ -41,7 +43,7 @@ async def cancel_shielded_checkpoint(): await trio.lowlevel.checkpoint() """ - return (await _async_yield(CancelShieldedCheckpoint)).unwrap() + return (await _async_yield(CancelShieldedCheckpoint)).unwrap() # type: ignore[no-any-return] # Return values for abort functions @@ -62,12 +64,10 @@ class Abort(enum.Enum): # Not exported in the trio._core namespace, but imported directly by _run. @attr.s(frozen=True) class WaitTaskRescheduled: - abort_func = attr.ib() + abort_func: AbortFunc = attr.ib() -async def wait_task_rescheduled( - abort_func: Callable[[Callable[[], None]], Abort] -) -> object: +async def wait_task_rescheduled(abort_func: AbortFunc) -> object: """Put the current task to sleep, with cancellation support. This is the lowest-level API for blocking in Trio. Every time a @@ -172,10 +172,10 @@ def abort(inner_raise_cancel): # Not exported in the trio._core namespace, but imported directly by _run. @attr.s(frozen=True) class PermanentlyDetachCoroutineObject: - final_outcome = attr.ib() + final_outcome: outcome.Outcome = attr.ib() -async def permanently_detach_coroutine_object(final_outcome): +async def permanently_detach_coroutine_object(final_outcome: outcome.Outcome) -> None: """Permanently detach the current task from the Trio scheduler. Normally, a Trio task doesn't exit until its coroutine object exits. When @@ -203,10 +203,10 @@ async def permanently_detach_coroutine_object(final_outcome): raise RuntimeError( "can't permanently detach a coroutine object with open nurseries" ) - return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome)) + return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome)) # type: ignore[no-any-return] -async def temporarily_detach_coroutine_object(abort_func): +async def temporarily_detach_coroutine_object(abort_func: AbortFunc) -> None: """Temporarily detach the current coroutine object from the Trio scheduler. @@ -242,7 +242,9 @@ async def temporarily_detach_coroutine_object(abort_func): return await _async_yield(WaitTaskRescheduled(abort_func)) -async def reattach_detached_coroutine_object(task, yield_value): +async def reattach_detached_coroutine_object( + task: "_run.Task", yield_value: object +) -> None: """Reattach a coroutine object that was detached using :func:`temporarily_detach_coroutine_object`. diff --git a/trio/_core/tests/test_windows.py b/trio/_core/tests/test_windows.py index f7d0896019..b789d74cf9 100644 --- a/trio/_core/tests/test_windows.py +++ b/trio/_core/tests/test_windows.py @@ -193,7 +193,9 @@ def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): ): _core.run(sleep, 0) - def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch) -> None: + def test_lsp_that_completely_hides_base_socket_gives_good_error( + monkeypatch, + ) -> None: # This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns # self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns # self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to diff --git a/trio/_highlevel_open_tcp_listeners.py b/trio/_highlevel_open_tcp_listeners.py index 80f2c7a180..402fb2a789 100644 --- a/trio/_highlevel_open_tcp_listeners.py +++ b/trio/_highlevel_open_tcp_listeners.py @@ -1,8 +1,12 @@ import errno import sys from math import inf +from typing import Awaitable, Callable, Optional, Union import trio +from . import Nursery +from .abc import Stream +from ._typing import TaskStatus from . import socket as tsocket @@ -144,13 +148,13 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): async def serve_tcp( - handler, - port, + handler: Callable[[Stream], Awaitable[object]], + port: int, *, - host=None, - backlog=None, - handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED, + host: Optional[Union[str, bytes]] = None, + backlog: Optional[int] = None, + handler_nursery: Optional[Nursery] = None, + task_status: TaskStatus = trio.TASK_STATUS_IGNORED, ): """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py index 0585fa516f..bf5bf07a97 100644 --- a/trio/_highlevel_serve_listeners.py +++ b/trio/_highlevel_serve_listeners.py @@ -1,8 +1,11 @@ import errno import logging import os +from typing import Awaitable, Callable, List, Optional import trio +import trio.abc +from ._typing import TaskStatus # Errors that accept(2) can return, and which indicate that the system is # overloaded @@ -49,7 +52,11 @@ async def _serve_one_listener(listener, handler_nursery, handler): async def serve_listeners( - handler, listeners, *, handler_nursery=None, task_status=trio.TASK_STATUS_IGNORED + handler: Callable[[trio.abc.Stream], Awaitable[object]], + listeners: List[trio.abc.Listener], + *, + handler_nursery: Optional[trio.Nursery] = None, + task_status: TaskStatus = trio.TASK_STATUS_IGNORED, ): r"""Listen for incoming connections on ``listeners``, and for each one start a task running ``handler(stream)``. diff --git a/trio/_socket.py b/trio/_socket.py index ac5d27b063..7039c3ad4e 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -4,6 +4,7 @@ import select import socket as _stdlib_socket from functools import wraps as _wraps +from types import TracebackType from typing import ( Awaitable, Callable, @@ -12,8 +13,10 @@ Optional, Iterable, Sequence, + Text, Tuple, TYPE_CHECKING, + Type, TypeVar, List, Any, @@ -41,19 +44,23 @@ # return await do_it_properly_with_a_check_point() # class _try_sync: - def __init__(self, blocking_exc_override=None) -> None: + def __init__( + self, blocking_exc_override: Callable[[BaseException], bool] = None + ) -> None: self._blocking_exc_override = blocking_exc_override - def _is_blocking_io_error(self, exc): + def _is_blocking_io_error(self, exc: BaseException) -> bool: if self._blocking_exc_override is None: return isinstance(exc, BlockingIOError) else: return self._blocking_exc_override(exc) - async def __aenter__(self): + async def __aenter__(self) -> None: await trio.lowlevel.checkpoint_if_cancelled() - async def __aexit__(self, etype, value, tb): + async def __aexit__( + self, etype: Type[BaseException], value: BaseException, tb: TracebackType + ) -> bool: if value is not None and self._is_blocking_io_error(value): # Discard the exception and fall through to the code below the # block @@ -84,7 +91,9 @@ async def __aexit__(self, etype, value, tb): _socket_factory = _core.RunVar[Optional["trio._abc.SocketFactory"]]("socket_factory") -def set_custom_hostname_resolver(hostname_resolver): +def set_custom_hostname_resolver( + hostname_resolver: "trio._abc.HostnameResolver", +) -> Optional["trio._abc.HostnameResolver"]: """Set a custom hostname resolver. By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions @@ -152,7 +161,22 @@ def set_custom_socket_factory( _NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV -async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): +async def getaddrinfo( + host: Optional[Union[bytearray, bytes, Text]], + port: Union[str, int, None], + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> List[ + Tuple[ + _stdlib_socket.AddressFamily, + _stdlib_socket.SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] +]: """Look up a numeric address given a name. Arguments and return values are identical to :func:`socket.getaddrinfo`, @@ -173,7 +197,7 @@ async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): # skip the whole thread thing, which seems worthwhile. So we try first # with the _NUMERIC_ONLY flags set, and then only spawn a thread if that # fails with EAI_NONAME: - def numeric_only_failure(exc): + def numeric_only_failure(exc: BaseException) -> bool: return ( isinstance(exc, _stdlib_socket.gaierror) and exc.errno == _stdlib_socket.EAI_NONAME @@ -215,7 +239,9 @@ def numeric_only_failure(exc): ) -async def getnameinfo(sockaddr, flags): +async def getnameinfo( + sockaddr: Union[Tuple[str, int], Tuple[str, int, int, int]], flags: int +) -> Tuple[str, str]: """Look up a name given a numeric address. Arguments and return values are identical to :func:`socket.getnameinfo`, @@ -234,7 +260,7 @@ async def getnameinfo(sockaddr, flags): ) -async def getprotobyname(name): +async def getprotobyname(name: str) -> int: """Look up a protocol number by name. (Rarely used.) Like :func:`socket.getprotobyname`, but async. @@ -344,7 +370,9 @@ def socket( return from_stdlib_socket(stdlib_socket) -def _sniff_sockopts_for_fileno(family, type, proto, fileno): +def _sniff_sockopts_for_fileno( + family: int, type: int, proto: int, fileno: int +) -> Tuple[int, int, int]: """Correct SOCKOPTS for given fileno, falling back to provided values.""" # Wrap the raw fileno into a Python socket object # This object might have the wrong metadata, but it lets us easily call getsockopt @@ -434,7 +462,12 @@ def did_shutdown_SHUT_WR(self) -> bool: def __enter__(self: _T) -> _T: ... - def __exit__(self, *args: Any) -> None: + def __exit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: ... def dup(self) -> "SocketType": @@ -443,7 +476,7 @@ def dup(self) -> "SocketType": def close(self) -> None: ... - async def bind(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: + async def bind(self, address: Union[_Address, bytes]) -> None: ... def shutdown(self, flag: int) -> None: @@ -599,19 +632,24 @@ def __init__(self, sock: _stdlib_socket.socket) -> None: "share", } - def __getattr__(self, name): + def __getattr__(self, name): # type: ignore if name in self._forward: return getattr(self._sock, name) raise AttributeError(name) - def __dir__(self): - return super().__dir__() + list(self._forward) + def __dir__(self) -> List[str]: + return [*super().__dir__(), *self._forward] - def __enter__(self): + def __enter__(self: _T) -> _T: return self - def __exit__(self, *exc_info): - return self._sock.__exit__(*exc_info) + def __exit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + return self._sock.__exit__(etype, exc, tb) # type: ignore[no-any-return,func-returns-value] @property def family(self) -> int: @@ -635,16 +673,16 @@ def did_shutdown_SHUT_WR(self) -> bool: def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") - def dup(self): + def dup(self) -> "_SocketType": """Same as :meth:`socket.socket.dup`.""" return _SocketType(self._sock.dup()) - def close(self): + def close(self) -> None: if self._sock.fileno() != -1: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address): + async def bind(self, address: Union[_Address, bytes]) -> None: address = await self._resolve_local_address_nocp(address) if ( hasattr(_stdlib_socket, "AF_UNIX") @@ -662,14 +700,14 @@ async def bind(self, address): await trio.lowlevel.checkpoint() return self._sock.bind(address) - def shutdown(self, flag): + def shutdown(self, flag: int) -> None: # no need to worry about return value b/c always returns None: self._sock.shutdown(flag) # only do this if the call succeeded: if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]: self._did_shutdown_SHUT_WR = True - def is_readable(self): + def is_readable(self) -> bool: # use select.select on Windows, and select.poll everywhere else if sys.platform == "win32": rready, _, _ = select.select([self._sock], [], [], 0) @@ -678,7 +716,7 @@ def is_readable(self): p.register(self._sock, select.POLLIN) return bool(p.poll(0)) - async def wait_writable(self): + async def wait_writable(self) -> None: await _core.wait_writable(self._sock) ################################################################ @@ -690,7 +728,9 @@ async def wait_writable(self): # etc. # # NOTE: this function does not always checkpoint - async def _resolve_address_nocp(self, address, flags): + async def _resolve_address_nocp( + self, address: Union[_Address, bytes], flags: int + ) -> Union[_Address, bytes]: # Do some pre-checking (or exit early for non-IP sockets) if self._sock.family == _stdlib_socket.AF_INET: if not isinstance(address, tuple) or not len(address) == 2: @@ -702,7 +742,7 @@ async def _resolve_address_nocp(self, address, flags): ) elif self._sock.family == _stdlib_socket.AF_UNIX: # unwrap path-likes - return os.fspath(address) + return os.fspath(address) # type: ignore[arg-type] else: return address @@ -740,6 +780,7 @@ async def _resolve_address_nocp(self, address, flags): # empty list. assert len(gai_res) >= 1 # Address is the last item in the first entry + normed: Union[List, Tuple] (*_, normed), *_ = gai_res # The above ignored any flowid and scopeid in the passed-in address, # so restore them if present: @@ -756,13 +797,17 @@ async def _resolve_address_nocp(self, address, flags): # Returns something appropriate to pass to bind() # # NOTE: this function does not always checkpoint - async def _resolve_local_address_nocp(self, address): + async def _resolve_local_address_nocp( + self, address: Union[_Address, bytes] + ) -> Union[_Address, bytes]: return await self._resolve_address_nocp(address, _stdlib_socket.AI_PASSIVE) # Returns something appropriate to pass to connect()/sendto()/sendmsg() # # NOTE: this function does not always checkpoint - async def _resolve_remote_address_nocp(self, address): + async def _resolve_remote_address_nocp( + self, address: Union[_Address, bytes] + ) -> Union[_Address, bytes]: return await self._resolve_address_nocp(address, 0) async def _nonblocking_helper( @@ -809,7 +854,7 @@ async def _nonblocking_helper( _accept = _make_simple_sock_method_wrapper("accept", _core.wait_readable) - async def accept(self): + async def accept(self) -> Tuple["SocketType", _Address]: """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr @@ -818,7 +863,7 @@ async def accept(self): # connect ################################################################ - async def connect(self, address): + async def connect(self, address: Union[_Address, bytes]) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the @@ -956,7 +1001,8 @@ async def sendto(self, *args: object) -> int: # args is: data[, flags], address) # and kwargs are not accepted list_args = list(args) - list_args[-1] = await self._resolve_remote_address_nocp(list_args[-1]) + address: _Address = list_args[-1] # type: ignore[assignment] + list_args[-1] = await self._resolve_remote_address_nocp(address) return await self._nonblocking_helper( _stdlib_socket.socket.sendto, list_args, {}, _core.wait_writable ) @@ -980,7 +1026,8 @@ async def sendmsg(self, *args: object) -> int: # args is: buffers[, ancdata[, flags[, address]]] # and kwargs are not accepted if len(args) == 4 and args[-1] is not None: - args = (*args[:-1], await self._resolve_remote_address_nocp(args[-1])) + address: _Address = args[-1] # type: ignore[assignment] + args = (*args[:-1], await self._resolve_remote_address_nocp(address)) return await self._nonblocking_helper( _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable ) diff --git a/trio/_timeouts.py b/trio/_timeouts.py index 1683b6157a..9568ec7cf3 100644 --- a/trio/_timeouts.py +++ b/trio/_timeouts.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Iterator +from typing import ContextManager, Iterator import trio @@ -15,7 +15,7 @@ def move_on_at(deadline: float) -> trio.CancelScope: return trio.CancelScope(deadline=deadline) -def move_on_after(seconds): +def move_on_after(seconds: float) -> trio.CancelScope: """Use as a context manager to create a cancel scope whose deadline is set to now + *seconds*. @@ -32,7 +32,7 @@ def move_on_after(seconds): return move_on_at(trio.current_time() + seconds) -async def sleep_forever(): +async def sleep_forever() -> None: """Pause execution of the current task forever (or until cancelled). Equivalent to calling ``await sleep(math.inf)``. @@ -41,7 +41,7 @@ async def sleep_forever(): await trio.lowlevel.wait_task_rescheduled(lambda _: trio.lowlevel.Abort.SUCCEEDED) -async def sleep_until(deadline): +async def sleep_until(deadline: float) -> None: """Pause execution of the current task until the given time. The difference between :func:`sleep` and :func:`sleep_until` is that the @@ -58,7 +58,7 @@ async def sleep_until(deadline): await sleep_forever() -async def sleep(seconds): +async def sleep(seconds: float) -> None: """Pause execution of the current task for the given number of seconds. Args: @@ -109,7 +109,7 @@ def fail_at(deadline: float) -> Iterator[trio.CancelScope]: raise TooSlowError -def fail_after(seconds): +def fail_after(seconds: float) -> ContextManager[trio.CancelScope]: """Creates a cancel scope with the given timeout, and raises an error if it is actually cancelled. diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index b31f5b36ea..77041c570f 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -81,7 +81,9 @@ def is_public(node: ast.AST) -> bool: return False -def get_public_methods(tree: ast.AST) -> Iterator[Union[ast.FunctionDef, ast.AsyncFunctionDef]]: +def get_public_methods( + tree: ast.AST, +) -> Iterator[Union[ast.FunctionDef, ast.AsyncFunctionDef]]: """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked @@ -95,7 +97,9 @@ def get_public_methods(tree: ast.AST) -> Iterator[Union[ast.FunctionDef, ast.Asy yield node -def create_passthrough_args(funcdef: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> str: +def create_passthrough_args( + funcdef: Union[ast.FunctionDef, ast.AsyncFunctionDef] +) -> str: """Given a function definition, create a string that represents taking all the arguments from the function, and passing them through to another invocation of the same function. @@ -187,7 +191,9 @@ def matches_disk_files(new_files: Dict[str, str]) -> bool: return True -def process(sources_and_lookups: List[Tuple[Union[Path, str], str]], *, do_test: bool) -> None: +def process( + sources_and_lookups: List[Tuple[Union[Path, str], str]], *, do_test: bool +) -> None: new_files = {} for source_path, lookup_path in sources_and_lookups: print("Scanning:", source_path) diff --git a/trio/_typing.py b/trio/_typing.py index fc614f1613..8a62b9bfa8 100644 --- a/trio/_typing.py +++ b/trio/_typing.py @@ -1,6 +1,13 @@ +from typing import Union + from typing_extensions import Protocol +from ._core._run import _TaskStatus, _TaskStatusIgnored + class _HasFileno(Protocol): def fileno(self) -> int: ... + + +TaskStatus = Union[_TaskStatus, _TaskStatusIgnored] diff --git a/trio/_util.py b/trio/_util.py index 653cc0a254..8ec5393bc9 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -313,7 +313,12 @@ class SomeClass(metaclass=Final): - TypeError if a sub class is created """ - def __new__(cls: t.Type[_T], name: str, bases: t.Tuple[type], cls_namespace: t.Dict[str, object]) -> _T: + def __new__( + cls: t.Type[_T], + name: str, + bases: t.Tuple[type], + cls_namespace: t.Dict[str, object], + ) -> _T: for base in bases: if isinstance(base, Final): raise TypeError( diff --git a/trio/tests/test_highlevel_ssl_helpers.py b/trio/tests/test_highlevel_ssl_helpers.py index ffd8bca5f4..33b08d85f9 100644 --- a/trio/tests/test_highlevel_ssl_helpers.py +++ b/trio/tests/test_highlevel_ssl_helpers.py @@ -43,7 +43,9 @@ async def getnameinfo(self, *args): # pragma: no cover # This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... # noqa is needed because flake8 doesn't understand how pytest fixtures work. -async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx) -> None: # noqa: F811 +async def test_open_ssl_over_tcp_stream_and_everything_else( + client_ctx, +) -> None: # noqa: F811 async with trio.open_nursery() as nursery: (listener,) = await nursery.start( partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") From a0abc2d115dceacefe4a9baa3b51c07f3e7d1d90 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sun, 14 Feb 2021 21:07:44 -0500 Subject: [PATCH 45/50] flake8 --- trio/tests/test_highlevel_ssl_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/tests/test_highlevel_ssl_helpers.py b/trio/tests/test_highlevel_ssl_helpers.py index 33b08d85f9..2c1972b15f 100644 --- a/trio/tests/test_highlevel_ssl_helpers.py +++ b/trio/tests/test_highlevel_ssl_helpers.py @@ -7,7 +7,7 @@ import trio from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP import trio.testing -from .test_ssl import client_ctx, SERVER_CTX +from .test_ssl import SERVER_CTX from .._highlevel_ssl_helpers import ( open_ssl_over_tcp_stream, @@ -45,7 +45,7 @@ async def getnameinfo(self, *args): # pragma: no cover # noqa is needed because flake8 doesn't understand how pytest fixtures work. async def test_open_ssl_over_tcp_stream_and_everything_else( client_ctx, -) -> None: # noqa: F811 +) -> None: async with trio.open_nursery() as nursery: (listener,) = await nursery.start( partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") From 8ca2719615fd3fa08c8c5bc5b44e45d429351359 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sun, 14 Feb 2021 21:12:21 -0500 Subject: [PATCH 46/50] flake8 --- trio/tests/test_highlevel_ssl_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/tests/test_highlevel_ssl_helpers.py b/trio/tests/test_highlevel_ssl_helpers.py index 2c1972b15f..e9cee31cdc 100644 --- a/trio/tests/test_highlevel_ssl_helpers.py +++ b/trio/tests/test_highlevel_ssl_helpers.py @@ -7,7 +7,7 @@ import trio from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP import trio.testing -from .test_ssl import SERVER_CTX +from .test_ssl import client_ctx, SERVER_CTX from .._highlevel_ssl_helpers import ( open_ssl_over_tcp_stream, @@ -44,7 +44,7 @@ async def getnameinfo(self, *args): # pragma: no cover # This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... # noqa is needed because flake8 doesn't understand how pytest fixtures work. async def test_open_ssl_over_tcp_stream_and_everything_else( - client_ctx, + client_ctx, # noqa: F811 ) -> None: async with trio.open_nursery() as nursery: (listener,) = await nursery.start( From ca44f08315a25977c5e49d6f8e4941b184cb7419 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sun, 14 Feb 2021 22:22:45 -0500 Subject: [PATCH 47/50] -> None (and more) --- trio/_abc.py | 27 ++-- trio/_channel.py | 4 +- trio/_core/_asyncgens.py | 2 +- trio/_core/_entry_queue.py | 4 +- trio/_core/_io_epoll.py | 6 +- trio/_core/_io_kqueue.py | 6 +- trio/_core/_io_windows.py | 6 +- trio/_core/_ki.py | 2 +- trio/_core/_mock_clock.py | 6 +- trio/_core/_run.py | 42 +++--- trio/_core/_traps.py | 2 +- trio/_core/_wakeup_socketpair.py | 10 +- trio/_core/tests/test_asyncgen.py | 10 +- trio/_core/tests/test_guest_mode.py | 4 +- trio/_core/tests/test_instrumentation.py | 30 ++-- trio/_core/tests/test_io.py | 10 +- trio/_core/tests/test_ki.py | 34 ++--- trio/_core/tests/test_local.py | 10 +- trio/_core/tests/test_mock_clock.py | 8 +- trio/_core/tests/test_multierror.py | 16 +- trio/_core/tests/test_run.py | 142 +++++++++--------- trio/_core/tests/test_thread_cache.py | 6 +- trio/_core/tests/test_unbounded_queue.py | 4 +- trio/_core/tests/test_windows.py | 2 +- trio/_core/tests/tutil.py | 2 +- trio/_file_io.py | 2 +- trio/_highlevel_generic.py | 2 +- trio/_highlevel_open_unix_stream.py | 2 +- trio/_highlevel_socket.py | 8 +- trio/_signals.py | 4 +- trio/_ssl.py | 14 +- trio/_subprocess.py | 8 +- trio/_sync.py | 10 +- trio/_threads.py | 2 +- trio/_unix_pipes.py | 8 +- trio/_windows_pipes.py | 10 +- trio/testing/_check_streams.py | 10 +- trio/testing/_memory_streams.py | 36 ++--- trio/tests/test_abc.py | 6 +- trio/tests/test_channel.py | 8 +- trio/tests/test_deprecate.py | 10 +- trio/tests/test_file_io.py | 6 +- trio/tests/test_highlevel_generic.py | 12 +- .../test_highlevel_open_tcp_listeners.py | 2 +- trio/tests/test_highlevel_open_tcp_stream.py | 6 +- trio/tests/test_highlevel_open_unix_stream.py | 2 +- trio/tests/test_highlevel_serve_listeners.py | 8 +- trio/tests/test_highlevel_socket.py | 6 +- trio/tests/test_signals.py | 4 +- trio/tests/test_socket.py | 12 +- trio/tests/test_ssl.py | 48 +++--- trio/tests/test_subprocess.py | 4 +- trio/tests/test_sync.py | 24 +-- trio/tests/test_testing.py | 36 ++--- trio/tests/test_threads.py | 36 ++--- trio/tests/test_timeouts.py | 8 +- trio/tests/test_unix_pipes.py | 10 +- trio/tests/test_util.py | 6 +- trio/tests/test_windows_pipes.py | 6 +- 59 files changed, 386 insertions(+), 385 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index ad136ea6e0..c828faec4f 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -4,9 +4,10 @@ from typing import Generic, List, Optional, Text, Tuple, TYPE_CHECKING, TypeVar, Union import socket import trio - +from ._core import _run _T = TypeVar("_T") +_TSelf = TypeVar("_TSelf") # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a @@ -71,13 +72,13 @@ class Instrument(metaclass=ABCMeta): __slots__ = () - def before_run(self): + def before_run(self) -> None: """Called at the beginning of :func:`trio.run`.""" - def after_run(self): + def after_run(self) -> None: """Called just before :func:`trio.run` returns.""" - def task_spawned(self, task): + def task_spawned(self, task: "_run.Task") -> None: """Called when the given task is created. Args: @@ -85,7 +86,7 @@ def task_spawned(self, task): """ - def task_scheduled(self, task): + def task_scheduled(self, task: "_run.Task") -> None: """Called when the given task becomes runnable. It may still be some time before it actually runs, if there are other @@ -96,7 +97,7 @@ def task_scheduled(self, task): """ - def before_task_step(self, task): + def before_task_step(self, task: "Task") -> None: """Called immediately before we resume running the given task. Args: @@ -104,7 +105,7 @@ def before_task_step(self, task): """ - def after_task_step(self, task): + def after_task_step(self, task: "Task") -> None: """Called when we return to the main run loop after a task has yielded. Args: @@ -112,7 +113,7 @@ def after_task_step(self, task): """ - def task_exited(self, task): + def task_exited(self, task: "Task") -> None: """Called when the given task exits. Args: @@ -120,7 +121,7 @@ def task_exited(self, task): """ - def before_io_wait(self, timeout): + def before_io_wait(self, timeout: float) -> None: """Called before blocking to wait for I/O readiness. Args: @@ -128,7 +129,7 @@ def before_io_wait(self, timeout): """ - def after_io_wait(self, timeout): + def after_io_wait(self, timeout: float) -> None: """Called after handling pending I/O. Args: @@ -443,10 +444,10 @@ async def receive_some(self, max_bytes: Optional[int] = ...) -> bytes: """ - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self - async def __anext__(self): + async def __anext__(self) -> bytes: data = await self.receive_some() if not data: raise StopAsyncIteration @@ -662,7 +663,7 @@ async def receive(self) -> ReceiveType: """ - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self async def __anext__(self) -> ReceiveType: diff --git a/trio/_channel.py b/trio/_channel.py index eb6a211dad..b37a003d7e 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -129,7 +129,7 @@ class MemorySendChannel(SendChannel[_T_contra], metaclass=NoPublicConstructor): # all clones. _tasks: Set[Task] = attr.ib(factory=set) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._state.open_send_channels += 1 def __repr__(self) -> str: @@ -264,7 +264,7 @@ class MemoryReceiveChannel(ReceiveChannel[_T_co], metaclass=NoPublicConstructor) _closed: bool = attr.ib(default=False) _tasks: Set[Task] = attr.ib(factory=set) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._state.open_receive_channels += 1 def statistics(self): diff --git a/trio/_core/_asyncgens.py b/trio/_core/_asyncgens.py index cb88a1d57b..284dc24e54 100644 --- a/trio/_core/_asyncgens.py +++ b/trio/_core/_asyncgens.py @@ -173,7 +173,7 @@ async def finalize_remaining(self, runner): for agen in batch: await self._finalize_one(agen, name_asyncgen(agen)) - def close(self): + def close(self) -> None: sys.set_asyncgen_hooks(*self.prev_hooks) async def _finalize_one(self, agen, name): diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py index bf887d8b7f..c80f6ba7df 100644 --- a/trio/_core/_entry_queue.py +++ b/trio/_core/_entry_queue.py @@ -70,7 +70,7 @@ async def kill_everything(exc): # This has to be carefully written to be safe in the face of new items # being queued while we iterate, and to do a bounded amount of work on # each pass: - def run_all_bounded(): + def run_all_bounded() -> None: for _ in range(len(self.queue)): run_cb(self.queue.popleft()) for job in list(self.idempotent_queue): @@ -104,7 +104,7 @@ def run_all_bounded(): assert not self.queue assert not self.idempotent_queue - def close(self): + def close(self) -> None: self.wakeup.close() def size(self): diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 3f6327347e..55794d23c4 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -198,7 +198,7 @@ class EpollIOManager: _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) _force_wakeup_fd: int = attr.ib(default=None) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() @@ -215,11 +215,11 @@ def statistics(self): tasks_waiting_write=tasks_waiting_write, ) - def close(self): + def close(self) -> None: self._epoll.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() # Return value must be False-y IFF the timeout expired, NOT if any I/O diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 48848e16dc..66a1847c5e 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -34,7 +34,7 @@ class KqueueIOManager: _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) _force_wakeup_fd: Optional[int] = attr.ib(default=None) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: force_wakeup_event = select.kevent( self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD ) @@ -51,11 +51,11 @@ def statistics(self): monitors += 1 return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors) - def close(self): + def close(self) -> None: self._kqueue.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() def get_events(self, timeout): diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 19f06546d4..6cb2656e17 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -450,7 +450,7 @@ def __init__(self) -> None: "netsh winsock show catalog" ) - def close(self): + def close(self) -> None: try: if self._iocp is not None: iocp = self._iocp @@ -461,7 +461,7 @@ def close(self): afd_handle = self._all_afd_handles.pop() _check(kernel32.CloseHandle(afd_handle)) - def __del__(self): + def __del__(self) -> None: self.close() def statistics(self): @@ -479,7 +479,7 @@ def statistics(self): completion_key_monitors=len(self._completion_key_queues), ) - def force_wakeup(self): + def force_wakeup(self) -> None: _check( kernel32.PostQueuedCompletionStatus( self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index ee0dab4b04..a292b9d2fa 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -195,7 +195,7 @@ def handler(signum, frame): self.handler = handler signal.signal(signal.SIGINT, handler) - def close(self): + def close(self) -> None: if self.handler is not None: if signal.getsignal(signal.SIGINT) is self.handler: signal.signal(signal.SIGINT, signal.default_int_handler) diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py index afd8517d4f..c86afc6643 100644 --- a/trio/_core/_mock_clock.py +++ b/trio/_core/_mock_clock.py @@ -112,7 +112,7 @@ def autojump_threshold(self, new_autojump_threshold): # API. Discussion: # # https://github.com/python-trio/trio/issues/1587 - def _try_resync_autojump_threshold(self): + def _try_resync_autojump_threshold(self) -> None: try: runner = GLOBAL_RUN_CONTEXT.runner if runner.is_guest: @@ -124,7 +124,7 @@ def _try_resync_autojump_threshold(self): # Invoked by the run loop when runner.clock_autojump_threshold is # exceeded. - def _autojump(self): + def _autojump(self) -> None: statistics = _core.current_statistics() jump = statistics.seconds_to_next_deadline if 0 < jump < inf: @@ -135,7 +135,7 @@ def _real_to_virtual(self, real): virtual_offset = self._rate * real_offset return self._virtual_base + virtual_offset - def start_clock(self): + def start_clock(self) -> None: self._try_resync_autojump_threshold() def current_time(self): diff --git a/trio/_core/_run.py b/trio/_core/_run.py index ecb64f0ed3..3e2ff79c01 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -106,7 +106,7 @@ def _public(fn: _Fn) -> _Fn: # # This can all be removed once we drop support for 3.6. def _count_context_run_tb_frames(): - def function_with_unique_name_xyzzy(): + def function_with_unique_name_xyzzy() -> None: 1 / 0 ctx = copy_context() @@ -133,7 +133,7 @@ class SystemClock: # between different runs, then they'll notice the bug quickly: offset = attr.ib(factory=lambda: _r.uniform(10000, 200000)) - def start_clock(self): + def start_clock(self) -> None: pass # In cPython 3, on every platform except Windows, perf_counter is @@ -187,7 +187,7 @@ def next_deadline(self): heappop(self._heap) return inf - def _prune(self): + def _prune(self) -> None: # In principle, it's possible for a cancel scope to toggle back and # forth repeatedly between the same two deadlines, and end up with # lots of stale entries that *look* like they're still active, because @@ -301,7 +301,7 @@ class CancelStatus: # recovery to show a useful traceback). abandoned_by_misnesting: bool = attr.ib(default=False, init=False, repr=False) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: if self._parent is not None: self._parent._children.add(self) self.recalculate() @@ -339,7 +339,7 @@ def encloses(self, other): other = other.parent return False - def close(self): + def close(self) -> None: self.parent = None # now we're not a child of self.parent anymore if self._tasks or self._children: # Cancel scopes weren't exited in opposite order of being @@ -375,7 +375,7 @@ def parent_cancellation_is_visible_to_us(self) -> bool: and self._parent.effectively_cancelled ) - def recalculate(self): + def recalculate(self) -> None: # This does a depth-first traversal over this and descendent cancel # statuses, to ensure their state is up-to-date. It's basically a # recursive algorithm, but we use an explicit stack to avoid any @@ -394,7 +394,7 @@ def recalculate(self): task._attempt_delivery_of_any_pending_cancel() todo.extend(current._children) - def _mark_abandoned(self): + def _mark_abandoned(self) -> None: self.abandoned_by_misnesting = True for child in self._children: child._mark_abandoned() @@ -866,12 +866,12 @@ async def __aexit__( assert value is combined_error_from_nursery value.__context__ = old_context - def __enter__(self): + def __enter__(self) -> None: raise RuntimeError( "use 'async with open_nursery(...)', not 'with open_nursery(...)'" ) - def __exit__(self): # pragma: no cover + def __exit__(self) -> None: # pragma: no cover assert False, """Never called, but should be defined""" @@ -943,7 +943,7 @@ def _add_exc(self, exc): self._pending_excs.append(exc) self.cancel_scope.cancel() - def _check_nursery_closed(self): + def _check_nursery_closed(self) -> None: if not any([self._nested_child_running, self._children, self._pending_starts]): self._closed = True if self._parent_waiting_in_aexit: @@ -1099,7 +1099,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): self._pending_starts -= 1 self._check_nursery_closed() - def __del__(self): + def __del__(self) -> None: assert not self._children @@ -1210,23 +1210,23 @@ def _attempt_abort(self, raise_cancel): if success is Abort.SUCCEEDED: self._runner.reschedule(self, capture(raise_cancel)) - def _attempt_delivery_of_any_pending_cancel(self): + def _attempt_delivery_of_any_pending_cancel(self) -> None: if self._abort_func is None: return if not self._cancel_status.effectively_cancelled: return - def raise_cancel(): + def raise_cancel() -> None: raise Cancelled._create() self._attempt_abort(raise_cancel) - def _attempt_delivery_of_pending_ki(self): + def _attempt_delivery_of_pending_ki(self) -> None: assert self._runner.ki_pending if self._abort_func is None: return - def raise_cancel(): + def raise_cancel() -> None: self._runner.ki_pending = False raise KeyboardInterrupt @@ -1311,7 +1311,7 @@ def get_events(): return self.runner.io_manager.get_events(timeout) def deliver(events_outcome): - def in_main_thread(): + def in_main_thread() -> None: self.unrolled_run_next_send = events_outcome self.runner.guest_tick_scheduled = True self.guest_tick() @@ -1356,13 +1356,13 @@ class Runner: is_guest: bool = attr.ib(default=False) guest_tick_scheduled: bool = attr.ib(default=False) - def force_guest_tick_asap(self): + def force_guest_tick_asap(self) -> None: if self.guest_tick_scheduled: return self.guest_tick_scheduled = True self.io_manager.force_wakeup() - def close(self): + def close(self) -> None: self.io_manager.close() self.entry_queue.close() self.asyncgens.close() @@ -1713,14 +1713,14 @@ def current_trio_token(self) -> TrioToken: # keep the class public so people can isinstance() it if they want. # This gets called from signal context - def deliver_ki(self): + def deliver_ki(self) -> None: self.ki_pending = True try: self.entry_queue.run_sync_soon(self._deliver_ki_cb) except RunFinishedError: pass - def _deliver_ki_cb(self): + def _deliver_ki_cb(self) -> None: if not self.ki_pending: return # Can't happen because main_task and run_sync_soon_task are created at @@ -2414,7 +2414,7 @@ async def checkpoint(): await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -async def checkpoint_if_cancelled(): +async def checkpoint_if_cancelled() -> None: """Issue a :ref:`checkpoint ` if the calling context has been cancelled. diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index 163a98da25..341490a201 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -239,7 +239,7 @@ async def temporarily_detach_coroutine_object(abort_func: AbortFunc) -> None: uses to resume the coroutine. """ - return await _async_yield(WaitTaskRescheduled(abort_func)) + return await _async_yield(WaitTaskRescheduled(abort_func)) # type: ignore[no-any-return] async def reattach_detached_coroutine_object( diff --git a/trio/_core/_wakeup_socketpair.py b/trio/_core/_wakeup_socketpair.py index 77abc2f2ed..ab4d5973fd 100644 --- a/trio/_core/_wakeup_socketpair.py +++ b/trio/_core/_wakeup_socketpair.py @@ -53,24 +53,24 @@ def __init__(self) -> None: pass self.old_wakeup_fd = None - def wakeup_thread_and_signal_safe(self): + def wakeup_thread_and_signal_safe(self) -> None: try: self.write_sock.send(b"\x00") except BlockingIOError: pass - async def wait_woken(self): + async def wait_woken(self) -> None: await _core.wait_readable(self.wakeup_sock) self.drain() - def drain(self): + def drain(self) -> None: try: while True: self.wakeup_sock.recv(2 ** 16) except BlockingIOError: pass - def wakeup_on_signals(self): + def wakeup_on_signals(self) -> None: assert self.old_wakeup_fd is None if not is_main_thread(): return @@ -91,7 +91,7 @@ def wakeup_on_signals(self): ) ) - def close(self): + def close(self) -> None: self.wakeup_sock.close() self.write_sock.close() if self.old_wakeup_fd is not None: diff --git a/trio/_core/tests/test_asyncgen.py b/trio/_core/tests/test_asyncgen.py index fd4e26d392..41f4dd172f 100644 --- a/trio/_core/tests/test_asyncgen.py +++ b/trio/_core/tests/test_asyncgen.py @@ -37,7 +37,7 @@ async def example(cause): saved = [] - async def async_main(): + async def async_main() -> None: # GC'ed before exhausted with pytest.warns( ResourceWarning, match="Async generator.*collected before.*exhausted" @@ -123,7 +123,7 @@ async def funky_agen(): record.append("cleanup 2") await funky_agen().asend(None) - async def async_main(): + async def async_main() -> None: aiter = funky_agen() saved.append(aiter) assert 1 == await aiter.asend(None) @@ -157,7 +157,7 @@ async def agen(label, inner): await inner.asend(None) record.append(label) - async def async_main(): + async def async_main() -> None: # This makes a chain of 101 interdependent asyncgens: # agen(99)'s cleanup will iterate agen(98)'s will iterate # ... agen(0)'s will iterate innermost()'s @@ -199,7 +199,7 @@ def collect_at_opportune_moment(token): nonlocal needs_retry needs_retry = True - async def async_main(): + async def async_main() -> None: token = _core.current_trio_token() token.run_sync_soon(collect_at_opportune_moment, token) saved.append(agen()) @@ -303,7 +303,7 @@ async def example(arg): await _core.checkpoint() record.append("trio collected " + arg) - async def async_main(): + async def async_main() -> None: await step_outside_async_context(example("theirs")) assert 42 == await example("ours").asend(None) gc_collect_harder() diff --git a/trio/_core/tests/test_guest_mode.py b/trio/_core/tests/test_guest_mode.py index 707688c739..920b04797c 100644 --- a/trio/_core/tests/test_guest_mode.py +++ b/trio/_core/tests/test_guest_mode.py @@ -92,7 +92,7 @@ async def trio_main(in_host): with a, b: async with trio.open_nursery() as nursery: - async def do_receive(): + async def do_receive() -> None: record.append(await a.recv(1)) nursery.start_soon(do_receive) @@ -523,7 +523,7 @@ async def agen(label): pass record.add((label, library)) - async def iterate_in_aio(): + async def iterate_in_aio() -> None: # "trio" gets inherited from our Trio caller if we don't set this sniffio.current_async_library_cvar.set("asyncio") await agen("asyncio").asend(None) diff --git a/trio/_core/tests/test_instrumentation.py b/trio/_core/tests/test_instrumentation.py index 2666e6bff3..f50405ccf2 100644 --- a/trio/_core/tests/test_instrumentation.py +++ b/trio/_core/tests/test_instrumentation.py @@ -8,7 +8,7 @@ class TaskRecorder: record = attr.ib(factory=list) - def before_run(self): + def before_run(self) -> None: self.record.append(("before_run",)) def task_scheduled(self, task): @@ -22,7 +22,7 @@ def after_task_step(self, task): assert task is _core.current_task() self.record.append(("after", task)) - def after_run(self): + def after_run(self) -> None: self.record.append(("after_run",)) def filter_tasks(self, tasks): @@ -43,7 +43,7 @@ def test_instruments(recwarn) -> None: # We use a child task for this, because the main task does some extra # bookkeeping stuff that can leak into the instrument results, and we # don't want to deal with it. - async def task_fn(): + async def task_fn() -> None: nonlocal task task = _core.current_task() @@ -59,7 +59,7 @@ async def task_fn(): for _ in range(1): await _core.checkpoint() - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(task_fn) @@ -80,15 +80,15 @@ async def main(): def test_instruments_interleave() -> None: tasks = {} - async def two_step1(): + async def two_step1() -> None: tasks["t1"] = _core.current_task() await _core.checkpoint() - async def two_step2(): + async def two_step2() -> None: tasks["t2"] = _core.current_task() await _core.checkpoint() - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(two_step1) nursery.start_soon(two_step2) @@ -123,10 +123,10 @@ async def main(): def test_null_instrument() -> None: # undefined instrument methods are skipped class NullInstrument: - def something_unrelated(self): + def something_unrelated(self) -> None: pass # pragma: no cover - async def main(): + async def main() -> None: await _core.checkpoint() _core.run(main, instruments=[NullInstrument()]) @@ -136,13 +136,13 @@ def test_instrument_before_after_run() -> None: record = [] class BeforeAfterRun: - def before_run(self): + def before_run(self) -> None: record.append("before_run") - def after_run(self): + def after_run(self) -> None: record.append("after_run") - async def main(): + async def main() -> None: pass _core.run(main, instruments=[BeforeAfterRun()]) @@ -177,7 +177,7 @@ def task_scheduled(self, task): record.append("scheduled") raise ValueError("oops") - def close(self): + def close(self) -> None: # Shouldn't be called -- tests that the instrument disabling logic # works right. record.append("closed") # pragma: no cover @@ -206,7 +206,7 @@ class NullInstrument(_abc.Instrument): instrument = NullInstrument() - async def main(): + async def main() -> None: record = [] # Changing the set of hooks implemented by an instrument after @@ -241,7 +241,7 @@ def task_exited(self, task): def after_run(self) -> None: raise ValueError("oops") - async def main(): + async def main() -> None: with pytest.raises(ValueError): _core.add_instrument(EvilInstrument()) diff --git a/trio/_core/tests/test_io.py b/trio/_core/tests/test_io.py index 1418df4080..f8ce11a1b8 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/tests/test_io.py @@ -108,7 +108,7 @@ async def test_wait_basic( # But readable() blocks until data arrives record = [] - async def block_on_read(): + async def block_on_read() -> None: try: with assert_checkpoints(): await wait_readable(a) @@ -131,7 +131,7 @@ async def block_on_read(): await wait_readable(b) record = [] - async def block_on_write(): + async def block_on_write() -> None: try: with assert_checkpoints(): await wait_writable(a) @@ -205,11 +205,11 @@ async def test_interrupted_by_close( ) -> None: a, b = socketpair - async def reader(): + async def reader() -> None: with pytest.raises(_core.ClosedResourceError): await wait_readable(a) - async def writer(): + async def writer() -> None: with pytest.raises(_core.ClosedResourceError): await wait_writable(a) @@ -459,7 +459,7 @@ async def allow_OSError(async_func, *args): # sleep waiting on 'a2', with the idea that the 'a2' notification will # definitely arrive, and when it does then we can assume that whatever # notification was going to arrive for 'a' has also arrived. - async def wait_readable_a2_then_set(): + async def wait_readable_a2_then_set() -> None: await trio.lowlevel.wait_readable(a2) e.set() diff --git a/trio/_core/tests/test_ki.py b/trio/_core/tests/test_ki.py index f6a5d88f3a..e849cf14ef 100644 --- a/trio/_core/tests/test_ki.py +++ b/trio/_core/tests/test_ki.py @@ -22,7 +22,7 @@ from .tutil import slow -def ki_self(): +def ki_self() -> None: signal_raise(signal.SIGINT) @@ -39,7 +39,7 @@ async def test_ki_enabled() -> None: token = _core.current_trio_token() record: Any = [] - def check(): + def check() -> None: record.append(_core.currently_ki_protected()) token.run_sync_soon(check) @@ -146,7 +146,7 @@ def protected_manager() -> Iterator[None]: async def test_agen_protection() -> None: @_core.enable_ki_protection @async_generator - async def agen_protected1(): # type: ignore[misc] + async def agen_protected1() -> None: # type: ignore[misc] assert _core.currently_ki_protected() try: await yield_() @@ -155,7 +155,7 @@ async def agen_protected1(): # type: ignore[misc] @_core.disable_ki_protection @async_generator - async def agen_unprotected1(): # type: ignore[misc] + async def agen_unprotected1() -> None: # type: ignore[misc] assert not _core.currently_ki_protected() try: await yield_() @@ -165,7 +165,7 @@ async def agen_unprotected1(): # type: ignore[misc] # Swap the order of the decorators: @async_generator @_core.enable_ki_protection - async def agen_protected2(): # type: ignore[misc] + async def agen_protected2() -> None: # type: ignore[misc] assert _core.currently_ki_protected() try: await yield_() @@ -174,7 +174,7 @@ async def agen_protected2(): # type: ignore[misc] @async_generator @_core.disable_ki_protection - async def agen_unprotected2(): # type: ignore[misc] + async def agen_unprotected2() -> None: # type: ignore[misc] assert not _core.currently_ki_protected() try: await yield_() @@ -231,7 +231,7 @@ def test_ki_disabled_in_del() -> None: def nestedfunction(): return _core.currently_ki_protected() - def __del__(): + def __del__() -> None: assert _core.currently_ki_protected() assert nestedfunction() @@ -279,7 +279,7 @@ async def raiser(name, record): print("check 1") record: Any = set() - async def check_unprotected_kill(): + async def check_unprotected_kill() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(sleeper, "s1", record) nursery.start_soon(sleeper, "s2", record) @@ -294,7 +294,7 @@ async def check_unprotected_kill(): print("check 2") record = set() - async def check_protected_kill(): + async def check_protected_kill() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(sleeper, "s1", record) nursery.start_soon(sleeper, "s2", record) @@ -309,10 +309,10 @@ async def check_protected_kill(): # error, then kill) print("check 3") - async def check_kill_during_shutdown(): + async def check_kill_during_shutdown() -> None: token = _core.current_trio_token() - def kill_during_shutdown(): + def kill_during_shutdown() -> None: assert _core.currently_ki_protected() try: token.run_sync_soon(kill_during_shutdown) @@ -330,10 +330,10 @@ def kill_during_shutdown(): print("check 4") class InstrumentOfDeath: - def before_run(self): + def before_run(self) -> None: ki_self() - async def main(): + async def main() -> None: await _core.checkpoint() with pytest.raises(KeyboardInterrupt): @@ -423,7 +423,7 @@ async def main_e() -> None: # restrict_keyboard_interrupt_to_checkpoints=True record = [] - async def main_f(): + async def main_f() -> None: # We're not KI protected... assert not _core.currently_ki_protected() ki_self() @@ -471,7 +471,7 @@ def test_ki_is_good_neighbor() -> None: def my_handler(signum, frame): # pragma: no cover pass - async def main(): + async def main() -> None: signal.signal(signal.SIGINT, my_handler) _core.run(main) @@ -563,7 +563,7 @@ def test_ki_wakes_us_up() -> None: # It will be very nice when the buggy_wakeup_fd bug is fixed. lock = threading.Lock() - def kill_soon(): + def kill_soon() -> None: # We want the signal to be raised after the main thread has entered # the IO manager blocking primitive. There really is no way to # deterministically interlock with that, so we have to use sleep and @@ -576,7 +576,7 @@ def kill_soon(): print("buggy_wakeup_fd =", buggy_wakeup_fd) ki_self() - async def main(): + async def main() -> None: thread = threading.Thread(target=kill_soon) print("Starting thread") thread.start() diff --git a/trio/_core/tests/test_local.py b/trio/_core/tests/test_local.py index 65a399e298..27536fb146 100644 --- a/trio/_core/tests/test_local.py +++ b/trio/_core/tests/test_local.py @@ -10,7 +10,7 @@ def test_runvar_smoketest() -> None: assert "RunVar" in repr(t1) - async def first_check(): + async def first_check() -> None: with pytest.raises(LookupError): t1.get() @@ -23,7 +23,7 @@ async def first_check(): assert t2.get() == "goldfish" assert t2.get(default="tuna") == "goldfish" - async def second_check(): + async def second_check() -> None: with pytest.raises(LookupError): t1.get() @@ -38,7 +38,7 @@ def test_runvar_resetting() -> None: t2 = _core.RunVar("test2", default="dogfish") t3 = _core.RunVar("test3") - async def reset_check(): + async def reset_check() -> None: token = t1.set("moonfish") assert t1.get() == "moonfish" t1.reset(token) @@ -69,8 +69,8 @@ async def reset_check(): def test_runvar_sync() -> None: t1 = _core.RunVar("test1") - async def sync_check(): - async def task1(): + async def sync_check() -> None: + async def task1() -> None: t1.set("plaice") assert t1.get() == "plaice" diff --git a/trio/_core/tests/test_mock_clock.py b/trio/_core/tests/test_mock_clock.py index 7242497b58..321bf94b42 100644 --- a/trio/_core/tests/test_mock_clock.py +++ b/trio/_core/tests/test_mock_clock.py @@ -129,11 +129,11 @@ async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock) -> record = [] - async def sleeper(): + async def sleeper() -> None: await sleep(100) record.append("yawn") - async def waiter(): + async def waiter() -> None: await wait_all_tasks_blocked() record.append("waiter woke") await sleep(1000) @@ -157,11 +157,11 @@ async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero( record = [] - async def sleeper(): + async def sleeper() -> None: await sleep(100) record.append("yawn") - async def waiter(): + async def waiter() -> None: await wait_all_tasks_blocked(1) record.append("waiter done") diff --git a/trio/_core/tests/test_multierror.py b/trio/_core/tests/test_multierror.py index b2f1e26a0c..2872a4266b 100644 --- a/trio/_core/tests/test_multierror.py +++ b/trio/_core/tests/test_multierror.py @@ -36,27 +36,27 @@ async def raise_nothashable(code): raise NotHashableException(code) -def raiser1(): +def raiser1() -> None: raiser1_2() -def raiser1_2(): +def raiser1_2() -> None: raiser1_3() -def raiser1_3(): +def raiser1_3() -> None: raise ValueError("raiser1_string") -def raiser2(): +def raiser2() -> None: raiser2_2() -def raiser2_2(): +def raiser2_2() -> None: raise KeyError("raiser2_string") -def raiser3(): +def raiser3() -> None: raise NameError @@ -506,13 +506,13 @@ def test_format_exception() -> None: # Prints duplicate exceptions in sub-exceptions exc1 = get_exc(raiser1) - def raise1_raiser1(): + def raise1_raiser1() -> None: try: raise exc1 except: raise ValueError("foo") - def raise2_raiser1(): + def raise2_raiser1() -> None: try: raise exc1 except: diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index c0fcc0dfb1..1990449841 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -79,7 +79,7 @@ async def main(x): def test_run_nesting() -> None: async def inception(): - async def main(): # pragma: no cover + async def main() -> None: # pragma: no cover pass return _core.run(main) @@ -115,7 +115,7 @@ async def test_nursery_main_block_error_basic() -> None: async def test_child_crash_basic() -> None: exc = ValueError("uh oh") - async def erroring(): + async def erroring() -> None: raise exc try: @@ -145,7 +145,7 @@ async def looper(whoami, record): def test_task_crash_propagation() -> None: looper_record = [] - async def looper(): + async def looper() -> None: try: while True: await _core.checkpoint() @@ -153,10 +153,10 @@ async def looper(): print("looper cancelled") looper_record.append("cancelled") - async def crasher(): + async def crasher() -> None: raise ValueError("argh") - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(looper) nursery.start_soon(crasher) @@ -171,10 +171,10 @@ async def main(): def test_main_and_task_both_crash() -> None: # If main crashes and there's also a task crash, then we get both in a # MultiError - async def crasher(): + async def crasher() -> None: raise ValueError - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher) raise KeyError @@ -192,7 +192,7 @@ def test_two_child_crashes() -> None: async def crasher(etype): raise etype - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher, KeyError) nursery.start_soon(crasher, ValueError) @@ -206,7 +206,7 @@ async def main(): async def test_child_crash_wakes_parent() -> None: - async def crasher(): + async def crasher() -> None: raise ValueError with pytest.raises(ValueError): @@ -219,7 +219,7 @@ async def test_reschedule() -> None: t1 = None t2 = None - async def child1(): + async def child1() -> None: nonlocal t1, t2 t1 = _core.current_task() print("child1 start") @@ -230,7 +230,7 @@ async def child1(): _core.reschedule(t2, outcome.Error(ValueError())) print("child1 exit") - async def child2(): + async def child2() -> None: nonlocal t1, t2 print("child2 start") t2 = _core.current_task() @@ -271,7 +271,7 @@ async def test_current_clock(mock_clock) -> None: async def test_current_task() -> None: parent_task = _core.current_task() - async def child(): + async def child() -> None: assert _core.current_task().parent_nursery.parent_task is parent_task async with _core.open_nursery() as nursery: @@ -295,7 +295,7 @@ async def test_current_statistics(mock_clock) -> None: await wait_all_tasks_blocked() # A child that sticks around to make some interesting stats: - async def child(): + async def child() -> None: try: await sleep_forever() except _core.Cancelled: @@ -359,7 +359,7 @@ async def test_cancel_scope_repr(mock_clock) -> None: def test_cancel_points() -> None: - async def main1(): + async def main1() -> None: with _core.CancelScope() as scope: await _core.checkpoint_if_cancelled() scope.cancel() @@ -368,7 +368,7 @@ async def main1(): _core.run(main1) - async def main2(): + async def main2() -> None: with _core.CancelScope() as scope: await _core.checkpoint() scope.cancel() @@ -377,7 +377,7 @@ async def main2(): _core.run(main2) - async def main3(): + async def main3() -> None: with _core.CancelScope() as scope: scope.cancel() with pytest.raises(_core.Cancelled): @@ -385,7 +385,7 @@ async def main3(): _core.run(main3) - async def main4(): + async def main4() -> None: with _core.CancelScope() as scope: scope.cancel() await _core.cancel_shielded_checkpoint() @@ -415,7 +415,7 @@ async def test_cancel_edge_cases() -> None: async def test_cancel_scope_multierror_filtering() -> None: - async def crasher(): + async def crasher() -> None: raise KeyError try: @@ -463,7 +463,7 @@ async def test_precancelled_task() -> None: # cancelled error at its first blocking call. record = [] - async def blocker(): + async def blocker() -> None: record.append("started") await sleep_forever() @@ -545,7 +545,7 @@ async def test_cancel_shield_abort() -> None: # shield, so it manages to get to sleep record = [] - async def sleeper(): + async def sleeper() -> None: record.append("sleeping") try: await sleep_forever() @@ -704,7 +704,7 @@ async def sleep_until_cancelled(scope): # Can't enter from multiple tasks simultaneously scope = _core.CancelScope() - async def enter_scope(): + async def enter_scope() -> None: with scope: await sleep_forever() @@ -740,12 +740,12 @@ async def test_cancel_scope_misnesting() -> None: # If there are other tasks inside the abandoned part of the cancel tree, # they get cancelled when the misnesting is detected - async def task1(): + async def task1() -> None: with pytest.raises(_core.Cancelled): await sleep_forever() # Even if inside another cancel scope - async def task2(): + async def task2() -> None: with _core.CancelScope(): with pytest.raises(_core.Cancelled): await sleep_forever() @@ -849,7 +849,7 @@ async def stubborn_sleeper(): def test_broken_abort() -> None: - async def main(): + async def main() -> None: # These yields are here to work around an annoying warning -- we're # going to crash the main loop, and if we (by chance) do this before # the run_sync_soon task runs for the first time, then Python gives us @@ -876,7 +876,7 @@ async def main(): def test_error_in_run_loop() -> None: # Blow stuff up real good to check we at least get a TrioInternalError - async def main(): + async def main() -> None: task = _core.current_task() task._schedule_points = "hello!" await _core.checkpoint() @@ -901,10 +901,10 @@ async def system_task(x): # intentionally make a system task crash def test_system_task_crash() -> None: - async def crasher(): + async def crasher() -> None: raise KeyError - async def main(): + async def main() -> None: _core.spawn_system_task(crasher) await sleep_forever() @@ -913,18 +913,18 @@ async def main(): def test_system_task_crash_MultiError() -> None: - async def crasher1(): + async def crasher1() -> None: raise KeyError - async def crasher2(): + async def crasher2() -> None: raise ValueError - async def system_task(): + async def system_task() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher1) nursery.start_soon(crasher2) - async def main(): + async def main() -> None: _core.spawn_system_task(system_task) await sleep_forever() @@ -941,21 +941,21 @@ async def main(): def test_system_task_crash_plus_Cancelled() -> None: # Set up a situation where a system task crashes with a # MultiError([Cancelled, ValueError]) - async def crasher(): + async def crasher() -> None: try: await sleep_forever() except _core.Cancelled: raise ValueError - async def cancelme(): + async def cancelme() -> None: await sleep_forever() - async def system_task(): + async def system_task() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher) nursery.start_soon(cancelme) - async def main(): + async def main() -> None: _core.spawn_system_task(system_task) # then we exit, triggering a cancellation @@ -965,10 +965,10 @@ async def main(): def test_system_task_crash_KeyboardInterrupt() -> None: - async def ki(): + async def ki() -> None: raise KeyboardInterrupt - async def main(): + async def main() -> None: _core.spawn_system_task(ki) await sleep_forever() @@ -1004,7 +1004,7 @@ async def test_exc_info() -> None: record = [] seq = Sequencer() - async def child1(): + async def child1() -> None: with pytest.raises(ValueError) as excinfo: try: async with seq(0): @@ -1021,7 +1021,7 @@ async def child1(): assert excinfo.value.__context__ is None record.append("child1 success") - async def child2(): + async def child2() -> None: with pytest.raises(KeyError) as excinfo: async with seq(1): pass # we don't yield until seq(3) below @@ -1064,7 +1064,7 @@ async def child2(): async def test_exc_info_after_yield_error() -> None: child_task = None - async def child(): + async def child() -> None: nonlocal child_task child_task = _core.current_task() @@ -1089,7 +1089,7 @@ async def child(): async def test_exception_chaining_after_yield_error() -> None: child_task = None - async def child(): + async def child() -> None: nonlocal child_task child_task = _core.current_task() @@ -1108,7 +1108,7 @@ async def child(): async def test_nursery_exception_chaining_doesnt_make_context_loops() -> None: - async def crasher(): + async def crasher() -> None: raise KeyError with pytest.raises(_core.MultiError) as excinfo: @@ -1149,7 +1149,7 @@ def cb(x): def test_TrioToken_run_sync_soon_too_late() -> None: token = None - async def main(): + async def main() -> None: nonlocal token token = _core.current_trio_token() @@ -1200,7 +1200,7 @@ def redo(token): except _core.RunFinishedError: pass - async def main(): + async def main() -> None: token = _core.current_trio_token() token.run_sync_soon(redo, token, idempotent=True) await _core.checkpoint() @@ -1273,7 +1273,7 @@ def naughty_cb(i): except _core.RunFinishedError: record.append(("run finished", i)) - async def main(): + async def main() -> None: nonlocal token token = _core.current_trio_token() token.run_sync_soon(naughty_cb, 0) @@ -1291,7 +1291,7 @@ async def main(): def test_TrioToken_run_sync_soon_threaded_stress_test() -> None: cb_counter = 0 - def cb(): + def cb() -> None: nonlocal cb_counter cb_counter += 1 @@ -1303,7 +1303,7 @@ def stress_thread(token): except _core.RunFinishedError: pass - async def main(): + async def main() -> None: token = _core.current_trio_token() thread = threading.Thread(target=stress_thread, args=(token,)) thread.start() @@ -1353,7 +1353,7 @@ async def agen(): token.run_sync_soon(lambda: {}["nope"]) token.run_sync_soon(lambda: record.append("2nd ran")) - async def main(): + async def main() -> None: saved.append(agen()) await saved[-1].asend(None) record.append("main exiting") @@ -1455,7 +1455,7 @@ async def parent(task_status=_core.TASK_STATUS_IGNORED): t = nursery.parent_task nursery = t.parent_nursery - async def child2(): + async def child2() -> None: tasks["child2"] = _core.current_task() assert tasks["parent"].child_nurseries == [nurseries["parent"]] assert nurseries["parent"].child_tasks == frozenset({tasks["child1"]}) @@ -1495,7 +1495,7 @@ async def child1(nursery): # so long as there are still tasks running nursery.start_soon(child2) - async def child2(): + async def child2() -> None: pass async with _core.open_nursery() as nursery: @@ -1511,7 +1511,7 @@ async def func1(expected): task = _core.current_task() assert expected in task.name - async def func2(): # pragma: no cover + async def func2() -> None: # pragma: no cover pass async with _core.open_nursery() as nursery: @@ -1550,7 +1550,7 @@ def bad_call_run(*args): _core.run(*args) def bad_call_spawn(*args): - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(*args) @@ -1558,7 +1558,7 @@ async def main(): for bad_call in bad_call_run, bad_call_spawn: - async def f(): # pragma: no cover + async def f() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="expecting an async function"): @@ -1574,12 +1574,12 @@ async def async_gen(arg): # pragma: no cover def test_calling_asyncio_function_gives_nice_error() -> None: - async def child_xyzzy(): + async def child_xyzzy() -> None: import asyncio await asyncio.Future() - async def misguided(): + async def misguided() -> None: await child_xyzzy() with pytest.raises(TypeError) as excinfo: @@ -1628,7 +1628,7 @@ async def test_trivial_yields() -> None: async def test_nursery_start(autojump_clock) -> None: - async def no_args(): # pragma: no cover + async def no_args() -> None: # pragma: no cover pass # Errors in calling convention get raised immediately from start @@ -1815,7 +1815,7 @@ async def test_nursery_explicit_exception() -> None: async def test_nursery_stop_iteration() -> None: - async def fail(): + async def fail() -> None: raise ValueError try: @@ -1879,7 +1879,7 @@ def handle(exc): async def test_traceback_frame_removal() -> None: - async def my_child_task(): + async def my_child_task() -> None: raise KeyError() try: @@ -1908,7 +1908,7 @@ def test_contextvar_support() -> None: assert var.get() == "before" - async def inner(): + async def inner() -> None: task = _core.current_task() assert task.context.get(var) == "before" assert var.get() == "before" @@ -1924,12 +1924,12 @@ async def inner(): async def test_contextvar_multitask() -> None: var = contextvars.ContextVar("test", default="hmmm") - async def t1(): + async def t1() -> None: assert var.get() == "hmmm" var.set("hmmmm") assert var.get() == "hmmmm" - async def t2(): + async def t2() -> None: assert var.get() == "hmmmm" async with _core.open_nursery() as n: @@ -1945,13 +1945,13 @@ def test_system_task_contexts() -> None: cvar = contextvars.ContextVar("qwilfish") cvar.set("water") - async def system_task(): + async def system_task() -> None: assert cvar.get() == "water" - async def regular_task(): + async def regular_task() -> None: assert cvar.get() == "poison" - async def inner(): + async def inner() -> None: async with _core.open_nursery() as nursery: cvar.set("poison") nursery.start_soon(regular_task) @@ -2013,7 +2013,7 @@ def test_sniffio_integration() -> None: with pytest.raises(sniffio.AsyncLibraryNotFoundError): sniffio.current_async_library() - async def check_inside_trio(): + async def check_inside_trio() -> None: assert sniffio.current_async_library() == "trio" _core.run(check_inside_trio) @@ -2072,7 +2072,7 @@ async def detachable_coroutine(task_outcome, yield_value): with pytest.raises(StopIteration): task.coro.send(None) - async def bad_detach(): + async def bad_detach() -> None: async with _core.open_nursery(): with pytest.raises(RuntimeError) as excinfo: await _core.permanently_detach_coroutine_object(outcome.Value(None)) @@ -2086,7 +2086,7 @@ async def test_detach_and_reattach_coroutine_object() -> None: unrelated_task = None task = None - async def unrelated_coroutine(): + async def unrelated_coroutine() -> None: nonlocal unrelated_task unrelated_task = _core.current_task() @@ -2172,7 +2172,7 @@ async def agen_fn(record): _core.run(agen.__anext__) assert run_record == ["the generator ran"] - async def main(): + async def main() -> None: start_soon_record = [] agen = agen_fn(start_soon_record) async with _core.open_nursery() as nursery: @@ -2210,7 +2210,7 @@ async def test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/issues/1770 gc.collect() - async def do_a_cancel(): + async def do_a_cancel() -> None: with _core.CancelScope() as cscope: cscope.cancel() await sleep_forever() @@ -2262,14 +2262,14 @@ async def test_nursery_cancel_doesnt_create_cyclic_garbage() -> None: async def test_locals_destroyed_promptly_on_cancel() -> None: destroyed = False - def finalizer(): + def finalizer() -> None: nonlocal destroyed destroyed = True class A: pass - async def task(): + async def task() -> None: a = A() weakref.finalize(a, finalizer) await _core.checkpoint() diff --git a/trio/_core/tests/test_thread_cache.py b/trio/_core/tests/test_thread_cache.py index 234fef251c..e2d55e49a3 100644 --- a/trio/_core/tests/test_thread_cache.py +++ b/trio/_core/tests/test_thread_cache.py @@ -15,7 +15,7 @@ def test_thread_cache_basics() -> None: q = Queue() - def fn(): + def fn() -> None: raise RuntimeError("hi") def deliver(outcome): @@ -35,7 +35,7 @@ class del_me: def __call__(self): return 42 - def __del__(self): + def __del__(self) -> None: res[0] = True q = Queue() @@ -137,7 +137,7 @@ def acquire(self, timeout=None): return False return True - def release(self): + def release(self) -> None: self._lock.release() monkeypatch.setattr(_thread_cache, "Lock", JankyLock) diff --git a/trio/_core/tests/test_unbounded_queue.py b/trio/_core/tests/test_unbounded_queue.py index b7434c3a13..433aae310c 100644 --- a/trio/_core/tests/test_unbounded_queue.py +++ b/trio/_core/tests/test_unbounded_queue.py @@ -39,13 +39,13 @@ async def test_UnboundedQueue_blocking() -> None: record = [] q = _core.UnboundedQueue() - async def get_batch_consumer(): + async def get_batch_consumer() -> None: while True: batch = await q.get_batch() assert batch record.append(batch) - async def aiter_consumer(): + async def aiter_consumer() -> None: async for batch in q: assert batch record.append(batch) diff --git a/trio/_core/tests/test_windows.py b/trio/_core/tests/test_windows.py index b789d74cf9..55eb9e5613 100644 --- a/trio/_core/tests/test_windows.py +++ b/trio/_core/tests/test_windows.py @@ -120,7 +120,7 @@ def test_forgot_to_register_with_iocp() -> None: left_run_yet = False - async def main(): + async def main() -> None: target = bytearray(1) try: async with _core.open_nursery() as nursery: diff --git a/trio/_core/tests/tutil.py b/trio/_core/tests/tutil.py index 498c2cf6b0..2b5cce1403 100644 --- a/trio/_core/tests/tutil.py +++ b/trio/_core/tests/tutil.py @@ -50,7 +50,7 @@ binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6") -def gc_collect_harder(): +def gc_collect_harder() -> None: # In the test suite we sometimes want to call gc.collect() to make sure # that any objects with noisy __del__ methods (e.g. unawaited coroutines) # get collected before we continue, so their noise doesn't leak into diff --git a/trio/_file_io.py b/trio/_file_io.py index b9122305f8..0d1eee61de 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -136,7 +136,7 @@ async def detach(self): raw = await trio.to_thread.run_sync(self._wrapped.detach) return wrap_file(raw) - async def aclose(self): + async def aclose(self) -> None: """Like :meth:`io.IOBase.close`, but async. This is also shielded from cancellation; if a cancellation scope is diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index c31b4fdbf3..c19e4d0276 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -99,7 +99,7 @@ async def receive_some(self, max_bytes=None): """Calls ``self.receive_stream.receive_some``.""" return await self.receive_stream.receive_some(max_bytes) - async def aclose(self): + async def aclose(self) -> None: """Calls ``aclose`` on both underlying streams.""" try: await self.send_stream.aclose() diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py index 8922d6d68c..e85ba058d7 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/trio/_highlevel_open_unix_stream.py @@ -15,7 +15,7 @@ class Closable(Protocol): - def close(self): + def close(self) -> None: ... diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 0dd05dc7f0..548bb002bf 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -111,14 +111,14 @@ async def send_all(self, data): sent = await self.socket.send(remaining) total_sent += sent - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: if self.socket.fileno() == -1: raise trio.ClosedResourceError with _translate_socket_errors_to_stream_errors(): await self.socket.wait_writable() - async def send_eof(self): + async def send_eof(self) -> None: with self._send_conflict_detector: await trio.lowlevel.checkpoint() # On macOS, calling shutdown a second time raises ENOTCONN, but @@ -136,7 +136,7 @@ async def receive_some(self, max_bytes=None): with _translate_socket_errors_to_stream_errors(): return await self.socket.recv(max_bytes) - async def aclose(self): + async def aclose(self) -> None: self.socket.close() await trio.lowlevel.checkpoint() @@ -377,7 +377,7 @@ async def accept(self): else: return SocketStream(sock) - async def aclose(self): + async def aclose(self) -> None: """Close this listener and its underlying socket.""" self.socket.close() await trio.lowlevel.checkpoint() diff --git a/trio/_signals.py b/trio/_signals.py index 4cd123b50c..48169215d0 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -79,7 +79,7 @@ def _add(self, signum): self._pending[signum] = None self._lot.unpark() - def _redeliver_remaining(self): + def _redeliver_remaining(self) -> None: # First make sure that any signals still in the delivery pipeline will # get redelivered self._closed = True @@ -87,7 +87,7 @@ def _redeliver_remaining(self): # And then redeliver any that are sitting in pending. This is done # using a weird recursive construct to make sure we process everything # even if some of the handlers raise exceptions. - def deliver_next(): + def deliver_next() -> None: if self._pending: signum, _ = self._pending.popitem(last=False) try: diff --git a/trio/_ssl.py b/trio/_ssl.py index 6dd55f349c..313c9d0855 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -409,7 +409,7 @@ def __setattr__(self, name, value): def __dir__(self): return super().__dir__() + list(self._forwarded) - def _check_status(self): + def _check_status(self) -> None: if self._state is _State.OK: return elif self._state is _State.BROKEN: @@ -595,14 +595,14 @@ async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): await trio.lowlevel.cancel_shielded_checkpoint() return ret - async def _do_handshake(self): + async def _do_handshake(self) -> None: try: await self._retry(self._ssl_object.do_handshake, is_handshake=True) except: self._state = _State.BROKEN raise - async def do_handshake(self): + async def do_handshake(self) -> None: """Ensure that the initial handshake has completed. The SSL protocol requires an initial handshake to exchange @@ -691,7 +691,7 @@ async def receive_some(self, max_bytes=None): else: raise - async def send_all(self, data): + async def send_all(self, data) -> None: """Encrypt some data and then send it on the underlying transport. See :meth:`trio.abc.SendStream.send_all` for details. @@ -738,7 +738,7 @@ async def unwrap(self): self._state = _State.CLOSED return (transport_stream, self._incoming.read()) - async def aclose(self): + async def aclose(self) -> None: """Gracefully shut down this connection, and close the underlying transport. @@ -825,7 +825,7 @@ async def aclose(self): finally: self._state = _State.CLOSED - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`.""" # This method's implementation is deceptively simple. # @@ -913,6 +913,6 @@ async def accept(self): https_compatible=self._https_compatible, ) - async def aclose(self): + async def aclose(self) -> None: """Close the transport listener.""" await self.transport_listener.aclose() diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 1f84345663..d9d1a67ebb 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -198,7 +198,7 @@ def returncode(self) -> Optional[int]: self._close_pidfd() return result - async def aclose(self): + async def aclose(self) -> None: """Close any pipes we have to the process (both input and output) and wait for it to exit. @@ -220,7 +220,7 @@ async def aclose(self): with trio.CancelScope(shield=True): await self.wait() - def _close_pidfd(self): + def _close_pidfd(self) -> None: if self._pidfd is not None: self._pidfd.close() self._pidfd = None @@ -270,7 +270,7 @@ def send_signal(self, sig): """ self._proc.send_signal(sig) - def terminate(self): + def terminate(self) -> None: """Terminate the process, politely if possible. On UNIX, this is equivalent to @@ -281,7 +281,7 @@ def terminate(self): """ self._proc.terminate() - def kill(self): + def kill(self) -> None: """Immediately terminate the process. On UNIX, this is equivalent to diff --git a/trio/_sync.py b/trio/_sync.py index 83c7925aff..e0be74d850 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -54,7 +54,7 @@ def set(self) -> None: self._flag = True self._lot.unpark_all() - async def wait(self): + async def wait(self) -> None: """Block until the internal flag value becomes True. If it's already True, then this method returns immediately. @@ -208,7 +208,7 @@ def total_tokens(self, new_total_tokens): self._total_tokens = new_total_tokens self._wake_waiters() - def _wake_waiters(self): + def _wake_waiters(self) -> None: available = self._total_tokens - len(self._borrowers) for woken in self._lot.unpark(count=available): self._borrowers.add(self._pending_borrowers.pop(woken)) @@ -704,11 +704,11 @@ def acquire_nowait(self): """ return self._lock.acquire_nowait() - async def acquire(self): + async def acquire(self) -> None: """Acquire the underlying lock, blocking if necessary.""" await self._lock.acquire() - def release(self): + def release(self) -> None: """Release the underlying lock.""" self._lock.release() @@ -762,7 +762,7 @@ def notify(self, n=1): raise RuntimeError("must hold the lock to notify") self._lot.repark(self._lock._lot, count=n) - def notify_all(self): + def notify_all(self) -> None: """Wake all tasks that are currently blocked in :meth:`wait`. Raises: diff --git a/trio/_threads.py b/trio/_threads.py index e3b0f0abe0..8cfd19f560 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -297,7 +297,7 @@ async def unprotected_afn() -> _T: coro = coroutine_or_error(afn, *args) return await coro # type: ignore[no-any-return] - async def await_in_trio_thread_task(): + async def await_in_trio_thread_task() -> None: q.put_nowait(await outcome.acapture(unprotected_afn)) try: diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 4500ce5318..b37b1a10b0 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -48,7 +48,7 @@ def __init__(self, fd: int) -> None: def closed(self) -> bool: return self.fd == -1 - def _raw_close(self): + def _raw_close(self) -> None: # This doesn't assume it's in a Trio context, so it can be called from # __del__. You should never call it from Trio context, because it # skips calling notify_fd_close. But from __del__, skipping that is @@ -63,10 +63,10 @@ def _raw_close(self): os.set_blocking(fd, self._original_is_blocking) os.close(fd) - def __del__(self): + def __del__(self) -> None: self._raw_close() - async def aclose(self): + async def aclose(self) -> None: if not self.closed: trio.lowlevel.notify_closing(self.fd) self._raw_close() @@ -179,7 +179,7 @@ async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: return data - async def aclose(self): + async def aclose(self) -> None: await self._fd_holder.aclose() def fileno(self): diff --git a/trio/_windows_pipes.py b/trio/_windows_pipes.py index 8b3cb93e0f..9a0a1abf6a 100644 --- a/trio/_windows_pipes.py +++ b/trio/_windows_pipes.py @@ -25,7 +25,7 @@ def __init__(self, handle: int) -> None: def closed(self) -> bool: return self.handle == -1 - def _close(self): + def _close(self) -> None: if self.closed: return handle = self.handle @@ -33,11 +33,11 @@ def _close(self): if not kernel32.CloseHandle(_handle(handle)): raise_winerror() - async def aclose(self): + async def aclose(self) -> None: self._close() await _core.checkpoint() - def __del__(self): + def __del__(self) -> None: self._close() @@ -78,7 +78,7 @@ async def wait_send_all_might_not_block(self) -> None: # not implemented yet, and probably not needed await _core.checkpoint() - async def aclose(self): + async def aclose(self) -> None: await self._handle_holder.aclose() @@ -130,5 +130,5 @@ async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: del buffer[size:] return buffer - async def aclose(self): + async def aclose(self) -> None: await self._handle_holder.aclose() diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 44554aeb3d..21be38e5bd 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -78,7 +78,7 @@ async def do_aclose(resource): nursery.start_soon(do_send_all, b"x") nursery.start_soon(checked_receive_1, b"x") - async def send_empty_then_y(): + async def send_empty_then_y() -> None: # Streams should tolerate sending b"" without giving it any # special meaning. await do_send_all(b"") @@ -137,7 +137,7 @@ async def simple_check_wait_send_all_might_not_block(scope): # closing the r side leads to BrokenResourceError on the s side # (eventually) - async def expect_broken_stream_on_send(): + async def expect_broken_stream_on_send() -> None: with _assert_raises(_core.BrokenResourceError): while True: await do_send_all(b"x" * 100) @@ -180,11 +180,11 @@ async def expect_broken_stream_on_send(): async with _ForceCloseBoth(await stream_maker()) as (s, r): # if send-then-graceful-close, receiver gets data then b"" - async def send_then_close(): + async def send_then_close() -> None: await do_send_all(b"y") await do_aclose(s) - async def receive_send_then_close(): + async def receive_send_then_close() -> None: # We want to make sure that if the sender closes the stream before # we read anything, then we still get all the data. But some # streams might block on the do_send_all call. So we let the @@ -438,7 +438,7 @@ async def receiver(s, data, seed): nursery.start_soon(receiver, s1, test_data[::-1], 2) nursery.start_soon(receiver, s2, test_data, 3) - async def expect_receive_some_empty(): + async def expect_receive_some_empty() -> None: assert await s2.receive_some(10) == b"" await s2.aclose() diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 21b36f044f..fd65ed5f01 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -23,11 +23,11 @@ def __init__(self) -> None: # channel: so after close(), calling put() raises ClosedResourceError, and # calling the get() variants drains the buffer and then returns an empty # bytearray. - def close(self): + def close(self) -> None: self._closed = True self._lot.unpark_all() - def close_and_wipe(self): + def close_and_wipe(self) -> None: self._data = bytearray() self.close() @@ -122,7 +122,7 @@ async def send_all(self, data): if self.send_all_hook is not None: await self.send_all_hook() - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and then returns immediately. @@ -137,7 +137,7 @@ async def wait_send_all_might_not_block(self): if self.wait_send_all_might_not_block_hook is not None: await self.wait_send_all_might_not_block_hook() - def close(self): + def close(self) -> None: """Marks this stream as closed, and then calls the :attr:`close_hook` (if any). @@ -154,7 +154,7 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() @@ -236,7 +236,7 @@ async def receive_some(self, max_bytes=None): raise _core.ClosedResourceError return data - def close(self): + def close(self) -> None: """Discards any pending data from the internal buffer, and marks this stream as closed. @@ -246,7 +246,7 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() @@ -255,7 +255,7 @@ def put_data(self, data): """Appends the given data to the internal buffer.""" self._incoming.put(data) - def put_eof(self): + def put_eof(self) -> None: """Adds an end-of-file marker to the internal buffer.""" self._incoming.close() @@ -320,10 +320,10 @@ def memory_stream_one_way_pair(): send_stream = MemorySendStream() recv_stream = MemoryReceiveStream() - def pump_from_send_stream_to_recv_stream(): + def pump_from_send_stream_to_recv_stream() -> None: memory_stream_pump(send_stream, recv_stream) - async def async_pump_from_send_stream_to_recv_stream(): + async def async_pump_from_send_stream_to_recv_stream() -> None: pump_from_send_stream_to_recv_stream() send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream @@ -435,7 +435,7 @@ def __init__(self) -> None: "another task is already receiving" ) - def _something_happened(self): + def _something_happened(self) -> None: self._waiters.unpark_all() # Always wakes up when one side is closed, because everyone always reacts @@ -449,11 +449,11 @@ async def _wait_for(self, fn): await self._waiters.park() await _core.checkpoint() - def close_sender(self): + def close_sender(self) -> None: self._sender_closed = True self._something_happened() - def close_receiver(self): + def close_receiver(self) -> None: self._receiver_closed = True self._something_happened() @@ -519,17 +519,17 @@ class _LockstepSendStream(SendStream): def __init__(self, lbq) -> None: self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_sender() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() async def send_all(self, data): await self._lbq.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: await self._lbq.wait_send_all_might_not_block() @@ -537,10 +537,10 @@ class _LockstepReceiveStream(ReceiveStream): def __init__(self, lbq) -> None: self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_receiver() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py index 7e1ae64009..3a114abaf0 100644 --- a/trio/tests/test_abc.py +++ b/trio/tests/test_abc.py @@ -11,7 +11,7 @@ async def test_AsyncResource_defaults() -> None: class MyAR(tabc.AsyncResource): record = attr.ib(factory=list) - async def aclose(self): + async def aclose(self) -> None: self.record.append("ac") async with MyAR() as myar: @@ -38,10 +38,10 @@ def send_nowait(self, value): async def send(self, value): raise RuntimeError # pragma: no cover - def clone(self): + def clone(self) -> None: raise RuntimeError # pragma: no cover - async def aclose(self): + async def aclose(self) -> None: pass # pragma: no cover channel = SlottedChannel() diff --git a/trio/tests/test_channel.py b/trio/tests/test_channel.py index 7ab5490d6d..f23b58c48c 100644 --- a/trio/tests/test_channel.py +++ b/trio/tests/test_channel.py @@ -245,11 +245,11 @@ async def test_close_multiple_send_handles() -> None: s1, r = open_memory_channel(0) s2 = s1.clone() - async def send_will_close(): + async def send_will_close() -> None: with pytest.raises(trio.ClosedResourceError): await s1.send("nope") - async def send_will_succeed(): + async def send_will_succeed() -> None: await s2.send("ok") async with trio.open_nursery() as nursery: @@ -266,11 +266,11 @@ async def test_close_multiple_receive_handles() -> None: s, r1 = open_memory_channel(0) r2 = r1.clone() - async def receive_will_close(): + async def receive_will_close() -> None: with pytest.raises(trio.ClosedResourceError): await r1.receive() - async def receive_will_succeed(): + async def receive_will_succeed() -> None: assert await r2.receive() == "ok" async with trio.open_nursery() as nursery: diff --git a/trio/tests/test_deprecate.py b/trio/tests/test_deprecate.py index bf3743f395..22a0f433d8 100644 --- a/trio/tests/test_deprecate.py +++ b/trio/tests/test_deprecate.py @@ -31,7 +31,7 @@ def _here(): def test_warn_deprecated(recwarn_always: _pytest.recwarn.WarningsRecorder) -> None: - def deprecated_thing(): + def deprecated_thing() -> None: warn_deprecated("ice", "1.2", issue=1, instead="water") deprecated_thing() @@ -63,10 +63,10 @@ def test_warn_deprecated_no_instead_or_issue( def test_warn_deprecated_stacklevel( recwarn_always: _pytest.recwarn.WarningsRecorder, ) -> None: - def nested1(): + def nested1() -> None: nested2() - def nested2(): + def nested2() -> None: warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3) filename, lineno = _here() @@ -76,11 +76,11 @@ def nested2(): assert got.lineno == lineno + 1 -def old(): # pragma: no cover +def old() -> None: # pragma: no cover pass -def new(): # pragma: no cover +def new() -> None: # pragma: no cover pass diff --git a/trio/tests/test_file_io.py b/trio/tests/test_file_io.py index 91cb4a3b5c..725693cd07 100644 --- a/trio/tests/test_file_io.py +++ b/trio/tests/test_file_io.py @@ -44,10 +44,10 @@ def test_wrap_invalid() -> None: def test_wrap_non_iobase() -> None: class FakeFile: - def close(self): # pragma: no cover + def close(self) -> None: # pragma: no cover pass - def write(self): # pragma: no cover + def write(self) -> None: # pragma: no cover pass wrapped = FakeFile() @@ -79,7 +79,7 @@ def test_dir_matches_wrapped(async_file, wrapped) -> None: def test_unsupported_not_forwarded() -> None: class FakeFile(io.RawIOBase): - def unsupported_attr(self): # pragma: no cover + def unsupported_attr(self) -> None: # pragma: no cover pass async_file = trio.wrap_file(FakeFile()) diff --git a/trio/tests/test_highlevel_generic.py b/trio/tests/test_highlevel_generic.py index 32dd86682f..33f8a7053e 100644 --- a/trio/tests/test_highlevel_generic.py +++ b/trio/tests/test_highlevel_generic.py @@ -13,10 +13,10 @@ class RecordSendStream(SendStream): async def send_all(self, data): self.record.append(("send_all", data)) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: self.record.append("wait_send_all_might_not_block") - async def aclose(self): + async def aclose(self) -> None: self.record.append("aclose") @@ -27,7 +27,7 @@ class RecordReceiveStream(ReceiveStream): async def receive_some(self, max_bytes=None): self.record.append(("receive_some", max_bytes)) - async def aclose(self): + async def aclose(self) -> None: self.record.append("aclose") @@ -51,7 +51,7 @@ async def test_StapledStream() -> None: assert send_stream.record == ["aclose"] send_stream.record.clear() - async def fake_send_eof(): + async def fake_send_eof() -> None: send_stream.record.append("send_eof") send_stream.send_eof = fake_send_eof @@ -75,12 +75,12 @@ async def test_StapledStream_with_erroring_close() -> None: # Make sure that if one of the aclose methods errors out, then the other # one still gets called. class BrokenSendStream(RecordSendStream): - async def aclose(self): + async def aclose(self) -> None: await super().aclose() raise ValueError class BrokenReceiveStream(RecordReceiveStream): - async def aclose(self): + async def aclose(self) -> None: await super().aclose() raise ValueError diff --git a/trio/tests/test_highlevel_open_tcp_listeners.py b/trio/tests/test_highlevel_open_tcp_listeners.py index 1107499218..524a5ad74c 100644 --- a/trio/tests/test_highlevel_open_tcp_listeners.py +++ b/trio/tests/test_highlevel_open_tcp_listeners.py @@ -138,7 +138,7 @@ def listen(self, backlog): if self.poison_listen: raise FakeOSError("whoops") - def close(self): + def close(self) -> None: self.closed = True diff --git a/trio/tests/test_highlevel_open_tcp_stream.py b/trio/tests/test_highlevel_open_tcp_stream.py index 3b0f7836b4..7feefed5be 100644 --- a/trio/tests/test_highlevel_open_tcp_stream.py +++ b/trio/tests/test_highlevel_open_tcp_stream.py @@ -18,11 +18,11 @@ def test_close_all() -> None: class CloseMe: closed = False - def close(self): + def close(self) -> None: self.closed = True class CloseKiller: - def close(self): + def close(self) -> None: raise OSError c = CloseMe() @@ -200,7 +200,7 @@ async def connect(self, sockaddr): self.failing = True self.succeeded = True - def close(self): + def close(self) -> None: self.closed = True # called when SocketStream is constructed diff --git a/trio/tests/test_highlevel_open_unix_stream.py b/trio/tests/test_highlevel_open_unix_stream.py index a73324cc64..5ab40e2e12 100644 --- a/trio/tests/test_highlevel_open_unix_stream.py +++ b/trio/tests/test_highlevel_open_unix_stream.py @@ -15,7 +15,7 @@ def test_close_on_error() -> None: class CloseMe: closed = False - def close(self): + def close(self) -> None: self.closed = True with close_on_error(CloseMe()) as c: diff --git a/trio/tests/test_highlevel_serve_listeners.py b/trio/tests/test_highlevel_serve_listeners.py index 707e299e34..52deb42d5a 100644 --- a/trio/tests/test_highlevel_serve_listeners.py +++ b/trio/tests/test_highlevel_serve_listeners.py @@ -37,7 +37,7 @@ async def accept(self): self.accepted_streams.append(stream) return stream - async def aclose(self): + async def aclose(self) -> None: self.closed = True await trio.lowlevel.checkpoint() @@ -47,7 +47,7 @@ async def test_serve_listeners_basic() -> None: record = [] - def close_hook(): + def close_hook() -> None: # Make sure this is a forceful close assert trio.current_effective_deadline() == float("-inf") record.append("closed") @@ -91,7 +91,7 @@ async def test_serve_listeners_accept_unrecognized_error() -> None: for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]: listener = MemoryListener() - async def raise_error(): + async def raise_error() -> None: raise error listener.accept_hook = raise_error @@ -104,7 +104,7 @@ async def raise_error(): async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog) -> None: listener = MemoryListener() - async def raise_EMFILE(): + async def raise_EMFILE() -> None: raise OSError(errno.EMFILE, "out of file descriptors") listener.accept_hook = raise_EMFILE diff --git a/trio/tests/test_highlevel_socket.py b/trio/tests/test_highlevel_socket.py index a67a435e0a..894c99e403 100644 --- a/trio/tests/test_highlevel_socket.py +++ b/trio/tests/test_highlevel_socket.py @@ -63,7 +63,7 @@ async def test_SocketStream_send_all() -> None: # Check a send_all that has to be split into multiple parts (on most # platforms... on Windows every send() either succeeds or fails as a # whole) - async def sender(): + async def sender() -> None: data = bytearray(BIG) await a.send_all(data) # send_all uses memoryviews internally, which temporarily "lock" @@ -87,7 +87,7 @@ async def sender(): # and we break our implementation of send_all, then we'll get some # early warning...) - async def receiver(): + async def receiver() -> None: # Make sure the sender fills up the kernel buffers and blocks await wait_all_tasks_blocked() nbytes = 0 @@ -108,7 +108,7 @@ async def receiver(): async def fill_stream(s): - async def sender(): + async def sender() -> None: while True: await s.send_all(b"x" * 10000) diff --git a/trio/tests/test_signals.py b/trio/tests/test_signals.py index bf2bb14d27..5006444f18 100644 --- a/trio/tests/test_signals.py +++ b/trio/tests/test_signals.py @@ -56,7 +56,7 @@ async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> async def test_catch_signals_wrong_thread() -> None: - async def naughty(): + async def naughty() -> None: with open_signal_receiver(signal.SIGINT): pass # pragma: no cover @@ -74,7 +74,7 @@ async def test_open_signal_receiver_conflict() -> None: # Blocks until all previous calls to run_sync_soon(idempotent=True) have been # processed. -async def wait_run_sync_soon_idempotent_queue_barrier(): +async def wait_run_sync_soon_idempotent_queue_barrier() -> None: ev = trio.Event() token = _core.current_trio_token() token.run_sync_soon(ev.set, idempotent=True) diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index 8c19476db0..2a065bbd96 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -628,7 +628,7 @@ async def test_SocketType_non_blocking_paths() -> None: await ta.recv("haha") # block then succeed - async def do_successful_blocking_recv(): + async def do_successful_blocking_recv() -> None: with assert_checkpoints(): assert await ta.recv(10) == b"2" @@ -638,7 +638,7 @@ async def do_successful_blocking_recv(): b.send(b"2") # block then cancelled - async def do_cancelled_blocking_recv(): + async def do_cancelled_blocking_recv() -> None: with assert_checkpoints(): with pytest.raises(_core.Cancelled): await ta.recv(10) @@ -656,13 +656,13 @@ async def do_cancelled_blocking_recv(): # other: tb = tsocket.from_stdlib_socket(b) - async def t1(): + async def t1() -> None: with assert_checkpoints(): assert await ta.recv(1) == b"a" with assert_checkpoints(): assert await tb.recv(1) == b"b" - async def t2(): + async def t2() -> None: with assert_checkpoints(): assert await tb.recv(1) == b"b" with assert_checkpoints(): @@ -991,11 +991,11 @@ async def test_interrupted_by_close() -> None: a = tsocket.from_stdlib_socket(a_stdlib) - async def sender(): + async def sender() -> None: with pytest.raises(_core.ClosedResourceError): await a.send(data) - async def receiver(): + async def receiver() -> None: with pytest.raises(_core.ClosedResourceError): await a.recv(1) diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index 30e2f72a3a..073f9ab1fe 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -229,18 +229,18 @@ async def no_op_sleeper(_): else: self.sleeper = sleeper - async def aclose(self): + async def aclose(self) -> None: self._conn.bio_shutdown() def renegotiate_pending(self): return self._conn.renegotiate_pending() - def renegotiate(self): + def renegotiate(self) -> None: # Returns false if a renegotiation is already in progress, meaning # nothing happens. assert self._conn.renegotiate() - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_all_conflict_detector: await _core.checkpoint() await _core.checkpoint() @@ -437,7 +437,7 @@ async def test_ssl_server_basics(client_ctx) -> None: ) assert server_transport.server_side - def client(): + def client() -> None: with client_ctx.wrap_socket( a, server_hostname="trio-test-1.example.org" ) as client_sock: @@ -597,7 +597,7 @@ async def test_renegotiation_randomized( async def sleeper(_): await trio.sleep(r.uniform(0, 10)) - async def clear(): + async def clear() -> None: while s.transport_stream.renegotiate_pending(): with assert_checkpoints(): await send(b"-") @@ -663,7 +663,7 @@ async def sleeper_with_slow_send_all(method): # And our wait_send_all_might_not_block call will give it time to get # stuck, and then start - async def sleep_then_wait_writable(): + async def sleep_then_wait_writable() -> None: await trio.sleep(1000) await s.wait_send_all_might_not_block() @@ -702,15 +702,15 @@ async def sleeper_with_slow_wait_writable_and_expect(method): async def test_resource_busy_errors(client_ctx) -> None: - async def do_send_all(): + async def do_send_all() -> None: with assert_checkpoints(): await s.send_all(b"x") - async def do_receive_some(): + async def do_receive_some() -> None: with assert_checkpoints(): await s.receive_some(1) - async def do_wait_send_all_might_not_block(): + async def do_wait_send_all_might_not_block() -> None: with assert_checkpoints(): await s.wait_send_all_might_not_block() @@ -747,7 +747,7 @@ async def test_wait_writable_calls_underlying_wait_writable() -> None: record = [] class NotAStream: - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: record.append("ok") ctx = ssl.create_default_context() @@ -835,7 +835,7 @@ async def test_unwrap(client_ctx) -> None: seq = Sequencer() - async def client(): + async def client() -> None: await client_ssl.do_handshake() await client_ssl.send_all(b"x") assert await client_ssl.receive_some(1) == b"y" @@ -862,7 +862,7 @@ async def client(): client_transport.send_stream.send_all_hook = send_all_hook await client_transport.send_stream.send_all_hook() - async def server(): + async def server() -> None: await server_ssl.do_handshake() assert await server_ssl.receive_some(1) == b"x" await server_ssl.send_all(b"y") @@ -890,11 +890,11 @@ async def test_closing_nice_case(client_ctx) -> None: # Both the handshake and the close require back-and-forth discussion, so # we need to run them concurrently - async def client_closer(): + async def client_closer() -> None: with assert_checkpoints(): await client_ssl.aclose() - async def server_closer(): + async def server_closer() -> None: assert await server_ssl.receive_some(10) == b"" assert await server_ssl.receive_some(10) == b"" with assert_checkpoints(): @@ -933,7 +933,7 @@ async def server_closer(): # the other side client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) - async def expect_eof_server(): + async def expect_eof_server() -> None: with assert_checkpoints(): assert await server_ssl.receive_some(10) == b"" with assert_checkpoints(): @@ -951,7 +951,7 @@ async def test_send_all_fails_in_the_middle(client_ctx) -> None: nursery.start_soon(client.do_handshake) nursery.start_soon(server.do_handshake) - async def bad_hook(): + async def bad_hook() -> None: raise KeyError client.transport_stream.send_stream.send_all_hook = bad_hook @@ -964,7 +964,7 @@ async def bad_hook(): closed = 0 - def close_hook(): + def close_hook() -> None: nonlocal closed closed += 1 @@ -988,11 +988,11 @@ async def test_ssl_over_ssl(client_ctx) -> None: ) server_2 = SSLStream(server_1, SERVER_CTX, server_side=True) - async def client(): + async def client() -> None: await client_2.send_all(b"hi") assert await client_2.receive_some(10) == b"bye" - async def server(): + async def server() -> None: assert await server_2.receive_some(10) == b"hi" await server_2.send_all(b"bye") @@ -1069,7 +1069,7 @@ async def test_ssl_only_closes_stream_once(client_ctx) -> None: client_orig_close_hook = client.transport_stream.send_stream.close_hook transport_close_count = 0 - def close_hook(): + def close_hook() -> None: nonlocal transport_close_count client_orig_close_hook() transport_close_count += 1 @@ -1095,7 +1095,7 @@ async def test_ssl_https_compatibility_disagreement(client_ctx) -> None: # client is in HTTPS-mode, server is not # so client doing graceful_shutdown causes an error on server - async def receive_and_expect_error(): + async def receive_and_expect_error() -> None: with pytest.raises(BrokenResourceError) as excinfo: await server.receive_some(10) assert isinstance(excinfo.value.__cause__, ssl.SSLEOFError) @@ -1112,7 +1112,7 @@ async def test_https_mode_eof_before_handshake(client_ctx) -> None: client_kwargs={"https_compatible": True}, ) - async def server_expect_clean_eof(): + async def server_expect_clean_eof() -> None: assert await server.receive_some(10) == b"" async with _core.open_nursery() as nursery: @@ -1123,7 +1123,7 @@ async def server_expect_clean_eof(): async def test_send_error_during_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) - async def bad_hook(): + async def bad_hook() -> None: raise KeyError client.transport_stream.send_stream.send_all_hook = bad_hook @@ -1140,7 +1140,7 @@ async def bad_hook(): async def test_receive_error_during_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) - async def bad_hook(): + async def bad_hook() -> None: raise KeyError client.transport_stream.receive_stream.receive_some_hook = bad_hook diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index 15efb9853e..793b1df573 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -138,7 +138,7 @@ async def test_pipes() -> None: ) as proc: msg = b"the quick brown fox jumps over the lazy dog" - async def feed_input(): + async def feed_input() -> None: await proc.stdin.send_all(msg) await proc.stdin.aclose() @@ -486,7 +486,7 @@ async def custom_deliver_cancel(proc): async def test_warn_on_failed_cancel_terminate(monkeypatch) -> None: original_terminate = Process.terminate - def broken_terminate(self): + def broken_terminate(self) -> None: original_terminate(self) raise OSError("whoops") diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index 560792c240..b6e7170270 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -26,7 +26,7 @@ async def test_Event() -> None: record = [] - async def child(): + async def child() -> None: record.append("sleeping") await e.wait() record.append("woken") @@ -288,7 +288,7 @@ async def test_Lock_and_StrictFIFOLock( holder_task = None - async def holder(): + async def holder() -> None: nonlocal holder_task holder_task = _core.current_task() async with l: @@ -424,13 +424,13 @@ def __init__(self, capacity) -> None: for _ in range(capacity - 1): self.s.send_nowait(None) - def acquire_nowait(self): + def acquire_nowait(self) -> None: self.s.send_nowait(None) - async def acquire(self): + async def acquire(self) -> None: await self.s.send(None) - def release(self): + def release(self) -> None: self.r.receive_nowait() @@ -440,13 +440,13 @@ def __init__(self) -> None: self.s, self.r = open_memory_channel(10) self.s.send_nowait(None) - def acquire_nowait(self): + def acquire_nowait(self) -> None: self.r.receive_nowait() - async def acquire(self): + async def acquire(self) -> None: await self.r.receive() - def release(self): + def release(self) -> None: self.s.send_nowait(None) @@ -459,18 +459,18 @@ def __init__(self) -> None: # waiting to acquire. self.acquired = False - def acquire_nowait(self): + def acquire_nowait(self) -> None: assert not self.acquired self.acquired = True - async def acquire(self): + async def acquire(self) -> None: if self.acquired: await self.s.send(None) else: self.acquired = True await _core.checkpoint() - def release(self): + def release(self) -> None: try: self.r.receive_nowait() except _core.WouldBlock: @@ -579,7 +579,7 @@ async def test_generic_lock_acquire_nowait_blocks_acquire( record = [] - async def lock_taker(): + async def lock_taker() -> None: record.append("started") async with lock_like: pass diff --git a/trio/tests/test_testing.py b/trio/tests/test_testing.py index 812c108137..981df52632 100644 --- a/trio/tests/test_testing.py +++ b/trio/tests/test_testing.py @@ -18,12 +18,12 @@ async def test_wait_all_tasks_blocked() -> None: record = [] - async def busy_bee(): + async def busy_bee() -> None: for _ in range(10): await _core.checkpoint() record.append("busy bee exhausted") - async def waiting_for_bee_to_leave(): + async def waiting_for_bee_to_leave() -> None: await wait_all_tasks_blocked() record.append("quiet at last!") @@ -35,7 +35,7 @@ async def waiting_for_bee_to_leave(): # check cancellation record = [] - async def cancelled_while_waiting(): + async def cancelled_while_waiting() -> None: try: await wait_all_tasks_blocked() except _core.Cancelled: @@ -50,7 +50,7 @@ async def cancelled_while_waiting(): async def test_wait_all_tasks_blocked_with_timeouts(mock_clock) -> None: record = [] - async def timeout_task(): + async def timeout_task() -> None: record.append("tt start") await sleep(5) record.append("tt finished") @@ -67,22 +67,22 @@ async def timeout_task(): async def test_wait_all_tasks_blocked_with_cushion() -> None: record = [] - async def blink(): + async def blink() -> None: record.append("blink start") await sleep(0.01) await sleep(0.01) await sleep(0.01) record.append("blink end") - async def wait_no_cushion(): + async def wait_no_cushion() -> None: await wait_all_tasks_blocked() record.append("wait_no_cushion end") - async def wait_small_cushion(): + async def wait_small_cushion() -> None: await wait_all_tasks_blocked(0.02) record.append("wait_small_cushion end") - async def wait_big_cushion(): + async def wait_big_cushion() -> None: await wait_all_tasks_blocked(0.03) record.append("wait_big_cushion end") @@ -310,7 +310,7 @@ async def getter(expect): # close wakes up blocked getters ubq2 = _UnboundedByteQueue() - async def closer(): + async def closer() -> None: await wait_all_tasks_blocked() ubq2.close() @@ -348,7 +348,7 @@ async def do_send_all(data): # and we don't know which one will get the error. resource_busy_count = 0 - async def do_send_all_count_resourcebusy(): + async def do_send_all_count_resourcebusy() -> None: nonlocal resource_busy_count try: await do_send_all(b"xxx") @@ -377,15 +377,15 @@ async def do_send_all_count_resourcebusy(): record = [] - async def send_all_hook(): + async def send_all_hook() -> None: # hook runs after send_all does its work (can pull data out) assert mss2.get_data_nowait() == b"abc" record.append("send_all_hook") - async def wait_send_all_might_not_block_hook(): + async def wait_send_all_might_not_block_hook() -> None: record.append("wait_send_all_might_not_block_hook") - def close_hook(): + def close_hook() -> None: record.append("close_hook") mss2 = MemorySendStream( @@ -440,12 +440,12 @@ async def do_receive_some(max_bytes): with pytest.raises(_core.ClosedResourceError): mrs.put_data(b"---") - async def receive_some_hook(): + async def receive_some_hook() -> None: mrs2.put_data(b"xxx") record = [] - def close_hook(): + def close_hook() -> None: record.append("closed") mrs2 = MemoryReceiveStream(receive_some_hook, close_hook) @@ -555,7 +555,7 @@ async def cancel_after_idle(nursery): await wait_all_tasks_blocked() nursery.cancel_scope.cancel() - async def check_for_cancel(): + async def check_for_cancel() -> None: with pytest.raises(_core.Cancelled): # This should block forever... or until cancelled. Even though we # sent some data on the send stream. @@ -580,11 +580,11 @@ async def test_memory_stream_pair() -> None: await a.send_eof() assert await b.receive_some(10) == b"" - async def sender(): + async def sender() -> None: await wait_all_tasks_blocked() await b.send_all(b"xyz") - async def receiver(): + async def receiver() -> None: assert await a.receive_some(10) == b"xyz" async with _core.open_nursery() as nursery: diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index 0ca94d71c7..f0c2afc1e4 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -25,7 +25,7 @@ async def test_do_in_trio_thread() -> None: async def check_case(do_in_trio_thread, fn, expected, trio_token=None): record = [] - def threadfn(): + def threadfn() -> None: try: record.append(("start", threading.current_thread())) x = do_in_trio_thread(fn, record, trio_token=trio_token) @@ -78,7 +78,7 @@ async def test_do_in_trio_thread_from_trio_thread() -> None: with pytest.raises(RuntimeError): from_thread_run_sync(lambda: None) # pragma: no branch - async def foo(): # pragma: no cover + async def foo() -> None: # pragma: no cover pass with pytest.raises(RuntimeError): @@ -90,10 +90,10 @@ def test_run_in_trio_thread_ki() -> None: # back to the caller (slick!) record = set() - async def check_run_in_trio_thread(): + async def check_run_in_trio_thread() -> None: token = _core.current_trio_token() - def trio_thread_fn(): + def trio_thread_fn() -> None: print("in Trio thread") assert not _core.currently_ki_protected() print("ki_self") @@ -104,10 +104,10 @@ def trio_thread_fn(): print("finally", sys.exc_info()) - async def trio_thread_afn(): + async def trio_thread_afn() -> None: trio_thread_fn() - def external_thread_fn(): + def external_thread_fn() -> None: try: print("running") from_thread_run_sync(trio_thread_fn, trio_token=token) @@ -171,7 +171,7 @@ def f(x): assert x == 1 assert child_thread != trio_thread - def g(): + def g() -> None: raise ValueError(threading.current_thread()) with pytest.raises(ValueError) as excinfo: @@ -246,12 +246,12 @@ def test_run_in_worker_thread_abandoned(capfd, monkeypatch) -> None: q1 = stdlib_queue.Queue() q2 = stdlib_queue.Queue() - def thread_fn(): + def thread_fn() -> None: q1.get() q2.put(threading.current_thread()) - async def main(): - async def child(): + async def main() -> None: + async def child() -> None: await to_thread_run_sync(thread_fn, cancellable=True) async with _core.open_nursery() as nursery: @@ -465,7 +465,7 @@ def thread_fn(): async def test_trio_to_thread_run_sync_expected_error() -> None: # Test correct error when passed async function - async def async_fn(): # pragma: no cover + async def async_fn() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="expected a sync function"): @@ -483,10 +483,10 @@ def thread_fn(): assert isinstance(trio_time, float) # Test correct error when passed async function - async def async_fn(): # pragma: no cover + async def async_fn() -> None: # pragma: no cover pass - def thread_fn(): + def thread_fn() -> None: from_thread_run_sync(async_fn) with pytest.raises(TypeError, match="expected a sync function"): @@ -498,11 +498,11 @@ async def test_trio_from_thread_run() -> None: # trio.from_thread.run() record = [] - async def back_in_trio_fn(): + async def back_in_trio_fn() -> None: _core.current_time() # implicitly checks that we're in trio record.append("back in trio") - def thread_fn(): + def thread_fn() -> None: record.append("in thread") from_thread_run(back_in_trio_fn) @@ -510,7 +510,7 @@ def thread_fn(): assert record == ["in thread", "back in trio"] # Test correct error when passed sync function - def sync_fn(): # pragma: no cover + def sync_fn() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="appears to be synchronous"): @@ -555,7 +555,7 @@ def test_run_fn_as_system_task_catched_badly_typed_token() -> None: async def test_from_thread_inside_trio_thread() -> None: - def not_called(): # pragma: no cover + def not_called() -> None: # pragma: no cover assert False trio_token = _core.current_trio_token() @@ -576,7 +576,7 @@ async def agen(): await to_thread_run_sync(from_thread_run, sleep, 0) record.append("ok") - async def main(): + async def main() -> None: save.append(agen()) await save[-1].asend(None) diff --git a/trio/tests/test_timeouts.py b/trio/tests/test_timeouts.py index 97cb728954..55bbda5036 100644 --- a/trio/tests/test_timeouts.py +++ b/trio/tests/test_timeouts.py @@ -48,7 +48,7 @@ async def sleep_1() -> None: await check_takes_about(sleep_1, TARGET) - async def sleep_2(): + async def sleep_2() -> None: await sleep(TARGET) await check_takes_about(sleep_2, TARGET) @@ -70,7 +70,7 @@ async def test_move_on_after() -> None: with move_on_after(-1): pass # pragma: no cover - async def sleep_3(): + async def sleep_3() -> None: with move_on_after(TARGET): await sleep(100) @@ -79,7 +79,7 @@ async def sleep_3(): @slow async def test_fail() -> None: - async def sleep_4(): + async def sleep_4() -> None: with fail_at(_core.current_time() + TARGET): await sleep(100) @@ -89,7 +89,7 @@ async def sleep_4(): with fail_at(_core.current_time() + 100): await sleep(0) - async def sleep_5(): + async def sleep_5() -> None: with fail_after(TARGET): await sleep(100) diff --git a/trio/tests/test_unix_pipes.py b/trio/tests/test_unix_pipes.py index 094cfb0cfd..ff70ad7ffd 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/tests/test_unix_pipes.py @@ -77,11 +77,11 @@ async def test_pipes_combined() -> None: write, read = await make_pipe() count = 2 ** 20 - async def sender(): + async def sender() -> None: big = bytearray(count) await write.send_all(big) - async def reader(): + async def reader() -> None: await wait_all_tasks_blocked() received = 0 while received < count: @@ -156,7 +156,7 @@ async def test_misdirected_aclose_regression() -> None: # And now set up a background task that's working on the new receive # handle - async def expect_eof(): + async def expect_eof() -> None: assert await r2.receive_some(10) == b"" async with _core.open_nursery() as nursery: @@ -181,7 +181,7 @@ async def test_close_at_bad_time_for_receive_some(monkeypatch) -> None: # # This tests what happens if the pipe gets closed in the moment *between* # when receive_some wakes up, and when it tries to call os.read - async def expect_closedresourceerror(): + async def expect_closedresourceerror() -> None: with pytest.raises(_core.ClosedResourceError): await r.receive_some(10) @@ -210,7 +210,7 @@ async def test_close_at_bad_time_for_send_all(monkeypatch) -> None: # # This tests what happens if the pipe gets closed in the moment *between* # when send_all wakes up, and when it tries to call os.write - async def expect_closedresourceerror(): + async def expect_closedresourceerror() -> None: with pytest.raises(_core.ClosedResourceError): await s.send_all(b"x" * 100) diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index 627e0b36a7..cc164989b3 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -48,7 +48,7 @@ async def test_ConflictDetector() -> None: pass # pragma: no cover assert "ul1" in str(excinfo.value) - async def wait_with_ul1(): + async def wait_with_ul1() -> None: with ul1: await wait_all_tasks_blocked() @@ -86,7 +86,7 @@ def test_module_metadata_is_fixed_up() -> None: async def test_is_main_thread() -> None: assert is_main_thread() - def not_main_thread(): + def not_main_thread() -> None: assert not is_main_thread() await trio.to_thread.run_sync(not_main_thread) @@ -100,7 +100,7 @@ class Deferred: with ignore_coroutine_never_awaited_warnings(): - async def f(): # pragma: no cover + async def f() -> None: # pragma: no cover pass with pytest.raises(TypeError) as excinfo: diff --git a/trio/tests/test_windows_pipes.py b/trio/tests/test_windows_pipes.py index a8654d8b04..6c543b1430 100644 --- a/trio/tests/test_windows_pipes.py +++ b/trio/tests/test_windows_pipes.py @@ -55,13 +55,13 @@ async def test_pipes_combined() -> None: count = 2 ** 20 replicas = 3 - async def sender(): + async def sender() -> None: async with write: big = bytearray(count) for _ in range(replicas): await write.send_all(big) - async def reader(): + async def reader() -> None: async with read: await wait_all_tasks_blocked() total_received = 0 @@ -94,7 +94,7 @@ async def test_close_during_write() -> None: w, r = await make_pipe() async with _core.open_nursery() as nursery: - async def write_forever(): + async def write_forever() -> None: with pytest.raises(_core.ClosedResourceError) as excinfo: while True: await w.send_all(b"x" * 4096) From b1e2fb7811216828a9a2def92cf94e629a0c5d68 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sun, 14 Feb 2021 22:25:01 -0500 Subject: [PATCH 48/50] black --- trio/_abc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_abc.py b/trio/_abc.py index c828faec4f..eea85dc69a 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -4,7 +4,7 @@ from typing import Generic, List, Optional, Text, Tuple, TYPE_CHECKING, TypeVar, Union import socket import trio -from ._core import _run +from ._core import _run _T = TypeVar("_T") _TSelf = TypeVar("_TSelf") From 55d191573b5234210c12ca4671b7014253630be0 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 16 Feb 2021 16:14:36 -0500 Subject: [PATCH 49/50] again --- trio/_abc.py | 8 ++--- trio/_core/_generated_io_epoll.py | 2 +- trio/_core/_generated_io_kqueue.py | 2 +- trio/_core/_generated_io_windows.py | 2 +- trio/_core/_io_epoll.py | 2 +- trio/_core/_io_kqueue.py | 2 +- trio/_core/_io_windows.py | 2 +- trio/_core/_ki.py | 32 ++++++++++++-------- trio/_core/_multierror.py | 3 +- trio/_core/_parking_lot.py | 2 +- trio/_core/_unbounded_queue.py | 21 ++++++------- trio/_core/_wakeup_socketpair.py | 5 ++-- trio/_file_io.py | 46 ++++++++++++++++------------- trio/_highlevel_generic.py | 20 +++++++------ trio/_highlevel_open_tcp_stream.py | 20 +++++++++---- trio/_highlevel_open_unix_stream.py | 4 +-- trio/_highlevel_socket.py | 24 ++++++++++----- trio/_path.py | 3 +- trio/_signals.py | 25 +++++++++++----- trio/_socket.py | 22 +++++++------- trio/_ssl.py | 3 +- 21 files changed, 148 insertions(+), 102 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index eea85dc69a..bac5135bab 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -97,7 +97,7 @@ def task_scheduled(self, task: "_run.Task") -> None: """ - def before_task_step(self, task: "Task") -> None: + def before_task_step(self, task: "_run.Task") -> None: """Called immediately before we resume running the given task. Args: @@ -105,7 +105,7 @@ def before_task_step(self, task: "Task") -> None: """ - def after_task_step(self, task: "Task") -> None: + def after_task_step(self, task: "_run.Task") -> None: """Called when we return to the main run loop after a task has yielded. Args: @@ -113,7 +113,7 @@ def after_task_step(self, task: "Task") -> None: """ - def task_exited(self, task: "Task") -> None: + def task_exited(self, task: "_run.Task") -> None: """Called when the given task exits. Args: @@ -310,7 +310,7 @@ class SendStream(AsyncResource): __slots__ = () @abstractmethod - async def send_all(self, data: Union[bytes, memoryview]) -> None: + async def send_all(self, data: Union[bytes, bytearray, memoryview]) -> None: """Sends the given data through the stream, blocking if necessary. Args: diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index a6d09fafbc..7b85cc837c 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -32,7 +32,7 @@ assert not TYPE_CHECKING or sys.platform == 'linux' -async def wait_readable(fd: Union[int, _HasFileno]) ->None: +async def wait_readable(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 21190a47be..a52a5ec41f 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -58,7 +58,7 @@ async def wait_kevent(ident: int, filter: int, abort_func: Callable[[ raise RuntimeError("must be called from async context") -async def wait_readable(fd: Union[int, _HasFileno]) ->None: +async def wait_readable(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index fb4ef94936..6b1cd622cd 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -32,7 +32,7 @@ assert not TYPE_CHECKING or sys.platform == 'win32' -async def wait_readable(sock: int) ->None: +async def wait_readable(sock: Union[int, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 55794d23c4..46e1e81bd4 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -301,7 +301,7 @@ def abort(_): @_public async def wait_readable( self, - fd: Union[int, _HasFileno], + fd: Union[int, _HasFileno, socket.socket], ) -> None: await self._epoll_wait(fd, "read_task") diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 66a1847c5e..e5bc31b3ae 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -175,7 +175,7 @@ def abort(_): await self.wait_kevent(fd, filter, abort) @_public - async def wait_readable(self, fd: Union[int, _HasFileno]) -> None: + async def wait_readable(self, fd: Union[int, _HasFileno, socket.socket]) -> None: await self._wait_common(fd, select.KQ_FILTER_READ) @_public diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 6cb2656e17..275d8699b7 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -697,7 +697,7 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) @_public - async def wait_readable(self, sock: int) -> None: + async def wait_readable(self, sock: Union[int, socket.socket]) -> None: await self._afd_poll(sock, "read_task") @_public diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index a292b9d2fa..34de0b9513 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -2,7 +2,8 @@ import signal import sys from functools import wraps -from typing import Any, TypeVar, Callable +from types import FrameType +from typing import Any, TypeVar, Callable, Optional, Union import attr import async_generator @@ -83,17 +84,18 @@ # NB: according to the signal.signal docs, 'frame' can be None on entry to # this function: -def ki_protection_enabled(frame): - while frame is not None: - if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals: - return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] - if frame.f_code.co_name == "__del__": +def ki_protection_enabled(frame: FrameType) -> bool: + traversed_frame: Optional[FrameType] = frame + while traversed_frame is not None: + if LOCALS_KEY_KI_PROTECTION_ENABLED in traversed_frame.f_locals: + return traversed_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] # type: ignore[no-any-return] + if traversed_frame.f_code.co_name == "__del__": return True - frame = frame.f_back + traversed_frame = traversed_frame.f_back return True -def currently_ki_protected(): +def currently_ki_protected() -> bool: r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection enabled. @@ -174,9 +176,15 @@ def wrapper(*args: object, **kwargs: object) -> object: @attr.s class KIManager: - handler = attr.ib(default=None) - - def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints): + handler: Optional[ + Callable[[Union[int, signal.Signals], FrameType], object] + ] = attr.ib(default=None) + + def install( + self, + deliver_cb: Callable[[], object], + restrict_keyboard_interrupt_to_checkpoints: bool, + ) -> None: assert self.handler is None if ( not is_main_thread() @@ -184,7 +192,7 @@ def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints): ): return - def handler(signum, frame): + def handler(signum: Union[int, signal.Signals], frame: FrameType) -> None: assert signum == signal.SIGINT protection_enabled = ki_protection_enabled(frame) if protection_enabled or restrict_keyboard_interrupt_to_checkpoints: diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index ecbb773962..6fb94b437b 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -10,6 +10,7 @@ Optional, overload, Set, + Sequence, Type, Union, ) @@ -215,7 +216,7 @@ class MultiError(BaseException): """ - def __init__(self, exceptions: List[BaseException]) -> None: + def __init__(self, exceptions: Sequence[BaseException]) -> None: # Avoid recursion when exceptions[0] returned by __new__() happens # to be a MultiError and subsequently __init__() is called. if hasattr(self, "exceptions"): diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index b2130417b2..84582f75db 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -84,7 +84,7 @@ @attr.s(frozen=True) class _ParkingLotStatistics: - tasks_waiting = attr.ib() + tasks_waiting: int = attr.ib() @attr.s(eq=False, hash=False) diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index f6eb7a9c52..8880cbd986 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -8,12 +8,13 @@ _T = TypeVar("_T") +_TSelf = TypeVar("_TSelf") @attr.s(frozen=True) class _UnboundedQueueStats: - qsize = attr.ib() - tasks_waiting = attr.ib() + qsize: int = attr.ib() + tasks_waiting: int = attr.ib() class UnboundedQueue(Generic[_T], metaclass=Final): @@ -61,11 +62,11 @@ def __init__(self) -> None: def __repr__(self) -> str: return "".format(len(self._data)) - def qsize(self): + def qsize(self) -> int: """Returns the number of items currently in the queue.""" return len(self._data) - def empty(self): + def empty(self) -> bool: """Returns True if the queue is empty, False otherwise. There is some subtlety to interpreting this method's return value: see @@ -93,13 +94,13 @@ def put_nowait(self, obj: _T) -> None: self._can_get = True self._data.append(obj) - def _get_batch_protected(self): + def _get_batch_protected(self) -> List[_T]: data = self._data.copy() self._data.clear() self._can_get = False return data - def get_batch_nowait(self): + def get_batch_nowait(self) -> List[_T]: """Attempt to get the next batch from the queue, without blocking. Returns: @@ -115,7 +116,7 @@ def get_batch_nowait(self): raise _core.WouldBlock return self._get_batch_protected() - async def get_batch(self): + async def get_batch(self) -> List[_T]: """Get the next batch from the queue, blocking as necessary. Returns: @@ -133,7 +134,7 @@ async def get_batch(self): finally: await _core.cancel_shielded_checkpoint() - def statistics(self): + def statistics(self) -> _UnboundedQueueStats: """Return an object containing debugging information. Currently the following fields are defined: @@ -147,8 +148,8 @@ def statistics(self): qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting ) - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self - async def __anext__(self): + async def __anext__(self) -> List[_T]: return await self.get_batch() diff --git a/trio/_core/_wakeup_socketpair.py b/trio/_core/_wakeup_socketpair.py index ab4d5973fd..77f2e6b8b9 100644 --- a/trio/_core/_wakeup_socketpair.py +++ b/trio/_core/_wakeup_socketpair.py @@ -1,13 +1,14 @@ import socket import sys import signal +from typing import Optional import warnings from .. import _core from .._util import is_main_thread -def _has_warn_on_full_buffer(): +def _has_warn_on_full_buffer() -> bool: if sys.version_info < (3, 7): return False @@ -51,7 +52,7 @@ def __init__(self) -> None: self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) except OSError: pass - self.old_wakeup_fd = None + self.old_wakeup_fd: Optional[int] = None def wakeup_thread_and_signal_safe(self) -> None: try: diff --git a/trio/_file_io.py b/trio/_file_io.py index 0d1eee61de..3c28d7b5e0 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -1,22 +1,23 @@ from functools import partial import io +import os from typing import ( Any, # AnyStr, # AsyncContextManager, AsyncIterator, # Awaitable, - # Callable, + Callable, # ContextManager, # FrozenSet, # Iterator, # Mapping, # NoReturn, Optional, - # Sequence, + Sequence, Union, # Sequence, - # TypeVar, + TypeVar, Tuple, List, Iterable, @@ -31,6 +32,8 @@ import trio +_TSelf = TypeVar("_TSelf") + # This list is also in the docs, make sure to keep them in sync _FILE_SYNC_ATTRS = { "closed", @@ -92,14 +95,14 @@ def wrapped(self) -> io.IOBase: return self._wrapped - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: if name in _FILE_SYNC_ATTRS: return getattr(self._wrapped, name) if name in _FILE_ASYNC_METHODS: meth = getattr(self._wrapped, name) @async_wraps(self.__class__, self._wrapped.__class__, name) - async def wrapper(*args, **kwargs): # type: ignore[misc] + async def wrapper(*args, **kwargs): # type: ignore[misc, no-untyped-def] func = partial(meth, *args, **kwargs) return await trio.to_thread.run_sync(func) @@ -109,23 +112,23 @@ async def wrapper(*args, **kwargs): # type: ignore[misc] raise AttributeError(name) - def __dir__(self): + def __dir__(self) -> Sequence[str]: attrs = set(super().__dir__()) attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a)) attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a)) - return attrs + return attrs # type: ignore[return-value] - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self - async def __anext__(self): - line = await self.readline() + async def __anext__(self) -> str: + line: str = await self.readline() # type: ignore[operator] if line: return line else: raise StopAsyncIteration - async def detach(self): + async def detach(self) -> "_AsyncIOBase": """Like :meth:`io.BufferedIOBase.detach`, but async. This also re-wraps the result in a new :term:`asynchronous file object` @@ -133,7 +136,7 @@ async def detach(self): """ - raw = await trio.to_thread.run_sync(self._wrapped.detach) + raw: Union[io.RawIOBase, BinaryIO] = await trio.to_thread.run_sync(self._wrapped.detach) # type: ignore[attr-defined] return wrap_file(raw) async def aclose(self) -> None: @@ -266,14 +269,14 @@ async def tell(self) -> int: async def open_file( - file, - mode="r", - buffering=-1, - encoding=None, - errors=None, - newline=None, - closefd=True, - opener=None, + file: Union[os.PathLike, int], + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + closefd: bool = True, + opener: Optional[Callable[[str, int], int]] = None, ): """Asynchronous version of :func:`io.open`. @@ -320,6 +323,7 @@ def wrap_file(obj: Union[IO[Any], io.IOBase]) -> _AsyncIOBase: ... +# def wrap_file(obj: Union[IO[Any], io.IOBase, io.RawIOBase, BinaryIO, io.BufferedIOBase, TextIO, io.TextIOBase]) -> _AsyncIOBase: def wrap_file(file): """This wraps any file object in a wrapper that provides an asynchronous file object interface. @@ -338,7 +342,7 @@ def wrap_file(file): """ - def has(attr): + def has(attr: str) -> bool: return hasattr(file, attr) and callable(getattr(file, attr)) if not (has("close") and (has("read") or has("write"))): diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index c19e4d0276..ff8db28b66 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -1,12 +1,14 @@ +from typing import Optional, Union + import attr import trio -from .abc import HalfCloseableStream +from .abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream from trio._util import Final -async def aclose_forcefully(resource): +async def aclose_forcefully(resource: AsyncResource) -> None: """Close an async resource or async generator immediately, without blocking to do any graceful cleanup. @@ -72,18 +74,18 @@ class StapledStream(HalfCloseableStream, metaclass=Final): """ - send_stream = attr.ib() - receive_stream = attr.ib() + send_stream: SendStream = attr.ib() + receive_stream: ReceiveStream = attr.ib() - async def send_all(self, data): + async def send_all(self, data: Union[bytes, bytearray, memoryview]) -> None: """Calls ``self.send_stream.send_all``.""" return await self.send_stream.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls ``self.send_stream.wait_send_all_might_not_block``.""" return await self.send_stream.wait_send_all_might_not_block() - async def send_eof(self): + async def send_eof(self) -> None: """Shuts down the send side of the stream. If ``self.send_stream.send_eof`` exists, then calls it. Otherwise, @@ -91,11 +93,11 @@ async def send_eof(self): """ if hasattr(self.send_stream, "send_eof"): - return await self.send_stream.send_eof() + return await self.send_stream.send_eof() # type: ignore[no-any-return, attr-defined] else: return await self.send_stream.aclose() - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: """Calls ``self.receive_stream.receive_some``.""" return await self.receive_stream.receive_some(max_bytes) diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 172468b4de..e2d7911894 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -1,8 +1,10 @@ from contextlib import contextmanager -from typing import Iterator, Set +from typing import Iterator, Optional, Sequence, Set, Union import trio from trio.socket import getaddrinfo, SOCK_STREAM, socket, SocketType +from trio._socket import _Address, _AddressInfo + # Implementation of RFC 6555 "Happy eyeballs" # https://tools.ietf.org/html/rfc6555 @@ -119,7 +121,7 @@ def close_all() -> Iterator[Set[SocketType]]: raise trio.MultiError(errs) -def reorder_for_rfc_6555_section_5_4(targets): +def reorder_for_rfc_6555_section_5_4(targets: _AddressInfo) -> None: # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address # families (e.g. IPv4 and IPv6), then you should make sure that your first # and second attempts use different families: @@ -137,7 +139,7 @@ def reorder_for_rfc_6555_section_5_4(targets): break -def format_host_port(host, port): +def format_host_port(host: Union[bytes, str], port: int) -> str: host = host.decode("ascii") if isinstance(host, bytes) else host if ":" in host: return "[{}]:{}".format(host, port) @@ -166,8 +168,12 @@ def format_host_port(host, port): # AF_INET6: "..."} # this might be simpler after async def open_tcp_stream( - host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None -): + host: Union[bytes, str], + port: int, + *, + happy_eyeballs_delay: float = DEFAULT_DELAY, + local_address: Optional[str] = None, +) -> trio.SocketStream: """Connect to the given host and port over TCP. If the given ``host`` has multiple IP addresses associated with it, then @@ -276,7 +282,9 @@ async def open_tcp_stream( # the next connection attempt to start early # code needs to ensure sockets can be closed appropriately in the # face of crash or cancellation - async def attempt_connect(socket_args, sockaddr, attempt_failed): + async def attempt_connect( + socket_args: Sequence, sockaddr: _Address, attempt_failed: trio.Event + ) -> None: nonlocal winning_socket try: diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py index e85ba058d7..47be214c15 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/trio/_highlevel_open_unix_stream.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from typing import Iterator, TypeVar +from typing import Iterator, TypeVar, Union from typing_extensions import Protocol import trio @@ -31,7 +31,7 @@ def close_on_error(obj: _CL) -> Iterator[_CL]: raise -async def open_unix_socket(filename): +async def open_unix_socket(filename: Union[bytes, str]) -> trio.SocketStream: """Opens a connection to the specified `Unix domain socket `__. diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 548bb002bf..dee8b00c3e 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -2,7 +2,7 @@ import errno from contextlib import contextmanager -from typing import Iterator +from typing import Iterator, Optional, overload, Union import trio from . import socket as tsocket @@ -60,7 +60,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final): """ - def __init__(self, socket) -> None: + def __init__(self, socket: tsocket.SocketType) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -94,7 +94,7 @@ def __init__(self, socket) -> None: except OSError: pass - async def send_all(self, data): + async def send_all(self, data: bytes) -> None: if self.socket.did_shutdown_SHUT_WR: raise trio.ClosedResourceError("can't send data after sending EOF") with self._send_conflict_detector: @@ -128,7 +128,7 @@ async def send_eof(self) -> None: with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: Optional[int] =None) -> bytes: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: @@ -142,7 +142,7 @@ async def aclose(self) -> None: # __aenter__, __aexit__ inherited from HalfCloseableStream are OK - def setsockopt(self, level, option, value): + def setsockopt(self, level: int, option: int, value: Union[bytes, int]) -> None: """Set an option on the underlying socket. See :meth:`socket.socket.setsockopt` for details. @@ -150,7 +150,15 @@ def setsockopt(self, level, option, value): """ return self.socket.setsockopt(level, option, value) - def getsockopt(self, level, option, buffersize=0): + @overload + def getsockopt(self, level: int, option: int) -> int: + ... + + @overload + def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: + ... + + def getsockopt(self, level: int , option: int, buffersize: int = 0) -> Union[int, bytes]: """Check the current value of an option on the underlying socket. See :meth:`socket.socket.getsockopt` for details. @@ -333,7 +341,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final): """ - def __init__(self, socket) -> None: + def __init__(self, socket: tsocket.SocketType) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -349,7 +357,7 @@ def __init__(self, socket) -> None: self.socket = socket - async def accept(self): + async def accept(self) -> SocketStream: """Accept an incoming connection. Returns: diff --git a/trio/_path.py b/trio/_path.py index 55d0bed8b6..fac5e4f65c 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -8,6 +8,7 @@ Iterator, Optional, overload, + Sequence, Type, TYPE_CHECKING, TypeVar, @@ -211,7 +212,7 @@ def __getattr__(self, name): return rewrap_path(value) raise AttributeError(name) - def __dir__(self): + def __dir__(self) -> Sequence[str]: return super().__dir__() + self._forward def __repr__(self) -> str: diff --git a/trio/_signals.py b/trio/_signals.py index 48169215d0..88e125edb1 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -2,7 +2,17 @@ from contextlib import contextmanager from collections import OrderedDict from types import FrameType -from typing import Any, Callable, Iterable, Iterator, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Iterable, + Iterator, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import trio from ._util import signal_raise, is_main_thread, ConflictDetector @@ -10,6 +20,7 @@ # https://github.com/python/typeshed/blob/master/stdlib/3/signal.pyi#L82-L83 _SignalNumber = Union[int, signal.Signals] _Handler = Union[Callable[[signal.Signals, FrameType], Any], int, signal.Handlers, None] +_TSelf = TypeVar("_TSelf") # Discussion of signal handling strategies: @@ -65,14 +76,14 @@ def _signal_handler( class SignalReceiver: def __init__(self) -> None: # {signal num: None} - self._pending = OrderedDict() + self._pending: "OrderedDict[_SignalNumber, None]" = OrderedDict() self._lot = trio.lowlevel.ParkingLot() self._conflict_detector = ConflictDetector( "only one task can iterate on a signal receiver at a time" ) self._closed = False - def _add(self, signum): + def _add(self, signum: _SignalNumber) -> None: if self._closed: signal_raise(signum) else: @@ -98,13 +109,13 @@ def deliver_next() -> None: deliver_next() # Helper for tests, not public or otherwise used - def _pending_signal_count(self): + def _pending_signal_count(self) -> int: return len(self._pending) - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self - async def __anext__(self): + async def __anext__(self) -> _SignalNumber: if self._closed: raise RuntimeError("open_signal_receiver block already exited") # In principle it would be possible to support multiple concurrent @@ -166,7 +177,7 @@ def open_signal_receiver(*signals: _SignalNumber) -> Iterator[SignalReceiver]: token = trio.lowlevel.current_trio_token() queue = SignalReceiver() - def handler(signum, _): + def handler(signum: _SignalNumber, _: object) -> None: token.run_sync_soon(queue._add, signum, idempotent=True) try: diff --git a/trio/_socket.py b/trio/_socket.py index 7039c3ad4e..2fa98ac334 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -33,7 +33,15 @@ _Address = Union[tuple, str] - +_AddressInfo = List[ + Tuple[ + _stdlib_socket.AddressFamily, + _stdlib_socket.SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] +] # Usage: # @@ -168,15 +176,7 @@ async def getaddrinfo( type: int = 0, proto: int = 0, flags: int = 0, -) -> List[ - Tuple[ - _stdlib_socket.AddressFamily, - _stdlib_socket.SocketKind, - int, - str, - Union[Tuple[str, int], Tuple[str, int, int, int]], - ] -]: +) -> _AddressInfo: """Look up a numeric address given a name. Arguments and return values are identical to :func:`socket.getaddrinfo`, @@ -637,7 +637,7 @@ def __getattr__(self, name): # type: ignore return getattr(self._sock, name) raise AttributeError(name) - def __dir__(self) -> List[str]: + def __dir__(self) -> Sequence[str]: return [*super().__dir__(), *self._forward] def __enter__(self: _T) -> _T: diff --git a/trio/_ssl.py b/trio/_ssl.py index 313c9d0855..dcc9793be1 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -152,6 +152,7 @@ import operator as _operator import ssl as _stdlib_ssl from enum import Enum as _Enum +from typing import Sequence import trio @@ -406,7 +407,7 @@ def __setattr__(self, name, value): else: super().__setattr__(name, value) - def __dir__(self): + def __dir__(self) -> Sequence[str]: return super().__dir__() + list(self._forwarded) def _check_status(self) -> None: From 9ecbc1d71189f1e6e713898759bc9b7b7a2f24df Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 16 Feb 2021 16:17:10 -0500 Subject: [PATCH 50/50] black --- trio/_highlevel_socket.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index dee8b00c3e..0639f983e3 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -128,7 +128,7 @@ async def send_eof(self) -> None: with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes: Optional[int] =None) -> bytes: + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: @@ -158,7 +158,9 @@ def getsockopt(self, level: int, option: int) -> int: def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: ... - def getsockopt(self, level: int , option: int, buffersize: int = 0) -> Union[int, bytes]: + def getsockopt( + self, level: int, option: int, buffersize: int = 0 + ) -> Union[int, bytes]: """Check the current value of an option on the underlying socket. See :meth:`socket.socket.getsockopt` for details.