diff --git a/pyproject.toml b/pyproject.toml index a212393452..6893927337 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,12 +50,6 @@ disallow_untyped_calls = false # files not yet fully typed [[tool.mypy.overrides]] module = [ -# 2747 -"trio/testing/_network", -"trio/testing/_trio_test", -"trio/testing/_checkpoints", -"trio/testing/_check_streams", -"trio/testing/_memory_streams", # 2745 "trio/_ssl", # 2756 @@ -70,7 +64,6 @@ module = [ "trio/_core/_generated_io_windows", "trio/_core/_io_windows", - "trio/_signals", # internal diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 1ba88da85e..3a4751254f 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -2152,7 +2152,7 @@ def setup_runner( def run( - async_fn: Callable[..., RetT], + async_fn: Callable[..., Awaitable[RetT]], *args: object, clock: Clock | None = None, instruments: Sequence[Instrument] = (), diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index e1ac378c6a..4269f90bae 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -1,16 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Generic, TypeVar import attr import trio from trio._util import Final -if TYPE_CHECKING: - from .abc import SendStream, ReceiveStream, AsyncResource +from .abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream -from .abc import HalfCloseableStream +SendStreamT = TypeVar("SendStreamT", bound=SendStream) +ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream) async def aclose_forcefully(resource: AsyncResource) -> None: @@ -44,7 +44,11 @@ async def aclose_forcefully(resource: AsyncResource) -> None: @attr.s(eq=False, hash=False) -class StapledStream(HalfCloseableStream, metaclass=Final): +class StapledStream( + HalfCloseableStream, + Generic[SendStreamT, ReceiveStreamT], + metaclass=Final, +): """This class `staples `__ together two unidirectional streams to make single bidirectional stream. @@ -79,8 +83,8 @@ class StapledStream(HalfCloseableStream, metaclass=Final): """ - send_stream: SendStream = attr.ib() - receive_stream: ReceiveStream = attr.ib() + send_stream: SendStreamT = attr.ib() + receive_stream: ReceiveStreamT = attr.ib() async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Calls ``self.send_stream.send_all``.""" diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 7cf990fa53..978f7e6188 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -149,7 +149,7 @@ def __init__( self.stdout = stdout self.stderr = stderr - self.stdio: StapledStream | None = None + self.stdio: StapledStream[SendStream, ReceiveStream] | None = None if self.stdin is not None and self.stdout is not None: self.stdio = StapledStream(self.stdin, self.stdout) diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index c5e9c4dc66..b61b28a428 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,16 +7,16 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.964968152866242, + "completenessScore": 0.9872611464968153, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 606, - "withUnknownType": 22 + "withKnownType": 620, + "withUnknownType": 8 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, "missingDefaultParamCount": 0, - "missingFunctionDocStringCount": 4, + "missingFunctionDocStringCount": 3, "moduleName": "trio", "modules": [ { @@ -46,8 +46,8 @@ ], "otherSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 642, - "withUnknownType": 39 + "withKnownType": 662, + "withUnknownType": 19 }, "packageName": "trio", "symbols": [ @@ -76,36 +76,6 @@ "trio.open_unix_socket", "trio.serve_listeners", "trio.serve_ssl_over_tcp", - "trio.testing._memory_streams.MemoryReceiveStream.__init__", - "trio.testing._memory_streams.MemoryReceiveStream.aclose", - "trio.testing._memory_streams.MemoryReceiveStream.close", - "trio.testing._memory_streams.MemoryReceiveStream.close_hook", - "trio.testing._memory_streams.MemoryReceiveStream.put_data", - "trio.testing._memory_streams.MemoryReceiveStream.put_eof", - "trio.testing._memory_streams.MemoryReceiveStream.receive_some", - "trio.testing._memory_streams.MemoryReceiveStream.receive_some_hook", - "trio.testing._memory_streams.MemorySendStream.__init__", - "trio.testing._memory_streams.MemorySendStream.aclose", - "trio.testing._memory_streams.MemorySendStream.close", - "trio.testing._memory_streams.MemorySendStream.close_hook", - "trio.testing._memory_streams.MemorySendStream.get_data", - "trio.testing._memory_streams.MemorySendStream.get_data_nowait", - "trio.testing._memory_streams.MemorySendStream.send_all", - "trio.testing._memory_streams.MemorySendStream.send_all_hook", - "trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block", - "trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block_hook", - "trio.testing.assert_checkpoints", - "trio.testing.assert_no_checkpoints", - "trio.testing.check_half_closeable_stream", - "trio.testing.check_one_way_stream", - "trio.testing.check_two_way_stream", - "trio.testing.lockstep_stream_one_way_pair", - "trio.testing.lockstep_stream_pair", - "trio.testing.memory_stream_one_way_pair", - "trio.testing.memory_stream_pair", - "trio.testing.memory_stream_pump", - "trio.testing.open_stream_to_socket_listener", - "trio.testing.trio_test", "trio.tests.TestsDeprecationWrapper" ] } diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 401b8ef0c2..33947ccc55 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -2,24 +2,33 @@ from __future__ import annotations import random +from collections.abc import Generator from contextlib import contextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Awaitable, Callable, Generic, Tuple, TypeVar -from .. import _core -from .._abc import HalfCloseableStream, ReceiveStream, SendStream, Stream +from .. import CancelScope, _core +from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream from .._highlevel_generic import aclose_forcefully from ._checkpoints import assert_checkpoints if TYPE_CHECKING: from types import TracebackType + from typing_extensions import ParamSpec, TypeAlias -class _ForceCloseBoth: - def __init__(self, both): - self._both = list(both) + ArgsT = ParamSpec("ArgsT") - async def __aenter__(self): - return self._both +Res1 = TypeVar("Res1", bound=AsyncResource) +Res2 = TypeVar("Res2", bound=AsyncResource) +StreamMaker: TypeAlias = Callable[[], Awaitable[Tuple[Res1, Res2]]] + + +class _ForceCloseBoth(Generic[Res1, Res2]): + def __init__(self, both: tuple[Res1, Res2]) -> None: + self._first, self._second = both + + async def __aenter__(self) -> tuple[Res1, Res2]: + return self._first, self._second async def __aexit__( self, @@ -28,13 +37,13 @@ async def __aexit__( traceback: TracebackType | None, ) -> None: try: - await aclose_forcefully(self._both[0]) + await aclose_forcefully(self._first) finally: - await aclose_forcefully(self._both[1]) + await aclose_forcefully(self._second) @contextmanager -def _assert_raises(exc): +def _assert_raises(exc: type[BaseException]) -> Generator[None, None, None]: __tracebackhide__ = True try: yield @@ -44,7 +53,10 @@ def _assert_raises(exc): raise AssertionError(f"expected exception: {exc}") -async def check_one_way_stream(stream_maker, clogged_stream_maker): +async def check_one_way_stream( + stream_maker: StreamMaker[SendStream, ReceiveStream], + clogged_stream_maker: StreamMaker[SendStream, ReceiveStream] | None, +) -> None: """Perform a number of generic tests on a custom one-way stream implementation. @@ -67,18 +79,18 @@ async def check_one_way_stream(stream_maker, clogged_stream_maker): assert isinstance(s, SendStream) assert isinstance(r, ReceiveStream) - async def do_send_all(data): - with assert_checkpoints(): - assert await s.send_all(data) is None + async def do_send_all(data: bytes | bytearray | memoryview) -> None: + with assert_checkpoints(): # We're testing that it doesn't return anything. + assert await s.send_all(data) is None # type: ignore[func-returns-value] - async def do_receive_some(*args): + async def do_receive_some(max_bytes: int | None = None) -> bytes | bytearray: with assert_checkpoints(): - return await r.receive_some(*args) + return await r.receive_some(max_bytes) - async def checked_receive_1(expected): + async def checked_receive_1(expected: bytes) -> None: assert await do_receive_some(1) == expected - async def do_aclose(resource): + async def do_aclose(resource: AsyncResource) -> None: with assert_checkpoints(): await resource.aclose() @@ -87,7 +99,7 @@ async def do_aclose(resource): nursery.start_soon(do_send_all, b"x") nursery.start_soon(checked_receive_1, b"x") - async def send_empty_then_y(): + async def send_empty_then_y() -> None: # Streams should tolerate sending b"" without giving it any # special meaning. await do_send_all(b"") @@ -114,7 +126,7 @@ async def send_empty_then_y(): with _assert_raises(ValueError): await r.receive_some(0) with _assert_raises(TypeError): - await r.receive_some(1.5) + await r.receive_some(1.5) # type: ignore[arg-type] # it can also be missing or None async with _core.open_nursery() as nursery: nursery.start_soon(do_send_all, b"x") @@ -133,7 +145,9 @@ async def send_empty_then_y(): # for send_all to wait until receive_some is called to run, though; a # stream doesn't *have* to have any internal buffering. That's why we # start a concurrent receive_some call, then cancel it.) - async def simple_check_wait_send_all_might_not_block(scope): + async def simple_check_wait_send_all_might_not_block( + scope: CancelScope, + ) -> None: with assert_checkpoints(): await s.wait_send_all_might_not_block() scope.cancel() @@ -146,7 +160,7 @@ async def simple_check_wait_send_all_might_not_block(scope): # closing the r side leads to BrokenResourceError on the s side # (eventually) - async def expect_broken_stream_on_send(): + async def expect_broken_stream_on_send() -> None: with _assert_raises(_core.BrokenResourceError): while True: await do_send_all(b"x" * 100) @@ -189,11 +203,11 @@ async def expect_broken_stream_on_send(): async with _ForceCloseBoth(await stream_maker()) as (s, r): # if send-then-graceful-close, receiver gets data then b"" - async def send_then_close(): + async def send_then_close() -> None: await do_send_all(b"y") await do_aclose(s) - async def receive_send_then_close(): + async def receive_send_then_close() -> None: # We want to make sure that if the sender closes the stream before # we read anything, then we still get all the data. But some # streams might block on the do_send_all call. So we let the @@ -258,9 +272,13 @@ async def receive_send_then_close(): # https://github.com/python-trio/trio/issues/77 async with _ForceCloseBoth(await stream_maker()) as (s, r): - async def expect_cancelled(afn, *args): + async def expect_cancelled( + afn: Callable[ArgsT, Awaitable[object]], + *args: ArgsT.args, + **kwargs: ArgsT.kwargs, + ) -> None: with _assert_raises(_core.Cancelled): - await afn(*args) + await afn(*args, **kwargs) with _core.CancelScope() as scope: scope.cancel() @@ -288,16 +306,16 @@ async def receive_expecting_closed(): # check wait_send_all_might_not_block, if we can if clogged_stream_maker is not None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): - record = [] + record: list[str] = [] - async def waiter(cancel_scope): + async def waiter(cancel_scope: CancelScope) -> None: record.append("waiter sleeping") with assert_checkpoints(): await s.wait_send_all_might_not_block() record.append("waiter wokeup") cancel_scope.cancel() - async def receiver(): + async def receiver() -> None: # give wait_send_all_might_not_block a chance to block await _core.wait_all_tasks_blocked() record.append("receiver starting") @@ -343,14 +361,14 @@ async def receiver(): # with or without an exception async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): - async def sender(): + async def sender() -> None: try: with assert_checkpoints(): await s.wait_send_all_might_not_block() except _core.BrokenResourceError: # pragma: no cover pass - async def receiver(): + async def receiver() -> None: await _core.wait_all_tasks_blocked() await aclose_forcefully(r) @@ -369,7 +387,7 @@ async def receiver(): # Check that if a task is blocked in a send-side method, then closing # the send stream causes it to wake up. - async def close_soon(s): + async def close_soon(s: SendStream) -> None: await _core.wait_all_tasks_blocked() await aclose_forcefully(s) @@ -386,7 +404,10 @@ async def close_soon(s): await s.wait_send_all_might_not_block() -async def check_two_way_stream(stream_maker, clogged_stream_maker): +async def check_two_way_stream( + stream_maker: StreamMaker[Stream, Stream], + clogged_stream_maker: StreamMaker[Stream, Stream] | None, +) -> None: """Perform a number of generic tests on a custom two-way stream implementation. @@ -401,13 +422,15 @@ async def check_two_way_stream(stream_maker, clogged_stream_maker): """ await check_one_way_stream(stream_maker, clogged_stream_maker) - async def flipped_stream_maker(): - return reversed(await stream_maker()) + async def flipped_stream_maker() -> tuple[Stream, Stream]: + return (await stream_maker())[::-1] + + flipped_clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None if clogged_stream_maker is not None: - async def flipped_clogged_stream_maker(): - return reversed(await clogged_stream_maker()) + async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]: + return (await clogged_stream_maker())[::-1] else: flipped_clogged_stream_maker = None @@ -425,7 +448,9 @@ async def flipped_clogged_stream_maker(): i = r.getrandbits(8 * DUPLEX_TEST_SIZE) test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little") - async def sender(s, data, seed): + async def sender( + s: Stream, data: bytes | bytearray | memoryview, seed: int + ) -> None: r = random.Random(seed) m = memoryview(data) while m: @@ -433,7 +458,7 @@ async def sender(s, data, seed): await s.send_all(m[:chunk_size]) m = m[chunk_size:] - async def receiver(s, data, seed): + async def receiver(s: Stream, data: bytes | bytearray, seed: int) -> None: r = random.Random(seed) got = bytearray() while len(got) < len(data): @@ -448,7 +473,7 @@ async def receiver(s, data, seed): nursery.start_soon(receiver, s1, test_data[::-1], 2) nursery.start_soon(receiver, s2, test_data, 3) - async def expect_receive_some_empty(): + async def expect_receive_some_empty() -> None: assert await s2.receive_some(10) == b"" await s2.aclose() @@ -457,7 +482,10 @@ async def expect_receive_some_empty(): nursery.start_soon(s1.aclose) -async def check_half_closeable_stream(stream_maker, clogged_stream_maker): +async def check_half_closeable_stream( + stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream], + clogged_stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream] | None, +) -> None: """Perform a number of generic tests on a custom half-closeable stream implementation. @@ -476,12 +504,12 @@ async def check_half_closeable_stream(stream_maker, clogged_stream_maker): assert isinstance(s1, HalfCloseableStream) assert isinstance(s2, HalfCloseableStream) - async def send_x_then_eof(s): + async def send_x_then_eof(s: HalfCloseableStream) -> None: await s.send_all(b"x") with assert_checkpoints(): await s.send_eof() - async def expect_x_then_eof(r): + async def expect_x_then_eof(r: HalfCloseableStream) -> None: await _core.wait_all_tasks_blocked() assert await r.receive_some(10) == b"x" assert await r.receive_some(10) == b"" diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py index 5804295300..4a4047813b 100644 --- a/trio/testing/_checkpoints.py +++ b/trio/testing/_checkpoints.py @@ -1,10 +1,14 @@ -from contextlib import contextmanager +from __future__ import annotations + +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager from .. import _core @contextmanager -def _assert_yields_or_not(expected): +def _assert_yields_or_not(expected: bool) -> Generator[None, None, None]: + """Check if checkpoints are executed in a block of code.""" __tracebackhide__ = True task = _core.current_task() orig_cancel = task._cancel_points @@ -22,7 +26,7 @@ def _assert_yields_or_not(expected): raise AssertionError("assert_no_checkpoints block yielded!") -def assert_checkpoints(): +def assert_checkpoints() -> AbstractContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block either exits with an exception or executes at least one :ref:`checkpoint `. @@ -42,7 +46,7 @@ def assert_checkpoints(): return _assert_yields_or_not(True) -def assert_no_checkpoints(): +def assert_no_checkpoints() -> AbstractContextManager[None]: """Use as a context manager to check that the code inside the ``with`` block does not execute any :ref:`checkpoints `. diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index 38e8e54de8..fc23fae842 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -1,16 +1,30 @@ +from __future__ import annotations + import operator +from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar from .. import _core, _util from .._highlevel_generic import StapledStream from ..abc import ReceiveStream, SendStream +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + +AsyncHook: TypeAlias = Callable[[], Awaitable[object]] +# Would be nice to exclude awaitable here, but currently not possible. +SyncHook: TypeAlias = Callable[[], object] +SendStreamT = TypeVar("SendStreamT", bound=SendStream) +ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream) + + ################################################################ # In-memory streams - Unbounded buffer version ################################################################ class _UnboundedByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._closed = False self._lot = _core.ParkingLot() @@ -22,28 +36,28 @@ def __init__(self): # channel: so after close(), calling put() raises ClosedResourceError, and # calling the get() variants drains the buffer and then returns an empty # bytearray. - def close(self): + def close(self) -> None: self._closed = True self._lot.unpark_all() - def close_and_wipe(self): + def close_and_wipe(self) -> None: self._data = bytearray() self.close() - def put(self, data): + def put(self, data: bytes | bytearray | memoryview) -> None: if self._closed: raise _core.ClosedResourceError("virtual connection closed") self._data += data self._lot.unpark_all() - def _check_max_bytes(self, max_bytes): + def _check_max_bytes(self, max_bytes: int | None) -> None: if max_bytes is None: return max_bytes = operator.index(max_bytes) if max_bytes < 1: raise ValueError("max_bytes must be >= 1") - def _get_impl(self, max_bytes): + def _get_impl(self, max_bytes: int | None) -> bytearray: assert self._closed or self._data if max_bytes is None: max_bytes = len(self._data) @@ -55,14 +69,14 @@ def _get_impl(self, max_bytes): else: return bytearray() - def get_nowait(self, max_bytes=None): + def get_nowait(self, max_bytes: int | None = None) -> bytearray: with self._fetch_lock: self._check_max_bytes(max_bytes) if not self._closed and not self._data: raise _core.WouldBlock return self._get_impl(max_bytes) - async def get(self, max_bytes=None): + async def get(self, max_bytes: int | None = None) -> bytearray: with self._fetch_lock: self._check_max_bytes(max_bytes) if not self._closed and not self._data: @@ -95,9 +109,9 @@ class MemorySendStream(SendStream, metaclass=_util.Final): def __init__( self, - send_all_hook=None, - wait_send_all_might_not_block_hook=None, - close_hook=None, + send_all_hook: AsyncHook | None = None, + wait_send_all_might_not_block_hook: AsyncHook | None = None, + close_hook: SyncHook | None = None, ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" @@ -107,7 +121,7 @@ def __init__( self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook self.close_hook = close_hook - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Places the given data into the object's internal buffer, and then calls the :attr:`send_all_hook` (if any). @@ -121,12 +135,12 @@ async def send_all(self, data): if self.send_all_hook is not None: await self.send_all_hook() - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and then returns immediately. """ - # Execute two checkpoints so we have more of a chance to detect + # Execute two checkpoints so that we have more of a chance to detect # buggy user code that calls this twice at the same time. with self._conflict_detector: await _core.checkpoint() @@ -136,7 +150,7 @@ async def wait_send_all_might_not_block(self): if self.wait_send_all_might_not_block_hook is not None: await self.wait_send_all_might_not_block_hook() - def close(self): + def close(self) -> None: """Marks this stream as closed, and then calls the :attr:`close_hook` (if any). @@ -153,12 +167,12 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() - async def get_data(self, max_bytes=None): + async def get_data(self, max_bytes: int | None = None) -> bytearray: """Retrieves data from the internal buffer, blocking if necessary. Args: @@ -174,7 +188,7 @@ async def get_data(self, max_bytes=None): """ return await self._outgoing.get(max_bytes) - def get_data_nowait(self, max_bytes=None): + def get_data_nowait(self, max_bytes: int | None = None) -> bytearray: """Retrieves data from the internal buffer, but doesn't block. See :meth:`get_data` for details. @@ -203,7 +217,11 @@ class MemoryReceiveStream(ReceiveStream, metaclass=_util.Final): """ - def __init__(self, receive_some_hook=None, close_hook=None): + def __init__( + self, + receive_some_hook: AsyncHook | None = None, + close_hook: SyncHook | None = None, + ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" ) @@ -212,7 +230,7 @@ def __init__(self, receive_some_hook=None, close_hook=None): self.receive_some_hook = receive_some_hook self.close_hook = close_hook - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytearray: """Calls the :attr:`receive_some_hook` (if any), and then retrieves data from the internal buffer, blocking if necessary. @@ -235,7 +253,7 @@ async def receive_some(self, max_bytes=None): raise _core.ClosedResourceError return data - def close(self): + def close(self) -> None: """Discards any pending data from the internal buffer, and marks this stream as closed. @@ -245,21 +263,26 @@ def close(self): if self.close_hook is not None: self.close_hook() - async def aclose(self): + async def aclose(self) -> None: """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() - def put_data(self, data): + def put_data(self, data: bytes | bytearray | memoryview) -> None: """Appends the given data to the internal buffer.""" self._incoming.put(data) - def put_eof(self): + def put_eof(self) -> None: """Adds an end-of-file marker to the internal buffer.""" self._incoming.close() -def memory_stream_pump(memory_send_stream, memory_receive_stream, *, max_bytes=None): +def memory_stream_pump( + memory_send_stream: MemorySendStream, + memory_receive_stream: MemoryReceiveStream, + *, + max_bytes: int | None = None, +) -> bool: """Take data out of the given :class:`MemorySendStream`'s internal buffer, and put it into the given :class:`MemoryReceiveStream`'s internal buffer. @@ -292,7 +315,7 @@ def memory_stream_pump(memory_send_stream, memory_receive_stream, *, max_bytes=N return True -def memory_stream_one_way_pair(): +def memory_stream_one_way_pair() -> tuple[MemorySendStream, MemoryReceiveStream]: """Create a connected, pure-Python, unidirectional stream with infinite buffering and flexible configuration options. @@ -319,10 +342,10 @@ def memory_stream_one_way_pair(): send_stream = MemorySendStream() recv_stream = MemoryReceiveStream() - def pump_from_send_stream_to_recv_stream(): + def pump_from_send_stream_to_recv_stream() -> None: memory_stream_pump(send_stream, recv_stream) - async def async_pump_from_send_stream_to_recv_stream(): + async def async_pump_from_send_stream_to_recv_stream() -> None: pump_from_send_stream_to_recv_stream() send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream @@ -330,7 +353,12 @@ async def async_pump_from_send_stream_to_recv_stream(): return send_stream, recv_stream -def _make_stapled_pair(one_way_pair): +def _make_stapled_pair( + one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]] +) -> tuple[ + StapledStream[SendStreamT, ReceiveStreamT], + StapledStream[SendStreamT, ReceiveStreamT], +]: pipe1_send, pipe1_recv = one_way_pair() pipe2_send, pipe2_recv = one_way_pair() stream1 = StapledStream(pipe1_send, pipe2_recv) @@ -338,7 +366,12 @@ def _make_stapled_pair(one_way_pair): return stream1, stream2 -def memory_stream_pair(): +def memory_stream_pair() -> ( + tuple[ + StapledStream[MemorySendStream, MemoryReceiveStream], + StapledStream[MemorySendStream, MemoryReceiveStream], + ] +): """Create a connected, pure-Python, bidirectional stream with infinite buffering and flexible configuration options. @@ -421,7 +454,7 @@ async def receiver(): class _LockstepByteQueue: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._sender_closed = False self._receiver_closed = False @@ -434,12 +467,12 @@ def __init__(self): "another task is already receiving" ) - def _something_happened(self): + def _something_happened(self) -> None: self._waiters.unpark_all() # Always wakes up when one side is closed, because everyone always reacts # to that. - async def _wait_for(self, fn): + async def _wait_for(self, fn: Callable[[], bool]) -> None: while True: if fn(): break @@ -448,15 +481,15 @@ async def _wait_for(self, fn): await self._waiters.park() await _core.checkpoint() - def close_sender(self): + def close_sender(self) -> None: self._sender_closed = True self._something_happened() - def close_receiver(self): + def close_receiver(self) -> None: self._receiver_closed = True self._something_happened() - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: with self._send_conflict_detector: if self._sender_closed: raise _core.ClosedResourceError @@ -465,13 +498,13 @@ async def send_all(self, data): assert not self._data self._data += data self._something_happened() - await self._wait_for(lambda: not self._data) + await self._wait_for(lambda: self._data == b"") if self._sender_closed: raise _core.ClosedResourceError if self._data and self._receiver_closed: raise _core.BrokenResourceError - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: if self._sender_closed: raise _core.ClosedResourceError @@ -482,7 +515,7 @@ async def wait_send_all_might_not_block(self): if self._sender_closed: raise _core.ClosedResourceError - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: with self._receive_conflict_detector: # Argument validation if max_bytes is not None: @@ -496,7 +529,7 @@ async def receive_some(self, max_bytes=None): self._receiver_waiting = True self._something_happened() try: - await self._wait_for(lambda: self._data) + await self._wait_for(lambda: self._data != b"") finally: self._receiver_waiting = False if self._receiver_closed: @@ -515,39 +548,39 @@ async def receive_some(self, max_bytes=None): class _LockstepSendStream(SendStream): - def __init__(self, lbq): + def __init__(self, lbq: _LockstepByteQueue): self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_sender() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: await self._lbq.send_all(data) - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: await self._lbq.wait_send_all_might_not_block() class _LockstepReceiveStream(ReceiveStream): - def __init__(self, lbq): + def __init__(self, lbq: _LockstepByteQueue): self._lbq = lbq - def close(self): + def close(self) -> None: self._lbq.close_receiver() - async def aclose(self): + async def aclose(self) -> None: self.close() await _core.checkpoint() - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: return await self._lbq.receive_some(max_bytes) -def lockstep_stream_one_way_pair(): +def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]: """Create a connected, pure Python, unidirectional stream where data flows in lockstep. @@ -574,7 +607,12 @@ def lockstep_stream_one_way_pair(): return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq) -def lockstep_stream_pair(): +def lockstep_stream_pair() -> ( + tuple[ + StapledStream[SendStream, ReceiveStream], + StapledStream[SendStream, ReceiveStream], + ] +): """Create a connected, pure-Python, bidirectional stream where data flows in lockstep. diff --git a/trio/testing/_network.py b/trio/testing/_network.py index 615ce2effb..fddbbf0fdc 100644 --- a/trio/testing/_network.py +++ b/trio/testing/_network.py @@ -1,8 +1,10 @@ from .. import socket as tsocket -from .._highlevel_socket import SocketStream +from .._highlevel_socket import SocketListener, SocketStream -async def open_stream_to_socket_listener(socket_listener): +async def open_stream_to_socket_listener( + socket_listener: SocketListener, +) -> SocketStream: """Connect to the given :class:`~trio.SocketListener`. This is particularly useful in tests when you want to let a server pick diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index b4ef69ef09..5619352846 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -1,20 +1,36 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable from functools import partial, wraps +from typing import TYPE_CHECKING, TypeVar from .. import _core from ..abc import Clock, Instrument +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + ArgsT = ParamSpec("ArgsT") + + +RetT = TypeVar("RetT") + + +def trio_test(fn: Callable[ArgsT, Awaitable[RetT]]) -> Callable[ArgsT, RetT]: + """Converts an async test function to be synchronous, running via Trio. + + Usage:: + + @trio_test + async def test_whatever(): + await ... + + If a pytest fixture is passed in that subclasses the :class:`~trio.abc.Clock` or + :class:`~trio.abc.Instrument` ABCs, then those are passed to :meth:`trio.run()`. + """ -# Use: -# -# @trio_test -# async def test_whatever(): -# await ... -# -# Also: if a pytest fixture is passed in that subclasses the Clock abc, then -# that clock is passed to trio.run(). -def trio_test(fn): @wraps(fn) - def wrapper(**kwargs): + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: __tracebackhide__ = True clocks = [c for c in kwargs.values() if isinstance(c, Clock)] if not clocks: @@ -24,6 +40,8 @@ def wrapper(**kwargs): else: raise ValueError("too many clocks spoil the broth!") instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] - return _core.run(partial(fn, **kwargs), clock=clock, instruments=instruments) + return _core.run( + partial(fn, *args, **kwargs), clock=clock, instruments=instruments + ) return wrapper