diff --git a/docs/source/conf.py b/docs/source/conf.py index b6d5e63043..66aa8dea05 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -59,6 +59,7 @@ # these are not defined in https://docs.python.org/3/objects.inv ("py:class", "socket.AddressFamily"), ("py:class", "socket.SocketKind"), + ("py:class", "Buffer"), # collections.abc.Buffer, in 3.12 ] autodoc_inherit_docstrings = False default_role = "obj" diff --git a/pyproject.toml b/pyproject.toml index 6893927337..ae2e662992 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,11 +52,6 @@ disallow_untyped_calls = false module = [ # 2745 "trio/_ssl", -# 2756 -"trio/_highlevel_open_unix_stream", -"trio/_highlevel_serve_listeners", -"trio/_highlevel_ssl_helpers", -"trio/_highlevel_socket", # 2755 "trio/_core/_windows_cffi", "trio/_wait_for_object", diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py index c2c3a3ca7c..c05b8f3fc8 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/trio/_highlevel_open_unix_stream.py @@ -1,9 +1,22 @@ +from __future__ import annotations + import os +from collections.abc import Generator from contextlib import contextmanager +from typing import Protocol, TypeVar import trio from trio.socket import SOCK_STREAM, socket + +class Closable(Protocol): + def close(self) -> None: + ... + + +CloseT = TypeVar("CloseT", bound=Closable) + + try: from trio.socket import AF_UNIX @@ -13,7 +26,7 @@ @contextmanager -def close_on_error(obj): +def close_on_error(obj: CloseT) -> Generator[CloseT, None, None]: try: yield obj except: @@ -21,7 +34,9 @@ def close_on_error(obj): raise -async def open_unix_socket(filename): +async def open_unix_socket( + filename: str | bytes | os.PathLike[str] | os.PathLike[bytes], +) -> trio.SocketStream: """Opens a connection to the specified `Unix domain socket `__. diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py index 0585fa516f..d5c7a3bdad 100644 --- a/trio/_highlevel_serve_listeners.py +++ b/trio/_highlevel_serve_listeners.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import errno import logging import os +from typing import Any, Awaitable, Callable, NoReturn, TypeVar import trio @@ -20,14 +23,23 @@ LOGGER = logging.getLogger("trio.serve_listeners") -async def _run_handler(stream, handler): +StreamT = TypeVar("StreamT", bound=trio.abc.AsyncResource) +ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any]) +Handler = Callable[[StreamT], Awaitable[object]] + + +async def _run_handler(stream: StreamT, handler: Handler[StreamT]) -> None: try: await handler(stream) finally: await trio.aclose_forcefully(stream) -async def _serve_one_listener(listener, handler_nursery, handler): +async def _serve_one_listener( + listener: trio.abc.Listener[StreamT], + handler_nursery: trio.Nursery, + handler: Handler[StreamT], +) -> NoReturn: async with listener: while True: try: @@ -48,9 +60,21 @@ async def _serve_one_listener(listener, handler_nursery, handler): handler_nursery.start_soon(_run_handler, stream, handler) -async def serve_listeners( - handler, listeners, *, handler_nursery=None, task_status=trio.TASK_STATUS_IGNORED -): +# This cannot be typed correctly, we need generic typevar bounds / HKT to indicate the +# relationship between StreamT & ListenerT. +# https://github.com/python/typing/issues/1226 +# https://github.com/python/typing/issues/548 + + +# It does never return (since _serve_one_listener never completes), but type checkers can't +# understand nurseries. +async def serve_listeners( # type: ignore[misc] + handler: Handler[StreamT], + listeners: list[ListenerT], + *, + handler_nursery: trio.Nursery | None = None, + task_status: trio.TaskStatus[list[ListenerT]] = trio.TASK_STATUS_IGNORED, +) -> NoReturn: r"""Listen for incoming connections on ``listeners``, and for each one start a task running ``handler(stream)``. diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index ce96153805..f8d01cd755 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -2,8 +2,9 @@ from __future__ import annotations import errno +from collections.abc import Generator from contextlib import contextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload import trio @@ -12,6 +13,8 @@ from .abc import HalfCloseableStream, Listener if TYPE_CHECKING: + from typing_extensions import Buffer + from ._socket import _SocketType as SocketType # XX TODO: this number was picked arbitrarily. We should do experiments to @@ -29,7 +32,7 @@ @contextmanager -def _translate_socket_errors_to_stream_errors(): +def _translate_socket_errors_to_stream_errors() -> Generator[None, None, None]: try: yield except OSError as exc: @@ -97,7 +100,7 @@ def __init__(self, socket: SocketType): except OSError: pass - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: if self.socket.did_shutdown_SHUT_WR: raise trio.ClosedResourceError("can't send data after sending EOF") with self._send_conflict_detector: @@ -145,15 +148,47 @@ async def aclose(self) -> None: # __aenter__, __aexit__ inherited from HalfCloseableStream are OK - def setsockopt(self, level, option, value): + @overload + def setsockopt(self, level: int, option: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, level: int, option: int, value: None, length: int) -> None: + ... + + def setsockopt( + self, + level: int, + option: int, + value: int | Buffer | None, + length: int | None = None, + ) -> None: """Set an option on the underlying socket. See :meth:`socket.socket.setsockopt` for details. """ - return self.socket.setsockopt(level, option, value) - - def getsockopt(self, level, option, buffersize=0): + if length is None: + if value is None: + raise TypeError( + "invalid value for argument 'value', must not be None when specifying length" + ) + return self.socket.setsockopt(level, option, value) + if value is not None: + raise TypeError( + f"invalid value for argument 'value': {value!r}, must be None when specifying optlen" + ) + return self.socket.setsockopt(level, option, value, length) + + @overload + def getsockopt(self, level: int, option: int) -> int: + ... + + @overload + def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: + ... + + def getsockopt(self, level: int, option: int, buffersize: int = 0) -> int | bytes: """Check the current value of an option on the underlying socket. See :meth:`socket.socket.getsockopt` for details. @@ -311,7 +346,7 @@ def getsockopt(self, level, option, buffersize=0): ] # Not all errnos are defined on all platforms -_ignorable_accept_errnos = set() +_ignorable_accept_errnos: set[int] = set() for name in _ignorable_accept_errno_names: try: _ignorable_accept_errnos.add(getattr(errno, name)) diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py index ad77a302f0..1647f373c2 100644 --- a/trio/_highlevel_ssl_helpers.py +++ b/trio/_highlevel_ssl_helpers.py @@ -1,4 +1,8 @@ +from __future__ import annotations + import ssl +from collections.abc import Awaitable, Callable +from typing import NoReturn import trio @@ -15,13 +19,13 @@ # So... let's punt on that for now. Hopefully we'll be getting a new Python # TLS API soon and can revisit this then. async def open_ssl_over_tcp_stream( - host, - port, + host: str | bytes, + port: int, *, - https_compatible=False, - ssl_context=None, - happy_eyeballs_delay=DEFAULT_DELAY, -): + https_compatible: bool = False, + ssl_context: ssl.SSLContext | None = None, + happy_eyeballs_delay: float | None = DEFAULT_DELAY, +) -> trio.SSLStream: """Make a TLS-encrypted Connection to the given host and port over TCP. This is a convenience wrapper that calls :func:`open_tcp_stream` and @@ -63,8 +67,13 @@ async def open_ssl_over_tcp_stream( async def open_ssl_over_tcp_listeners( - port, ssl_context, *, host=None, https_compatible=False, backlog=None -): + port: int, + ssl_context: ssl.SSLContext, + *, + host: str | bytes | None = None, + https_compatible: bool = False, + backlog: int | float | None = None, +) -> list[trio.SSLListener]: """Start listening for SSL/TLS-encrypted TCP connections to the given port. Args: @@ -86,16 +95,16 @@ async def open_ssl_over_tcp_listeners( async def serve_ssl_over_tcp( - handler, - port, - ssl_context, + handler: Callable[[trio.SSLStream], Awaitable[object]], + port: int, + ssl_context: ssl.SSLContext, *, - host=None, - https_compatible=False, - backlog=None, - handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED, -): + host: str | bytes | None = None, + https_compatible: bool = False, + backlog: int | float | None = None, + handler_nursery: trio.Nursery | None = None, + task_status: trio.TaskStatus[list[trio.SSLListener]] = trio.TASK_STATUS_IGNORED, +) -> NoReturn: """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. diff --git a/trio/_socket.py b/trio/_socket.py index b0ec1d480d..b6d5966397 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -599,7 +599,7 @@ def setsockopt( return self._sock.setsockopt(level, optname, value) if value is not None: raise TypeError( - "invalid value for argument 'value': {value!r}, must be None when specifying optlen" + f"invalid value for argument 'value': {value!r}, must be None when specifying optlen" ) # Note: PyPy may crash here due to setsockopt only supporting diff --git a/trio/_tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py index 14143affe2..1a987df3f3 100644 --- a/trio/_tests/test_highlevel_socket.py +++ b/trio/_tests/test_highlevel_socket.py @@ -11,6 +11,7 @@ check_half_closeable_stream, wait_all_tasks_blocked, ) +from .test_socket import setsockopt_tests async def test_SocketStream_basics(): @@ -50,6 +51,8 @@ async def test_SocketStream_basics(): b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1) assert isinstance(b, bytes) + setsockopt_tests(s) + async def test_SocketStream_send_all(): BIG = 10000000 diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py index f01b4fde14..036098b8e5 100644 --- a/trio/_tests/test_socket.py +++ b/trio/_tests/test_socket.py @@ -363,21 +363,26 @@ async def test_SocketType_basics(): async def test_SocketType_setsockopt() -> None: sock = tsocket.socket() with sock as _: - # specifying optlen. Not supported on pypy, and I couldn't find - # valid calls on darwin or win32. - if hasattr(tsocket, "SO_BINDTODEVICE"): - sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0) + setsockopt_tests(sock) - # specifying value - sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) - # specifying both - with pytest.raises(TypeError, match="invalid value for argument 'value'"): - sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) # type: ignore[call-overload] +def setsockopt_tests(sock): + """Extract these out, to be reused for SocketStream also.""" + # specifying optlen. Not supported on pypy, and I couldn't find + # valid calls on darwin or win32. + if hasattr(tsocket, "SO_BINDTODEVICE"): + sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0) + + # specifying value + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) + + # specifying both + with pytest.raises(TypeError, match="invalid value for argument 'value'"): + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) - # specifying neither - with pytest.raises(TypeError, match="invalid value for argument 'value'"): - sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) # type: ignore[call-overload] + # specifying neither + with pytest.raises(TypeError, match="invalid value for argument 'value'"): + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) async def test_SocketType_dup(): diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index b61b28a428..e8c405d2eb 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9872611464968153, + "completenessScore": 0.9968152866242038, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 620, - "withUnknownType": 8 + "withKnownType": 626, + "withUnknownType": 2 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -46,14 +46,11 @@ ], "otherSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 662, - "withUnknownType": 19 + "withKnownType": 666, + "withUnknownType": 15 }, "packageName": "trio", "symbols": [ - "trio._highlevel_socket.SocketStream.getsockopt", - "trio._highlevel_socket.SocketStream.send_all", - "trio._highlevel_socket.SocketStream.setsockopt", "trio._ssl.SSLListener.__init__", "trio._ssl.SSLListener.accept", "trio._ssl.SSLListener.aclose", @@ -71,11 +68,6 @@ "trio.lowlevel.notify_closing", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable", - "trio.open_ssl_over_tcp_listeners", - "trio.open_ssl_over_tcp_stream", - "trio.open_unix_socket", - "trio.serve_listeners", - "trio.serve_ssl_over_tcp", "trio.tests.TestsDeprecationWrapper" ] }