diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e0959c1797..c9ed826d17 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,12 +61,16 @@ jobs: matrix: python: ['pypy-3.6', 'pypy-3.7', '3.6', '3.7', '3.8', '3.9', '3.6-dev', '3.7-dev', '3.8-dev', '3.9-dev'] check_formatting: ['0'] + check_typing: ['0'] pypy_nightly_branch: [''] extra_name: [''] include: - python: '3.8' check_formatting: '1' extra_name: ', check formatting' + - python: '3.8' + check_typing: '1' + extra_name: ', check typing' - python: '3.7' # <- not actually used pypy_nightly_branch: 'py3.7' extra_name: ', pypy 3.7 nightly' @@ -88,6 +92,7 @@ jobs: env: PYPY_NIGHTLY_BRANCH: '${{ matrix.pypy_nightly_branch }}' CHECK_FORMATTING: '${{ matrix.check_formatting }}' + CHECK_TYPING: '${{ matrix.check_typing }}' # Should match 'name:' up above JOB_NAME: 'Ubuntu (${{ matrix.python }}${{ matrix.extra_name }})' diff --git a/check.sh b/check.sh index 57f1e2db40..c1e81986ae 100755 --- a/check.sh +++ b/check.sh @@ -23,11 +23,6 @@ flake8 trio/ \ --ignore=D,E,W,F401,F403,F405,F821,F822\ || EXIT_STATUS=$? -# Run mypy on all supported platforms -mypy -m trio -m trio.testing --platform linux || EXIT_STATUS=$? -mypy -m trio -m trio.testing --platform darwin || EXIT_STATUS=$? # tests FreeBSD too -mypy -m trio -m trio.testing --platform win32 || EXIT_STATUS=$? - # Finally, leave a really clear warning of any issues and exit if [ $EXIT_STATUS -ne 0 ]; then cat <= 1.9 idna outcome sniffio +typing-extensions # See note in test-requirements.in immutables >= 0.6 diff --git a/docs-requirements.txt b/docs-requirements.txt index 7d6220f8c0..1c1d2f5cea 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --output-file docs-requirements.txt docs-requirements.in +# pip-compile --output-file=docs-requirements.txt docs-requirements.in # alabaster==0.7.12 # via sphinx @@ -81,5 +81,10 @@ toml==0.10.2 # via towncrier towncrier==19.2.0 # via -r docs-requirements.in +typing-extensions==3.7.4.3 + # via -r docs-requirements.in urllib3==1.26.3 # via requests + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/docs/source/conf.py b/docs/source/conf.py index 6045ffd828..7d09b5900e 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -54,6 +54,18 @@ # https://github.com/sphinx-doc/sphinx/issues/7722 ("py:class", "SendType"), ("py:class", "ReceiveType"), + ("py:class", "_T_contra"), + ("py:class", "_T_co"), + ("py:class", "_T"), + ("py:class", "T_resource"), + ("py:class", "AbstractContextManager"), + ("py:class", "_socket.socket"), + ("py:class", "signal.Signals"), + ("py:class", "trio._signals.SignalReceiver"), + ("py:class", "socket.socket"), + ("py:class", "trio._core._run._RunStatistics"), + ("py:class", "socket.AddressFamily"), + ("py:class", "socket.SocketKind"), ] autodoc_inherit_docstrings = False default_role = "obj" diff --git a/mypy.ini b/mypy.ini index 31eeef1cd0..147dc059bc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -12,14 +12,14 @@ warn_redundant_casts = True warn_return_any = True # Avoid subtle backsliding -#disallow_any_decorated = True -#disallow_incomplete_defs = True -#disallow_subclassing_any = True +disallow_any_decorated = True +disallow_incomplete_defs = True +disallow_subclassing_any = True # Enable gradually / for new modules check_untyped_defs = False disallow_untyped_calls = False -disallow_untyped_defs = False +disallow_untyped_defs = True # DO NOT use `ignore_errors`; it doesn't apply # downstream and users have to deal with them. diff --git a/setup.py b/setup.py index 11eda8e96f..5f48a4224f 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,7 @@ "idna", "outcome", "sniffio", + "typing-extensions", # cffi 1.12 adds from_buffer(require_writable=True) and ffi.release() # cffi 1.14 fixes memory leak inside ffi.getwinerror() # cffi is required on Windows, except on PyPy where it is built-in diff --git a/test-requirements.txt b/test-requirements.txt index 1a974170a6..81556807c5 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --output-file test-requirements.txt test-requirements.in +# pip-compile --output-file=test-requirements.txt test-requirements.in # appdirs==1.4.4 # via black @@ -139,3 +139,6 @@ wcwidth==0.2.5 # via prompt-toolkit wrapt==1.12.1 # via astroid + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/trio/_abc.py b/trio/_abc.py index b8e341fdaa..bac5135bab 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,8 +1,13 @@ # coding: utf-8 from abc import ABCMeta, abstractmethod -from typing import Generic, TypeVar +from typing import Generic, List, Optional, Text, Tuple, TYPE_CHECKING, TypeVar, Union +import socket import trio +from ._core import _run + +_T = TypeVar("_T") +_TSelf = TypeVar("_TSelf") # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a @@ -13,7 +18,7 @@ class Clock(metaclass=ABCMeta): __slots__ = () @abstractmethod - def start_clock(self): + def start_clock(self) -> None: """Do any setup this clock might need. Called at the beginning of the run. @@ -21,7 +26,7 @@ def start_clock(self): """ @abstractmethod - def current_time(self): + def current_time(self) -> float: """Return the current time, according to this clock. This is used to implement functions like :func:`trio.current_time` and @@ -33,7 +38,7 @@ def current_time(self): """ @abstractmethod - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: """Compute the real time until the given deadline. This is called before we enter a system-specific wait function like @@ -67,13 +72,13 @@ class Instrument(metaclass=ABCMeta): __slots__ = () - def before_run(self): + def before_run(self) -> None: """Called at the beginning of :func:`trio.run`.""" - def after_run(self): + def after_run(self) -> None: """Called just before :func:`trio.run` returns.""" - def task_spawned(self, task): + def task_spawned(self, task: "_run.Task") -> None: """Called when the given task is created. Args: @@ -81,7 +86,7 @@ def task_spawned(self, task): """ - def task_scheduled(self, task): + def task_scheduled(self, task: "_run.Task") -> None: """Called when the given task becomes runnable. It may still be some time before it actually runs, if there are other @@ -92,7 +97,7 @@ def task_scheduled(self, task): """ - def before_task_step(self, task): + def before_task_step(self, task: "_run.Task") -> None: """Called immediately before we resume running the given task. Args: @@ -100,7 +105,7 @@ def before_task_step(self, task): """ - def after_task_step(self, task): + def after_task_step(self, task: "_run.Task") -> None: """Called when we return to the main run loop after a task has yielded. Args: @@ -108,7 +113,7 @@ def after_task_step(self, task): """ - def task_exited(self, task): + def task_exited(self, task: "_run.Task") -> None: """Called when the given task exits. Args: @@ -116,7 +121,7 @@ def task_exited(self, task): """ - def before_io_wait(self, timeout): + def before_io_wait(self, timeout: float) -> None: """Called before blocking to wait for I/O readiness. Args: @@ -124,7 +129,7 @@ def before_io_wait(self, timeout): """ - def after_io_wait(self, timeout): + def after_io_wait(self, timeout: float) -> None: """Called after handling pending I/O. Args: @@ -146,7 +151,23 @@ class HostnameResolver(metaclass=ABCMeta): __slots__ = () @abstractmethod - async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): + async def getaddrinfo( + self, + host: Optional[Union[bytearray, bytes, Text]], + port: Union[str, int, None], + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> List[ + Tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] + ]: """A custom implementation of :func:`~trio.socket.getaddrinfo`. Called by :func:`trio.socket.getaddrinfo`. @@ -163,7 +184,11 @@ async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): """ @abstractmethod - async def getnameinfo(self, sockaddr, flags): + async def getnameinfo( + self, + sockaddr: Union[Tuple[str, int], Tuple[str, int, int, int]], + flags: int, + ) -> Tuple[str, str]: """A custom implementation of :func:`~trio.socket.getnameinfo`. Called by :func:`trio.socket.getnameinfo`. @@ -180,7 +205,12 @@ class SocketFactory(metaclass=ABCMeta): """ @abstractmethod - def socket(self, family=None, type=None, proto=None): + def socket( + self, + family: Optional[int] = None, + type: Optional[int] = None, + proto: Optional[int] = None, + ) -> "trio.socket.SocketType": """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, @@ -226,7 +256,7 @@ class AsyncResource(metaclass=ABCMeta): __slots__ = () @abstractmethod - async def aclose(self): + async def aclose(self) -> None: """Close this resource, possibly blocking. IMPORTANT: This method may block in order to perform a "graceful" @@ -254,10 +284,10 @@ async def aclose(self): """ - async def __aenter__(self): + async def __aenter__(self: _T) -> _T: return self - async def __aexit__(self, *args): + async def __aexit__(self, *args: object) -> None: await self.aclose() @@ -280,7 +310,7 @@ class SendStream(AsyncResource): __slots__ = () @abstractmethod - async def send_all(self, data): + async def send_all(self, data: Union[bytes, bytearray, memoryview]) -> None: """Sends the given data through the stream, blocking if necessary. Args: @@ -306,7 +336,7 @@ async def send_all(self, data): """ @abstractmethod - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Block until it's possible that :meth:`send_all` might not block. This method may return early: it's possible that after it returns, @@ -386,7 +416,7 @@ class ReceiveStream(AsyncResource): __slots__ = () @abstractmethod - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: Optional[int] = ...) -> bytes: """Wait until there is data available on this stream, and then return some of it. @@ -414,10 +444,10 @@ async def receive_some(self, max_bytes=None): """ - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self - async def __anext__(self): + async def __anext__(self) -> bytes: data = await self.receive_some() if not data: raise StopAsyncIteration @@ -447,7 +477,7 @@ class HalfCloseableStream(Stream): __slots__ = () @abstractmethod - async def send_eof(self): + async def send_eof(self) -> None: """Send an end-of-file indication on this stream, if possible. The difference between :meth:`send_eof` and @@ -526,7 +556,7 @@ class Listener(AsyncResource, Generic[T_resource]): __slots__ = () @abstractmethod - async def accept(self): + async def accept(self) -> T_resource: """Wait until an incoming connection arrives, and then return it. Returns: @@ -633,7 +663,7 @@ async def receive(self) -> ReceiveType: """ - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self async def __anext__(self) -> ReceiveType: diff --git a/trio/_channel.py b/trio/_channel.py index 1cecc55621..b37a003d7e 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,18 +1,28 @@ from collections import deque, OrderedDict from math import inf +import sys +from typing import cast, Callable, Deque, Generic, Set, Tuple, TypeVar, Union import attr from outcome import Error, Value from .abc import SendChannel, ReceiveChannel, Channel from ._util import generic_function, NoPublicConstructor +from ._core._run import Task import trio from ._core import enable_ki_protection +_T_contra = TypeVar("_T_contra", contravariant=True) +_T_co = TypeVar("_T_co", covariant=True) + + @generic_function -def open_memory_channel(max_buffer_size): +def open_memory_channel( # type: ignore[misc] + # TODO: should restrict the float bit to just the inf value + max_buffer_size: Union[int, float], +) -> Tuple["MemorySendChannel", "MemoryReceiveChannel"]: """Open a channel for passing objects between tasks within a process. Memory channels are lightweight, cheap to allocate, and entirely @@ -68,7 +78,7 @@ def open_memory_channel(max_buffer_size): raise TypeError("max_buffer_size must be an integer or math.inf") if max_buffer_size < 0: raise ValueError("max_buffer_size must be >= 0") - state = MemoryChannelState(max_buffer_size) + state = MemoryChannelState(max_buffer_size) # type: ignore[var-annotated] return ( MemorySendChannel._create(state), MemoryReceiveChannel._create(state), @@ -85,17 +95,19 @@ class MemoryChannelStats: tasks_waiting_receive = attr.ib() -@attr.s(slots=True) -class MemoryChannelState: - max_buffer_size = attr.ib() - data = attr.ib(factory=deque) +# TODO: ick... how to handle 3.6? +# https://github.com/python-attrs/attrs/issues/313 +@attr.s(slots=sys.version_info >= (3, 7)) +class MemoryChannelState(Generic[_T_contra]): + max_buffer_size: float = attr.ib() + data: Deque[_T_contra] = attr.ib(factory=deque) # Counts of open endpoints using this state - open_send_channels = attr.ib(default=0) - open_receive_channels = attr.ib(default=0) + open_send_channels: int = attr.ib(default=0) + open_receive_channels: int = attr.ib(default=0) # {task: value} - send_tasks = attr.ib(factory=OrderedDict) + send_tasks: "OrderedDict[Task, _T_contra]" = attr.ib(factory=OrderedDict) # {task: None} - receive_tasks = attr.ib(factory=OrderedDict) + receive_tasks: "OrderedDict[Task, None]" = attr.ib(factory=OrderedDict) def statistics(self): return MemoryChannelStats( @@ -109,18 +121,18 @@ def statistics(self): @attr.s(eq=False, repr=False) -class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor): - _state = attr.ib() - _closed = attr.ib(default=False) +class MemorySendChannel(SendChannel[_T_contra], metaclass=NoPublicConstructor): + _state: MemoryChannelState[_T_contra] = attr.ib() + _closed: bool = attr.ib(default=False) # This is just the tasks waiting on *this* object. As compared to # self._state.send_tasks, which includes tasks from this object and # all clones. - _tasks = attr.ib(factory=set) + _tasks: Set[Task] = attr.ib(factory=set) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._state.open_send_channels += 1 - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), id(self._state) ) @@ -130,7 +142,7 @@ def statistics(self): return self._state.statistics() @enable_ki_protection - def send_nowait(self, value): + def send_nowait(self, value: _T_contra) -> None: """Like `~trio.abc.SendChannel.send`, but if the channel's buffer is full, raises `WouldBlock` instead of blocking. @@ -150,7 +162,7 @@ def send_nowait(self, value): raise trio.WouldBlock @enable_ki_protection - async def send(self, value): + async def send(self, value: _T_contra) -> None: """See `SendChannel.send `. Memory channels allow multiple tasks to call `send` at the same time. @@ -178,7 +190,7 @@ def abort_fn(_): await trio.lowlevel.wait_task_rescheduled(abort_fn) @enable_ki_protection - def clone(self): + def clone(self) -> "MemorySendChannel[_T_contra]": """Clone this send channel object. This returns a new `MemorySendChannel` object, which acts as a @@ -213,7 +225,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() @enable_ki_protection - def close(self): + def close(self) -> None: """Close this send channel object synchronously. All channel objects have an asynchronous `~.AsyncResource.aclose` method. @@ -241,30 +253,30 @@ def close(self): self._state.receive_tasks.clear() @enable_ki_protection - async def aclose(self): + async def aclose(self) -> None: self.close() await trio.lowlevel.checkpoint() @attr.s(eq=False, repr=False) -class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor): - _state = attr.ib() - _closed = attr.ib(default=False) - _tasks = attr.ib(factory=set) +class MemoryReceiveChannel(ReceiveChannel[_T_co], metaclass=NoPublicConstructor): + _state: MemoryChannelState[_T_co] = attr.ib() + _closed: bool = attr.ib(default=False) + _tasks: Set[Task] = attr.ib(factory=set) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._state.open_receive_channels += 1 def statistics(self): return self._state.statistics() - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), id(self._state) ) @enable_ki_protection - def receive_nowait(self): + def receive_nowait(self) -> _T_co: """Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing ready to receive, raises `WouldBlock` instead of blocking. @@ -284,7 +296,7 @@ def receive_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def receive(self): + async def receive(self) -> _T_co: """See `ReceiveChannel.receive `. Memory channels allow multiple tasks to call `receive` at the same @@ -311,10 +323,10 @@ def abort_fn(_): del self._state.receive_tasks[task] return trio.lowlevel.Abort.SUCCEEDED - return await trio.lowlevel.wait_task_rescheduled(abort_fn) + return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[return-value] @enable_ki_protection - def clone(self): + def clone(self) -> "MemoryReceiveChannel[_T_co]": """Clone this receive channel object. This returns a new `MemoryReceiveChannel` object, which acts as a @@ -352,7 +364,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() @enable_ki_protection - def close(self): + def close(self) -> None: """Close this receive channel object synchronously. All channel objects have an asynchronous `~.AsyncResource.aclose` method. @@ -381,6 +393,6 @@ def close(self): self._state.data.clear() @enable_ki_protection - async def aclose(self): + async def aclose(self) -> None: self.close() await trio.lowlevel.checkpoint() diff --git a/trio/_core/_asyncgens.py b/trio/_core/_asyncgens.py index cb88a1d57b..284dc24e54 100644 --- a/trio/_core/_asyncgens.py +++ b/trio/_core/_asyncgens.py @@ -173,7 +173,7 @@ async def finalize_remaining(self, runner): for agen in batch: await self._finalize_one(agen, name_asyncgen(agen)) - def close(self): + def close(self) -> None: sys.set_asyncgen_hooks(*self.prev_hooks) async def _finalize_one(self, agen, name): diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py index a1587a18cd..c80f6ba7df 100644 --- a/trio/_core/_entry_queue.py +++ b/trio/_core/_entry_queue.py @@ -70,7 +70,7 @@ async def kill_everything(exc): # This has to be carefully written to be safe in the face of new items # being queued while we iterate, and to do a bounded amount of work on # each pass: - def run_all_bounded(): + def run_all_bounded() -> None: for _ in range(len(self.queue)): run_cb(self.queue.popleft()) for job in list(self.idempotent_queue): @@ -104,7 +104,7 @@ def run_all_bounded(): assert not self.queue assert not self.idempotent_queue - def close(self): + def close(self) -> None: self.wakeup.close() def size(self): @@ -147,7 +147,7 @@ class TrioToken(metaclass=NoPublicConstructor): __slots__ = ("_reentry_queue",) - def __init__(self, reentry_queue): + def __init__(self, reentry_queue) -> None: self._reentry_queue = reentry_queue def run_sync_soon(self, sync_fn, *args, idempotent=False): diff --git a/trio/_core/_exceptions.py b/trio/_core/_exceptions.py index 6189c484b4..3d5c1d831d 100644 --- a/trio/_core/_exceptions.py +++ b/trio/_core/_exceptions.py @@ -61,7 +61,7 @@ class Cancelled(BaseException, metaclass=NoPublicConstructor): """ - def __str__(self): + def __str__(self) -> str: return "Cancelled" diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py index 986ab2c7f5..dc0495fae7 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/trio/_core/_generated_instrumentation.py @@ -1,10 +1,31 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +import select +import socket +import sys +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from .. import _core +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index 9ae54e4f68..7b85cc837c 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -1,14 +1,38 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +import select +import socket +import sys +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from .. import _core +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off -async def wait_readable(fd): +assert not TYPE_CHECKING or sys.platform == 'linux' + + +async def wait_readable(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -16,7 +40,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -24,7 +48,7 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 7549899dbe..a52a5ec41f 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -1,14 +1,38 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +import select +import socket +import sys +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from .. import _core +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off -def current_kqueue(): +assert not TYPE_CHECKING or sys.platform != 'linux' and sys.platform != 'win32' + + +def current_kqueue() ->select.kqueue: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() @@ -16,7 +40,8 @@ def current_kqueue(): raise RuntimeError("must be called from async context") -def monitor_kevent(ident, filter): +def monitor_kevent(ident: int, filter: int) ->ContextManager[ + '_core.UnboundedQueue[Task]']: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) @@ -24,7 +49,8 @@ def monitor_kevent(ident, filter): raise RuntimeError("must be called from async context") -async def wait_kevent(ident, filter, abort_func): +async def wait_kevent(ident: int, filter: int, abort_func: Callable[[ + Callable[[], None]], '_core.Abort']) ->object: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func) @@ -32,7 +58,7 @@ async def wait_kevent(ident, filter, abort_func): raise RuntimeError("must be called from async context") -async def wait_readable(fd): +async def wait_readable(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -40,7 +66,7 @@ async def wait_readable(fd): raise RuntimeError("must be called from async context") -async def wait_writable(fd): +async def wait_writable(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -48,7 +74,7 @@ async def wait_writable(fd): raise RuntimeError("must be called from async context") -def notify_closing(fd): +def notify_closing(fd: Union[int, _HasFileno, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index e6337e94b0..6b1cd622cd 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -1,14 +1,38 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +import select +import socket +import sys +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from .. import _core +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off -async def wait_readable(sock): +assert not TYPE_CHECKING or sys.platform == 'win32' + + +async def wait_readable(sock: Union[int, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) @@ -16,7 +40,7 @@ async def wait_readable(sock): raise RuntimeError("must be called from async context") -async def wait_writable(sock): +async def wait_writable(sock: Union[int, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) @@ -24,7 +48,7 @@ async def wait_writable(sock): raise RuntimeError("must be called from async context") -def notify_closing(handle): +def notify_closing(handle: Union[int, socket.socket]) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) @@ -32,7 +56,7 @@ def notify_closing(handle): raise RuntimeError("must be called from async context") -def register_with_iocp(handle): +def register_with_iocp(handle: int) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) @@ -40,7 +64,8 @@ def register_with_iocp(handle): raise RuntimeError("must be called from async context") -async def wait_overlapped(handle, lpOverlapped): +async def wait_overlapped(handle: socket.socket, lpOverlapped: Union[int, + object]) ->object: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(handle, lpOverlapped) @@ -48,7 +73,7 @@ async def wait_overlapped(handle, lpOverlapped): raise RuntimeError("must be called from async context") -async def write_overlapped(handle, data, file_offset=0): +async def write_overlapped(handle: int, data: bytes, file_offset: int=0) ->int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(handle, data, file_offset) @@ -56,7 +81,8 @@ async def write_overlapped(handle, data, file_offset=0): raise RuntimeError("must be called from async context") -async def readinto_overlapped(handle, buffer, file_offset=0): +async def readinto_overlapped(handle: int, buffer: bytearray, file_offset: + int=0) ->int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(handle, buffer, file_offset) @@ -64,7 +90,7 @@ async def readinto_overlapped(handle, buffer, file_offset=0): raise RuntimeError("must be called from async context") -def current_iocp(): +def current_iocp() ->int: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() @@ -72,7 +98,8 @@ def current_iocp(): raise RuntimeError("must be called from async context") -def monitor_completion_key(): +def monitor_completion_key() ->ContextManager[Tuple[int, + '_core.UnboundedQueue[CompletionKeyEventInfo]']]: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 1272b4c73c..8789f52b91 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -1,14 +1,35 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +import select +import socket +import sys +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from .. import _core +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off -def current_statistics(): +def current_statistics() ->_RunStatistics: """Returns an object containing run-loop-level debugging information. Currently the following fields are defined: @@ -38,7 +59,7 @@ def current_statistics(): raise RuntimeError("must be called from async context") -def current_time(): +def current_time() ->float: """Returns the current time according to Trio's internal clock. Returns: @@ -55,7 +76,7 @@ def current_time(): raise RuntimeError("must be called from async context") -def current_clock(): +def current_clock() ->Clock: """Returns the current :class:`~trio.abc.Clock`.""" locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: @@ -64,7 +85,7 @@ def current_clock(): raise RuntimeError("must be called from async context") -def current_root_task(): +def current_root_task() ->Optional[Task]: """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. @@ -77,7 +98,7 @@ def current_root_task(): raise RuntimeError("must be called from async context") -def reschedule(task, next_send=_NO_SEND): +def reschedule(task: Task, next_send: object=_NO_SEND) ->None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -102,7 +123,8 @@ def reschedule(task, next_send=_NO_SEND): raise RuntimeError("must be called from async context") -def spawn_system_task(async_fn, *args, name=None): +def spawn_system_task(async_fn: Callable[..., Awaitable[object]], *args: + object, name: Optional[Union[str, Callable]]=None) ->Task: """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -157,7 +179,7 @@ def spawn_system_task(async_fn, *args, name=None): raise RuntimeError("must be called from async context") -def current_trio_token(): +def current_trio_token() ->TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. @@ -169,7 +191,7 @@ def current_trio_token(): raise RuntimeError("must be called from async context") -async def wait_all_tasks_blocked(cushion=0.0): +async def wait_all_tasks_blocked(cushion: float=0.0) ->None: """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a @@ -203,7 +225,7 @@ async def lock_taker(lock): await lock.acquire() lock.release() - async def test_lock_fairness(): + async def test_lock_fairness() -> None: lock = trio.Lock() await lock.acquire() async with trio.open_nursery() as nursery: diff --git a/trio/_core/_instrumentation.py b/trio/_core/_instrumentation.py index e14c1ef1e0..ea990c0b95 100644 --- a/trio/_core/_instrumentation.py +++ b/trio/_core/_instrumentation.py @@ -29,7 +29,7 @@ class Instruments(Dict[str, Dict[Instrument, None]]): __slots__ = () - def __init__(self, incoming: Sequence[Instrument]): + def __init__(self, incoming: Sequence[Instrument]) -> None: self["_all"] = {} for instrument in incoming: self.add_instrument(instrument) diff --git a/trio/_core/_io_common.py b/trio/_core/_io_common.py index 9891849bc9..e9395dbf0b 100644 --- a/trio/_core/_io_common.py +++ b/trio/_core/_io_common.py @@ -1,10 +1,20 @@ import copy +from typing import Optional + import outcome from .. import _core +from . import _run + +from typing_extensions import Protocol + + +class Waiter(Protocol): + read_task: Optional[_run.Task] + write_task: Optional[_run.Task] # Utility function shared between _io_epoll and _io_windows -def wake_all(waiters, exc): +def wake_all(waiters: Waiter, exc: Exception) -> None: try: current_task = _core.current_task() except RuntimeError: diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index c1537cf53e..46e1e81bd4 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -1,10 +1,14 @@ import select +import socket import sys import attr from collections import defaultdict -from typing import Dict, TYPE_CHECKING +from typing import DefaultDict, Dict, TYPE_CHECKING, Union + +from typing_extensions import Protocol from .. import _core +from .._typing import _HasFileno from ._run import _public from ._io_common import wake_all from ._wakeup_socketpair import WakeupSocketpair @@ -186,15 +190,15 @@ class EpollWaiters: @attr.s(slots=True, eq=False, hash=False) class EpollIOManager: - _epoll = attr.ib(factory=select.epoll) + _epoll: select.epoll = attr.ib(factory=select.epoll) # {fd: EpollWaiters} - _registered = attr.ib( - factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters] + _registered: DefaultDict[int, EpollWaiters] = attr.ib( + factory=lambda: defaultdict(EpollWaiters) ) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: int = attr.ib(default=None) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() @@ -211,11 +215,11 @@ def statistics(self): tasks_waiting_write=tasks_waiting_write, ) - def close(self): + def close(self) -> None: self._epoll.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() # Return value must be False-y IFF the timeout expired, NOT if any I/O @@ -295,15 +299,18 @@ def abort(_): await _core.wait_task_rescheduled(abort) @_public - async def wait_readable(self, fd): + async def wait_readable( + self, + fd: Union[int, _HasFileno, socket.socket], + ) -> None: await self._epoll_wait(fd, "read_task") @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: Union[int, _HasFileno, socket.socket]) -> None: await self._epoll_wait(fd, "write_task") @_public - def notify_closing(self, fd): + def notify_closing(self, fd: Union[int, _HasFileno, socket.socket]) -> None: if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 31940d5694..e5bc31b3ae 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -1,6 +1,7 @@ import select +import socket import sys -from typing import TYPE_CHECKING +from typing import Callable, Dict, Iterator, Optional, Tuple, TYPE_CHECKING, Union import outcome from contextlib import contextmanager @@ -8,7 +9,9 @@ import errno from .. import _core -from ._run import _public +from .._typing import _HasFileno +from ._run import _public, Task +from ._unbounded_queue import UnboundedQueue from ._wakeup_socketpair import WakeupSocketpair assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") @@ -23,13 +26,15 @@ class _KqueueStatistics: @attr.s(slots=True, eq=False) class KqueueIOManager: - _kqueue = attr.ib(factory=select.kqueue) + _kqueue: select.kqueue = attr.ib(factory=select.kqueue) # {(ident, filter): Task or UnboundedQueue} - _registered = attr.ib(factory=dict) - _force_wakeup = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd = attr.ib(default=None) + _registered: Dict[Tuple[int, int], Union[Task, "UnboundedQueue[Task]"]] = attr.ib( + factory=dict + ) + _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd: Optional[int] = attr.ib(default=None) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: force_wakeup_event = select.kevent( self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD ) @@ -46,11 +51,11 @@ def statistics(self): monitors += 1 return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors) - def close(self): + def close(self) -> None: self._kqueue.close() self._force_wakeup.close() - def force_wakeup(self): + def force_wakeup(self) -> None: self._force_wakeup.wakeup_thread_and_signal_safe() def get_events(self, timeout): @@ -96,18 +101,20 @@ def process_events(self, events): # be more ergonomic... @_public - def current_kqueue(self): + def current_kqueue(self) -> select.kqueue: return self._kqueue @contextmanager @_public - def monitor_kevent(self, ident, filter): + def monitor_kevent( + self, ident: int, filter: int + ) -> Iterator["_core.UnboundedQueue[Task]"]: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( "attempt to register multiple listeners for same ident/filter pair" ) - q = _core.UnboundedQueue() + q = _core.UnboundedQueue[Task]() self._registered[key] = q try: yield q @@ -115,7 +122,12 @@ def monitor_kevent(self, ident, filter): del self._registered[key] @_public - async def wait_kevent(self, ident, filter, abort_func): + async def wait_kevent( + self, + ident: int, + filter: int, + abort_func: Callable[[Callable[[], None]], "_core.Abort"], + ) -> object: key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( @@ -131,7 +143,7 @@ def abort(raise_cancel): return await _core.wait_task_rescheduled(abort) - async def _wait_common(self, fd, filter): + async def _wait_common(self, fd: Union[int, _HasFileno], filter: int) -> None: if not isinstance(fd, int): fd = fd.fileno() flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT @@ -163,15 +175,15 @@ def abort(_): await self.wait_kevent(fd, filter, abort) @_public - async def wait_readable(self, fd): + async def wait_readable(self, fd: Union[int, _HasFileno, socket.socket]) -> None: await self._wait_common(fd, select.KQ_FILTER_READ) @_public - async def wait_writable(self, fd): + async def wait_writable(self, fd: Union[int, _HasFileno, socket.socket]) -> None: await self._wait_common(fd, select.KQ_FILTER_WRITE) @_public - def notify_closing(self, fd): + def notify_closing(self, fd: Union[int, _HasFileno, socket.socket]) -> None: if not isinstance(fd, int): fd = fd.fileno() @@ -182,7 +194,9 @@ def notify_closing(self, fd): if receiver is None: continue - if type(receiver) is _core.Task: + # if type(receiver) is _core.Task: + # TODO: is this unacceptably less specific? + if isinstance(receiver, _core.Task): event = select.kevent(fd, filter, select.KQ_EV_DELETE) self._kqueue.control([event], 0) exc = _core.ClosedResourceError("another task closed this fd") diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 6d3994499f..275d8699b7 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -3,7 +3,7 @@ import enum import socket import sys -from typing import TYPE_CHECKING +from typing import Iterator, Tuple, TYPE_CHECKING, Union import attr from outcome import Value @@ -387,7 +387,7 @@ class CompletionKeyEventInfo: class WindowsIOManager: - def __init__(self): + def __init__(self) -> None: # If this method raises an exception, then __del__ could run on a # half-initialized object. So we initialize everything that __del__ # touches to safe values up front, before we do anything that can @@ -450,7 +450,7 @@ def __init__(self): "netsh winsock show catalog" ) - def close(self): + def close(self) -> None: try: if self._iocp is not None: iocp = self._iocp @@ -461,7 +461,7 @@ def close(self): afd_handle = self._all_afd_handles.pop() _check(kernel32.CloseHandle(afd_handle)) - def __del__(self): + def __del__(self) -> None: self.close() def statistics(self): @@ -479,7 +479,7 @@ def statistics(self): completion_key_monitors=len(self._completion_key_queues), ) - def force_wakeup(self): + def force_wakeup(self) -> None: _check( kernel32.PostQueuedCompletionStatus( self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL @@ -697,15 +697,15 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) @_public - async def wait_readable(self, sock): + async def wait_readable(self, sock: Union[int, socket.socket]) -> None: await self._afd_poll(sock, "read_task") @_public - async def wait_writable(self, sock): + async def wait_writable(self, sock: Union[int, socket.socket]) -> None: await self._afd_poll(sock, "write_task") @_public - def notify_closing(self, handle): + def notify_closing(self, handle: Union[int, socket.socket]) -> None: handle = _get_base_socket(handle) waiters = self._afd_waiters.get(handle) if waiters is not None: @@ -717,13 +717,17 @@ def notify_closing(self, handle): ################################################################ @_public - def register_with_iocp(self, handle): + def register_with_iocp(self, handle: int) -> None: self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED) + # TODO: what else can lpOverlapped be? @_public - async def wait_overlapped(self, handle, lpOverlapped): + async def wait_overlapped( + self, handle: socket.socket, lpOverlapped: Union[int, object] + ) -> object: handle = _handle(handle) if isinstance(lpOverlapped, int): + # TODO: figure out how to hint this? lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped) if lpOverlapped in self._overlapped_waiters: raise _core.BusyResourceError( @@ -769,11 +773,11 @@ def abort(raise_cancel_): return _core.Abort.FAILED info = await _core.wait_task_rescheduled(abort) - if lpOverlapped.Internal != 0: + if lpOverlapped.Internal != 0: # type: ignore[attr-defined] # the lpOverlapped reports the error as an NT status code, # which we must convert back to a Win32 error code before # it will produce the right sorts of exceptions - code = ntdll.RtlNtStatusToDosError(lpOverlapped.Internal) + code = ntdll.RtlNtStatusToDosError(lpOverlapped.Internal) # type: ignore[attr-defined] if code == ErrorCodes.ERROR_OPERATION_ABORTED: if raise_cancel is not None: raise_cancel() @@ -805,7 +809,9 @@ async def _perform_overlapped(self, handle, submit_fn): return lpOverlapped @_public - async def write_overlapped(self, handle, data, file_offset=0): + async def write_overlapped( + self, handle: int, data: bytes, file_offset: int = 0 + ) -> int: with ffi.from_buffer(data) as cbuf: def submit_write(lpOverlapped): @@ -825,10 +831,12 @@ def submit_write(lpOverlapped): lpOverlapped = await self._perform_overlapped(handle, submit_write) # this is "number of bytes transferred" - return lpOverlapped.InternalHigh + return lpOverlapped.InternalHigh # type: ignore[no-any-return] @_public - async def readinto_overlapped(self, handle, buffer, file_offset=0): + async def readinto_overlapped( + self, handle: int, buffer: bytearray, file_offset: int = 0 + ) -> int: with ffi.from_buffer(buffer, require_writable=True) as cbuf: def submit_read(lpOverlapped): @@ -846,21 +854,23 @@ def submit_read(lpOverlapped): ) lpOverlapped = await self._perform_overlapped(handle, submit_read) - return lpOverlapped.InternalHigh + return lpOverlapped.InternalHigh # type: ignore[no-any-return] ################################################################ # Raw IOCP operations ################################################################ @_public - def current_iocp(self): + def current_iocp(self) -> int: return int(ffi.cast("uintptr_t", self._iocp)) @contextmanager @_public - def monitor_completion_key(self): + def monitor_completion_key( + self, + ) -> Iterator[Tuple[int, "_core.UnboundedQueue[CompletionKeyEventInfo]"]]: key = next(self._completion_key_counter) - queue = _core.UnboundedQueue() + queue: "_core.UnboundedQueue[CompletionKeyEventInfo]" = _core.UnboundedQueue() self._completion_key_queues[key] = queue try: yield (key, queue) diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index 36aacecd96..34de0b9513 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -2,16 +2,17 @@ import signal import sys from functools import wraps +from types import FrameType +from typing import Any, TypeVar, Callable, Optional, Union import attr import async_generator from .._util import is_main_thread -if False: - from typing import Any, TypeVar, Callable - F = TypeVar("F", bound=Callable[..., Any]) +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) + # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. @@ -83,17 +84,18 @@ # NB: according to the signal.signal docs, 'frame' can be None on entry to # this function: -def ki_protection_enabled(frame): - while frame is not None: - if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals: - return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] - if frame.f_code.co_name == "__del__": +def ki_protection_enabled(frame: FrameType) -> bool: + traversed_frame: Optional[FrameType] = frame + while traversed_frame is not None: + if LOCALS_KEY_KI_PROTECTION_ENABLED in traversed_frame.f_locals: + return traversed_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] # type: ignore[no-any-return] + if traversed_frame.f_code.co_name == "__del__": return True - frame = frame.f_back + traversed_frame = traversed_frame.f_back return True -def currently_ki_protected(): +def currently_ki_protected() -> bool: r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection enabled. @@ -109,15 +111,17 @@ def currently_ki_protected(): return ki_protection_enabled(sys._getframe()) -def _ki_protection_decorator(enabled): - def decorator(fn): +def _ki_protection_decorator(enabled: bool) -> Callable[[_Fn], _Fn]: + def decorator(fn: _Fn) -> _Fn: # In some version of Python, isgeneratorfunction returns true for # coroutine functions, so we have to check for coroutine functions # first. + wrapper: _Fn + if inspect.iscoroutinefunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: # See the comment for regular generators below coro = fn(*args, **kwargs) coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled @@ -127,7 +131,7 @@ def wrapper(*args, **kwargs): elif inspect.isgeneratorfunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: # It's important that we inject this directly into the # generator's locals, as opposed to setting it here and then # doing 'yield from'. The reason is, if a generator is @@ -144,7 +148,7 @@ def wrapper(*args, **kwargs): elif async_generator.isasyncgenfunction(fn): @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: # See the comment for regular generators above agen = fn(*args, **kwargs) agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled @@ -154,7 +158,7 @@ def wrapper(*args, **kwargs): else: @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return fn(*args, **kwargs) @@ -163,18 +167,24 @@ def wrapper(*args, **kwargs): return decorator -enable_ki_protection = _ki_protection_decorator(True) # type: Callable[[F], F] +enable_ki_protection: Callable[[_Fn], _Fn] = _ki_protection_decorator(True) enable_ki_protection.__name__ = "enable_ki_protection" -disable_ki_protection = _ki_protection_decorator(False) # type: Callable[[F], F] +disable_ki_protection: Callable[[_Fn], _Fn] = _ki_protection_decorator(False) disable_ki_protection.__name__ = "disable_ki_protection" @attr.s class KIManager: - handler = attr.ib(default=None) - - def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints): + handler: Optional[ + Callable[[Union[int, signal.Signals], FrameType], object] + ] = attr.ib(default=None) + + def install( + self, + deliver_cb: Callable[[], object], + restrict_keyboard_interrupt_to_checkpoints: bool, + ) -> None: assert self.handler is None if ( not is_main_thread() @@ -182,7 +192,7 @@ def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints): ): return - def handler(signum, frame): + def handler(signum: Union[int, signal.Signals], frame: FrameType) -> None: assert signum == signal.SIGINT protection_enabled = ki_protection_enabled(frame) if protection_enabled or restrict_keyboard_interrupt_to_checkpoints: @@ -193,7 +203,7 @@ def handler(signum, frame): self.handler = handler signal.signal(signal.SIGINT, handler) - def close(self): + def close(self) -> None: if self.handler is not None: if signal.getsignal(signal.SIGINT) is self.handler: signal.signal(signal.SIGINT, signal.default_int_handler) diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 1f64d4ce85..b9a1974b8d 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,25 +1,59 @@ # Runvar implementations +from typing import Generic, overload, TypeVar, Union + from . import _run from .._util import Final -class _RunVarToken: - _no_value = object() +_T = TypeVar("_T") + + +class _NoValue: + pass + + +class _RunVarToken(Generic[_T]): + _no_value = _NoValue() __slots__ = ("_var", "previous_value", "redeemed") @classmethod - def empty(cls, var): + def empty(cls, var: "RunVar[_T]") -> "_RunVarToken[_T]": return cls(var, value=cls._no_value) - def __init__(self, var, value): + def __init__(self, var: "RunVar", value: Union[_NoValue, _T]) -> None: self._var = var self.previous_value = value self.redeemed = False -class RunVar(metaclass=Final): +class _NoDefault: + pass + + +# TODO: ack! this is... not pleasant. But otherwise we hit the exception below when +# testing in 3.6. Part of cleaning this up is undoing the skip in +# test_classes_are_final(). +# ImportError while loading conftest '/home/altendky/repos/trio/trio/tests/conftest.py'. +# trio/__init__.py:67: in +# from ._highlevel_socket import SocketStream, SocketListener +# trio/_highlevel_socket.py:8: in +# from . import socket as tsocket +# trio/socket.py:9: in +# from . import _socket +# trio/_socket.py:83: in +# _resolver = _core.RunVar[Optional[HostnameResolver]]("hostname_resolver") +# ../../.pyenv/versions/3.6.12/lib/python3.6/typing.py:682: in inner +# return func(*args, **kwds) +# ../../.pyenv/versions/3.6.12/lib/python3.6/typing.py:1143: in __getitem__ +# orig_bases=self.__orig_bases__) +# E TypeError: __new__() got an unexpected keyword argument 'tvars' +import sys +from .._util import BaseMeta + + +class RunVar(Generic[_T], metaclass=Final if sys.version_info >= (3, 7) else BaseMeta): # type: ignore[misc] """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, @@ -28,30 +62,43 @@ class RunVar(metaclass=Final): """ - _NO_DEFAULT = object() + _NO_DEFAULT = _NoDefault() __slots__ = ("_name", "_default") - def __init__(self, name, default=_NO_DEFAULT): + @overload + def __init__(self, name: str) -> None: + ... + + @overload + def __init__(self, name: str, default: _T) -> None: + ... + + def __init__(self, name: str, default: object = _NO_DEFAULT) -> None: self._name = name self._default = default - def get(self, default=_NO_DEFAULT): + def get(self, default: Union[_NoDefault, _T] = _NO_DEFAULT) -> _T: """Gets the value of this :class:`RunVar` for the current run call.""" + + # Ignoring type hint return complaints since the underlying dict can't really + # be hinted per run local and other options including casting and checking + # instance types would result in runtime overhead. + try: - return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] + return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: # contextvars consistency if default is not self._NO_DEFAULT: - return default + return default # type: ignore[return-value] if self._default is not self._NO_DEFAULT: - return self._default + return self._default # type: ignore[return-value] raise LookupError(self) from None - def set(self, value): + def set(self, value: _T) -> _RunVarToken[_T]: """Sets the value of this :class:`RunVar` for this current run call. @@ -59,16 +106,16 @@ def set(self, value): try: old_value = self.get() except LookupError: - token = _RunVarToken.empty(self) + token = _RunVarToken[_T].empty(self) else: - token = _RunVarToken(self, old_value) + token = _RunVarToken[_T](self, old_value) # This can't fail, because if we weren't in Trio context then the # get() above would have failed. _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value return token - def reset(self, token): + def reset(self, token: _RunVarToken) -> None: """Resets the value of this :class:`RunVar` to what it was previously specified by the token. @@ -93,5 +140,5 @@ def reset(self, token): token.redeemed = True - def __repr__(self): + def __repr__(self) -> str: return "".format(self._name) diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py index 0e95e4e5c5..c86afc6643 100644 --- a/trio/_core/_mock_clock.py +++ b/trio/_core/_mock_clock.py @@ -62,7 +62,7 @@ class MockClock(Clock, metaclass=Final): """ - def __init__(self, rate=0.0, autojump_threshold=inf): + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf) -> None: # when the real clock said 'real_base', the virtual time was # 'virtual_base', and since then it's advanced at 'rate' virtual # seconds per real second. @@ -77,13 +77,13 @@ def __init__(self, rate=0.0, autojump_threshold=inf): self.rate = rate self.autojump_threshold = autojump_threshold - def __repr__(self): + def __repr__(self) -> str: return "".format( self.current_time(), self._rate, id(self) ) @property - def rate(self): + def rate(self) -> float: return self._rate @rate.setter @@ -98,7 +98,7 @@ def rate(self, new_rate): self._rate = float(new_rate) @property - def autojump_threshold(self): + def autojump_threshold(self) -> float: return self._autojump_threshold @autojump_threshold.setter @@ -112,7 +112,7 @@ def autojump_threshold(self, new_autojump_threshold): # API. Discussion: # # https://github.com/python-trio/trio/issues/1587 - def _try_resync_autojump_threshold(self): + def _try_resync_autojump_threshold(self) -> None: try: runner = GLOBAL_RUN_CONTEXT.runner if runner.is_guest: @@ -124,7 +124,7 @@ def _try_resync_autojump_threshold(self): # Invoked by the run loop when runner.clock_autojump_threshold is # exceeded. - def _autojump(self): + def _autojump(self) -> None: statistics = _core.current_statistics() jump = statistics.seconds_to_next_deadline if 0 < jump < inf: @@ -135,7 +135,7 @@ def _real_to_virtual(self, real): virtual_offset = self._rate * real_offset return self._virtual_base + virtual_offset - def start_clock(self): + def start_clock(self) -> None: self._try_resync_autojump_threshold() def current_time(self): diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index 13fc3f3d0f..6fb94b437b 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -1,6 +1,19 @@ import sys import traceback import textwrap +from types import TracebackType +from typing import ( + Callable, + ContextManager, + Iterator, + List, + Optional, + overload, + Set, + Sequence, + Type, + Union, +) import warnings import attr @@ -16,8 +29,23 @@ # MultiError ################################################################ +_Handler = Callable[[BaseException], Optional[BaseException]] -def _filter_impl(handler, root_exc): + +@overload +def _filter_impl(handler: _Handler, root_exc: "MultiError") -> Optional[BaseException]: + ... + + +@overload +def _filter_impl(handler: _Handler, root_exc: BaseException) -> Optional[BaseException]: + ... + + +def _filter_impl( + handler: _Handler, + root_exc: BaseException, +) -> Optional[BaseException]: # We have a tree of MultiError's, like: # # MultiError([ @@ -76,7 +104,7 @@ def _filter_impl(handler, root_exc): # Filters a subtree, ignoring tracebacks, while keeping a record of # which MultiErrors were preserved unchanged - def filter_tree(exc, preserved): + def filter_tree(exc: BaseException, preserved: Set[int]) -> Optional[BaseException]: if isinstance(exc, MultiError): new_exceptions = [] changed = False @@ -100,7 +128,11 @@ def filter_tree(exc, preserved): new_exc.__context__ = exc return new_exc - def push_tb_down(tb, exc, preserved): + def push_tb_down( + tb: Optional[TracebackType], + exc: BaseException, + preserved: Set[int], + ) -> None: if id(exc) in preserved: return new_tb = concat_tb(tb, exc.__traceback__) @@ -111,7 +143,7 @@ def push_tb_down(tb, exc, preserved): else: exc.__traceback__ = new_tb - preserved = set() + preserved: Set[int] = set() new_root_exc = filter_tree(root_exc, preserved) push_tb_down(None, root_exc, preserved) # Delete the local functions to avoid a reference cycle (see @@ -126,13 +158,18 @@ def push_tb_down(tb, exc, preserved): # result: if the exception gets modified, then the 'raise' here makes this # frame show up in the traceback; otherwise, we leave no trace.) @attr.s(frozen=True) -class MultiErrorCatcher: - _handler = attr.ib() +class MultiErrorCatcher(ContextManager[None]): + _handler: _Handler = attr.ib() - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, etype, exc, tb): + def __exit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: if exc is not None: filtered_exc = MultiError.filter(self._handler, exc) @@ -153,6 +190,7 @@ def __exit__(self, etype, exc, tb): _, value, _ = sys.exc_info() assert value is filtered_exc value.__context__ = old_context + return None class MultiError(BaseException): @@ -178,7 +216,7 @@ class MultiError(BaseException): """ - def __init__(self, exceptions): + def __init__(self, exceptions: Sequence[BaseException]) -> None: # Avoid recursion when exceptions[0] returned by __new__() happens # to be a MultiError and subsequently __init__() is called. if hasattr(self, "exceptions"): @@ -187,7 +225,9 @@ def __init__(self, exceptions): return self.exceptions = exceptions - def __new__(cls, exceptions): + def __new__( # type: ignore[misc] + cls, exceptions: List[Union[Exception, "MultiError"]] + ) -> Union[Exception, "MultiError"]: exceptions = list(exceptions) for exc in exceptions: if not isinstance(exc, BaseException): @@ -204,16 +244,20 @@ def __new__(cls, exceptions): # In an earlier version of the code, we didn't define __init__ and # simply set the `exceptions` attribute directly on the new object. # However, linters expect attributes to be initialized in __init__. - return BaseException.__new__(cls, exceptions) + return BaseException.__new__(cls, exceptions) # type: ignore[no-any-return, call-arg] - def __str__(self): + def __str__(self) -> str: return ", ".join(repr(exc) for exc in self.exceptions) - def __repr__(self): + def __repr__(self) -> str: return "".format(self) @classmethod - def filter(cls, handler, root_exc): + def filter( + cls, + handler: _Handler, + root_exc: BaseException, + ) -> Optional[BaseException]: """Apply the given ``handler`` to all the exceptions in ``root_exc``. Args: @@ -232,7 +276,7 @@ def filter(cls, handler, root_exc): return _filter_impl(handler, root_exc) @classmethod - def catch(cls, handler): + def catch(cls, handler: _Handler) -> ContextManager[None]: """Return a context manager that catches and re-throws exceptions after running :meth:`filter` on them. @@ -373,21 +417,21 @@ def concat_tb(head, tail): def traceback_exception_init( - self, - exc_type, - exc_value, - exc_traceback, + self: traceback.TracebackException, + exc_type: Type[BaseException], + exc_value: BaseException, + exc_traceback: TracebackType, *, - limit=None, - lookup_lines=True, - capture_locals=False, - _seen=None, -): + limit: Optional[int] = None, + lookup_lines: bool = True, + capture_locals: bool = False, + _seen: Optional[set] = None, +) -> None: if _seen is None: _seen = set() # Capture the original exception and its cause and context as TracebackExceptions - traceback_exception_original_init( + traceback_exception_original_init( # type: ignore[call-arg] self, exc_type, exc_value, @@ -404,7 +448,7 @@ def traceback_exception_init( for exc in exc_value.exceptions: if exc_key(exc) not in _seen: embedded.append( - traceback.TracebackException.from_exception( + traceback.TracebackException.from_exception( # type: ignore[call-arg] exc, limit=limit, lookup_lines=lookup_lines, @@ -414,27 +458,31 @@ def traceback_exception_init( _seen=set(_seen), ) ) - self.embedded = embedded + self.embedded = embedded # type: ignore[attr-defined] else: - self.embedded = [] + self.embedded = [] # type: ignore[attr-defined] -traceback.TracebackException.__init__ = traceback_exception_init # type: ignore +traceback.TracebackException.__init__ = traceback_exception_init # type: ignore[assignment] traceback_exception_original_format = traceback.TracebackException.format -def traceback_exception_format(self, *, chain=True): +def traceback_exception_format( + self: traceback.TracebackException, *, chain: bool = True +) -> Iterator[str]: yield from traceback_exception_original_format(self, chain=chain) - for i, exc in enumerate(self.embedded): + for i, exc in enumerate(self.embedded): # type: ignore[attr-defined] yield "\nDetails of embedded exception {}:\n\n".format(i + 1) yield from (textwrap.indent(line, " " * 2) for line in exc.format(chain=chain)) -traceback.TracebackException.format = traceback_exception_format # type: ignore +traceback.TracebackException.format = traceback_exception_format # type: ignore[assignment] -def trio_excepthook(etype, value, tb): +def trio_excepthook( + etype: Type[BaseException], value: BaseException, tb: TracebackType +) -> None: for chunk in traceback.format_exception(etype, value, tb): sys.stderr.write(chunk) @@ -457,7 +505,13 @@ def trio_excepthook(etype, value, tb): monkeypatched_or_warned = True else: - def trio_show_traceback(self, etype, value, tb, tb_offset=None): + def trio_show_traceback( + self: object, + etype: Type[BaseException], + value: BaseException, + tb: TracebackType, + tb_offset: Optional[int] = None, + ) -> None: # XX it would be better to integrate with IPython's fancy # exception formatting stuff (and not ignore tb_offset) trio_excepthook(etype, value, tb) @@ -493,7 +547,7 @@ class TrioFakeSysModuleForApport: fake_sys = TrioFakeSysModuleForApport() fake_sys.__dict__.update(sys.__dict__) - fake_sys.__excepthook__ = trio_excepthook # type: ignore + fake_sys.__excepthook__ = trio_excepthook # type: ignore[attr-defined] apport_python_hook.sys = fake_sys monkeypatched_or_warned = True diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index 8b114b5230..84582f75db 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -73,7 +73,9 @@ from itertools import count import attr from collections import OrderedDict +import typing +from ._run import Task from .. import _core from .._util import Final @@ -82,7 +84,7 @@ @attr.s(frozen=True) class _ParkingLotStatistics: - tasks_waiting = attr.ib() + tasks_waiting: int = attr.ib() @attr.s(eq=False, hash=False) @@ -101,7 +103,7 @@ class ParkingLot(metaclass=Final): # {task: None}, we just want a deque where we can quickly delete random # items - _parked = attr.ib(factory=OrderedDict, init=False) + _parked: "OrderedDict[Task, None]" = attr.ib(factory=OrderedDict, init=False) def __len__(self): """Returns the number of parked tasks.""" @@ -116,7 +118,7 @@ def __bool__(self): # line (for false wakeups), then we could have it return a ticket that # abstracts the "place in line" concept. @_core.enable_ki_protection - async def park(self): + async def park(self) -> None: """Park the current task until woken by a call to :meth:`unpark` or :meth:`unpark_all`. @@ -137,7 +139,7 @@ def _pop_several(self, count): yield task @_core.enable_ki_protection - def unpark(self, *, count=1): + def unpark(self, *, count: int = 1) -> typing.List[Task]: """Unpark one or more tasks. This wakes up ``count`` tasks that are blocked in :meth:`park`. If @@ -158,7 +160,7 @@ def unpark_all(self): return self.unpark(count=len(self)) @_core.enable_ki_protection - def repark(self, new_lot, *, count=1): + def repark(self, new_lot: "ParkingLot", *, count: int = 1) -> None: """Move parked tasks from one :class:`ParkingLot` object to another. This dequeues ``count`` tasks from one lot, and requeues them on diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 807e330c1d..3e2ff79c01 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1,5 +1,6 @@ # coding: utf-8 +from contextvars import Context import functools import itertools import logging @@ -18,7 +19,30 @@ from contextvars import copy_context from math import inf from time import perf_counter -from typing import Callable, TYPE_CHECKING +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Coroutine, + Deque, + Dict, + FrozenSet, + Generator, + Iterator, + List, + Optional, + overload, + Sequence, + Set, + Tuple, + Type, + TypeVar, + TYPE_CHECKING, + Union, +) + +from typing_extensions import Protocol from sniffio import current_async_library_cvar @@ -46,7 +70,9 @@ from ._asyncgens import AsyncGenerators from ._thread_cache import start_thread_soon from ._instrumentation import Instruments +from ._local import RunVar from .. import _core +from ..abc import Clock from .._deprecate import warn_deprecated from .._util import Final, NoPublicConstructor, coroutine_or_error @@ -54,10 +80,12 @@ _NO_SEND = object() +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) + # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn): +def _public(fn: _Fn) -> _Fn: return fn @@ -78,7 +106,7 @@ def _public(fn): # # This can all be removed once we drop support for 3.6. def _count_context_run_tb_frames(): - def function_with_unique_name_xyzzy(): + def function_with_unique_name_xyzzy() -> None: 1 / 0 ctx = copy_context() @@ -105,7 +133,7 @@ class SystemClock: # between different runs, then they'll notice the bug quickly: offset = attr.ib(factory=lambda: _r.uniform(10000, 200000)) - def start_clock(self): + def start_clock(self) -> None: pass # In cPython 3, on every platform except Windows, perf_counter is @@ -159,7 +187,7 @@ def next_deadline(self): heappop(self._heap) return inf - def _prune(self): + def _prune(self) -> None: # In principle, it's possible for a cancel scope to toggle back and # forth repeatedly between the same two deadlines, and end up with # lots of stale entries that *look* like they're still active, because @@ -197,6 +225,15 @@ def expire(self, now): return did_something +class Scope(Protocol): + deadline: float + shield: bool + + @property + def cancel_called(self) -> bool: + ... + + @attr.s(eq=False, slots=True) class CancelStatus: """Tracks the cancellation status for a contiguous extent @@ -230,7 +267,7 @@ class CancelStatus: # Our associated cancel scope. Can be any object with attributes # `deadline`, `shield`, and `cancel_called`, but in current usage # is always a CancelScope object. Must not be None. - _scope = attr.ib() + _scope: Scope = attr.ib() # True iff the tasks in self._tasks should receive cancellations # when they checkpoint. Always True when scope.cancel_called is True; @@ -240,31 +277,31 @@ class CancelStatus: # effectively cancelled due to the cancel scope two levels out # becoming cancelled, but then the cancel scope one level out # becomes shielded so we're not effectively cancelled anymore. - effectively_cancelled = attr.ib(default=False) + effectively_cancelled: bool = attr.ib(default=False) # The CancelStatus whose cancellations can propagate to us; we # become effectively cancelled when they do, unless scope.shield # is True. May be None (for the outermost CancelStatus in a call # to trio.run(), briefly during TaskStatus.started(), or during # recovery from mis-nesting of cancel scopes). - _parent = attr.ib(default=None, repr=False) + _parent: Optional["CancelStatus"] = attr.ib(default=None, repr=False) # All of the CancelStatuses that have this CancelStatus as their parent. - _children = attr.ib(factory=set, init=False, repr=False) + _children: Set["CancelStatus"] = attr.ib(factory=set, init=False, repr=False) # Tasks whose cancellation state is currently tied directly to # the cancellation state of this CancelStatus object. Don't modify # this directly; instead, use Task._activate_cancel_status(). # Invariant: all(task._cancel_status is self for task in self._tasks) - _tasks = attr.ib(factory=set, init=False, repr=False) + _tasks: Set["Task"] = attr.ib(factory=set, init=False, repr=False) # Set to True on still-active cancel statuses that are children # of a cancel status that's been closed. This is used to permit # recovery from mis-nested cancel scopes (well, at least enough # recovery to show a useful traceback). - abandoned_by_misnesting = attr.ib(default=False, init=False, repr=False) + abandoned_by_misnesting: bool = attr.ib(default=False, init=False, repr=False) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: if self._parent is not None: self._parent._children.add(self) self.recalculate() @@ -272,11 +309,11 @@ def __attrs_post_init__(self): # parent/children/tasks accessors are used by TaskStatus.started() @property - def parent(self): + def parent(self) -> Optional["CancelStatus"]: return self._parent @parent.setter - def parent(self, parent): + def parent(self, parent: Optional["CancelStatus"]) -> None: if self._parent is not None: self._parent._children.remove(self) self._parent = parent @@ -285,11 +322,11 @@ def parent(self, parent): self.recalculate() @property - def children(self): + def children(self) -> FrozenSet["CancelStatus"]: return frozenset(self._children) @property - def tasks(self): + def tasks(self) -> FrozenSet["Task"]: return frozenset(self._tasks) def encloses(self, other): @@ -302,7 +339,7 @@ def encloses(self, other): other = other.parent return False - def close(self): + def close(self) -> None: self.parent = None # now we're not a child of self.parent anymore if self._tasks or self._children: # Cancel scopes weren't exited in opposite order of being @@ -331,14 +368,14 @@ def close(self): child.recalculate() @property - def parent_cancellation_is_visible_to_us(self): + def parent_cancellation_is_visible_to_us(self) -> bool: return ( self._parent is not None and not self._scope.shield and self._parent.effectively_cancelled ) - def recalculate(self): + def recalculate(self) -> None: # This does a depth-first traversal over this and descendent cancel # statuses, to ensure their state is up-to-date. It's basically a # recursive algorithm, but we use an explicit stack to avoid any @@ -357,7 +394,7 @@ def recalculate(self): task._attempt_delivery_of_any_pending_cancel() todo.extend(current._children) - def _mark_abandoned(self): + def _mark_abandoned(self) -> None: self.abandoned_by_misnesting = True for child in self._children: child._mark_abandoned() @@ -445,7 +482,7 @@ class CancelScope(metaclass=Final): _shield = attr.ib(default=False, kw_only=True) @enable_ki_protection - def __enter__(self): + def __enter__(self) -> "CancelScope": task = _core.current_task() if self._has_been_entered: raise RuntimeError( @@ -529,7 +566,12 @@ def _close(self, exc): return exc @enable_ki_protection - def __exit__(self, etype, exc, tb): + def __exit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: # NB: NurseryManager calls _close() directly rather than __exit__(), # so __exit__() must be just _close() plus this logic for adapting # the exception-filtering result to the context manager API. @@ -552,7 +594,7 @@ def __exit__(self, etype, exc, tb): assert value is remaining_error_after_cancel_scope value.__context__ = old_context - def __repr__(self): + def __repr__(self) -> str: if self._cancel_status is not None: binding = "active" elif self._has_been_entered: @@ -579,7 +621,7 @@ def __repr__(self): @contextmanager @enable_ki_protection - def _might_change_registered_deadline(self): + def _might_change_registered_deadline(self) -> Iterator[None]: try: yield finally: @@ -603,7 +645,7 @@ def _might_change_registered_deadline(self): runner.force_guest_tick_asap() @property - def deadline(self): + def deadline(self) -> float: """Read-write, :class:`float`. An absolute time on the current run's clock at which this scope will automatically become cancelled. You can adjust the deadline by modifying this @@ -629,12 +671,12 @@ def deadline(self): return self._deadline @deadline.setter - def deadline(self, new_deadline): + def deadline(self, new_deadline: float) -> None: with self._might_change_registered_deadline(): self._deadline = float(new_deadline) @property - def shield(self): + def shield(self) -> bool: """Read-write, :class:`bool`, default :data:`False`. So long as this is set to :data:`True`, then the code inside this scope will not receive :exc:`~trio.Cancelled` exceptions from scopes @@ -657,9 +699,10 @@ def shield(self): """ return self._shield - @shield.setter # type: ignore # "decorated property not supported" + # ignore for "decorated property not supported" + @shield.setter # type: ignore[misc] @enable_ki_protection - def shield(self, new_value): + def shield(self, new_value: bool) -> None: if not isinstance(new_value, bool): raise TypeError("shield must be a bool") self._shield = new_value @@ -667,7 +710,7 @@ def shield(self, new_value): self._cancel_status.recalculate() @enable_ki_protection - def cancel(self): + def cancel(self) -> None: """Cancels this scope immediately. This method is idempotent, i.e., if the scope was already @@ -681,7 +724,7 @@ def cancel(self): self._cancel_status.recalculate() @property - def cancel_called(self): + def cancel_called(self) -> bool: """Readonly :class:`bool`. Records whether cancellation has been requested for this scope, either by an explicit call to :meth:`cancel` or by the deadline expiring. @@ -724,7 +767,7 @@ class _TaskStatus: _called_started = attr.ib(default=False) _value = attr.ib(default=None) - def __repr__(self): + def __repr__(self) -> str: return "".format(id(self)) def started(self, value=None): @@ -791,14 +834,19 @@ class NurseryManager: """ @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self) -> "Nursery": self._scope = CancelScope() self._scope.__enter__() self._nursery = Nursery._create(current_task(), self._scope) return self._nursery @enable_ki_protection - async def __aexit__(self, etype, exc, tb): + async def __aexit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: new_exc = await self._nursery._nested_child_finished(exc) # Tracebacks show the 'raise' line below out of context, so let's give # this variable a name that makes sense out of context. @@ -818,12 +866,12 @@ async def __aexit__(self, etype, exc, tb): assert value is combined_error_from_nursery value.__context__ = old_context - def __enter__(self): + def __enter__(self) -> None: raise RuntimeError( "use 'async with open_nursery(...)', not 'with open_nursery(...)'" ) - def __exit__(self): # pragma: no cover + def __exit__(self) -> None: # pragma: no cover assert False, """Never called, but should be defined""" @@ -860,7 +908,7 @@ class Nursery(metaclass=NoPublicConstructor): in response to some external event. """ - def __init__(self, parent_task, cancel_scope): + def __init__(self, parent_task: "Task", cancel_scope: CancelScope) -> None: self._parent_task = parent_task parent_task._child_nurseries.append(self) # the cancel status that children inherit - we take a snapshot, so it @@ -870,8 +918,8 @@ def __init__(self, parent_task, cancel_scope): # children. self.cancel_scope = cancel_scope assert self.cancel_scope._cancel_status is self._cancel_status - self._children = set() - self._pending_excs = [] + self._children: Set["Task"] = set() + self._pending_excs: List[Exception] = [] # The "nested child" is how this code refers to the contents of the # nursery's 'async with' block, which acts like a child Task in all # the ways we can make it. @@ -881,13 +929,13 @@ def __init__(self, parent_task, cancel_scope): self._closed = False @property - def child_tasks(self): + def child_tasks(self) -> FrozenSet["Task"]: """(`frozenset`): Contains all the child :class:`~trio.lowlevel.Task` objects which are still running.""" return frozenset(self._children) @property - def parent_task(self): + def parent_task(self) -> "Task": "(`~trio.lowlevel.Task`): The Task that opened this nursery." return self._parent_task @@ -895,7 +943,7 @@ def _add_exc(self, exc): self._pending_excs.append(exc) self.cancel_scope.cancel() - def _check_nursery_closed(self): + def _check_nursery_closed(self) -> None: if not any([self._nested_child_running, self._children, self._pending_starts]): self._closed = True if self._parent_waiting_in_aexit: @@ -1051,7 +1099,7 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): self._pending_starts -= 1 self._check_nursery_closed() - def __del__(self): + def __del__(self) -> None: assert not self._children @@ -1062,13 +1110,13 @@ def __del__(self): @attr.s(eq=False, hash=False, repr=False) class Task(metaclass=NoPublicConstructor): - _parent_nursery = attr.ib() - coro = attr.ib() - _runner = attr.ib() - name = attr.ib() + _parent_nursery: Nursery = attr.ib() + coro: Coroutine = attr.ib() + _runner: "Runner" = attr.ib() + name: str = attr.ib() # PEP 567 contextvars context - context = attr.ib() - _counter = attr.ib(init=False, factory=itertools.count().__next__) + context: Context = attr.ib() + _counter: int = attr.ib(init=False, factory=itertools.count().__next__) # Invariant: # - for unscheduled tasks, _next_send_fn and _next_send are both None @@ -1081,26 +1129,26 @@ class Task(metaclass=NoPublicConstructor): # tracebacks with extraneous frames. # - for scheduled tasks, custom_sleep_data is None # Tasks start out unscheduled. - _next_send_fn = attr.ib(default=None) - _next_send = attr.ib(default=None) - _abort_func = attr.ib(default=None) - custom_sleep_data = attr.ib(default=None) + _next_send_fn: Callable = attr.ib(default=None) + _next_send: Optional[Union[Outcome, Exception, MultiError]] = attr.ib(default=None) + _abort_func: Callable = attr.ib(default=None) + custom_sleep_data: Any = attr.ib(default=None) # For introspection and nursery.start() - _child_nurseries = attr.ib(factory=list) - _eventual_parent_nursery = attr.ib(default=None) + _child_nurseries: List[Nursery] = attr.ib(factory=list) + _eventual_parent_nursery: Optional[Nursery] = attr.ib(default=None) # these are counts of how many cancel/schedule points this task has # executed, for assert{_no,}_checkpoints # XX maybe these should be exposed as part of a statistics() method? - _cancel_points = attr.ib(default=0) - _schedule_points = attr.ib(default=0) + _cancel_points: int = attr.ib(default=0) + _schedule_points: int = attr.ib(default=0) - def __repr__(self): + def __repr__(self) -> str: return "".format(self.name, id(self)) @property - def parent_nursery(self): + def parent_nursery(self) -> Nursery: """The nursery this task is inside (or None if this is the "init" task). @@ -1111,7 +1159,7 @@ def parent_nursery(self): return self._parent_nursery @property - def eventual_parent_nursery(self): + def eventual_parent_nursery(self) -> Optional[Nursery]: """The nursery this task will be inside after it calls ``task_status.started()``. @@ -1123,7 +1171,7 @@ def eventual_parent_nursery(self): return self._eventual_parent_nursery @property - def child_nurseries(self): + def child_nurseries(self) -> List[Nursery]: """The nurseries this task contains. This is a list, with outer nurseries before inner nurseries. @@ -1137,7 +1185,7 @@ def child_nurseries(self): # The CancelStatus object that is currently active for this task. # Don't change this directly; instead, use _activate_cancel_status(). - _cancel_status = attr.ib(default=None, repr=False) + _cancel_status: Optional[CancelStatus] = attr.ib(default=None, repr=False) def _activate_cancel_status(self, cancel_status): if self._cancel_status is not None: @@ -1162,23 +1210,23 @@ def _attempt_abort(self, raise_cancel): if success is Abort.SUCCEEDED: self._runner.reschedule(self, capture(raise_cancel)) - def _attempt_delivery_of_any_pending_cancel(self): + def _attempt_delivery_of_any_pending_cancel(self) -> None: if self._abort_func is None: return if not self._cancel_status.effectively_cancelled: return - def raise_cancel(): + def raise_cancel() -> None: raise Cancelled._create() self._attempt_abort(raise_cancel) - def _attempt_delivery_of_pending_ki(self): + def _attempt_delivery_of_pending_ki(self) -> None: assert self._runner.ki_pending if self._abort_func is None: return - def raise_cancel(): + def raise_cancel() -> None: self._runner.ki_pending = False raise KeyboardInterrupt @@ -1228,13 +1276,15 @@ class _RunStatistics: # worker thread. @attr.s(eq=False, hash=False, slots=True) class GuestState: - runner = attr.ib() - run_sync_soon_threadsafe = attr.ib() - run_sync_soon_not_threadsafe = attr.ib() - done_callback = attr.ib() - unrolled_run_gen = attr.ib() + runner: "Runner" = attr.ib() + run_sync_soon_threadsafe: Callable[[Callable[[], object]], object] = attr.ib() + run_sync_soon_not_threadsafe: Optional[ + Callable[[Callable[[], object]], object] + ] = attr.ib() + done_callback: Callable[[Outcome], object] = attr.ib() + unrolled_run_gen: Generator[int, List[Tuple[int, int]], None] = attr.ib() _value_factory: Callable[[], Value] = lambda: Value(None) - unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome) + unrolled_run_next_send: Outcome = attr.ib(factory=_value_factory) def guest_tick(self): try: @@ -1261,7 +1311,7 @@ def get_events(): return self.runner.io_manager.get_events(timeout) def deliver(events_outcome): - def in_main_thread(): + def in_main_thread() -> None: self.unrolled_run_next_send = events_outcome self.runner.guest_tick_scheduled = True self.guest_tick() @@ -1273,43 +1323,46 @@ def in_main_thread(): @attr.s(eq=False, hash=False, slots=True) class Runner: - clock = attr.ib() + clock: Clock = attr.ib() instruments: Instruments = attr.ib() - io_manager = attr.ib() - ki_manager = attr.ib() + # TODO: It seems that down at the bottom kqueue is the IO manager chosen for for + # type checking. Seems like there ought to be a protocol or union here. + # io_manager: Union["KqueueIOManager", "EpollIOManager", "WindowsIOManager"] = attr.ib() + io_manager: "TheIOManager" = attr.ib() + ki_manager: KIManager = attr.ib() # Run-local values, see _local.py - _locals = attr.ib(factory=dict) + _locals: Dict[RunVar, object] = attr.ib(factory=dict) - runq = attr.ib(factory=deque) - tasks = attr.ib(factory=set) + runq: Deque[Task] = attr.ib(factory=deque) + tasks: Set[Task] = attr.ib(factory=set) - deadlines = attr.ib(factory=Deadlines) + deadlines: Deadlines = attr.ib(factory=Deadlines) - init_task = attr.ib(default=None) - system_nursery = attr.ib(default=None) - system_context = attr.ib(default=None) - main_task = attr.ib(default=None) - main_task_outcome = attr.ib(default=None) + init_task: Optional[Task] = attr.ib(default=None) + system_nursery: Optional[Nursery] = attr.ib(default=None) + system_context: Optional[Context] = attr.ib(default=None) + main_task: Optional[Task] = attr.ib(default=None) + main_task_outcome: Optional[Outcome] = attr.ib(default=None) - entry_queue = attr.ib(factory=EntryQueue) - trio_token = attr.ib(default=None) - asyncgens = attr.ib(factory=AsyncGenerators) + entry_queue: EntryQueue = attr.ib(factory=EntryQueue) + trio_token: TrioToken = attr.ib(default=None) + asyncgens: AsyncGenerators = attr.ib(factory=AsyncGenerators) # If everything goes idle for this long, we call clock._autojump() - clock_autojump_threshold = attr.ib(default=inf) + clock_autojump_threshold: float = attr.ib(default=inf) # Guest mode stuff - is_guest = attr.ib(default=False) - guest_tick_scheduled = attr.ib(default=False) + is_guest: bool = attr.ib(default=False) + guest_tick_scheduled: bool = attr.ib(default=False) - def force_guest_tick_asap(self): + def force_guest_tick_asap(self) -> None: if self.guest_tick_scheduled: return self.guest_tick_scheduled = True self.io_manager.force_wakeup() - def close(self): + def close(self) -> None: self.io_manager.close() self.entry_queue.close() self.asyncgens.close() @@ -1319,7 +1372,7 @@ def close(self): self.ki_manager.close() @_public - def current_statistics(self): + def current_statistics(self) -> _RunStatistics: """Returns an object containing run-loop-level debugging information. Currently the following fields are defined: @@ -1352,7 +1405,7 @@ def current_statistics(self): ) @_public - def current_time(self): + def current_time(self) -> float: """Returns the current time according to Trio's internal clock. Returns: @@ -1365,12 +1418,12 @@ def current_time(self): return self.clock.current_time() @_public - def current_clock(self): + def current_clock(self) -> Clock: """Returns the current :class:`~trio.abc.Clock`.""" return self.clock @_public - def current_root_task(self): + def current_root_task(self) -> Optional[Task]: """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. @@ -1382,8 +1435,16 @@ def current_root_task(self): # Core task handling primitives ################ + @overload + def reschedule(self, task: Task) -> None: + ... + + @overload + def reschedule(self, task: Task, next_send: Outcome) -> None: + ... + @_public - def reschedule(self, task, next_send=_NO_SEND): + def reschedule(self, task: Task, next_send: object = _NO_SEND) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -1416,7 +1477,15 @@ def reschedule(self, task, next_send=_NO_SEND): if "task_scheduled" in self.instruments: self.instruments.call("task_scheduled", task) - def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False): + def spawn_impl( + self, + async_fn: Callable[..., Awaitable[object]], + args: Sequence[object], + nursery: Optional[Nursery], + name: Optional[Union[str, Callable]], + *, + system_task: bool = False, + ) -> Task: ###### # Make sure the nursery is in working order @@ -1447,7 +1516,7 @@ def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False): name = repr(name) if system_task: - context = self.system_context.copy() + context = self.system_context.copy() # type: ignore[union-attr] else: context = copy_context() @@ -1527,7 +1596,12 @@ def task_exited(self, task, outcome): ################ @_public - def spawn_system_task(self, async_fn, *args, name=None): + def spawn_system_task( # type: ignore[misc] + self, + async_fn: Callable[..., Awaitable[object]], + *args: object, + name: Optional[Union[str, Callable]] = None, + ) -> Task: """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -1617,7 +1691,7 @@ async def init(self, async_fn, args): ################ @_public - def current_trio_token(self): + def current_trio_token(self) -> TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. @@ -1630,7 +1704,7 @@ def current_trio_token(self): # KI handling ################ - ki_pending = attr.ib(default=False) + ki_pending: bool = attr.ib(default=False) # deliver_ki is broke. Maybe move all the actual logic and state into # RunToken, and we'll only have one instance per runner? But then we can't @@ -1639,14 +1713,14 @@ def current_trio_token(self): # keep the class public so people can isinstance() it if they want. # This gets called from signal context - def deliver_ki(self): + def deliver_ki(self) -> None: self.ki_pending = True try: self.entry_queue.run_sync_soon(self._deliver_ki_cb) except RunFinishedError: pass - def _deliver_ki_cb(self): + def _deliver_ki_cb(self) -> None: if not self.ki_pending: return # Can't happen because main_task and run_sync_soon_task are created at @@ -1663,10 +1737,11 @@ def _deliver_ki_cb(self): # Quiescing ################ - waiting_for_idle = attr.ib(factory=SortedDict) + # TODO: how to hint a SortedDict with it's content type as well? + waiting_for_idle: SortedDict = attr.ib(factory=SortedDict) @_public - async def wait_all_tasks_blocked(self, cushion=0.0): + async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a @@ -1700,7 +1775,7 @@ async def lock_taker(lock): await lock.acquire() lock.release() - async def test_lock_fairness(): + async def test_lock_fairness() -> None: lock = trio.Lock() await lock.acquire() async with trio.open_nursery() as nursery: @@ -2261,7 +2336,7 @@ def unrolled_run(runner, async_fn, args, host_uses_signal_set_wakeup_fd=False): class _TaskStatusIgnored: - def __repr__(self): + def __repr__(self) -> str: return "TASK_STATUS_IGNORED" def started(self, value=None): @@ -2339,7 +2414,7 @@ async def checkpoint(): await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -async def checkpoint_if_cancelled(): +async def checkpoint_if_cancelled() -> None: """Issue a :ref:`checkpoint ` if the calling context has been cancelled. diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py index ae5e8450b9..70768878df 100644 --- a/trio/_core/_thread_cache.py +++ b/trio/_core/_thread_cache.py @@ -1,6 +1,7 @@ from threading import Thread, Lock import outcome from itertools import count +from typing import Callable, Dict, Optional, Tuple # The "thread cache" is a simple unbounded thread pool, i.e., it automatically # spawns as many threads as needed to handle all the requests its given. Its @@ -39,10 +40,14 @@ name_counter = count() +_Fn = Callable[..., object] +_Deliver = Callable[[outcome.Outcome], object] +_Job = Tuple[_Fn, _Deliver] + class WorkerThread: - def __init__(self, thread_cache): - self._job = None + def __init__(self, thread_cache: "ThreadCache") -> None: + self._job: Optional[_Job] = None self._thread_cache = thread_cache # This Lock is used in an unconventional way. # @@ -56,11 +61,14 @@ def __init__(self, thread_cache): thread.name = f"Trio worker thread {next(name_counter)}" thread.start() - def _work(self): + def _work(self) -> None: while True: if self._worker_lock.acquire(timeout=IDLE_TIMEOUT): # We got a job - fn, deliver = self._job + fn: _Fn + deliver: _Deliver + # type ignoring to avoid any runtime cost of casting etc + fn, deliver = self._job # type: ignore[misc] self._job = None result = outcome.capture(fn) # Tell the cache that we're available to be assigned a new @@ -90,10 +98,10 @@ def _work(self): class ThreadCache: - def __init__(self): - self._idle_workers = {} + def __init__(self) -> None: + self._idle_workers: Dict[WorkerThread, None] = {} - def start_thread_soon(self, fn, deliver): + def start_thread_soon(self, fn: _Fn, deliver: _Deliver) -> None: try: worker, _ = self._idle_workers.popitem() except KeyError: @@ -105,7 +113,7 @@ def start_thread_soon(self, fn, deliver): THREAD_CACHE = ThreadCache() -def start_thread_soon(fn, deliver): +def start_thread_soon(fn: _Fn, deliver: _Deliver) -> None: """Runs ``deliver(outcome.capture(fn))`` in a worker thread. Generally ``fn`` does some blocking work, and ``deliver`` delivers the diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index 95cf46de9b..341490a201 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -2,12 +2,15 @@ import types import enum +from typing import Callable import attr import outcome from . import _run +AbortFunc = Callable[[Callable[[], None]], "Abort"] + # Helper for the bottommost 'yield'. You can't use 'yield' inside an async # function, but you can inside a generator, and if you decorate your generator @@ -17,7 +20,7 @@ # tracking machinery. Since our traps are public APIs, we make them real async # functions, and then this helper takes care of the actual yield: @types.coroutine -def _async_yield(obj): +def _async_yield(obj: object) -> object: return (yield obj) @@ -27,7 +30,7 @@ class CancelShieldedCheckpoint: pass -async def cancel_shielded_checkpoint(): +async def cancel_shielded_checkpoint() -> None: """Introduce a schedule point, but not a cancel point. This is *not* a :ref:`checkpoint `, but it is half of a @@ -40,7 +43,7 @@ async def cancel_shielded_checkpoint(): await trio.lowlevel.checkpoint() """ - return (await _async_yield(CancelShieldedCheckpoint)).unwrap() + return (await _async_yield(CancelShieldedCheckpoint)).unwrap() # type: ignore[no-any-return] # Return values for abort functions @@ -61,10 +64,10 @@ class Abort(enum.Enum): # Not exported in the trio._core namespace, but imported directly by _run. @attr.s(frozen=True) class WaitTaskRescheduled: - abort_func = attr.ib() + abort_func: AbortFunc = attr.ib() -async def wait_task_rescheduled(abort_func): +async def wait_task_rescheduled(abort_func: AbortFunc) -> object: """Put the current task to sleep, with cancellation support. This is the lowest-level API for blocking in Trio. Every time a @@ -169,10 +172,10 @@ def abort(inner_raise_cancel): # Not exported in the trio._core namespace, but imported directly by _run. @attr.s(frozen=True) class PermanentlyDetachCoroutineObject: - final_outcome = attr.ib() + final_outcome: outcome.Outcome = attr.ib() -async def permanently_detach_coroutine_object(final_outcome): +async def permanently_detach_coroutine_object(final_outcome: outcome.Outcome) -> None: """Permanently detach the current task from the Trio scheduler. Normally, a Trio task doesn't exit until its coroutine object exits. When @@ -200,10 +203,10 @@ async def permanently_detach_coroutine_object(final_outcome): raise RuntimeError( "can't permanently detach a coroutine object with open nurseries" ) - return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome)) + return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome)) # type: ignore[no-any-return] -async def temporarily_detach_coroutine_object(abort_func): +async def temporarily_detach_coroutine_object(abort_func: AbortFunc) -> None: """Temporarily detach the current coroutine object from the Trio scheduler. @@ -236,10 +239,12 @@ async def temporarily_detach_coroutine_object(abort_func): uses to resume the coroutine. """ - return await _async_yield(WaitTaskRescheduled(abort_func)) + return await _async_yield(WaitTaskRescheduled(abort_func)) # type: ignore[no-any-return] -async def reattach_detached_coroutine_object(task, yield_value): +async def reattach_detached_coroutine_object( + task: "_run.Task", yield_value: object +) -> None: """Reattach a coroutine object that was detached using :func:`temporarily_detach_coroutine_object`. diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index f877e42a0c..8880cbd986 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -1,3 +1,5 @@ +from typing import Generic, List, TypeVar + import attr from .. import _core @@ -5,13 +7,17 @@ from .._util import Final +_T = TypeVar("_T") +_TSelf = TypeVar("_TSelf") + + @attr.s(frozen=True) class _UnboundedQueueStats: - qsize = attr.ib() - tasks_waiting = attr.ib() + qsize: int = attr.ib() + tasks_waiting: int = attr.ib() -class UnboundedQueue(metaclass=Final): +class UnboundedQueue(Generic[_T], metaclass=Final): """An unbounded queue suitable for certain unusual forms of inter-task communication. @@ -47,20 +53,20 @@ class UnboundedQueue(metaclass=Final): thing="trio.lowlevel.UnboundedQueue", instead="trio.open_memory_channel(math.inf)", ) - def __init__(self): + def __init__(self) -> None: self._lot = _core.ParkingLot() - self._data = [] + self._data: List[_T] = [] # used to allow handoff from put to the first task in the lot self._can_get = False - def __repr__(self): + def __repr__(self) -> str: return "".format(len(self._data)) - def qsize(self): + def qsize(self) -> int: """Returns the number of items currently in the queue.""" return len(self._data) - def empty(self): + def empty(self) -> bool: """Returns True if the queue is empty, False otherwise. There is some subtlety to interpreting this method's return value: see @@ -70,7 +76,7 @@ def empty(self): return not self._data @_core.enable_ki_protection - def put_nowait(self, obj): + def put_nowait(self, obj: _T) -> None: """Put an object into the queue, without blocking. This always succeeds, because the queue is unbounded. We don't provide @@ -88,13 +94,13 @@ def put_nowait(self, obj): self._can_get = True self._data.append(obj) - def _get_batch_protected(self): + def _get_batch_protected(self) -> List[_T]: data = self._data.copy() self._data.clear() self._can_get = False return data - def get_batch_nowait(self): + def get_batch_nowait(self) -> List[_T]: """Attempt to get the next batch from the queue, without blocking. Returns: @@ -110,7 +116,7 @@ def get_batch_nowait(self): raise _core.WouldBlock return self._get_batch_protected() - async def get_batch(self): + async def get_batch(self) -> List[_T]: """Get the next batch from the queue, blocking as necessary. Returns: @@ -128,7 +134,7 @@ async def get_batch(self): finally: await _core.cancel_shielded_checkpoint() - def statistics(self): + def statistics(self) -> _UnboundedQueueStats: """Return an object containing debugging information. Currently the following fields are defined: @@ -142,8 +148,8 @@ def statistics(self): qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting ) - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self - async def __anext__(self): + async def __anext__(self) -> List[_T]: return await self.get_batch() diff --git a/trio/_core/_wakeup_socketpair.py b/trio/_core/_wakeup_socketpair.py index 121cec584e..77f2e6b8b9 100644 --- a/trio/_core/_wakeup_socketpair.py +++ b/trio/_core/_wakeup_socketpair.py @@ -1,13 +1,14 @@ import socket import sys import signal +from typing import Optional import warnings from .. import _core from .._util import is_main_thread -def _has_warn_on_full_buffer(): +def _has_warn_on_full_buffer() -> bool: if sys.version_info < (3, 7): return False @@ -26,7 +27,7 @@ def _has_warn_on_full_buffer(): class WakeupSocketpair: - def __init__(self): + def __init__(self) -> None: self.wakeup_sock, self.write_sock = socket.socketpair() self.wakeup_sock.setblocking(False) self.write_sock.setblocking(False) @@ -51,26 +52,26 @@ def __init__(self): self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) except OSError: pass - self.old_wakeup_fd = None + self.old_wakeup_fd: Optional[int] = None - def wakeup_thread_and_signal_safe(self): + def wakeup_thread_and_signal_safe(self) -> None: try: self.write_sock.send(b"\x00") except BlockingIOError: pass - async def wait_woken(self): + async def wait_woken(self) -> None: await _core.wait_readable(self.wakeup_sock) self.drain() - def drain(self): + def drain(self) -> None: try: while True: self.wakeup_sock.recv(2 ** 16) except BlockingIOError: pass - def wakeup_on_signals(self): + def wakeup_on_signals(self) -> None: assert self.old_wakeup_fd is None if not is_main_thread(): return @@ -91,7 +92,7 @@ def wakeup_on_signals(self): ) ) - def close(self): + def close(self) -> None: self.wakeup_sock.close() self.write_sock.close() if self.old_wakeup_fd is not None: diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py index a1071519e9..7949932dce 100644 --- a/trio/_core/_windows_cffi.py +++ b/trio/_core/_windows_cffi.py @@ -1,6 +1,7 @@ import cffi import re import enum +from typing import Optional, Union ################################################################ # Functions and types @@ -301,7 +302,7 @@ class IoControlCodes(enum.IntEnum): ################################################################ -def _handle(obj): +def _handle(obj: Union[int, object]) -> object: # For now, represent handles as either cffi HANDLEs or as ints. If you # try to pass in a file descriptor instead, it's not going to work # out. (For that msvcrt.get_osfhandle does the trick, but I don't know if @@ -314,7 +315,12 @@ def _handle(obj): return obj -def raise_winerror(winerror=None, *, filename=None, filename2=None): +def raise_winerror( + winerror: Optional[object] = None, + *, + filename: Optional[str] = None, + filename2: Optional[str] = None, +) -> None: if winerror is None: winerror, msg = ffi.getwinerror() else: diff --git a/trio/_core/tests/conftest.py b/trio/_core/tests/conftest.py index aca1f98a65..bdb9a9a04f 100644 --- a/trio/_core/tests/conftest.py +++ b/trio/_core/tests/conftest.py @@ -1,4 +1,6 @@ import pytest +import _pytest.python + import inspect # XX this should move into a global something @@ -6,12 +8,12 @@ @pytest.fixture -def mock_clock(): +def mock_clock() -> MockClock: return MockClock() @pytest.fixture -def autojump_clock(): +def autojump_clock() -> MockClock: return MockClock(autojump_threshold=0) @@ -20,6 +22,6 @@ def autojump_clock(): # guess it's useful with the class- and file-level marking machinery (where # the raw @trio_test decorator isn't enough). @pytest.hookimpl(tryfirst=True) -def pytest_pyfunc_call(pyfuncitem): +def pytest_pyfunc_call(pyfuncitem: _pytest.python.Function) -> None: # type: ignore[misc] if inspect.iscoroutinefunction(pyfuncitem.obj): pyfuncitem.obj = trio_test(pyfuncitem.obj) diff --git a/trio/_core/tests/test_asyncgen.py b/trio/_core/tests/test_asyncgen.py index 1f886e11ab..41f4dd172f 100644 --- a/trio/_core/tests/test_asyncgen.py +++ b/trio/_core/tests/test_asyncgen.py @@ -7,8 +7,10 @@ from ... import _core from .tutil import gc_collect_harder, buggy_pypy_asyncgens +import _pytest.capture -def test_asyncgen_basics(): + +def test_asyncgen_basics() -> None: collected = [] async def example(cause): @@ -35,7 +37,7 @@ async def example(cause): saved = [] - async def async_main(): + async def async_main() -> None: # GC'ed before exhausted with pytest.warns( ResourceWarning, match="Async generator.*collected before.*exhausted" @@ -83,7 +85,7 @@ async def async_main(): assert agen.ag_frame is None # all should now be exhausted -async def test_asyncgen_throws_during_finalization(caplog): +async def test_asyncgen_throws_during_finalization(caplog) -> None: record = [] async def agen(): @@ -105,7 +107,7 @@ async def agen(): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_firstiter_after_closing(): +def test_firstiter_after_closing() -> None: saved = [] record = [] @@ -121,7 +123,7 @@ async def funky_agen(): record.append("cleanup 2") await funky_agen().asend(None) - async def async_main(): + async def async_main() -> None: aiter = funky_agen() saved.append(aiter) assert 1 == await aiter.asend(None) @@ -132,7 +134,7 @@ async def async_main(): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_interdependent_asyncgen_cleanup_order(): +def test_interdependent_asyncgen_cleanup_order() -> None: saved = [] record = [] @@ -155,7 +157,7 @@ async def agen(label, inner): await inner.asend(None) record.append(label) - async def async_main(): + async def async_main() -> None: # This makes a chain of 101 interdependent asyncgens: # agen(99)'s cleanup will iterate agen(98)'s will iterate # ... agen(0)'s will iterate innermost()'s @@ -167,10 +169,10 @@ async def async_main(): assert record == [] _core.run(async_main) - assert record == ["innermost"] + list(range(100)) + assert record == ["innermost", *range(100)] -def test_last_minute_gc_edge_case(): +def test_last_minute_gc_edge_case() -> None: saved = [] record = [] needs_retry = True @@ -197,7 +199,7 @@ def collect_at_opportune_moment(token): nonlocal needs_retry needs_retry = True - async def async_main(): + async def async_main() -> None: token = _core.current_trio_token() token.run_sync_soon(collect_at_opportune_moment, token) saved.append(agen()) @@ -252,8 +254,11 @@ def abort_fn(_): nursery.cancel_scope.deadline = _core.current_time() +# can switch to annotating from pytest directly as of 6.2.0 @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -async def test_fallback_when_no_hook_claims_it(capsys): +async def test_fallback_when_no_hook_claims_it( + capsys: _pytest.capture.CaptureFixture[str], +) -> None: async def well_behaved(): yield 42 @@ -281,7 +286,7 @@ async def awaits_after_yield(): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_delegation_to_existing_hooks(): +def test_delegation_to_existing_hooks() -> None: record = [] def my_firstiter(agen): @@ -298,7 +303,7 @@ async def example(arg): await _core.checkpoint() record.append("trio collected " + arg) - async def async_main(): + async def async_main() -> None: await step_outside_async_context(example("theirs")) assert 42 == await example("ours").asend(None) gc_collect_harder() diff --git a/trio/_core/tests/test_guest_mode.py b/trio/_core/tests/test_guest_mode.py index c9701e7cdd..920b04797c 100644 --- a/trio/_core/tests/test_guest_mode.py +++ b/trio/_core/tests/test_guest_mode.py @@ -71,7 +71,7 @@ def done_callback(outcome): del todo, run_sync_soon_threadsafe, done_callback -def test_guest_trivial(): +def test_guest_trivial() -> None: async def trio_return(in_host): await trio.sleep(0) return "ok" @@ -85,14 +85,14 @@ async def trio_fail(in_host): trivial_guest_run(trio_fail) -def test_guest_can_do_io(): +def test_guest_can_do_io() -> None: async def trio_main(in_host): record = [] a, b = trio.socket.socketpair() with a, b: async with trio.open_nursery() as nursery: - async def do_receive(): + async def do_receive() -> None: record.append(await a.recv(1)) nursery.start_soon(do_receive) @@ -105,7 +105,7 @@ async def do_receive(): trivial_guest_run(trio_main) -def test_host_can_directly_wake_trio_task(): +def test_host_can_directly_wake_trio_task() -> None: async def trio_main(in_host): ev = trio.Event() in_host(ev.set) @@ -115,7 +115,7 @@ async def trio_main(in_host): assert trivial_guest_run(trio_main) == "ok" -def test_host_altering_deadlines_wakes_trio_up(): +def test_host_altering_deadlines_wakes_trio_up() -> None: def set_deadline(cscope, new_deadline): cscope.deadline = new_deadline @@ -138,7 +138,7 @@ async def trio_main(in_host): assert trivial_guest_run(trio_main) == "ok" -def test_warn_set_wakeup_fd_overwrite(): +def test_warn_set_wakeup_fd_overwrite() -> None: assert signal.set_wakeup_fd(-1) == -1 async def trio_main(in_host): @@ -206,7 +206,7 @@ async def trio_check_wakeup_fd_unaltered(in_host): assert signal.set_wakeup_fd(-1) == a.fileno() -def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked(): +def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked() -> None: # This is designed to hit the branch in unrolled_run where: # idle_primed=True # runner.runq is empty @@ -245,7 +245,7 @@ async def get_woken_by_host_deadline(watb_cscope): # actually end. So in after_io_wait we schedule a second host # call to tear things down. class InstrumentHelper: - def __init__(self): + def __init__(self) -> None: self.primed = False def before_io_wait(self, timeout): @@ -277,7 +277,7 @@ def after_io_wait(self, timeout): assert trivial_guest_run(trio_main) == "ok" -def test_guest_warns_if_abandoned(): +def test_guest_warns_if_abandoned() -> None: # This warning is emitted from the garbage collector. So we have to make # sure that our abandoned run is garbage. The easiest way to do this is to # put it into a function, so that we're sure all the local state, @@ -346,7 +346,7 @@ def trio_done_callback(main_outcome): loop.close() -def test_guest_mode_on_asyncio(): +def test_guest_mode_on_asyncio() -> None: async def trio_main(): print("trio_main!") @@ -405,7 +405,7 @@ async def aio_pingpong(from_trio, to_trio): ) -def test_guest_mode_internal_errors(monkeypatch, recwarn): +def test_guest_mode_internal_errors(monkeypatch, recwarn) -> None: with monkeypatch.context() as m: async def crash_in_run_loop(in_host): @@ -446,7 +446,7 @@ def bad_get_events(*args): gc_collect_harder() -def test_guest_mode_ki(): +def test_guest_mode_ki() -> None: assert signal.getsignal(signal.SIGINT) is signal.default_int_handler # Check SIGINT in Trio func and in host func @@ -478,7 +478,7 @@ async def trio_main_raising(in_host): assert signal.getsignal(signal.SIGINT) is signal.default_int_handler -def test_guest_mode_autojump_clock_threshold_changing(): +def test_guest_mode_autojump_clock_threshold_changing() -> None: # This is super obscure and probably no-one will ever notice, but # technically mutating the MockClock.autojump_threshold from the host # should wake up the guest, so let's test it. @@ -506,7 +506,7 @@ async def trio_main(in_host): sys.implementation.name == "pypy" and sys.version_info >= (3, 7), reason="async generator issue under investigation", ) -def test_guest_mode_asyncgens(): +def test_guest_mode_asyncgens() -> None: import sniffio record = set() @@ -523,7 +523,7 @@ async def agen(label): pass record.add((label, library)) - async def iterate_in_aio(): + async def iterate_in_aio() -> None: # "trio" gets inherited from our Trio caller if we don't set this sniffio.current_async_library_cvar.set("asyncio") await agen("asyncio").asend(None) diff --git a/trio/_core/tests/test_instrumentation.py b/trio/_core/tests/test_instrumentation.py index 57d3461d3b..f50405ccf2 100644 --- a/trio/_core/tests/test_instrumentation.py +++ b/trio/_core/tests/test_instrumentation.py @@ -8,7 +8,7 @@ class TaskRecorder: record = attr.ib(factory=list) - def before_run(self): + def before_run(self) -> None: self.record.append(("before_run",)) def task_scheduled(self, task): @@ -22,7 +22,7 @@ def after_task_step(self, task): assert task is _core.current_task() self.record.append(("after", task)) - def after_run(self): + def after_run(self) -> None: self.record.append(("after_run",)) def filter_tasks(self, tasks): @@ -33,7 +33,7 @@ def filter_tasks(self, tasks): yield item -def test_instruments(recwarn): +def test_instruments(recwarn) -> None: r1 = TaskRecorder() r2 = TaskRecorder() r3 = TaskRecorder() @@ -43,7 +43,7 @@ def test_instruments(recwarn): # We use a child task for this, because the main task does some extra # bookkeeping stuff that can leak into the instrument results, and we # don't want to deal with it. - async def task_fn(): + async def task_fn() -> None: nonlocal task task = _core.current_task() @@ -59,7 +59,7 @@ async def task_fn(): for _ in range(1): await _core.checkpoint() - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(task_fn) @@ -77,18 +77,18 @@ async def main(): assert list(r1.filter_tasks([task])) == expected -def test_instruments_interleave(): +def test_instruments_interleave() -> None: tasks = {} - async def two_step1(): + async def two_step1() -> None: tasks["t1"] = _core.current_task() await _core.checkpoint() - async def two_step2(): + async def two_step2() -> None: tasks["t2"] = _core.current_task() await _core.checkpoint() - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(two_step1) nursery.start_soon(two_step2) @@ -120,36 +120,36 @@ async def main(): check_sequence_matches(list(r.filter_tasks(tasks.values())), expected) -def test_null_instrument(): +def test_null_instrument() -> None: # undefined instrument methods are skipped class NullInstrument: - def something_unrelated(self): + def something_unrelated(self) -> None: pass # pragma: no cover - async def main(): + async def main() -> None: await _core.checkpoint() _core.run(main, instruments=[NullInstrument()]) -def test_instrument_before_after_run(): +def test_instrument_before_after_run() -> None: record = [] class BeforeAfterRun: - def before_run(self): + def before_run(self) -> None: record.append("before_run") - def after_run(self): + def after_run(self) -> None: record.append("after_run") - async def main(): + async def main() -> None: pass _core.run(main, instruments=[BeforeAfterRun()]) assert record == ["before_run", "after_run"] -def test_instrument_task_spawn_exit(): +def test_instrument_task_spawn_exit() -> None: record = [] class SpawnExitRecorder: @@ -169,7 +169,7 @@ async def main(): # This test also tests having a crash before the initial task is even spawned, # which is very difficult to handle. -def test_instruments_crash(caplog): +def test_instruments_crash(caplog) -> None: record = [] class BrokenInstrument: @@ -177,7 +177,7 @@ def task_scheduled(self, task): record.append("scheduled") raise ValueError("oops") - def close(self): + def close(self) -> None: # Shouldn't be called -- tests that the instrument disabling logic # works right. record.append("closed") # pragma: no cover @@ -200,13 +200,13 @@ async def main(): assert "Instrument has been disabled" in caplog.records[0].message -def test_instruments_monkeypatch(): +def test_instruments_monkeypatch() -> None: class NullInstrument(_abc.Instrument): pass instrument = NullInstrument() - async def main(): + async def main() -> None: record = [] # Changing the set of hooks implemented by an instrument after @@ -232,16 +232,16 @@ async def main(): _core.run(main, instruments=[instrument]) -def test_instrument_that_raises_on_getattr(): +def test_instrument_that_raises_on_getattr() -> None: class EvilInstrument: def task_exited(self, task): assert False # pragma: no cover @property - def after_run(self): + def after_run(self) -> None: raise ValueError("oops") - async def main(): + async def main() -> None: with pytest.raises(ValueError): _core.add_instrument(EvilInstrument()) diff --git a/trio/_core/tests/test_io.py b/trio/_core/tests/test_io.py index 397375503d..f8ce11a1b8 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/tests/test_io.py @@ -5,8 +5,10 @@ import random import errno from contextlib import suppress +from typing import Awaitable, Callable, Iterator, List, Tuple, Union from ... import _core +from ..._typing import _HasFileno from ...testing import wait_all_tasks_blocked, Sequencer, assert_checkpoints import trio @@ -29,8 +31,11 @@ def drain_socket(sock): pass +_SocketPair = Tuple[stdlib_socket.socket, stdlib_socket.socket] + + @pytest.fixture -def socketpair(): +def socketpair() -> Iterator[_SocketPair]: pair = stdlib_socket.socketpair() for sock in pair: sock.setblocking(False) @@ -48,9 +53,16 @@ def fileno_wrapper(fileobj): return fileno_wrapper -wait_readable_options = [trio.lowlevel.wait_readable] -wait_writable_options = [trio.lowlevel.wait_writable] -notify_closing_options = [trio.lowlevel.notify_closing] +_WaitReadable = Callable[[stdlib_socket.socket], Awaitable[None]] +_WaitWritable = Callable[[stdlib_socket.socket], Awaitable[None]] +_NotifyClosing = Callable[[stdlib_socket.socket], None] + +# OptionsList = List[Callable[[Union[int, _HasFileno]], Union[Awaitable[None], None]]] + +wait_readable_options: List = [trio.lowlevel.wait_readable] +wait_writable_options: List = [trio.lowlevel.wait_writable] +notify_closing_options: List = [trio.lowlevel.notify_closing] + for options_list in [ wait_readable_options, @@ -59,18 +71,23 @@ def fileno_wrapper(fileobj): ]: options_list += [using_fileno(f) for f in options_list] + +def get__name__(fn: Callable) -> str: + return fn.__name__ + + # Decorators that feed in different settings for wait_readable / wait_writable # / notify_closing. # Note that if you use all three decorators on the same test, it will run all # N**3 *combinations* read_socket_test = pytest.mark.parametrize( - "wait_readable", wait_readable_options, ids=lambda fn: fn.__name__ + "wait_readable", wait_readable_options, ids=get__name__ ) write_socket_test = pytest.mark.parametrize( - "wait_writable", wait_writable_options, ids=lambda fn: fn.__name__ + "wait_writable", wait_writable_options, ids=get__name__ ) notify_closing_test = pytest.mark.parametrize( - "notify_closing", notify_closing_options, ids=lambda fn: fn.__name__ + "notify_closing", notify_closing_options, ids=get__name__ ) @@ -79,7 +96,9 @@ def fileno_wrapper(fileobj): # momentarily and then immediately resuming. @read_socket_test @write_socket_test -async def test_wait_basic(socketpair, wait_readable, wait_writable): +async def test_wait_basic( + socketpair: _SocketPair, wait_readable: _WaitReadable, wait_writable: _WaitWritable +) -> None: a, b = socketpair # They start out writable() @@ -89,7 +108,7 @@ async def test_wait_basic(socketpair, wait_readable, wait_writable): # But readable() blocks until data arrives record = [] - async def block_on_read(): + async def block_on_read() -> None: try: with assert_checkpoints(): await wait_readable(a) @@ -112,7 +131,7 @@ async def block_on_read(): await wait_readable(b) record = [] - async def block_on_write(): + async def block_on_write() -> None: try: with assert_checkpoints(): await wait_writable(a) @@ -145,7 +164,9 @@ async def block_on_write(): @read_socket_test -async def test_double_read(socketpair, wait_readable): +async def test_double_read( + socketpair: _SocketPair, wait_readable: _WaitWritable +) -> None: a, b = socketpair # You can't have two tasks trying to read from a socket at the same time @@ -158,7 +179,9 @@ async def test_double_read(socketpair, wait_readable): @write_socket_test -async def test_double_write(socketpair, wait_writable): +async def test_double_write( + socketpair: _SocketPair, wait_writable: _WaitWritable +) -> None: a, b = socketpair # You can't have two tasks trying to write to a socket at the same time @@ -175,15 +198,18 @@ async def test_double_write(socketpair, wait_writable): @write_socket_test @notify_closing_test async def test_interrupted_by_close( - socketpair, wait_readable, wait_writable, notify_closing -): + socketpair: _SocketPair, + wait_readable: _WaitReadable, + wait_writable: _WaitWritable, + notify_closing: _NotifyClosing, +) -> None: a, b = socketpair - async def reader(): + async def reader() -> None: with pytest.raises(_core.ClosedResourceError): await wait_readable(a) - async def writer(): + async def writer() -> None: with pytest.raises(_core.ClosedResourceError): await wait_writable(a) @@ -198,7 +224,9 @@ async def writer(): @read_socket_test @write_socket_test -async def test_socket_simultaneous_read_write(socketpair, wait_readable, wait_writable): +async def test_socket_simultaneous_read_write( + socketpair: _SocketPair, wait_readable: _WaitReadable, wait_writable: _WaitWritable +) -> None: record = [] async def r_task(sock): @@ -226,7 +254,9 @@ async def w_task(sock): @read_socket_test @write_socket_test -async def test_socket_actual_streaming(socketpair, wait_readable, wait_writable): +async def test_socket_actual_streaming( + socketpair: _SocketPair, wait_readable: _WaitReadable, wait_writable: _WaitWritable +) -> None: a, b = socketpair # Use a small send buffer on one of the sockets to increase the chance of @@ -275,7 +305,7 @@ async def receiver(sock, key): assert results["send_b"] == results["recv_a"] -async def test_notify_closing_on_invalid_object(): +async def test_notify_closing_on_invalid_object() -> None: # It should either be a no-op (generally on Unix, where we don't know # which fds are valid), or an OSError (on Windows, where we currently only # support sockets, so we have to do some validation to figure out whether @@ -291,7 +321,7 @@ async def test_notify_closing_on_invalid_object(): assert got_oserror or got_no_error -async def test_wait_on_invalid_object(): +async def test_wait_on_invalid_object() -> None: # We definitely want to raise an error everywhere if you pass in an # invalid fd to wait_* for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]: @@ -303,7 +333,7 @@ async def test_wait_on_invalid_object(): await wait(fileno) -async def test_io_manager_statistics(): +async def test_io_manager_statistics() -> None: def check(*, expected_readers, expected_writers): statistics = _core.current_statistics() print(statistics) @@ -351,7 +381,7 @@ def check(*, expected_readers, expected_writers): check(expected_readers=1, expected_writers=0) -async def test_can_survive_unnotified_close(): +async def test_can_survive_unnotified_close() -> None: # An "unnotified" close is when the user closes an fd/socket/handle # directly, without calling notify_closing first. This should never happen # -- users should call notify_closing before closing things. But, just in @@ -429,7 +459,7 @@ async def allow_OSError(async_func, *args): # sleep waiting on 'a2', with the idea that the 'a2' notification will # definitely arrive, and when it does then we can assume that whatever # notification was going to arrive for 'a' has also arrived. - async def wait_readable_a2_then_set(): + async def wait_readable_a2_then_set() -> None: await trio.lowlevel.wait_readable(a2) e.set() diff --git a/trio/_core/tests/test_ki.py b/trio/_core/tests/test_ki.py index 0e4db4af49..e849cf14ef 100644 --- a/trio/_core/tests/test_ki.py +++ b/trio/_core/tests/test_ki.py @@ -6,6 +6,7 @@ import threading import contextlib import time +from typing import Any, AsyncIterator, Iterator from async_generator import ( async_generator, @@ -21,24 +22,24 @@ from .tutil import slow -def ki_self(): +def ki_self() -> None: signal_raise(signal.SIGINT) -def test_ki_self(): +def test_ki_self() -> None: with pytest.raises(KeyboardInterrupt): ki_self() -async def test_ki_enabled(): +async def test_ki_enabled() -> None: # Regular tasks aren't KI-protected assert not _core.currently_ki_protected() # Low-level call-soon callbacks are KI-protected token = _core.current_trio_token() - record = [] + record: Any = [] - def check(): + def check() -> None: record.append(_core.currently_ki_protected()) token.run_sync_soon(check) @@ -46,23 +47,23 @@ def check(): assert record == [True] @_core.enable_ki_protection - def protected(): + def protected() -> None: assert _core.currently_ki_protected() unprotected() @_core.disable_ki_protection - def unprotected(): + def unprotected() -> None: assert not _core.currently_ki_protected() protected() @_core.enable_ki_protection - async def aprotected(): + async def aprotected() -> None: assert _core.currently_ki_protected() await aunprotected() @_core.disable_ki_protection - async def aunprotected(): + async def aunprotected() -> None: assert not _core.currently_ki_protected() await aprotected() @@ -74,7 +75,7 @@ async def aunprotected(): nursery.start_soon(aunprotected) @_core.enable_ki_protection - def gen_protected(): + def gen_protected() -> Iterator[None]: assert _core.currently_ki_protected() yield @@ -82,7 +83,7 @@ def gen_protected(): pass @_core.disable_ki_protection - def gen_unprotected(): + def gen_unprotected() -> Iterator[None]: assert not _core.currently_ki_protected() yield @@ -99,16 +100,16 @@ def gen_unprotected(): # .throw(), not the actual caller. So child() here would have a caller deep in # the guts of the run loop, and always be protected, even when it shouldn't # have been. (Solution: we don't use .throw() anymore.) -async def test_ki_enabled_after_yield_briefly(): +async def test_ki_enabled_after_yield_briefly() -> None: @_core.enable_ki_protection - async def protected(): + async def protected() -> None: await child(True) @_core.disable_ki_protection - async def unprotected(): + async def unprotected() -> None: await child(False) - async def child(expected): + async def child(expected: bool) -> None: import traceback traceback.print_stack() @@ -123,10 +124,10 @@ async def child(expected): # This also used to be broken due to # https://bugs.python.org/issue29590 -async def test_generator_based_context_manager_throw(): +async def test_generator_based_context_manager_throw() -> None: @contextlib.contextmanager @_core.enable_ki_protection - def protected_manager(): + def protected_manager() -> Iterator[None]: assert _core.currently_ki_protected() try: yield @@ -142,10 +143,10 @@ def protected_manager(): raise KeyError -async def test_agen_protection(): +async def test_agen_protection() -> None: @_core.enable_ki_protection @async_generator - async def agen_protected1(): + async def agen_protected1() -> None: # type: ignore[misc] assert _core.currently_ki_protected() try: await yield_() @@ -154,7 +155,7 @@ async def agen_protected1(): @_core.disable_ki_protection @async_generator - async def agen_unprotected1(): + async def agen_unprotected1() -> None: # type: ignore[misc] assert not _core.currently_ki_protected() try: await yield_() @@ -164,7 +165,7 @@ async def agen_unprotected1(): # Swap the order of the decorators: @async_generator @_core.enable_ki_protection - async def agen_protected2(): + async def agen_protected2() -> None: # type: ignore[misc] assert _core.currently_ki_protected() try: await yield_() @@ -173,7 +174,7 @@ async def agen_protected2(): @async_generator @_core.disable_ki_protection - async def agen_unprotected2(): + async def agen_unprotected2() -> None: # type: ignore[misc] assert not _core.currently_ki_protected() try: await yield_() @@ -182,7 +183,7 @@ async def agen_unprotected2(): # Native async generators @_core.enable_ki_protection - async def agen_protected3(): + async def agen_protected3() -> AsyncIterator[None]: assert _core.currently_ki_protected() try: yield @@ -190,7 +191,7 @@ async def agen_protected3(): assert _core.currently_ki_protected() @_core.disable_ki_protection - async def agen_unprotected3(): + async def agen_unprotected3() -> AsyncIterator[None]: assert not _core.currently_ki_protected() try: yield @@ -222,20 +223,20 @@ async def agen_unprotected3(): # Test the case where there's no magic local anywhere in the call stack -def test_ki_disabled_out_of_context(): +def test_ki_disabled_out_of_context() -> None: assert _core.currently_ki_protected() -def test_ki_disabled_in_del(): +def test_ki_disabled_in_del() -> None: def nestedfunction(): return _core.currently_ki_protected() - def __del__(): + def __del__() -> None: assert _core.currently_ki_protected() assert nestedfunction() @_core.disable_ki_protection - def outerfunction(): + def outerfunction() -> None: assert not _core.currently_ki_protected() assert not nestedfunction() __del__() @@ -245,7 +246,7 @@ def outerfunction(): assert nestedfunction() -def test_ki_protection_works(): +def test_ki_protection_works() -> None: async def sleeper(name, record): try: while True: @@ -276,9 +277,9 @@ async def raiser(name, record): # simulated control-C during raiser, which is *unprotected* print("check 1") - record = set() + record: Any = set() - async def check_unprotected_kill(): + async def check_unprotected_kill() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(sleeper, "s1", record) nursery.start_soon(sleeper, "s2", record) @@ -293,7 +294,7 @@ async def check_unprotected_kill(): print("check 2") record = set() - async def check_protected_kill(): + async def check_protected_kill() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(sleeper, "s1", record) nursery.start_soon(sleeper, "s2", record) @@ -308,10 +309,10 @@ async def check_protected_kill(): # error, then kill) print("check 3") - async def check_kill_during_shutdown(): + async def check_kill_during_shutdown() -> None: token = _core.current_trio_token() - def kill_during_shutdown(): + def kill_during_shutdown() -> None: assert _core.currently_ki_protected() try: token.run_sync_soon(kill_during_shutdown) @@ -329,10 +330,10 @@ def kill_during_shutdown(): print("check 4") class InstrumentOfDeath: - def before_run(self): + def before_run(self) -> None: ki_self() - async def main(): + async def main() -> None: await _core.checkpoint() with pytest.raises(KeyboardInterrupt): @@ -342,19 +343,19 @@ async def main(): print("check 5") @_core.enable_ki_protection - async def main(): + async def main_a() -> None: assert _core.currently_ki_protected() ki_self() with pytest.raises(KeyboardInterrupt): await _core.checkpoint_if_cancelled() - _core.run(main) + _core.run(main_a) # KI arrives while main task is not abortable, b/c already scheduled print("check 6") @_core.enable_ki_protection - async def main(): + async def main_b() -> None: assert _core.currently_ki_protected() ki_self() await _core.cancel_shielded_checkpoint() @@ -363,13 +364,13 @@ async def main(): with pytest.raises(KeyboardInterrupt): await _core.checkpoint() - _core.run(main) + _core.run(main_b) # KI arrives while main task is not abortable, b/c refuses to be aborted print("check 7") @_core.enable_ki_protection - async def main(): + async def main_c() -> None: assert _core.currently_ki_protected() ki_self() task = _core.current_task() @@ -382,13 +383,13 @@ def abort(_): with pytest.raises(KeyboardInterrupt): await _core.checkpoint() - _core.run(main) + _core.run(main_c) # KI delivered via slow abort print("check 8") @_core.enable_ki_protection - async def main(): + async def main_d() -> None: assert _core.currently_ki_protected() ki_self() task = _core.current_task() @@ -402,7 +403,7 @@ def abort(raise_cancel): assert await _core.wait_task_rescheduled(abort) await _core.checkpoint() - _core.run(main) + _core.run(main_d) # KI arrives just before main task exits, so the run_sync_soon machinery # is still functioning and will accept the callback to deliver the KI, but @@ -411,18 +412,18 @@ def abort(raise_cancel): print("check 9") @_core.enable_ki_protection - async def main(): + async def main_e() -> None: ki_self() with pytest.raises(KeyboardInterrupt): - _core.run(main) + _core.run(main_e) print("check 10") # KI in unprotected code, with # restrict_keyboard_interrupt_to_checkpoints=True record = [] - async def main(): + async def main_f() -> None: # We're not KI protected... assert not _core.currently_ki_protected() ki_self() @@ -432,13 +433,13 @@ async def main(): with pytest.raises(KeyboardInterrupt): await sleep(10) - _core.run(main, restrict_keyboard_interrupt_to_checkpoints=True) + _core.run(main_f, restrict_keyboard_interrupt_to_checkpoints=True) assert record == ["ok"] record = [] # Exact same code raises KI early if we leave off the argument, doesn't # even reach the record.append call: with pytest.raises(KeyboardInterrupt): - _core.run(main) + _core.run(main_f) assert record == [] # KI arrives while main task is inside a cancelled cancellation scope @@ -446,7 +447,7 @@ async def main(): print("check 11") @_core.enable_ki_protection - async def main(): + async def main_g() -> None: assert _core.currently_ki_protected() with _core.CancelScope() as cancel_scope: cancel_scope.cancel() @@ -458,10 +459,10 @@ async def main(): with pytest.raises(_core.Cancelled): await _core.checkpoint() - _core.run(main) + _core.run(main_g) -def test_ki_is_good_neighbor(): +def test_ki_is_good_neighbor() -> None: # in the unlikely event someone overwrites our signal handler, we leave # the overwritten one be try: @@ -470,7 +471,7 @@ def test_ki_is_good_neighbor(): def my_handler(signum, frame): # pragma: no cover pass - async def main(): + async def main() -> None: signal.signal(signal.SIGINT, my_handler) _core.run(main) @@ -481,31 +482,31 @@ async def main(): # Regression test for #461 -def test_ki_with_broken_threads(): +def test_ki_with_broken_threads() -> None: thread = threading.main_thread() # scary! - original = threading._active[thread.ident] + original = threading._active[thread.ident] # type: ignore[attr-defined] # put this in a try finally so we don't have a chance of cascading a # breakage down to everything else try: - del threading._active[thread.ident] + del threading._active[thread.ident] # type: ignore[attr-defined] @_core.enable_ki_protection - async def inner(): + async def inner() -> None: assert signal.getsignal(signal.SIGINT) != signal.default_int_handler _core.run(inner) finally: - threading._active[thread.ident] = original + threading._active[thread.ident] = original # type: ignore[attr-defined] # For details on why this test is non-trivial, see: # https://github.com/python-trio/trio/issues/42 # https://github.com/python-trio/trio/issues/109 @slow -def test_ki_wakes_us_up(): +def test_ki_wakes_us_up() -> None: assert is_main_thread() # This test is flaky due to a race condition on Windows; see: @@ -562,7 +563,7 @@ def test_ki_wakes_us_up(): # It will be very nice when the buggy_wakeup_fd bug is fixed. lock = threading.Lock() - def kill_soon(): + def kill_soon() -> None: # We want the signal to be raised after the main thread has entered # the IO manager blocking primitive. There really is no way to # deterministically interlock with that, so we have to use sleep and @@ -575,7 +576,7 @@ def kill_soon(): print("buggy_wakeup_fd =", buggy_wakeup_fd) ki_self() - async def main(): + async def main() -> None: thread = threading.Thread(target=kill_soon) print("Starting thread") thread.start() diff --git a/trio/_core/tests/test_local.py b/trio/_core/tests/test_local.py index 7f403168ea..27536fb146 100644 --- a/trio/_core/tests/test_local.py +++ b/trio/_core/tests/test_local.py @@ -4,13 +4,13 @@ # scary runvar tests -def test_runvar_smoketest(): +def test_runvar_smoketest() -> None: t1 = _core.RunVar("test1") t2 = _core.RunVar("test2", default="catfish") assert "RunVar" in repr(t1) - async def first_check(): + async def first_check() -> None: with pytest.raises(LookupError): t1.get() @@ -23,7 +23,7 @@ async def first_check(): assert t2.get() == "goldfish" assert t2.get(default="tuna") == "goldfish" - async def second_check(): + async def second_check() -> None: with pytest.raises(LookupError): t1.get() @@ -33,12 +33,12 @@ async def second_check(): _core.run(second_check) -def test_runvar_resetting(): +def test_runvar_resetting() -> None: t1 = _core.RunVar("test1") t2 = _core.RunVar("test2", default="dogfish") t3 = _core.RunVar("test3") - async def reset_check(): + async def reset_check() -> None: token = t1.set("moonfish") assert t1.get() == "moonfish" t1.reset(token) @@ -66,11 +66,11 @@ async def reset_check(): _core.run(reset_check) -def test_runvar_sync(): +def test_runvar_sync() -> None: t1 = _core.RunVar("test1") - async def sync_check(): - async def task1(): + async def sync_check() -> None: + async def task1() -> None: t1.set("plaice") assert t1.get() == "plaice" @@ -97,7 +97,7 @@ async def task2(tok): _core.run(sync_check) -def test_accessing_runvar_outside_run_call_fails(): +def test_accessing_runvar_outside_run_call_fails() -> None: t1 = _core.RunVar("test1") with pytest.raises(RuntimeError): diff --git a/trio/_core/tests/test_mock_clock.py b/trio/_core/tests/test_mock_clock.py index bea9509686..321bf94b42 100644 --- a/trio/_core/tests/test_mock_clock.py +++ b/trio/_core/tests/test_mock_clock.py @@ -10,7 +10,7 @@ from .tutil import slow -def test_mock_clock(): +def test_mock_clock() -> None: REAL_NOW = 123.0 c = MockClock() c._real_clock = lambda: REAL_NOW @@ -54,7 +54,7 @@ def test_mock_clock(): assert c2.current_time() < 10 -async def test_mock_clock_autojump(mock_clock): +async def test_mock_clock_autojump(mock_clock) -> None: assert mock_clock.autojump_threshold == inf mock_clock.autojump_threshold = 0 @@ -94,7 +94,7 @@ async def test_mock_clock_autojump(mock_clock): await sleep(100000) -async def test_mock_clock_autojump_interference(mock_clock): +async def test_mock_clock_autojump_interference(mock_clock) -> None: mock_clock.autojump_threshold = 0.02 mock_clock2 = MockClock() @@ -111,7 +111,7 @@ async def test_mock_clock_autojump_interference(mock_clock): await sleep(100000) -def test_mock_clock_autojump_preset(): +def test_mock_clock_autojump_preset() -> None: # Check that we can set the autojump_threshold before the clock is # actually in use, and it gets picked up mock_clock = MockClock(autojump_threshold=0.1) @@ -121,7 +121,7 @@ def test_mock_clock_autojump_preset(): assert time.perf_counter() - real_start < 1 -async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock): +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock) -> None: # Checks that autojump_threshold=0 doesn't interfere with # calling wait_all_tasks_blocked with the default cushion=0. @@ -129,11 +129,11 @@ async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock): record = [] - async def sleeper(): + async def sleeper() -> None: await sleep(100) record.append("yawn") - async def waiter(): + async def waiter() -> None: await wait_all_tasks_blocked() record.append("waiter woke") await sleep(1000) @@ -147,7 +147,9 @@ async def waiter(): @slow -async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clock): +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero( + mock_clock: MockClock, +) -> None: # Checks that autojump_threshold=0 doesn't interfere with # calling wait_all_tasks_blocked with a non-zero cushion. @@ -155,11 +157,11 @@ async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clo record = [] - async def sleeper(): + async def sleeper() -> None: await sleep(100) record.append("yawn") - async def waiter(): + async def waiter() -> None: await wait_all_tasks_blocked(1) record.append("waiter done") diff --git a/trio/_core/tests/test_multierror.py b/trio/_core/tests/test_multierror.py index 14eab22df7..2872a4266b 100644 --- a/trio/_core/tests/test_multierror.py +++ b/trio/_core/tests/test_multierror.py @@ -6,7 +6,7 @@ print_exception, format_exception, ) -from traceback import _cause_message # type: ignore +from traceback import _cause_message # type: ignore[attr-defined] import sys import os import re @@ -22,7 +22,7 @@ class NotHashableException(Exception): code = None - def __init__(self, code): + def __init__(self, code) -> None: super().__init__() self.code = code @@ -36,27 +36,27 @@ async def raise_nothashable(code): raise NotHashableException(code) -def raiser1(): +def raiser1() -> None: raiser1_2() -def raiser1_2(): +def raiser1_2() -> None: raiser1_3() -def raiser1_3(): +def raiser1_3() -> None: raise ValueError("raiser1_string") -def raiser2(): +def raiser2() -> None: raiser2_2() -def raiser2_2(): +def raiser2_2() -> None: raise KeyError("raiser2_string") -def raiser3(): +def raiser3() -> None: raise NameError @@ -75,7 +75,7 @@ def einfo(exc): return (type(exc), exc, exc.__traceback__) -def test_concat_tb(): +def test_concat_tb() -> None: tb1 = get_tb(raiser1) tb2 = get_tb(raiser2) @@ -101,7 +101,7 @@ def test_concat_tb(): assert extract_tb(get_tb(raiser2)) == entries2 -def test_MultiError(): +def test_MultiError() -> None: exc1 = get_exc(raiser1) exc2 = get_exc(raiser2) @@ -117,7 +117,7 @@ def test_MultiError(): MultiError([KeyError(), ValueError]) -def test_MultiErrorOfSingleMultiError(): +def test_MultiErrorOfSingleMultiError() -> None: # For MultiError([MultiError]), ensure there is no bad recursion by the # constructor where __init__ is called if __new__ returns a bare MultiError. exceptions = [KeyError(), ValueError()] @@ -127,7 +127,7 @@ def test_MultiErrorOfSingleMultiError(): assert b.exceptions == exceptions -async def test_MultiErrorNotHashable(): +async def test_MultiErrorNotHashable() -> None: exc1 = NotHashableException(42) exc2 = NotHashableException(4242) exc3 = ValueError() @@ -140,7 +140,7 @@ async def test_MultiErrorNotHashable(): nursery.start_soon(raise_nothashable, 4242) -def test_MultiError_filter_NotHashable(): +def test_MultiError_filter_NotHashable() -> None: excs = MultiError([NotHashableException(42), ValueError()]) def handle_ValueError(exc): @@ -153,7 +153,7 @@ def handle_ValueError(exc): assert isinstance(filtered_excs, NotHashableException) -def test_traceback_recursion(): +def test_traceback_recursion() -> None: exc1 = RuntimeError() exc2 = KeyError() exc3 = NotHashableException(42) @@ -205,7 +205,7 @@ def assert_tree_eq(m1, m2): assert_tree_eq(e1, e2) -def test_MultiError_filter(): +def test_MultiError_filter() -> None: def null_handler(exc): return exc @@ -283,7 +283,7 @@ def filter_all(exc): assert MultiError.filter(filter_all, make_tree()) is None -def test_MultiError_catch(): +def test_MultiError_catch() -> None: # No exception to catch def noop(_): @@ -380,14 +380,14 @@ def assert_match_in_seq(pattern_list, string): offset = match.end() -def test_assert_match_in_seq(): +def test_assert_match_in_seq() -> None: assert_match_in_seq(["a", "b"], "xx a xx b xx") assert_match_in_seq(["b", "a"], "xx b xx a xx") with pytest.raises(AssertionError): assert_match_in_seq(["a", "b"], "xx b xx a xx") -def test_format_exception(): +def test_format_exception() -> None: exc = get_exc(raiser1) formatted = "".join(format_exception(*einfo(exc))) assert "raiser1_string" in formatted @@ -506,13 +506,13 @@ def test_format_exception(): # Prints duplicate exceptions in sub-exceptions exc1 = get_exc(raiser1) - def raise1_raiser1(): + def raise1_raiser1() -> None: try: raise exc1 except: raise ValueError("foo") - def raise2_raiser1(): + def raise2_raiser1() -> None: try: raise exc1 except: @@ -565,7 +565,7 @@ def raise2_raiser1(): ) -def test_logging(caplog): +def test_logging(caplog) -> None: exc1 = get_exc(raiser1) exc2 = get_exc(raiser2) @@ -638,12 +638,12 @@ def check_simple_excepthook(completed): ) -def test_simple_excepthook(): +def test_simple_excepthook() -> None: completed = run_script("simple_excepthook.py") check_simple_excepthook(completed) -def test_custom_excepthook(): +def test_custom_excepthook() -> None: # Check that user-defined excepthooks aren't overridden completed = run_script("custom_excepthook.py") assert_match_in_seq( @@ -682,20 +682,20 @@ def test_custom_excepthook(): @slow @need_ipython -def test_ipython_exc_handler(): +def test_ipython_exc_handler() -> None: completed = run_script("simple_excepthook.py", use_ipython=True) check_simple_excepthook(completed) @slow @need_ipython -def test_ipython_imported_but_unused(): +def test_ipython_imported_but_unused() -> None: completed = run_script("simple_excepthook_IPython.py") check_simple_excepthook(completed) @slow -def test_partial_imported_but_unused(): +def test_partial_imported_but_unused() -> None: # Check that a functools.partial as sys.excepthook doesn't cause an exception when # importing trio. This was a problem due to the lack of a .__name__ attribute and # happens when inside a pytest-qt test case for example. @@ -705,7 +705,7 @@ def test_partial_imported_but_unused(): @slow @need_ipython -def test_ipython_custom_exc_handler(): +def test_ipython_custom_exc_handler() -> None: # Check we get a nice warning (but only one!) if the user is using IPython # and already has some other set_custom_exc handler installed. completed = run_script("ipython_custom_exc.py", use_ipython=True) @@ -731,7 +731,7 @@ def test_ipython_custom_exc_handler(): not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(), reason="need Ubuntu with python3-apport installed", ) -def test_apport_excepthook_monkeypatch_interaction(): +def test_apport_excepthook_monkeypatch_interaction() -> None: completed = run_script("apport_excepthook.py") stdout = completed.stdout.decode("utf-8") diff --git a/trio/_core/tests/test_parking_lot.py b/trio/_core/tests/test_parking_lot.py index 13ffe0c066..9dcdfb490a 100644 --- a/trio/_core/tests/test_parking_lot.py +++ b/trio/_core/tests/test_parking_lot.py @@ -6,7 +6,7 @@ from .tutil import check_sequence_matches -async def test_parking_lot_basic(): +async def test_parking_lot_basic() -> None: record = [] async def waiter(i, lot): @@ -85,7 +85,7 @@ async def cancellable_waiter(name, lot, scopes, record): record.append("wake {}".format(name)) -async def test_parking_lot_cancel(): +async def test_parking_lot_cancel() -> None: record = [] scopes = {} @@ -111,7 +111,7 @@ async def test_parking_lot_cancel(): ) -async def test_parking_lot_repark(): +async def test_parking_lot_repark() -> None: record = [] scopes = {} lot1 = ParkingLot() @@ -165,7 +165,7 @@ async def test_parking_lot_repark(): ] -async def test_parking_lot_repark_with_count(): +async def test_parking_lot_repark_with_count() -> None: record = [] scopes = {} lot1 = ParkingLot() diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index d2bcdfd740..1990449841 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -11,6 +11,7 @@ from math import inf from textwrap import dedent import gc +from typing import Iterator, TypeVar import attr import outcome @@ -36,6 +37,9 @@ ) +_T = TypeVar("_T") + + # slightly different from _timeouts.sleep_forever because it returns the value # its rescheduled with, which is really only useful for tests of # rescheduling... @@ -43,7 +47,7 @@ async def sleep_forever(): return await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -def test_basic(): +def test_basic() -> None: async def trivial(x): return x @@ -64,7 +68,7 @@ async def trivial2(x): assert _core.run(trivial2, 1) == 1 -def test_initial_task_error(): +def test_initial_task_error() -> None: async def main(x): raise ValueError(x) @@ -73,9 +77,9 @@ async def main(x): assert excinfo.value.args == (17,) -def test_run_nesting(): +def test_run_nesting() -> None: async def inception(): - async def main(): # pragma: no cover + async def main() -> None: # pragma: no cover pass return _core.run(main) @@ -85,7 +89,7 @@ async def main(): # pragma: no cover assert "from inside" in str(excinfo.value) -async def test_nursery_warn_use_async_with(): +async def test_nursery_warn_use_async_with() -> None: with pytest.raises(RuntimeError) as excinfo: on = _core.open_nursery() with on: @@ -99,7 +103,7 @@ async def test_nursery_warn_use_async_with(): pass -async def test_nursery_main_block_error_basic(): +async def test_nursery_main_block_error_basic() -> None: exc = ValueError("whoops") with pytest.raises(ValueError) as excinfo: @@ -108,10 +112,10 @@ async def test_nursery_main_block_error_basic(): assert excinfo.value is exc -async def test_child_crash_basic(): +async def test_child_crash_basic() -> None: exc = ValueError("uh oh") - async def erroring(): + async def erroring() -> None: raise exc try: @@ -122,7 +126,7 @@ async def erroring(): assert e is exc -async def test_basic_interleave(): +async def test_basic_interleave() -> None: async def looper(whoami, record): for i in range(3): record.append((whoami, i)) @@ -138,10 +142,10 @@ async def looper(whoami, record): ) -def test_task_crash_propagation(): +def test_task_crash_propagation() -> None: looper_record = [] - async def looper(): + async def looper() -> None: try: while True: await _core.checkpoint() @@ -149,10 +153,10 @@ async def looper(): print("looper cancelled") looper_record.append("cancelled") - async def crasher(): + async def crasher() -> None: raise ValueError("argh") - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(looper) nursery.start_soon(crasher) @@ -164,13 +168,13 @@ async def main(): assert excinfo.value.args == ("argh",) -def test_main_and_task_both_crash(): +def test_main_and_task_both_crash() -> None: # If main crashes and there's also a task crash, then we get both in a # MultiError - async def crasher(): + async def crasher() -> None: raise ValueError - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher) raise KeyError @@ -184,11 +188,11 @@ async def main(): } -def test_two_child_crashes(): +def test_two_child_crashes() -> None: async def crasher(etype): raise etype - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher, KeyError) nursery.start_soon(crasher, ValueError) @@ -201,8 +205,8 @@ async def main(): } -async def test_child_crash_wakes_parent(): - async def crasher(): +async def test_child_crash_wakes_parent() -> None: + async def crasher() -> None: raise ValueError with pytest.raises(ValueError): @@ -211,11 +215,11 @@ async def crasher(): await sleep_forever() -async def test_reschedule(): +async def test_reschedule() -> None: t1 = None t2 = None - async def child1(): + async def child1() -> None: nonlocal t1, t2 t1 = _core.current_task() print("child1 start") @@ -226,7 +230,7 @@ async def child1(): _core.reschedule(t2, outcome.Error(ValueError())) print("child1 exit") - async def child2(): + async def child2() -> None: nonlocal t1, t2 print("child2 start") t2 = _core.current_task() @@ -243,7 +247,7 @@ async def child2(): nursery.start_soon(child2) -async def test_current_time(): +async def test_current_time() -> None: t1 = _core.current_time() # Windows clock is pretty low-resolution -- appveyor tests fail unless we # sleep for a bit here. @@ -252,7 +256,7 @@ async def test_current_time(): assert t1 < t2 -async def test_current_time_with_mock_clock(mock_clock): +async def test_current_time_with_mock_clock(mock_clock) -> None: start = mock_clock.current_time() assert mock_clock.current_time() == _core.current_time() assert mock_clock.current_time() == _core.current_time() @@ -260,38 +264,38 @@ async def test_current_time_with_mock_clock(mock_clock): assert start + 3.14 == mock_clock.current_time() == _core.current_time() -async def test_current_clock(mock_clock): +async def test_current_clock(mock_clock) -> None: assert mock_clock is _core.current_clock() -async def test_current_task(): +async def test_current_task() -> None: parent_task = _core.current_task() - async def child(): + async def child() -> None: assert _core.current_task().parent_nursery.parent_task is parent_task async with _core.open_nursery() as nursery: nursery.start_soon(child) -async def test_root_task(): +async def test_root_task() -> None: root = _core.current_root_task() assert root.parent_nursery is root.eventual_parent_nursery is None -def test_out_of_context(): +def test_out_of_context() -> None: with pytest.raises(RuntimeError): _core.current_task() with pytest.raises(RuntimeError): _core.current_time() -async def test_current_statistics(mock_clock): +async def test_current_statistics(mock_clock) -> None: # Make sure all the early startup stuff has settled down await wait_all_tasks_blocked() # A child that sticks around to make some interesting stats: - async def child(): + async def child() -> None: try: await sleep_forever() except _core.Cancelled: @@ -338,7 +342,7 @@ async def child(): assert stats.seconds_to_next_deadline == inf -async def test_cancel_scope_repr(mock_clock): +async def test_cancel_scope_repr(mock_clock) -> None: scope = _core.CancelScope() assert "unbound" in repr(scope) with scope: @@ -354,8 +358,8 @@ async def test_cancel_scope_repr(mock_clock): assert "exited" in repr(scope) -def test_cancel_points(): - async def main1(): +def test_cancel_points() -> None: + async def main1() -> None: with _core.CancelScope() as scope: await _core.checkpoint_if_cancelled() scope.cancel() @@ -364,7 +368,7 @@ async def main1(): _core.run(main1) - async def main2(): + async def main2() -> None: with _core.CancelScope() as scope: await _core.checkpoint() scope.cancel() @@ -373,7 +377,7 @@ async def main2(): _core.run(main2) - async def main3(): + async def main3() -> None: with _core.CancelScope() as scope: scope.cancel() with pytest.raises(_core.Cancelled): @@ -381,7 +385,7 @@ async def main3(): _core.run(main3) - async def main4(): + async def main4() -> None: with _core.CancelScope() as scope: scope.cancel() await _core.cancel_shielded_checkpoint() @@ -392,7 +396,7 @@ async def main4(): _core.run(main4) -async def test_cancel_edge_cases(): +async def test_cancel_edge_cases() -> None: with _core.CancelScope() as scope: # Two cancels in a row -- idempotent scope.cancel() @@ -410,8 +414,8 @@ async def test_cancel_edge_cases(): await sleep_forever() -async def test_cancel_scope_multierror_filtering(): - async def crasher(): +async def test_cancel_scope_multierror_filtering() -> None: + async def crasher() -> None: raise KeyError try: @@ -453,13 +457,13 @@ async def crasher(): assert False -async def test_precancelled_task(): +async def test_precancelled_task() -> None: # a task that gets spawned into an already-cancelled nursery should begin # execution (https://github.com/python-trio/trio/issues/41), but get a # cancelled error at its first blocking call. record = [] - async def blocker(): + async def blocker() -> None: record.append("started") await sleep_forever() @@ -469,7 +473,7 @@ async def blocker(): assert record == ["started"] -async def test_cancel_shielding(): +async def test_cancel_shielding() -> None: with _core.CancelScope() as outer: with _core.CancelScope() as inner: await _core.checkpoint() @@ -510,7 +514,7 @@ async def test_cancel_shielding(): # make sure that cancellation propagates immediately to all children -async def test_cancel_inheritance(): +async def test_cancel_inheritance() -> None: record = set() async def leaf(ident): @@ -532,7 +536,7 @@ async def worker(ident): assert record == {"w1-l1", "w1-l2", "w2-l1", "w2-l2"} -async def test_cancel_shield_abort(): +async def test_cancel_shield_abort() -> None: with _core.CancelScope() as outer: async with _core.open_nursery() as nursery: outer.cancel() @@ -541,7 +545,7 @@ async def test_cancel_shield_abort(): # shield, so it manages to get to sleep record = [] - async def sleeper(): + async def sleeper() -> None: record.append("sleeping") try: await sleep_forever() @@ -563,7 +567,7 @@ async def sleeper(): assert record == ["sleeping", "cancelled"] -async def test_basic_timeout(mock_clock): +async def test_basic_timeout(mock_clock) -> None: start = _core.current_time() with _core.CancelScope() as scope: assert scope.deadline == inf @@ -600,7 +604,7 @@ async def test_basic_timeout(mock_clock): await _core.checkpoint() -async def test_cancel_scope_nesting(): +async def test_cancel_scope_nesting() -> None: # Nested scopes: if two triggering at once, the outer one wins with _core.CancelScope() as scope1: with _core.CancelScope() as scope2: @@ -639,7 +643,7 @@ async def test_cancel_scope_nesting(): # Regression test for https://github.com/python-trio/trio/issues/1175 -async def test_unshield_while_cancel_propagating(): +async def test_unshield_while_cancel_propagating() -> None: with _core.CancelScope() as outer: with _core.CancelScope() as inner: outer.cancel() @@ -650,7 +654,7 @@ async def test_unshield_while_cancel_propagating(): assert outer.cancelled_caught and not inner.cancelled_caught -async def test_cancel_unbound(): +async def test_cancel_unbound() -> None: async def sleep_until_cancelled(scope): with scope, fail_after(1): await sleep_forever() @@ -700,7 +704,7 @@ async def sleep_until_cancelled(scope): # Can't enter from multiple tasks simultaneously scope = _core.CancelScope() - async def enter_scope(): + async def enter_scope() -> None: with scope: await sleep_forever() @@ -724,7 +728,7 @@ async def enter_scope(): assert scope.cancel_called # never become un-cancelled -async def test_cancel_scope_misnesting(): +async def test_cancel_scope_misnesting() -> None: outer = _core.CancelScope() inner = _core.CancelScope() with ExitStack() as stack: @@ -736,12 +740,12 @@ async def test_cancel_scope_misnesting(): # If there are other tasks inside the abandoned part of the cancel tree, # they get cancelled when the misnesting is detected - async def task1(): + async def task1() -> None: with pytest.raises(_core.Cancelled): await sleep_forever() # Even if inside another cancel scope - async def task2(): + async def task2() -> None: with _core.CancelScope(): with pytest.raises(_core.Cancelled): await sleep_forever() @@ -793,7 +797,7 @@ async def task3(task_status): @slow -async def test_timekeeping(): +async def test_timekeeping() -> None: # probably a good idea to use a real clock for *one* test anyway... TARGET = 1.0 # give it a few tries in case of random CI server flakiness @@ -813,7 +817,7 @@ async def test_timekeeping(): assert False -async def test_failed_abort(): +async def test_failed_abort() -> None: stubborn_task = [None] stubborn_scope = [None] record = [] @@ -844,8 +848,8 @@ async def stubborn_sleeper(): assert record == ["sleep", "woke", "cancelled"] -def test_broken_abort(): - async def main(): +def test_broken_abort() -> None: + async def main() -> None: # These yields are here to work around an annoying warning -- we're # going to crash the main loop, and if we (by chance) do this before # the run_sync_soon task runs for the first time, then Python gives us @@ -870,9 +874,9 @@ async def main(): gc_collect_harder() -def test_error_in_run_loop(): +def test_error_in_run_loop() -> None: # Blow stuff up real good to check we at least get a TrioInternalError - async def main(): + async def main() -> None: task = _core.current_task() task._schedule_points = "hello!" await _core.checkpoint() @@ -882,7 +886,7 @@ async def main(): _core.run(main) -async def test_spawn_system_task(): +async def test_spawn_system_task() -> None: record = [] async def system_task(x): @@ -896,11 +900,11 @@ async def system_task(x): # intentionally make a system task crash -def test_system_task_crash(): - async def crasher(): +def test_system_task_crash() -> None: + async def crasher() -> None: raise KeyError - async def main(): + async def main() -> None: _core.spawn_system_task(crasher) await sleep_forever() @@ -908,19 +912,19 @@ async def main(): _core.run(main) -def test_system_task_crash_MultiError(): - async def crasher1(): +def test_system_task_crash_MultiError() -> None: + async def crasher1() -> None: raise KeyError - async def crasher2(): + async def crasher2() -> None: raise ValueError - async def system_task(): + async def system_task() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher1) nursery.start_soon(crasher2) - async def main(): + async def main() -> None: _core.spawn_system_task(system_task) await sleep_forever() @@ -934,24 +938,24 @@ async def main(): assert isinstance(exc, (KeyError, ValueError)) -def test_system_task_crash_plus_Cancelled(): +def test_system_task_crash_plus_Cancelled() -> None: # Set up a situation where a system task crashes with a # MultiError([Cancelled, ValueError]) - async def crasher(): + async def crasher() -> None: try: await sleep_forever() except _core.Cancelled: raise ValueError - async def cancelme(): + async def cancelme() -> None: await sleep_forever() - async def system_task(): + async def system_task() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(crasher) nursery.start_soon(cancelme) - async def main(): + async def main() -> None: _core.spawn_system_task(system_task) # then we exit, triggering a cancellation @@ -960,11 +964,11 @@ async def main(): assert type(excinfo.value.__cause__) is ValueError -def test_system_task_crash_KeyboardInterrupt(): - async def ki(): +def test_system_task_crash_KeyboardInterrupt() -> None: + async def ki() -> None: raise KeyboardInterrupt - async def main(): + async def main() -> None: _core.spawn_system_task(ki) await sleep_forever() @@ -982,7 +986,7 @@ async def main(): # 4) this task has timed out # 5) ...but it's on the run queue, so the timeout is queued to be delivered # the next time that it's blocked. -async def test_yield_briefly_checks_for_timeout(mock_clock): +async def test_yield_briefly_checks_for_timeout(mock_clock) -> None: with _core.CancelScope(deadline=_core.current_time() + 5): await _core.checkpoint() with pytest.raises(_core.Cancelled): @@ -996,11 +1000,11 @@ async def test_yield_briefly_checks_for_timeout(mock_clock): # still nice to know that it works :-). # # Update: it turns out I was right to be nervous! see the next test... -async def test_exc_info(): +async def test_exc_info() -> None: record = [] seq = Sequencer() - async def child1(): + async def child1() -> None: with pytest.raises(ValueError) as excinfo: try: async with seq(0): @@ -1017,7 +1021,7 @@ async def child1(): assert excinfo.value.__context__ is None record.append("child1 success") - async def child2(): + async def child2() -> None: with pytest.raises(KeyError) as excinfo: async with seq(1): pass # we don't yield until seq(3) below @@ -1057,10 +1061,10 @@ async def child2(): # like re-raising and exception chaining are broken. # # https://bugs.python.org/issue29587 -async def test_exc_info_after_yield_error(): +async def test_exc_info_after_yield_error() -> None: child_task = None - async def child(): + async def child() -> None: nonlocal child_task child_task = _core.current_task() @@ -1082,10 +1086,10 @@ async def child(): # Similar to previous test -- if the ValueError() gets sent in via 'throw', # then Python's normal implicit chaining stuff is broken. -async def test_exception_chaining_after_yield_error(): +async def test_exception_chaining_after_yield_error() -> None: child_task = None - async def child(): + async def child() -> None: nonlocal child_task child_task = _core.current_task() @@ -1103,8 +1107,8 @@ async def child(): assert isinstance(excinfo.value.__context__, KeyError) -async def test_nursery_exception_chaining_doesnt_make_context_loops(): - async def crasher(): +async def test_nursery_exception_chaining_doesnt_make_context_loops() -> None: + async def crasher() -> None: raise KeyError with pytest.raises(_core.MultiError) as excinfo: @@ -1115,7 +1119,7 @@ async def crasher(): assert excinfo.value.__context__ is None -def test_TrioToken_identity(): +def test_TrioToken_identity() -> None: async def get_and_check_token(): token = _core.current_trio_token() # Two calls in the same run give the same object @@ -1129,7 +1133,7 @@ async def get_and_check_token(): assert hash(t1) != hash(t2) -async def test_TrioToken_run_sync_soon_basic(): +async def test_TrioToken_run_sync_soon_basic() -> None: record = [] def cb(x): @@ -1142,10 +1146,10 @@ def cb(x): assert record == [("cb", 1)] -def test_TrioToken_run_sync_soon_too_late(): +def test_TrioToken_run_sync_soon_too_late() -> None: token = None - async def main(): + async def main() -> None: nonlocal token token = _core.current_trio_token() @@ -1155,7 +1159,7 @@ async def main(): token.run_sync_soon(lambda: None) # pragma: no branch -async def test_TrioToken_run_sync_soon_idempotent(): +async def test_TrioToken_run_sync_soon_idempotent() -> None: record = [] def cb(x): @@ -1182,7 +1186,7 @@ def cb(x): assert record == list(range(100)) -def test_TrioToken_run_sync_soon_idempotent_requeue(): +def test_TrioToken_run_sync_soon_idempotent_requeue() -> None: # We guarantee that if a call has finished, queueing it again will call it # again. Due to the lack of synchronization, this effectively means that # we have to guarantee that once a call has *started*, queueing it again @@ -1196,7 +1200,7 @@ def redo(token): except _core.RunFinishedError: pass - async def main(): + async def main() -> None: token = _core.current_trio_token() token.run_sync_soon(redo, token, idempotent=True) await _core.checkpoint() @@ -1208,7 +1212,7 @@ async def main(): assert len(record) >= 2 -def test_TrioToken_run_sync_soon_after_main_crash(): +def test_TrioToken_run_sync_soon_after_main_crash() -> None: record = [] async def main(): @@ -1224,7 +1228,7 @@ async def main(): assert record == ["sync-cb"] -def test_TrioToken_run_sync_soon_crashes(): +def test_TrioToken_run_sync_soon_crashes() -> None: record = set() async def main(): @@ -1245,7 +1249,7 @@ async def main(): assert record == {"2nd run_sync_soon ran", "cancelled!"} -async def test_TrioToken_run_sync_soon_FIFO(): +async def test_TrioToken_run_sync_soon_FIFO() -> None: N = 100 record = [] token = _core.current_trio_token() @@ -1255,7 +1259,7 @@ async def test_TrioToken_run_sync_soon_FIFO(): assert record == list(range(N)) -def test_TrioToken_run_sync_soon_starvation_resistance(): +def test_TrioToken_run_sync_soon_starvation_resistance() -> None: # Even if we push callbacks in from callbacks, so that the callback queue # never empties out, then we still can't starve out other tasks from # running. @@ -1269,7 +1273,7 @@ def naughty_cb(i): except _core.RunFinishedError: record.append(("run finished", i)) - async def main(): + async def main() -> None: nonlocal token token = _core.current_trio_token() token.run_sync_soon(naughty_cb, 0) @@ -1284,10 +1288,10 @@ async def main(): assert record[1][1] >= 19 -def test_TrioToken_run_sync_soon_threaded_stress_test(): +def test_TrioToken_run_sync_soon_threaded_stress_test() -> None: cb_counter = 0 - def cb(): + def cb() -> None: nonlocal cb_counter cb_counter += 1 @@ -1299,7 +1303,7 @@ def stress_thread(token): except _core.RunFinishedError: pass - async def main(): + async def main() -> None: token = _core.current_trio_token() thread = threading.Thread(target=stress_thread, args=(token,)) thread.start() @@ -1312,7 +1316,7 @@ async def main(): print(cb_counter) -async def test_TrioToken_run_sync_soon_massive_queue(): +async def test_TrioToken_run_sync_soon_massive_queue() -> None: # There are edge cases in the wakeup fd code when the wakeup fd overflows, # so let's try to make that happen. This is also just a good stress test # in general. (With the current-as-of-2017-02-14 code using a socketpair @@ -1335,7 +1339,7 @@ def cb(i): @pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy") -def test_TrioToken_run_sync_soon_late_crash(): +def test_TrioToken_run_sync_soon_late_crash() -> None: # Crash after system nursery is closed -- easiest way to do that is # from an async generator finalizer. record = [] @@ -1349,7 +1353,7 @@ async def agen(): token.run_sync_soon(lambda: {}["nope"]) token.run_sync_soon(lambda: record.append("2nd ran")) - async def main(): + async def main() -> None: saved.append(agen()) await saved[-1].asend(None) record.append("main exiting") @@ -1361,7 +1365,7 @@ async def main(): assert record == ["main exiting", "2nd ran"] -async def test_slow_abort_basic(): +async def test_slow_abort_basic() -> None: with _core.CancelScope() as scope: scope.cancel() with pytest.raises(_core.Cancelled): @@ -1376,7 +1380,7 @@ def slow_abort(raise_cancel): await _core.wait_task_rescheduled(slow_abort) -async def test_slow_abort_edge_cases(): +async def test_slow_abort_edge_cases() -> None: record = [] async def slow_aborter(): @@ -1419,7 +1423,7 @@ def slow_abort(raise_cancel): assert record == ["sleeping", "abort-called", "cancelled", "done"] -async def test_task_tree_introspection(): +async def test_task_tree_introspection() -> None: tasks = {} nurseries = {} @@ -1451,7 +1455,7 @@ async def parent(task_status=_core.TASK_STATUS_IGNORED): t = nursery.parent_task nursery = t.parent_nursery - async def child2(): + async def child2() -> None: tasks["child2"] = _core.current_task() assert tasks["parent"].child_nurseries == [nurseries["parent"]] assert nurseries["parent"].child_tasks == frozenset({tasks["child1"]}) @@ -1485,13 +1489,13 @@ async def child1(task_status=_core.TASK_STATUS_IGNORED): assert task.eventual_parent_nursery is None -async def test_nursery_closure(): +async def test_nursery_closure() -> None: async def child1(nursery): # We can add new tasks to the nursery even after entering __aexit__, # so long as there are still tasks running nursery.start_soon(child2) - async def child2(): + async def child2() -> None: pass async with _core.open_nursery() as nursery: @@ -1502,12 +1506,12 @@ async def child2(): nursery.start_soon(child2) -async def test_spawn_name(): +async def test_spawn_name() -> None: async def func1(expected): task = _core.current_task() assert expected in task.name - async def func2(): # pragma: no cover + async def func2() -> None: # pragma: no cover pass async with _core.open_nursery() as nursery: @@ -1519,7 +1523,7 @@ async def func2(): # pragma: no cover spawn_fn(func1, "object", name=object()) -async def test_current_effective_deadline(mock_clock): +async def test_current_effective_deadline(mock_clock) -> None: assert _core.current_effective_deadline() == inf with _core.CancelScope(deadline=5) as scope1: @@ -1541,12 +1545,12 @@ async def test_current_effective_deadline(mock_clock): assert _core.current_effective_deadline() == inf -def test_nice_error_on_bad_calls_to_run_or_spawn(): +def test_nice_error_on_bad_calls_to_run_or_spawn() -> None: def bad_call_run(*args): _core.run(*args) def bad_call_spawn(*args): - async def main(): + async def main() -> None: async with _core.open_nursery() as nursery: nursery.start_soon(*args) @@ -1554,7 +1558,7 @@ async def main(): for bad_call in bad_call_run, bad_call_spawn: - async def f(): # pragma: no cover + async def f() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="expecting an async function"): @@ -1569,13 +1573,13 @@ async def async_gen(arg): # pragma: no cover bad_call(async_gen, 0) -def test_calling_asyncio_function_gives_nice_error(): - async def child_xyzzy(): +def test_calling_asyncio_function_gives_nice_error() -> None: + async def child_xyzzy() -> None: import asyncio await asyncio.Future() - async def misguided(): + async def misguided() -> None: await child_xyzzy() with pytest.raises(TypeError) as excinfo: @@ -1588,7 +1592,7 @@ async def misguided(): ) -async def test_asyncio_function_inside_nursery_does_not_explode(): +async def test_asyncio_function_inside_nursery_does_not_explode() -> None: # Regression test for https://github.com/python-trio/trio/issues/552 with pytest.raises(TypeError) as excinfo: async with _core.open_nursery() as nursery: @@ -1599,7 +1603,7 @@ async def test_asyncio_function_inside_nursery_does_not_explode(): assert "asyncio" in str(excinfo.value) -async def test_trivial_yields(): +async def test_trivial_yields() -> None: with assert_checkpoints(): await _core.checkpoint() @@ -1623,8 +1627,8 @@ async def test_trivial_yields(): } -async def test_nursery_start(autojump_clock): - async def no_args(): # pragma: no cover +async def test_nursery_start(autojump_clock) -> None: + async def no_args() -> None: # pragma: no cover pass # Errors in calling convention get raised immediately from start @@ -1719,7 +1723,7 @@ async def raise_keyerror_after_started(task_status=_core.TASK_STATUS_IGNORED): assert _core.current_time() == t0 -async def test_task_nursery_stack(): +async def test_task_nursery_stack() -> None: task = _core.current_task() assert task._child_nurseries == [] async with _core.open_nursery() as nursery1: @@ -1732,7 +1736,7 @@ async def test_task_nursery_stack(): assert task._child_nurseries == [] -async def test_nursery_start_with_cancelled_nursery(): +async def test_nursery_start_with_cancelled_nursery() -> None: # This function isn't testing task_status, it's using task_status as a # convenient way to get a nursery that we can test spawning stuff into. async def setup_nursery(task_status=_core.TASK_STATUS_IGNORED): @@ -1765,7 +1769,7 @@ async def sleeping_children(fn, *, task_status=_core.TASK_STATUS_IGNORED): target_nursery.cancel_scope.cancel() -async def test_nursery_start_keeps_nursery_open(autojump_clock): +async def test_nursery_start_keeps_nursery_open(autojump_clock) -> None: async def sleep_a_bit(task_status=_core.TASK_STATUS_IGNORED): await sleep(2) task_status.started() @@ -1804,14 +1808,14 @@ async def start_sleep_then_crash(nursery): assert _core.current_time() - t0 == 7 -async def test_nursery_explicit_exception(): +async def test_nursery_explicit_exception() -> None: with pytest.raises(KeyError): async with _core.open_nursery(): raise KeyError() -async def test_nursery_stop_iteration(): - async def fail(): +async def test_nursery_stop_iteration() -> None: + async def fail() -> None: raise ValueError try: @@ -1822,9 +1826,9 @@ async def fail(): assert tuple(map(type, e.exceptions)) == (StopIteration, ValueError) -async def test_nursery_stop_async_iteration(): +async def test_nursery_stop_async_iteration() -> None: class it: - def __init__(self, count): + def __init__(self, count) -> None: self.count = count self.val = 0 @@ -1837,7 +1841,7 @@ async def __anext__(self): return val class async_zip: - def __init__(self, *largs): + def __init__(self, *largs) -> None: self.nexts = [obj.__anext__ for obj in largs] async def _accumulate(self, f, items, i): @@ -1874,8 +1878,8 @@ def handle(exc): assert result == [[0, 0], [1, 1]] -async def test_traceback_frame_removal(): - async def my_child_task(): +async def test_traceback_frame_removal() -> None: + async def my_child_task() -> None: raise KeyError() try: @@ -1898,13 +1902,13 @@ async def my_child_task(): assert frame.f_code is my_child_task.__code__ -def test_contextvar_support(): +def test_contextvar_support() -> None: var = contextvars.ContextVar("test") var.set("before") assert var.get() == "before" - async def inner(): + async def inner() -> None: task = _core.current_task() assert task.context.get(var) == "before" assert var.get() == "before" @@ -1917,15 +1921,15 @@ async def inner(): assert var.get() == "before" -async def test_contextvar_multitask(): +async def test_contextvar_multitask() -> None: var = contextvars.ContextVar("test", default="hmmm") - async def t1(): + async def t1() -> None: assert var.get() == "hmmm" var.set("hmmmm") assert var.get() == "hmmmm" - async def t2(): + async def t2() -> None: assert var.get() == "hmmmm" async with _core.open_nursery() as n: @@ -1937,17 +1941,17 @@ async def t2(): await wait_all_tasks_blocked() -def test_system_task_contexts(): +def test_system_task_contexts() -> None: cvar = contextvars.ContextVar("qwilfish") cvar.set("water") - async def system_task(): + async def system_task() -> None: assert cvar.get() == "water" - async def regular_task(): + async def regular_task() -> None: assert cvar.get() == "poison" - async def inner(): + async def inner() -> None: async with _core.open_nursery() as nursery: cvar.set("poison") nursery.start_soon(regular_task) @@ -1957,25 +1961,25 @@ async def inner(): _core.run(inner) -def test_Nursery_init(): +def test_Nursery_init() -> None: with pytest.raises(TypeError): _core._run.Nursery(None, None) -async def test_Nursery_private_init(): +async def test_Nursery_private_init() -> None: # context manager creation should not raise async with _core.open_nursery() as nursery: assert False == nursery._closed -def test_Nursery_subclass(): +def test_Nursery_subclass() -> None: with pytest.raises(TypeError): class Subclass(_core._run.Nursery): pass -def test_Cancelled_init(): +def test_Cancelled_init() -> None: with pytest.raises(TypeError): raise _core.Cancelled @@ -1986,30 +1990,30 @@ def test_Cancelled_init(): _core.Cancelled._create() -def test_Cancelled_str(): +def test_Cancelled_str() -> None: cancelled = _core.Cancelled._create() assert str(cancelled) == "Cancelled" -def test_Cancelled_subclass(): +def test_Cancelled_subclass() -> None: with pytest.raises(TypeError): class Subclass(_core.Cancelled): pass -def test_CancelScope_subclass(): +def test_CancelScope_subclass() -> None: with pytest.raises(TypeError): class Subclass(_core.CancelScope): pass -def test_sniffio_integration(): +def test_sniffio_integration() -> None: with pytest.raises(sniffio.AsyncLibraryNotFoundError): sniffio.current_async_library() - async def check_inside_trio(): + async def check_inside_trio() -> None: assert sniffio.current_async_library() == "trio" _core.run(check_inside_trio) @@ -2018,7 +2022,7 @@ async def check_inside_trio(): sniffio.current_async_library() -async def test_Task_custom_sleep_data(): +async def test_Task_custom_sleep_data() -> None: task = _core.current_task() assert task.custom_sleep_data is None task.custom_sleep_data = 1 @@ -2028,11 +2032,11 @@ async def test_Task_custom_sleep_data(): @types.coroutine -def async_yield(value): +def async_yield(value: _T) -> Iterator[_T]: yield value -async def test_permanently_detach_coroutine_object(): +async def test_permanently_detach_coroutine_object() -> None: task = None pdco_outcome = None @@ -2068,7 +2072,7 @@ async def detachable_coroutine(task_outcome, yield_value): with pytest.raises(StopIteration): task.coro.send(None) - async def bad_detach(): + async def bad_detach() -> None: async with _core.open_nursery(): with pytest.raises(RuntimeError) as excinfo: await _core.permanently_detach_coroutine_object(outcome.Value(None)) @@ -2078,11 +2082,11 @@ async def bad_detach(): nursery.start_soon(bad_detach) -async def test_detach_and_reattach_coroutine_object(): +async def test_detach_and_reattach_coroutine_object() -> None: unrelated_task = None task = None - async def unrelated_coroutine(): + async def unrelated_coroutine() -> None: nonlocal unrelated_task unrelated_task = _core.current_task() @@ -2124,7 +2128,7 @@ def abort_fn(_): # pragma: no cover # Now it's been reattached, and we can leave the nursery -async def test_detached_coroutine_cancellation(): +async def test_detached_coroutine_cancellation() -> None: abort_fn_called = False task = None @@ -2154,7 +2158,7 @@ def abort_fn(_): assert abort_fn_called -def test_async_function_implemented_in_C(): +def test_async_function_implemented_in_C() -> None: # These used to crash because we'd try to mutate the coroutine object's # cr_frame, but C functions don't have Python frames. @@ -2168,7 +2172,7 @@ async def agen_fn(record): _core.run(agen.__anext__) assert run_record == ["the generator ran"] - async def main(): + async def main() -> None: start_soon_record = [] agen = agen_fn(start_soon_record) async with _core.open_nursery() as nursery: @@ -2178,7 +2182,7 @@ async def main(): _core.run(main) -async def test_very_deep_cancel_scope_nesting(): +async def test_very_deep_cancel_scope_nesting() -> None: # This used to crash with a RecursionError in CancelStatus.recalculate with ExitStack() as exit_stack: outermost_scope = _core.CancelScope() @@ -2188,7 +2192,7 @@ async def test_very_deep_cancel_scope_nesting(): outermost_scope.cancel() -async def test_cancel_scope_deadline_duplicates(): +async def test_cancel_scope_deadline_duplicates() -> None: # This exercises an assert in Deadlines._prune, by intentionally creating # duplicate entries in the deadline heap. now = _core.current_time() @@ -2202,11 +2206,11 @@ async def test_cancel_scope_deadline_duplicates(): @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage(): +async def test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/issues/1770 gc.collect() - async def do_a_cancel(): + async def do_a_cancel() -> None: with _core.CancelScope() as cscope: cscope.cancel() await sleep_forever() @@ -2232,7 +2236,7 @@ async def do_a_cancel(): @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_nursery_cancel_doesnt_create_cyclic_garbage(): +async def test_nursery_cancel_doesnt_create_cyclic_garbage() -> None: # https://github.com/python-trio/trio/issues/1770#issuecomment-730229423 gc.collect() @@ -2255,17 +2259,17 @@ async def test_nursery_cancel_doesnt_create_cyclic_garbage(): @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) -async def test_locals_destroyed_promptly_on_cancel(): +async def test_locals_destroyed_promptly_on_cancel() -> None: destroyed = False - def finalizer(): + def finalizer() -> None: nonlocal destroyed destroyed = True class A: pass - async def task(): + async def task() -> None: a = A() weakref.finalize(a, finalizer) await _core.checkpoint() diff --git a/trio/_core/tests/test_thread_cache.py b/trio/_core/tests/test_thread_cache.py index 0f6e0a0715..e2d55e49a3 100644 --- a/trio/_core/tests/test_thread_cache.py +++ b/trio/_core/tests/test_thread_cache.py @@ -1,18 +1,21 @@ import pytest +import _pytest.monkeypatch import threading from queue import Queue import time import sys +from outcome import Outcome + from .tutil import slow, gc_collect_harder from .. import _thread_cache from .._thread_cache import start_thread_soon, ThreadCache -def test_thread_cache_basics(): +def test_thread_cache_basics() -> None: q = Queue() - def fn(): + def fn() -> None: raise RuntimeError("hi") def deliver(outcome): @@ -25,14 +28,14 @@ def deliver(outcome): outcome.unwrap() -def test_thread_cache_deref(): +def test_thread_cache_deref() -> None: res = [False] class del_me: def __call__(self): return 42 - def __del__(self): + def __del__(self) -> None: res[0] = True q = Queue() @@ -49,7 +52,7 @@ def deliver(outcome): @slow -def test_spawning_new_thread_from_deliver_reuses_starting_thread(): +def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None: # We know that no-one else is using the thread cache, so if we keep # submitting new jobs the instant the previous one is finished, we should # keep getting the same thread over and over. This tests both that the @@ -58,7 +61,7 @@ def test_spawning_new_thread_from_deliver_reuses_starting_thread(): # Make sure there are a few threads running, so if we weren't LIFO then we # could grab the wrong one. - q = Queue() + q: "Queue[Outcome]" = Queue() COUNT = 5 for _ in range(COUNT): start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result)) @@ -83,14 +86,15 @@ def deliver(n, _): assert len(seen_threads) == 1 +# can switch to annotating from pytest directly as of 6.2.0 @slow -def test_idle_threads_exit(monkeypatch): +def test_idle_threads_exit(monkeypatch: _pytest.monkeypatch.MonkeyPatch) -> None: # Temporarily set the idle timeout to something tiny, to speed up the # test. (But non-zero, so that the worker loop will at least yield the # CPU.) monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) - q = Queue() + q: "Queue[threading.Thread]" = Queue() start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread())) seen_thread = q.get() # Since the idle timeout is 0, after sleeping for 1 second, the thread @@ -99,7 +103,7 @@ def test_idle_threads_exit(monkeypatch): assert not seen_thread.is_alive() -def test_race_between_idle_exit_and_job_assignment(monkeypatch): +def test_race_between_idle_exit_and_job_assignment(monkeypatch) -> None: # This is a lock where the first few times you try to acquire it with a # timeout, it waits until the lock is available and then pretends to time # out. Using this in our thread cache implementation causes the following @@ -118,7 +122,7 @@ def test_race_between_idle_exit_and_job_assignment(monkeypatch): # everything proceeds as normal. class JankyLock: - def __init__(self): + def __init__(self) -> None: self._lock = threading.Lock() self._counter = 3 @@ -133,7 +137,7 @@ def acquire(self, timeout=None): return False return True - def release(self): + def release(self) -> None: self._lock.release() monkeypatch.setattr(_thread_cache, "Lock", JankyLock) diff --git a/trio/_core/tests/test_tutil.py b/trio/_core/tests/test_tutil.py index eb16de883f..07bba9407d 100644 --- a/trio/_core/tests/test_tutil.py +++ b/trio/_core/tests/test_tutil.py @@ -3,7 +3,7 @@ from .tutil import check_sequence_matches -def test_check_sequence_matches(): +def test_check_sequence_matches() -> None: check_sequence_matches([1, 2, 3], [1, 2, 3]) with pytest.raises(AssertionError): check_sequence_matches([1, 3, 2], [1, 2, 3]) diff --git a/trio/_core/tests/test_unbounded_queue.py b/trio/_core/tests/test_unbounded_queue.py index 801c34ce46..433aae310c 100644 --- a/trio/_core/tests/test_unbounded_queue.py +++ b/trio/_core/tests/test_unbounded_queue.py @@ -10,7 +10,7 @@ ) -async def test_UnboundedQueue_basic(): +async def test_UnboundedQueue_basic() -> None: q = _core.UnboundedQueue() q.put_nowait("hi") assert await q.get_batch() == ["hi"] @@ -35,17 +35,17 @@ async def test_UnboundedQueue_basic(): repr(q) -async def test_UnboundedQueue_blocking(): +async def test_UnboundedQueue_blocking() -> None: record = [] q = _core.UnboundedQueue() - async def get_batch_consumer(): + async def get_batch_consumer() -> None: while True: batch = await q.get_batch() assert batch record.append(batch) - async def aiter_consumer(): + async def aiter_consumer() -> None: async for batch in q: assert batch record.append(batch) @@ -67,7 +67,7 @@ async def aiter_consumer(): nursery.cancel_scope.cancel() -async def test_UnboundedQueue_fairness(): +async def test_UnboundedQueue_fairness() -> None: q = _core.UnboundedQueue() # If there's no-one else around, we can put stuff in and take it out @@ -114,7 +114,7 @@ async def reader(name): assert record == list(zip(itertools.cycle("ab"), [[i] for i in range(20)])) -async def test_UnboundedQueue_trivial_yields(): +async def test_UnboundedQueue_trivial_yields() -> None: q = _core.UnboundedQueue() q.put_nowait(None) @@ -127,7 +127,7 @@ async def test_UnboundedQueue_trivial_yields(): break -async def test_UnboundedQueue_no_spurious_wakeups(): +async def test_UnboundedQueue_no_spurious_wakeups() -> None: # If we have two tasks waiting, and put two items into the queue... then # only one task wakes up record = [] diff --git a/trio/_core/tests/test_windows.py b/trio/_core/tests/test_windows.py index e6bab82204..55eb9e5613 100644 --- a/trio/_core/tests/test_windows.py +++ b/trio/_core/tests/test_windows.py @@ -1,6 +1,9 @@ +from io import BufferedWriter import os import tempfile from contextlib import contextmanager +import sys +from typing import Iterator, Tuple import pytest @@ -22,196 +25,198 @@ ) -# The undocumented API that this is testing should be changed to stop using -# UnboundedQueue (or just removed until we have time to redo it), but until -# then we filter out the warning. -@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning") -async def test_completion_key_listen(): - async def post(key): - iocp = ffi.cast("HANDLE", _core.current_iocp()) - for i in range(10): - print("post", i) - if i % 3 == 0: - await _core.checkpoint() - success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL) - assert success - - with _core.monitor_completion_key() as (key, queue): - async with _core.open_nursery() as nursery: - nursery.start_soon(post, key) - i = 0 - print("loop") - async for batch in queue: # pragma: no branch - print("got some", batch) - for info in batch: - assert info.lpOverlapped == 0 - assert info.dwNumberOfBytesTransferred == i - i += 1 - if i == 10: - break - print("end loop") - - -async def test_readinto_overlapped(): - data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024 - buffer = bytearray(len(data)) - - with tempfile.TemporaryDirectory() as tdir: - tfile = os.path.join(tdir, "numbers.txt") - with open(tfile, "wb") as fp: - fp.write(data) - fp.flush() - - rawname = tfile.encode("utf-16le") + b"\0\0" - rawname_buf = ffi.from_buffer(rawname) - handle = kernel32.CreateFileW( - ffi.cast("LPCWSTR", rawname_buf), - FileFlags.GENERIC_READ, - FileFlags.FILE_SHARE_READ, - ffi.NULL, # no security attributes - FileFlags.OPEN_EXISTING, - FileFlags.FILE_FLAG_OVERLAPPED, - ffi.NULL, # no template file - ) - if handle == INVALID_HANDLE_VALUE: # pragma: no cover - raise_winerror() - - try: - with memoryview(buffer) as buffer_view: - - async def read_region(start, end): - await _core.readinto_overlapped( - handle, buffer_view[start:end], start - ) +# mypy recognizes this. an assert would break the pytest skipif +if sys.platform == "win32": + # The undocumented API that this is testing should be changed to stop using + # UnboundedQueue (or just removed until we have time to redo it), but until + # then we filter out the warning. + @pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning") + async def test_completion_key_listen() -> None: + async def post(key): + iocp = ffi.cast("HANDLE", _core.current_iocp()) + for i in range(10): + print("post", i) + if i % 3 == 0: + await _core.checkpoint() + success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL) + assert success + + with _core.monitor_completion_key() as (key, queue): + async with _core.open_nursery() as nursery: + nursery.start_soon(post, key) + i = 0 + print("loop") + async for batch in queue: # pragma: no branch + print("got some", batch) + for info in batch: + assert info.lpOverlapped == 0 + assert info.dwNumberOfBytesTransferred == i + i += 1 + if i == 10: + break + print("end loop") + + async def test_readinto_overlapped() -> None: + data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024 + buffer = bytearray(len(data)) + + with tempfile.TemporaryDirectory() as tdir: + tfile = os.path.join(tdir, "numbers.txt") + with open(tfile, "wb") as fp: + fp.write(data) + fp.flush() + + rawname = tfile.encode("utf-16le") + b"\0\0" + rawname_buf = ffi.from_buffer(rawname) + handle = kernel32.CreateFileW( + ffi.cast("LPCWSTR", rawname_buf), + FileFlags.GENERIC_READ, + FileFlags.FILE_SHARE_READ, + ffi.NULL, # no security attributes + FileFlags.OPEN_EXISTING, + FileFlags.FILE_FLAG_OVERLAPPED, + ffi.NULL, # no template file + ) + if handle == INVALID_HANDLE_VALUE: # pragma: no cover + raise_winerror() - _core.register_with_iocp(handle) - async with _core.open_nursery() as nursery: - for start in range(0, 4096, 512): - nursery.start_soon(read_region, start, start + 512) - - assert buffer == data - - with pytest.raises(BufferError): - await _core.readinto_overlapped(handle, b"immutable") - finally: - kernel32.CloseHandle(handle) + try: + with memoryview(buffer) as buffer_view: + async def read_region(start, end): + await _core.readinto_overlapped( + handle, buffer_view[start:end], start + ) -@contextmanager -def pipe_with_overlapped_read(): - from asyncio.windows_utils import pipe - import msvcrt + _core.register_with_iocp(handle) + async with _core.open_nursery() as nursery: + for start in range(0, 4096, 512): + nursery.start_soon(read_region, start, start + 512) - read_handle, write_handle = pipe(overlapped=(True, False)) - try: - write_fd = msvcrt.open_osfhandle(write_handle, 0) - yield os.fdopen(write_fd, "wb", closefd=False), read_handle - finally: - kernel32.CloseHandle(ffi.cast("HANDLE", read_handle)) - kernel32.CloseHandle(ffi.cast("HANDLE", write_handle)) + assert buffer == data + with pytest.raises(BufferError): + await _core.readinto_overlapped(handle, b"immutable") + finally: + kernel32.CloseHandle(handle) -def test_forgot_to_register_with_iocp(): - with pipe_with_overlapped_read() as (write_fp, read_handle): - with write_fp: - write_fp.write(b"test\n") + @contextmanager + def pipe_with_overlapped_read() -> Iterator[Tuple[BufferedWriter, int]]: + from asyncio.windows_utils import pipe + import msvcrt - left_run_yet = False + read_handle, write_handle = pipe(overlapped=(True, False)) + try: + write_fd = msvcrt.open_osfhandle(write_handle, 0) + yield os.fdopen(write_fd, "wb", closefd=False), read_handle + finally: + kernel32.CloseHandle(ffi.cast("HANDLE", read_handle)) + kernel32.CloseHandle(ffi.cast("HANDLE", write_handle)) - async def main(): - target = bytearray(1) - try: - async with _core.open_nursery() as nursery: - nursery.start_soon( - _core.readinto_overlapped, read_handle, target, name="xyz" - ) - await wait_all_tasks_blocked() - nursery.cancel_scope.cancel() - finally: - # Run loop is exited without unwinding running tasks, so - # we don't get here until the main() coroutine is GC'ed - assert left_run_yet - - with pytest.raises(_core.TrioInternalError) as exc_info: - _core.run(main) - left_run_yet = True - assert "Failed to cancel overlapped I/O in xyz " in str(exc_info.value) - assert "forget to call register_with_iocp()?" in str(exc_info.value) - - # Make sure the Nursery.__del__ assertion about dangling children - # gets put with the correct test - del exc_info - gc_collect_harder() - - -@slow -async def test_too_late_to_cancel(): - import time - - with pipe_with_overlapped_read() as (write_fp, read_handle): - _core.register_with_iocp(read_handle) - target = bytearray(6) - async with _core.open_nursery() as nursery: - # Start an async read in the background - nursery.start_soon(_core.readinto_overlapped, read_handle, target) - await wait_all_tasks_blocked() - - # Synchronous write to the other end of the pipe + def test_forgot_to_register_with_iocp() -> None: + with pipe_with_overlapped_read() as (write_fp, read_handle): with write_fp: - write_fp.write(b"test1\ntest2\n") - - # Note: not trio.sleep! We're making sure the OS level - # ReadFile completes, before Trio has a chance to execute - # another checkpoint and notice it completed. - time.sleep(1) - nursery.cancel_scope.cancel() - assert target[:6] == b"test1\n" - - # Do another I/O to make sure we've actually processed the - # fallback completion that was posted when CancelIoEx failed. - assert await _core.readinto_overlapped(read_handle, target) == 6 - assert target[:6] == b"test2\n" - - -def test_lsp_that_hooks_select_gives_good_error(monkeypatch): - from .._windows_cffi import WSAIoctls, _handle - from .. import _io_windows - - def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): - if hasattr(sock, "fileno"): # pragma: no branch - sock = sock.fileno() - if which == WSAIoctls.SIO_BSP_HANDLE_SELECT: - return _handle(sock + 1) - else: - return _handle(sock) - - monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying) - with pytest.raises( - RuntimeError, match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ" - ): - _core.run(sleep, 0) - - -def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch): - # This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns - # self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns - # self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to - # make sure we get an error rather than an infinite loop. - - from .._windows_cffi import WSAIoctls, _handle - from .. import _io_windows - - def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): - if hasattr(sock, "fileno"): # pragma: no branch - sock = sock.fileno() - if which == WSAIoctls.SIO_BASE_HANDLE: - raise OSError("nope") - else: - return _handle(sock) - - monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying) - with pytest.raises( - RuntimeError, - match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff", - ): - _core.run(sleep, 0) + write_fp.write(b"test\n") + + left_run_yet = False + + async def main() -> None: + target = bytearray(1) + try: + async with _core.open_nursery() as nursery: + nursery.start_soon( + _core.readinto_overlapped, read_handle, target, name="xyz" + ) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + finally: + # Run loop is exited without unwinding running tasks, so + # we don't get here until the main() coroutine is GC'ed + assert left_run_yet + + with pytest.raises(_core.TrioInternalError) as exc_info: + _core.run(main) + left_run_yet = True + assert "Failed to cancel overlapped I/O in xyz " in str(exc_info.value) + assert "forget to call register_with_iocp()?" in str(exc_info.value) + + # Make sure the Nursery.__del__ assertion about dangling children + # gets put with the correct test + del exc_info + gc_collect_harder() + + @slow + async def test_too_late_to_cancel() -> None: + import time + + with pipe_with_overlapped_read() as (write_fp, read_handle): + _core.register_with_iocp(read_handle) + target = bytearray(6) + async with _core.open_nursery() as nursery: + # Start an async read in the background + nursery.start_soon(_core.readinto_overlapped, read_handle, target) + await wait_all_tasks_blocked() + + # Synchronous write to the other end of the pipe + with write_fp: + write_fp.write(b"test1\ntest2\n") + + # Note: not trio.sleep! We're making sure the OS level + # ReadFile completes, before Trio has a chance to execute + # another checkpoint and notice it completed. + time.sleep(1) + nursery.cancel_scope.cancel() + assert target[:6] == b"test1\n" + + # Do another I/O to make sure we've actually processed the + # fallback completion that was posted when CancelIoEx failed. + assert await _core.readinto_overlapped(read_handle, target) == 6 + assert target[:6] == b"test2\n" + + def test_lsp_that_hooks_select_gives_good_error(monkeypatch) -> None: + from .._windows_cffi import WSAIoctls, _handle + from .. import _io_windows + + def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): + if hasattr(sock, "fileno"): # pragma: no branch + sock = sock.fileno() + if which == WSAIoctls.SIO_BSP_HANDLE_SELECT: + return _handle(sock + 1) + else: + return _handle(sock) + + monkeypatch.setattr( + _io_windows, "_get_underlying_socket", patched_get_underlying + ) + with pytest.raises( + RuntimeError, match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ" + ): + _core.run(sleep, 0) + + def test_lsp_that_completely_hides_base_socket_gives_good_error( + monkeypatch, + ) -> None: + # This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns + # self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns + # self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to + # make sure we get an error rather than an infinite loop. + + from .._windows_cffi import WSAIoctls, _handle + from .. import _io_windows + + def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): + if hasattr(sock, "fileno"): # pragma: no branch + sock = sock.fileno() + if which == WSAIoctls.SIO_BASE_HANDLE: + raise OSError("nope") + else: + return _handle(sock) + + monkeypatch.setattr( + _io_windows, "_get_underlying_socket", patched_get_underlying + ) + with pytest.raises( + RuntimeError, + match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff", + ): + _core.run(sleep, 0) diff --git a/trio/_core/tests/tutil.py b/trio/_core/tests/tutil.py index 00669e883e..2b5cce1403 100644 --- a/trio/_core/tests/tutil.py +++ b/trio/_core/tests/tutil.py @@ -2,7 +2,7 @@ import socket as stdlib_socket import os import sys -from typing import TYPE_CHECKING +from typing import Iterator, TYPE_CHECKING import pytest import warnings @@ -50,7 +50,7 @@ binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6") -def gc_collect_harder(): +def gc_collect_harder() -> None: # In the test suite we sometimes want to call gc.collect() to make sure # that any objects with noisy __del__ methods (e.g. unawaited coroutines) # get collected before we continue, so their noise doesn't leak into @@ -69,7 +69,7 @@ def gc_collect_harder(): # manager should be used anywhere this happens to hide those messages, because # when expected they're clutter. @contextmanager -def ignore_coroutine_never_awaited_warnings(): +def ignore_coroutine_never_awaited_warnings() -> Iterator[None]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited") try: diff --git a/trio/_deprecate.py b/trio/_deprecate.py index 4f9f15ec35..6b9f2d90d5 100644 --- a/trio/_deprecate.py +++ b/trio/_deprecate.py @@ -1,11 +1,15 @@ import sys from functools import wraps from types import ModuleType +from typing import Any, Callable, Dict, Optional, TypeVar, Union import warnings import attr +_T = TypeVar("_T", bound=Callable[..., Any]) + + # We want our warnings to be visible by default (at least for now), but we # also want it to be possible to override that using the -W switch. AFAICT # this means we cannot inherit from DeprecationWarning, because the only way @@ -29,17 +33,24 @@ class TrioDeprecationWarning(FutureWarning): """ -def _url_for_issue(issue): +def _url_for_issue(issue: int) -> str: return "https://github.com/python-trio/trio/issues/{}".format(issue) -def _stringify(thing): +def _stringify(thing: object) -> str: if hasattr(thing, "__module__") and hasattr(thing, "__qualname__"): - return "{}.{}".format(thing.__module__, thing.__qualname__) + return "{}.{}".format(thing.__module__, thing.__qualname__) # type: ignore[attr-defined] return str(thing) -def warn_deprecated(thing, version, *, issue, instead, stacklevel=2): +def warn_deprecated( + thing: object, + version: str, + *, + issue: Optional[int], + instead: Optional[object], + stacklevel: int = 2, +) -> None: stacklevel += 1 msg = "{} is deprecated since Trio {}".format(_stringify(thing), version) if instead is None: @@ -53,20 +64,29 @@ def warn_deprecated(thing, version, *, issue, instead, stacklevel=2): # @deprecated("0.2.0", issue=..., instead=...) # def ... -def deprecated(version, *, thing=None, issue, instead): - def do_wrap(fn): - nonlocal thing - - @wraps(fn) - def wrapper(*args, **kwargs): - warn_deprecated(thing, version, instead=instead, issue=issue) +def deprecated( + version: str, + *, + thing: Optional[str] = None, + issue: Optional[int], + instead: object, +) -> Callable[[_T], _T]: + def do_wrap(fn: _T) -> _T: + wrapper: _T + + @wraps(fn) # type: ignore[no-redef] + def wrapper(*args: object, **kwargs: object) -> object: + warn_deprecated(final_thing, version, instead=instead, issue=issue) return fn(*args, **kwargs) # If our __module__ or __qualname__ get modified, we want to pick up # on that, so we read them off the wrapper object instead of the (now # hidden) fn object + final_thing: Union[str, _T] if thing is None: - thing = wrapper + final_thing = wrapper + else: + final_thing = thing if wrapper.__doc__ is not None: doc = wrapper.__doc__ @@ -87,10 +107,12 @@ def wrapper(*args, **kwargs): return do_wrap -def deprecated_alias(old_qualname, new_fn, version, *, issue): - @deprecated(version, issue=issue, instead=new_fn) +def deprecated_alias(old_qualname: str, new_fn: _T, version: str, *, issue: int) -> _T: + wrapper: _T + + @deprecated(version, issue=issue, instead=new_fn) # type: ignore[no-redef] @wraps(new_fn, assigned=("__module__", "__annotations__")) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> object: "Deprecated alias." return new_fn(*args, **kwargs) @@ -103,14 +125,16 @@ def wrapper(*args, **kwargs): class DeprecatedAttribute: _not_set = object() - value = attr.ib() - version = attr.ib() - issue = attr.ib() - instead = attr.ib(default=_not_set) + value: str = attr.ib() + version: str = attr.ib() + issue: int = attr.ib() + instead: Union[object, str] = attr.ib(default=_not_set) class _ModuleWithDeprecations(ModuleType): - def __getattr__(self, name): + __deprecated_attributes__: Dict[str, DeprecatedAttribute] + + def __getattr__(self, name: str) -> object: if name in self.__deprecated_attributes__: info = self.__deprecated_attributes__[name] instead = info.instead @@ -124,10 +148,10 @@ def __getattr__(self, name): raise AttributeError(msg.format(self.__name__, name)) -def enable_attribute_deprecations(module_name): +def enable_attribute_deprecations(module_name: str) -> None: module = sys.modules[module_name] module.__class__ = _ModuleWithDeprecations # Make sure that this is always defined so that # _ModuleWithDeprecations.__getattr__ can access it without jumping # through hoops or risking infinite recursion. - module.__deprecated_attributes__ = {} + module.__deprecated_attributes__ = {} # type: ignore[attr-defined] diff --git a/trio/_file_io.py b/trio/_file_io.py index 8c8425c775..3c28d7b5e0 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -1,11 +1,39 @@ from functools import partial import io +import os +from typing import ( + Any, + # AnyStr, + # AsyncContextManager, + AsyncIterator, + # Awaitable, + Callable, + # ContextManager, + # FrozenSet, + # Iterator, + # Mapping, + # NoReturn, + Optional, + Sequence, + Union, + # Sequence, + TypeVar, + Tuple, + List, + Iterable, + TextIO, + BinaryIO, + IO, + overload, +) from .abc import AsyncResource from ._util import async_wraps import trio +_TSelf = TypeVar("_TSelf") + # This list is also in the docs, make sure to keep them in sync _FILE_SYNC_ATTRS = { "closed", @@ -58,23 +86,23 @@ class AsyncIOWrapper(AsyncResource): """ - def __init__(self, file): + def __init__(self, file: io.IOBase) -> None: self._wrapped = file @property - def wrapped(self): + def wrapped(self) -> io.IOBase: """object: A reference to the wrapped file object""" return self._wrapped - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: if name in _FILE_SYNC_ATTRS: return getattr(self._wrapped, name) if name in _FILE_ASYNC_METHODS: meth = getattr(self._wrapped, name) @async_wraps(self.__class__, self._wrapped.__class__, name) - async def wrapper(*args, **kwargs): + async def wrapper(*args, **kwargs): # type: ignore[misc, no-untyped-def] func = partial(meth, *args, **kwargs) return await trio.to_thread.run_sync(func) @@ -84,23 +112,23 @@ async def wrapper(*args, **kwargs): raise AttributeError(name) - def __dir__(self): + def __dir__(self) -> Sequence[str]: attrs = set(super().__dir__()) attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a)) attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a)) - return attrs + return attrs # type: ignore[return-value] - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self - async def __anext__(self): - line = await self.readline() + async def __anext__(self) -> str: + line: str = await self.readline() # type: ignore[operator] if line: return line else: raise StopAsyncIteration - async def detach(self): + async def detach(self) -> "_AsyncIOBase": """Like :meth:`io.BufferedIOBase.detach`, but async. This also re-wraps the result in a new :term:`asynchronous file object` @@ -108,10 +136,10 @@ async def detach(self): """ - raw = await trio.to_thread.run_sync(self._wrapped.detach) + raw: Union[io.RawIOBase, BinaryIO] = await trio.to_thread.run_sync(self._wrapped.detach) # type: ignore[attr-defined] return wrap_file(raw) - async def aclose(self): + async def aclose(self) -> None: """Like :meth:`io.IOBase.close`, but async. This is also shielded from cancellation; if a cancellation scope is @@ -126,15 +154,129 @@ async def aclose(self): await trio.lowlevel.checkpoint_if_cancelled() +# _file_io +class _AsyncIOBase(trio.abc.AsyncResource): + closed: bool + + def __aiter__(self) -> AsyncIterator[bytes]: + ... + + async def __anext__(self) -> bytes: + ... + + async def aclose(self) -> None: + ... + + def fileno(self) -> int: + ... + + async def flush(self) -> None: + ... + + def isatty(self) -> bool: + ... + + def readable(self) -> bool: + ... + + async def readlines(self, hint: int = ...) -> List[bytes]: + ... + + async def seek(self, offset: int, whence: int = ...) -> int: + ... + + def seekable(self) -> bool: + ... + + async def tell(self) -> int: + ... + + async def truncate(self, size: Optional[int] = ...) -> int: + ... + + def writable(self) -> bool: + ... + + async def writelines(self, lines: Iterable[bytes]) -> None: + ... + + async def readline(self, size: int = ...) -> bytes: + ... + + +class _AsyncRawIOBase(_AsyncIOBase): + async def readall(self) -> bytes: + ... + + async def readinto(self, b: bytearray) -> Optional[int]: + ... + + async def write(self, b: bytes) -> Optional[int]: + ... + + async def read(self, size: int = ...) -> Optional[bytes]: + ... + + +class _AsyncBufferedIOBase(_AsyncIOBase): + async def detach(self) -> _AsyncRawIOBase: + ... + + async def readinto(self, b: bytearray) -> int: + ... + + async def write(self, b: bytes) -> int: + ... + + async def readinto1(self, b: bytearray) -> int: + ... + + async def read(self, size: Optional[int] = ...) -> bytes: + ... + + async def read1(self, size: int = ...) -> bytes: + ... + + +class _AsyncTextIOBase(_AsyncIOBase): + encoding: str + errors: Optional[str] + newlines: Union[str, Tuple[str, ...], None] + + def __aiter__(self) -> AsyncIterator[str]: # type: ignore + ... + + async def __anext__(self) -> str: # type: ignore + ... + + async def detach(self) -> _AsyncRawIOBase: + ... + + async def write(self, s: str) -> int: + ... + + async def readline(self, size: int = ...) -> str: # type: ignore + ... + + async def read(self, size: Optional[int] = ...) -> str: + ... + + async def seek(self, offset: int, whence: int = ...) -> int: + ... + + async def tell(self) -> int: + ... + + async def open_file( - file, - mode="r", - buffering=-1, - encoding=None, - errors=None, - newline=None, - closefd=True, - opener=None, + file: Union[os.PathLike, int], + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + closefd: bool = True, + opener: Optional[Callable[[str, int], int]] = None, ): """Asynchronous version of :func:`io.open`. @@ -161,6 +303,27 @@ async def open_file( return _file +@overload +def wrap_file(obj: Union[TextIO, io.TextIOBase]) -> _AsyncTextIOBase: + ... + + +@overload +def wrap_file(obj: Union[BinaryIO, io.BufferedIOBase]) -> _AsyncBufferedIOBase: + ... + + +@overload +def wrap_file(obj: io.RawIOBase) -> _AsyncRawIOBase: + ... + + +@overload +def wrap_file(obj: Union[IO[Any], io.IOBase]) -> _AsyncIOBase: + ... + + +# def wrap_file(obj: Union[IO[Any], io.IOBase, io.RawIOBase, BinaryIO, io.BufferedIOBase, TextIO, io.TextIOBase]) -> _AsyncIOBase: def wrap_file(file): """This wraps any file object in a wrapper that provides an asynchronous file object interface. @@ -179,7 +342,7 @@ def wrap_file(file): """ - def has(attr): + def has(attr: str) -> bool: return hasattr(file, attr) and callable(getattr(file, attr)) if not (has("close") and (has("read") or has("write"))): diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index c31b4fdbf3..ff8db28b66 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -1,12 +1,14 @@ +from typing import Optional, Union + import attr import trio -from .abc import HalfCloseableStream +from .abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream from trio._util import Final -async def aclose_forcefully(resource): +async def aclose_forcefully(resource: AsyncResource) -> None: """Close an async resource or async generator immediately, without blocking to do any graceful cleanup. @@ -72,18 +74,18 @@ class StapledStream(HalfCloseableStream, metaclass=Final): """ - send_stream = attr.ib() - receive_stream = attr.ib() + send_stream: SendStream = attr.ib() + receive_stream: ReceiveStream = attr.ib() - async def send_all(self, data): + async def send_all(self, data: Union[bytes, bytearray, memoryview]) -> None: """Calls ``self.send_stream.send_all``.""" return await self.send_stream.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls ``self.send_stream.wait_send_all_might_not_block``.""" return await self.send_stream.wait_send_all_might_not_block() - async def send_eof(self): + async def send_eof(self) -> None: """Shuts down the send side of the stream. If ``self.send_stream.send_eof`` exists, then calls it. Otherwise, @@ -91,15 +93,15 @@ async def send_eof(self): """ if hasattr(self.send_stream, "send_eof"): - return await self.send_stream.send_eof() + return await self.send_stream.send_eof() # type: ignore[no-any-return, attr-defined] else: return await self.send_stream.aclose() - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: """Calls ``self.receive_stream.receive_some``.""" return await self.receive_stream.receive_some(max_bytes) - async def aclose(self): + async def aclose(self) -> None: """Calls ``aclose`` on both underlying streams.""" try: await self.send_stream.aclose() diff --git a/trio/_highlevel_open_tcp_listeners.py b/trio/_highlevel_open_tcp_listeners.py index 80f2c7a180..402fb2a789 100644 --- a/trio/_highlevel_open_tcp_listeners.py +++ b/trio/_highlevel_open_tcp_listeners.py @@ -1,8 +1,12 @@ import errno import sys from math import inf +from typing import Awaitable, Callable, Optional, Union import trio +from . import Nursery +from .abc import Stream +from ._typing import TaskStatus from . import socket as tsocket @@ -144,13 +148,13 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): async def serve_tcp( - handler, - port, + handler: Callable[[Stream], Awaitable[object]], + port: int, *, - host=None, - backlog=None, - handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED, + host: Optional[Union[str, bytes]] = None, + backlog: Optional[int] = None, + handler_nursery: Optional[Nursery] = None, + task_status: TaskStatus = trio.TASK_STATUS_IGNORED, ): """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 545fac8641..e2d7911894 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -1,7 +1,10 @@ from contextlib import contextmanager +from typing import Iterator, Optional, Sequence, Set, Union import trio -from trio.socket import getaddrinfo, SOCK_STREAM, socket +from trio.socket import getaddrinfo, SOCK_STREAM, socket, SocketType +from trio._socket import _Address, _AddressInfo + # Implementation of RFC 6555 "Happy eyeballs" # https://tools.ietf.org/html/rfc6555 @@ -103,8 +106,8 @@ @contextmanager -def close_all(): - sockets_to_close = set() +def close_all() -> Iterator[Set[SocketType]]: + sockets_to_close: Set[SocketType] = set() try: yield sockets_to_close finally: @@ -118,7 +121,7 @@ def close_all(): raise trio.MultiError(errs) -def reorder_for_rfc_6555_section_5_4(targets): +def reorder_for_rfc_6555_section_5_4(targets: _AddressInfo) -> None: # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address # families (e.g. IPv4 and IPv6), then you should make sure that your first # and second attempts use different families: @@ -136,7 +139,7 @@ def reorder_for_rfc_6555_section_5_4(targets): break -def format_host_port(host, port): +def format_host_port(host: Union[bytes, str], port: int) -> str: host = host.decode("ascii") if isinstance(host, bytes) else host if ":" in host: return "[{}]:{}".format(host, port) @@ -165,8 +168,12 @@ def format_host_port(host, port): # AF_INET6: "..."} # this might be simpler after async def open_tcp_stream( - host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None -): + host: Union[bytes, str], + port: int, + *, + happy_eyeballs_delay: float = DEFAULT_DELAY, + local_address: Optional[str] = None, +) -> trio.SocketStream: """Connect to the given host and port over TCP. If the given ``host`` has multiple IP addresses associated with it, then @@ -275,7 +282,9 @@ async def open_tcp_stream( # the next connection attempt to start early # code needs to ensure sockets can be closed appropriately in the # face of crash or cancellation - async def attempt_connect(socket_args, sockaddr, attempt_failed): + async def attempt_connect( + socket_args: Sequence, sockaddr: _Address, attempt_failed: trio.Event + ) -> None: nonlocal winning_socket try: diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py index e5aba4695f..47be214c15 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/trio/_highlevel_open_unix_stream.py @@ -1,5 +1,7 @@ import os from contextlib import contextmanager +from typing import Iterator, TypeVar, Union +from typing_extensions import Protocol import trio from trio.socket import socket, SOCK_STREAM @@ -12,8 +14,16 @@ has_unix = False +class Closable(Protocol): + def close(self) -> None: + ... + + +_CL = TypeVar("_CL", bound=Closable) + + @contextmanager -def close_on_error(obj): +def close_on_error(obj: _CL) -> Iterator[_CL]: try: yield obj except: @@ -21,7 +31,7 @@ def close_on_error(obj): raise -async def open_unix_socket(filename): +async def open_unix_socket(filename: Union[bytes, str]) -> trio.SocketStream: """Opens a connection to the specified `Unix domain socket `__. diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py index 0585fa516f..bf5bf07a97 100644 --- a/trio/_highlevel_serve_listeners.py +++ b/trio/_highlevel_serve_listeners.py @@ -1,8 +1,11 @@ import errno import logging import os +from typing import Awaitable, Callable, List, Optional import trio +import trio.abc +from ._typing import TaskStatus # Errors that accept(2) can return, and which indicate that the system is # overloaded @@ -49,7 +52,11 @@ async def _serve_one_listener(listener, handler_nursery, handler): async def serve_listeners( - handler, listeners, *, handler_nursery=None, task_status=trio.TASK_STATUS_IGNORED + handler: Callable[[trio.abc.Stream], Awaitable[object]], + listeners: List[trio.abc.Listener], + *, + handler_nursery: Optional[trio.Nursery] = None, + task_status: TaskStatus = trio.TASK_STATUS_IGNORED, ): r"""Listen for incoming connections on ``listeners``, and for each one start a task running ``handler(stream)``. diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 0d9dbc0e92..0639f983e3 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -2,6 +2,7 @@ import errno from contextlib import contextmanager +from typing import Iterator, Optional, overload, Union import trio from . import socket as tsocket @@ -23,7 +24,7 @@ @contextmanager -def _translate_socket_errors_to_stream_errors(): +def _translate_socket_errors_to_stream_errors() -> Iterator[None]: try: yield except OSError as exc: @@ -59,7 +60,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket: tsocket.SocketType) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -93,7 +94,7 @@ def __init__(self, socket): except OSError: pass - async def send_all(self, data): + async def send_all(self, data: bytes) -> None: if self.socket.did_shutdown_SHUT_WR: raise trio.ClosedResourceError("can't send data after sending EOF") with self._send_conflict_detector: @@ -110,14 +111,14 @@ async def send_all(self, data): sent = await self.socket.send(remaining) total_sent += sent - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: if self.socket.fileno() == -1: raise trio.ClosedResourceError with _translate_socket_errors_to_stream_errors(): await self.socket.wait_writable() - async def send_eof(self): + async def send_eof(self) -> None: with self._send_conflict_detector: await trio.lowlevel.checkpoint() # On macOS, calling shutdown a second time raises ENOTCONN, but @@ -127,7 +128,7 @@ async def send_eof(self): with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: @@ -135,13 +136,13 @@ async def receive_some(self, max_bytes=None): with _translate_socket_errors_to_stream_errors(): return await self.socket.recv(max_bytes) - async def aclose(self): + async def aclose(self) -> None: self.socket.close() await trio.lowlevel.checkpoint() # __aenter__, __aexit__ inherited from HalfCloseableStream are OK - def setsockopt(self, level, option, value): + def setsockopt(self, level: int, option: int, value: Union[bytes, int]) -> None: """Set an option on the underlying socket. See :meth:`socket.socket.setsockopt` for details. @@ -149,7 +150,17 @@ def setsockopt(self, level, option, value): """ return self.socket.setsockopt(level, option, value) - def getsockopt(self, level, option, buffersize=0): + @overload + def getsockopt(self, level: int, option: int) -> int: + ... + + @overload + def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: + ... + + def getsockopt( + self, level: int, option: int, buffersize: int = 0 + ) -> Union[int, bytes]: """Check the current value of an option on the underlying socket. See :meth:`socket.socket.getsockopt` for details. @@ -332,7 +343,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final): """ - def __init__(self, socket): + def __init__(self, socket: tsocket.SocketType) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -348,7 +359,7 @@ def __init__(self, socket): self.socket = socket - async def accept(self): + async def accept(self) -> SocketStream: """Accept an incoming connection. Returns: @@ -376,7 +387,7 @@ async def accept(self): else: return SocketStream(sock) - async def aclose(self): + async def aclose(self) -> None: """Close this listener and its underlying socket.""" self.socket.close() await trio.lowlevel.checkpoint() diff --git a/trio/_path.py b/trio/_path.py index 4077c449d7..fac5e4f65c 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -1,12 +1,29 @@ -# type: ignore - from functools import wraps, partial import os import types import pathlib +from typing import ( + Any, + Callable, + Iterator, + Optional, + overload, + Sequence, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) + +from typing_extensions import Protocol import trio from trio._util import async_wraps, Final +from ._file_io import _AsyncIOBase + + +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) +_P = TypeVar("_P", bound="Path") # re-wrap return value from methods that return new instances of pathlib.Path @@ -16,9 +33,15 @@ def rewrap_path(value): return value -def _forward_factory(cls, attr_name, attr): - @wraps(attr) - def wrapper(self, *args, **kwargs): +class _Wrapper(Protocol): + _wrapped: object + + +def _forward_factory(cls: object, attr_name: str, attr: _Fn) -> _Fn: + wrapper: _Fn + + @wraps(attr) # type: ignore[no-redef] + def wrapper(self: _Wrapper, *args: object, **kwargs: object) -> object: attr = getattr(self._wrapped, attr_name) value = attr(*args, **kwargs) return rewrap_path(value) @@ -26,11 +49,13 @@ def wrapper(self, *args, **kwargs): return wrapper -def _forward_magic(cls, attr): +def _forward_magic(cls: Type, attr: _Fn) -> _Fn: sentinel = object() - @wraps(attr) - def wrapper(self, other=sentinel): + wrapper: _Fn + + @wraps(attr) # type: ignore[no-redef] + def wrapper(self: _Wrapper, other: object = sentinel) -> object: if other is sentinel: return attr(self._wrapped) if isinstance(other, cls): @@ -41,9 +66,9 @@ def wrapper(self, other=sentinel): return wrapper -def iter_wrapper_factory(cls, meth_name): +def iter_wrapper_factory(cls: Type, meth_name: str): # type: ignore[no-untyped-def] @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): + async def wrapper(self, *args, **kwargs): # type: ignore[misc] meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) # Make sure that the full iteration is performed in the thread @@ -54,9 +79,9 @@ async def wrapper(self, *args, **kwargs): return wrapper -def thread_wrapper_factory(cls, meth_name): +def thread_wrapper_factory(cls: Type, meth_name: str): # type: ignore[no-untyped-def] @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): + async def wrapper(self, *args, **kwargs): # type: ignore[misc] meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) @@ -65,10 +90,10 @@ async def wrapper(self, *args, **kwargs): return wrapper -def classmethod_wrapper_factory(cls, meth_name): - @classmethod +def classmethod_wrapper_factory(cls: Type, meth_name: str): # type: ignore[no-untyped-def] + @classmethod # type: ignore[misc] @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(cls, *args, **kwargs): + async def wrapper(cls, *args, **kwargs): # type: ignore[misc] meth = getattr(cls._wraps, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) @@ -153,7 +178,32 @@ class Path(metaclass=AsyncAutoWrapperType): ] _wrap_iter = ["glob", "rglob", "iterdir"] - def __init__(self, *args): + if TYPE_CHECKING: + # TODO: fill out the rest. Just copy from typeshed? Maybe this design won't pan + # out cleaner than a stub .pyi in the long run. + # https://github.com/python/typeshed/blob/58032a701811093d7bd24f9f75ad5e5de07e7723/stdlib/3/pathlib.pyi#L17-L53 + + # NOTE: These are effectively type hints compensating for Mypy not being able to + # see through AsyncAutoWrapperType. They are inline here such that the rest + # of the file can be hinted regularly rather than in a separate stub .pyi. + + # TODO: Can we handle os.PathLike[str] at least for 3.9+? + def joinpath(self: _P, *other: Union[os.PathLike, str]) -> _P: + ... + + def iterdir(self: _P) -> Iterator[_P]: + ... + + def __gt__(self, other: os.PathLike) -> bool: + ... + + def __lt__(self, other: os.PathLike) -> bool: + ... + + def __truediv__(self: _P, *args: Union[os.PathLike, str]) -> _P: + ... + + def __init__(self, *args) -> None: self._wrapped = pathlib.Path(*args) def __getattr__(self, name): @@ -162,17 +212,28 @@ def __getattr__(self, name): return rewrap_path(value) raise AttributeError(name) - def __dir__(self): + def __dir__(self) -> Sequence[str]: return super().__dir__() + self._forward - def __repr__(self): + def __repr__(self) -> str: return "trio.Path({})".format(repr(str(self))) def __fspath__(self): return os.fspath(self._wrapped) + @overload # type: ignore[misc] + async def open( + self, + mode: str = ..., + buffering: int = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + ) -> _AsyncIOBase: + ... + @wraps(pathlib.Path.open) - async def open(self, *args, **kwargs): + async def open(self, *args: object, **kwargs: object) -> object: """Open the file pointed to by the path, like the :func:`trio.open_file` function does. @@ -201,6 +262,6 @@ async def open(self, *args, **kwargs): # The value of Path.absolute.__doc__ makes a reference to # :meth:~pathlib.Path.absolute, which does not exist. Removing this makes more # sense than inventing our own special docstring for this. -del Path.absolute.__doc__ +del Path.absolute.__doc__ # type: ignore[attr-defined] os.PathLike.register(Path) diff --git a/trio/_path.pyi b/trio/_path.pyi deleted file mode 100644 index 85a8e1f960..0000000000 --- a/trio/_path.pyi +++ /dev/null @@ -1 +0,0 @@ -class Path: ... diff --git a/trio/_signals.py b/trio/_signals.py index cee3b7db53..88e125edb1 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -1,10 +1,28 @@ import signal from contextlib import contextmanager from collections import OrderedDict +from types import FrameType +from typing import ( + Any, + Callable, + Iterable, + Iterator, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import trio from ._util import signal_raise, is_main_thread, ConflictDetector +# https://github.com/python/typeshed/blob/master/stdlib/3/signal.pyi#L82-L83 +_SignalNumber = Union[int, signal.Signals] +_Handler = Union[Callable[[signal.Signals, FrameType], Any], int, signal.Handlers, None] +_TSelf = TypeVar("_TSelf") + + # Discussion of signal handling strategies: # # - On Windows signals barely exist. There are no options; signal handlers are @@ -42,7 +60,9 @@ @contextmanager -def _signal_handler(signals, handler): +def _signal_handler( + signals: Iterable[_SignalNumber], handler: _Handler +) -> Iterator[None]: original_handlers = {} try: for signum in set(signals): @@ -54,23 +74,23 @@ def _signal_handler(signals, handler): class SignalReceiver: - def __init__(self): + def __init__(self) -> None: # {signal num: None} - self._pending = OrderedDict() + self._pending: "OrderedDict[_SignalNumber, None]" = OrderedDict() self._lot = trio.lowlevel.ParkingLot() self._conflict_detector = ConflictDetector( "only one task can iterate on a signal receiver at a time" ) self._closed = False - def _add(self, signum): + def _add(self, signum: _SignalNumber) -> None: if self._closed: signal_raise(signum) else: self._pending[signum] = None self._lot.unpark() - def _redeliver_remaining(self): + def _redeliver_remaining(self) -> None: # First make sure that any signals still in the delivery pipeline will # get redelivered self._closed = True @@ -78,7 +98,7 @@ def _redeliver_remaining(self): # And then redeliver any that are sitting in pending. This is done # using a weird recursive construct to make sure we process everything # even if some of the handlers raise exceptions. - def deliver_next(): + def deliver_next() -> None: if self._pending: signum, _ = self._pending.popitem(last=False) try: @@ -89,13 +109,13 @@ def deliver_next(): deliver_next() # Helper for tests, not public or otherwise used - def _pending_signal_count(self): + def _pending_signal_count(self) -> int: return len(self._pending) - def __aiter__(self): + def __aiter__(self: _TSelf) -> _TSelf: return self - async def __anext__(self): + async def __anext__(self) -> _SignalNumber: if self._closed: raise RuntimeError("open_signal_receiver block already exited") # In principle it would be possible to support multiple concurrent @@ -111,7 +131,7 @@ async def __anext__(self): @contextmanager -def open_signal_receiver(*signals): +def open_signal_receiver(*signals: _SignalNumber) -> Iterator[SignalReceiver]: """A context manager for catching signals. Entering this context manager starts listening for the given signals and @@ -157,7 +177,7 @@ def open_signal_receiver(*signals): token = trio.lowlevel.current_trio_token() queue = SignalReceiver() - def handler(signum, _): + def handler(signum: _SignalNumber, _: object) -> None: token.run_sync_soon(queue._add, signum, idempotent=True) try: diff --git a/trio/_socket.py b/trio/_socket.py index fcf26e072b..2fa98ac334 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1,9 +1,27 @@ +from abc import ABCMeta import os import sys import select import socket as _stdlib_socket from functools import wraps as _wraps -from typing import TYPE_CHECKING +from types import TracebackType +from typing import ( + Awaitable, + Callable, + Mapping, + Union, + Optional, + Iterable, + Sequence, + Text, + Tuple, + TYPE_CHECKING, + Type, + TypeVar, + List, + Any, + overload, +) import idna as _idna @@ -11,6 +29,20 @@ from . import _core +_T = TypeVar("_T") + + +_Address = Union[tuple, str] +_AddressInfo = List[ + Tuple[ + _stdlib_socket.AddressFamily, + _stdlib_socket.SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] +] + # Usage: # # async with _try_sync(): @@ -20,19 +52,23 @@ # return await do_it_properly_with_a_check_point() # class _try_sync: - def __init__(self, blocking_exc_override=None): + def __init__( + self, blocking_exc_override: Callable[[BaseException], bool] = None + ) -> None: self._blocking_exc_override = blocking_exc_override - def _is_blocking_io_error(self, exc): + def _is_blocking_io_error(self, exc: BaseException) -> bool: if self._blocking_exc_override is None: return isinstance(exc, BlockingIOError) else: return self._blocking_exc_override(exc) - async def __aenter__(self): + async def __aenter__(self) -> None: await trio.lowlevel.checkpoint_if_cancelled() - async def __aexit__(self, etype, value, tb): + async def __aexit__( + self, etype: Type[BaseException], value: BaseException, tb: TracebackType + ) -> bool: if value is not None and self._is_blocking_io_error(value): # Discard the exception and fall through to the code below the # block @@ -59,11 +95,13 @@ async def __aexit__(self, etype, value, tb): # Overrides ################################################################ -_resolver = _core.RunVar("hostname_resolver") -_socket_factory = _core.RunVar("socket_factory") +_resolver = _core.RunVar[Optional["trio._abc.HostnameResolver"]]("hostname_resolver") +_socket_factory = _core.RunVar[Optional["trio._abc.SocketFactory"]]("socket_factory") -def set_custom_hostname_resolver(hostname_resolver): +def set_custom_hostname_resolver( + hostname_resolver: "trio._abc.HostnameResolver", +) -> Optional["trio._abc.HostnameResolver"]: """Set a custom hostname resolver. By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions @@ -95,7 +133,9 @@ def set_custom_hostname_resolver(hostname_resolver): return old -def set_custom_socket_factory(socket_factory): +def set_custom_socket_factory( + socket_factory: Optional["trio.abc.SocketFactory"], +) -> Optional["trio.abc.SocketFactory"]: """Set a custom socket object factory. This function allows you to replace Trio's normal socket class with a @@ -129,7 +169,14 @@ def set_custom_socket_factory(socket_factory): _NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV -async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): +async def getaddrinfo( + host: Optional[Union[bytearray, bytes, Text]], + port: Union[str, int, None], + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> _AddressInfo: """Look up a numeric address given a name. Arguments and return values are identical to :func:`socket.getaddrinfo`, @@ -150,7 +197,7 @@ async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): # skip the whole thread thing, which seems worthwhile. So we try first # with the _NUMERIC_ONLY flags set, and then only spawn a thread if that # fails with EAI_NONAME: - def numeric_only_failure(exc): + def numeric_only_failure(exc: BaseException) -> bool: return ( isinstance(exc, _stdlib_socket.gaierror) and exc.errno == _stdlib_socket.EAI_NONAME @@ -192,7 +239,9 @@ def numeric_only_failure(exc): ) -async def getnameinfo(sockaddr, flags): +async def getnameinfo( + sockaddr: Union[Tuple[str, int], Tuple[str, int, int, int]], flags: int +) -> Tuple[str, str]: """Look up a name given a numeric address. Arguments and return values are identical to :func:`socket.getnameinfo`, @@ -211,7 +260,7 @@ async def getnameinfo(sockaddr, flags): ) -async def getprotobyname(name): +async def getprotobyname(name: str) -> int: """Look up a protocol number by name. (Rarely used.) Like :func:`socket.getprotobyname`, but async. @@ -230,7 +279,7 @@ async def getprotobyname(name): ################################################################ -def from_stdlib_socket(sock): +def from_stdlib_socket(sock: _stdlib_socket.socket) -> "SocketType": """Convert a standard library :func:`socket.socket` object into a Trio socket object. @@ -239,7 +288,7 @@ def from_stdlib_socket(sock): @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) -def fromfd(fd, family, type, proto=0): +def fromfd(fd: int, family: int, type: int, proto: int = 0) -> "SocketType": """Like :func:`socket.fromfd`, but returns a Trio socket object.""" family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -250,27 +299,61 @@ def fromfd(fd, family, type, proto=0): ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(*args, **kwargs): - return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs)) + def fromshare(data: bytes) -> "SocketType": + # Not using *args, **kwargs to make mypy happy. + # trio/_socket.py:277: error: Argument 1 to "fromshare" has incompatible type "*Tuple[object, ...]"; expected "bytes" [arg-type] + # trio/_socket.py:277: error: Argument 2 to "fromshare" has incompatible type "**Dict[str, object]"; expected "bytes" [arg-type] + # So, we will just have to keep this in sync with the stdlib function in such + # case as it ever changes in the future. + # https://docs.python.org/3.9/library/socket.html#socket.fromshare + return from_stdlib_socket(_stdlib_socket.fromshare(data)) + + +# @overload +# def socketpair() -> Tuple["SocketType", "SocketType"]: +# ... +# +# +# @overload +# def socketpair(family: int = ...) -> Tuple["SocketType", "SocketType"]: +# ... +# +# +# @overload +# def socketpair(family: int = ..., type: int = ...) -> Tuple["SocketType", "SocketType"]: +# ... + +# @overload # type: ignore[misc] +# def socketpair( +# family: int = ..., type: int = ..., proto: int = ... +# ) -> Tuple["SocketType", "SocketType"]: +# ... +# TODO: uh... stuff... comments... @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) -def socketpair(*args, **kwargs): +# def socketpair( +# family: int = _stdlib_socket.AF_INET, +## type: int = _stdlib_socket.SOCK_STREAM, +# proto: int = 0, +# ) -> Tuple["SocketType", "SocketType"]: +def socketpair(*args: object, **kwargs: object) -> Tuple["SocketType", "SocketType"]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. """ - left, right = _stdlib_socket.socketpair(*args, **kwargs) + # left, right = _stdlib_socket.socketpair(family=family, type=type, proto=proto) + left, right = _stdlib_socket.socketpair(*args, **kwargs) # type: ignore[arg-type] return (from_stdlib_socket(left), from_stdlib_socket(right)) @_wraps(_stdlib_socket.socket, assigned=(), updated=()) def socket( - family=_stdlib_socket.AF_INET, - type=_stdlib_socket.SOCK_STREAM, - proto=0, - fileno=None, -): + family: int = _stdlib_socket.AF_INET, + type: int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, + fileno: Optional[int] = None, +) -> "SocketType": """Create a new Trio socket, like :func:`socket.socket`. This function's behavior can be customized using @@ -287,7 +370,9 @@ def socket( return from_stdlib_socket(stdlib_socket) -def _sniff_sockopts_for_fileno(family, type, proto, fileno): +def _sniff_sockopts_for_fileno( + family: int, type: int, proto: int, fileno: int +) -> Tuple[int, int, int]: """Correct SOCKOPTS for given fileno, falling back to provided values.""" # Wrap the raw fileno into a Python socket object # This object might have the wrong metadata, but it lets us easily call getsockopt @@ -318,7 +403,7 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno): # But on other platforms (e.g. Windows) SOCK_NONBLOCK and SOCK_CLOEXEC aren't # even defined. To recover the actual socket type (e.g. SOCK_STREAM) from a # socket.type attribute, mask with this: -_SOCK_TYPE_MASK = ~( +_SOCK_TYPE_MASK: int = ~( getattr(_stdlib_socket, "SOCK_NONBLOCK", 0) | getattr(_stdlib_socket, "SOCK_CLOEXEC", 0) ) @@ -327,7 +412,7 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno): # This function will modify the given socket to match the behavior in python # 3.7. This will become unecessary and can be removed when support for versions # older than 3.7 is dropped. -def real_socket_type(type_num): +def real_socket_type(type_num: int) -> int: return type_num & _SOCK_TYPE_MASK @@ -335,7 +420,7 @@ def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False): fn = getattr(_stdlib_socket.socket, methname) @_wraps(fn, assigned=("__name__",), updated=()) - async def wrapper(self, *args, **kwargs): + async def wrapper(self, *args, **kwargs): # type: ignore[misc] return await self._nonblocking_helper(fn, args, kwargs, wait_fn) wrapper.__doc__ = f"""Like :meth:`socket.socket.{methname}`, but async. @@ -350,15 +435,164 @@ async def wrapper(self, *args, **kwargs): class SocketType: - def __init__(self): + def __init__(self) -> None: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" ) + if TYPE_CHECKING: + + @property + def family(self) -> int: + ... + + @property + def type(self) -> int: + ... + + @property + def proto(self) -> int: + ... + + @property + def did_shutdown_SHUT_WR(self) -> bool: + ... + + def __enter__(self: _T) -> _T: + ... + + def __exit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + ... + + def dup(self) -> "SocketType": + ... + + def close(self) -> None: + ... + + async def bind(self, address: Union[_Address, bytes]) -> None: + ... + + def shutdown(self, flag: int) -> None: + ... + + def is_readable(self) -> bool: + ... + + async def wait_writable(self) -> None: + ... + + async def accept(self) -> Tuple["SocketType", Any]: + ... + + async def connect(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: + ... + + async def recv(self, bufsize: int, flags: int = ...) -> bytes: + ... + + async def recv_into( + self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... + ) -> int: + ... + + async def recvfrom(self, bufsize: int, flags: int = ...) -> Tuple[bytes, Any]: + ... + + async def recvfrom_into( + self, buffer: Union[bytearray, memoryview], nbytes: int, flags: int = ... + ) -> Tuple[int, Any]: + ... + + async def recvmsg( + self, bufsize: int, ancbufsize: int = ..., flags: int = ... + ) -> Tuple[bytes, List[Tuple[int, int, bytes]], int, Any]: + ... + + async def recvmsg_into( + self, + buffers: Iterable[Union[bytearray, memoryview]], + ancbufsize: int = ..., + flags: int = ..., + ) -> Tuple[int, List[Tuple[int, int, bytes]], int, Any]: + ... + + async def send(self, data: bytes, flags: int = ...) -> int: + ... + + async def sendmsg( + self, + buffers: Iterable[Union[bytes, memoryview]], + ancdata: Iterable[Tuple[int, int, Union[bytes, memoryview]]] = ..., + flags: int = ..., + address: Union[Tuple[Any, ...], str] = ..., + ) -> int: + ... + + @overload + async def sendto( + self, data: bytes, address: Union[Tuple[Any, ...], str] + ) -> int: + ... + + @overload + async def sendto( + self, data: bytes, flags: int, address: Union[Tuple[Any, ...], str] + ) -> int: + ... + + async def sendto(self, *args: object, **kwargs: object) -> int: + ... + + def detach(self) -> int: + ... + + def get_inheritable(self) -> bool: + ... + + def set_inheritable(self, inheritable: bool) -> None: + ... + + def fileno(self) -> int: + ... + + def getpeername(self) -> Any: + ... + + def getsockname(self) -> Any: + ... + + @overload + def getsockopt(self, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt(self, *args: object, **kwargs: object) -> object: + ... + + def setsockopt( + self, level: int, optname: int, value: Union[int, bytes] + ) -> None: + ... + + def listen(self, backlog: int) -> None: + ... + + def share(self, process_id: int) -> bytes: + ... + class _SocketType(SocketType): - def __init__(self, sock): + def __init__(self, sock: _stdlib_socket.socket) -> None: if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we # certainly don't want to blindly wrap one of those. @@ -398,52 +632,57 @@ def __init__(self, sock): "share", } - def __getattr__(self, name): + def __getattr__(self, name): # type: ignore if name in self._forward: return getattr(self._sock, name) raise AttributeError(name) - def __dir__(self): - return super().__dir__() + list(self._forward) + def __dir__(self) -> Sequence[str]: + return [*super().__dir__(), *self._forward] - def __enter__(self): + def __enter__(self: _T) -> _T: return self - def __exit__(self, *exc_info): - return self._sock.__exit__(*exc_info) + def __exit__( + self, + etype: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + return self._sock.__exit__(etype, exc, tb) # type: ignore[no-any-return,func-returns-value] @property - def family(self): + def family(self) -> int: return self._sock.family @property - def type(self): + def type(self) -> int: # Modify the socket type do match what is done on python 3.7. When # support for versions older than 3.7 is dropped, this can be updated # to just return self._sock.type return real_socket_type(self._sock.type) @property - def proto(self): + def proto(self) -> int: return self._sock.proto @property - def did_shutdown_SHUT_WR(self): + def did_shutdown_SHUT_WR(self) -> bool: return self._did_shutdown_SHUT_WR - def __repr__(self): + def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") - def dup(self): + def dup(self) -> "_SocketType": """Same as :meth:`socket.socket.dup`.""" return _SocketType(self._sock.dup()) - def close(self): + def close(self) -> None: if self._sock.fileno() != -1: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address): + async def bind(self, address: Union[_Address, bytes]) -> None: address = await self._resolve_local_address_nocp(address) if ( hasattr(_stdlib_socket, "AF_UNIX") @@ -461,14 +700,14 @@ async def bind(self, address): await trio.lowlevel.checkpoint() return self._sock.bind(address) - def shutdown(self, flag): + def shutdown(self, flag: int) -> None: # no need to worry about return value b/c always returns None: self._sock.shutdown(flag) # only do this if the call succeeded: if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]: self._did_shutdown_SHUT_WR = True - def is_readable(self): + def is_readable(self) -> bool: # use select.select on Windows, and select.poll everywhere else if sys.platform == "win32": rready, _, _ = select.select([self._sock], [], [], 0) @@ -477,7 +716,7 @@ def is_readable(self): p.register(self._sock, select.POLLIN) return bool(p.poll(0)) - async def wait_writable(self): + async def wait_writable(self) -> None: await _core.wait_writable(self._sock) ################################################################ @@ -489,7 +728,9 @@ async def wait_writable(self): # etc. # # NOTE: this function does not always checkpoint - async def _resolve_address_nocp(self, address, flags): + async def _resolve_address_nocp( + self, address: Union[_Address, bytes], flags: int + ) -> Union[_Address, bytes]: # Do some pre-checking (or exit early for non-IP sockets) if self._sock.family == _stdlib_socket.AF_INET: if not isinstance(address, tuple) or not len(address) == 2: @@ -501,7 +742,7 @@ async def _resolve_address_nocp(self, address, flags): ) elif self._sock.family == _stdlib_socket.AF_UNIX: # unwrap path-likes - return os.fspath(address) + return os.fspath(address) # type: ignore[arg-type] else: return address @@ -539,6 +780,7 @@ async def _resolve_address_nocp(self, address, flags): # empty list. assert len(gai_res) >= 1 # Address is the last item in the first entry + normed: Union[List, Tuple] (*_, normed), *_ = gai_res # The above ignored any flowid and scopeid in the passed-in address, # so restore them if present: @@ -555,16 +797,26 @@ async def _resolve_address_nocp(self, address, flags): # Returns something appropriate to pass to bind() # # NOTE: this function does not always checkpoint - async def _resolve_local_address_nocp(self, address): + async def _resolve_local_address_nocp( + self, address: Union[_Address, bytes] + ) -> Union[_Address, bytes]: return await self._resolve_address_nocp(address, _stdlib_socket.AI_PASSIVE) # Returns something appropriate to pass to connect()/sendto()/sendmsg() # # NOTE: this function does not always checkpoint - async def _resolve_remote_address_nocp(self, address): + async def _resolve_remote_address_nocp( + self, address: Union[_Address, bytes] + ) -> Union[_Address, bytes]: return await self._resolve_address_nocp(address, 0) - async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): + async def _nonblocking_helper( + self, + fn: Callable[..., _T], + args: Sequence[object], + kwargs: Mapping[str, object], + wait_fn: Callable[..., Awaitable[object]], + ) -> _T: # We have to reconcile two conflicting goals: # - We want to make it look like we always blocked in doing these # operations. The obvious way is to always do an IO wait before @@ -602,7 +854,7 @@ async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): _accept = _make_simple_sock_method_wrapper("accept", _core.wait_readable) - async def accept(self): + async def accept(self) -> Tuple["SocketType", _Address]: """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr @@ -611,7 +863,7 @@ async def accept(self): # connect ################################################################ - async def connect(self, address): + async def connect(self, address: Union[_Address, bytes]) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the @@ -735,15 +987,24 @@ async def connect(self, address): # sendto ################################################################ - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) - async def sendto(self, *args): + @overload + async def sendto(self, data: bytes, address: _Address) -> int: + ... + + @overload + async def sendto(self, data: bytes, flags: int, address: _Address) -> int: + ... + + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] + async def sendto(self, *args: object) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address) # and kwargs are not accepted - args = list(args) - args[-1] = await self._resolve_remote_address_nocp(args[-1]) + list_args = list(args) + address: _Address = list_args[-1] # type: ignore[assignment] + list_args[-1] = await self._resolve_remote_address_nocp(address) return await self._nonblocking_helper( - _stdlib_socket.socket.sendto, args, {}, _core.wait_writable + _stdlib_socket.socket.sendto, list_args, {}, _core.wait_writable ) ################################################################ @@ -755,7 +1016,7 @@ async def sendto(self, *args): ): @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) - async def sendmsg(self, *args): + async def sendmsg(self, *args: object) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. Only available on platforms where :meth:`socket.socket.sendmsg` is @@ -765,8 +1026,8 @@ async def sendmsg(self, *args): # args is: buffers[, ancdata[, flags[, address]]] # and kwargs are not accepted if len(args) == 4 and args[-1] is not None: - args = list(args) - args[-1] = await self._resolve_remote_address_nocp(args[-1]) + address: _Address = args[-1] # type: ignore[assignment] + args = (*args[:-1], await self._resolve_remote_address_nocp(address)) return await self._nonblocking_helper( _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable ) diff --git a/trio/_ssl.py b/trio/_ssl.py index c4ffa3ddbe..dcc9793be1 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -152,6 +152,7 @@ import operator as _operator import ssl as _stdlib_ssl from enum import Enum as _Enum +from typing import Sequence import trio @@ -199,11 +200,11 @@ class NeedHandshakeError(Exception): class _Once: - def __init__(self, afn, *args): + def __init__(self, afn, *args) -> None: self._afn = afn self._args = args - self.started = False - self._done = _sync.Event() + self.started: bool = False + self._done: _sync.Event = _sync.Event() async def ensure(self, *, checkpoint): if not self.started: @@ -216,7 +217,7 @@ async def ensure(self, *, checkpoint): await self._done.wait() @property - def done(self): + def done(self) -> bool: return self._done.is_set() @@ -406,10 +407,10 @@ def __setattr__(self, name, value): else: super().__setattr__(name, value) - def __dir__(self): + def __dir__(self) -> Sequence[str]: return super().__dir__() + list(self._forwarded) - def _check_status(self): + def _check_status(self) -> None: if self._state is _State.OK: return elif self._state is _State.BROKEN: @@ -595,14 +596,14 @@ async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): await trio.lowlevel.cancel_shielded_checkpoint() return ret - async def _do_handshake(self): + async def _do_handshake(self) -> None: try: await self._retry(self._ssl_object.do_handshake, is_handshake=True) except: self._state = _State.BROKEN raise - async def do_handshake(self): + async def do_handshake(self) -> None: """Ensure that the initial handshake has completed. The SSL protocol requires an initial handshake to exchange @@ -691,7 +692,7 @@ async def receive_some(self, max_bytes=None): else: raise - async def send_all(self, data): + async def send_all(self, data) -> None: """Encrypt some data and then send it on the underlying transport. See :meth:`trio.abc.SendStream.send_all` for details. @@ -738,7 +739,7 @@ async def unwrap(self): self._state = _State.CLOSED return (transport_stream, self._incoming.read()) - async def aclose(self): + async def aclose(self) -> None: """Gracefully shut down this connection, and close the underlying transport. @@ -825,7 +826,7 @@ async def aclose(self): finally: self._state = _State.CLOSED - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`.""" # This method's implementation is deceptively simple. # @@ -913,6 +914,6 @@ async def accept(self): https_compatible=self._https_compatible, ) - async def aclose(self): + async def aclose(self) -> None: """Close the transport listener.""" await self.transport_listener.aclose() diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 876cc0d7c9..d9d1a67ebb 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -3,10 +3,21 @@ import os import subprocess import sys -from typing import Optional from functools import partial import warnings -from typing import TYPE_CHECKING +from typing import ( + Any, + Awaitable, + Callable, + Mapping, + Optional, + overload, + Union, + Sequence, + TYPE_CHECKING, +) + +from typing_extensions import Literal, Protocol from ._abc import AsyncResource, SendStream, ReceiveStream from ._highlevel_generic import StapledStream @@ -16,6 +27,7 @@ create_pipe_to_child_stdin, create_pipe_from_child_output, ) +from ._typing import _HasFileno from ._util import NoPublicConstructor import trio @@ -117,11 +129,17 @@ class Process(AsyncResource, metaclass=NoPublicConstructor): # arbitrarily many threads if wait() keeps getting cancelled. _wait_for_exit_data = None - def __init__(self, popen, stdin, stdout, stderr): + def __init__( + self, + popen: subprocess.Popen, + stdin: Optional[SendStream], + stdout: Optional[ReceiveStream], + stderr: Optional[ReceiveStream], + ): self._proc = popen - self.stdin = stdin # type: Optional[SendStream] - self.stdout = stdout # type: Optional[ReceiveStream] - self.stderr = stderr # type: Optional[ReceiveStream] + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr self.stdio = None # type: Optional[StapledStream] if self.stdin is not None and self.stdout is not None: @@ -147,7 +165,7 @@ def __init__(self, popen, stdin, stdout, stderr): self.args = self._proc.args self.pid = self._proc.pid - def __repr__(self): + def __repr__(self) -> str: returncode = self.returncode if returncode is None: status = "running with PID {}".format(self.pid) @@ -159,7 +177,7 @@ def __repr__(self): return "".format(self.args, status) @property - def returncode(self): + def returncode(self) -> Optional[int]: """The exit status of the process (an integer), or ``None`` if it's still running. @@ -180,7 +198,7 @@ def returncode(self): self._close_pidfd() return result - async def aclose(self): + async def aclose(self) -> None: """Close any pipes we have to the process (both input and output) and wait for it to exit. @@ -202,7 +220,7 @@ async def aclose(self): with trio.CancelScope(shield=True): await self.wait() - def _close_pidfd(self): + def _close_pidfd(self) -> None: if self._pidfd is not None: self._pidfd.close() self._pidfd = None @@ -252,7 +270,7 @@ def send_signal(self, sig): """ self._proc.send_signal(sig) - def terminate(self): + def terminate(self) -> None: """Terminate the process, politely if possible. On UNIX, this is equivalent to @@ -263,7 +281,7 @@ def terminate(self): """ self._proc.terminate() - def kill(self): + def kill(self) -> None: """Immediately terminate the process. On UNIX, this is equivalent to @@ -276,119 +294,527 @@ def kill(self): self._proc.kill() -async def open_process( - command, *, stdin=None, stdout=None, stderr=None, **options -) -> Process: - r"""Execute a child program in a new process. - - After construction, you can interact with the child process by writing - data to its `~Process.stdin` stream (a `~trio.abc.SendStream`), reading - data from its `~Process.stdout` and/or `~Process.stderr` streams (both - `~trio.abc.ReceiveStream`\s), sending it signals using - `~Process.terminate`, `~Process.kill`, or `~Process.send_signal`, and - waiting for it to exit using `~Process.wait`. See `Process` for details. - - Each standard stream is only available if you specify that a pipe should - be created for it. For example, if you pass ``stdin=subprocess.PIPE``, you - can write to the `~Process.stdin` stream, else `~Process.stdin` will be - ``None``. - - Args: - command (list or str): The command to run. Typically this is a - sequence of strings such as ``['ls', '-l', 'directory with spaces']``, - where the first element names the executable to invoke and the other - elements specify its arguments. With ``shell=True`` in the - ``**options``, or on Windows, ``command`` may alternatively - be a string, which will be parsed following platform-dependent - :ref:`quoting rules `. - stdin: Specifies what the child process's standard input - stream should connect to: output written by the parent - (``subprocess.PIPE``), nothing (``subprocess.DEVNULL``), - or an open file (pass a file descriptor or something whose - ``fileno`` method returns one). If ``stdin`` is unspecified, - the child process will have the same standard input stream - as its parent. - stdout: Like ``stdin``, but for the child process's standard output - stream. - stderr: Like ``stdin``, but for the child process's standard error - stream. An additional value ``subprocess.STDOUT`` is supported, - which causes the child's standard output and standard error - messages to be intermixed on a single standard output stream, - attached to whatever the ``stdout`` option says to attach it to. - **options: Other :ref:`general subprocess options ` - are also accepted. - - Returns: - A new `Process` object. - - Raises: - OSError: if the process spawning fails, for example because the - specified command could not be found. +_Redirect = Union[int, _HasFileno, None] - """ - for key in ("universal_newlines", "text", "encoding", "errors", "bufsize"): - if options.get(key): - raise TypeError( - "trio.Process only supports communicating over " - "unbuffered byte streams; the '{}' option is not supported".format(key) - ) +# There's a lot of duplication here because mypy doesn't +# have a good way to represent overloads that differ only +# slightly. A cheat sheet: +# - on Windows, command is Union[str, Sequence[str]]; +# on Unix, command is str if shell=True and Sequence[str] otherwise +# - on Windows, there are startupinfo and creationflags options; +# on Unix, there are preexec_fn, restore_signals, start_new_session, and pass_fds +# - run_process() has the signature of open_process() plus arguments +# capture_stdout, capture_stderr, check, deliver_cancel, and the ability to pass +# bytes as stdin - if os.name == "posix": - if isinstance(command, str) and not options.get("shell"): - raise TypeError( - "command must be a sequence (not a string) if shell=False " - "on UNIX systems" - ) - if not isinstance(command, str) and options.get("shell"): - raise TypeError( - "command must be a string (not a sequence) if shell=True " - "on UNIX systems" +if TYPE_CHECKING: + if sys.platform == "win32": + + async def open_process( + command: Union[str, Sequence[str]], + *, + stdin: _Redirect = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: bool = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + startupinfo: subprocess.STARTUPINFO = ..., + creationflags: int = ..., + ) -> Process: + ... + + async def run_process( + command: Union[str, Sequence[str]], + *, + stdin: Union[bytes, _Redirect] = ..., + capture_stdout: bool = ..., + capture_stderr: bool = ..., + check: bool = ..., + deliver_cancel: Callable[[Process], Awaitable[None]] = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: bool = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + startupinfo: subprocess.STARTUPINFO = ..., + creationflags: int = ..., + ) -> subprocess.CompletedProcess[bytes]: + ... + + else: + + @overload + async def open_process( + command: str, + *, + stdin: _Redirect = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: Literal[True], + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> Process: + ... + + @overload + async def open_process( + command: Sequence[str], + *, + stdin: _Redirect = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: bool = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> Process: + ... + + async def open_process( + command: Union[str, Sequence[str]], + *, + stdin: _Redirect = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: Union[Literal[True], bool] = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> Process: + ... + + @overload + async def run_process( + command: str, + *, + stdin: Union[bytes, _Redirect] = ..., + capture_stdout: bool = ..., + capture_stderr: bool = ..., + check: bool = ..., + deliver_cancel: Callable[[Process], Awaitable[None]] = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: Literal[True], + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> subprocess.CompletedProcess[bytes]: + ... + + @overload + async def run_process( + command: Sequence[str], + *, + stdin: Union[bytes, _Redirect] = ..., + capture_stdout: bool = ..., + capture_stderr: bool = ..., + check: bool = ..., + deliver_cancel: Callable[[Process], Awaitable[None]] = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: bool = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> subprocess.CompletedProcess[bytes]: + ... + + async def run_process( + command: Union[str, Sequence[str]], + *, + stdin: Union[bytes, _Redirect] = ..., + capture_stdout: bool = ..., + capture_stderr: bool = ..., + check: bool = ..., + deliver_cancel: Callable[[Process], Awaitable[None]] = ..., + stdout: _Redirect = ..., + stderr: _Redirect = ..., + close_fds: bool = ..., + shell: Union[Literal[True], bool] = ..., + cwd: str = ..., + env: Mapping[str, str] = ..., + preexec_fn: Optional[Callable[[], Any]] = ..., + restore_signals: bool = ..., + start_new_session: bool = ..., + pass_fds: Sequence[int] = ..., + ) -> subprocess.CompletedProcess[bytes]: + ... + + +else: + + async def open_process(command, *, stdin=None, stdout=None, stderr=None, **options): + r"""Execute a child program in a new process. + + After construction, you can interact with the child process by writing + data to its `~Process.stdin` stream (a `~trio.abc.SendStream`), reading + data from its `~Process.stdout` and/or `~Process.stderr` streams (both + `~trio.abc.ReceiveStream`\s), sending it signals using + `~Process.terminate`, `~Process.kill`, or `~Process.send_signal`, and + waiting for it to exit using `~Process.wait`. See `Process` for details. + + Each standard stream is only available if you specify that a pipe should + be created for it. For example, if you pass ``stdin=subprocess.PIPE``, you + can write to the `~Process.stdin` stream, else `~Process.stdin` will be + ``None``. + + Args: + command (list or str): The command to run. Typically this is a + sequence of strings such as ``['ls', '-l', 'directory with spaces']``, + where the first element names the executable to invoke and the other + elements specify its arguments. With ``shell=True`` in the + ``**options``, or on Windows, ``command`` may alternatively + be a string, which will be parsed following platform-dependent + :ref:`quoting rules `. + stdin: Specifies what the child process's standard input + stream should connect to: output written by the parent + (``subprocess.PIPE``), nothing (``subprocess.DEVNULL``), + or an open file (pass a file descriptor or something whose + ``fileno`` method returns one). If ``stdin`` is unspecified, + the child process will have the same standard input stream + as its parent. + stdout: Like ``stdin``, but for the child process's standard output + stream. + stderr: Like ``stdin``, but for the child process's standard error + stream. An additional value ``subprocess.STDOUT`` is supported, + which causes the child's standard output and standard error + messages to be intermixed on a single standard output stream, + attached to whatever the ``stdout`` option says to attach it to. + **options: Other :ref:`general subprocess options ` + are also accepted. + + Returns: + A new `Process` object. + + Raises: + OSError: if the process spawning fails, for example because the + specified command could not be found. + + """ + for key in ("universal_newlines", "text", "encoding", "errors", "bufsize"): + if options.get(key): + raise TypeError( + "trio.Process only supports communicating over " + "unbuffered byte streams; the '{}' option is not supported".format( + key + ) + ) + + if os.name == "posix": + if isinstance(command, str) and not options.get("shell"): + raise TypeError( + "command must be a sequence (not a string) if shell=False " + "on UNIX systems" + ) + if not isinstance(command, str) and options.get("shell"): + raise TypeError( + "command must be a string (not a sequence) if shell=True " + "on UNIX systems" + ) + + trio_stdin = None # type: Optional[SendStream] + trio_stdout = None # type: Optional[ReceiveStream] + trio_stderr = None # type: Optional[ReceiveStream] + + if stdin == subprocess.PIPE: + trio_stdin, stdin = create_pipe_to_child_stdin() + if stdout == subprocess.PIPE: + trio_stdout, stdout = create_pipe_from_child_output() + if stderr == subprocess.STDOUT: + # If we created a pipe for stdout, pass the same pipe for + # stderr. If stdout was some non-pipe thing (DEVNULL or a + # given FD), pass the same thing. If stdout was passed as + # None, keep stderr as STDOUT to allow subprocess to dup + # our stdout. Regardless of which of these is applicable, + # don't create a new Trio stream for stderr -- if stdout + # is piped, stderr will be intermixed on the stdout stream. + if stdout is not None: + stderr = stdout + elif stderr == subprocess.PIPE: + trio_stderr, stderr = create_pipe_from_child_output() + + try: + popen = await trio.to_thread.run_sync( + partial( + subprocess.Popen, + command, + stdin=stdin, + stdout=stdout, + stderr=stderr, + **options, + ) ) + finally: + # Close the parent's handle for each child side of a pipe; + # we want the child to have the only copy, so that when + # it exits we can read EOF on our side. + if trio_stdin is not None: + os.close(stdin) + if trio_stdout is not None: + os.close(stdout) + if trio_stderr is not None: + os.close(stderr) + + return Process._create(popen, trio_stdin, trio_stdout, trio_stderr) + + async def run_process( + command, + *, + stdin=b"", + capture_stdout=False, + capture_stderr=False, + check=True, + deliver_cancel=None, + **options, + ): + """Run ``command`` in a subprocess, wait for it to complete, and + return a :class:`subprocess.CompletedProcess` instance describing + the results. + + If cancelled, :func:`run_process` terminates the subprocess and + waits for it to exit before propagating the cancellation, like + :meth:`Process.aclose`. + + **Input:** The subprocess's standard input stream is set up to + receive the bytes provided as ``stdin``. Once the given input has + been fully delivered, or if none is provided, the subprocess will + receive end-of-file when reading from its standard input. + Alternatively, if you want the subprocess to read its + standard input from the same place as the parent Trio process, you + can pass ``stdin=None``. + + **Output:** By default, any output produced by the subprocess is + passed through to the standard output and error streams of the + parent Trio process. If you would like to capture this output and + do something with it, you can pass ``capture_stdout=True`` to + capture the subprocess's standard output, and/or + ``capture_stderr=True`` to capture its standard error. Captured + data is provided as the + :attr:`~subprocess.CompletedProcess.stdout` and/or + :attr:`~subprocess.CompletedProcess.stderr` attributes of the + returned :class:`~subprocess.CompletedProcess` object. The value + for any stream that was not captured will be ``None``. + + If you want to capture both stdout and stderr while keeping them + separate, pass ``capture_stdout=True, capture_stderr=True``. + + If you want to capture both stdout and stderr but mixed together + in the order they were printed, use: ``capture_stdout=True, stderr=subprocess.STDOUT``. + This directs the child's stderr into its stdout, so the combined + output will be available in the `~subprocess.CompletedProcess.stdout` + attribute. + + **Error checking:** If the subprocess exits with a nonzero status + code, indicating failure, :func:`run_process` raises a + :exc:`subprocess.CalledProcessError` exception rather than + returning normally. The captured outputs are still available as + the :attr:`~subprocess.CalledProcessError.stdout` and + :attr:`~subprocess.CalledProcessError.stderr` attributes of that + exception. To disable this behavior, so that :func:`run_process` + returns normally even if the subprocess exits abnormally, pass + ``check=False``. + + Args: + command (list or str): The command to run. Typically this is a + sequence of strings such as ``['ls', '-l', 'directory with spaces']``, + where the first element names the executable to invoke and the other + elements specify its arguments. With ``shell=True`` in the + ``**options``, or on Windows, ``command`` may alternatively + be a string, which will be parsed following platform-dependent + :ref:`quoting rules `. + + stdin (:obj:`bytes`, file descriptor, or None): The bytes to provide to + the subprocess on its standard input stream, or ``None`` if the + subprocess's standard input should come from the same place as + the parent Trio process's standard input. As is the case with + the :mod:`subprocess` module, you can also pass a + file descriptor or an object with a ``fileno()`` method, + in which case the subprocess's standard input will come from + that file. + + capture_stdout (bool): If true, capture the bytes that the subprocess + writes to its standard output stream and return them in the + :attr:`~subprocess.CompletedProcess.stdout` attribute + of the returned :class:`~subprocess.CompletedProcess` object. + + capture_stderr (bool): If true, capture the bytes that the subprocess + writes to its standard error stream and return them in the + :attr:`~subprocess.CompletedProcess.stderr` attribute + of the returned :class:`~subprocess.CompletedProcess` object. + + check (bool): If false, don't validate that the subprocess exits + successfully. You should be sure to check the + ``returncode`` attribute of the returned object if you pass + ``check=False``, so that errors don't pass silently. + + deliver_cancel (async function or None): If `run_process` is cancelled, + then it needs to kill the child process. There are multiple ways to + do this, so we let you customize it. + + If you pass None (the default), then the behavior depends on the + platform: + + - On Windows, Trio calls ``TerminateProcess``, which should kill the + process immediately. + + - On Unix-likes, the default behavior is to send a ``SIGTERM``, wait + 5 seconds, and send a ``SIGKILL``. + + Alternatively, you can customize this behavior by passing in an + arbitrary async function, which will be called with the `Process` + object as an argument. For example, the default Unix behavior could + be implemented like this:: + + async def my_deliver_cancel(process): + process.send_signal(signal.SIGTERM) + await trio.sleep(5) + process.send_signal(signal.SIGKILL) + + When the process actually exits, the ``deliver_cancel`` function + will automatically be cancelled – so if the process exits after + ``SIGTERM``, then we'll never reach the ``SIGKILL``. + + In any case, `run_process` will always wait for the child process to + exit before raising `Cancelled`. + + **options: :func:`run_process` also accepts any :ref:`general subprocess + options ` and passes them on to the + :class:`~trio.Process` constructor. This includes the + ``stdout`` and ``stderr`` options, which provide additional + redirection possibilities such as ``stderr=subprocess.STDOUT``, + ``stdout=subprocess.DEVNULL``, or file descriptors. + + Returns: + A :class:`subprocess.CompletedProcess` instance describing the + return code and outputs. + + Raises: + UnicodeError: if ``stdin`` is specified as a Unicode string, rather + than bytes + ValueError: if multiple redirections are specified for the same + stream, e.g., both ``capture_stdout=True`` and + ``stdout=subprocess.DEVNULL`` + subprocess.CalledProcessError: if ``check=False`` is not passed + and the process exits with a nonzero exit status + OSError: if an error is encountered starting or communicating with + the process + + .. note:: The child process runs in the same process group as the parent + Trio process, so a Ctrl+C will be delivered simultaneously to both + parent and child. If you don't want this behavior, consult your + platform's documentation for starting child processes in a different + process group. - trio_stdin = None # type: Optional[SendStream] - trio_stdout = None # type: Optional[ReceiveStream] - trio_stderr = None # type: Optional[ReceiveStream] - - if stdin == subprocess.PIPE: - trio_stdin, stdin = create_pipe_to_child_stdin() - if stdout == subprocess.PIPE: - trio_stdout, stdout = create_pipe_from_child_output() - if stderr == subprocess.STDOUT: - # If we created a pipe for stdout, pass the same pipe for - # stderr. If stdout was some non-pipe thing (DEVNULL or a - # given FD), pass the same thing. If stdout was passed as - # None, keep stderr as STDOUT to allow subprocess to dup - # our stdout. Regardless of which of these is applicable, - # don't create a new Trio stream for stderr -- if stdout - # is piped, stderr will be intermixed on the stdout stream. - if stdout is not None: - stderr = stdout - elif stderr == subprocess.PIPE: - trio_stderr, stderr = create_pipe_from_child_output() + """ - try: - popen = await trio.to_thread.run_sync( - partial( - subprocess.Popen, - command, - stdin=stdin, - stdout=stdout, - stderr=stderr, - **options, + if isinstance(stdin, str): + raise UnicodeError("process stdin must be bytes, not str") + if stdin == subprocess.PIPE: + raise ValueError( + "stdin=subprocess.PIPE doesn't make sense since the pipe " + "is internal to run_process(); pass the actual data you " + "want to send over that pipe instead" ) - ) - finally: - # Close the parent's handle for each child side of a pipe; - # we want the child to have the only copy, so that when - # it exits we can read EOF on our side. - if trio_stdin is not None: - os.close(stdin) - if trio_stdout is not None: - os.close(stdout) - if trio_stderr is not None: - os.close(stderr) + if isinstance(stdin, (bytes, bytearray, memoryview)): + input = stdin + options["stdin"] = subprocess.PIPE + else: + # stdin should be something acceptable to Process + # (None, DEVNULL, a file descriptor, etc) and Process + # will raise if it's not + input = None + options["stdin"] = stdin + + if capture_stdout: + if "stdout" in options: + raise ValueError("can't specify both stdout and capture_stdout") + options["stdout"] = subprocess.PIPE + if capture_stderr: + if "stderr" in options: + raise ValueError("can't specify both stderr and capture_stderr") + options["stderr"] = subprocess.PIPE + + if deliver_cancel is None: + if os.name == "nt": + deliver_cancel = _windows_deliver_cancel + else: + assert os.name == "posix" + deliver_cancel = _posix_deliver_cancel + + stdout_chunks = [] + stderr_chunks = [] + + async with await open_process(command, **options) as proc: + + async def feed_input(): + async with proc.stdin: + try: + await proc.stdin.send_all(input) + except trio.BrokenResourceError: + pass + + async def read_output(stream, chunks): + async with stream: + async for chunk in stream: + chunks.append(chunk) + + async with trio.open_nursery() as nursery: + if proc.stdin is not None: + nursery.start_soon(feed_input) + if proc.stdout is not None: + nursery.start_soon(read_output, proc.stdout, stdout_chunks) + if proc.stderr is not None: + nursery.start_soon(read_output, proc.stderr, stderr_chunks) + try: + await proc.wait() + except trio.Cancelled: + with trio.CancelScope(shield=True): + killer_cscope = trio.CancelScope(shield=True) + + async def killer(): + with killer_cscope: + await deliver_cancel(proc) + + nursery.start_soon(killer) + await proc.wait() + killer_cscope.cancel() + raise + + stdout = b"".join(stdout_chunks) if proc.stdout is not None else None + stderr = b"".join(stderr_chunks) if proc.stderr is not None else None - return Process._create(popen, trio_stdin, trio_stdout, trio_stderr) + if proc.returncode and check: + raise subprocess.CalledProcessError( + proc.returncode, proc.args, output=stdout, stderr=stderr + ) + else: + return subprocess.CompletedProcess( + proc.args, proc.returncode, stdout, stderr + ) async def _windows_deliver_cancel(p): @@ -414,237 +840,3 @@ async def _posix_deliver_cancel(p): warnings.warn( RuntimeWarning(f"tried to kill process {p!r}, but failed with: {exc!r}") ) - - -async def run_process( - command, - *, - stdin=b"", - capture_stdout=False, - capture_stderr=False, - check=True, - deliver_cancel=None, - **options, -): - """Run ``command`` in a subprocess, wait for it to complete, and - return a :class:`subprocess.CompletedProcess` instance describing - the results. - - If cancelled, :func:`run_process` terminates the subprocess and - waits for it to exit before propagating the cancellation, like - :meth:`Process.aclose`. - - **Input:** The subprocess's standard input stream is set up to - receive the bytes provided as ``stdin``. Once the given input has - been fully delivered, or if none is provided, the subprocess will - receive end-of-file when reading from its standard input. - Alternatively, if you want the subprocess to read its - standard input from the same place as the parent Trio process, you - can pass ``stdin=None``. - - **Output:** By default, any output produced by the subprocess is - passed through to the standard output and error streams of the - parent Trio process. If you would like to capture this output and - do something with it, you can pass ``capture_stdout=True`` to - capture the subprocess's standard output, and/or - ``capture_stderr=True`` to capture its standard error. Captured - data is provided as the - :attr:`~subprocess.CompletedProcess.stdout` and/or - :attr:`~subprocess.CompletedProcess.stderr` attributes of the - returned :class:`~subprocess.CompletedProcess` object. The value - for any stream that was not captured will be ``None``. - - If you want to capture both stdout and stderr while keeping them - separate, pass ``capture_stdout=True, capture_stderr=True``. - - If you want to capture both stdout and stderr but mixed together - in the order they were printed, use: ``capture_stdout=True, stderr=subprocess.STDOUT``. - This directs the child's stderr into its stdout, so the combined - output will be available in the `~subprocess.CompletedProcess.stdout` - attribute. - - **Error checking:** If the subprocess exits with a nonzero status - code, indicating failure, :func:`run_process` raises a - :exc:`subprocess.CalledProcessError` exception rather than - returning normally. The captured outputs are still available as - the :attr:`~subprocess.CalledProcessError.stdout` and - :attr:`~subprocess.CalledProcessError.stderr` attributes of that - exception. To disable this behavior, so that :func:`run_process` - returns normally even if the subprocess exits abnormally, pass - ``check=False``. - - Args: - command (list or str): The command to run. Typically this is a - sequence of strings such as ``['ls', '-l', 'directory with spaces']``, - where the first element names the executable to invoke and the other - elements specify its arguments. With ``shell=True`` in the - ``**options``, or on Windows, ``command`` may alternatively - be a string, which will be parsed following platform-dependent - :ref:`quoting rules `. - - stdin (:obj:`bytes`, file descriptor, or None): The bytes to provide to - the subprocess on its standard input stream, or ``None`` if the - subprocess's standard input should come from the same place as - the parent Trio process's standard input. As is the case with - the :mod:`subprocess` module, you can also pass a - file descriptor or an object with a ``fileno()`` method, - in which case the subprocess's standard input will come from - that file. - - capture_stdout (bool): If true, capture the bytes that the subprocess - writes to its standard output stream and return them in the - :attr:`~subprocess.CompletedProcess.stdout` attribute - of the returned :class:`~subprocess.CompletedProcess` object. - - capture_stderr (bool): If true, capture the bytes that the subprocess - writes to its standard error stream and return them in the - :attr:`~subprocess.CompletedProcess.stderr` attribute - of the returned :class:`~subprocess.CompletedProcess` object. - - check (bool): If false, don't validate that the subprocess exits - successfully. You should be sure to check the - ``returncode`` attribute of the returned object if you pass - ``check=False``, so that errors don't pass silently. - - deliver_cancel (async function or None): If `run_process` is cancelled, - then it needs to kill the child process. There are multiple ways to - do this, so we let you customize it. - - If you pass None (the default), then the behavior depends on the - platform: - - - On Windows, Trio calls ``TerminateProcess``, which should kill the - process immediately. - - - On Unix-likes, the default behavior is to send a ``SIGTERM``, wait - 5 seconds, and send a ``SIGKILL``. - - Alternatively, you can customize this behavior by passing in an - arbitrary async function, which will be called with the `Process` - object as an argument. For example, the default Unix behavior could - be implemented like this:: - - async def my_deliver_cancel(process): - process.send_signal(signal.SIGTERM) - await trio.sleep(5) - process.send_signal(signal.SIGKILL) - - When the process actually exits, the ``deliver_cancel`` function - will automatically be cancelled – so if the process exits after - ``SIGTERM``, then we'll never reach the ``SIGKILL``. - - In any case, `run_process` will always wait for the child process to - exit before raising `Cancelled`. - - **options: :func:`run_process` also accepts any :ref:`general subprocess - options ` and passes them on to the - :class:`~trio.Process` constructor. This includes the - ``stdout`` and ``stderr`` options, which provide additional - redirection possibilities such as ``stderr=subprocess.STDOUT``, - ``stdout=subprocess.DEVNULL``, or file descriptors. - - Returns: - A :class:`subprocess.CompletedProcess` instance describing the - return code and outputs. - - Raises: - UnicodeError: if ``stdin`` is specified as a Unicode string, rather - than bytes - ValueError: if multiple redirections are specified for the same - stream, e.g., both ``capture_stdout=True`` and - ``stdout=subprocess.DEVNULL`` - subprocess.CalledProcessError: if ``check=False`` is not passed - and the process exits with a nonzero exit status - OSError: if an error is encountered starting or communicating with - the process - - .. note:: The child process runs in the same process group as the parent - Trio process, so a Ctrl+C will be delivered simultaneously to both - parent and child. If you don't want this behavior, consult your - platform's documentation for starting child processes in a different - process group. - - """ - - if isinstance(stdin, str): - raise UnicodeError("process stdin must be bytes, not str") - if stdin == subprocess.PIPE: - raise ValueError( - "stdin=subprocess.PIPE doesn't make sense since the pipe " - "is internal to run_process(); pass the actual data you " - "want to send over that pipe instead" - ) - if isinstance(stdin, (bytes, bytearray, memoryview)): - input = stdin - options["stdin"] = subprocess.PIPE - else: - # stdin should be something acceptable to Process - # (None, DEVNULL, a file descriptor, etc) and Process - # will raise if it's not - input = None - options["stdin"] = stdin - - if capture_stdout: - if "stdout" in options: - raise ValueError("can't specify both stdout and capture_stdout") - options["stdout"] = subprocess.PIPE - if capture_stderr: - if "stderr" in options: - raise ValueError("can't specify both stderr and capture_stderr") - options["stderr"] = subprocess.PIPE - - if deliver_cancel is None: - if os.name == "nt": - deliver_cancel = _windows_deliver_cancel - else: - assert os.name == "posix" - deliver_cancel = _posix_deliver_cancel - - stdout_chunks = [] - stderr_chunks = [] - - async with await open_process(command, **options) as proc: - - async def feed_input(): - async with proc.stdin: - try: - await proc.stdin.send_all(input) - except trio.BrokenResourceError: - pass - - async def read_output(stream, chunks): - async with stream: - async for chunk in stream: - chunks.append(chunk) - - async with trio.open_nursery() as nursery: - if proc.stdin is not None: - nursery.start_soon(feed_input) - if proc.stdout is not None: - nursery.start_soon(read_output, proc.stdout, stdout_chunks) - if proc.stderr is not None: - nursery.start_soon(read_output, proc.stderr, stderr_chunks) - try: - await proc.wait() - except trio.Cancelled: - with trio.CancelScope(shield=True): - killer_cscope = trio.CancelScope(shield=True) - - async def killer(): - with killer_cscope: - await deliver_cancel(proc) - - nursery.start_soon(killer) - await proc.wait() - killer_cscope.cancel() - raise - - stdout = b"".join(stdout_chunks) if proc.stdout is not None else None - stderr = b"".join(stderr_chunks) if proc.stderr is not None else None - - if proc.returncode and check: - raise subprocess.CalledProcessError( - proc.returncode, proc.args, output=stdout, stderr=stderr - ) - else: - return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr) diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index 91ba224546..dad765a0f1 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -7,14 +7,17 @@ from .._sync import CapacityLimiter, Event from .._threads import to_thread_run_sync -try: - from os import waitid + +waitid = getattr(os, "waitid", None) + + +if waitid is not None: def sync_wait_reapable(pid): waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT) -except ImportError: +else: # pypy doesn't define os.waitid so we need to pull it out ourselves # using cffi: https://bitbucket.org/pypy/pypy/issues/2922/ import cffi @@ -102,7 +105,8 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: # process. if process._wait_for_exit_data is None: - process._wait_for_exit_data = event = Event() # type: ignore + event = Event() + process._wait_for_exit_data = event # type: ignore[assignment] _core.spawn_system_task(_waitid_system_task, process.pid, event) assert isinstance(process._wait_for_exit_data, Event) await process._wait_for_exit_data.wait() diff --git a/trio/_subprocess_platform/windows.py b/trio/_subprocess_platform/windows.py index 958be8675c..816da4b203 100644 --- a/trio/_subprocess_platform/windows.py +++ b/trio/_subprocess_platform/windows.py @@ -3,4 +3,4 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: - await WaitForSingleObject(int(process._proc._handle)) + await WaitForSingleObject(int(process._proc._handle)) # type: ignore[attr-defined] diff --git a/trio/_sync.py b/trio/_sync.py index bed339ef6b..e0be74d850 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -1,4 +1,7 @@ import math +from typing import Optional, Type, TypeVar, Union + +from typing_extensions import Protocol import attr import outcome @@ -6,6 +9,7 @@ import trio from ._core import enable_ki_protection, ParkingLot +from ._core._run import Task from ._deprecate import deprecated from ._util import Final @@ -40,17 +44,17 @@ class Event(metaclass=Final): _lot = attr.ib(factory=ParkingLot, init=False) _flag = attr.ib(default=False, init=False) - def is_set(self): + def is_set(self) -> bool: """Return the current value of the internal flag.""" return self._flag @enable_ki_protection - def set(self): + def set(self) -> None: """Set the internal flag value to True, and wake any waiting tasks.""" self._flag = True self._lot.unpark_all() - async def wait(self): + async def wait(self) -> None: """Block until the internal flag value becomes True. If it's already True, then this method returns immediately. @@ -73,20 +77,31 @@ def statistics(self): return self._lot.statistics() -def async_cm(cls): +class _HasAcquire(Protocol): + async def acquire(self) -> object: + ... + + def release(self) -> object: + ... + + +_TA = TypeVar("_TA", bound=_HasAcquire) + + +def async_cm(cls: Type[_TA]) -> Type[_TA]: @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self: _TA) -> None: await self.acquire() __aenter__.__qualname__ = cls.__qualname__ + ".__aenter__" - cls.__aenter__ = __aenter__ + cls.__aenter__ = __aenter__ # type: ignore[attr-defined] @enable_ki_protection - async def __aexit__(self, *args): + async def __aexit__(self: _TA, *args: object) -> None: self.release() __aexit__.__qualname__ = cls.__qualname__ + ".__aexit__" - cls.__aexit__ = __aexit__ + cls.__aexit__ = __aexit__ # type: ignore[attr-defined] return cls @@ -153,7 +168,9 @@ class CapacityLimiter(metaclass=Final): """ - def __init__(self, total_tokens): + _total_tokens: int + + def __init__(self, total_tokens) -> None: self._lot = ParkingLot() self._borrowers = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of @@ -162,13 +179,13 @@ def __init__(self, total_tokens): self.total_tokens = total_tokens assert self._total_tokens == total_tokens - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), len(self._borrowers), self._total_tokens, len(self._lot) ) @property - def total_tokens(self): + def total_tokens(self) -> int: """The total capacity available. You can change :attr:`total_tokens` by assigning to this attribute. If @@ -191,23 +208,23 @@ def total_tokens(self, new_total_tokens): self._total_tokens = new_total_tokens self._wake_waiters() - def _wake_waiters(self): + def _wake_waiters(self) -> None: available = self._total_tokens - len(self._borrowers) for woken in self._lot.unpark(count=available): self._borrowers.add(self._pending_borrowers.pop(woken)) @property - def borrowed_tokens(self): + def borrowed_tokens(self) -> int: """The amount of capacity that's currently in use.""" return len(self._borrowers) @property - def available_tokens(self): + def available_tokens(self) -> int: """The amount of capacity that's available to use.""" return self.total_tokens - self.borrowed_tokens @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Borrow a token from the sack, without blocking. Raises: @@ -219,7 +236,7 @@ def acquire_nowait(self): self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) @enable_ki_protection - def acquire_on_behalf_of_nowait(self, borrower): + def acquire_on_behalf_of_nowait(self, borrower: Union[object, Task]) -> None: """Borrow a token from the sack on behalf of ``borrower``, without blocking. @@ -248,7 +265,7 @@ def acquire_on_behalf_of_nowait(self, borrower): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Borrow a token from the sack, blocking if necessary. Raises: @@ -259,7 +276,7 @@ async def acquire(self): await self.acquire_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - async def acquire_on_behalf_of(self, borrower): + async def acquire_on_behalf_of(self, borrower: Union[object, Task]) -> None: """Borrow a token from the sack on behalf of ``borrower``, blocking if necessary. @@ -288,7 +305,7 @@ async def acquire_on_behalf_of(self, borrower): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Put a token back into the sack. Raises: @@ -299,7 +316,7 @@ def release(self): self.release_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - def release_on_behalf_of(self, borrower): + def release_on_behalf_of(self, borrower: Union[object, Task]) -> None: """Put a token back into the sack on behalf of ``borrower``. Raises: @@ -369,7 +386,7 @@ class Semaphore(metaclass=Final): """ - def __init__(self, initial_value, *, max_value=None): + def __init__(self, initial_value: int, *, max_value: Optional[int] = None) -> None: if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") if initial_value < 0: @@ -387,7 +404,7 @@ def __init__(self, initial_value, *, max_value=None): self._value = initial_value self._max_value = max_value - def __repr__(self): + def __repr__(self) -> str: if self._max_value is None: max_value_str = "" else: @@ -397,17 +414,17 @@ def __repr__(self): ) @property - def value(self): + def value(self) -> int: """The current value of the semaphore.""" return self._value @property - def max_value(self): + def max_value(self) -> Optional[int]: """The maximum allowed value. May be None to indicate no limit.""" return self._max_value @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to decrement the semaphore value, without blocking. Raises: @@ -421,7 +438,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Decrement the semaphore value, blocking if necessary to avoid letting it drop below zero. @@ -435,7 +452,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Increment the semaphore value, possibly waking a task blocked in :meth:`acquire`. @@ -477,7 +494,7 @@ class _LockImpl: _lot = attr.ib(factory=ParkingLot, init=False) _owner = attr.ib(default=None, init=False) - def __repr__(self): + def __repr__(self) -> str: if self.locked(): s1 = "locked" s2 = " with {} waiters".format(len(self._lot)) @@ -498,7 +515,7 @@ def locked(self): return self._owner is not None @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the lock, without blocking. Raises: @@ -516,7 +533,7 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Acquire the lock, blocking if necessary.""" await trio.lowlevel.checkpoint_if_cancelled() try: @@ -530,7 +547,7 @@ async def acquire(self): await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Release the lock. Raises: @@ -661,7 +678,7 @@ class Condition(metaclass=Final): """ - def __init__(self, lock=None): + def __init__(self, lock=None) -> None: if lock is None: lock = Lock() if not type(lock) is Lock: @@ -687,16 +704,16 @@ def acquire_nowait(self): """ return self._lock.acquire_nowait() - async def acquire(self): + async def acquire(self) -> None: """Acquire the underlying lock, blocking if necessary.""" await self._lock.acquire() - def release(self): + def release(self) -> None: """Release the underlying lock.""" self._lock.release() @enable_ki_protection - async def wait(self): + async def wait(self) -> None: """Wait for another task to call :meth:`notify` or :meth:`notify_all`. @@ -745,7 +762,7 @@ def notify(self, n=1): raise RuntimeError("must hold the lock to notify") self._lot.repark(self._lock._lot, count=n) - def notify_all(self): + def notify_all(self) -> None: """Wake all tasks that are currently blocked in :meth:`wait`. Raises: diff --git a/trio/_threads.py b/trio/_threads.py index 648b87d801..8cfd19f560 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -3,6 +3,7 @@ import threading import queue as stdlib_queue from itertools import count +from typing import Awaitable, Callable, Optional, Sequence, TypeVar import attr import inspect @@ -23,13 +24,16 @@ # Global due to Threading API, thread local storage for trio token TOKEN_LOCAL = threading.local() -_limiter_local = RunVar("limiter") +_limiter_local = RunVar[CapacityLimiter]("limiter") # I pulled this number out of the air; it isn't based on anything. Probably we # should make some kind of measurements to pick a good value. DEFAULT_LIMIT = 40 _thread_counter = count() +_T = TypeVar("_T") + + def current_default_thread_limiter(): """Get the default `~trio.CapacityLimiter` used by `trio.to_thread.run_sync`. @@ -55,8 +59,16 @@ class ThreadPlaceholder: name = attr.ib() +# TODO: maybe we don't want to ban Any in decorated functions? just any? maybe +# then we would just ignore the line with the Any? (it is the callable's +# unspecified parameter list) @enable_ki_protection -async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None): +async def to_thread_run_sync( # type: ignore[misc] + sync_fn: Callable[..., _T], + *args: object, + cancellable: bool = False, + limiter: Optional[CapacityLimiter] = None, +) -> _T: """Convert a blocking operation into an async operation using a thread. These two lines are equivalent:: @@ -204,7 +216,7 @@ def abort(_): else: return trio.lowlevel.Abort.FAILED - return await trio.lowlevel.wait_task_rescheduled(abort) + return await trio.lowlevel.wait_task_rescheduled(abort) # type: ignore[return-value] def _run_fn_as_system_task(cb, fn, *args, trio_token=None): @@ -238,7 +250,11 @@ def _run_fn_as_system_task(cb, fn, *args, trio_token=None): return q.get().unwrap() -def from_thread_run(afn, *args, trio_token=None): +def from_thread_run( + afn: Callable[..., Awaitable[_T]], + *args: object, + trio_token: Optional[TrioToken] = None, +) -> _T: """Run the given async function in the parent Trio thread, blocking until it is complete. @@ -273,13 +289,15 @@ def from_thread_run(afn, *args, trio_token=None): to enter Trio. """ - def callback(q, afn, args): + def callback( + q: stdlib_queue.Queue, afn: Callable[..., Awaitable[_T]], args: Sequence[object] + ) -> None: @disable_ki_protection - async def unprotected_afn(): + async def unprotected_afn() -> _T: coro = coroutine_or_error(afn, *args) - return await coro + return await coro # type: ignore[no-any-return] - async def await_in_trio_thread_task(): + async def await_in_trio_thread_task() -> None: q.put_nowait(await outcome.acapture(unprotected_afn)) try: @@ -289,10 +307,12 @@ async def await_in_trio_thread_task(): outcome.Error(trio.RunFinishedError("system nursery is closed")) ) - return _run_fn_as_system_task(callback, afn, *args, trio_token=trio_token) + return _run_fn_as_system_task(callback, afn, *args, trio_token=trio_token) # type: ignore[no-any-return] -def from_thread_run_sync(fn, *args, trio_token=None): +def from_thread_run_sync( + fn: Callable[..., _T], *args: object, trio_token: Optional[TrioToken] = None +) -> _T: """Run the given sync function in the parent Trio thread, blocking until it is complete. @@ -323,14 +343,16 @@ def from_thread_run_sync(fn, *args, trio_token=None): to enter Trio. """ - def callback(q, fn, args): + def callback( + q: stdlib_queue.Queue, fn: Callable[..., _T], args: Sequence[object] + ) -> None: @disable_ki_protection - def unprotected_fn(): + def unprotected_fn() -> _T: ret = fn(*args) if inspect.iscoroutine(ret): # Manually close coroutine to avoid RuntimeWarnings - ret.close() + ret.close() # type: ignore[attr-defined] raise TypeError( "Trio expected a sync function, but {!r} appears to be " "asynchronous".format(getattr(fn, "__qualname__", fn)) @@ -341,4 +363,4 @@ def unprotected_fn(): res = outcome.capture(unprotected_fn) q.put_nowait(res) - return _run_fn_as_system_task(callback, fn, *args, trio_token=trio_token) + return _run_fn_as_system_task(callback, fn, *args, trio_token=trio_token) # type: ignore[no-any-return] diff --git a/trio/_timeouts.py b/trio/_timeouts.py index 1f7878f89e..9568ec7cf3 100644 --- a/trio/_timeouts.py +++ b/trio/_timeouts.py @@ -1,9 +1,10 @@ from contextlib import contextmanager +from typing import ContextManager, Iterator import trio -def move_on_at(deadline): +def move_on_at(deadline: float) -> trio.CancelScope: """Use as a context manager to create a cancel scope with the given absolute deadline. @@ -14,7 +15,7 @@ def move_on_at(deadline): return trio.CancelScope(deadline=deadline) -def move_on_after(seconds): +def move_on_after(seconds: float) -> trio.CancelScope: """Use as a context manager to create a cancel scope whose deadline is set to now + *seconds*. @@ -31,7 +32,7 @@ def move_on_after(seconds): return move_on_at(trio.current_time() + seconds) -async def sleep_forever(): +async def sleep_forever() -> None: """Pause execution of the current task forever (or until cancelled). Equivalent to calling ``await sleep(math.inf)``. @@ -40,7 +41,7 @@ async def sleep_forever(): await trio.lowlevel.wait_task_rescheduled(lambda _: trio.lowlevel.Abort.SUCCEEDED) -async def sleep_until(deadline): +async def sleep_until(deadline: float) -> None: """Pause execution of the current task until the given time. The difference between :func:`sleep` and :func:`sleep_until` is that the @@ -57,7 +58,7 @@ async def sleep_until(deadline): await sleep_forever() -async def sleep(seconds): +async def sleep(seconds: float) -> None: """Pause execution of the current task for the given number of seconds. Args: @@ -84,7 +85,7 @@ class TooSlowError(Exception): @contextmanager -def fail_at(deadline): +def fail_at(deadline: float) -> Iterator[trio.CancelScope]: """Creates a cancel scope with the given deadline, and raises an error if it is actually cancelled. @@ -108,7 +109,7 @@ def fail_at(deadline): raise TooSlowError -def fail_after(seconds): +def fail_after(seconds: float) -> ContextManager[trio.CancelScope]: """Creates a cancel scope with the given timeout, and raises an error if it is actually cancelled. diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index d7e6326ce6..77041c570f 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -10,6 +10,7 @@ import os from pathlib import Path import sys +from typing import Dict, Iterator, List, Tuple, Union from textwrap import indent @@ -18,10 +19,31 @@ HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +import select +import socket +import sys +from typing import ( + Awaitable, + Callable, + ContextManager, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) + +from .._abc import Clock +from .._typing import _HasFileno +from .._core._entry_queue import TrioToken +from .. import _core +from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND, _RunStatistics, Task from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument +if TYPE_CHECKING and sys.platform == "win32": + from ._io_windows import CompletionKeyEventInfo + # fmt: off """ @@ -36,7 +58,7 @@ """ -def is_function(node): +def is_function(node: ast.AST) -> bool: """Check if the AST node is either a function or an async function """ @@ -45,17 +67,23 @@ def is_function(node): return False -def is_public(node): +def is_public(node: ast.AST) -> bool: """Check if the AST node has a _public decorator""" if not is_function(node): return False + + # the `if` above does this but we have to help out Mypy + assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + for decorator in node.decorator_list: if isinstance(decorator, ast.Name) and decorator.id == "_public": return True return False -def get_public_methods(tree): +def get_public_methods( + tree: ast.AST, +) -> Iterator[Union[ast.FunctionDef, ast.AsyncFunctionDef]]: """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked @@ -63,10 +91,15 @@ def get_public_methods(tree): """ for node in ast.walk(tree): if is_public(node): + # the `if` above does this but we have to help out Mypy + assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + yield node -def create_passthrough_args(funcdef): +def create_passthrough_args( + funcdef: Union[ast.FunctionDef, ast.AsyncFunctionDef] +) -> str: """Given a function definition, create a string that represents taking all the arguments from the function, and passing them through to another invocation of the same function. @@ -86,18 +119,33 @@ def create_passthrough_args(funcdef): return "({})".format(", ".join(call_args)) -def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: +def gen_public_wrappers_source(source_path: Union[Path, str], lookup_path: str) -> str: """Scan the given .py file for @_public decorators, and generate wrapper functions. """ generated = [HEADER] + # source_string = source_path.read_text("utf-8") + # source = astor.code_to_ast.parse_string(source_string) source = astor.code_to_ast.parse_file(source_path) + + asserts = [ + node for node in ast.iter_child_nodes(source) if isinstance(node, ast.Assert) + ] + if len(asserts) > 0: + the_assert = asserts[0] + generated.append(astor.to_source(the_assert)) + for method in get_public_methods(source): # Remove self from arguments assert method.args.args[0].arg == "self" del method.args.args[0] + contextmanager_decorated = any( + decorator.id in {"contextmanager", "contextlib.contextmanager"} + for decorator in method.decorator_list + if isinstance(decorator, ast.Name) + ) # Remove decorators method.decorator_list = [] @@ -113,6 +161,8 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: # Create the function definition including the body func = astor.to_source(method, indent_with=" " * 4) + if contextmanager_decorated: + func = func.replace("->Iterator[", "->ContextManager[") # Create export function body template = TEMPLATE.format( @@ -130,7 +180,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: return "\n\n".join(generated) -def matches_disk_files(new_files): +def matches_disk_files(new_files: Dict[str, str]) -> bool: for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False @@ -141,7 +191,9 @@ def matches_disk_files(new_files): return True -def process(sources_and_lookups, *, do_test): +def process( + sources_and_lookups: List[Tuple[Union[Path, str], str]], *, do_test: bool +) -> None: new_files = {} for source_path, lookup_path in sources_and_lookups: print("Scanning:", source_path) @@ -164,7 +216,7 @@ def process(sources_and_lookups, *, do_test): # This is in fact run in CI, but only in the formatting check job, which # doesn't collect coverage. -def main(): # pragma: no cover +def main() -> None: # pragma: no cover parser = argparse.ArgumentParser( description="Generate python code for public api wrappers" ) @@ -177,7 +229,7 @@ def main(): # pragma: no cover # Double-check we found the right directory assert (source_root / "LICENSE").exists() core = source_root / "trio/_core" - to_wrap = [ + to_wrap: List[Tuple[Union[Path, str], str]] = [ (core / "_run.py", "runner"), (core / "_instrumentation.py", "runner.instruments"), (core / "_io_windows.py", "runner.io_manager"), diff --git a/trio/_typing.py b/trio/_typing.py new file mode 100644 index 0000000000..8a62b9bfa8 --- /dev/null +++ b/trio/_typing.py @@ -0,0 +1,13 @@ +from typing import Union + +from typing_extensions import Protocol + +from ._core._run import _TaskStatus, _TaskStatusIgnored + + +class _HasFileno(Protocol): + def fileno(self) -> int: + ... + + +TaskStatus = Union[_TaskStatus, _TaskStatusIgnored] diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index 0d2d11c53c..b37b1a10b0 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -1,15 +1,14 @@ import os import errno +import sys +from typing import Optional from ._abc import Stream from ._util import ConflictDetector, Final import trio -if os.name != "posix": - # We raise an error here rather than gating the import in lowlevel.py - # in order to keep jedi static analysis happy. - raise ImportError +assert sys.platform != "win32" # XX TODO: is this a good number? who knows... it does match the default Linux # pipe capacity though. @@ -34,7 +33,7 @@ class _FdHolder: # impossible to make this mistake – we'll just get an EBADF. # # (This trick was copied from the stdlib socket module.) - def __init__(self, fd: int): + def __init__(self, fd: int) -> None: # make sure self.fd is always initialized to *something*, because even # if we error out here then __del__ will run and access it. self.fd = -1 @@ -46,10 +45,10 @@ def __init__(self, fd: int): os.set_blocking(fd, False) @property - def closed(self): + def closed(self) -> bool: return self.fd == -1 - def _raw_close(self): + def _raw_close(self) -> None: # This doesn't assume it's in a Trio context, so it can be called from # __del__. You should never call it from Trio context, because it # skips calling notify_fd_close. But from __del__, skipping that is @@ -64,10 +63,10 @@ def _raw_close(self): os.set_blocking(fd, self._original_is_blocking) os.close(fd) - def __del__(self): + def __del__(self) -> None: self._raw_close() - async def aclose(self): + async def aclose(self) -> None: if not self.closed: trio.lowlevel.notify_closing(self.fd) self._raw_close() @@ -107,7 +106,7 @@ class FdStream(Stream, metaclass=Final): A new `FdStream` object. """ - def __init__(self, fd: int): + def __init__(self, fd: int) -> None: self._fd_holder = _FdHolder(fd) self._send_conflict_detector = ConflictDetector( "another task is using this stream for send" @@ -116,7 +115,7 @@ def __init__(self, fd: int): "another task is using this stream for receive" ) - async def send_all(self, data: bytes): + async def send_all(self, data: bytes) -> None: with self._send_conflict_detector: # have to check up front, because send_all(b"") on a closed pipe # should raise @@ -152,7 +151,7 @@ async def wait_send_all_might_not_block(self) -> None: # of sending, which is annoying raise trio.BrokenResourceError from e - async def receive_some(self, max_bytes=None) -> bytes: + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: with self._receive_conflict_detector: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE @@ -180,7 +179,7 @@ async def receive_some(self, max_bytes=None) -> bytes: return data - async def aclose(self): + async def aclose(self) -> None: await self._fd_holder.aclose() def fileno(self): diff --git a/trio/_util.py b/trio/_util.py index ec0350b305..8ec5393bc9 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -3,7 +3,6 @@ # Little utilities we use internally from abc import ABCMeta -import os import signal import sys import pathlib @@ -17,7 +16,7 @@ import trio # Equivalent to the C function raise(), which Python doesn't wrap -if os.name == "nt": +if sys.platform == "win32": # On windows, os.kill exists but is really weird. # # If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver @@ -61,10 +60,13 @@ signal_raise = getattr(_lib, "raise") else: - def signal_raise(signum): + def signal_raise(signum: int) -> None: signal.pthread_kill(threading.get_ident(), signum) +_T = t.TypeVar("_T") + + # See: #461 as to why this is needed. # The gist is that threading.main_thread() has the capability to lie to us # if somebody else edits the threading ident cache to replace the main @@ -73,7 +75,7 @@ def signal_raise(signum): # Trying to use signal out of the main thread will fail, so we can then # reliably check if this is the main thread without relying on a # potentially modified threading. -def is_main_thread(): +def is_main_thread() -> bool: """Attempt to reliably check if we are in the main thread.""" try: signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT)) @@ -86,8 +88,10 @@ def is_main_thread(): # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. Returns coroutine object. ###### -def coroutine_or_error(async_fn, *args): - def _return_value_looks_like_wrong_library(value): +def coroutine_or_error( + async_fn: t.Callable[..., t.Awaitable[object]], *args: object +) -> t.Awaitable[object]: + def _return_value_looks_like_wrong_library(value: object) -> bool: # Returned by legacy @asyncio.coroutine functions, which includes # a surprising proportion of asyncio builtins. if isinstance(value, collections.abc.Generator): @@ -181,24 +185,29 @@ class ConflictDetector: """ - def __init__(self, msg): + def __init__(self, msg: str) -> None: self._msg = msg self._held = False - def __enter__(self): + def __enter__(self) -> None: if self._held: raise trio.BusyResourceError(self._msg) else: self._held = True - def __exit__(self, *args): + def __exit__(self, *args: object) -> None: self._held = False -def async_wraps(cls, wrapped_cls, attr_name): +_Fn = t.TypeVar("_Fn", bound=t.Callable) + + +def async_wraps( + cls: t.Type[object], wrapped_cls: t.Type[object], attr_name: str +) -> t.Callable[[_Fn], _Fn]: """Similar to wraps, but for async wrappers of non-async functions.""" - def decorator(func): + def decorator(func: _Fn) -> _Fn: func.__name__ = attr_name func.__qualname__ = ".".join((cls.__qualname__, attr_name)) @@ -213,10 +222,10 @@ def decorator(func): return decorator -def fixup_module_metadata(module_name, namespace): +def fixup_module_metadata(module_name: str, namespace: t.Dict[str, object]) -> None: seen_ids = set() - def fix_one(qualname, name, obj): + def fix_one(qualname: str, name: str, obj: object) -> None: # avoid infinite recursion (relevant when using # typing.Generic, for example) if id(obj) in seen_ids: @@ -229,9 +238,9 @@ def fix_one(qualname, name, obj): # Modules, unlike everything else in Python, put fully-qualitied # names into their __name__ attribute. We check for "." to avoid # rewriting these. - if hasattr(obj, "__name__") and "." not in obj.__name__: - obj.__name__ = name - obj.__qualname__ = qualname + if hasattr(obj, "__name__") and "." not in obj.__name__: # type: ignore[attr-defined] + obj.__name__ = name # type: ignore[attr-defined] + obj.__qualname__ = qualname # type: ignore[attr-defined] if isinstance(obj, type): for attr_name, attr_value in obj.__dict__.items(): fix_one(objname + "." + attr_name, attr_name, attr_value) @@ -241,7 +250,13 @@ def fix_one(qualname, name, obj): fix_one(objname, objname, obj) -class generic_function: +# TODO: This does not account for the generic parametrization via __getitem__. +# Presumably this will not work out in the long run even though it helped +# with trio._channel.open_memory_channel right now. +generic_function: t.Callable[[_T], _T] + + +class generic_function: # type: ignore[no-redef] """Decorator that makes a function indexable, to communicate non-inferrable generic type parameters to a static type checker. @@ -258,14 +273,14 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ but at least it becomes possible to write those. """ - def __init__(self, fn): + def __init__(self, fn: t.Callable[..., object]) -> None: update_wrapper(self, fn) self._fn = fn - def __call__(self, *args, **kwargs): + def __call__(self, *args: object, **kwargs: object) -> object: return self._fn(*args, **kwargs) - def __getitem__(self, _): + def __getitem__(self: _T, _: object) -> _T: return self @@ -298,16 +313,20 @@ class SomeClass(metaclass=Final): - TypeError if a sub class is created """ - def __new__(cls, name, bases, cls_namespace): + def __new__( + cls: t.Type[_T], + name: str, + bases: t.Tuple[type], + cls_namespace: t.Dict[str, object], + ) -> _T: for base in bases: if isinstance(base, Final): raise TypeError( f"{base.__module__}.{base.__qualname__} does not support subclassing" ) - return super().__new__(cls, name, bases, cls_namespace) - -T = t.TypeVar("T") + # https://github.com/python/mypy/issues/9282 + return super().__new__(cls, name, bases, cls_namespace) # type: ignore[no-any-return,misc] class NoPublicConstructor(Final): @@ -329,16 +348,16 @@ class SomeClass(metaclass=NoPublicConstructor): - TypeError if a sub class or an instance is created. """ - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: object, **kwargs: object) -> None: raise TypeError( f"{cls.__module__}.{cls.__qualname__} has no public constructor" ) - def _create(cls: t.Type[T], *args: t.Any, **kwargs: t.Any) -> T: - return super().__call__(*args, **kwargs) # type: ignore + def _create(cls: t.Type[_T], *args: t.Any, **kwargs: t.Any) -> _T: + return super().__call__(*args, **kwargs) # type: ignore[no-any-return,misc] -def name_asyncgen(agen): +def name_asyncgen(agen: t.AsyncGenerator) -> str: """Return the fully-qualified name of the async generator function that produced the async generator iterator *agen*. """ @@ -349,7 +368,7 @@ def name_asyncgen(agen): except (AttributeError, KeyError): module = "<{}>".format(agen.ag_code.co_filename) try: - qualname = agen.__qualname__ + qualname = agen.__qualname__ # type: ignore[attr-defined] except AttributeError: qualname = agen.ag_code.co_name return f"{module}.{qualname}" diff --git a/trio/_wait_for_object.py b/trio/_wait_for_object.py index 07c0461429..bb43e90cf8 100644 --- a/trio/_wait_for_object.py +++ b/trio/_wait_for_object.py @@ -1,4 +1,6 @@ import math +from typing import Union + from . import _timeouts import trio from ._core._windows_cffi import ( @@ -10,7 +12,7 @@ ) -async def WaitForSingleObject(obj): +async def WaitForSingleObject(obj: Union[int, object]) -> None: """Async and cancellable variant of WaitForSingleObject. Windows only. Args: @@ -50,7 +52,7 @@ async def WaitForSingleObject(obj): kernel32.CloseHandle(cancel_handle) -def WaitForMultipleObjects_sync(*handles): +def WaitForMultipleObjects_sync(*handles: object) -> None: """Wait for any of the given Windows handles to be signaled.""" n = len(handles) handle_arr = ffi.new("HANDLE[{}]".format(n)) diff --git a/trio/_windows_pipes.py b/trio/_windows_pipes.py index fb420535f4..9a0a1abf6a 100644 --- a/trio/_windows_pipes.py +++ b/trio/_windows_pipes.py @@ -1,5 +1,5 @@ import sys -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from . import _core from ._abc import SendStream, ReceiveStream from ._util import ConflictDetector, Final @@ -22,10 +22,10 @@ def __init__(self, handle: int) -> None: _core.register_with_iocp(self.handle) @property - def closed(self): + def closed(self) -> bool: return self.handle == -1 - def _close(self): + def _close(self) -> None: if self.closed: return handle = self.handle @@ -33,11 +33,11 @@ def _close(self): if not kernel32.CloseHandle(_handle(handle)): raise_winerror() - async def aclose(self): + async def aclose(self) -> None: self._close() await _core.checkpoint() - def __del__(self): + def __del__(self) -> None: self._close() @@ -52,7 +52,7 @@ def __init__(self, handle: int) -> None: "another task is currently using this pipe" ) - async def send_all(self, data: bytes): + async def send_all(self, data: bytes) -> None: with self._conflict_detector: if self._handle_holder.closed: raise _core.ClosedResourceError("this pipe is already closed") @@ -78,7 +78,7 @@ async def wait_send_all_might_not_block(self) -> None: # not implemented yet, and probably not needed await _core.checkpoint() - async def aclose(self): + async def aclose(self) -> None: await self._handle_holder.aclose() @@ -91,7 +91,7 @@ def __init__(self, handle: int) -> None: "another task is currently using this pipe" ) - async def receive_some(self, max_bytes=None) -> bytes: + async def receive_some(self, max_bytes: Optional[int] = None) -> bytes: with self._conflict_detector: if self._handle_holder.closed: raise _core.ClosedResourceError("this pipe is already closed") @@ -130,5 +130,5 @@ async def receive_some(self, max_bytes=None) -> bytes: del buffer[size:] return buffer - async def aclose(self): + async def aclose(self) -> None: await self._handle_holder.aclose() diff --git a/trio/socket.py b/trio/socket.py index 5402f5bc73..b4c9649502 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -22,7 +22,7 @@ # kept up to date. try: # fmt: off - from socket import ( # type: ignore + from socket import ( # type: ignore[attr-defined] CMSG_LEN, CMSG_SPACE, CAPI, AF_UNSPEC, AF_INET, AF_UNIX, AF_IPX, AF_APPLETALK, AF_INET6, AF_ROUTE, AF_LINK, AF_SNA, PF_SYSTEM, AF_SYSTEM, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, SOCK_SEQPACKET, SOCK_RDM, @@ -137,7 +137,7 @@ globals().update( { _name: getattr(_stdlib_socket, _name) - for _name in _stdlib_socket.__all__ # type: ignore + for _name in _stdlib_socket.__all__ # type: ignore[attr-defined] if _name.isupper() and _name not in _bad_symbols } ) diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 7a9006ff43..21be38e5bd 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -2,6 +2,7 @@ from contextlib import contextmanager import random +from typing import Iterator, Type from .. import _core from .._highlevel_generic import aclose_forcefully @@ -10,7 +11,7 @@ class _ForceCloseBoth: - def __init__(self, both): + def __init__(self, both) -> None: self._both = list(both) async def __aenter__(self): @@ -24,7 +25,7 @@ async def __aexit__(self, *args): @contextmanager -def _assert_raises(exc): +def _assert_raises(exc: Type[Exception]) -> Iterator[None]: __tracebackhide__ = True try: yield @@ -77,7 +78,7 @@ async def do_aclose(resource): nursery.start_soon(do_send_all, b"x") nursery.start_soon(checked_receive_1, b"x") - async def send_empty_then_y(): + async def send_empty_then_y() -> None: # Streams should tolerate sending b"" without giving it any # special meaning. await do_send_all(b"") @@ -136,7 +137,7 @@ async def simple_check_wait_send_all_might_not_block(scope): # closing the r side leads to BrokenResourceError on the s side # (eventually) - async def expect_broken_stream_on_send(): + async def expect_broken_stream_on_send() -> None: with _assert_raises(_core.BrokenResourceError): while True: await do_send_all(b"x" * 100) @@ -179,11 +180,11 @@ async def expect_broken_stream_on_send(): async with _ForceCloseBoth(await stream_maker()) as (s, r): # if send-then-graceful-close, receiver gets data then b"" - async def send_then_close(): + async def send_then_close() -> None: await do_send_all(b"y") await do_aclose(s) - async def receive_send_then_close(): + async def receive_send_then_close() -> None: # We want to make sure that if the sender closes the stream before # we read anything, then we still get all the data. But some # streams might block on the do_send_all call. So we let the @@ -437,7 +438,7 @@ async def receiver(s, data, seed): nursery.start_soon(receiver, s1, test_data[::-1], 2) nursery.start_soon(receiver, s2, test_data, 3) - async def expect_receive_some_empty(): + async def expect_receive_some_empty() -> None: assert await s2.receive_some(10) == b"" await s2.aclose() diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py index 5804295300..e23c8254c5 100644 --- a/trio/testing/_checkpoints.py +++ b/trio/testing/_checkpoints.py @@ -1,10 +1,11 @@ from contextlib import contextmanager +from typing import ContextManager, Iterator from .. import _core @contextmanager -def _assert_yields_or_not(expected): +def _assert_yields_or_not(expected: bool) -> Iterator[None]: __tracebackhide__ = True task = _core.current_task() orig_cancel = task._cancel_points @@ -22,7 +23,7 @@ def _assert_yields_or_not(expected): raise AssertionError("assert_no_checkpoints block yielded!") -def assert_checkpoints(): +def assert_checkpoints() -> ContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block either exits with an exception or executes at least one :ref:`checkpoint `. @@ -42,7 +43,7 @@ def assert_checkpoints(): return _assert_yields_or_not(True) -def assert_no_checkpoints(): +def assert_no_checkpoints() -> ContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block does not execute any :ref:`checkpoints `. diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 99ad7dfcaf..fd65ed5f01 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -11,7 +11,7 @@ class _UnboundedByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._closed = False self._lot = _core.ParkingLot() @@ -23,11 +23,11 @@ def __init__(self): # channel: so after close(), calling put() raises ClosedResourceError, and # calling the get() variants drains the buffer and then returns an empty # bytearray. - def close(self): + def close(self) -> None: self._closed = True self._lot.unpark_all() - def close_and_wipe(self): + def close_and_wipe(self) -> None: self._data = bytearray() self.close() @@ -122,7 +122,7 @@ async def send_all(self, data): if self.send_all_hook is not None: await self.send_all_hook() - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and then returns immediately. @@ -137,7 +137,7 @@ async def wait_send_all_might_not_block(self): if self.wait_send_all_might_not_block_hook is not None: await self.wait_send_all_might_not_block_hook() - def close(self): + def close(self) -> None: """Marks this stream as closed, and then calls the :attr:`close_hook` (if any). @@ -154,7 +154,7 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() @@ -204,7 +204,7 @@ class MemoryReceiveStream(ReceiveStream, metaclass=_util.Final): """ - def __init__(self, receive_some_hook=None, close_hook=None): + def __init__(self, receive_some_hook=None, close_hook=None) -> None: self._conflict_detector = _util.ConflictDetector( "another task is using this stream" ) @@ -236,7 +236,7 @@ async def receive_some(self, max_bytes=None): raise _core.ClosedResourceError return data - def close(self): + def close(self) -> None: """Discards any pending data from the internal buffer, and marks this stream as closed. @@ -246,7 +246,7 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() @@ -255,7 +255,7 @@ def put_data(self, data): """Appends the given data to the internal buffer.""" self._incoming.put(data) - def put_eof(self): + def put_eof(self) -> None: """Adds an end-of-file marker to the internal buffer.""" self._incoming.close() @@ -320,10 +320,10 @@ def memory_stream_one_way_pair(): send_stream = MemorySendStream() recv_stream = MemoryReceiveStream() - def pump_from_send_stream_to_recv_stream(): + def pump_from_send_stream_to_recv_stream() -> None: memory_stream_pump(send_stream, recv_stream) - async def async_pump_from_send_stream_to_recv_stream(): + async def async_pump_from_send_stream_to_recv_stream() -> None: pump_from_send_stream_to_recv_stream() send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream @@ -422,7 +422,7 @@ async def receiver(): class _LockstepByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._sender_closed = False self._receiver_closed = False @@ -435,7 +435,7 @@ def __init__(self): "another task is already receiving" ) - def _something_happened(self): + def _something_happened(self) -> None: self._waiters.unpark_all() # Always wakes up when one side is closed, because everyone always reacts @@ -449,11 +449,11 @@ async def _wait_for(self, fn): await self._waiters.park() await _core.checkpoint() - def close_sender(self): + def close_sender(self) -> None: self._sender_closed = True self._something_happened() - def close_receiver(self): + def close_receiver(self) -> None: self._receiver_closed = True self._something_happened() @@ -516,31 +516,31 @@ async def receive_some(self, max_bytes=None): class _LockstepSendStream(SendStream): - def __init__(self, lbq): + def __init__(self, lbq) -> None: self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_sender() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() async def send_all(self, data): await self._lbq.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: await self._lbq.wait_send_all_might_not_block() class _LockstepReceiveStream(ReceiveStream): - def __init__(self, lbq): + def __init__(self, lbq) -> None: self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_receiver() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() diff --git a/trio/testing/_sequencer.py b/trio/testing/_sequencer.py index a7e6e50ff0..4a1d42eca8 100644 --- a/trio/testing/_sequencer.py +++ b/trio/testing/_sequencer.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import AsyncIterator, DefaultDict, Set import attr from async_generator import asynccontextmanager @@ -7,9 +8,6 @@ from .. import _util from .. import Event -if False: - from typing import DefaultDict, Set - @attr.s(eq=False, hash=False) class Sequencer(metaclass=_util.Final): @@ -59,7 +57,7 @@ async def main(): _broken = attr.ib(default=False, init=False) @asynccontextmanager - async def __call__(self, position: int): + async def __call__(self, position: int) -> AsyncIterator[None]: # type: ignore[misc] if position in self._claimed: raise RuntimeError("Attempted to re-use sequence point {}".format(position)) if self._broken: diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index 4fcaeae372..d0bd8ab9d6 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -1,20 +1,25 @@ from functools import wraps, partial +from typing import Any, Callable, TypeVar from .. import _core from ..abc import Clock, Instrument +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) + # Use: # # @trio_test -# async def test_whatever(): +# async def test_whatever() -> None: # await ... # # Also: if a pytest fixture is passed in that subclasses the Clock abc, then # that clock is passed to trio.run(). -def trio_test(fn): - @wraps(fn) - def wrapper(**kwargs): +def trio_test(fn: _Fn) -> _Fn: + wrapper: _Fn + + @wraps(fn) # type: ignore[no-redef] + def wrapper(**kwargs: object) -> object: __tracebackhide__ = True clocks = [c for c in kwargs.values() if isinstance(c, Clock)] if not clocks: diff --git a/trio/tests/conftest.py b/trio/tests/conftest.py index 772486e1eb..f7e92662d8 100644 --- a/trio/tests/conftest.py +++ b/trio/tests/conftest.py @@ -5,7 +5,9 @@ # this stuff should become a proper pytest plugin import pytest +import _pytest.python import inspect +from typing import Callable from ..testing import trio_test, MockClock @@ -22,12 +24,12 @@ def pytest_configure(config): @pytest.fixture -def mock_clock(): +def mock_clock() -> MockClock: return MockClock() @pytest.fixture -def autojump_clock(): +def autojump_clock() -> MockClock: return MockClock(autojump_threshold=0) @@ -36,6 +38,6 @@ def autojump_clock(): # guess it's useful with the class- and file-level marking machinery (where # the raw @trio_test decorator isn't enough). @pytest.hookimpl(tryfirst=True) -def pytest_pyfunc_call(pyfuncitem): +def pytest_pyfunc_call(pyfuncitem: _pytest.python.Function) -> None: # type: ignore[misc] if inspect.iscoroutinefunction(pyfuncitem.obj): pyfuncitem.obj = trio_test(pyfuncitem.obj) diff --git a/trio/tests/module_with_deprecations.py b/trio/tests/module_with_deprecations.py index 73184d11e8..ed51f150c3 100644 --- a/trio/tests/module_with_deprecations.py +++ b/trio/tests/module_with_deprecations.py @@ -10,7 +10,7 @@ import sys this_mod = sys.modules[__name__] -assert this_mod.regular == "hi" +assert this_mod.regular == "hi" # type: ignore[attr-defined] assert not hasattr(this_mod, "dep1") __deprecated_attributes__ = { diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py index c445c97103..3a114abaf0 100644 --- a/trio/tests/test_abc.py +++ b/trio/tests/test_abc.py @@ -6,12 +6,12 @@ from .. import abc as tabc -async def test_AsyncResource_defaults(): +async def test_AsyncResource_defaults() -> None: @attr.s class MyAR(tabc.AsyncResource): record = attr.ib(factory=list) - async def aclose(self): + async def aclose(self) -> None: self.record.append("ac") async with MyAR() as myar: @@ -21,7 +21,7 @@ async def aclose(self): assert myar.record == ["ac"] -def test_abc_generics(): +def test_abc_generics() -> None: # Pythons below 3.5.2 had a typing.Generic that would throw # errors when instantiating or subclassing a parameterized # version of a class with any __slots__. This is why RunVar @@ -38,10 +38,10 @@ def send_nowait(self, value): async def send(self, value): raise RuntimeError # pragma: no cover - def clone(self): + def clone(self) -> None: raise RuntimeError # pragma: no cover - async def aclose(self): + async def aclose(self) -> None: pass # pragma: no cover channel = SlottedChannel() diff --git a/trio/tests/test_channel.py b/trio/tests/test_channel.py index fd990fb3e3..f23b58c48c 100644 --- a/trio/tests/test_channel.py +++ b/trio/tests/test_channel.py @@ -5,7 +5,7 @@ from trio import open_memory_channel, EndOfChannel -async def test_channel(): +async def test_channel() -> None: with pytest.raises(TypeError): open_memory_channel(1.0) with pytest.raises(ValueError): @@ -48,7 +48,7 @@ async def test_channel(): await r.aclose() -async def test_553(autojump_clock): +async def test_553(autojump_clock) -> None: s, r = open_memory_channel(1) with trio.move_on_after(10) as timeout_scope: await r.receive() @@ -56,7 +56,7 @@ async def test_553(autojump_clock): await s.send("Test for PR #553") -async def test_channel_multiple_producers(): +async def test_channel_multiple_producers() -> None: async def producer(send_channel, i): # We close our handle when we're done with it async with send_channel: @@ -79,7 +79,7 @@ async def producer(send_channel, i): assert got == list(range(30)) -async def test_channel_multiple_consumers(): +async def test_channel_multiple_consumers() -> None: successful_receivers = set() received = [] @@ -102,7 +102,7 @@ async def consumer(receive_channel, i): assert set(received) == set(range(10)) -async def test_close_basics(): +async def test_close_basics() -> None: async def send_block(s, expect): with pytest.raises(expect): await s.send(None) @@ -157,7 +157,7 @@ async def receive_block(r): await r.receive() -async def test_close_sync(): +async def test_close_sync() -> None: async def send_block(s, expect): with pytest.raises(expect): await s.send(None) @@ -212,7 +212,7 @@ async def receive_block(r): await r.receive() -async def test_receive_channel_clone_and_close(): +async def test_receive_channel_clone_and_close() -> None: s, r = open_memory_channel(10) r2 = r.clone() @@ -239,17 +239,17 @@ async def test_receive_channel_clone_and_close(): s.send_nowait(None) -async def test_close_multiple_send_handles(): +async def test_close_multiple_send_handles() -> None: # With multiple send handles, closing one handle only wakes senders on # that handle, but others can continue just fine s1, r = open_memory_channel(0) s2 = s1.clone() - async def send_will_close(): + async def send_will_close() -> None: with pytest.raises(trio.ClosedResourceError): await s1.send("nope") - async def send_will_succeed(): + async def send_will_succeed() -> None: await s2.send("ok") async with trio.open_nursery() as nursery: @@ -260,17 +260,17 @@ async def send_will_succeed(): assert await r.receive() == "ok" -async def test_close_multiple_receive_handles(): +async def test_close_multiple_receive_handles() -> None: # With multiple receive handles, closing one handle only wakes receivers on # that handle, but others can continue just fine s, r1 = open_memory_channel(0) r2 = r1.clone() - async def receive_will_close(): + async def receive_will_close() -> None: with pytest.raises(trio.ClosedResourceError): await r1.receive() - async def receive_will_succeed(): + async def receive_will_succeed() -> None: assert await r2.receive() == "ok" async with trio.open_nursery() as nursery: @@ -281,7 +281,7 @@ async def receive_will_succeed(): await s.send("ok") -async def test_inf_capacity(): +async def test_inf_capacity() -> None: s, r = open_memory_channel(float("inf")) # It's accepted, and we can send all day without blocking @@ -295,7 +295,7 @@ async def test_inf_capacity(): assert got == list(range(10)) -async def test_statistics(): +async def test_statistics() -> None: s, r = open_memory_channel(2) assert s.statistics() == r.statistics() @@ -345,7 +345,7 @@ async def test_statistics(): assert s.statistics().tasks_waiting_receive == 0 -async def test_channel_fairness(): +async def test_channel_fairness() -> None: # We can remove an item we just sent, and send an item back in after, if # no-one else is waiting. @@ -388,7 +388,7 @@ async def do_receive(r): assert (await r.receive()) == 2 -async def test_unbuffered(): +async def test_unbuffered() -> None: s, r = open_memory_channel(0) with pytest.raises(trio.WouldBlock): r.receive_nowait() diff --git a/trio/tests/test_deprecate.py b/trio/tests/test_deprecate.py index e5e1da8c5f..22a0f433d8 100644 --- a/trio/tests/test_deprecate.py +++ b/trio/tests/test_deprecate.py @@ -1,4 +1,5 @@ import pytest +import _pytest.recwarn import inspect import warnings @@ -14,7 +15,9 @@ @pytest.fixture -def recwarn_always(recwarn): +def recwarn_always( + recwarn: _pytest.recwarn.WarningsRecorder, +) -> _pytest.recwarn.WarningsRecorder: warnings.simplefilter("always") # ResourceWarnings about unclosed sockets can occur nondeterministically # (during GC) which throws off the tests in this file @@ -27,14 +30,15 @@ def _here(): return (info.filename, info.lineno) -def test_warn_deprecated(recwarn_always): - def deprecated_thing(): +def test_warn_deprecated(recwarn_always: _pytest.recwarn.WarningsRecorder) -> None: + def deprecated_thing() -> None: warn_deprecated("ice", "1.2", issue=1, instead="water") deprecated_thing() filename, lineno = _here() assert len(recwarn_always) == 1 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "ice is deprecated" in got.message.args[0] assert "Trio 1.2" in got.message.args[0] assert "water instead" in got.message.args[0] @@ -43,21 +47,26 @@ def deprecated_thing(): assert got.lineno == lineno - 1 -def test_warn_deprecated_no_instead_or_issue(recwarn_always): +def test_warn_deprecated_no_instead_or_issue( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: # Explicitly no instead or issue warn_deprecated("water", "1.3", issue=None, instead=None) assert len(recwarn_always) == 1 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "water is deprecated" in got.message.args[0] assert "no replacement" in got.message.args[0] assert "Trio 1.3" in got.message.args[0] -def test_warn_deprecated_stacklevel(recwarn_always): - def nested1(): +def test_warn_deprecated_stacklevel( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: + def nested1() -> None: nested2() - def nested2(): + def nested2() -> None: warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3) filename, lineno = _here() @@ -67,29 +76,33 @@ def nested2(): assert got.lineno == lineno + 1 -def old(): # pragma: no cover +def old() -> None: # pragma: no cover pass -def new(): # pragma: no cover +def new() -> None: # pragma: no cover pass -def test_warn_deprecated_formatting(recwarn_always): +def test_warn_deprecated_formatting( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: warn_deprecated(old, "1.0", issue=1, instead=new) got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.old is deprecated" in got.message.args[0] assert "test_deprecate.new instead" in got.message.args[0] @deprecated("1.5", issue=123, instead=new) -def deprecated_old(): +def deprecated_old() -> int: return 3 -def test_deprecated_decorator(recwarn_always): +def test_deprecated_decorator(recwarn_always: _pytest.recwarn.WarningsRecorder) -> None: assert deprecated_old() == 3 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0] assert "1.5" in got.message.args[0] assert "test_deprecate.new" in got.message.args[0] @@ -98,25 +111,31 @@ def test_deprecated_decorator(recwarn_always): class Foo: @deprecated("1.0", issue=123, instead="crying") - def method(self): + def method(self) -> int: return 7 -def test_deprecated_decorator_method(recwarn_always): +def test_deprecated_decorator_method( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: f = Foo() assert f.method() == 7 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.Foo.method is deprecated" in got.message.args[0] @deprecated("1.2", thing="the thing", issue=None, instead=None) -def deprecated_with_thing(): +def deprecated_with_thing() -> int: return 72 -def test_deprecated_decorator_with_explicit_thing(recwarn_always): +def test_deprecated_decorator_with_explicit_thing( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: assert deprecated_with_thing() == 72 got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "the thing is deprecated" in got.message.args[0] @@ -127,14 +146,16 @@ def new_hotness(): old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1) -def test_deprecated_alias(recwarn_always): +def test_deprecated_alias(recwarn_always: _pytest.recwarn.WarningsRecorder) -> None: assert old_hotness() == "new hotness" got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "test_deprecate.old_hotness is deprecated" in got.message.args[0] assert "1.23" in got.message.args[0] assert "test_deprecate.new_hotness instead" in got.message.args[0] assert "issues/1" in got.message.args[0] + assert old_hotness.__doc__ is not None assert ".. deprecated:: 1.23" in old_hotness.__doc__ assert "test_deprecate.new_hotness instead" in old_hotness.__doc__ assert "issues/1>`__" in old_hotness.__doc__ @@ -149,36 +170,39 @@ def new_hotness_method(self): ) -def test_deprecated_alias_method(recwarn_always): +def test_deprecated_alias_method( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: obj = Alias() assert obj.old_hotness_method() == "new hotness method" got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) msg = got.message.args[0] assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg assert "test_deprecate.Alias.new_hotness_method instead" in msg @deprecated("2.1", issue=1, instead="hi") -def docstring_test1(): # pragma: no cover +def docstring_test1() -> None: # pragma: no cover """Hello!""" @deprecated("2.1", issue=None, instead="hi") -def docstring_test2(): # pragma: no cover +def docstring_test2() -> None: # pragma: no cover """Hello!""" @deprecated("2.1", issue=1, instead=None) -def docstring_test3(): # pragma: no cover +def docstring_test3() -> None: # pragma: no cover """Hello!""" @deprecated("2.1", issue=None, instead=None) -def docstring_test4(): # pragma: no cover +def docstring_test4() -> None: # pragma: no cover """Hello!""" -def test_deprecated_docstring_munging(): +def test_deprecated_docstring_munging() -> None: assert ( docstring_test1.__doc__ == """Hello! @@ -220,24 +244,28 @@ def test_deprecated_docstring_munging(): ) -def test_module_with_deprecations(recwarn_always): +def test_module_with_deprecations( + recwarn_always: _pytest.recwarn.WarningsRecorder, +) -> None: assert module_with_deprecations.regular == "hi" assert len(recwarn_always) == 0 filename, lineno = _here() - assert module_with_deprecations.dep1 == "value1" + assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined] got = recwarn_always.pop(TrioDeprecationWarning) assert got.filename == filename assert got.lineno == lineno + 1 + assert isinstance(got.message, Warning) assert "module_with_deprecations.dep1" in got.message.args[0] assert "Trio 1.1" in got.message.args[0] assert "/issues/1" in got.message.args[0] assert "value1 instead" in got.message.args[0] - assert module_with_deprecations.dep2 == "value2" + assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined] got = recwarn_always.pop(TrioDeprecationWarning) + assert isinstance(got.message, Warning) assert "instead-string instead" in got.message.args[0] with pytest.raises(AttributeError): - module_with_deprecations.asdf + module_with_deprecations.asdf # type: ignore [attr-defined] diff --git a/trio/tests/test_exports.py b/trio/tests/test_exports.py index 374ce8c044..ba03f74c9c 100644 --- a/trio/tests/test_exports.py +++ b/trio/tests/test_exports.py @@ -13,7 +13,7 @@ from .. import _util -def test_core_is_properly_reexported(): +def test_core_is_properly_reexported() -> None: # Each export from _core should be re-exported by exactly one of these # three modules: sources = [trio, trio.lowlevel, trio.testing] @@ -69,7 +69,7 @@ def public_modules(module): ) @pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES) @pytest.mark.parametrize("tool", ["pylint", "jedi"]) -def test_static_tool_sees_all_symbols(tool, modname): +def test_static_tool_sees_all_symbols(tool: str, modname: str) -> None: module = importlib.import_module(modname) def no_underscores(symbols): @@ -113,7 +113,7 @@ def no_underscores(symbols): assert False -def test_classes_are_final(): +def test_classes_are_final() -> None: for module in PUBLIC_MODULES: for name, class_ in module.__dict__.items(): if not isinstance(class_, type): @@ -121,6 +121,9 @@ def test_classes_are_final(): # Deprecated classes are exported with a leading underscore if name.startswith("_"): # pragma: no cover continue + # TODO: fix RunVar as a generic to work in 3.6 + if name == "RunVar": + continue # Abstract classes can be subclassed, because that's the whole # point of ABCs diff --git a/trio/tests/test_file_io.py b/trio/tests/test_file_io.py index b40f7518a9..725693cd07 100644 --- a/trio/tests/test_file_io.py +++ b/trio/tests/test_file_io.py @@ -1,41 +1,53 @@ import io import os +from typing import Union +import py.path import pytest from unittest import mock from unittest.mock import sentinel import trio from trio import _core -from trio._file_io import AsyncIOWrapper, _FILE_SYNC_ATTRS, _FILE_ASYNC_METHODS +from trio._file_io import ( + _AsyncTextIOBase, + _AsyncBufferedIOBase, + _AsyncRawIOBase, + _AsyncIOBase, + AsyncIOWrapper, + _FILE_SYNC_ATTRS, + _FILE_ASYNC_METHODS, +) @pytest.fixture -def path(tmpdir): +def path(tmpdir: py.path.local) -> str: return os.fspath(tmpdir.join("test")) @pytest.fixture -def wrapped(): +def wrapped() -> mock.Mock: return mock.Mock(spec_set=io.StringIO) @pytest.fixture -def async_file(wrapped): +def async_file( + wrapped: mock.Mock, +) -> Union[_AsyncTextIOBase, _AsyncBufferedIOBase, _AsyncRawIOBase, _AsyncIOBase]: return trio.wrap_file(wrapped) -def test_wrap_invalid(): +def test_wrap_invalid() -> None: with pytest.raises(TypeError): trio.wrap_file(str()) -def test_wrap_non_iobase(): +def test_wrap_non_iobase() -> None: class FakeFile: - def close(self): # pragma: no cover + def close(self) -> None: # pragma: no cover pass - def write(self): # pragma: no cover + def write(self) -> None: # pragma: no cover pass wrapped = FakeFile() @@ -50,11 +62,11 @@ def write(self): # pragma: no cover trio.wrap_file(FakeFile()) -def test_wrapped_property(async_file, wrapped): +def test_wrapped_property(async_file, wrapped) -> None: assert async_file.wrapped is wrapped -def test_dir_matches_wrapped(async_file, wrapped): +def test_dir_matches_wrapped(async_file, wrapped) -> None: attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS) # all supported attrs in wrapped should be available in async_file @@ -65,9 +77,9 @@ def test_dir_matches_wrapped(async_file, wrapped): ) -def test_unsupported_not_forwarded(): +def test_unsupported_not_forwarded() -> None: class FakeFile(io.RawIOBase): - def unsupported_attr(self): # pragma: no cover + def unsupported_attr(self) -> None: # pragma: no cover pass async_file = trio.wrap_file(FakeFile()) @@ -78,7 +90,7 @@ def unsupported_attr(self): # pragma: no cover getattr(async_file, "unsupported_attr") -def test_sync_attrs_forwarded(async_file, wrapped): +def test_sync_attrs_forwarded(async_file, wrapped) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name not in dir(async_file): continue @@ -86,7 +98,7 @@ def test_sync_attrs_forwarded(async_file, wrapped): assert getattr(async_file, attr_name) is getattr(wrapped, attr_name) -def test_sync_attrs_match_wrapper(async_file, wrapped): +def test_sync_attrs_match_wrapper(async_file, wrapped) -> None: for attr_name in _FILE_SYNC_ATTRS: if attr_name in dir(async_file): continue @@ -98,7 +110,7 @@ def test_sync_attrs_match_wrapper(async_file, wrapped): getattr(wrapped, attr_name) -def test_async_methods_generated_once(async_file): +def test_async_methods_generated_once(async_file) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name not in dir(async_file): continue @@ -106,7 +118,7 @@ def test_async_methods_generated_once(async_file): assert getattr(async_file, meth_name) is getattr(async_file, meth_name) -def test_async_methods_signature(async_file): +def test_async_methods_signature(async_file) -> None: # use read as a representative of all async methods assert async_file.read.__name__ == "read" assert async_file.read.__qualname__ == "AsyncIOWrapper.read" @@ -114,7 +126,7 @@ def test_async_methods_signature(async_file): assert "io.StringIO.read" in async_file.read.__doc__ -async def test_async_methods_wrap(async_file, wrapped): +async def test_async_methods_wrap(async_file, wrapped) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name not in dir(async_file): continue @@ -132,7 +144,7 @@ async def test_async_methods_wrap(async_file, wrapped): wrapped.reset_mock() -async def test_async_methods_match_wrapper(async_file, wrapped): +async def test_async_methods_match_wrapper(async_file, wrapped) -> None: for meth_name in _FILE_ASYNC_METHODS: if meth_name in dir(async_file): continue @@ -144,7 +156,7 @@ async def test_async_methods_match_wrapper(async_file, wrapped): getattr(wrapped, meth_name) -async def test_open(path): +async def test_open(path) -> None: f = await trio.open_file(path, "w") assert isinstance(f, AsyncIOWrapper) @@ -152,7 +164,7 @@ async def test_open(path): await f.aclose() -async def test_open_context_manager(path): +async def test_open_context_manager(path) -> None: async with await trio.open_file(path, "w") as f: assert isinstance(f, AsyncIOWrapper) assert not f.closed @@ -160,7 +172,7 @@ async def test_open_context_manager(path): assert f.closed -async def test_async_iter(): +async def test_async_iter() -> None: async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar")) expected = list(async_file.wrapped) result = [] @@ -172,7 +184,7 @@ async def test_async_iter(): assert result == expected -async def test_aclose_cancelled(path): +async def test_aclose_cancelled(path) -> None: with _core.CancelScope() as cscope: f = await trio.open_file(path, "w") cscope.cancel() @@ -186,7 +198,7 @@ async def test_aclose_cancelled(path): assert f.closed -async def test_detach_rewraps_asynciobase(): +async def test_detach_rewraps_asynciobase() -> None: raw = io.BytesIO() buffered = io.BufferedReader(raw) diff --git a/trio/tests/test_highlevel_generic.py b/trio/tests/test_highlevel_generic.py index df2b2cecf7..33f8a7053e 100644 --- a/trio/tests/test_highlevel_generic.py +++ b/trio/tests/test_highlevel_generic.py @@ -13,10 +13,10 @@ class RecordSendStream(SendStream): async def send_all(self, data): self.record.append(("send_all", data)) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: self.record.append("wait_send_all_might_not_block") - async def aclose(self): + async def aclose(self) -> None: self.record.append("aclose") @@ -27,11 +27,11 @@ class RecordReceiveStream(ReceiveStream): async def receive_some(self, max_bytes=None): self.record.append(("receive_some", max_bytes)) - async def aclose(self): + async def aclose(self) -> None: self.record.append("aclose") -async def test_StapledStream(): +async def test_StapledStream() -> None: send_stream = RecordSendStream() receive_stream = RecordReceiveStream() stapled = StapledStream(send_stream, receive_stream) @@ -51,7 +51,7 @@ async def test_StapledStream(): assert send_stream.record == ["aclose"] send_stream.record.clear() - async def fake_send_eof(): + async def fake_send_eof() -> None: send_stream.record.append("send_eof") send_stream.send_eof = fake_send_eof @@ -71,16 +71,16 @@ async def fake_send_eof(): assert send_stream.record == ["aclose"] -async def test_StapledStream_with_erroring_close(): +async def test_StapledStream_with_erroring_close() -> None: # Make sure that if one of the aclose methods errors out, then the other # one still gets called. class BrokenSendStream(RecordSendStream): - async def aclose(self): + async def aclose(self) -> None: await super().aclose() raise ValueError class BrokenReceiveStream(RecordReceiveStream): - async def aclose(self): + async def aclose(self) -> None: await super().aclose() raise ValueError diff --git a/trio/tests/test_highlevel_open_tcp_listeners.py b/trio/tests/test_highlevel_open_tcp_listeners.py index d5fc576ec5..524a5ad74c 100644 --- a/trio/tests/test_highlevel_open_tcp_listeners.py +++ b/trio/tests/test_highlevel_open_tcp_listeners.py @@ -2,6 +2,7 @@ import socket as stdlib_socket import errno +from typing import Set import attr @@ -9,10 +10,11 @@ from trio import open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream from trio.testing import open_stream_to_socket_listener from .. import socket as tsocket +from .._abc import SocketFactory from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6 -async def test_open_tcp_listeners_basic(): +async def test_open_tcp_listeners_basic() -> None: listeners = await open_tcp_listeners(0) assert isinstance(listeners, list) for obj in listeners: @@ -40,7 +42,7 @@ async def test_open_tcp_listeners_basic(): await resource.aclose() -async def test_open_tcp_listeners_specific_port_specific_host(): +async def test_open_tcp_listeners_specific_port_specific_host() -> None: # Pick a port sock = tsocket.socket() await sock.bind(("127.0.0.1", 0)) @@ -53,7 +55,7 @@ async def test_open_tcp_listeners_specific_port_specific_host(): @binds_ipv6 -async def test_open_tcp_listeners_ipv6_v6only(): +async def test_open_tcp_listeners_ipv6_v6only() -> None: # Check IPV6_V6ONLY is working properly (ipv6_listener,) = await open_tcp_listeners(0, host="::1") async with ipv6_listener: @@ -63,7 +65,7 @@ async def test_open_tcp_listeners_ipv6_v6only(): await open_tcp_stream("127.0.0.1", port) -async def test_open_tcp_listeners_rebind(): +async def test_open_tcp_listeners_rebind() -> None: (l1,) = await open_tcp_listeners(0, host="127.0.0.1") sockaddr1 = l1.socket.getsockname() @@ -136,12 +138,12 @@ def listen(self, backlog): if self.poison_listen: raise FakeOSError("whoops") - def close(self): + def close(self) -> None: self.closed = True @attr.s -class FakeSocketFactory: +class FakeSocketFactory(SocketFactory): poison_after = attr.ib() sockets = attr.ib(factory=list) raise_on_family = attr.ib(factory=dict) # family => errno @@ -168,7 +170,7 @@ async def getaddrinfo(self, host, port, family, type, proto, flags): ] -async def test_open_tcp_listeners_multiple_host_cleanup_on_error(): +async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None: # If we were trying to bind to multiple hosts and one of them failed, they # call get cleaned up before returning fsf = FakeSocketFactory(3) @@ -191,7 +193,7 @@ async def test_open_tcp_listeners_multiple_host_cleanup_on_error(): assert sock.closed -async def test_open_tcp_listeners_port_checking(): +async def test_open_tcp_listeners_port_checking() -> None: for host in ["127.0.0.1", None]: with pytest.raises(TypeError): await open_tcp_listeners(None, host=host) @@ -201,7 +203,7 @@ async def test_open_tcp_listeners_port_checking(): await open_tcp_listeners("http", host=host) -async def test_serve_tcp(): +async def test_serve_tcp() -> None: async def handler(stream): await stream.send_all(b"x") @@ -222,8 +224,9 @@ async def handler(stream): [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}], ) async def test_open_tcp_listeners_some_address_families_unavailable( - try_families, fail_families -): + try_families: Set[stdlib_socket.AddressFamily], + fail_families: Set[stdlib_socket.AddressFamily], +) -> None: fsf = FakeSocketFactory( 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families} ) @@ -252,7 +255,7 @@ async def test_open_tcp_listeners_some_address_families_unavailable( assert not should_succeed -async def test_open_tcp_listeners_socket_fails_not_afnosupport(): +async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None: fsf = FakeSocketFactory( 10, raise_on_family={ @@ -280,7 +283,7 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport(): # effectively is no backlog), sometimes the host might not be enough resources # to give us the full requested backlog... it was a mess. So now we just check # that the backlog argument is passed through correctly. -async def test_open_tcp_listeners_backlog(): +async def test_open_tcp_listeners_backlog() -> None: fsf = FakeSocketFactory(99) tsocket.set_custom_socket_factory(fsf) for (given, expected) in [ diff --git a/trio/tests/test_highlevel_open_tcp_stream.py b/trio/tests/test_highlevel_open_tcp_stream.py index bcd3ef7f5a..7feefed5be 100644 --- a/trio/tests/test_highlevel_open_tcp_stream.py +++ b/trio/tests/test_highlevel_open_tcp_stream.py @@ -14,15 +14,15 @@ ) -def test_close_all(): +def test_close_all() -> None: class CloseMe: closed = False - def close(self): + def close(self) -> None: self.closed = True class CloseKiller: - def close(self): + def close(self) -> None: raise OSError c = CloseMe() @@ -45,7 +45,7 @@ def close(self): assert c.closed -def test_reorder_for_rfc_6555_section_5_4(): +def test_reorder_for_rfc_6555_section_5_4() -> None: def fake4(i): return ( AF_INET, @@ -82,7 +82,7 @@ def fake6(i): assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)] -def test_format_host_port(): +def test_format_host_port() -> None: assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port("example.com", 443) == "example.com:443" @@ -92,7 +92,7 @@ def test_format_host_port(): # Make sure we can connect to localhost using real kernel sockets -async def test_open_tcp_stream_real_socket_smoketest(): +async def test_open_tcp_stream_real_socket_smoketest() -> None: listen_sock = trio.socket.socket() await listen_sock.bind(("127.0.0.1", 0)) _, listen_port = listen_sock.getsockname() @@ -107,7 +107,7 @@ async def test_open_tcp_stream_real_socket_smoketest(): listen_sock.close() -async def test_open_tcp_stream_input_validation(): +async def test_open_tcp_stream_input_validation() -> None: with pytest.raises(ValueError): await open_tcp_stream(None, 80) with pytest.raises(TypeError): @@ -123,7 +123,7 @@ def can_bind_127_0_0_2(): return s.getsockname()[0] == "127.0.0.2" -async def test_local_address_real(): +async def test_local_address_real() -> None: with trio.socket.socket() as listener: await listener.bind(("127.0.0.1", 0)) listener.listen() @@ -200,7 +200,7 @@ async def connect(self, sockaddr): self.failing = True self.succeeded = True - def close(self): + def close(self) -> None: self.closed = True # called when SocketStream is constructed @@ -212,7 +212,7 @@ def setsockopt(self, *args, **kwargs): class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver): - def __init__(self, port, ip_list, supported_families): + def __init__(self, port, ip_list, supported_families) -> None: # ip_list have to be unique ip_order = [ip for (ip, _, _) in ip_list] assert len(set(ip_order)) == len(ip_list) @@ -310,19 +310,19 @@ async def run_scenario( return (exc, scenario) -async def test_one_host_quick_success(autojump_clock): +async def test_one_host_quick_success(autojump_clock) -> None: sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")]) assert sock.ip == "1.2.3.4" assert trio.current_time() == 0.123 -async def test_one_host_slow_success(autojump_clock): +async def test_one_host_slow_success(autojump_clock) -> None: sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")]) assert sock.ip == "1.2.3.4" assert trio.current_time() == 100 -async def test_one_host_quick_fail(autojump_clock): +async def test_one_host_quick_fail(autojump_clock) -> None: exc, scenario = await run_scenario( 82, [("1.2.3.4", 0.123, "error")], expect_error=OSError ) @@ -330,7 +330,7 @@ async def test_one_host_quick_fail(autojump_clock): assert trio.current_time() == 0.123 -async def test_one_host_slow_fail(autojump_clock): +async def test_one_host_slow_fail(autojump_clock) -> None: exc, scenario = await run_scenario( 83, [("1.2.3.4", 100, "error")], expect_error=OSError ) @@ -338,7 +338,7 @@ async def test_one_host_slow_fail(autojump_clock): assert trio.current_time() == 100 -async def test_one_host_failed_after_connect(autojump_clock): +async def test_one_host_failed_after_connect(autojump_clock) -> None: exc, scenario = await run_scenario( 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt ) @@ -346,7 +346,7 @@ async def test_one_host_failed_after_connect(autojump_clock): # With the default 0.250 second delay, the third attempt will win -async def test_basic_fallthrough(autojump_clock): +async def test_basic_fallthrough(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -365,7 +365,7 @@ async def test_basic_fallthrough(autojump_clock): } -async def test_early_success(autojump_clock): +async def test_early_success(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -384,7 +384,7 @@ async def test_early_success(autojump_clock): # With a 0.450 second delay, the first attempt will win -async def test_custom_delay(autojump_clock): +async def test_custom_delay(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -403,7 +403,7 @@ async def test_custom_delay(autojump_clock): } -async def test_custom_errors_expedite(autojump_clock): +async def test_custom_errors_expedite(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -424,7 +424,7 @@ async def test_custom_errors_expedite(autojump_clock): } -async def test_all_fail(autojump_clock): +async def test_all_fail(autojump_clock) -> None: exc, scenario = await run_scenario( 80, [ @@ -447,7 +447,7 @@ async def test_all_fail(autojump_clock): } -async def test_multi_success(autojump_clock): +async def test_multi_success(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -477,7 +477,7 @@ async def test_multi_success(autojump_clock): } -async def test_does_reorder(autojump_clock): +async def test_does_reorder(autojump_clock) -> None: sock, scenario = await run_scenario( 80, [ @@ -497,7 +497,7 @@ async def test_does_reorder(autojump_clock): } -async def test_handles_no_ipv4(autojump_clock): +async def test_handles_no_ipv4(autojump_clock) -> None: sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect @@ -519,7 +519,7 @@ async def test_handles_no_ipv4(autojump_clock): } -async def test_handles_no_ipv6(autojump_clock): +async def test_handles_no_ipv6(autojump_clock) -> None: sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect @@ -541,12 +541,12 @@ async def test_handles_no_ipv6(autojump_clock): } -async def test_no_hosts(autojump_clock): +async def test_no_hosts(autojump_clock) -> None: exc, scenario = await run_scenario(80, [], expect_error=OSError) assert "no results found" in str(exc) -async def test_cancel(autojump_clock): +async def test_cancel(autojump_clock) -> None: with trio.move_on_after(5) as cancel_scope: exc, scenario = await run_scenario( 80, diff --git a/trio/tests/test_highlevel_open_unix_stream.py b/trio/tests/test_highlevel_open_unix_stream.py index 211aff3e70..5ab40e2e12 100644 --- a/trio/tests/test_highlevel_open_unix_stream.py +++ b/trio/tests/test_highlevel_open_unix_stream.py @@ -11,11 +11,11 @@ pytestmark = pytest.mark.skip("Needs unix socket support") -def test_close_on_error(): +def test_close_on_error() -> None: class CloseMe: closed = False - def close(self): + def close(self) -> None: self.closed = True with close_on_error(CloseMe()) as c: @@ -29,12 +29,12 @@ def close(self): @pytest.mark.parametrize("filename", [4, 4.5]) -async def test_open_with_bad_filename_type(filename): +async def test_open_with_bad_filename_type(filename: float) -> None: with pytest.raises(TypeError): await open_unix_socket(filename) -async def test_open_bad_socket(): +async def test_open_bad_socket() -> None: # mktemp is marked as insecure, but that's okay, we don't want the file to # exist name = tempfile.mktemp() @@ -42,7 +42,7 @@ async def test_open_bad_socket(): await open_unix_socket(name) -async def test_open_unix_socket(): +async def test_open_unix_socket() -> None: for name_type in [Path, str]: name = tempfile.mktemp() serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) diff --git a/trio/tests/test_highlevel_serve_listeners.py b/trio/tests/test_highlevel_serve_listeners.py index b028092eb9..52deb42d5a 100644 --- a/trio/tests/test_highlevel_serve_listeners.py +++ b/trio/tests/test_highlevel_serve_listeners.py @@ -2,18 +2,24 @@ from functools import partial import errno +from typing import Tuple import attr import trio from trio.testing import memory_stream_pair, wait_all_tasks_blocked +from trio._channel import MemorySendChannel, MemoryReceiveChannel + + +def _open_memory_channel_1() -> Tuple[MemorySendChannel, MemoryReceiveChannel]: + return trio.open_memory_channel(1) @attr.s(hash=False, eq=False) class MemoryListener(trio.abc.Listener): closed = attr.ib(default=False) accepted_streams = attr.ib(factory=list) - queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1))) + queued_streams = attr.ib(factory=_open_memory_channel_1) accept_hook = attr.ib(default=None) async def connect(self): @@ -31,17 +37,17 @@ async def accept(self): self.accepted_streams.append(stream) return stream - async def aclose(self): + async def aclose(self) -> None: self.closed = True await trio.lowlevel.checkpoint() -async def test_serve_listeners_basic(): +async def test_serve_listeners_basic() -> None: listeners = [MemoryListener(), MemoryListener()] record = [] - def close_hook(): + def close_hook() -> None: # Make sure this is a forceful close assert trio.current_effective_deadline() == float("-inf") record.append("closed") @@ -81,11 +87,11 @@ async def do_tests(parent_nursery): assert listener.closed -async def test_serve_listeners_accept_unrecognized_error(): +async def test_serve_listeners_accept_unrecognized_error() -> None: for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]: listener = MemoryListener() - async def raise_error(): + async def raise_error() -> None: raise error listener.accept_hook = raise_error @@ -95,10 +101,10 @@ async def raise_error(): assert excinfo.value is error -async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog): +async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog) -> None: listener = MemoryListener() - async def raise_EMFILE(): + async def raise_EMFILE() -> None: raise OSError(errno.EMFILE, "out of file descriptors") listener.accept_hook = raise_EMFILE @@ -114,7 +120,7 @@ async def raise_EMFILE(): assert record.exc_info[1].errno == errno.EMFILE -async def test_serve_listeners_connection_nursery(autojump_clock): +async def test_serve_listeners_connection_nursery(autojump_clock) -> None: listener = MemoryListener() async def handler(stream): diff --git a/trio/tests/test_highlevel_socket.py b/trio/tests/test_highlevel_socket.py index 9dcb834d2c..894c99e403 100644 --- a/trio/tests/test_highlevel_socket.py +++ b/trio/tests/test_highlevel_socket.py @@ -14,7 +14,7 @@ from .. import socket as tsocket -async def test_SocketStream_basics(): +async def test_SocketStream_basics() -> None: # stdlib socket bad (even if connected) a, b = stdlib_socket.socketpair() with a, b: @@ -52,7 +52,7 @@ async def test_SocketStream_basics(): assert isinstance(b, bytes) -async def test_SocketStream_send_all(): +async def test_SocketStream_send_all() -> None: BIG = 10000000 a_sock, b_sock = tsocket.socketpair() @@ -63,7 +63,7 @@ async def test_SocketStream_send_all(): # Check a send_all that has to be split into multiple parts (on most # platforms... on Windows every send() either succeeds or fails as a # whole) - async def sender(): + async def sender() -> None: data = bytearray(BIG) await a.send_all(data) # send_all uses memoryviews internally, which temporarily "lock" @@ -87,7 +87,7 @@ async def sender(): # and we break our implementation of send_all, then we'll get some # early warning...) - async def receiver(): + async def receiver() -> None: # Make sure the sender fills up the kernel buffers and blocks await wait_all_tasks_blocked() nbytes = 0 @@ -108,7 +108,7 @@ async def receiver(): async def fill_stream(s): - async def sender(): + async def sender() -> None: while True: await s.send_all(b"x" * 10000) @@ -121,7 +121,7 @@ async def waiter(nursery): nursery.start_soon(waiter, nursery) -async def test_SocketStream_generic(): +async def test_SocketStream_generic() -> None: async def stream_maker(): left, right = tsocket.socketpair() return SocketStream(left), SocketStream(right) @@ -135,7 +135,7 @@ async def clogged_stream_maker(): await check_half_closeable_stream(stream_maker, clogged_stream_maker) -async def test_SocketListener(): +async def test_SocketListener() -> None: # Not a Trio socket with stdlib_socket.socket() as s: s.bind(("127.0.0.1", 0)) @@ -188,7 +188,7 @@ async def test_SocketListener(): await server_stream.aclose() -async def test_SocketListener_socket_closed_underfoot(): +async def test_SocketListener_socket_closed_underfoot() -> None: listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(10) @@ -203,9 +203,9 @@ async def test_SocketListener_socket_closed_underfoot(): await listener.accept() -async def test_SocketListener_accept_errors(): +async def test_SocketListener_accept_errors() -> None: class FakeSocket(tsocket.SocketType): - def __init__(self, events): + def __init__(self, events) -> None: self._events = iter(events) type = tsocket.SOCK_STREAM @@ -257,7 +257,7 @@ async def accept(self): assert s.socket is fake_server_sock -async def test_socket_stream_works_when_peer_has_already_closed(): +async def test_socket_stream_works_when_peer_has_already_closed() -> None: sock_a, sock_b = tsocket.socketpair() with sock_a, sock_b: await sock_b.send(b"x") diff --git a/trio/tests/test_highlevel_ssl_helpers.py b/trio/tests/test_highlevel_ssl_helpers.py index c00f5dc464..e9cee31cdc 100644 --- a/trio/tests/test_highlevel_ssl_helpers.py +++ b/trio/tests/test_highlevel_ssl_helpers.py @@ -43,7 +43,9 @@ async def getnameinfo(self, *args): # pragma: no cover # This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... # noqa is needed because flake8 doesn't understand how pytest fixtures work. -async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811 +async def test_open_ssl_over_tcp_stream_and_everything_else( + client_ctx, # noqa: F811 +) -> None: async with trio.open_nursery() as nursery: (listener,) = await nursery.start( partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") @@ -96,7 +98,7 @@ async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa nursery.cancel_scope.cancel() -async def test_open_ssl_over_tcp_listeners(): +async def test_open_ssl_over_tcp_listeners() -> None: (listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1") async with listener: assert isinstance(listener, trio.SSLListener) diff --git a/trio/tests/test_path.py b/trio/tests/test_path.py index 284bcf82dd..a6bb53cecc 100644 --- a/trio/tests/test_path.py +++ b/trio/tests/test_path.py @@ -1,15 +1,17 @@ import os import pathlib +from typing import Callable, Type, Union +import py.path import pytest import trio -from trio._path import AsyncAutoWrapperType as Type +from trio._path import AsyncAutoWrapperType as WrapperType from trio._file_io import AsyncIOWrapper @pytest.fixture -def path(tmpdir): +def path(tmpdir: py.path.local) -> trio.Path: p = str(tmpdir.join("test")) return trio.Path(p) @@ -20,14 +22,14 @@ def method_pair(path, method_name): return getattr(path, method_name), getattr(async_path, method_name) -async def test_open_is_async_context_manager(path): +async def test_open_is_async_context_manager(path) -> None: async with await path.open("w") as f: assert isinstance(f, AsyncIOWrapper) assert f.closed -async def test_magic(): +async def test_magic() -> None: path = trio.Path("test") assert str(path) == "test" @@ -42,7 +44,10 @@ async def test_magic(): @pytest.mark.parametrize("cls_a,cls_b", cls_pairs) -async def test_cmp_magic(cls_a, cls_b): +async def test_cmp_magic( + cls_a: Union[Type[pathlib.Path], Type[trio.Path]], + cls_b: Union[Type[pathlib.Path], Type[trio.Path]], +) -> None: a, b = cls_a(""), cls_b("") assert a == b assert not a != b @@ -69,7 +74,10 @@ async def test_cmp_magic(cls_a, cls_b): @pytest.mark.parametrize("cls_a,cls_b", cls_pairs) -async def test_div_magic(cls_a, cls_b): +async def test_div_magic( + cls_a: Union[Type[pathlib.Path], Type[trio.Path]], + cls_b: Union[Type[pathlib.Path], Type[trio.Path]], +) -> None: a, b = cls_a("a"), cls_b("b") result = a / b @@ -81,19 +89,23 @@ async def test_div_magic(cls_a, cls_b): "cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] ) @pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"]) -async def test_hash_magic(cls_a, cls_b, path): +async def test_hash_magic( + cls_a: Union[Type[pathlib.Path], Type[trio.Path]], + cls_b: Union[Type[pathlib.Path], Type[trio.Path]], + path: str, +) -> None: a, b = cls_a(path), cls_b(path) assert hash(a) == hash(b) -async def test_forwarded_properties(path): +async def test_forwarded_properties(path) -> None: # use `name` as a representative of forwarded properties assert "name" in dir(path) assert path.name == "test" -async def test_async_method_signature(path): +async def test_async_method_signature(path) -> None: # use `resolve` as a representative of wrapped methods assert path.resolve.__name__ == "resolve" @@ -103,7 +115,7 @@ async def test_async_method_signature(path): @pytest.mark.parametrize("method_name", ["is_dir", "is_file"]) -async def test_compare_async_stat_methods(method_name): +async def test_compare_async_stat_methods(method_name: str) -> None: method, async_method = method_pair(".", method_name) @@ -113,13 +125,13 @@ async def test_compare_async_stat_methods(method_name): assert result == async_result -async def test_invalid_name_not_wrapped(path): +async def test_invalid_name_not_wrapped(path) -> None: with pytest.raises(AttributeError): getattr(path, "invalid_fake_attr") @pytest.mark.parametrize("method_name", ["absolute", "resolve"]) -async def test_async_methods_rewrap(method_name): +async def test_async_methods_rewrap(method_name: str) -> None: method, async_method = method_pair(".", method_name) @@ -130,7 +142,7 @@ async def test_async_methods_rewrap(method_name): assert str(result) == str(async_result) -async def test_forward_methods_rewrap(path, tmpdir): +async def test_forward_methods_rewrap(path, tmpdir) -> None: with_name = path.with_name("foo") with_suffix = path.with_suffix(".py") @@ -140,17 +152,17 @@ async def test_forward_methods_rewrap(path, tmpdir): assert with_suffix == tmpdir.join("test.py") -async def test_forward_properties_rewrap(path): +async def test_forward_properties_rewrap(path) -> None: assert isinstance(path.parent, trio.Path) -async def test_forward_methods_without_rewrap(path, tmpdir): +async def test_forward_methods_without_rewrap(path, tmpdir) -> None: path = await path.parent.resolve() assert path.as_uri().startswith("file:///") -async def test_repr(): +async def test_repr() -> None: path = trio.Path(".") assert repr(path) == "trio.Path('.')" @@ -166,30 +178,30 @@ class MockWrapper: _wraps = MockWrapped -async def test_type_forwards_unsupported(): +async def test_type_forwards_unsupported() -> None: with pytest.raises(TypeError): - Type.generate_forwards(MockWrapper, {}) + WrapperType.generate_forwards(MockWrapper, {}) -async def test_type_wraps_unsupported(): +async def test_type_wraps_unsupported() -> None: with pytest.raises(TypeError): - Type.generate_wraps(MockWrapper, {}) + WrapperType.generate_wraps(MockWrapper, {}) -async def test_type_forwards_private(): - Type.generate_forwards(MockWrapper, {"unsupported": None}) +async def test_type_forwards_private() -> None: + WrapperType.generate_forwards(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") -async def test_type_wraps_private(): - Type.generate_wraps(MockWrapper, {"unsupported": None}) +async def test_type_wraps_private() -> None: + WrapperType.generate_wraps(MockWrapper, {"unsupported": None}) assert not hasattr(MockWrapper, "_private") @pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath]) -async def test_path_wraps_path(path, meth): +async def test_path_wraps_path(path: trio.Path, meth: Callable[..., trio.Path]) -> None: # type: ignore[misc] wrapped = await path.absolute() result = meth(path, wrapped) if result is None: @@ -198,17 +210,17 @@ async def test_path_wraps_path(path, meth): assert wrapped == result -async def test_path_nonpath(): +async def test_path_nonpath() -> None: with pytest.raises(TypeError): trio.Path(1) -async def test_open_file_can_open_path(path): +async def test_open_file_can_open_path(path) -> None: async with await trio.open_file(path, "w") as f: assert f.name == os.fspath(path) -async def test_globmethods(path): +async def test_globmethods(path) -> None: # Populate a directory tree await path.mkdir() await (path / "foo").mkdir() @@ -237,7 +249,7 @@ async def test_globmethods(path): assert entries == {"_bar.txt", "bar.txt"} -async def test_iterdir(path): +async def test_iterdir(path) -> None: # Populate a directory await path.mkdir() await (path / "foo").mkdir() @@ -251,7 +263,7 @@ async def test_iterdir(path): assert entries == {"bar.txt", "foo"} -async def test_classmethods(): +async def test_classmethods() -> None: assert isinstance(await trio.Path.home(), trio.Path) # pathlib.Path has only two classmethods diff --git a/trio/tests/test_scheduler_determinism.py b/trio/tests/test_scheduler_determinism.py index e2d3167e45..d52c066de8 100644 --- a/trio/tests/test_scheduler_determinism.py +++ b/trio/tests/test_scheduler_determinism.py @@ -17,7 +17,7 @@ async def tracer(name): return tuple(trace) -def test_the_trio_scheduler_is_not_deterministic(): +def test_the_trio_scheduler_is_not_deterministic() -> None: # At least, not yet. See https://github.com/python-trio/trio/issues/32 traces = [] for _ in range(10): @@ -25,7 +25,7 @@ def test_the_trio_scheduler_is_not_deterministic(): assert len(set(traces)) == len(traces) -def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch): +def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch) -> None: monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True) traces = [] for _ in range(10): diff --git a/trio/tests/test_signals.py b/trio/tests/test_signals.py index 235772f900..5006444f18 100644 --- a/trio/tests/test_signals.py +++ b/trio/tests/test_signals.py @@ -8,7 +8,7 @@ from .._signals import open_signal_receiver, _signal_handler -async def test_open_signal_receiver(): +async def test_open_signal_receiver() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL) as receiver: # Raise it a few times, to exercise signal coalescing, both at the @@ -32,7 +32,7 @@ async def test_open_signal_receiver(): assert signal.getsignal(signal.SIGILL) is orig -async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): +async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None: orig = signal.getsignal(signal.SIGILL) with pytest.raises(ValueError): with open_signal_receiver(signal.SIGILL, 1234567): @@ -41,13 +41,13 @@ async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): assert signal.getsignal(signal.SIGILL) is orig -async def test_open_signal_receiver_empty_fail(): +async def test_open_signal_receiver_empty_fail() -> None: with pytest.raises(TypeError, match="No signals were provided"): with open_signal_receiver(): pass -async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): +async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL, signal.SIGILL): pass @@ -55,8 +55,8 @@ async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): assert signal.getsignal(signal.SIGILL) is orig -async def test_catch_signals_wrong_thread(): - async def naughty(): +async def test_catch_signals_wrong_thread() -> None: + async def naughty() -> None: with open_signal_receiver(signal.SIGINT): pass # pragma: no cover @@ -64,7 +64,7 @@ async def naughty(): await trio.to_thread.run_sync(trio.run, naughty) -async def test_open_signal_receiver_conflict(): +async def test_open_signal_receiver_conflict() -> None: with pytest.raises(trio.BusyResourceError): with open_signal_receiver(signal.SIGILL) as receiver: async with trio.open_nursery() as nursery: @@ -74,14 +74,14 @@ async def test_open_signal_receiver_conflict(): # Blocks until all previous calls to run_sync_soon(idempotent=True) have been # processed. -async def wait_run_sync_soon_idempotent_queue_barrier(): +async def wait_run_sync_soon_idempotent_queue_barrier() -> None: ev = trio.Event() token = _core.current_trio_token() token.run_sync_soon(ev.set, idempotent=True) await ev.wait() -async def test_open_signal_receiver_no_starvation(): +async def test_open_signal_receiver_no_starvation() -> None: # Set up a situation where there are always 2 pending signals available to # report, and make sure that instead of getting the same signal reported # over and over, it alternates between reporting both of them. @@ -112,7 +112,7 @@ async def test_open_signal_receiver_no_starvation(): traceback.print_exc() -async def test_catch_signals_race_condition_on_exit(): +async def test_catch_signals_race_condition_on_exit() -> None: delivered_directly = set() def direct_handler(signo, frame): diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index f8c061ffd3..2a065bbd96 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -1,10 +1,12 @@ import errno import pytest +import _pytest.monkeypatch import attr import os import socket as stdlib_socket +import sys import inspect import tempfile import sys as _sys @@ -21,7 +23,7 @@ class MonkeypatchedGAI: - def __init__(self, orig_getaddrinfo): + def __init__(self, orig_getaddrinfo) -> None: self._orig_getaddrinfo = orig_getaddrinfo self._responses = {} self.record = [] @@ -50,13 +52,13 @@ def getaddrinfo(self, *args, **kwargs): @pytest.fixture -def monkeygai(monkeypatch): +def monkeygai(monkeypatch: _pytest.monkeypatch.MonkeyPatch) -> MonkeypatchedGAI: controller = MonkeypatchedGAI(stdlib_socket.getaddrinfo) monkeypatch.setattr(stdlib_socket, "getaddrinfo", controller.getaddrinfo) return controller -async def test__try_sync(): +async def test__try_sync() -> None: with assert_checkpoints(): async with _try_sync(): pass @@ -86,7 +88,7 @@ def _is_ValueError(exc): ################################################################ -def test_socket_has_some_reexports(): +def test_socket_has_some_reexports() -> None: assert tsocket.SOL_SOCKET == stdlib_socket.SOL_SOCKET assert tsocket.TCP_NODELAY == stdlib_socket.TCP_NODELAY assert tsocket.gaierror == stdlib_socket.gaierror @@ -98,7 +100,7 @@ def test_socket_has_some_reexports(): ################################################################ -async def test_getaddrinfo(monkeygai): +async def test_getaddrinfo(monkeygai) -> None: def check(got, expected): # win32 returns 0 for the proto field # musl and glibc have inconsistent handling of the canonical name @@ -174,7 +176,7 @@ def filtered(gai_list): await tsocket.getaddrinfo("asdf", "12345") -async def test_getnameinfo(): +async def test_getnameinfo() -> None: # Trivial test: ni_numeric = stdlib_socket.NI_NUMERICHOST | stdlib_socket.NI_NUMERICSERV with assert_checkpoints(): @@ -209,7 +211,7 @@ async def test_getnameinfo(): ################################################################ -async def test_from_stdlib_socket(): +async def test_from_stdlib_socket() -> None: sa, sb = stdlib_socket.socketpair() assert not isinstance(sa, tsocket.SocketType) with sa, sb: @@ -231,7 +233,7 @@ class MySocket(stdlib_socket.socket): tsocket.from_stdlib_socket(mysock) -async def test_from_fd(): +async def test_from_fd() -> None: sa, sb = stdlib_socket.socketpair() ta = tsocket.fromfd(sa.fileno(), sa.family, sa.type, sa.proto) with sa, sb, ta: @@ -240,7 +242,7 @@ async def test_from_fd(): assert sb.recv(3) == b"x" -async def test_socketpair_simple(): +async def test_socketpair_simple() -> None: async def child(sock): print("sending hello") await sock.send(b"h") @@ -254,33 +256,39 @@ async def child(sock): @pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only") -async def test_fromshare(): - a, b = tsocket.socketpair() - with a, b: - # share with ourselves - shared = a.share(os.getpid()) - a2 = tsocket.fromshare(shared) - with a2: - assert a.fileno() != a2.fileno() - await a2.send(b"x") - assert await b.recv(1) == b"x" - - -async def test_socket(): +async def test_fromshare() -> None: + if sys.platform != "win32": + # mypy doesn't recognize the pytest.mark.skipif and ignores an assert inside + # this function. + # https://github.com/python/mypy/issues/9025 + assert False # we should have been skipped, if not then fail + else: + a, b = tsocket.socketpair() + with a, b: + # share with ourselves + shared = a.share(os.getpid()) + a2 = tsocket.fromshare(shared) + with a2: + assert a.fileno() != a2.fileno() + await a2.send(b"x") + assert await b.recv(1) == b"x" + + +async def test_socket() -> None: with tsocket.socket() as s: assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET @creates_ipv6 -async def test_socket_v6(): +async def test_socket_v6() -> None: with tsocket.socket(tsocket.AF_INET6, tsocket.SOCK_DGRAM) as s: assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET6 @pytest.mark.skipif(not _sys.platform == "linux", reason="linux only") -async def test_sniff_sockopts(): +async def test_sniff_sockopts() -> None: from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM # generate the combinations of families/types we're testing: @@ -311,7 +319,7 @@ async def test_sniff_sockopts(): ################################################################ -async def test_SocketType_basics(): +async def test_SocketType_basics() -> None: sock = tsocket.socket() with sock as cm_enter_value: assert cm_enter_value is sock @@ -362,7 +370,7 @@ async def test_SocketType_basics(): sock.close() -async def test_SocketType_dup(): +async def test_SocketType_dup() -> None: a, b = tsocket.socketpair() with a, b: a2 = a.dup() @@ -374,7 +382,7 @@ async def test_SocketType_dup(): assert await b.recv(1) == b"x" -async def test_SocketType_shutdown(): +async def test_SocketType_shutdown() -> None: a, b = tsocket.socketpair() with a, b: await a.send(b"x") @@ -408,7 +416,9 @@ async def test_SocketType_shutdown(): pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6), ], ) -async def test_SocketType_simple_server(address, socket_type): +async def test_SocketType_simple_server( + address: str, socket_type: stdlib_socket.AddressFamily +) -> None: # listen, bind, accept, connect, getpeername, getsockname listener = tsocket.socket(socket_type) client = tsocket.socket(socket_type) @@ -425,7 +435,7 @@ async def test_SocketType_simple_server(address, socket_type): assert await client.recv(1) == b"x" -async def test_SocketType_is_readable(): +async def test_SocketType_is_readable() -> None: a, b = tsocket.socketpair() with a, b: assert not a.is_readable() @@ -481,7 +491,9 @@ class Addresses: ), ], ) -async def test_SocketType_resolve(socket_type, addrs): +async def test_SocketType_resolve( + socket_type: stdlib_socket.AddressFamily, addrs: Addresses +) -> None: v6 = socket_type == tsocket.AF_INET6 def pad(addr): @@ -498,9 +510,9 @@ def assert_eq(actual, expected): # getaddrinfo They also error out on None, but whatever, None is much # more consistent, so we accept it too. for null in [None, ""]: - got = await sock._resolve_local_address_nocp((null, 80)) + got = await sock._resolve_local_address_nocp((null, 80)) # type: ignore[attr-defined] assert_eq(got, (addrs.bind_all, 80)) - got = await sock._resolve_remote_address_nocp((null, 80)) + got = await sock._resolve_remote_address_nocp((null, 80)) # type: ignore[attr-defined] assert_eq(got, (addrs.localhost, 80)) # AI_PASSIVE only affects the wildcard address, so for everything else @@ -574,7 +586,7 @@ async def res(*args): await res(("1.2.3.4", 80, 0, 0)) -async def test_SocketType_unresolved_names(): +async def test_SocketType_unresolved_names() -> None: with tsocket.socket() as sock: await sock.bind(("localhost", 0)) assert sock.getsockname()[0] == "127.0.0.1" @@ -593,7 +605,7 @@ async def test_SocketType_unresolved_names(): # This tests all the complicated paths through _nonblocking_helper, using recv # as a stand-in for all the methods that use _nonblocking_helper. -async def test_SocketType_non_blocking_paths(): +async def test_SocketType_non_blocking_paths() -> None: a, b = stdlib_socket.socketpair() with a, b: ta = tsocket.from_stdlib_socket(a) @@ -616,7 +628,7 @@ async def test_SocketType_non_blocking_paths(): await ta.recv("haha") # block then succeed - async def do_successful_blocking_recv(): + async def do_successful_blocking_recv() -> None: with assert_checkpoints(): assert await ta.recv(10) == b"2" @@ -626,7 +638,7 @@ async def do_successful_blocking_recv(): b.send(b"2") # block then cancelled - async def do_cancelled_blocking_recv(): + async def do_cancelled_blocking_recv() -> None: with assert_checkpoints(): with pytest.raises(_core.Cancelled): await ta.recv(10) @@ -644,13 +656,13 @@ async def do_cancelled_blocking_recv(): # other: tb = tsocket.from_stdlib_socket(b) - async def t1(): + async def t1() -> None: with assert_checkpoints(): assert await ta.recv(1) == b"a" with assert_checkpoints(): assert await tb.recv(1) == b"b" - async def t2(): + async def t2() -> None: with assert_checkpoints(): assert await tb.recv(1) == b"b" with assert_checkpoints(): @@ -668,7 +680,7 @@ async def t2(): # This tests the complicated paths through connect -async def test_SocketType_connect_paths(): +async def test_SocketType_connect_paths() -> None: with tsocket.socket() as sock: with pytest.raises(ValueError): # Should be a tuple @@ -721,7 +733,7 @@ def connect(self, *args, **kwargs): await sock.connect(("127.0.0.1", 2)) -async def test_resolve_remote_address_exception_closes_socket(): +async def test_resolve_remote_address_exception_closes_socket() -> None: # Here we are testing issue 247, any cancellation will leave the socket closed with _core.CancelScope() as cancel_scope: with tsocket.socket() as sock: @@ -737,7 +749,7 @@ async def _resolve_remote_address_nocp(self, *args, **kwargs): assert sock.fileno() == -1 -async def test_send_recv_variants(): +async def test_send_recv_variants() -> None: a, b = tsocket.socketpair() with a, b: # recv, including with flags @@ -833,7 +845,7 @@ async def test_send_recv_variants(): assert await b.recv(10) == b"yyy" -async def test_idna(monkeygai): +async def test_idna(monkeygai) -> None: # This is the encoding for "faß.de", which uses one of the characters that # IDNA 2003 handles incorrectly: monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80) @@ -851,14 +863,14 @@ async def test_idna(monkeygai): assert "ok faß.de" == await tsocket.getaddrinfo(b"xn--fa-hia.de", 80) -async def test_getprotobyname(): +async def test_getprotobyname() -> None: # These are the constants used in IP header fields, so the numeric values # had *better* be stable across systems... assert await tsocket.getprotobyname("udp") == 17 assert await tsocket.getprotobyname("tcp") == 6 -async def test_custom_hostname_resolver(monkeygai): +async def test_custom_hostname_resolver(monkeygai) -> None: class CustomResolver: async def getaddrinfo(self, host, port, family, type, proto, flags): return ("custom_gai", host, port, family, type, proto, flags) @@ -902,7 +914,7 @@ async def getnameinfo(self, sockaddr, flags): assert await tsocket.getaddrinfo("host", "port") == "x" -async def test_custom_socket_factory(): +async def test_custom_socket_factory() -> None: class CustomSocketFactory: def socket(self, family, type, proto): return ("hi", family, type, proto) @@ -929,13 +941,13 @@ def socket(self, family, type, proto): assert tsocket.set_custom_socket_factory(None) is csf -async def test_SocketType_is_abstract(): +async def test_SocketType_is_abstract() -> None: with pytest.raises(TypeError): tsocket.SocketType() @pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets") -async def test_unix_domain_socket(): +async def test_unix_domain_socket() -> None: # Bind has a special branch to use a thread, since it has to do filesystem # traversal. Maybe connect should too? Not sure. @@ -964,7 +976,7 @@ async def check_AF_UNIX(path): pass -async def test_interrupted_by_close(): +async def test_interrupted_by_close() -> None: a_stdlib, b_stdlib = stdlib_socket.socketpair() with a_stdlib, b_stdlib: a_stdlib.setblocking(False) @@ -979,11 +991,11 @@ async def test_interrupted_by_close(): a = tsocket.from_stdlib_socket(a_stdlib) - async def sender(): + async def sender() -> None: with pytest.raises(_core.ClosedResourceError): await a.send(data) - async def receiver(): + async def receiver() -> None: with pytest.raises(_core.ClosedResourceError): await a.recv(1) @@ -994,7 +1006,7 @@ async def receiver(): a.close() -async def test_many_sockets(): +async def test_many_sockets() -> None: total = 5000 # Must be more than MAX_AFD_GROUP_SIZE sockets = [] for x in range(total // 2): diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index f160af4999..073f9ab1fe 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -1,10 +1,12 @@ import pytest +import _pytest.fixtures import threading import socket as stdlib_socket import ssl from contextlib import contextmanager from functools import partial +from typing import AsyncIterator, Iterator from OpenSSL import SSL import trustme @@ -26,6 +28,7 @@ assert_checkpoints, Sequencer, memory_stream_pair, + MockClock, lockstep_stream_pair, check_two_way_stream, ) @@ -71,7 +74,7 @@ @pytest.fixture(scope="module", params=client_ctx_params) -def client_ctx(request): +def client_ctx(request: _pytest.fixtures.SubRequest) -> ssl.SSLContext: ctx = ssl.create_default_context() TRIO_TEST_CA.configure_trust(ctx) if request.param in ["default", "tls13"]: @@ -141,7 +144,7 @@ def ssl_echo_serve_sync(sock, *, expect_fail=False): # (running in a thread). Useful for testing making connections with different # SSLContexts. @asynccontextmanager -async def ssl_echo_server_raw(**kwargs): +async def ssl_echo_server_raw(**kwargs: object) -> AsyncIterator[SocketStream]: # type: ignore[misc] a, b = stdlib_socket.socketpair() async with trio.open_nursery() as nursery: # Exiting the 'with a, b' context manager closes the sockets, which @@ -158,7 +161,9 @@ async def ssl_echo_server_raw(**kwargs): # Fixture that gives a properly set up SSLStream connected to a trio-test-1 # echo server (running in a thread) @asynccontextmanager -async def ssl_echo_server(client_ctx, **kwargs): +async def ssl_echo_server( # type: ignore[misc] + client_ctx: ssl.SSLContext, **kwargs: object +) -> AsyncIterator[SSLStream]: async with ssl_echo_server_raw(**kwargs) as sock: yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") @@ -167,7 +172,7 @@ async def ssl_echo_server(client_ctx, **kwargs): # Doesn't inherit from Stream because I left out the methods that we don't # actually need. class PyOpenSSLEchoStream: - def __init__(self, sleeper=None): + def __init__(self, sleeper=None) -> None: ctx = SSL.Context(SSL.SSLv23_METHOD) # TLS 1.3 removes renegotiation support. Which is great for them, but # we still have to support versions before that, and that means we @@ -224,18 +229,18 @@ async def no_op_sleeper(_): else: self.sleeper = sleeper - async def aclose(self): + async def aclose(self) -> None: self._conn.bio_shutdown() def renegotiate_pending(self): return self._conn.renegotiate_pending() - def renegotiate(self): + def renegotiate(self) -> None: # Returns false if a renegotiation is already in progress, meaning # nothing happens. assert self._conn.renegotiate() - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_all_conflict_detector: await _core.checkpoint() await _core.checkpoint() @@ -320,7 +325,7 @@ async def receive_some(self, nbytes=None): print(" <-- transport_stream.receive_some finished") -async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): +async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None: # Make sure that PyOpenSSLEchoStream complains if two tasks call send_all # at the same time, or ditto for receive_some. The tricky cases where SSLStream # might accidentally do this are during renegotiation, which we test using @@ -357,7 +362,9 @@ async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): @contextmanager -def virtual_ssl_echo_server(client_ctx, **kwargs): +def virtual_ssl_echo_server( + client_ctx: ssl.SSLContext, **kwargs: object +) -> Iterator[SSLStream]: fakesock = PyOpenSSLEchoStream(**kwargs) yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org") @@ -395,7 +402,7 @@ def ssl_lockstep_stream_pair(client_ctx, **kwargs): # Simple smoke test for handshake/send/receive/shutdown talking to a # synchronous server, plus make sure that we do the bare minimum of # certificate checking (even though this is really Python's responsibility) -async def test_ssl_client_basics(client_ctx): +async def test_ssl_client_basics(client_ctx) -> None: # Everything OK async with ssl_echo_server(client_ctx) as s: assert not s.server_side @@ -421,7 +428,7 @@ async def test_ssl_client_basics(client_ctx): assert isinstance(excinfo.value.__cause__, ssl.CertificateError) -async def test_ssl_server_basics(client_ctx): +async def test_ssl_server_basics(client_ctx) -> None: a, b = stdlib_socket.socketpair() with a, b: server_sock = tsocket.from_stdlib_socket(b) @@ -430,7 +437,7 @@ async def test_ssl_server_basics(client_ctx): ) assert server_transport.server_side - def client(): + def client() -> None: with client_ctx.wrap_socket( a, server_hostname="trio-test-1.example.org" ) as client_sock: @@ -451,7 +458,7 @@ def client(): t.join() -async def test_attributes(client_ctx): +async def test_attributes(client_ctx) -> None: async with ssl_echo_server_raw(expect_fail=True) as sock: good_ctx = client_ctx bad_ctx = ssl.create_default_context() @@ -520,7 +527,7 @@ async def test_attributes(client_ctx): # I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it... -async def test_full_duplex_basics(client_ctx): +async def test_full_duplex_basics(client_ctx) -> None: CHUNKS = 30 CHUNK_SIZE = 32768 EXPECTED = CHUNKS * CHUNK_SIZE @@ -557,7 +564,7 @@ async def receiver(s): assert sent == received -async def test_renegotiation_simple(client_ctx): +async def test_renegotiation_simple(client_ctx) -> None: with virtual_ssl_echo_server(client_ctx) as s: await s.do_handshake() @@ -576,7 +583,9 @@ async def test_renegotiation_simple(client_ctx): @slow -async def test_renegotiation_randomized(mock_clock, client_ctx): +async def test_renegotiation_randomized( + mock_clock: MockClock, client_ctx: ssl.SSLContext +) -> None: # The only blocking things in this function are our random sleeps, so 0 is # a good threshold. mock_clock.autojump_threshold = 0 @@ -588,7 +597,7 @@ async def test_renegotiation_randomized(mock_clock, client_ctx): async def sleeper(_): await trio.sleep(r.uniform(0, 10)) - async def clear(): + async def clear() -> None: while s.transport_stream.renegotiate_pending(): with assert_checkpoints(): await send(b"-") @@ -654,7 +663,7 @@ async def sleeper_with_slow_send_all(method): # And our wait_send_all_might_not_block call will give it time to get # stuck, and then start - async def sleep_then_wait_writable(): + async def sleep_then_wait_writable() -> None: await trio.sleep(1000) await s.wait_send_all_might_not_block() @@ -692,16 +701,16 @@ async def sleeper_with_slow_wait_writable_and_expect(method): await s.aclose() -async def test_resource_busy_errors(client_ctx): - async def do_send_all(): +async def test_resource_busy_errors(client_ctx) -> None: + async def do_send_all() -> None: with assert_checkpoints(): await s.send_all(b"x") - async def do_receive_some(): + async def do_receive_some() -> None: with assert_checkpoints(): await s.receive_some(1) - async def do_wait_send_all_might_not_block(): + async def do_wait_send_all_might_not_block() -> None: with assert_checkpoints(): await s.wait_send_all_might_not_block() @@ -734,11 +743,11 @@ async def do_wait_send_all_might_not_block(): assert "another task" in str(excinfo.value) -async def test_wait_writable_calls_underlying_wait_writable(): +async def test_wait_writable_calls_underlying_wait_writable() -> None: record = [] class NotAStream: - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: record.append("ok") ctx = ssl.create_default_context() @@ -747,7 +756,7 @@ async def wait_send_all_might_not_block(self): assert record == ["ok"] -async def test_checkpoints(client_ctx): +async def test_checkpoints(client_ctx) -> None: async with ssl_echo_server(client_ctx) as s: with assert_checkpoints(): await s.do_handshake() @@ -776,7 +785,7 @@ async def test_checkpoints(client_ctx): await s.aclose() -async def test_send_all_empty_string(client_ctx): +async def test_send_all_empty_string(client_ctx) -> None: async with ssl_echo_server(client_ctx) as s: await s.do_handshake() @@ -793,7 +802,9 @@ async def test_send_all_empty_string(client_ctx): @pytest.mark.parametrize("https_compatible", [False, True]) -async def test_SSLStream_generic(client_ctx, https_compatible): +async def test_SSLStream_generic( + client_ctx: ssl.SSLContext, https_compatible: bool +) -> None: async def stream_maker(): return ssl_memory_stream_pair( client_ctx, @@ -817,14 +828,14 @@ async def clogged_stream_maker(): await check_two_way_stream(stream_maker, clogged_stream_maker) -async def test_unwrap(client_ctx): +async def test_unwrap(client_ctx) -> None: client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) client_transport = client_ssl.transport_stream server_transport = server_ssl.transport_stream seq = Sequencer() - async def client(): + async def client() -> None: await client_ssl.do_handshake() await client_ssl.send_all(b"x") assert await client_ssl.receive_some(1) == b"y" @@ -851,7 +862,7 @@ async def client(): client_transport.send_stream.send_all_hook = send_all_hook await client_transport.send_stream.send_all_hook() - async def server(): + async def server() -> None: await server_ssl.do_handshake() assert await server_ssl.receive_some(1) == b"x" await server_ssl.send_all(b"y") @@ -871,7 +882,7 @@ async def server(): nursery.start_soon(server) -async def test_closing_nice_case(client_ctx): +async def test_closing_nice_case(client_ctx) -> None: # the nice case: graceful closes all around client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) @@ -879,11 +890,11 @@ async def test_closing_nice_case(client_ctx): # Both the handshake and the close require back-and-forth discussion, so # we need to run them concurrently - async def client_closer(): + async def client_closer() -> None: with assert_checkpoints(): await client_ssl.aclose() - async def server_closer(): + async def server_closer() -> None: assert await server_ssl.receive_some(10) == b"" assert await server_ssl.receive_some(10) == b"" with assert_checkpoints(): @@ -922,7 +933,7 @@ async def server_closer(): # the other side client_ssl, server_ssl = ssl_memory_stream_pair(client_ctx) - async def expect_eof_server(): + async def expect_eof_server() -> None: with assert_checkpoints(): assert await server_ssl.receive_some(10) == b"" with assert_checkpoints(): @@ -933,14 +944,14 @@ async def expect_eof_server(): nursery.start_soon(expect_eof_server) -async def test_send_all_fails_in_the_middle(client_ctx): +async def test_send_all_fails_in_the_middle(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: nursery.start_soon(client.do_handshake) nursery.start_soon(server.do_handshake) - async def bad_hook(): + async def bad_hook() -> None: raise KeyError client.transport_stream.send_stream.send_all_hook = bad_hook @@ -953,7 +964,7 @@ async def bad_hook(): closed = 0 - def close_hook(): + def close_hook() -> None: nonlocal closed closed += 1 @@ -964,7 +975,7 @@ def close_hook(): assert closed == 2 -async def test_ssl_over_ssl(client_ctx): +async def test_ssl_over_ssl(client_ctx) -> None: client_0, server_0 = memory_stream_pair() client_1 = SSLStream( @@ -977,11 +988,11 @@ async def test_ssl_over_ssl(client_ctx): ) server_2 = SSLStream(server_1, SERVER_CTX, server_side=True) - async def client(): + async def client() -> None: await client_2.send_all(b"hi") assert await client_2.receive_some(10) == b"bye" - async def server(): + async def server() -> None: assert await server_2.receive_some(10) == b"hi" await server_2.send_all(b"bye") @@ -990,7 +1001,7 @@ async def server(): nursery.start_soon(server) -async def test_ssl_bad_shutdown(client_ctx): +async def test_ssl_bad_shutdown(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -1007,7 +1018,7 @@ async def test_ssl_bad_shutdown(client_ctx): await server.aclose() -async def test_ssl_bad_shutdown_but_its_ok(client_ctx): +async def test_ssl_bad_shutdown_but_its_ok(client_ctx) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, @@ -1027,7 +1038,7 @@ async def test_ssl_bad_shutdown_but_its_ok(client_ctx): await server.aclose() -async def test_ssl_handshake_failure_during_aclose(): +async def test_ssl_handshake_failure_during_aclose() -> None: # Weird scenario: aclose() triggers an automatic handshake, and this # fails. This also exercises a bit of code in aclose() that was otherwise # uncovered, for re-raising exceptions after calling aclose_forcefully on @@ -1046,7 +1057,7 @@ async def test_ssl_handshake_failure_during_aclose(): await s.aclose() -async def test_ssl_only_closes_stream_once(client_ctx): +async def test_ssl_only_closes_stream_once(client_ctx) -> None: # We used to have a bug where if transport_stream.aclose() raised an # error, we would call it again. This checks that that's fixed. client, server = ssl_memory_stream_pair(client_ctx) @@ -1058,7 +1069,7 @@ async def test_ssl_only_closes_stream_once(client_ctx): client_orig_close_hook = client.transport_stream.send_stream.close_hook transport_close_count = 0 - def close_hook(): + def close_hook() -> None: nonlocal transport_close_count client_orig_close_hook() transport_close_count += 1 @@ -1071,7 +1082,7 @@ def close_hook(): assert transport_close_count == 1 -async def test_ssl_https_compatibility_disagreement(client_ctx): +async def test_ssl_https_compatibility_disagreement(client_ctx) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": False}, @@ -1084,7 +1095,7 @@ async def test_ssl_https_compatibility_disagreement(client_ctx): # client is in HTTPS-mode, server is not # so client doing graceful_shutdown causes an error on server - async def receive_and_expect_error(): + async def receive_and_expect_error() -> None: with pytest.raises(BrokenResourceError) as excinfo: await server.receive_some(10) assert isinstance(excinfo.value.__cause__, ssl.SSLEOFError) @@ -1094,14 +1105,14 @@ async def receive_and_expect_error(): nursery.start_soon(receive_and_expect_error) -async def test_https_mode_eof_before_handshake(client_ctx): +async def test_https_mode_eof_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, client_kwargs={"https_compatible": True}, ) - async def server_expect_clean_eof(): + async def server_expect_clean_eof() -> None: assert await server.receive_some(10) == b"" async with _core.open_nursery() as nursery: @@ -1109,10 +1120,10 @@ async def server_expect_clean_eof(): nursery.start_soon(server_expect_clean_eof) -async def test_send_error_during_handshake(client_ctx): +async def test_send_error_during_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) - async def bad_hook(): + async def bad_hook() -> None: raise KeyError client.transport_stream.send_stream.send_all_hook = bad_hook @@ -1126,10 +1137,10 @@ async def bad_hook(): await client.do_handshake() -async def test_receive_error_during_handshake(client_ctx): +async def test_receive_error_during_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) - async def bad_hook(): + async def bad_hook() -> None: raise KeyError client.transport_stream.receive_stream.receive_some_hook = bad_hook @@ -1149,7 +1160,7 @@ async def client_side(cancel_scope): await client.do_handshake() -async def test_selected_alpn_protocol_before_handshake(client_ctx): +async def test_selected_alpn_protocol_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1159,7 +1170,7 @@ async def test_selected_alpn_protocol_before_handshake(client_ctx): server.selected_alpn_protocol() -async def test_selected_alpn_protocol_when_not_set(client_ctx): +async def test_selected_alpn_protocol_when_not_set(client_ctx) -> None: # ALPN protocol still returns None when it's not ser, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1174,7 +1185,7 @@ async def test_selected_alpn_protocol_when_not_set(client_ctx): assert client.selected_alpn_protocol() == server.selected_alpn_protocol() -async def test_selected_npn_protocol_before_handshake(client_ctx): +async def test_selected_npn_protocol_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1184,7 +1195,7 @@ async def test_selected_npn_protocol_before_handshake(client_ctx): server.selected_npn_protocol() -async def test_selected_npn_protocol_when_not_set(client_ctx): +async def test_selected_npn_protocol_when_not_set(client_ctx) -> None: # NPN protocol still returns None when it's not ser, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1199,7 +1210,7 @@ async def test_selected_npn_protocol_when_not_set(client_ctx): assert client.selected_npn_protocol() == server.selected_npn_protocol() -async def test_get_channel_binding_before_handshake(client_ctx): +async def test_get_channel_binding_before_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) with pytest.raises(NeedHandshakeError): @@ -1209,7 +1220,7 @@ async def test_get_channel_binding_before_handshake(client_ctx): server.get_channel_binding() -async def test_get_channel_binding_after_handshake(client_ctx): +async def test_get_channel_binding_after_handshake(client_ctx) -> None: client, server = ssl_memory_stream_pair(client_ctx) async with _core.open_nursery() as nursery: @@ -1222,7 +1233,7 @@ async def test_get_channel_binding_after_handshake(client_ctx): assert client.get_channel_binding() == server.get_channel_binding() -async def test_getpeercert(client_ctx): +async def test_getpeercert(client_ctx) -> None: # Make sure we're not affected by https://bugs.python.org/issue29334 client, server = ssl_memory_stream_pair(client_ctx) @@ -1235,7 +1246,7 @@ async def test_getpeercert(client_ctx): assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"] -async def test_SSLListener(client_ctx): +async def test_SSLListener(client_ctx) -> None: async def setup(**kwargs): listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) diff --git a/trio/tests/test_subprocess.py b/trio/tests/test_subprocess.py index 7ba794a428..793b1df573 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/tests/test_subprocess.py @@ -3,7 +3,9 @@ import subprocess import sys import pytest +import _pytest.monkeypatch import random +from typing import Optional from functools import partial from .. import ( @@ -18,13 +20,26 @@ TrioDeprecationWarning, ) from .._core.tests.tutil import slow, skip_if_fbsd_pipes_broken -from ..testing import wait_all_tasks_blocked +from ..testing import MockClock, wait_all_tasks_blocked -posix = os.name == "posix" -if posix: - from signal import SIGKILL, SIGTERM, SIGUSR1 +SIGKILL: Optional[signal.Signals] +SIGTERM: Optional[signal.Signals] +SIGUSR1: Optional[signal.Signals] +SIGCHLD: Optional[signal.Signals] + +# TODO: is this the proper translation from os.name to sys.platform? +# Mypy understands sys.platform but not os.name +posix = sys.platform != "win32" + +if sys.platform != "win32": + import signal + + SIGKILL = signal.SIGKILL + SIGTERM = signal.SIGTERM + SIGUSR1 = signal.SIGUSR1 + SIGCHLD = signal.SIGCHLD else: - SIGKILL, SIGTERM, SIGUSR1 = None, None, None + SIGKILL, SIGTERM, SIGUSR1, SIGCHLD = None, None, None, None # Since Windows has very few command-line utilities generally available, @@ -47,7 +62,7 @@ def got_signal(proc, sig): return proc.returncode != 0 -async def test_basic(): +async def test_basic() -> None: async with await open_process(EXIT_TRUE) as proc: pass assert isinstance(proc, Process) @@ -63,7 +78,7 @@ async def test_basic(): ) -async def test_auto_update_returncode(): +async def test_auto_update_returncode() -> None: p = await open_process(SLEEP(9999)) assert p.returncode is None assert "running" in repr(p) @@ -75,7 +90,7 @@ async def test_auto_update_returncode(): assert p.returncode is not None -async def test_multi_wait(): +async def test_multi_wait() -> None: async with await open_process(SLEEP(10)) as proc: # Check that wait (including multi-wait) tolerates being cancelled async with _core.open_nursery() as nursery: @@ -94,7 +109,7 @@ async def test_multi_wait(): proc.kill() -async def test_kill_when_context_cancelled(): +async def test_kill_when_context_cancelled() -> None: with move_on_after(100) as scope: async with await open_process(SLEEP(10)) as proc: assert proc.poll() is None @@ -114,7 +129,7 @@ async def test_kill_when_context_cancelled(): ) -async def test_pipes(): +async def test_pipes() -> None: async with await open_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, @@ -123,7 +138,7 @@ async def test_pipes(): ) as proc: msg = b"the quick brown fox jumps over the lazy dog" - async def feed_input(): + async def feed_input() -> None: await proc.stdin.send_all(msg) await proc.stdin.aclose() @@ -144,7 +159,7 @@ async def check_output(stream, expected): assert 0 == await proc.wait() -async def test_interactive(): +async def test_interactive() -> None: # Test some back-and-forth with a subprocess. This one works like so: # in: 32\n # out: 0000...0000\n (32 zeroes) @@ -212,7 +227,7 @@ async def drain_one(stream, count, digit): assert proc.returncode == 0 -async def test_run(): +async def test_run() -> None: data = bytes(random.randint(0, 255) for _ in range(2 ** 18)) result = await run_process( @@ -251,7 +266,7 @@ async def test_run(): await run_process(CAT, capture_stderr=True, stderr=None) -async def test_run_check(): +async def test_run_check() -> None: cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)") with pytest.raises(subprocess.CalledProcessError) as excinfo: await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True) @@ -270,7 +285,7 @@ async def test_run_check(): @skip_if_fbsd_pipes_broken -async def test_run_with_broken_pipe(): +async def test_run_with_broken_pipe() -> None: result = await run_process( [sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072 ) @@ -278,7 +293,7 @@ async def test_run_with_broken_pipe(): assert result.stdout is result.stderr is None -async def test_stderr_stdout(): +async def test_stderr_stdout() -> None: async with await open_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, @@ -343,7 +358,7 @@ async def test_stderr_stdout(): os.close(r) -async def test_errors(): +async def test_errors() -> None: with pytest.raises(TypeError) as excinfo: await open_process(["ls"], encoding="utf-8") assert "unbuffered byte streams" in str(excinfo.value) @@ -356,8 +371,8 @@ async def test_errors(): await open_process("ls", shell=False) -async def test_signals(): - async def test_one_signal(send_it, signum): +async def test_signals() -> None: + async def test_one_signal(send_it, signum) -> None: with move_on_after(1.0) as scope: async with await open_process(SLEEP(3600)) as proc: send_it(proc) @@ -381,26 +396,33 @@ async def test_one_signal(send_it, signum): @pytest.mark.skipif(not posix, reason="POSIX specific") -async def test_wait_reapable_fails(): - old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN) - try: - # With SIGCHLD disabled, the wait() syscall will wait for the - # process to exit but then fail with ECHILD. Make sure we - # support this case as the stdlib subprocess module does. - async with await open_process(SLEEP(3600)) as proc: - async with _core.open_nursery() as nursery: - nursery.start_soon(proc.wait) - await wait_all_tasks_blocked() - proc.kill() - nursery.cancel_scope.deadline = _core.current_time() + 1.0 - assert not nursery.cancel_scope.cancelled_caught - assert proc.returncode == 0 # exit status unknowable, so... - finally: - signal.signal(signal.SIGCHLD, old_sigchld) +async def test_wait_reapable_fails() -> None: + if sys.platform == "win32": + # mypy doesn't recognize the pytest.mark.skipif and ignores an assert inside + # this function. + # https://github.com/python/mypy/issues/9025 + assert False # we should have been skipped, if not then fail + else: + assert SIGCHLD is not None # for mypy + old_sigchld = signal.signal(SIGCHLD, signal.SIG_IGN) + try: + # With SIGCHLD disabled, the wait() syscall will wait for the + # process to exit but then fail with ECHILD. Make sure we + # support this case as the stdlib subprocess module does. + async with await open_process(SLEEP(3600)) as proc: + async with _core.open_nursery() as nursery: + nursery.start_soon(proc.wait) + await wait_all_tasks_blocked() + proc.kill() + nursery.cancel_scope.deadline = _core.current_time() + 1.0 + assert not nursery.cancel_scope.cancelled_caught + assert proc.returncode == 0 # exit status unknowable, so... + finally: + signal.signal(SIGCHLD, old_sigchld) @slow -def test_waitid_eintr(): +def test_waitid_eintr() -> None: # This only matters on PyPy (where we're coding EINTR handling # ourselves) but the test works on all waitid platforms. from .._subprocess_platform import wait_child_exiting @@ -409,29 +431,35 @@ def test_waitid_eintr(): pytest.skip("waitid only") from .._subprocess_platform.waitid import sync_wait_reapable - got_alarm = False - sleeper = subprocess.Popen(["sleep", "3600"]) - - def on_alarm(sig, frame): - nonlocal got_alarm - got_alarm = True - sleeper.kill() - - old_sigalrm = signal.signal(signal.SIGALRM, on_alarm) - try: - signal.alarm(1) - sync_wait_reapable(sleeper.pid) - assert sleeper.wait(timeout=1) == -9 - finally: - if sleeper.returncode is None: # pragma: no cover - # We only get here if something fails in the above; - # if the test passes, wait() will reap the process + if sys.platform == "win32": + # mypy doesn't recognize the waitid checks above as representing not-Windows + # and ignores an assert inside this function. + # https://github.com/python/mypy/issues/9025 + assert False # we should have been skipped, if not then fail + else: + got_alarm = False + sleeper = subprocess.Popen(["sleep", "3600"]) + + def on_alarm(sig, frame): + nonlocal got_alarm + got_alarm = True sleeper.kill() - sleeper.wait() - signal.signal(signal.SIGALRM, old_sigalrm) + + old_sigalrm = signal.signal(signal.SIGALRM, on_alarm) + try: + signal.alarm(1) + sync_wait_reapable(sleeper.pid) + assert sleeper.wait(timeout=1) == -9 + finally: + if sleeper.returncode is None: # pragma: no cover + # We only get here if something fails in the above; + # if the test passes, wait() will reap the process + sleeper.kill() + sleeper.wait() + signal.signal(signal.SIGALRM, old_sigalrm) -async def test_custom_deliver_cancel(): +async def test_custom_deliver_cancel() -> None: custom_deliver_cancel_called = False async def custom_deliver_cancel(proc): @@ -455,10 +483,10 @@ async def custom_deliver_cancel(proc): assert custom_deliver_cancel_called -async def test_warn_on_failed_cancel_terminate(monkeypatch): +async def test_warn_on_failed_cancel_terminate(monkeypatch) -> None: original_terminate = Process.terminate - def broken_terminate(self): + def broken_terminate(self) -> None: original_terminate(self) raise OSError("whoops") @@ -472,7 +500,9 @@ def broken_terminate(self): @pytest.mark.skipif(os.name != "posix", reason="posix only") -async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch): +async def test_warn_on_cancel_SIGKILL_escalation( + autojump_clock: MockClock, monkeypatch: _pytest.monkeypatch.MonkeyPatch +) -> None: monkeypatch.setattr(Process, "terminate", lambda *args: None) with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"): diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index 229dea301c..b6e7170270 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -1,3 +1,5 @@ +from typing import cast, Callable + import pytest import weakref @@ -10,7 +12,7 @@ from .._sync import * -async def test_Event(): +async def test_Event() -> None: e = Event() assert not e.is_set() assert e.statistics().tasks_waiting == 0 @@ -24,7 +26,7 @@ async def test_Event(): record = [] - async def child(): + async def child() -> None: record.append("sleeping") await e.wait() record.append("woken") @@ -40,7 +42,7 @@ async def child(): assert record == ["sleeping", "sleeping", "woken", "woken"] -async def test_CapacityLimiter(): +async def test_CapacityLimiter() -> None: with pytest.raises(TypeError): CapacityLimiter(1.0) with pytest.raises(ValueError): @@ -109,7 +111,7 @@ async def test_CapacityLimiter(): c.release_on_behalf_of("value 1") -async def test_CapacityLimiter_inf(): +async def test_CapacityLimiter_inf() -> None: from math import inf c = CapacityLimiter(inf) @@ -125,7 +127,7 @@ async def test_CapacityLimiter_inf(): assert c.available_tokens == inf -async def test_CapacityLimiter_change_total_tokens(): +async def test_CapacityLimiter_change_total_tokens() -> None: c = CapacityLimiter(2) with pytest.raises(TypeError): @@ -162,7 +164,7 @@ async def test_CapacityLimiter_change_total_tokens(): # regression test for issue #548 -async def test_CapacityLimiter_memleak_548(): +async def test_CapacityLimiter_memleak_548() -> None: limiter = CapacityLimiter(total_tokens=1) await limiter.acquire() @@ -176,7 +178,7 @@ async def test_CapacityLimiter_memleak_548(): assert len(limiter._pending_borrowers) == 0 -async def test_Semaphore(): +async def test_Semaphore() -> None: with pytest.raises(TypeError): Semaphore(1.0) with pytest.raises(ValueError): @@ -224,7 +226,7 @@ async def do_acquire(s): assert record == ["started", "finished"] -async def test_Semaphore_bounded(): +async def test_Semaphore_bounded() -> None: with pytest.raises(TypeError): Semaphore(1, max_value=1.0) with pytest.raises(ValueError): @@ -241,8 +243,14 @@ async def test_Semaphore_bounded(): assert bs.value == 1 -@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__) -async def test_Lock_and_StrictFIFOLock(lockcls): +def get__name__(fn: Callable) -> str: + return fn.__name__ + + +@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=get__name__) +async def test_Lock_and_StrictFIFOLock( + lockcls: Union[Type[Lock], Type[StrictFIFOLock]] +) -> None: l = lockcls() # noqa assert not l.locked() @@ -254,7 +262,8 @@ async def test_Lock_and_StrictFIFOLock(lockcls): # make sure repr uses the right name for subclasses assert lockcls.__name__ in repr(l) with assert_checkpoints(): - async with l: + # TODO: hint async_cm + async with l: # type: ignore[union-attr] assert l.locked() repr(l) # smoke test (repr branches on locked/unlocked) assert not l.locked() @@ -279,7 +288,7 @@ async def test_Lock_and_StrictFIFOLock(lockcls): holder_task = None - async def holder(): + async def holder() -> None: nonlocal holder_task holder_task = _core.current_task() async with l: @@ -317,7 +326,7 @@ async def holder(): assert statistics.tasks_waiting == 0 -async def test_Condition(): +async def test_Condition() -> None: with pytest.raises(TypeError): Condition(Semaphore(1)) with pytest.raises(TypeError): @@ -410,58 +419,58 @@ async def waiter(i): @async_cm class ChannelLock1: - def __init__(self, capacity): + def __init__(self, capacity) -> None: self.s, self.r = open_memory_channel(capacity) for _ in range(capacity - 1): self.s.send_nowait(None) - def acquire_nowait(self): + def acquire_nowait(self) -> None: self.s.send_nowait(None) - async def acquire(self): + async def acquire(self) -> None: await self.s.send(None) - def release(self): + def release(self) -> None: self.r.receive_nowait() @async_cm class ChannelLock2: - def __init__(self): + def __init__(self) -> None: self.s, self.r = open_memory_channel(10) self.s.send_nowait(None) - def acquire_nowait(self): + def acquire_nowait(self) -> None: self.r.receive_nowait() - async def acquire(self): + async def acquire(self) -> None: await self.r.receive() - def release(self): + def release(self) -> None: self.s.send_nowait(None) @async_cm class ChannelLock3: - def __init__(self): + def __init__(self) -> None: self.s, self.r = open_memory_channel(0) # self.acquired is true when one task acquires the lock and # only becomes false when it's released and no tasks are # waiting to acquire. self.acquired = False - def acquire_nowait(self): + def acquire_nowait(self) -> None: assert not self.acquired self.acquired = True - async def acquire(self): + async def acquire(self) -> None: if self.acquired: await self.s.send(None) else: self.acquired = True await _core.checkpoint() - def release(self): + def release(self) -> None: try: self.r.receive_nowait() except _core.WouldBlock: @@ -494,11 +503,23 @@ def release(self): "lock_factory", lock_factories, ids=lock_factory_names ) +_LockFactory = Callable[ + [], + Union[ + CapacityLimiter, + Semaphore, + Lock, + StrictFIFOLock, + ChannelLock1, + ChannelLock2, + ChannelLock3, + ], +] # Spawn a bunch of workers that take a lock and then yield; make sure that # only one worker is ever in the critical section at a time. @generic_lock_test -async def test_generic_lock_exclusion(lock_factory): +async def test_generic_lock_exclusion(lock_factory: _LockFactory) -> None: LOOPS = 10 WORKERS = 5 in_critical_section = False @@ -527,7 +548,7 @@ async def worker(lock_like): # Several workers queue on the same lock; make sure they each get it, in # order. @generic_lock_test -async def test_generic_lock_fifo_fairness(lock_factory): +async def test_generic_lock_fifo_fairness(lock_factory: _LockFactory) -> None: initial_order = [] record = [] LOOPS = 5 @@ -551,12 +572,14 @@ async def loopy(name, lock_like): @generic_lock_test -async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory): +async def test_generic_lock_acquire_nowait_blocks_acquire( + lock_factory: _LockFactory, +) -> None: lock_like = lock_factory() record = [] - async def lock_taker(): + async def lock_taker() -> None: record.append("started") async with lock_like: pass diff --git a/trio/tests/test_testing.py b/trio/tests/test_testing.py index 0b10ae71e1..981df52632 100644 --- a/trio/tests/test_testing.py +++ b/trio/tests/test_testing.py @@ -15,15 +15,15 @@ from .._highlevel_socket import SocketListener -async def test_wait_all_tasks_blocked(): +async def test_wait_all_tasks_blocked() -> None: record = [] - async def busy_bee(): + async def busy_bee() -> None: for _ in range(10): await _core.checkpoint() record.append("busy bee exhausted") - async def waiting_for_bee_to_leave(): + async def waiting_for_bee_to_leave() -> None: await wait_all_tasks_blocked() record.append("quiet at last!") @@ -35,7 +35,7 @@ async def waiting_for_bee_to_leave(): # check cancellation record = [] - async def cancelled_while_waiting(): + async def cancelled_while_waiting() -> None: try: await wait_all_tasks_blocked() except _core.Cancelled: @@ -47,10 +47,10 @@ async def cancelled_while_waiting(): assert record == ["ok"] -async def test_wait_all_tasks_blocked_with_timeouts(mock_clock): +async def test_wait_all_tasks_blocked_with_timeouts(mock_clock) -> None: record = [] - async def timeout_task(): + async def timeout_task() -> None: record.append("tt start") await sleep(5) record.append("tt finished") @@ -64,25 +64,25 @@ async def timeout_task(): assert record == ["tt start", "tt finished"] -async def test_wait_all_tasks_blocked_with_cushion(): +async def test_wait_all_tasks_blocked_with_cushion() -> None: record = [] - async def blink(): + async def blink() -> None: record.append("blink start") await sleep(0.01) await sleep(0.01) await sleep(0.01) record.append("blink end") - async def wait_no_cushion(): + async def wait_no_cushion() -> None: await wait_all_tasks_blocked() record.append("wait_no_cushion end") - async def wait_small_cushion(): + async def wait_small_cushion() -> None: await wait_all_tasks_blocked(0.02) record.append("wait_small_cushion end") - async def wait_big_cushion(): + async def wait_big_cushion() -> None: await wait_all_tasks_blocked(0.03) record.append("wait_big_cushion end") @@ -106,7 +106,7 @@ async def wait_big_cushion(): ################################################################ -async def test_assert_checkpoints(recwarn): +async def test_assert_checkpoints(recwarn) -> None: with assert_checkpoints(): await _core.checkpoint() @@ -132,7 +132,7 @@ async def test_assert_checkpoints(recwarn): await _core.cancel_shielded_checkpoint() -async def test_assert_no_checkpoints(recwarn): +async def test_assert_no_checkpoints(recwarn) -> None: with assert_no_checkpoints(): 1 + 1 @@ -162,7 +162,7 @@ async def test_assert_no_checkpoints(recwarn): ################################################################ -async def test_Sequencer(): +async def test_Sequencer() -> None: record = [] def t(val): @@ -200,7 +200,7 @@ async def f2(seq): pass # pragma: no cover -async def test_Sequencer_cancel(): +async def test_Sequencer_cancel() -> None: # Killing a blocked task makes everything blow up record = [] seq = Sequencer() @@ -232,7 +232,7 @@ async def child(i): ################################################################ -async def test__assert_raises(): +async def test__assert_raises() -> None: with pytest.raises(AssertionError): with _assert_raises(RuntimeError): 1 + 1 @@ -247,7 +247,7 @@ async def test__assert_raises(): # This is a private implementation detail, but it's complex enough to be worth # testing directly -async def test__UnboundeByteQueue(): +async def test__UnboundeByteQueue() -> None: ubq = _UnboundedByteQueue() ubq.put(b"123") @@ -310,7 +310,7 @@ async def getter(expect): # close wakes up blocked getters ubq2 = _UnboundedByteQueue() - async def closer(): + async def closer() -> None: await wait_all_tasks_blocked() ubq2.close() @@ -319,7 +319,7 @@ async def closer(): nursery.start_soon(closer) -async def test_MemorySendStream(): +async def test_MemorySendStream() -> None: mss = MemorySendStream() async def do_send_all(data): @@ -348,7 +348,7 @@ async def do_send_all(data): # and we don't know which one will get the error. resource_busy_count = 0 - async def do_send_all_count_resourcebusy(): + async def do_send_all_count_resourcebusy() -> None: nonlocal resource_busy_count try: await do_send_all(b"xxx") @@ -377,15 +377,15 @@ async def do_send_all_count_resourcebusy(): record = [] - async def send_all_hook(): + async def send_all_hook() -> None: # hook runs after send_all does its work (can pull data out) assert mss2.get_data_nowait() == b"abc" record.append("send_all_hook") - async def wait_send_all_might_not_block_hook(): + async def wait_send_all_might_not_block_hook() -> None: record.append("wait_send_all_might_not_block_hook") - def close_hook(): + def close_hook() -> None: record.append("close_hook") mss2 = MemorySendStream( @@ -409,7 +409,7 @@ def close_hook(): ] -async def test_MemoryReceiveStream(): +async def test_MemoryReceiveStream() -> None: mrs = MemoryReceiveStream() async def do_receive_some(max_bytes): @@ -440,12 +440,12 @@ async def do_receive_some(max_bytes): with pytest.raises(_core.ClosedResourceError): mrs.put_data(b"---") - async def receive_some_hook(): + async def receive_some_hook() -> None: mrs2.put_data(b"xxx") record = [] - def close_hook(): + def close_hook() -> None: record.append("closed") mrs2 = MemoryReceiveStream(receive_some_hook, close_hook) @@ -470,7 +470,7 @@ def close_hook(): await mrs2.receive_some(10) -async def test_MemoryRecvStream_closing(): +async def test_MemoryRecvStream_closing() -> None: mrs = MemoryReceiveStream() # close with no pending data mrs.close() @@ -490,7 +490,7 @@ async def test_MemoryRecvStream_closing(): await mrs2.receive_some(10) -async def test_memory_stream_pump(): +async def test_memory_stream_pump() -> None: mss = MemorySendStream() mrs = MemoryReceiveStream() @@ -514,7 +514,7 @@ async def test_memory_stream_pump(): assert await mrs.receive_some(10) == b"" -async def test_memory_stream_one_way_pair(): +async def test_memory_stream_one_way_pair() -> None: s, r = memory_stream_one_way_pair() assert s.send_all_hook is not None assert s.wait_send_all_might_not_block_hook is None @@ -555,7 +555,7 @@ async def cancel_after_idle(nursery): await wait_all_tasks_blocked() nursery.cancel_scope.cancel() - async def check_for_cancel(): + async def check_for_cancel() -> None: with pytest.raises(_core.Cancelled): # This should block forever... or until cancelled. Even though we # sent some data on the send stream. @@ -570,7 +570,7 @@ async def check_for_cancel(): assert await r.receive_some(10) == b"456789" -async def test_memory_stream_pair(): +async def test_memory_stream_pair() -> None: a, b = memory_stream_pair() await a.send_all(b"123") await b.send_all(b"abc") @@ -580,11 +580,11 @@ async def test_memory_stream_pair(): await a.send_eof() assert await b.receive_some(10) == b"" - async def sender(): + async def sender() -> None: await wait_all_tasks_blocked() await b.send_all(b"xyz") - async def receiver(): + async def receiver() -> None: assert await a.receive_some(10) == b"xyz" async with _core.open_nursery() as nursery: @@ -592,7 +592,7 @@ async def receiver(): nursery.start_soon(sender) -async def test_memory_streams_with_generic_tests(): +async def test_memory_streams_with_generic_tests() -> None: async def one_way_stream_maker(): return memory_stream_one_way_pair() @@ -604,7 +604,7 @@ async def half_closeable_stream_maker(): await check_half_closeable_stream(half_closeable_stream_maker, None) -async def test_lockstep_streams_with_generic_tests(): +async def test_lockstep_streams_with_generic_tests() -> None: async def one_way_stream_maker(): return lockstep_stream_one_way_pair() @@ -616,7 +616,7 @@ async def two_way_stream_maker(): await check_two_way_stream(two_way_stream_maker, two_way_stream_maker) -async def test_open_stream_to_socket_listener(): +async def test_open_stream_to_socket_listener() -> None: async def check(listener): async with listener: client_stream = await open_stream_to_socket_listener(listener) diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index 9da2838cbd..f0c2afc1e4 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -2,6 +2,7 @@ import queue as stdlib_queue import time +import attr import pytest from .. import _core @@ -18,13 +19,13 @@ from .._core.tests.test_ki import ki_self -async def test_do_in_trio_thread(): +async def test_do_in_trio_thread() -> None: trio_thread = threading.current_thread() async def check_case(do_in_trio_thread, fn, expected, trio_token=None): record = [] - def threadfn(): + def threadfn() -> None: try: record.append(("start", threading.current_thread())) x = do_in_trio_thread(fn, record, trio_token=trio_token) @@ -73,26 +74,26 @@ async def f(record): await check_case(from_thread_run, f, ("error", KeyError), trio_token=token) -async def test_do_in_trio_thread_from_trio_thread(): +async def test_do_in_trio_thread_from_trio_thread() -> None: with pytest.raises(RuntimeError): from_thread_run_sync(lambda: None) # pragma: no branch - async def foo(): # pragma: no cover + async def foo() -> None: # pragma: no cover pass with pytest.raises(RuntimeError): from_thread_run(foo) -def test_run_in_trio_thread_ki(): +def test_run_in_trio_thread_ki() -> None: # if we get a control-C during a run_in_trio_thread, then it propagates # back to the caller (slick!) record = set() - async def check_run_in_trio_thread(): + async def check_run_in_trio_thread() -> None: token = _core.current_trio_token() - def trio_thread_fn(): + def trio_thread_fn() -> None: print("in Trio thread") assert not _core.currently_ki_protected() print("ki_self") @@ -103,10 +104,10 @@ def trio_thread_fn(): print("finally", sys.exc_info()) - async def trio_thread_afn(): + async def trio_thread_afn() -> None: trio_thread_fn() - def external_thread_fn(): + def external_thread_fn() -> None: try: print("running") from_thread_run_sync(trio_thread_fn, trio_token=token) @@ -132,7 +133,7 @@ def external_thread_fn(): assert record == {"ok1", "ok2"} -def test_await_in_trio_thread_while_main_exits(): +def test_await_in_trio_thread_while_main_exits() -> None: record = [] ev = Event() @@ -160,7 +161,7 @@ async def main(): assert record == ["sleeping", "cancelled"] -async def test_run_in_worker_thread(): +async def test_run_in_worker_thread() -> None: trio_thread = threading.current_thread() def f(x): @@ -170,7 +171,7 @@ def f(x): assert x == 1 assert child_thread != trio_thread - def g(): + def g() -> None: raise ValueError(threading.current_thread()) with pytest.raises(ValueError) as excinfo: @@ -179,7 +180,7 @@ def g(): assert excinfo.value.args[0] != trio_thread -async def test_run_in_worker_thread_cancellation(): +async def test_run_in_worker_thread_cancellation() -> None: register = [None] def f(q): @@ -239,18 +240,18 @@ async def child(q, cancellable): # Make sure that if trio.run exits, and then the thread finishes, then that's # handled gracefully. (Requires that the thread result machinery be prepared # for call_soon to raise RunFinishedError.) -def test_run_in_worker_thread_abandoned(capfd, monkeypatch): +def test_run_in_worker_thread_abandoned(capfd, monkeypatch) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) q1 = stdlib_queue.Queue() q2 = stdlib_queue.Queue() - def thread_fn(): + def thread_fn() -> None: q1.get() q2.put(threading.current_thread()) - async def main(): - async def child(): + async def main() -> None: + async def child() -> None: await to_thread_run_sync(thread_fn, cancellable=True) async with _core.open_nursery() as nursery: @@ -277,7 +278,9 @@ async def child(): @pytest.mark.parametrize("MAX", [3, 5, 10]) @pytest.mark.parametrize("cancel", [False, True]) @pytest.mark.parametrize("use_default_limiter", [False, True]) -async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter): +async def test_run_in_worker_thread_limiter( + MAX: int, cancel: bool, use_default_limiter: bool +) -> None: # This test is a bit tricky. The goal is to make sure that if we set # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever # running at a time, even if there are more concurrent calls to @@ -306,13 +309,16 @@ async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter): # # Mutating them in-place is OK though (as long as you use proper # locking etc.). - class state: - pass - state.ran = 0 - state.high_water = 0 - state.running = 0 - state.parked = 0 + # TODO: does this break the concerns explained above...? + @attr.s() + class State: + ran: int = attr.ib(default=0) + high_water: int = attr.ib(default=0) + running: int = attr.ib(default=0) + parked: int = attr.ib(default=0) + + state = State() token = _core.current_trio_token() @@ -381,7 +387,7 @@ async def run_thread(event): c.total_tokens = orig_total_tokens -async def test_run_in_worker_thread_custom_limiter(): +async def test_run_in_worker_thread_custom_limiter() -> None: # Basically just checking that we only call acquire_on_behalf_of and # release_on_behalf_of, since that's part of our documented API. record = [] @@ -399,7 +405,7 @@ def release_on_behalf_of(self, borrower): assert record == ["acquire", "release"] -async def test_run_in_worker_thread_limiter_error(): +async def test_run_in_worker_thread_limiter_error() -> None: record = [] class BadCapacityLimiter: @@ -427,7 +433,7 @@ def release_on_behalf_of(self, borrower): assert record == ["acquire", "release"] -async def test_run_in_worker_thread_fail_to_spawn(monkeypatch): +async def test_run_in_worker_thread_fail_to_spawn(monkeypatch) -> None: # Test the unlikely but possible case where trying to spawn a thread fails def bad_start(self, *args): raise RuntimeError("the engines canna take it captain") @@ -445,7 +451,7 @@ def bad_start(self, *args): assert limiter.borrowed_tokens == 0 -async def test_trio_to_thread_run_sync_token(): +async def test_trio_to_thread_run_sync_token() -> None: # Test that to_thread_run_sync automatically injects the current trio token # into a spawned thread def thread_fn(): @@ -457,16 +463,16 @@ def thread_fn(): assert callee_token == caller_token -async def test_trio_to_thread_run_sync_expected_error(): +async def test_trio_to_thread_run_sync_expected_error() -> None: # Test correct error when passed async function - async def async_fn(): # pragma: no cover + async def async_fn() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="expected a sync function"): await to_thread_run_sync(async_fn) -async def test_trio_from_thread_run_sync(): +async def test_trio_from_thread_run_sync() -> None: # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run_sync() def thread_fn(): @@ -477,26 +483,26 @@ def thread_fn(): assert isinstance(trio_time, float) # Test correct error when passed async function - async def async_fn(): # pragma: no cover + async def async_fn() -> None: # pragma: no cover pass - def thread_fn(): + def thread_fn() -> None: from_thread_run_sync(async_fn) with pytest.raises(TypeError, match="expected a sync function"): await to_thread_run_sync(thread_fn) -async def test_trio_from_thread_run(): +async def test_trio_from_thread_run() -> None: # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run() record = [] - async def back_in_trio_fn(): + async def back_in_trio_fn() -> None: _core.current_time() # implicitly checks that we're in trio record.append("back in trio") - def thread_fn(): + def thread_fn() -> None: record.append("in thread") from_thread_run(back_in_trio_fn) @@ -504,14 +510,14 @@ def thread_fn(): assert record == ["in thread", "back in trio"] # Test correct error when passed sync function - def sync_fn(): # pragma: no cover + def sync_fn() -> None: # pragma: no cover pass with pytest.raises(TypeError, match="appears to be synchronous"): await to_thread_run_sync(from_thread_run, sync_fn) -async def test_trio_from_thread_token(): +async def test_trio_from_thread_token() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() # share the same Trio token def thread_fn(): @@ -523,7 +529,7 @@ def thread_fn(): assert callee_token == caller_token -async def test_trio_from_thread_token_kwarg(): +async def test_trio_from_thread_token_kwarg() -> None: # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can # use an explicitly defined token def thread_fn(token): @@ -535,7 +541,7 @@ def thread_fn(token): assert callee_token == caller_token -async def test_from_thread_no_token(): +async def test_from_thread_no_token() -> None: # Test that a "raw call" to trio.from_thread.run() fails because no token # has been provided @@ -543,13 +549,13 @@ async def test_from_thread_no_token(): from_thread_run_sync(_core.current_time) -def test_run_fn_as_system_task_catched_badly_typed_token(): +def test_run_fn_as_system_task_catched_badly_typed_token() -> None: with pytest.raises(RuntimeError): from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype") -async def test_from_thread_inside_trio_thread(): - def not_called(): # pragma: no cover +async def test_from_thread_inside_trio_thread() -> None: + def not_called() -> None: # pragma: no cover assert False trio_token = _core.current_trio_token() @@ -558,7 +564,7 @@ def not_called(): # pragma: no cover @pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") -def test_from_thread_run_during_shutdown(): +def test_from_thread_run_during_shutdown() -> None: save = [] record = [] @@ -570,7 +576,7 @@ async def agen(): await to_thread_run_sync(from_thread_run, sleep, 0) record.append("ok") - async def main(): + async def main() -> None: save.append(agen()) await save[-1].asend(None) diff --git a/trio/tests/test_timeouts.py b/trio/tests/test_timeouts.py index 382c015b1d..55bbda5036 100644 --- a/trio/tests/test_timeouts.py +++ b/trio/tests/test_timeouts.py @@ -42,13 +42,13 @@ async def check_takes_about(f, expected_dur): @slow -async def test_sleep(): - async def sleep_1(): +async def test_sleep() -> None: + async def sleep_1() -> None: await sleep_until(_core.current_time() + TARGET) await check_takes_about(sleep_1, TARGET) - async def sleep_2(): + async def sleep_2() -> None: await sleep(TARGET) await check_takes_about(sleep_2, TARGET) @@ -65,12 +65,12 @@ async def sleep_2(): @slow -async def test_move_on_after(): +async def test_move_on_after() -> None: with pytest.raises(ValueError): with move_on_after(-1): pass # pragma: no cover - async def sleep_3(): + async def sleep_3() -> None: with move_on_after(TARGET): await sleep(100) @@ -78,8 +78,8 @@ async def sleep_3(): @slow -async def test_fail(): - async def sleep_4(): +async def test_fail() -> None: + async def sleep_4() -> None: with fail_at(_core.current_time() + TARGET): await sleep(100) @@ -89,7 +89,7 @@ async def sleep_4(): with fail_at(_core.current_time() + 100): await sleep(0) - async def sleep_5(): + async def sleep_5() -> None: with fail_after(TARGET): await sleep(100) diff --git a/trio/tests/test_unix_pipes.py b/trio/tests/test_unix_pipes.py index 55dd4e3734..ff70ad7ffd 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/tests/test_unix_pipes.py @@ -3,6 +3,7 @@ import os import tempfile import sys +from typing import Tuple import pytest @@ -10,256 +11,252 @@ from .. import _core, move_on_after from ..testing import wait_all_tasks_blocked, check_one_way_stream -posix = os.name == "posix" -pytestmark = pytest.mark.skipif(not posix, reason="posix only") -if posix: - from .._unix_pipes import FdStream -else: - with pytest.raises(ImportError): - from .._unix_pipes import FdStream - - -# Have to use quoted types so import doesn't crash on windows -async def make_pipe() -> "Tuple[FdStream, FdStream]": - """Makes a new pair of pipes.""" - (r, w) = os.pipe() - return FdStream(w), FdStream(r) - - -async def make_clogged_pipe(): - s, r = await make_pipe() - try: - while True: - # We want to totally fill up the pipe buffer. - # This requires working around a weird feature that POSIX pipes - # have. - # If you do a write of <= PIPE_BUF bytes, then it's guaranteed - # to either complete entirely, or not at all. So if we tried to - # write PIPE_BUF bytes, and the buffer's free space is only - # PIPE_BUF/2, then the write will raise BlockingIOError... even - # though a smaller write could still succeed! To avoid this, - # make sure to write >PIPE_BUF bytes each time, which disables - # the special behavior. - # For details, search for PIPE_BUF here: - # http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html - - # for the getattr: - # https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3 - buf_size = getattr(select, "PIPE_BUF", 8192) - os.write(s.fileno(), b"x" * buf_size * 2) - except BlockingIOError: - pass - return s, r - - -async def test_send_pipe(): - r, w = os.pipe() - async with FdStream(w) as send: - assert send.fileno() == w - await send.send_all(b"123") - assert (os.read(r, 8)) == b"123" - - os.close(r) - - -async def test_receive_pipe(): - r, w = os.pipe() - async with FdStream(r) as recv: - assert (recv.fileno()) == r - os.write(w, b"123") - assert (await recv.receive_some(8)) == b"123" - - os.close(w) - - -async def test_pipes_combined(): - write, read = await make_pipe() - count = 2 ** 20 - - async def sender(): - big = bytearray(count) - await write.send_all(big) - - async def reader(): - await wait_all_tasks_blocked() - received = 0 - while received < count: - received += len(await read.receive_some(4096)) - - assert received == count - - async with _core.open_nursery() as n: - n.start_soon(sender) - n.start_soon(reader) - - await read.aclose() - await write.aclose() - - -async def test_pipe_errors(): - with pytest.raises(TypeError): - FdStream(None) - - r, w = os.pipe() - os.close(w) - async with FdStream(r) as s: - with pytest.raises(ValueError): - await s.receive_some(0) - - -async def test_del(): - w, r = await make_pipe() - f1, f2 = w.fileno(), r.fileno() - del w, r - gc_collect_harder() - - with pytest.raises(OSError) as excinfo: - os.close(f1) - assert excinfo.value.errno == errno.EBADF - - with pytest.raises(OSError) as excinfo: - os.close(f2) - assert excinfo.value.errno == errno.EBADF - +pytestmark = pytest.mark.skipif(sys.platform == "win32", reason="posix only") -async def test_async_with(): - w, r = await make_pipe() - async with w, r: - pass - - assert w.fileno() == -1 - assert r.fileno() == -1 - - with pytest.raises(OSError) as excinfo: - os.close(w.fileno()) - assert excinfo.value.errno == errno.EBADF - - with pytest.raises(OSError) as excinfo: - os.close(r.fileno()) - assert excinfo.value.errno == errno.EBADF - - -async def test_misdirected_aclose_regression(): - # https://github.com/python-trio/trio/issues/661#issuecomment-456582356 - w, r = await make_pipe() - old_r_fd = r.fileno() - - # Close the original objects - await w.aclose() - await r.aclose() - - # Do a little dance to get a new pipe whose receive handle matches the old - # receive handle. - r2_fd, w2_fd = os.pipe() - if r2_fd != old_r_fd: # pragma: no cover - os.dup2(r2_fd, old_r_fd) - os.close(r2_fd) - async with FdStream(old_r_fd) as r2: - assert r2.fileno() == old_r_fd +# mypy recognizes this. an assert would break the pytest skipif +if sys.platform == "win32": + with pytest.raises(AssertionError): + # Using sys instead of FdStream since sys is created before the assertion that + # terminates the import of _unix_pipes and this makes Mypy happier. The type + # warning can't be ignored since it is not present on all platforms and thus + # triggers a warning about being unneeded on other platforms. + from .._unix_pipes import sys +else: + from .._unix_pipes import FdStream - # And now set up a background task that's working on the new receive - # handle - async def expect_eof(): - assert await r2.receive_some(10) == b"" + # Have to use quoted types so import doesn't crash on windows + async def make_pipe() -> Tuple["FdStream", "FdStream"]: + """Makes a new pair of pipes.""" + (r, w) = os.pipe() + return FdStream(w), FdStream(r) - async with _core.open_nursery() as nursery: - nursery.start_soon(expect_eof) - await wait_all_tasks_blocked() - - # Here's the key test: does calling aclose() again on the *old* - # handle, cause the task blocked on the *new* handle to raise - # ClosedResourceError? - await r.aclose() + async def make_clogged_pipe(): + s, r = await make_pipe() + try: + while True: + # We want to totally fill up the pipe buffer. + # This requires working around a weird feature that POSIX pipes + # have. + # If you do a write of <= PIPE_BUF bytes, then it's guaranteed + # to either complete entirely, or not at all. So if we tried to + # write PIPE_BUF bytes, and the buffer's free space is only + # PIPE_BUF/2, then the write will raise BlockingIOError... even + # though a smaller write could still succeed! To avoid this, + # make sure to write >PIPE_BUF bytes each time, which disables + # the special behavior. + # For details, search for PIPE_BUF here: + # http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html + + # for the getattr: + # https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3 + buf_size = getattr(select, "PIPE_BUF", 8192) + os.write(s.fileno(), b"x" * buf_size * 2) + except BlockingIOError: + pass + return s, r + + async def test_send_pipe() -> None: + r, w = os.pipe() + async with FdStream(w) as send: + assert send.fileno() == w + await send.send_all(b"123") + assert (os.read(r, 8)) == b"123" + + os.close(r) + + async def test_receive_pipe() -> None: + r, w = os.pipe() + async with FdStream(r) as recv: + assert (recv.fileno()) == r + os.write(w, b"123") + assert (await recv.receive_some(8)) == b"123" + + os.close(w) + + async def test_pipes_combined() -> None: + write, read = await make_pipe() + count = 2 ** 20 + + async def sender() -> None: + big = bytearray(count) + await write.send_all(big) + + async def reader() -> None: await wait_all_tasks_blocked() + received = 0 + while received < count: + received += len(await read.receive_some(4096)) - # Guess we survived! Close the new write handle so that the task - # gets an EOF and can exit cleanly. - os.close(w2_fd) + assert received == count + async with _core.open_nursery() as n: + n.start_soon(sender) + n.start_soon(reader) -async def test_close_at_bad_time_for_receive_some(monkeypatch): - # We used to have race conditions where if one task was using the pipe, - # and another closed it at *just* the wrong moment, it would give an - # unexpected error instead of ClosedResourceError: - # https://github.com/python-trio/trio/issues/661 - # - # This tests what happens if the pipe gets closed in the moment *between* - # when receive_some wakes up, and when it tries to call os.read - async def expect_closedresourceerror(): - with pytest.raises(_core.ClosedResourceError): - await r.receive_some(10) + await read.aclose() + await write.aclose() - orig_wait_readable = _core._run.TheIOManager.wait_readable + async def test_pipe_errors() -> None: + with pytest.raises(TypeError): + FdStream(None) - async def patched_wait_readable(*args, **kwargs): - await orig_wait_readable(*args, **kwargs) + r, w = os.pipe() + os.close(w) + async with FdStream(r) as s: + with pytest.raises(ValueError): + await s.receive_some(0) + + async def test_del() -> None: + w, r = await make_pipe() + f1, f2 = w.fileno(), r.fileno() + del w, r + gc_collect_harder() + + with pytest.raises(OSError) as excinfo: + os.close(f1) + assert excinfo.value.errno == errno.EBADF + + with pytest.raises(OSError) as excinfo: + os.close(f2) + assert excinfo.value.errno == errno.EBADF + + async def test_async_with() -> None: + w, r = await make_pipe() + async with w, r: + pass + + assert w.fileno() == -1 + assert r.fileno() == -1 + + with pytest.raises(OSError) as excinfo: + os.close(w.fileno()) + assert excinfo.value.errno == errno.EBADF + + with pytest.raises(OSError) as excinfo: + os.close(r.fileno()) + assert excinfo.value.errno == errno.EBADF + + async def test_misdirected_aclose_regression() -> None: + # https://github.com/python-trio/trio/issues/661#issuecomment-456582356 + w, r = await make_pipe() + old_r_fd = r.fileno() + + # Close the original objects + await w.aclose() await r.aclose() - monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable) - s, r = await make_pipe() - async with s, r: - async with _core.open_nursery() as nursery: - nursery.start_soon(expect_closedresourceerror) - await wait_all_tasks_blocked() - # Trigger everything by waking up the receiver - await s.send_all(b"x") - - -async def test_close_at_bad_time_for_send_all(monkeypatch): - # We used to have race conditions where if one task was using the pipe, - # and another closed it at *just* the wrong moment, it would give an - # unexpected error instead of ClosedResourceError: - # https://github.com/python-trio/trio/issues/661 - # - # This tests what happens if the pipe gets closed in the moment *between* - # when send_all wakes up, and when it tries to call os.write - async def expect_closedresourceerror(): - with pytest.raises(_core.ClosedResourceError): - await s.send_all(b"x" * 100) - - orig_wait_writable = _core._run.TheIOManager.wait_writable - - async def patched_wait_writable(*args, **kwargs): - await orig_wait_writable(*args, **kwargs) - await s.aclose() - - monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable) - s, r = await make_clogged_pipe() - async with s, r: - async with _core.open_nursery() as nursery: - nursery.start_soon(expect_closedresourceerror) - await wait_all_tasks_blocked() - # Trigger everything by waking up the sender - await r.receive_some(10000) - - -# On FreeBSD, directories are readable, and we haven't found any other trick -# for making an unreadable fd, so there's no way to run this test. Fortunately -# the logic this is testing doesn't depend on the platform, so testing on -# other platforms is probably good enough. -@pytest.mark.skipif( - sys.platform.startswith("freebsd"), - reason="no way to make read() return a bizarro error on FreeBSD", -) -async def test_bizarro_OSError_from_receive(): - # Make sure that if the read syscall returns some bizarro error, then we - # get a BrokenResourceError. This is incredibly unlikely; there's almost - # no way to trigger a failure here intentionally (except for EBADF, but we - # exploit that to detect file closure, so it takes a different path). So - # we set up a strange scenario where the pipe fd somehow transmutes into a - # directory fd, causing os.read to raise IsADirectoryError (yes, that's a - # real built-in exception type). - s, r = await make_pipe() - async with s, r: - dir_fd = os.open("/", os.O_DIRECTORY, 0) - try: - os.dup2(dir_fd, r.fileno()) - with pytest.raises(_core.BrokenResourceError): + # Do a little dance to get a new pipe whose receive handle matches the old + # receive handle. + r2_fd, w2_fd = os.pipe() + if r2_fd != old_r_fd: # pragma: no cover + os.dup2(r2_fd, old_r_fd) + os.close(r2_fd) + async with FdStream(old_r_fd) as r2: + assert r2.fileno() == old_r_fd + + # And now set up a background task that's working on the new receive + # handle + async def expect_eof() -> None: + assert await r2.receive_some(10) == b"" + + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_eof) + await wait_all_tasks_blocked() + + # Here's the key test: does calling aclose() again on the *old* + # handle, cause the task blocked on the *new* handle to raise + # ClosedResourceError? + await r.aclose() + await wait_all_tasks_blocked() + + # Guess we survived! Close the new write handle so that the task + # gets an EOF and can exit cleanly. + os.close(w2_fd) + + async def test_close_at_bad_time_for_receive_some(monkeypatch) -> None: + # We used to have race conditions where if one task was using the pipe, + # and another closed it at *just* the wrong moment, it would give an + # unexpected error instead of ClosedResourceError: + # https://github.com/python-trio/trio/issues/661 + # + # This tests what happens if the pipe gets closed in the moment *between* + # when receive_some wakes up, and when it tries to call os.read + async def expect_closedresourceerror() -> None: + with pytest.raises(_core.ClosedResourceError): await r.receive_some(10) - finally: - os.close(dir_fd) + orig_wait_readable = _core._run.TheIOManager.wait_readable + + async def patched_wait_readable(*args, **kwargs): + await orig_wait_readable(*args, **kwargs) + await r.aclose() -@skip_if_fbsd_pipes_broken -async def test_pipe_fully(): - await check_one_way_stream(make_pipe, make_clogged_pipe) + monkeypatch.setattr( + _core._run.TheIOManager, "wait_readable", patched_wait_readable + ) + s, r = await make_pipe() + async with s, r: + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_closedresourceerror) + await wait_all_tasks_blocked() + # Trigger everything by waking up the receiver + await s.send_all(b"x") + + async def test_close_at_bad_time_for_send_all(monkeypatch) -> None: + # We used to have race conditions where if one task was using the pipe, + # and another closed it at *just* the wrong moment, it would give an + # unexpected error instead of ClosedResourceError: + # https://github.com/python-trio/trio/issues/661 + # + # This tests what happens if the pipe gets closed in the moment *between* + # when send_all wakes up, and when it tries to call os.write + async def expect_closedresourceerror() -> None: + with pytest.raises(_core.ClosedResourceError): + await s.send_all(b"x" * 100) + + orig_wait_writable = _core._run.TheIOManager.wait_writable + + async def patched_wait_writable(*args, **kwargs): + await orig_wait_writable(*args, **kwargs) + await s.aclose() + + monkeypatch.setattr( + _core._run.TheIOManager, "wait_writable", patched_wait_writable + ) + s, r = await make_clogged_pipe() + async with s, r: + async with _core.open_nursery() as nursery: + nursery.start_soon(expect_closedresourceerror) + await wait_all_tasks_blocked() + # Trigger everything by waking up the sender + await r.receive_some(10000) + + # On FreeBSD, directories are readable, and we haven't found any other trick + # for making an unreadable fd, so there's no way to run this test. Fortunately + # the logic this is testing doesn't depend on the platform, so testing on + # other platforms is probably good enough. + @pytest.mark.skipif( + sys.platform.startswith("freebsd"), + reason="no way to make read() return a bizarro error on FreeBSD", + ) + async def test_bizarro_OSError_from_receive() -> None: + # Make sure that if the read syscall returns some bizarro error, then we + # get a BrokenResourceError. This is incredibly unlikely; there's almost + # no way to trigger a failure here intentionally (except for EBADF, but we + # exploit that to detect file closure, so it takes a different path). So + # we set up a strange scenario where the pipe fd somehow transmutes into a + # directory fd, causing os.read to raise IsADirectoryError (yes, that's a + # real built-in exception type). + s, r = await make_pipe() + async with s, r: + dir_fd = os.open("/", os.O_DIRECTORY, 0) + try: + os.dup2(dir_fd, r.fileno()) + with pytest.raises(_core.BrokenResourceError): + await r.receive_some(10) + finally: + os.close(dir_fd) + + @skip_if_fbsd_pipes_broken + async def test_pipe_fully() -> None: + await check_one_way_stream(make_pipe, make_clogged_pipe) diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index 2ea0a1e287..cc164989b3 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -1,4 +1,5 @@ import signal +from typing import Iterator, TypeVar import pytest import trio @@ -16,7 +17,10 @@ from ..testing import wait_all_tasks_blocked -def test_signal_raise(): +_T = TypeVar("_T") + + +def test_signal_raise() -> None: record = [] def handler(signum, _): @@ -30,7 +34,7 @@ def handler(signum, _): assert record == [signal.SIGFPE] -async def test_ConflictDetector(): +async def test_ConflictDetector() -> None: ul1 = ConflictDetector("ul1") ul2 = ConflictDetector("ul2") @@ -44,7 +48,7 @@ async def test_ConflictDetector(): pass # pragma: no cover assert "ul1" in str(excinfo.value) - async def wait_with_ul1(): + async def wait_with_ul1() -> None: with ul1: await wait_all_tasks_blocked() @@ -55,7 +59,7 @@ async def wait_with_ul1(): assert "ul1" in str(excinfo.value) -def test_module_metadata_is_fixed_up(): +def test_module_metadata_is_fixed_up() -> None: import trio import trio.testing @@ -79,10 +83,10 @@ def test_module_metadata_is_fixed_up(): assert trio.to_thread.run_sync.__qualname__ == "run_sync" -async def test_is_main_thread(): +async def test_is_main_thread() -> None: assert is_main_thread() - def not_main_thread(): + def not_main_thread() -> None: assert not is_main_thread() await trio.to_thread.run_sync(not_main_thread) @@ -90,13 +94,13 @@ def not_main_thread(): # @coroutine is deprecated since python 3.8, which is fine with us. @pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning") -def test_coroutine_or_error(): +def test_coroutine_or_error() -> None: class Deferred: "Just kidding" with ignore_coroutine_never_awaited_warnings(): - async def f(): # pragma: no cover + async def f() -> None: # pragma: no cover pass with pytest.raises(TypeError) as excinfo: @@ -106,7 +110,7 @@ async def f(): # pragma: no cover import asyncio @asyncio.coroutine - def generator_based_coro(): # pragma: no cover + def generator_based_coro() -> Iterator[None]: # pragma: no cover yield from asyncio.sleep(1) with pytest.raises(TypeError) as excinfo: @@ -146,21 +150,21 @@ async def async_gen(arg): # pragma: no cover del excinfo -def test_generic_function(): +def test_generic_function() -> None: @generic_function - def test_func(arg): + def test_func(arg: _T) -> _T: """Look, a docstring!""" return arg - assert test_func is test_func[int] is test_func[int, str] - assert test_func(42) == test_func[int](42) == 42 + assert test_func is test_func[int] is test_func[int, str] # type: ignore[index] + assert test_func(42) == test_func[int](42) == 42 # type: ignore[index] assert test_func.__doc__ == "Look, a docstring!" assert test_func.__qualname__ == "test_generic_function..test_func" assert test_func.__name__ == "test_func" assert test_func.__module__ == __name__ -def test_final_metaclass(): +def test_final_metaclass() -> None: class FinalClass(metaclass=Final): pass @@ -170,7 +174,7 @@ class SubClass(FinalClass): pass -def test_no_public_constructor_metaclass(): +def test_no_public_constructor_metaclass() -> None: class SpecialClass(metaclass=NoPublicConstructor): pass diff --git a/trio/tests/test_wait_for_object.py b/trio/tests/test_wait_for_object.py index 38acfa802d..00a0b23015 100644 --- a/trio/tests/test_wait_for_object.py +++ b/trio/tests/test_wait_for_object.py @@ -19,7 +19,7 @@ ) -async def test_WaitForMultipleObjects_sync(): +async def test_WaitForMultipleObjects_sync() -> None: # This does a series of tests where we set/close the handle before # initiating the waiting for it. # @@ -73,7 +73,7 @@ async def test_WaitForMultipleObjects_sync(): @slow -async def test_WaitForMultipleObjects_sync_slow(): +async def test_WaitForMultipleObjects_sync_slow() -> None: # This does a series of test in which the main thread sync-waits for # handles, while we spawn a thread to set the handles after a short while. @@ -128,7 +128,7 @@ async def test_WaitForMultipleObjects_sync_slow(): print("test_WaitForMultipleObjects_sync_slow thread-set second OK") -async def test_WaitForSingleObject(): +async def test_WaitForSingleObject() -> None: # This does a series of test for setting/closing the handle before # initiating the wait. @@ -163,7 +163,7 @@ async def test_WaitForSingleObject(): @slow -async def test_WaitForSingleObject_slow(): +async def test_WaitForSingleObject_slow() -> None: # This does a series of test for setting the handle in another task, # and cancelling the wait task. diff --git a/trio/tests/test_windows_pipes.py b/trio/tests/test_windows_pipes.py index 361cd64ce2..6c543b1430 100644 --- a/trio/tests/test_windows_pipes.py +++ b/trio/tests/test_windows_pipes.py @@ -3,6 +3,7 @@ import os import sys +from typing import Any, Tuple import pytest from .._core.tests.tutil import gc_collect_harder @@ -15,25 +16,25 @@ from asyncio.windows_utils import pipe else: pytestmark = pytest.mark.skip(reason="windows only") - pipe = None # type: Any - PipeSendStream = None # type: Any - PipeReceiveStream = None # type: Any + pipe: Any = None + PipeSendStream: Any = None + PipeReceiveStream: Any = None -async def make_pipe() -> "Tuple[PipeSendStream, PipeReceiveStream]": +async def make_pipe() -> Tuple[PipeSendStream, PipeReceiveStream]: """Makes a new pair of pipes.""" (r, w) = pipe() return PipeSendStream(w), PipeReceiveStream(r) -async def test_pipe_typecheck(): +async def test_pipe_typecheck() -> None: with pytest.raises(TypeError): PipeSendStream(1.0) with pytest.raises(TypeError): PipeReceiveStream(None) -async def test_pipe_error_on_close(): +async def test_pipe_error_on_close() -> None: # Make sure we correctly handle a failure from kernel32.CloseHandle r, w = pipe() @@ -49,18 +50,18 @@ async def test_pipe_error_on_close(): await receive_stream.aclose() -async def test_pipes_combined(): +async def test_pipes_combined() -> None: write, read = await make_pipe() count = 2 ** 20 replicas = 3 - async def sender(): + async def sender() -> None: async with write: big = bytearray(count) for _ in range(replicas): await write.send_all(big) - async def reader(): + async def reader() -> None: async with read: await wait_all_tasks_blocked() total_received = 0 @@ -78,7 +79,7 @@ async def reader(): n.start_soon(reader) -async def test_async_with(): +async def test_async_with() -> None: w, r = await make_pipe() async with w, r: pass @@ -89,11 +90,11 @@ async def test_async_with(): await r.receive_some(10) -async def test_close_during_write(): +async def test_close_during_write() -> None: w, r = await make_pipe() async with _core.open_nursery() as nursery: - async def write_forever(): + async def write_forever() -> None: with pytest.raises(_core.ClosedResourceError) as excinfo: while True: await w.send_all(b"x" * 4096) @@ -104,7 +105,7 @@ async def write_forever(): await w.aclose() -async def test_pipe_fully(): +async def test_pipe_fully() -> None: # passing make_clogged_pipe tests wait_send_all_might_not_block, and we # can't implement that on Windows await check_one_way_stream(make_pipe, None) diff --git a/trio/tests/tools/test_gen_exports.py b/trio/tests/tools/test_gen_exports.py index e4e388c226..43cbc3a88a 100644 --- a/trio/tests/tools/test_gen_exports.py +++ b/trio/tests/tools/test_gen_exports.py @@ -1,5 +1,6 @@ import ast import astor +from pathlib import Path import pytest import os import sys @@ -32,12 +33,12 @@ async def not_public_async(self): ''' -def test_get_public_methods(): +def test_get_public_methods() -> None: methods = list(get_public_methods(ast.parse(SOURCE))) assert {m.name for m in methods} == {"public_func", "public_async_func"} -def test_create_pass_through_args(): +def test_create_pass_through_args() -> None: testcases = [ ("def f()", "()"), ("def f(one)", "(one)"), @@ -55,7 +56,7 @@ def test_create_pass_through_args(): assert create_passthrough_args(func_node) == expected -def test_process(tmp_path): +def test_process(tmp_path: Path) -> None: modpath = tmp_path / "_module.py" genpath = tmp_path / "_generated_module.py" modpath.write_text(SOURCE, encoding="utf-8") diff --git a/typing.sh b/typing.sh new file mode 100755 index 0000000000..fe156450fd --- /dev/null +++ b/typing.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +set -ex + +EXIT_STATUS=0 + +# Test if the generated code is still up to date +python ./trio/_tools/gen_exports.py --test \ + || EXIT_STATUS=$? + +# Run mypy on all supported platforms +for PLATFORM in linux darwin win32; do + for VERSION in 3.6 3.7 3.8 3.9; do + mypy -p trio --platform $PLATFORM --python-version $VERSION || EXIT_STATUS=$? + done +done + +# Finally, leave a really clear warning of any issues and exit +if [ $EXIT_STATUS -ne 0 ]; then + cat <