From fbae13dd98f64a369385c922b5367fff0156ad41 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 23 Aug 2023 18:53:22 +0200 Subject: [PATCH 01/20] wip --- docs/source/conf.py | 2 - docs/source/reference-io.rst | 3 +- pyproject.toml | 5 +- trio/_abc.py | 4 +- trio/_dtls.py | 14 +- trio/_highlevel_open_tcp_stream.py | 32 ++- trio/_highlevel_socket.py | 2 +- trio/_socket.py | 122 +++++++++- .../test_highlevel_open_tcp_listeners.py | 19 +- trio/_tests/test_highlevel_open_tcp_stream.py | 214 ++++++++++++------ trio/_tests/test_socket.py | 135 ++++++----- trio/socket.py | 1 - trio/testing/_fake_net.py | 23 +- 13 files changed, 407 insertions(+), 169 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 66aa8dea05..ee23ce587f 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -54,8 +54,6 @@ ("py:class", "sync function"), # why aren't these found in stdlib? ("py:class", "types.FrameType"), - # TODO: temporary type - ("py:class", "_SocketType"), # these are not defined in https://docs.python.org/3/objects.inv ("py:class", "socket.AddressFamily"), ("py:class", "socket.SocketKind"), diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index e270033b46..61bbef78c2 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -506,7 +506,8 @@ Socket objects The internal SocketType ~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: _SocketType +.. + .. autoclass:: _SocketType .. TODO: adding `:members:` here gives error due to overload+_wraps on `sendto` TODO: rewrite ... all of the above when fixing _SocketType vs SocketType diff --git a/pyproject.toml b/pyproject.toml index 24be2d07bf..64e15a4c71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ module = [ "trio/_core/_tests/test_tutil", "trio/_core/_tests/test_unbounded_queue", "trio/_core/_tests/tutil", +"trio/_tests/check_type_completeness", "trio/_tests/pytest_plugin", "trio/_tests/test_abc", "trio/_tests/test_channel", @@ -95,16 +96,14 @@ module = [ "trio/_tests/test_exports", "trio/_tests/test_file_io", "trio/_tests/test_highlevel_generic", +"trio/_tests/test_highlevel_socket", "trio/_tests/test_highlevel_open_tcp_listeners", -"trio/_tests/test_highlevel_open_tcp_stream", "trio/_tests/test_highlevel_open_unix_stream", "trio/_tests/test_highlevel_serve_listeners", -"trio/_tests/test_highlevel_socket", "trio/_tests/test_highlevel_ssl_helpers", "trio/_tests/test_path", "trio/_tests/test_scheduler_determinism", "trio/_tests/test_signals", -"trio/_tests/test_socket", "trio/_tests/test_ssl", "trio/_tests/test_subprocess", "trio/_tests/test_sync", diff --git a/trio/_abc.py b/trio/_abc.py index 746360c8f8..6a99ea0842 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -12,7 +12,7 @@ from typing_extensions import Self # both of these introduce circular imports if outside a TYPE_CHECKING guard - from ._socket import _SocketType + from ._socket import SocketType from .lowlevel import Task @@ -214,7 +214,7 @@ def socket( family: socket.AddressFamily | int | None = None, type: socket.SocketKind | int | None = None, proto: int | None = None, - ) -> _SocketType: + ) -> SocketType: """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, diff --git a/trio/_dtls.py b/trio/_dtls.py index 08b7672a2f..cc940cf206 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -42,26 +42,26 @@ from OpenSSL.SSL import Context from typing_extensions import Self, TypeAlias - from trio.socket import Address, _SocketType + from trio.socket import Address, SocketType MAX_UDP_PACKET_SIZE = 65527 -def packet_header_overhead(sock: _SocketType) -> int: +def packet_header_overhead(sock: SocketType) -> int: if sock.family == trio.socket.AF_INET: return 28 else: return 48 -def worst_case_mtu(sock: _SocketType) -> int: +def worst_case_mtu(sock: SocketType) -> int: if sock.family == trio.socket.AF_INET: return 576 - packet_header_overhead(sock) else: return 1280 - packet_header_overhead(sock) -def best_guess_mtu(sock: _SocketType) -> int: +def best_guess_mtu(sock: SocketType) -> int: return 1500 - packet_header_overhead(sock) @@ -738,7 +738,7 @@ async def handle_client_hello_untrusted( async def dtls_receive_loop( - endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType + endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType ) -> None: try: while True: @@ -1177,7 +1177,7 @@ class DTLSEndpoint(metaclass=Final): """ - def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): + def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. global SSL @@ -1188,7 +1188,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") self._initialized = True - self.socket: _SocketType = socket + self.socket: SocketType = socket self.incoming_packets_buffer = incoming_packets_buffer self._token = trio.lowlevel.current_trio_token() diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 0c4e8a4a8d..322ae4006e 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -8,7 +8,7 @@ import trio from trio._core._multierror import MultiError -from trio.socket import SOCK_STREAM, Address, _SocketType, getaddrinfo, socket +from trio.socket import SOCK_STREAM, Address, SocketType, getaddrinfo, socket if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -114,8 +114,8 @@ @contextmanager -def close_all() -> Generator[set[_SocketType], None, None]: - sockets_to_close: set[_SocketType] = set() +def close_all() -> Generator[set[SocketType], None, None]: + sockets_to_close: set[SocketType] = set() try: yield sockets_to_close finally: @@ -131,6 +131,7 @@ def close_all() -> Generator[set[_SocketType], None, None]: raise MultiError(errs) +# workaround for list being invariant def reorder_for_rfc_6555_section_5_4( targets: list[ tuple[ @@ -140,6 +141,22 @@ def reorder_for_rfc_6555_section_5_4( str, tuple[str, int] | tuple[str, int, int, int], ] + ] | list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] + ] + ] | list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int, int, int], + ] ] ) -> None: # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address @@ -155,11 +172,12 @@ def reorder_for_rfc_6555_section_5_4( # Found the first entry with a different address family; move it # so that it becomes the second item on the list. if i != 1: - targets.insert(1, targets.pop(i)) + # invariant workaround in arguments leads to type issues here + targets.insert(1, targets.pop(i)) # type: ignore[arg-type] break -def format_host_port(host: str | bytes, port: int) -> str: +def format_host_port(host: str | bytes, port: int|str) -> str: host = host.decode("ascii") if isinstance(host, bytes) else host if ":" in host: return f"[{host}]:{port}" @@ -193,7 +211,7 @@ async def open_tcp_stream( *, happy_eyeballs_delay: float | None = DEFAULT_DELAY, local_address: str | None = None, -) -> trio.abc.Stream: +) -> trio.SocketStream: """Connect to the given host and port over TCP. If the given ``host`` has multiple IP addresses associated with it, then @@ -292,7 +310,7 @@ async def open_tcp_stream( # Keeps track of the socket that we're going to complete with, # need to make sure this isn't automatically closed - winning_socket: _SocketType | None = None + winning_socket: SocketType | None = None # Try connecting to the specified address. Possible outcomes: # - success: record connected socket in winning_socket and cancel diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index f8d01cd755..d733537752 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from typing_extensions import Buffer - from ._socket import _SocketType as SocketType + from ._socket import SocketType # XX TODO: this number was picked arbitrarily. We should do experiments to # tune it. (Or make it dynamic -- one idea is to start small and increase it diff --git a/trio/_socket.py b/trio/_socket.py index 2834a5b055..0b64e5adec 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -282,7 +282,7 @@ async def getprotobyname(name: str) -> int: ################################################################ -def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType: +def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType: """Convert a standard library :class:`socket.socket` object into a Trio socket object. @@ -296,7 +296,7 @@ def fromfd( family: AddressFamily | int = _stdlib_socket.AF_INET, type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, -) -> _SocketType: +) -> SocketType: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -307,7 +307,7 @@ def fromfd( ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(info: bytes) -> _SocketType: + def fromshare(info: bytes) -> SocketType: return from_stdlib_socket(_stdlib_socket.fromshare(info)) @@ -326,7 +326,7 @@ def socketpair( family: FamilyT = FamilyDefault, type: TypeT = SocketKind.SOCK_STREAM, proto: int = 0, -) -> tuple[_SocketType, _SocketType]: +) -> tuple[SocketType, SocketType]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. @@ -341,7 +341,7 @@ def socket( type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, fileno: int | None = None, -) -> _SocketType: +) -> SocketType: """Create a new Trio socket, like :class:`socket.socket`. This function's behavior can be customized using @@ -530,7 +530,117 @@ def __init__(self) -> NoReturn: "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" ) + def __enter__(self) -> Self: + raise NotImplementedError() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + raise NotImplementedError() + @property + def type(self) -> SocketKind: + raise NotImplementedError() + @property + def family(self) -> AddressFamily: + raise NotImplementedError() + @property + def proto(self) -> int: + raise NotImplementedError() + @property + def did_shutdown_SHUT_WR(self) -> bool: + raise NotImplementedError() + def is_readable(self) -> bool: + raise NotImplementedError() + def fileno(self) -> int: + raise NotImplementedError() + async def wait_writable(self) -> None: + raise NotImplementedError() + def shutdown(self, flag: int) -> None: + raise NotImplementedError() + async def connect(self, address: Address) -> None: + raise NotImplementedError() + def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + raise NotImplementedError() + async def bind(self, address: Address) -> None: + raise NotImplementedError() + def close(self) -> None: + raise NotImplementedError() + def getsockname(self) -> Any: + raise NotImplementedError() + async def accept(self) -> tuple[SocketType, object]: + raise NotImplementedError() + + @overload + async def sendto( + self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @overload + async def sendto( + self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] + async def sendto(self, *args: Any) -> int: + """Similar to :meth:`socket.socket.sendto`, but async.""" + raise NotImplementedError() + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + raise NotImplementedError() + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + raise NotImplementedError() + + def recvfrom( + __self, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, Address]]: + raise NotImplementedError() + def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: + raise NotImplementedError() + def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: + raise NotImplementedError() + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): + + def share(self, /, process_id: int) -> bytes: + raise NotImplementedError() + def detach(self) -> int: + raise NotImplementedError() + def get_inheritable(self) -> bool: + raise NotImplementedError() + + def set_inheritable(self, inheritable: bool) -> None: + raise NotImplementedError() class _SocketType(SocketType): def __init__(self, sock: _stdlib_socket.socket): @@ -772,7 +882,7 @@ async def _nonblocking_helper( _stdlib_socket.socket.accept, _core.wait_readable ) - async def accept(self) -> tuple[_SocketType, object]: + async def accept(self) -> tuple[SocketType, object]: """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py index 6eca844f0c..71b4c72a98 100644 --- a/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -2,6 +2,7 @@ import socket as stdlib_socket import sys from math import inf +from socket import AddressFamily, SocketKind import attr import pytest @@ -115,14 +116,26 @@ class FakeOSError(OSError): @attr.s class FakeSocket(tsocket.SocketType): - family = attr.ib() - type = attr.ib() - proto = attr.ib() + _family: SocketKind = attr.ib() + _type: AddressFamily = attr.ib() + _proto: int = attr.ib() closed = attr.ib(default=False) poison_listen = attr.ib(default=False) backlog = attr.ib(default=None) + @property + def type(self) -> SocketKind: + return self._type + + @property + def family(self) -> AddressFamily: + return self._family + + @property + def proto(self) -> int: + return self._proto + def getsockopt(self, level, option): if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN): return True diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py index 24f82bddd5..a5f16ad777 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -1,9 +1,11 @@ +from __future__ import annotations import socket import sys import attr import pytest +from typing import Any, Sequence, TYPE_CHECKING import trio from trio._highlevel_open_tcp_stream import ( close_all, @@ -11,24 +13,32 @@ open_tcp_stream, reorder_for_rfc_6555_section_5_4, ) -from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM +from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType, Address +from socket import AddressFamily, SocketKind + +if TYPE_CHECKING: + from trio.testing import MockClock if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup -def test_close_all(): - class CloseMe: +def test_close_all() -> None: + class CloseMe(SocketType): + def __init__(self) -> None: + ... closed = False - def close(self): + def close(self) -> None: self.closed = True - class CloseKiller: - def close(self): + class CloseKiller(SocketType): + def __init__(self) -> None: + ... + def close(self) -> None: raise OSError - c = CloseMe() + c: CloseMe = CloseMe() with close_all() as to_close: to_close.add(c) assert c.closed @@ -48,8 +58,10 @@ def close(self): assert c.closed -def test_reorder_for_rfc_6555_section_5_4(): - def fake4(i): +def test_reorder_for_rfc_6555_section_5_4() -> None: + def fake4( + i: int, + ) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]: return ( AF_INET, SOCK_STREAM, @@ -58,7 +70,9 @@ def fake4(i): (f"10.0.0.{i}", 80), ) - def fake6(i): + def fake6( + i: int, + ) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]: return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", (f"::{i}", 80)) for fake in fake4, fake6: @@ -85,7 +99,7 @@ def fake6(i): assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)] -def test_format_host_port(): +def test_format_host_port() -> None: assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port("example.com", 443) == "example.com:443" @@ -95,7 +109,7 @@ def test_format_host_port(): # Make sure we can connect to localhost using real kernel sockets -async def test_open_tcp_stream_real_socket_smoketest(): +async def test_open_tcp_stream_real_socket_smoketest() -> None: listen_sock = trio.socket.socket() await listen_sock.bind(("127.0.0.1", 0)) _, listen_port = listen_sock.getsockname() @@ -110,23 +124,23 @@ async def test_open_tcp_stream_real_socket_smoketest(): listen_sock.close() -async def test_open_tcp_stream_input_validation(): +async def test_open_tcp_stream_input_validation() -> None: with pytest.raises(ValueError): - await open_tcp_stream(None, 80) + await open_tcp_stream(None, 80) # type: ignore[arg-type] with pytest.raises(TypeError): - await open_tcp_stream("127.0.0.1", b"80") + await open_tcp_stream("127.0.0.1", b"80") # type: ignore[arg-type] -def can_bind_127_0_0_2(): +def can_bind_127_0_0_2() -> bool: with socket.socket() as s: try: s.bind(("127.0.0.2", 0)) except OSError: return False - return s.getsockname()[0] == "127.0.0.2" + return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return] -async def test_local_address_real(): +async def test_local_address_real() -> None: with trio.socket.socket() as listener: await listener.bind(("127.0.0.1", 0)) listener.listen() @@ -153,11 +167,11 @@ async def test_local_address_real(): assert client_stream.socket.getsockopt( trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT ) - server_sock, remote_addr = await listener.accept() await client_stream.aclose() server_sock.close() - assert remote_addr[0] == local_address + # accept returns tuple[SocketType, object] + assert remote_addr[0] == local_address # type: ignore[index] # Trying to connect to an ipv4 address with the ipv6 wildcard # local_address should fail @@ -178,18 +192,30 @@ async def test_local_address_real(): @attr.s(eq=False) class FakeSocket(trio.socket.SocketType): - scenario = attr.ib() - family = attr.ib() - type = attr.ib() - proto = attr.ib() - - ip = attr.ib(default=None) - port = attr.ib(default=None) - succeeded = attr.ib(default=False) - closed = attr.ib(default=False) - failing = attr.ib(default=False) - - async def connect(self, sockaddr): + scenario: Scenario = attr.ib() + _family: AddressFamily = attr.ib(alias="_family") + _type: SocketKind = attr.ib(alias="_type") + _proto: int = attr.ib(alias="_proto") + + ip: str | int | None = attr.ib(default=None) + port: str | int | None = attr.ib(default=None) + succeeded: bool = attr.ib(default=False) + closed: bool = attr.ib(default=False) + failing: bool = attr.ib(default=False) + + @property + def type(self) -> SocketKind: + return self._type + + @property + def family(self) -> AddressFamily: + return self._family + + @property + def proto(self) -> int: + return self._proto + + async def connect(self, sockaddr: Address) -> None: self.ip = sockaddr[0] self.port = sockaddr[1] assert self.ip not in self.scenario.sockets @@ -203,11 +229,11 @@ async def connect(self, sockaddr): self.failing = True self.succeeded = True - def close(self): + def close(self) -> None: self.closed = True # called when SocketStream is constructed - def setsockopt(self, *args, **kwargs): + def setsockopt(self, *args: object, **kwargs: object) -> None: if self.failing: # raise something that isn't OSError as SocketStream # ignores those @@ -215,11 +241,16 @@ def setsockopt(self, *args, **kwargs): class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver): - def __init__(self, port, ip_list, supported_families): + def __init__( + self, + port: int, + ip_list: Sequence[tuple[str, float, str]], + supported_families: set[AddressFamily], + ): # ip_list have to be unique ip_order = [ip for (ip, _, _) in ip_list] assert len(set(ip_order)) == len(ip_list) - ip_dict = {} + ip_dict: dict[str | int, tuple[float, str]] = {} for ip, delay, result in ip_list: assert 0 <= delay assert result in ["error", "success", "postconnect_fail"] @@ -230,16 +261,32 @@ def __init__(self, port, ip_list, supported_families): self.ip_dict = ip_dict self.supported_families = supported_families self.socket_count = 0 - self.sockets = {} - self.connect_times = {} - - def socket(self, family, type, proto): + self.sockets: dict[str | int, FakeSocket] = {} + self.connect_times: dict[str | int, float] = {} + + def socket( + self, + family: AddressFamily | int | None = None, + type: SocketKind | int | None = None, + proto: int | None = None, + ) -> SocketType: + assert isinstance(family, AddressFamily) + assert isinstance(type, SocketKind) if family not in self.supported_families: raise OSError("pretending not to support this family") self.socket_count += 1 return FakeSocket(self, family, type, proto) - def _ip_to_gai_entry(self, ip): + def _ip_to_gai_entry( + self, ip: str + ) -> tuple[ + AddressFamily, + SocketKind, + int | None, + str, + tuple[int | str, int, int, int] | tuple[int | str, int], + ]: + sockaddr: tuple[int | str, int] | tuple[int | str, int, int, int] if ":" in ip: family = trio.socket.AF_INET6 sockaddr = (ip, self.port, 0, 0) @@ -248,7 +295,25 @@ def _ip_to_gai_entry(self, ip): sockaddr = (ip, self.port) return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr) - async def getaddrinfo(self, host, port, family, type, proto, flags): + # should hostnameresolver use AddressFamily and SocketKind, instead of int&int? + # the return type in supertype is ... wildly incompatible with what this returns + async def getaddrinfo( # type: ignore[override] + self, + host: str | bytes | None, + port: bytes | str | int | None, + family: int = -1, + type: int = -1, + proto: int = -1, + flags: int = -1, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int | None, + str, + tuple[int | str, int, int, int] | tuple[int | str, int], + ] + ]: assert host == b"test.example.com" assert port == self.port assert family == trio.socket.AF_UNSPEC @@ -257,10 +322,12 @@ async def getaddrinfo(self, host, port, family, type, proto, flags): assert flags == 0 return [self._ip_to_gai_entry(ip) for ip in self.ip_order] - async def getnameinfo(self, sockaddr, flags): # pragma: no cover + async def getnameinfo( # pragma: no cover + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, str]: raise NotImplementedError - def check(self, succeeded): + def check(self, succeeded: SocketType | None) -> None: # sockets only go into self.sockets when connect is called; make sure # all the sockets that were created did in fact go in there. assert self.socket_count == len(self.sockets) @@ -274,24 +341,24 @@ def check(self, succeeded): async def run_scenario( # The port to connect to - port, + port: int, # A list of # (ip, delay, result) # tuples, where delay is in seconds and result is "success" or "error" # The ip's will be returned from getaddrinfo in this order, and then # connect() calls to them will have the given result. - ip_list, + ip_list: Sequence[tuple[str, float, str]], *, # If False, AF_INET4/6 sockets error out on creation, before connect is # even called. - ipv4_supported=True, - ipv6_supported=True, + ipv4_supported: bool = True, + ipv6_supported: bool = True, # Normally, we return (winning_sock, scenario object) # If this is True, we require there to be an exception, and return # (exception, scenario object) - expect_error=(), - **kwargs, -): + expect_error: tuple[type[BaseException], ...] | type[BaseException] = (), + **kwargs: Any, +) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]: supported_families = set() if ipv4_supported: supported_families.add(trio.socket.AF_INET) @@ -313,19 +380,21 @@ async def run_scenario( return (exc, scenario) -async def test_one_host_quick_success(autojump_clock): +async def test_one_host_quick_success(autojump_clock: trio.testing.MockClock) -> None: sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")]) + assert isinstance(sock, FakeSocket) assert sock.ip == "1.2.3.4" assert trio.current_time() == 0.123 -async def test_one_host_slow_success(autojump_clock): +async def test_one_host_slow_success(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")]) + assert isinstance(sock, FakeSocket) assert sock.ip == "1.2.3.4" assert trio.current_time() == 100 -async def test_one_host_quick_fail(autojump_clock): +async def test_one_host_quick_fail(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario( 82, [("1.2.3.4", 0.123, "error")], expect_error=OSError ) @@ -333,7 +402,7 @@ async def test_one_host_quick_fail(autojump_clock): assert trio.current_time() == 0.123 -async def test_one_host_slow_fail(autojump_clock): +async def test_one_host_slow_fail(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario( 83, [("1.2.3.4", 100, "error")], expect_error=OSError ) @@ -341,7 +410,7 @@ async def test_one_host_slow_fail(autojump_clock): assert trio.current_time() == 100 -async def test_one_host_failed_after_connect(autojump_clock): +async def test_one_host_failed_after_connect(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario( 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt ) @@ -349,7 +418,7 @@ async def test_one_host_failed_after_connect(autojump_clock): # With the default 0.250 second delay, the third attempt will win -async def test_basic_fallthrough(autojump_clock): +async def test_basic_fallthrough(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -358,6 +427,7 @@ async def test_basic_fallthrough(autojump_clock): ("3.3.3.3", 0.2, "success"), ], ) + assert isinstance(sock, FakeSocket) assert sock.ip == "3.3.3.3" # current time is default time + default time + connection time assert trio.current_time() == (0.250 + 0.250 + 0.2) @@ -368,7 +438,7 @@ async def test_basic_fallthrough(autojump_clock): } -async def test_early_success(autojump_clock): +async def test_early_success(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -377,6 +447,7 @@ async def test_early_success(autojump_clock): ("3.3.3.3", 0.2, "success"), ], ) + assert isinstance(sock, FakeSocket) assert sock.ip == "2.2.2.2" assert trio.current_time() == (0.250 + 0.1) assert scenario.connect_times == { @@ -387,7 +458,7 @@ async def test_early_success(autojump_clock): # With a 0.450 second delay, the first attempt will win -async def test_custom_delay(autojump_clock): +async def test_custom_delay(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -397,6 +468,7 @@ async def test_custom_delay(autojump_clock): ], happy_eyeballs_delay=0.450, ) + assert isinstance(sock, FakeSocket) assert sock.ip == "1.1.1.1" assert trio.current_time() == 1 assert scenario.connect_times == { @@ -406,7 +478,7 @@ async def test_custom_delay(autojump_clock): } -async def test_custom_errors_expedite(autojump_clock): +async def test_custom_errors_expedite(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -417,6 +489,7 @@ async def test_custom_errors_expedite(autojump_clock): ("4.4.4.4", 0.25, "success"), ], ) + assert isinstance(sock, FakeSocket) assert sock.ip == "4.4.4.4" assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25) assert scenario.connect_times == { @@ -427,7 +500,7 @@ async def test_custom_errors_expedite(autojump_clock): } -async def test_all_fail(autojump_clock): +async def test_all_fail(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario( 80, [ @@ -450,7 +523,7 @@ async def test_all_fail(autojump_clock): } -async def test_multi_success(autojump_clock): +async def test_multi_success(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -469,6 +542,7 @@ async def test_multi_success(autojump_clock): or scenario.sockets["4.4.4.4"].succeeded ) assert not scenario.sockets["5.5.5.5"].succeeded + assert isinstance(sock, FakeSocket) assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"] assert trio.current_time() == (0.5 + 10) assert scenario.connect_times == { @@ -480,7 +554,7 @@ async def test_multi_success(autojump_clock): } -async def test_does_reorder(autojump_clock): +async def test_does_reorder(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -492,6 +566,7 @@ async def test_does_reorder(autojump_clock): ], happy_eyeballs_delay=1, ) + assert isinstance(sock, FakeSocket) assert sock.ip == "::3" assert trio.current_time() == 1 + 0.5 assert scenario.connect_times == { @@ -500,7 +575,7 @@ async def test_does_reorder(autojump_clock): } -async def test_handles_no_ipv4(autojump_clock): +async def test_handles_no_ipv4(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect @@ -514,6 +589,7 @@ async def test_handles_no_ipv4(autojump_clock): happy_eyeballs_delay=1, ipv4_supported=False, ) + assert isinstance(sock, FakeSocket) assert sock.ip == "::3" assert trio.current_time() == 1 + 0.1 assert scenario.connect_times == { @@ -522,7 +598,7 @@ async def test_handles_no_ipv4(autojump_clock): } -async def test_handles_no_ipv6(autojump_clock): +async def test_handles_no_ipv6(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect @@ -536,6 +612,7 @@ async def test_handles_no_ipv6(autojump_clock): happy_eyeballs_delay=1, ipv6_supported=False, ) + assert isinstance(sock, FakeSocket) assert sock.ip == "4.4.4.4" assert trio.current_time() == 1 + 0.1 assert scenario.connect_times == { @@ -544,12 +621,12 @@ async def test_handles_no_ipv6(autojump_clock): } -async def test_no_hosts(autojump_clock): +async def test_no_hosts(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario(80, [], expect_error=OSError) assert "no results found" in str(exc) -async def test_cancel(autojump_clock): +async def test_cancel(autojump_clock: MockClock) -> None: with trio.move_on_after(5) as cancel_scope: exc, scenario = await run_scenario( 80, @@ -561,6 +638,7 @@ async def test_cancel(autojump_clock): ], expect_error=BaseExceptionGroup, ) + assert isinstance(exc, BaseException) # What comes out should be 1 or more Cancelled errors that all belong # to this cancel_scope; this is the easiest way to check that raise exc @@ -571,4 +649,4 @@ async def test_cancel(autojump_clock): # This should have been called already, but just to make sure, since the # exception-handling logic in run_scenario is a bit complicated and the # main thing we care about here is that all the sockets were cleaned up. - scenario.check(succeeded=False) + scenario.check(succeeded=None) diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py index 036098b8e5..8536c66f89 100644 --- a/trio/_tests/test_socket.py +++ b/trio/_tests/test_socket.py @@ -1,31 +1,41 @@ +from __future__ import annotations import errno import inspect import os import socket as stdlib_socket import sys import tempfile +from typing import Callable, Any, TYPE_CHECKING, Tuple, List, Union +from socket import SocketKind, AddressFamily import attr import pytest from .. import _core, socket as tsocket from .._core._tests.tutil import binds_ipv6, creates_ipv6 -from .._socket import _NUMERIC_ONLY, _try_sync +from .._socket import _NUMERIC_ONLY, _try_sync, SocketType from ..testing import assert_checkpoints, wait_all_tasks_blocked +if TYPE_CHECKING: + from typing_extensions import TypeAlias + GaiTuple: TypeAlias = Tuple[AddressFamily, SocketKind, int, str, Union[Tuple[str, int],Tuple[str, int, int, int]]] + getaddrinfoResponse: TypeAlias = List[GaiTuple] +else: + GaiTuple: object + getaddrinfoResponse = object + ################################################################ # utils ################################################################ - class MonkeypatchedGAI: - def __init__(self, orig_getaddrinfo): + def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]): self._orig_getaddrinfo = orig_getaddrinfo - self._responses = {} - self.record = [] + self._responses: dict[tuple[Any, ...], getaddrinfoResponse|str] = {} + self.record: list[tuple[Any, ...]] = [] # get a normalized getaddrinfo argument tuple - def _frozenbind(self, *args, **kwargs): + def _frozenbind(self, *args: Any, **kwargs: Any) -> tuple[Any, ...]: sig = inspect.signature(self._orig_getaddrinfo) bound = sig.bind(*args, **kwargs) bound.apply_defaults() @@ -33,10 +43,10 @@ def _frozenbind(self, *args, **kwargs): assert not bound.kwargs return frozenbound - def set(self, response, *args, **kwargs): + def set(self, response: getaddrinfoResponse|str, *args: Any, **kwargs: Any) -> None: self._responses[self._frozenbind(*args, **kwargs)] = response - def getaddrinfo(self, *args, **kwargs): + def getaddrinfo(self, *args: Any, **kwargs: Any) -> getaddrinfoResponse|str: bound = self._frozenbind(*args, **kwargs) self.record.append(bound) if bound in self._responses: @@ -48,13 +58,13 @@ def getaddrinfo(self, *args, **kwargs): @pytest.fixture -def monkeygai(monkeypatch): +def monkeygai(monkeypatch: pytest.MonkeyPatch) -> MonkeypatchedGAI: controller = MonkeypatchedGAI(stdlib_socket.getaddrinfo) monkeypatch.setattr(stdlib_socket, "getaddrinfo", controller.getaddrinfo) return controller -async def test__try_sync(): +async def test__try_sync() -> None: with assert_checkpoints(): async with _try_sync(): pass @@ -67,7 +77,7 @@ async def test__try_sync(): async with _try_sync(): raise BlockingIOError - def _is_ValueError(exc): + def _is_ValueError(exc: BaseException) -> bool: return isinstance(exc, ValueError) async with _try_sync(_is_ValueError): @@ -84,7 +94,7 @@ def _is_ValueError(exc): ################################################################ -def test_socket_has_some_reexports(): +def test_socket_has_some_reexports() -> None: assert tsocket.SOL_SOCKET == stdlib_socket.SOL_SOCKET assert tsocket.TCP_NODELAY == stdlib_socket.TCP_NODELAY assert tsocket.gaierror == stdlib_socket.gaierror @@ -96,19 +106,19 @@ def test_socket_has_some_reexports(): ################################################################ -async def test_getaddrinfo(monkeygai): - def check(got, expected): +async def test_getaddrinfo(monkeygai: MonkeypatchedGAI) -> None: + def check(got: getaddrinfoResponse, expected: getaddrinfoResponse) -> None: # win32 returns 0 for the proto field # musl and glibc have inconsistent handling of the canonical name # field (https://github.com/python-trio/trio/issues/1499) # Neither field gets used much and there isn't much opportunity for us # to mess them up, so we don't bother checking them here - def interesting_fields(gai_tup): + def interesting_fields(gai_tup: GaiTuple) -> tuple[AddressFamily, SocketKind, tuple[str, int]|tuple[str, int, int]|tuple[str, int, int, int]]: # (family, type, proto, canonname, sockaddr) family, type, proto, canonname, sockaddr = gai_tup return (family, type, sockaddr) - def filtered(gai_list): + def filtered(gai_list: getaddrinfoResponse) -> list[tuple[AddressFamily, SocketKind, tuple[str, int]|tuple[str, int, int]|tuple[str, int, int, int]]]: return [interesting_fields(gai_tup) for gai_tup in gai_list] assert filtered(got) == filtered(expected) @@ -172,7 +182,7 @@ def filtered(gai_list): await tsocket.getaddrinfo("asdf", "12345") -async def test_getnameinfo(): +async def test_getnameinfo() -> None: # Trivial test: ni_numeric = stdlib_socket.NI_NUMERICHOST | stdlib_socket.NI_NUMERICSERV with assert_checkpoints(): @@ -207,7 +217,7 @@ async def test_getnameinfo(): ################################################################ -async def test_from_stdlib_socket(): +async def test_from_stdlib_socket() -> None: sa, sb = stdlib_socket.socketpair() assert not isinstance(sa, tsocket.SocketType) with sa, sb: @@ -219,7 +229,7 @@ async def test_from_stdlib_socket(): # rejects other types with pytest.raises(TypeError): - tsocket.from_stdlib_socket(1) + tsocket.from_stdlib_socket(1) # type: ignore[arg-type] class MySocket(stdlib_socket.socket): pass @@ -229,7 +239,7 @@ class MySocket(stdlib_socket.socket): tsocket.from_stdlib_socket(mysock) -async def test_from_fd(): +async def test_from_fd() -> None: sa, sb = stdlib_socket.socketpair() ta = tsocket.fromfd(sa.fileno(), sa.family, sa.type, sa.proto) with sa, sb, ta: @@ -238,8 +248,8 @@ async def test_from_fd(): assert sb.recv(3) == b"x" -async def test_socketpair_simple(): - async def child(sock): +async def test_socketpair_simple() -> None: + async def child(sock: SocketType) -> None: print("sending hello") await sock.send(b"h") assert await sock.recv(1) == b"h" @@ -252,7 +262,8 @@ async def child(sock): @pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only") -async def test_fromshare(): +async def test_fromshare() -> None: + assert not TYPE_CHECKING or sys.platform == "win32" a, b = tsocket.socketpair() with a, b: # share with ourselves @@ -264,21 +275,21 @@ async def test_fromshare(): assert await b.recv(1) == b"x" -async def test_socket(): +async def test_socket() -> None: with tsocket.socket() as s: assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET @creates_ipv6 -async def test_socket_v6(): +async def test_socket_v6() -> None: with tsocket.socket(tsocket.AF_INET6, tsocket.SOCK_DGRAM) as s: assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET6 @pytest.mark.skipif(not sys.platform == "linux", reason="linux only") -async def test_sniff_sockopts(): +async def test_sniff_sockopts() -> None: from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM # generate the combinations of families/types we're testing: @@ -309,7 +320,7 @@ async def test_sniff_sockopts(): ################################################################ -async def test_SocketType_basics(): +async def test_SocketType_basics() -> None: sock = tsocket.socket() with sock as cm_enter_value: assert cm_enter_value is sock @@ -349,7 +360,7 @@ async def test_SocketType_basics(): # our __getattr__ handles unknown names with pytest.raises(AttributeError): - sock.asdf + sock.asdf # type: ignore[attr-defined] # type family proto stdlib_sock = stdlib_socket.socket() @@ -366,7 +377,7 @@ async def test_SocketType_setsockopt() -> None: setsockopt_tests(sock) -def setsockopt_tests(sock): +def setsockopt_tests(sock: SocketType) -> None: """Extract these out, to be reused for SocketStream also.""" # specifying optlen. Not supported on pypy, and I couldn't find # valid calls on darwin or win32. @@ -378,14 +389,14 @@ def setsockopt_tests(sock): # specifying both with pytest.raises(TypeError, match="invalid value for argument 'value'"): - sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) # type: ignore[call-overload] # specifying neither with pytest.raises(TypeError, match="invalid value for argument 'value'"): - sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) # type: ignore[call-overload] -async def test_SocketType_dup(): +async def test_SocketType_dup() -> None: a, b = tsocket.socketpair() with a, b: a2 = a.dup() @@ -397,7 +408,7 @@ async def test_SocketType_dup(): assert await b.recv(1) == b"x" -async def test_SocketType_shutdown(): +async def test_SocketType_shutdown() -> None: a, b = tsocket.socketpair() with a, b: await a.send(b"x") @@ -431,7 +442,7 @@ async def test_SocketType_shutdown(): pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6), ], ) -async def test_SocketType_simple_server(address, socket_type): +async def test_SocketType_simple_server(address, socket_type) -> None: # listen, bind, accept, connect, getpeername, getsockname listener = tsocket.socket(socket_type) client = tsocket.socket(socket_type) @@ -448,7 +459,7 @@ async def test_SocketType_simple_server(address, socket_type): assert await client.recv(1) == b"x" -async def test_SocketType_is_readable(): +async def test_SocketType_is_readable() -> None: a, b = tsocket.socketpair() with a, b: assert not a.is_readable() @@ -462,7 +473,7 @@ async def test_SocketType_is_readable(): # On some macOS systems, getaddrinfo likes to return V4-mapped addresses even # when we *don't* pass AI_V4MAPPED. # https://github.com/python-trio/trio/issues/580 -def gai_without_v4mapped_is_buggy(): # pragma: no cover +def gai_without_v4mapped_is_buggy() -> bool: # pragma: no cover try: stdlib_socket.getaddrinfo("1.2.3.4", 0, family=stdlib_socket.AF_INET6) except stdlib_socket.gaierror: @@ -504,16 +515,16 @@ class Addresses: ), ], ) -async def test_SocketType_resolve(socket_type, addrs): +async def test_SocketType_resolve(socket_type, addrs) -> None: v6 = socket_type == tsocket.AF_INET6 - def pad(addr): + def pad(addr) -> tuple[str, int]|tuple[str, int,int,int]: if v6: while len(addr) < 4: addr += (0,) return addr - def assert_eq(actual, expected): + def assert_eq(actual, expected) -> None: assert pad(expected) == pad(actual) with tsocket.socket(family=socket_type) as sock: @@ -599,7 +610,7 @@ async def res(*args): await res(("1.2.3.4", 80, 0, 0)) -async def test_SocketType_unresolved_names(): +async def test_SocketType_unresolved_names() -> None: with tsocket.socket() as sock: await sock.bind(("localhost", 0)) assert sock.getsockname()[0] == "127.0.0.1" @@ -618,7 +629,7 @@ async def test_SocketType_unresolved_names(): # This tests all the complicated paths through _nonblocking_helper, using recv # as a stand-in for all the methods that use _nonblocking_helper. -async def test_SocketType_non_blocking_paths(): +async def test_SocketType_non_blocking_paths() -> None: a, b = stdlib_socket.socketpair() with a, b: ta = tsocket.from_stdlib_socket(a) @@ -641,7 +652,7 @@ async def test_SocketType_non_blocking_paths(): await ta.recv("haha") # block then succeed - async def do_successful_blocking_recv(): + async def do_successful_blocking_recv() -> None: with assert_checkpoints(): assert await ta.recv(10) == b"2" @@ -651,7 +662,7 @@ async def do_successful_blocking_recv(): b.send(b"2") # block then cancelled - async def do_cancelled_blocking_recv(): + async def do_cancelled_blocking_recv() -> None: with assert_checkpoints(): with pytest.raises(_core.Cancelled): await ta.recv(10) @@ -669,13 +680,13 @@ async def do_cancelled_blocking_recv(): # other: tb = tsocket.from_stdlib_socket(b) - async def t1(): + async def t1() -> None: with assert_checkpoints(): assert await ta.recv(1) == b"a" with assert_checkpoints(): assert await tb.recv(1) == b"b" - async def t2(): + async def t2() -> None: with assert_checkpoints(): assert await tb.recv(1) == b"b" with assert_checkpoints(): @@ -693,7 +704,7 @@ async def t2(): # This tests the complicated paths through connect -async def test_SocketType_connect_paths(): +async def test_SocketType_connect_paths() -> None: with tsocket.socket() as sock: with pytest.raises(ValueError): # Should be a tuple @@ -716,7 +727,7 @@ async def test_SocketType_connect_paths(): # nose -- and then swap it back out again before we hit # wait_socket_writable, which insists on a real socket. class CancelSocket(stdlib_socket.socket): - def connect(self, *args, **kwargs): + def connect(self, *args, **kwargs) -> None: cancel_scope.cancel() sock._sock = stdlib_socket.fromfd( self.detach(), self.family, self.type @@ -747,7 +758,7 @@ def connect(self, *args, **kwargs): # Fix issue #1810 -async def test_address_in_socket_error(): +async def test_address_in_socket_error() -> None: address = "127.0.0.1" with tsocket.socket() as sock: try: @@ -756,12 +767,12 @@ async def test_address_in_socket_error(): assert any(address in str(arg) for arg in e.args) -async def test_resolve_address_exception_in_connect_closes_socket(): +async def test_resolve_address_exception_in_connect_closes_socket() -> None: # Here we are testing issue 247, any cancellation will leave the socket closed with _core.CancelScope() as cancel_scope: with tsocket.socket() as sock: - async def _resolve_address_nocp(self, *args, **kwargs): + async def _resolve_address_nocp(self, *args, **kwargs) -> None: cancel_scope.cancel() await _core.checkpoint() @@ -772,7 +783,7 @@ async def _resolve_address_nocp(self, *args, **kwargs): assert sock.fileno() == -1 -async def test_send_recv_variants(): +async def test_send_recv_variants() -> None: a, b = tsocket.socketpair() with a, b: # recv, including with flags @@ -868,7 +879,7 @@ async def test_send_recv_variants(): assert await b.recv(10) == b"yyy" -async def test_idna(monkeygai): +async def test_idna(monkeygai) -> None: # This is the encoding for "faß.de", which uses one of the characters that # IDNA 2003 handles incorrectly: monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80) @@ -886,14 +897,14 @@ async def test_idna(monkeygai): assert "ok faß.de" == await tsocket.getaddrinfo(b"xn--fa-hia.de", 80) -async def test_getprotobyname(): +async def test_getprotobyname() -> None: # These are the constants used in IP header fields, so the numeric values # had *better* be stable across systems... assert await tsocket.getprotobyname("udp") == 17 assert await tsocket.getprotobyname("tcp") == 6 -async def test_custom_hostname_resolver(monkeygai): +async def test_custom_hostname_resolver(monkeygai) -> None: class CustomResolver: async def getaddrinfo(self, host, port, family, type, proto, flags): return ("custom_gai", host, port, family, type, proto, flags) @@ -937,7 +948,7 @@ async def getnameinfo(self, sockaddr, flags): assert await tsocket.getaddrinfo("host", "port") == "x" -async def test_custom_socket_factory(): +async def test_custom_socket_factory() -> None: class CustomSocketFactory: def socket(self, family, type, proto): return ("hi", family, type, proto) @@ -964,17 +975,17 @@ def socket(self, family, type, proto): assert tsocket.set_custom_socket_factory(None) is csf -async def test_SocketType_is_abstract(): +async def test_SocketType_is_abstract() -> None: with pytest.raises(TypeError): tsocket.SocketType() @pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets") -async def test_unix_domain_socket(): +async def test_unix_domain_socket() -> None: # Bind has a special branch to use a thread, since it has to do filesystem # traversal. Maybe connect should too? Not sure. - async def check_AF_UNIX(path): + async def check_AF_UNIX(path) -> None: with tsocket.socket(family=tsocket.AF_UNIX) as lsock: await lsock.bind(path) lsock.listen(10) @@ -999,7 +1010,7 @@ async def check_AF_UNIX(path): pass -async def test_interrupted_by_close(): +async def test_interrupted_by_close() -> None: a_stdlib, b_stdlib = stdlib_socket.socketpair() with a_stdlib, b_stdlib: a_stdlib.setblocking(False) @@ -1014,11 +1025,11 @@ async def test_interrupted_by_close(): a = tsocket.from_stdlib_socket(a_stdlib) - async def sender(): + async def sender() -> None: with pytest.raises(_core.ClosedResourceError): await a.send(data) - async def receiver(): + async def receiver() -> None: with pytest.raises(_core.ClosedResourceError): await a.recv(1) @@ -1029,7 +1040,7 @@ async def receiver(): a.close() -async def test_many_sockets(): +async def test_many_sockets() -> None: total = 5000 # Must be more than MAX_AFD_GROUP_SIZE sockets = [] for x in range(total // 2): diff --git a/trio/socket.py b/trio/socket.py index f8d0bc3fc2..ddefc72649 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -36,7 +36,6 @@ from ._socket import ( Address as Address, SocketType as SocketType, - _SocketType as _SocketType, from_stdlib_socket as from_stdlib_socket, fromfd as fromfd, getaddrinfo as getaddrinfo, diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index ddf46174f3..2a358119f3 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -11,7 +11,7 @@ import errno import ipaddress import os -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union, Type import attr @@ -21,6 +21,7 @@ if TYPE_CHECKING: from socket import AddressFamily, SocketKind from types import TracebackType + from typing_extensions import TypeAlias IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -174,7 +175,7 @@ def deliver_packet(self, packet) -> None: class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor): - def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): + def __init__(self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, proto: int): self._fake_net = fake_net if not family: @@ -187,9 +188,9 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): if type != trio.socket.SOCK_DGRAM: raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}") - self.family = family - self.type = type - self.proto = proto + self._family = family + self._type = type + self._proto = proto self._closed = False @@ -199,6 +200,15 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): # This is the source-of-truth for what port etc. this socket is bound to self._binding: Optional[UDPBinding] = None + @property + def type(self) -> SocketKind: + return self._type + @property + def family(self) -> AddressFamily: + return self._family + @property + def proto(self) -> int: + return self._proto def _check_closed(self): if self._closed: @@ -362,7 +372,8 @@ def __enter__(self): def __exit__( self, - exc_type: type[BaseException] | None, + # builtin `type` is shadowed by the property + exc_type: Type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: From 1a38b7e4eda4e2702401e1bb651fd386a0eda1ca Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 25 Aug 2023 17:40:54 +0200 Subject: [PATCH 02/20] wip2 --- pyproject.toml | 4 +- trio/_core/_generated_io_epoll.py | 9 +- trio/_core/_io_epoll.py | 9 +- trio/_socket.py | 213 ++++++++++++------ .../test_highlevel_open_tcp_listeners.py | 74 ++++-- trio/_tests/test_highlevel_open_tcp_stream.py | 10 +- trio/_tests/test_highlevel_socket.py | 87 ++++--- trio/_tests/test_socket.py | 143 +++++++++--- trio/_tools/gen_exports.py | 4 + 9 files changed, 388 insertions(+), 165 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 64e15a4c71..aeb35977a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,8 +96,8 @@ module = [ "trio/_tests/test_exports", "trio/_tests/test_file_io", "trio/_tests/test_highlevel_generic", -"trio/_tests/test_highlevel_socket", -"trio/_tests/test_highlevel_open_tcp_listeners", +#"trio/_tests/test_highlevel_socket", # +#"trio/_tests/test_highlevel_open_tcp_listeners", # "trio/_tests/test_highlevel_open_unix_stream", "trio/_tests/test_highlevel_serve_listeners", "trio/_tests/test_highlevel_ssl_helpers", diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index abe49ed3ff..54957952be 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -10,12 +10,15 @@ from ._run import GLOBAL_RUN_CONTEXT from socket import socket from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .._file_io import _HasFileNo import sys assert not TYPE_CHECKING or sys.platform=="linux" -async def wait_readable(fd: (int | socket)) ->None: +async def wait_readable(fd: (int | _HasFileNo)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -23,7 +26,7 @@ async def wait_readable(fd: (int | socket)) ->None: raise RuntimeError("must be called from async context") -async def wait_writable(fd: (int | socket)) ->None: +async def wait_writable(fd: (int | _HasFileNo)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -31,7 +34,7 @@ async def wait_writable(fd: (int | socket)) ->None: raise RuntimeError("must be called from async context") -def notify_closing(fd: (int | socket)) ->None: +def notify_closing(fd: (int | _HasFileNo)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 0d247cae64..b39ad547f1 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -18,6 +18,7 @@ from typing_extensions import TypeAlias from .._core import Abort, RaiseCancelT + from .._file_io import _HasFileNo @attr.s(slots=True, eq=False) @@ -290,7 +291,7 @@ def _update_registrations(self, fd: int) -> None: if not wanted_flags: del self._registered[fd] - async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None: + async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None: if not isinstance(fd, int): fd = fd.fileno() waiters = self._registered[fd] @@ -309,15 +310,15 @@ def abort(_: RaiseCancelT) -> Abort: await _core.wait_task_rescheduled(abort) @_public - async def wait_readable(self, fd: int | socket) -> None: + async def wait_readable(self, fd: int | _HasFileNo) -> None: await self._epoll_wait(fd, "read_task") @_public - async def wait_writable(self, fd: int | socket) -> None: + async def wait_writable(self, fd: int | _HasFileNo) -> None: await self._epoll_wait(fd, "write_task") @_public - def notify_closing(self, fd: int | socket) -> None: + def notify_closing(self, fd: int | _HasFileNo) -> None: if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_socket.py b/trio/_socket.py index 0b64e5adec..ace35ff253 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -14,6 +14,7 @@ Callable, Literal, NoReturn, + Optional, SupportsIndex, Tuple, TypeVar, @@ -44,6 +45,13 @@ Address: TypeAlias = Union[ str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int] ] +AddressWithNoneHost: TypeAlias = Union[ + str, + bytes, + Tuple[Optional[str], int], + Tuple[Optional[str], int, int], + Tuple[Optional[str], int, int, int], +] # Usage: @@ -450,7 +458,7 @@ async def _resolve_address_nocp( proto: int, *, ipv6_v6only: bool | int, - address: Address, + address: AddressWithNoneHost, local: bool, ) -> Address: # Do some pre-checking (or exit early for non-IP sockets) @@ -467,7 +475,8 @@ async def _resolve_address_nocp( assert isinstance(address, (str, bytes)) return os.fspath(address) else: - return address + # TODO: check for host is None? + return address # type: ignore[return-value] # -- From here on we know we have IPv4 or IPV6 -- host: str | None @@ -475,13 +484,13 @@ async def _resolve_address_nocp( # Fast path for the simple case: already-resolved IP address, # already-resolved port. This is particularly important for UDP, since # every sendto call goes through here. - if isinstance(port, int): + if isinstance(port, int) and host is not None: try: - _stdlib_socket.inet_pton(family, address[0]) + _stdlib_socket.inet_pton(family, host) except (OSError, TypeError): pass else: - return address + return address # type: ignore[return-value] # Special cases to match the stdlib, see gh-277 if host == "": host = None @@ -530,6 +539,69 @@ def __init__(self) -> NoReturn: "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" ) + + ################################################################ + # Simple + portable methods and attributes + ################################################################ + def detach(self) -> int: + raise NotImplementedError() + + def fileno(self) -> int: + raise NotImplementedError() + + def getpeername(self) -> Any: + raise NotImplementedError() + + def getsockname(self) -> Any: + raise NotImplementedError() + + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + raise NotImplementedError() + + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + raise NotImplementedError() + + def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + raise NotImplementedError() + + def get_inheritable(self) -> bool: + raise NotImplementedError() + + def set_inheritable(self, inheritable: bool) -> None: + raise NotImplementedError() + + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): + + def share(self, /, process_id: int) -> bytes: + raise NotImplementedError() + def __enter__(self) -> Self: raise NotImplementedError() @@ -540,38 +612,85 @@ def __exit__( traceback: TracebackType | None, ) -> None: raise NotImplementedError() + @property - def type(self) -> SocketKind: + def family(self) -> AddressFamily: raise NotImplementedError() + @property - def family(self) -> AddressFamily: + def type(self) -> SocketKind: raise NotImplementedError() + @property def proto(self) -> int: raise NotImplementedError() + @property def did_shutdown_SHUT_WR(self) -> bool: raise NotImplementedError() - def is_readable(self) -> bool: + + # def __repr__(self) -> str: + # raise NotImplementedError() + def dup(self) -> SocketType: raise NotImplementedError() - def fileno(self) -> int: + + def close(self) -> None: raise NotImplementedError() - async def wait_writable(self) -> None: + + async def bind(self, address: Address) -> None: raise NotImplementedError() + def shutdown(self, flag: int) -> None: raise NotImplementedError() + def is_readable(self) -> bool: + raise NotImplementedError() + + async def wait_writable(self) -> None: + raise NotImplementedError() + + async def accept(self) -> tuple[SocketType, object]: + raise NotImplementedError() + async def connect(self, address: Address) -> None: raise NotImplementedError() - def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + + def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: raise NotImplementedError() - async def bind(self, address: Address) -> None: + + def recv_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: raise NotImplementedError() - def close(self) -> None: + + ### + def recvfrom( + __self, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, Address]]: raise NotImplementedError() - def getsockname(self) -> Any: + + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[tuple[int, Address]]: raise NotImplementedError() - async def accept(self) -> tuple[SocketType, object]: + + # if hasattr(_stdlib_socket.socket, "recvmsg"): + # def recvmsg( + # __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0 + # ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + # raise NotImplementedError() + + # if hasattr(_stdlib_socket.socket, "recvmsg_into"): + # def recvmsg_into( + # __self, + # __buffers: Iterable[Buffer], + # __ancbufsize: int = 0, + # __flags: int = 0, + # ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + # raise NotImplementedError() + + def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: raise NotImplementedError() @overload @@ -590,57 +709,21 @@ async def sendto( async def sendto(self, *args: Any) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" raise NotImplementedError() - @overload - def getsockopt(self, /, level: int, optname: int) -> int: - ... - - @overload - def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: - ... - - def getsockopt( - self, /, level: int, optname: int, buflen: int | None = None - ) -> int | bytes: - raise NotImplementedError() - @overload - def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: - ... - - @overload - def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: - ... - - def setsockopt( - self, - /, - level: int, - optname: int, - value: int | Buffer | None, - optlen: int | None = None, - ) -> None: - raise NotImplementedError() - def recvfrom( - __self, __bufsize: int, __flags: int = 0 - ) -> Awaitable[tuple[bytes, Address]]: - raise NotImplementedError() - def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: - raise NotImplementedError() - def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: - raise NotImplementedError() - if sys.platform == "win32" or ( - not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg") ): - def share(self, /, process_id: int) -> bytes: + @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) + async def sendmsg( + self, + __buffers: Iterable[Buffer], + __ancdata: Iterable[tuple[int, int, Buffer]] = (), + __flags: int = 0, + __address: Address | None = None, + ) -> int: raise NotImplementedError() - def detach(self) -> int: - raise NotImplementedError() - def get_inheritable(self) -> bool: - raise NotImplementedError() - def set_inheritable(self, inheritable: bool) -> None: - raise NotImplementedError() class _SocketType(SocketType): def __init__(self, sock: _stdlib_socket.socket): @@ -772,7 +855,7 @@ def close(self) -> None: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address: Address) -> None: + async def bind(self, address: AddressWithNoneHost) -> None: address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") @@ -811,7 +894,7 @@ async def wait_writable(self) -> None: async def _resolve_address_nocp( self, - address: Address, + address: AddressWithNoneHost, *, local: bool, ) -> Address: @@ -891,7 +974,7 @@ async def accept(self) -> tuple[SocketType, object]: # connect ################################################################ - async def connect(self, address: Address) -> None: + async def connect(self, address: AddressWithNoneHost) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the @@ -1115,7 +1198,7 @@ async def sendmsg( __buffers: Iterable[Buffer], __ancdata: Iterable[tuple[int, int, Buffer]] = (), __flags: int = 0, - __address: Address | None = None, + __address: AddressWithNoneHost | None = None, ) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py index 71b4c72a98..0d9e9316fa 100644 --- a/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import errno import socket as stdlib_socket import sys from math import inf from socket import AddressFamily, SocketKind +from typing import overload import attr import pytest @@ -18,7 +21,7 @@ from exceptiongroup import BaseExceptionGroup -async def test_open_tcp_listeners_basic(): +async def test_open_tcp_listeners_basic() -> None: listeners = await open_tcp_listeners(0) assert isinstance(listeners, list) for obj in listeners: @@ -46,7 +49,7 @@ async def test_open_tcp_listeners_basic(): await resource.aclose() -async def test_open_tcp_listeners_specific_port_specific_host(): +async def test_open_tcp_listeners_specific_port_specific_host() -> None: # Pick a port sock = tsocket.socket() await sock.bind(("127.0.0.1", 0)) @@ -59,7 +62,7 @@ async def test_open_tcp_listeners_specific_port_specific_host(): @binds_ipv6 -async def test_open_tcp_listeners_ipv6_v6only(): +async def test_open_tcp_listeners_ipv6_v6only() -> None: # Check IPV6_V6ONLY is working properly (ipv6_listener,) = await open_tcp_listeners(0, host="::1") async with ipv6_listener: @@ -69,7 +72,7 @@ async def test_open_tcp_listeners_ipv6_v6only(): await open_tcp_stream("127.0.0.1", port) -async def test_open_tcp_listeners_rebind(): +async def test_open_tcp_listeners_rebind() -> None: (l1,) = await open_tcp_listeners(0, host="127.0.0.1") sockaddr1 = l1.socket.getsockname() @@ -116,13 +119,13 @@ class FakeOSError(OSError): @attr.s class FakeSocket(tsocket.SocketType): - _family: SocketKind = attr.ib() - _type: AddressFamily = attr.ib() + _family: AddressFamily = attr.ib() + _type: SocketKind = attr.ib() _proto: int = attr.ib() - closed = attr.ib(default=False) - poison_listen = attr.ib(default=False) - backlog = attr.ib(default=None) + closed: bool = attr.ib(default=False) + poison_listen: bool = attr.ib(default=False) + backlog: int | None = attr.ib(default=None) @property def type(self) -> SocketKind: @@ -136,25 +139,50 @@ def family(self) -> AddressFamily: def proto(self) -> int: return self._proto - def getsockopt(self, level, option): - if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN): + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + if (level, optname) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN): return True assert False # pragma: no cover - def setsockopt(self, level, option, value): + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: pass - async def bind(self, sockaddr): + async def bind(self, address: AddressWithNoneHost) -> None: pass - def listen(self, backlog): + def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: assert self.backlog is None assert backlog is not None self.backlog = backlog if self.poison_listen: raise FakeOSError("whoops") - def close(self): + def close(self) -> None: self.closed = True @@ -186,7 +214,7 @@ async def getaddrinfo(self, host, port, family, type, proto, flags): ] -async def test_open_tcp_listeners_multiple_host_cleanup_on_error(): +async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None: # If we were trying to bind to multiple hosts and one of them failed, they # call get cleaned up before returning fsf = FakeSocketFactory(3) @@ -209,7 +237,7 @@ async def test_open_tcp_listeners_multiple_host_cleanup_on_error(): assert sock.closed -async def test_open_tcp_listeners_port_checking(): +async def test_open_tcp_listeners_port_checking() -> None: for host in ["127.0.0.1", None]: with pytest.raises(TypeError): await open_tcp_listeners(None, host=host) @@ -219,8 +247,8 @@ async def test_open_tcp_listeners_port_checking(): await open_tcp_listeners("http", host=host) -async def test_serve_tcp(): - async def handler(stream): +async def test_serve_tcp() -> None: + async def handler(stream) -> None: await stream.send_all(b"x") async with trio.open_nursery() as nursery: @@ -241,7 +269,7 @@ async def handler(stream): ) async def test_open_tcp_listeners_some_address_families_unavailable( try_families, fail_families -): +) -> None: fsf = FakeSocketFactory( 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families} ) @@ -270,7 +298,7 @@ async def test_open_tcp_listeners_some_address_families_unavailable( assert not should_succeed -async def test_open_tcp_listeners_socket_fails_not_afnosupport(): +async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None: fsf = FakeSocketFactory( 10, raise_on_family={ @@ -298,7 +326,7 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport(): # effectively is no backlog), sometimes the host might not be enough resources # to give us the full requested backlog... it was a mess. So now we just check # that the backlog argument is passed through correctly. -async def test_open_tcp_listeners_backlog(): +async def test_open_tcp_listeners_backlog() -> None: fsf = FakeSocketFactory(99) tsocket.set_custom_socket_factory(fsf) for given, expected in [ @@ -314,7 +342,7 @@ async def test_open_tcp_listeners_backlog(): assert listener.socket.backlog == expected -async def test_open_tcp_listeners_backlog_float_error(): +async def test_open_tcp_listeners_backlog_float_error() -> None: fsf = FakeSocketFactory(99) tsocket.set_custom_socket_factory(fsf) for should_fail in (0.0, 2.18, 3.14, 9.75): diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py index a5f16ad777..fc1bf4c006 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -1,11 +1,13 @@ from __future__ import annotations + import socket import sys +from socket import AddressFamily, SocketKind +from typing import TYPE_CHECKING, Any, Sequence import attr import pytest -from typing import Any, Sequence, TYPE_CHECKING import trio from trio._highlevel_open_tcp_stream import ( close_all, @@ -13,8 +15,7 @@ open_tcp_stream, reorder_for_rfc_6555_section_5_4, ) -from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType, Address -from socket import AddressFamily, SocketKind +from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, Address, SocketType if TYPE_CHECKING: from trio.testing import MockClock @@ -27,6 +28,7 @@ def test_close_all() -> None: class CloseMe(SocketType): def __init__(self) -> None: ... + closed = False def close(self) -> None: @@ -35,6 +37,7 @@ def close(self) -> None: class CloseKiller(SocketType): def __init__(self) -> None: ... + def close(self) -> None: raise OSError @@ -272,6 +275,7 @@ def socket( ) -> SocketType: assert isinstance(family, AddressFamily) assert isinstance(type, SocketKind) + assert proto is not None if family not in self.supported_families: raise OSError("pretending not to support this family") self.socket_count += 1 diff --git a/trio/_tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py index 1a987df3f3..c5d46b6c6a 100644 --- a/trio/_tests/test_highlevel_socket.py +++ b/trio/_tests/test_highlevel_socket.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import errno import socket as stdlib_socket import sys +from typing import Sequence import pytest @@ -14,16 +17,17 @@ from .test_socket import setsockopt_tests -async def test_SocketStream_basics(): +async def test_SocketStream_basics() -> None: # stdlib socket bad (even if connected) - a, b = stdlib_socket.socketpair() - with a, b: + stdlib_a, stdlib_b = stdlib_socket.socketpair() + with stdlib_a, stdlib_b: with pytest.raises(TypeError): - SocketStream(a) + SocketStream(stdlib_a) # type: ignore[arg-type] # DGRAM socket bad with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock: with pytest.raises(ValueError): + # TODO: does not raise an error? SocketStream(sock) a, b = tsocket.socketpair() @@ -48,13 +52,13 @@ async def test_SocketStream_basics(): s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY) - b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1) - assert isinstance(b, bytes) + res = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1) + assert isinstance(res, bytes) setsockopt_tests(s) -async def test_SocketStream_send_all(): +async def test_SocketStream_send_all() -> None: BIG = 10000000 a_sock, b_sock = tsocket.socketpair() @@ -65,7 +69,7 @@ async def test_SocketStream_send_all(): # Check a send_all that has to be split into multiple parts (on most # platforms... on Windows every send() either succeeds or fails as a # whole) - async def sender(): + async def sender() -> None: data = bytearray(BIG) await a.send_all(data) # send_all uses memoryviews internally, which temporarily "lock" @@ -89,7 +93,7 @@ async def sender(): # and we break our implementation of send_all, then we'll get some # early warning...) - async def receiver(): + async def receiver() -> None: # Make sure the sender fills up the kernel buffers and blocks await wait_all_tasks_blocked() nbytes = 0 @@ -109,12 +113,12 @@ async def receiver(): assert await b.receive_some(10) == b"" -async def fill_stream(s): - async def sender(): +async def fill_stream(s: SocketStream) -> None: + async def sender() -> None: while True: await s.send_all(b"x" * 10000) - async def waiter(nursery): + async def waiter(nursery: _core.Nursery) -> None: await wait_all_tasks_blocked() nursery.cancel_scope.cancel() @@ -123,12 +127,12 @@ async def waiter(nursery): nursery.start_soon(waiter, nursery) -async def test_SocketStream_generic(): - async def stream_maker(): +async def test_SocketStream_generic() -> None: + async def stream_maker() -> tuple[SocketStream, SocketStream]: left, right = tsocket.socketpair() return SocketStream(left), SocketStream(right) - async def clogged_stream_maker(): + async def clogged_stream_maker() -> tuple[SocketStream, SocketStream]: left, right = await stream_maker() await fill_stream(left) await fill_stream(right) @@ -137,13 +141,13 @@ async def clogged_stream_maker(): await check_half_closeable_stream(stream_maker, clogged_stream_maker) -async def test_SocketListener(): +async def test_SocketListener() -> None: # Not a Trio socket with stdlib_socket.socket() as s: s.bind(("127.0.0.1", 0)) s.listen(10) with pytest.raises(TypeError): - SocketListener(s) + SocketListener(s) # type: ignore[arg-type] # Not a SOCK_STREAM with tsocket.socket(type=tsocket.SOCK_DGRAM) as s: @@ -190,7 +194,7 @@ async def test_SocketListener(): await server_stream.aclose() -async def test_SocketListener_socket_closed_underfoot(): +async def test_SocketListener_socket_closed_underfoot() -> None: listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(10) @@ -205,21 +209,48 @@ async def test_SocketListener_socket_closed_underfoot(): await listener.accept() -async def test_SocketListener_accept_errors(): +async def test_SocketListener_accept_errors() -> None: class FakeSocket(tsocket.SocketType): - def __init__(self, events): + def __init__(self, events: Sequence[SocketType | BaseException]): self._events = iter(events) type = tsocket.SOCK_STREAM # Fool the check for SO_ACCEPTCONN in SocketListener.__init__ - def getsockopt(self, level, opt): + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: return True - def setsockopt(self, level, opt, value): + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt( + self, /, level: int, optname: int, value: None, optlen: int + ) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: pass - async def accept(self): + async def accept(self) -> tuple[SocketType, object]: await _core.checkpoint() event = next(self._events) if isinstance(event, BaseException): @@ -242,24 +273,24 @@ async def accept(self): ] ) - l = SocketListener(fake_listen_sock) + listener = SocketListener(fake_listen_sock) with assert_checkpoints(): - s = await l.accept() + s = await listener.accept() assert s.socket is fake_server_sock for code in [errno.EMFILE, errno.EFAULT, errno.ENOBUFS]: with assert_checkpoints(): with pytest.raises(OSError) as excinfo: - await l.accept() + await listener.accept() assert excinfo.value.errno == code with assert_checkpoints(): - s = await l.accept() + s = await listener.accept() assert s.socket is fake_server_sock -async def test_socket_stream_works_when_peer_has_already_closed(): +async def test_socket_stream_works_when_peer_has_already_closed() -> None: sock_a, sock_b = tsocket.socketpair() with sock_a, sock_b: await sock_b.send(b"x") diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py index 8536c66f89..b405e97e72 100644 --- a/trio/_tests/test_socket.py +++ b/trio/_tests/test_socket.py @@ -1,24 +1,33 @@ from __future__ import annotations + import errno import inspect import os import socket as stdlib_socket import sys import tempfile -from typing import Callable, Any, TYPE_CHECKING, Tuple, List, Union -from socket import SocketKind, AddressFamily +from socket import AddressFamily, SocketKind +from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union import attr import pytest from .. import _core, socket as tsocket from .._core._tests.tutil import binds_ipv6, creates_ipv6 -from .._socket import _NUMERIC_ONLY, _try_sync, SocketType +from .._highlevel_socket import SocketStream +from .._socket import _NUMERIC_ONLY, SocketType, _SocketType, _try_sync from ..testing import assert_checkpoints, wait_all_tasks_blocked if TYPE_CHECKING: from typing_extensions import TypeAlias - GaiTuple: TypeAlias = Tuple[AddressFamily, SocketKind, int, str, Union[Tuple[str, int],Tuple[str, int, int, int]]] + + GaiTuple: TypeAlias = Tuple[ + AddressFamily, + SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] getaddrinfoResponse: TypeAlias = List[GaiTuple] else: GaiTuple: object @@ -28,10 +37,11 @@ # utils ################################################################ + class MonkeypatchedGAI: def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]): self._orig_getaddrinfo = orig_getaddrinfo - self._responses: dict[tuple[Any, ...], getaddrinfoResponse|str] = {} + self._responses: dict[tuple[Any, ...], getaddrinfoResponse | str] = {} self.record: list[tuple[Any, ...]] = [] # get a normalized getaddrinfo argument tuple @@ -43,10 +53,12 @@ def _frozenbind(self, *args: Any, **kwargs: Any) -> tuple[Any, ...]: assert not bound.kwargs return frozenbound - def set(self, response: getaddrinfoResponse|str, *args: Any, **kwargs: Any) -> None: + def set( + self, response: getaddrinfoResponse | str, *args: Any, **kwargs: Any + ) -> None: self._responses[self._frozenbind(*args, **kwargs)] = response - def getaddrinfo(self, *args: Any, **kwargs: Any) -> getaddrinfoResponse|str: + def getaddrinfo(self, *args: Any, **kwargs: Any) -> getaddrinfoResponse | str: bound = self._frozenbind(*args, **kwargs) self.record.append(bound) if bound in self._responses: @@ -113,12 +125,26 @@ def check(got: getaddrinfoResponse, expected: getaddrinfoResponse) -> None: # field (https://github.com/python-trio/trio/issues/1499) # Neither field gets used much and there isn't much opportunity for us # to mess them up, so we don't bother checking them here - def interesting_fields(gai_tup: GaiTuple) -> tuple[AddressFamily, SocketKind, tuple[str, int]|tuple[str, int, int]|tuple[str, int, int, int]]: + def interesting_fields( + gai_tup: GaiTuple, + ) -> tuple[ + AddressFamily, + SocketKind, + tuple[str, int] | tuple[str, int, int] | tuple[str, int, int, int], + ]: # (family, type, proto, canonname, sockaddr) family, type, proto, canonname, sockaddr = gai_tup return (family, type, sockaddr) - def filtered(gai_list: getaddrinfoResponse) -> list[tuple[AddressFamily, SocketKind, tuple[str, int]|tuple[str, int, int]|tuple[str, int, int, int]]]: + def filtered( + gai_list: getaddrinfoResponse, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + tuple[str, int] | tuple[str, int, int] | tuple[str, int, int, int], + ] + ]: return [interesting_fields(gai_tup) for gai_tup in gai_list] assert filtered(got) == filtered(expected) @@ -263,7 +289,8 @@ async def child(sock: SocketType) -> None: @pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only") async def test_fromshare() -> None: - assert not TYPE_CHECKING or sys.platform == "win32" + if TYPE_CHECKING and sys.platform != "win32": + return a, b = tsocket.socketpair() with a, b: # share with ourselves @@ -377,7 +404,7 @@ async def test_SocketType_setsockopt() -> None: setsockopt_tests(sock) -def setsockopt_tests(sock: SocketType) -> None: +def setsockopt_tests(sock: SocketType | SocketStream) -> None: """Extract these out, to be reused for SocketStream also.""" # specifying optlen. Not supported on pypy, and I couldn't find # valid calls on darwin or win32. @@ -442,7 +469,9 @@ async def test_SocketType_shutdown() -> None: pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6), ], ) -async def test_SocketType_simple_server(address, socket_type) -> None: +async def test_SocketType_simple_server( + address: str, socket_type: AddressFamily +) -> None: # listen, bind, accept, connect, getpeername, getsockname listener = tsocket.socket(socket_type) client = tsocket.socket(socket_type) @@ -484,10 +513,10 @@ def gai_without_v4mapped_is_buggy() -> bool: # pragma: no cover @attr.s class Addresses: - bind_all = attr.ib() - localhost = attr.ib() - arbitrary = attr.ib() - broadcast = attr.ib() + bind_all: str = attr.ib() + localhost: str = attr.ib() + arbitrary: str = attr.ib() + broadcast: str = attr.ib() # Direct thorough tests of the implicit resolver helpers @@ -515,34 +544,53 @@ class Addresses: ), ], ) -async def test_SocketType_resolve(socket_type, addrs) -> None: +async def test_SocketType_resolve(socket_type: AddressFamily, addrs: Addresses) -> None: v6 = socket_type == tsocket.AF_INET6 - def pad(addr) -> tuple[str, int]|tuple[str, int,int,int]: + def pad(addr: tuple[str | int, ...]) -> tuple[str | int, ...]: if v6: while len(addr) < 4: addr += (0,) return addr - def assert_eq(actual, expected) -> None: + def assert_eq( + actual: tuple[str | int, ...], expected: tuple[str | int, ...] + ) -> None: assert pad(expected) == pad(actual) with tsocket.socket(family=socket_type) as sock: + # testing internal functionality, so we check it against the internal type + assert isinstance(sock, _SocketType) + # For some reason the stdlib special-cases "" to pass NULL to # getaddrinfo. They also error out on None, but whatever, None is much # more consistent, so we accept it too. + # TODO: this implies that we can send host=None, but what does that imply for the return value, and other stuff? for null in [None, ""]: got = await sock._resolve_address_nocp((null, 80), local=True) + assert not isinstance(got, (str, bytes)) assert_eq(got, (addrs.bind_all, 80)) got = await sock._resolve_address_nocp((null, 80), local=False) + assert not isinstance(got, (str, bytes)) assert_eq(got, (addrs.localhost, 80)) # AI_PASSIVE only affects the wildcard address, so for everything else # local=True/local=False should work the same: for local in [False, True]: - async def res(*args): - return await sock._resolve_address_nocp(*args, local=local) + async def res( + args: tuple[str, int] + | tuple[str, int, int] + | tuple[str, int, int, int] + | tuple[str, str] + | tuple[str, str, int] + | tuple[str, str, int, int] + ) -> tuple[str, int] | tuple[str, int, int, int]: + # we're only passing IP sockets, so we ignore the str/bytes return type + # But what about when port/family is a string? Should that be part of the public API? + res = await sock._resolve_address_nocp(args, local=local) # type: ignore[arg-type] + # no str/bytes + return res # type: ignore[return-value] assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) if v6: @@ -593,6 +641,7 @@ async def res(*args): except (AttributeError, OSError): pass else: + assert isinstance(netlink_sock, _SocketType) assert ( await netlink_sock._resolve_address_nocp("asdf", local=local) == "asdf" @@ -600,13 +649,14 @@ async def res(*args): netlink_sock.close() with pytest.raises(ValueError): - await res("1.2.3.4") + await res("1.2.3.4") # type: ignore[arg-type] with pytest.raises(ValueError): - await res(("1.2.3.4",)) + await res(("1.2.3.4",)) # type: ignore[arg-type] with pytest.raises(ValueError): if v6: - await res(("1.2.3.4", 80, 0, 0, 0)) + await res(("1.2.3.4", 80, 0, 0, 0)) # type: ignore[arg-type] else: + # I guess in theory there could be enough overloads that this could error? await res(("1.2.3.4", 80, 0, 0)) @@ -649,7 +699,7 @@ async def test_SocketType_non_blocking_paths() -> None: # immediate failure with assert_checkpoints(): with pytest.raises(TypeError): - await ta.recv("haha") + await ta.recv("haha") # type: ignore[arg-type] # block then succeed async def do_successful_blocking_recv() -> None: @@ -727,7 +777,10 @@ async def test_SocketType_connect_paths() -> None: # nose -- and then swap it back out again before we hit # wait_socket_writable, which insists on a real socket. class CancelSocket(stdlib_socket.socket): - def connect(self, *args, **kwargs) -> None: + def connect(self, *args: Any, **kwargs: Any) -> None: + # accessing private method only available in _SocketType + assert isinstance(sock, _SocketType) + cancel_scope.cancel() sock._sock = stdlib_socket.fromfd( self.detach(), self.family, self.type @@ -736,6 +789,8 @@ def connect(self, *args, **kwargs) -> None: # If connect *doesn't* raise, then pretend it did raise BlockingIOError # pragma: no cover + # accessing private method only available in _SocketType + assert isinstance(sock, _SocketType) sock._sock.close() sock._sock = CancelSocket() @@ -772,11 +827,14 @@ async def test_resolve_address_exception_in_connect_closes_socket() -> None: with _core.CancelScope() as cancel_scope: with tsocket.socket() as sock: - async def _resolve_address_nocp(self, *args, **kwargs) -> None: + async def _resolve_address_nocp( + self: Any, *args: Any, **kwargs: Any + ) -> None: cancel_scope.cancel() await _core.checkpoint() - sock._resolve_address_nocp = _resolve_address_nocp + assert isinstance(sock, _SocketType) + sock._resolve_address_nocp = _resolve_address_nocp # type: ignore[method-assign, assignment] with assert_checkpoints(): with pytest.raises(_core.Cancelled): await sock.connect("") @@ -879,7 +937,7 @@ async def test_send_recv_variants() -> None: assert await b.recv(10) == b"yyy" -async def test_idna(monkeygai) -> None: +async def test_idna(monkeygai: MonkeypatchedGAI) -> None: # This is the encoding for "faß.de", which uses one of the characters that # IDNA 2003 handles incorrectly: monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80) @@ -904,17 +962,22 @@ async def test_getprotobyname() -> None: assert await tsocket.getprotobyname("tcp") == 6 -async def test_custom_hostname_resolver(monkeygai) -> None: +async def test_custom_hostname_resolver(monkeygai: MonkeypatchedGAI) -> None: + # This intentionally breaks the signatures used in HostnameResolver class CustomResolver: - async def getaddrinfo(self, host, port, family, type, proto, flags): + async def getaddrinfo( + self, host: str, port: str, family: int, type: int, proto: int, flags: int + ) -> tuple[str, str, str, int, int, int, int]: return ("custom_gai", host, port, family, type, proto, flags) - async def getnameinfo(self, sockaddr, flags): + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, tuple[str, int] | tuple[str, int, int, int], int]: return ("custom_gni", sockaddr, flags) cr = CustomResolver() - assert tsocket.set_custom_hostname_resolver(cr) is None + assert tsocket.set_custom_hostname_resolver(cr) is None # type: ignore[arg-type] # Check that the arguments are all getting passed through. # We have to use valid calls to avoid making the underlying system @@ -937,7 +1000,11 @@ async def getnameinfo(self, sockaddr, flags): expected = ("custom_gai", b"xn--f-1gaa", "foo", 0, 0, 0, 0) assert got == expected - assert await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0) + assert await tsocket.getnameinfo("a", 0) == ( # type: ignore[arg-type] + "custom_gni", + "a", + 0, + ) # We can set it back to None assert tsocket.set_custom_hostname_resolver(None) is cr @@ -950,12 +1017,14 @@ async def getnameinfo(self, sockaddr, flags): async def test_custom_socket_factory() -> None: class CustomSocketFactory: - def socket(self, family, type, proto): + def socket( + self, family: AddressFamily, type: SocketKind, proto: int + ) -> tuple[str, AddressFamily, SocketKind, int]: return ("hi", family, type, proto) csf = CustomSocketFactory() - assert tsocket.set_custom_socket_factory(csf) is None + assert tsocket.set_custom_socket_factory(csf) is None # type: ignore[arg-type] assert tsocket.socket() == ("hi", tsocket.AF_INET, tsocket.SOCK_STREAM, 0) assert tsocket.socket(1, 2, 3) == ("hi", 1, 2, 3) @@ -985,7 +1054,7 @@ async def test_unix_domain_socket() -> None: # Bind has a special branch to use a thread, since it has to do filesystem # traversal. Maybe connect should too? Not sure. - async def check_AF_UNIX(path) -> None: + async def check_AF_UNIX(path: str | bytes) -> None: with tsocket.socket(family=tsocket.AF_UNIX) as lsock: await lsock.bind(path) lsock.listen(10) diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 3c598e8eae..3e8c300d53 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -263,6 +263,10 @@ def main() -> None: # pragma: no cover IMPORTS_EPOLL = """\ from socket import socket +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .._file_io import _HasFileNo """ IMPORTS_KQUEUE = """\ From 201da1a26b07a4eeb82c77d5fc843e8596936596 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 28 Aug 2023 17:04:01 +0200 Subject: [PATCH 03/20] hide methods in SocketType behind a TYPE_CHECKING guard, update _io_kqueue to use _HasFileNo --- .coveragerc | 1 + pyproject.toml | 2 - trio/_abc.py | 6 +- trio/_core/_generated_io_kqueue.py | 8 +- trio/_core/_io_epoll.py | 2 - trio/_core/_io_kqueue.py | 11 +- trio/_highlevel_open_tcp_listeners.py | 5 +- trio/_socket.py | 330 +++++++++--------- .../test_highlevel_open_tcp_listeners.py | 79 +++-- trio/_tests/test_highlevel_open_tcp_stream.py | 27 +- trio/_tools/gen_exports.py | 2 +- 11 files changed, 257 insertions(+), 216 deletions(-) diff --git a/.coveragerc b/.coveragerc index 5272237caf..07709cd482 100644 --- a/.coveragerc +++ b/.coveragerc @@ -24,6 +24,7 @@ exclude_lines = if t.TYPE_CHECKING: @overload class .*\bProtocol\b.*\): + raise NotImplementedError partial_branches = pragma: no branch diff --git a/pyproject.toml b/pyproject.toml index aeb35977a7..ec2ef1ffe1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,8 +96,6 @@ module = [ "trio/_tests/test_exports", "trio/_tests/test_file_io", "trio/_tests/test_highlevel_generic", -#"trio/_tests/test_highlevel_socket", # -#"trio/_tests/test_highlevel_open_tcp_listeners", # "trio/_tests/test_highlevel_open_unix_stream", "trio/_tests/test_highlevel_serve_listeners", "trio/_tests/test_highlevel_ssl_helpers", diff --git a/trio/_abc.py b/trio/_abc.py index 6a99ea0842..a839bd380a 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -211,9 +211,9 @@ class SocketFactory(metaclass=ABCMeta): @abstractmethod def socket( self, - family: socket.AddressFamily | int | None = None, - type: socket.SocketKind | int | None = None, - proto: int | None = None, + family: socket.AddressFamily | int = ..., + type: socket.SocketKind | int = ..., + proto: int = ..., ) -> SocketType: """Create and return a socket object. diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index cfcf6354c7..b8b58c2edc 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -12,11 +12,11 @@ if TYPE_CHECKING: import select - from socket import socket from ._traps import Abort, RaiseCancelT from .. import _core + from .._file_io import _HasFileNo import sys @@ -49,7 +49,7 @@ async def wait_kevent(ident: int, filter: int, abort_func: Callable[[ raise RuntimeError("must be called from async context") -async def wait_readable(fd: (int | socket)) ->None: +async def wait_readable(fd: (int | _HasFileNo)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -57,7 +57,7 @@ async def wait_readable(fd: (int | socket)) ->None: raise RuntimeError("must be called from async context") -async def wait_writable(fd: (int | socket)) ->None: +async def wait_writable(fd: (int | _HasFileNo)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -65,7 +65,7 @@ async def wait_writable(fd: (int | socket)) ->None: raise RuntimeError("must be called from async context") -def notify_closing(fd: (int | socket)) ->None: +def notify_closing(fd: (int | _HasFileNo)) ->None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index b39ad547f1..a0373fb8fa 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -13,8 +13,6 @@ from ._wakeup_socketpair import WakeupSocketpair if TYPE_CHECKING: - from socket import socket - from typing_extensions import TypeAlias from .._core import Abort, RaiseCancelT diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 56a6559091..4faa382eca 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -14,11 +14,10 @@ from ._wakeup_socketpair import WakeupSocketpair if TYPE_CHECKING: - from socket import socket - from typing_extensions import TypeAlias from .._core import Abort, RaiseCancelT, Task, UnboundedQueue + from .._file_io import _HasFileNo assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") @@ -149,7 +148,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: # wait_task_rescheduled does not have its return type typed return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return] - async def _wait_common(self, fd: int | socket, filter: int) -> None: + async def _wait_common(self, fd: int | _HasFileNo, filter: int) -> None: if not isinstance(fd, int): fd = fd.fileno() flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT @@ -181,15 +180,15 @@ def abort(_: RaiseCancelT) -> Abort: await self.wait_kevent(fd, filter, abort) @_public - async def wait_readable(self, fd: int | socket) -> None: + async def wait_readable(self, fd: int | _HasFileNo) -> None: await self._wait_common(fd, select.KQ_FILTER_READ) @_public - async def wait_writable(self, fd: int | socket) -> None: + async def wait_writable(self, fd: int | _HasFileNo) -> None: await self._wait_common(fd, select.KQ_FILTER_WRITE) @_public - def notify_closing(self, fd: int | socket) -> None: + def notify_closing(self, fd: int | _HasFileNo) -> None: if not isinstance(fd, int): fd = fd.fileno() diff --git a/trio/_highlevel_open_tcp_listeners.py b/trio/_highlevel_open_tcp_listeners.py index e6840eae97..019505c827 100644 --- a/trio/_highlevel_open_tcp_listeners.py +++ b/trio/_highlevel_open_tcp_listeners.py @@ -192,11 +192,12 @@ async def serve_tcp( connect to it to check that it's working properly, you can use something like:: + from trio import SocketListener, SocketStream from trio.testing import open_stream_to_socket_listener async with trio.open_nursery() as nursery: - listeners = await nursery.start(serve_tcp, handler, 0) - client_stream = await open_stream_to_socket_listener(listeners[0]) + listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) + client_stream: SocketStream = await open_stream_to_socket_listener(listeners[0]) # Then send and receive data on 'client_stream', for example: await client_stream.send_all(b"GET / HTTP/1.0\\r\\n\\r\\n") diff --git a/trio/_socket.py b/trio/_socket.py index ace35ff253..8321de1d63 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -13,7 +13,6 @@ Awaitable, Callable, Literal, - NoReturn, Optional, SupportsIndex, Tuple, @@ -397,7 +396,7 @@ def _sniff_sockopts_for_fileno( ################################################################ -# _SocketType +# SocketType ################################################################ # sock.type gets weird stuff set in it, in particular on Linux: @@ -529,200 +528,211 @@ async def _resolve_address_nocp( return normed -# TODO: stopping users from initializing this type should be done in a different way, -# so SocketType can be used as a type. Note that this is *far* from trivial without -# breaking subclasses of SocketType. Can maybe add abstract methods to SocketType, -# or rename _SocketType. class SocketType: - def __init__(self) -> NoReturn: - raise TypeError( - "SocketType is an abstract class; use trio.socket.socket if you " - "want to construct a socket object" - ) + def __init__(self) -> None: + if type(self) == SocketType: + raise TypeError( + "SocketType is an abstract class; use trio.socket.socket if you " + "want to construct a socket object" + ) - ################################################################ - # Simple + portable methods and attributes - ################################################################ - def detach(self) -> int: - raise NotImplementedError() + if TYPE_CHECKING: - def fileno(self) -> int: - raise NotImplementedError() + def detach(self: SocketType) -> int: + ... - def getpeername(self) -> Any: - raise NotImplementedError() + def fileno(self: SocketType) -> int: + ... - def getsockname(self) -> Any: - raise NotImplementedError() + def getpeername(self: SocketType) -> Any: + ... - @overload - def getsockopt(self, /, level: int, optname: int) -> int: - ... + def getsockname(self: SocketType) -> Any: + ... - @overload - def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: - ... + @overload + def getsockopt(self: SocketType, /, level: int, optname: int) -> int: + ... - def getsockopt( - self, /, level: int, optname: int, buflen: int | None = None - ) -> int | bytes: - raise NotImplementedError() + @overload + def getsockopt( + self: SocketType, /, level: int, optname: int, buflen: int + ) -> bytes: + ... - @overload - def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: - ... + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + ... - @overload - def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: - ... + @overload + def setsockopt( + self: SocketType, /, level: int, optname: int, value: int | Buffer + ) -> None: + ... - def setsockopt( - self, - /, - level: int, - optname: int, - value: int | Buffer | None, - optlen: int | None = None, - ) -> None: - raise NotImplementedError() + @overload + def setsockopt( + self: SocketType, /, level: int, optname: int, value: None, optlen: int + ) -> None: + ... - def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: - raise NotImplementedError() + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + ... - def get_inheritable(self) -> bool: - raise NotImplementedError() + def listen( + self: SocketType, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128) + ) -> None: + ... - def set_inheritable(self, inheritable: bool) -> None: - raise NotImplementedError() + def get_inheritable(self: SocketType) -> bool: + ... - if sys.platform == "win32" or ( - not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") - ): + def set_inheritable(self: SocketType, inheritable: bool) -> None: + ... - def share(self, /, process_id: int) -> bytes: - raise NotImplementedError() + if sys.platform == "win32": - def __enter__(self) -> Self: - raise NotImplementedError() + def share(self: SocketType, /, process_id: int) -> bytes: + ... - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - raise NotImplementedError() + def __enter__(self) -> Self: + ... - @property - def family(self) -> AddressFamily: - raise NotImplementedError() + def __exit__( + self: SocketType, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + ... - @property - def type(self) -> SocketKind: - raise NotImplementedError() + @property + def family(self: SocketType) -> AddressFamily: + ... - @property - def proto(self) -> int: - raise NotImplementedError() + @property + def type(self: SocketType) -> SocketKind: + ... - @property - def did_shutdown_SHUT_WR(self) -> bool: - raise NotImplementedError() + @property + def proto(self: SocketType) -> int: + ... - # def __repr__(self) -> str: - # raise NotImplementedError() - def dup(self) -> SocketType: - raise NotImplementedError() + @property + def did_shutdown_SHUT_WR(self) -> bool: + ... - def close(self) -> None: - raise NotImplementedError() + def __repr__(self) -> str: + ... - async def bind(self, address: Address) -> None: - raise NotImplementedError() + def dup(self: SocketType) -> SocketType: + ... - def shutdown(self, flag: int) -> None: - raise NotImplementedError() + def close(self) -> None: + ... - def is_readable(self) -> bool: - raise NotImplementedError() + async def bind(self, address: Address) -> None: + ... - async def wait_writable(self) -> None: - raise NotImplementedError() + def shutdown(self, flag: int) -> None: + ... - async def accept(self) -> tuple[SocketType, object]: - raise NotImplementedError() - - async def connect(self, address: Address) -> None: - raise NotImplementedError() - - def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: - raise NotImplementedError() - - def recv_into( - __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 - ) -> Awaitable[int]: - raise NotImplementedError() - - ### - def recvfrom( - __self, __bufsize: int, __flags: int = 0 - ) -> Awaitable[tuple[bytes, Address]]: - raise NotImplementedError() - - # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] - def recvfrom_into( - __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 - ) -> Awaitable[tuple[int, Address]]: - raise NotImplementedError() - - # if hasattr(_stdlib_socket.socket, "recvmsg"): - # def recvmsg( - # __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0 - # ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: - # raise NotImplementedError() - - # if hasattr(_stdlib_socket.socket, "recvmsg_into"): - # def recvmsg_into( - # __self, - # __buffers: Iterable[Buffer], - # __ancbufsize: int = 0, - # __flags: int = 0, - # ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: - # raise NotImplementedError() - - def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: - raise NotImplementedError() + def is_readable(self) -> bool: + ... - @overload - async def sendto( - self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer - ) -> int: - ... + async def wait_writable(self) -> None: + ... - @overload - async def sendto( - self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer - ) -> int: - ... + async def accept(self: SocketType) -> tuple[SocketType, object]: + ... - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] - async def sendto(self, *args: Any) -> int: - """Similar to :meth:`socket.socket.sendto`, but async.""" - raise NotImplementedError() + async def connect(self, address: Address) -> None: + ... - if sys.platform != "win32" or ( - not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg") - ): + def recv( + __self: SocketType, __buflen: int, __flags: int = 0 + ) -> Awaitable[bytes]: + ... - @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) - async def sendmsg( + def recv_into( + __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: + ... + + # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] + def recvfrom( + __self: SocketType, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, Address]]: + ... + + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[tuple[int, Address]]: + ... + + if hasattr(_stdlib_socket.socket, "recvmsg"): + + def recvmsg( + __self: SocketType, + __bufsize: int, + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + ... + + if hasattr(_stdlib_socket.socket, "recvmsg_into"): + + def recvmsg_into( + __self: SocketType, + __buffers: Iterable[Buffer], + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + ... + + def send( + __self: SocketType, __bytes: Buffer, __flags: int = 0 + ) -> Awaitable[int]: + ... + + @overload + async def sendto( + self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @overload + async def sendto( self, - __buffers: Iterable[Buffer], - __ancdata: Iterable[tuple[int, int, Buffer]] = (), - __flags: int = 0, - __address: Address | None = None, + __data: Buffer, + __flags: int, + __address: tuple[Any, ...] | str | Buffer, ) -> int: - raise NotImplementedError() + ... + + async def sendto(self, *args: Any) -> int: + ... + + if sys.platform != "win32": + + @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) + async def sendmsg( + self, + __buffers: Iterable[Buffer], + __ancdata: Iterable[tuple[int, int, Buffer]] = (), + __flags: int = 0, + __address: AddressWithNoneHost | None = None, + ) -> int: + ... class _SocketType(SocketType): diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py index 0d9e9316fa..fe7c55483d 100644 --- a/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -5,21 +5,26 @@ import sys from math import inf from socket import AddressFamily, SocketKind -from typing import overload +from typing import TYPE_CHECKING, Sequence, overload import attr import pytest import trio from trio import SocketListener, open_tcp_listeners, open_tcp_stream, serve_tcp +from trio.abc import HostnameResolver, SendStream, SocketFactory from trio.testing import open_stream_to_socket_listener from .. import socket as tsocket from .._core._tests.tutil import binds_ipv6 +from .._socket import AddressWithNoneHost if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup +if TYPE_CHECKING: + from typing_extensions import Buffer + async def test_open_tcp_listeners_basic() -> None: listeners = await open_tcp_listeners(0) @@ -119,8 +124,8 @@ class FakeOSError(OSError): @attr.s class FakeSocket(tsocket.SocketType): - _family: AddressFamily = attr.ib() - _type: SocketKind = attr.ib() + _family: AddressFamily = attr.ib(converter=AddressFamily) + _type: SocketKind = attr.ib(converter=SocketKind) _proto: int = attr.ib() closed: bool = attr.ib(default=False) @@ -175,7 +180,7 @@ def setsockopt( async def bind(self, address: AddressWithNoneHost) -> None: pass - def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None: assert self.backlog is None assert backlog is not None self.backlog = backlog @@ -187,12 +192,21 @@ def close(self) -> None: @attr.s -class FakeSocketFactory: - poison_after = attr.ib() - sockets = attr.ib(factory=list) - raise_on_family = attr.ib(factory=dict) # family => errno +class FakeSocketFactory(SocketFactory): + poison_after: int = attr.ib() + sockets: list[tsocket.SocketType] = attr.ib(factory=list) + raise_on_family: dict[AddressFamily, int] = attr.ib(factory=dict) # family => errno - def socket(self, family, type, proto): + def socket( + self, + family: AddressFamily | int | None = None, + type: SocketKind | int | None = None, + proto: int = 0, + ) -> tsocket.SocketType: + assert family is not None + assert type is not None + if isinstance(family, int): + family = AddressFamily(family) if family in self.raise_on_family: raise OSError(self.raise_on_family[family], "nope") sock = FakeSocket(family, type, proto) @@ -204,15 +218,37 @@ def socket(self, family, type, proto): @attr.s -class FakeHostnameResolver: - family_addr_pairs = attr.ib() +class FakeHostnameResolver(HostnameResolver): + family_addr_pairs: Sequence[tuple[AddressFamily, str]] = attr.ib() - async def getaddrinfo(self, host, port, family, type, proto, flags): + async def getaddrinfo( + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: + assert isinstance(port, int) return [ (family, tsocket.SOCK_STREAM, 0, "", (addr, port)) for family, addr in self.family_addr_pairs ] + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, str]: + raise NotImplementedError() + async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None: # If we were trying to bind to multiple hosts and one of them failed, they @@ -234,25 +270,27 @@ async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None: assert len(fsf.sockets) == 3 for sock in fsf.sockets: - assert sock.closed + # property only exists on FakeSocket + assert sock.closed # type: ignore[attr-defined] async def test_open_tcp_listeners_port_checking() -> None: for host in ["127.0.0.1", None]: with pytest.raises(TypeError): - await open_tcp_listeners(None, host=host) + await open_tcp_listeners(None, host=host) # type: ignore[arg-type] with pytest.raises(TypeError): - await open_tcp_listeners(b"80", host=host) + await open_tcp_listeners(b"80", host=host) # type: ignore[arg-type] with pytest.raises(TypeError): - await open_tcp_listeners("http", host=host) + await open_tcp_listeners("http", host=host) # type: ignore[arg-type] async def test_serve_tcp() -> None: - async def handler(stream) -> None: + async def handler(stream: SendStream) -> None: await stream.send_all(b"x") async with trio.open_nursery() as nursery: - listeners = await nursery.start(serve_tcp, handler, 0) + # nursery.start is incorrectly typed, awaiting #2773 + listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) # type: ignore[arg-type] stream = await open_stream_to_socket_listener(listeners[0]) async with stream: await stream.receive_some(1) == b"x" @@ -268,7 +306,7 @@ async def handler(stream) -> None: [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}], ) async def test_open_tcp_listeners_some_address_families_unavailable( - try_families, fail_families + try_families: set[AddressFamily], fail_families: set[AddressFamily] ) -> None: fsf = FakeSocketFactory( 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families} @@ -339,7 +377,8 @@ async def test_open_tcp_listeners_backlog() -> None: listeners = await open_tcp_listeners(0, backlog=given) assert listeners for listener in listeners: - assert listener.socket.backlog == expected + # `backlog` only exists on FakeSocket + assert listener.socket.backlog == expected # type: ignore[attr-defined] async def test_open_tcp_listeners_backlog_float_error() -> None: diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py index fc1bf4c006..0fa142ba4b 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -26,18 +26,12 @@ def test_close_all() -> None: class CloseMe(SocketType): - def __init__(self) -> None: - ... - closed = False def close(self) -> None: self.closed = True class CloseKiller(SocketType): - def __init__(self) -> None: - ... - def close(self) -> None: raise OSError @@ -140,6 +134,7 @@ def can_bind_127_0_0_2() -> bool: s.bind(("127.0.0.2", 0)) except OSError: return False + # s.getsockname() is typed as returning Any return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return] @@ -173,7 +168,7 @@ async def test_local_address_real() -> None: server_sock, remote_addr = await listener.accept() await client_stream.aclose() server_sock.close() - # accept returns tuple[SocketType, object] + # accept returns tuple[SocketType, object], due to typeshed returning `Any` assert remote_addr[0] == local_address # type: ignore[index] # Trying to connect to an ipv4 address with the ipv6 wildcard @@ -196,9 +191,9 @@ async def test_local_address_real() -> None: @attr.s(eq=False) class FakeSocket(trio.socket.SocketType): scenario: Scenario = attr.ib() - _family: AddressFamily = attr.ib(alias="_family") - _type: SocketKind = attr.ib(alias="_type") - _proto: int = attr.ib(alias="_proto") + _family: AddressFamily = attr.ib() + _type: SocketKind = attr.ib() + _proto: int = attr.ib() ip: str | int | None = attr.ib(default=None) port: str | int | None = attr.ib(default=None) @@ -286,11 +281,11 @@ def _ip_to_gai_entry( ) -> tuple[ AddressFamily, SocketKind, - int | None, + int, str, - tuple[int | str, int, int, int] | tuple[int | str, int], + tuple[str, int, int, int] | tuple[str, int], ]: - sockaddr: tuple[int | str, int] | tuple[int | str, int, int, int] + sockaddr: tuple[str, int] | tuple[str, int, int, int] if ":" in ip: family = trio.socket.AF_INET6 sockaddr = (ip, self.port, 0, 0) @@ -301,7 +296,7 @@ def _ip_to_gai_entry( # should hostnameresolver use AddressFamily and SocketKind, instead of int&int? # the return type in supertype is ... wildly incompatible with what this returns - async def getaddrinfo( # type: ignore[override] + async def getaddrinfo( self, host: str | bytes | None, port: bytes | str | int | None, @@ -313,9 +308,9 @@ async def getaddrinfo( # type: ignore[override] tuple[ AddressFamily, SocketKind, - int | None, + int, str, - tuple[int | str, int, int, int] | tuple[int | str, int], + tuple[str, int, int, int] | tuple[str, int], ] ]: assert host == b"test.example.com" diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index 3e8c300d53..63fd38d5d2 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -274,11 +274,11 @@ def main() -> None: # pragma: no cover if TYPE_CHECKING: import select - from socket import socket from ._traps import Abort, RaiseCancelT from .. import _core + from .._file_io import _HasFileNo """ From c8fc73ce7af51aaea0d63e1b58bef25352cc6e44 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 29 Aug 2023 15:24:32 +0200 Subject: [PATCH 04/20] remove unneeded pragma: no cover --- trio/_tests/test_highlevel_open_tcp_stream.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py index 0fa142ba4b..775ab1f644 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -294,8 +294,6 @@ def _ip_to_gai_entry( sockaddr = (ip, self.port) return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr) - # should hostnameresolver use AddressFamily and SocketKind, instead of int&int? - # the return type in supertype is ... wildly incompatible with what this returns async def getaddrinfo( self, host: str | bytes | None, @@ -321,7 +319,7 @@ async def getaddrinfo( assert flags == 0 return [self._ip_to_gai_entry(ip) for ip in self.ip_order] - async def getnameinfo( # pragma: no cover + async def getnameinfo( self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int ) -> tuple[str, str]: raise NotImplementedError From e3fc6ae5e2f81bcc7e2d6d89edec775f497a64a2 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 29 Aug 2023 16:17:07 +0200 Subject: [PATCH 05/20] fix default parameters --- trio/_abc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index a839bd380a..dbc7966cee 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -211,9 +211,9 @@ class SocketFactory(metaclass=ABCMeta): @abstractmethod def socket( self, - family: socket.AddressFamily | int = ..., - type: socket.SocketKind | int = ..., - proto: int = ..., + family: socket.AddressFamily | int = socket.AF_INET, + type: socket.SocketKind | int = socket.SOCK_STREAM, + proto: int = 0, ) -> SocketType: """Create and return a socket object. From d60d1a141360f9ea0d29c1a6ca0cb620ce1baf6c Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 2 Sep 2023 12:54:01 +0200 Subject: [PATCH 06/20] raise NotImplementedError instead of methods not being defined at runtime --- trio/_socket.py | 308 ++++++++++++++++++++++++------------------------ 1 file changed, 154 insertions(+), 154 deletions(-) diff --git a/trio/_socket.py b/trio/_socket.py index 8321de1d63..58d1d83cba 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -536,203 +536,203 @@ def __init__(self) -> None: "want to construct a socket object" ) - if TYPE_CHECKING: + def detach(self: SocketType) -> int: + raise NotImplementedError - def detach(self: SocketType) -> int: - ... + def fileno(self: SocketType) -> int: + raise NotImplementedError - def fileno(self: SocketType) -> int: - ... + def getpeername(self: SocketType) -> Any: + raise NotImplementedError - def getpeername(self: SocketType) -> Any: - ... + def getsockname(self: SocketType) -> Any: + raise NotImplementedError - def getsockname(self: SocketType) -> Any: - ... - - @overload - def getsockopt(self: SocketType, /, level: int, optname: int) -> int: - ... + @overload + def getsockopt(self: SocketType, /, level: int, optname: int) -> int: + ... - @overload - def getsockopt( - self: SocketType, /, level: int, optname: int, buflen: int - ) -> bytes: - ... + @overload + def getsockopt(self: SocketType, /, level: int, optname: int, buflen: int) -> bytes: + ... - def getsockopt( - self, /, level: int, optname: int, buflen: int | None = None - ) -> int | bytes: - ... + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + raise NotImplementedError - @overload - def setsockopt( - self: SocketType, /, level: int, optname: int, value: int | Buffer - ) -> None: - ... + @overload + def setsockopt( + self: SocketType, /, level: int, optname: int, value: int | Buffer + ) -> None: + ... - @overload - def setsockopt( - self: SocketType, /, level: int, optname: int, value: None, optlen: int - ) -> None: - ... + @overload + def setsockopt( + self: SocketType, /, level: int, optname: int, value: None, optlen: int + ) -> None: + ... - def setsockopt( - self, - /, - level: int, - optname: int, - value: int | Buffer | None, - optlen: int | None = None, - ) -> None: - ... + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + raise NotImplementedError - def listen( - self: SocketType, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128) - ) -> None: - ... + def listen( + self: SocketType, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128) + ) -> None: + raise NotImplementedError - def get_inheritable(self: SocketType) -> bool: - ... + def get_inheritable(self: SocketType) -> bool: + raise NotImplementedError - def set_inheritable(self: SocketType, inheritable: bool) -> None: - ... + def set_inheritable(self: SocketType, inheritable: bool) -> None: + raise NotImplementedError - if sys.platform == "win32": + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): - def share(self: SocketType, /, process_id: int) -> bytes: - ... + def share(self: SocketType, /, process_id: int) -> bytes: + raise NotImplementedError - def __enter__(self) -> Self: - ... + def __enter__(self) -> Self: + raise NotImplementedError - def __exit__( - self: SocketType, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - ... + def __exit__( + self: SocketType, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + raise NotImplementedError - @property - def family(self: SocketType) -> AddressFamily: - ... + @property + def family(self: SocketType) -> AddressFamily: + raise NotImplementedError - @property - def type(self: SocketType) -> SocketKind: - ... + @property + def type(self: SocketType) -> SocketKind: + raise NotImplementedError - @property - def proto(self: SocketType) -> int: - ... + @property + def proto(self: SocketType) -> int: + raise NotImplementedError - @property - def did_shutdown_SHUT_WR(self) -> bool: - ... + @property + def did_shutdown_SHUT_WR(self) -> bool: + raise NotImplementedError - def __repr__(self) -> str: - ... + def __repr__(self) -> str: + raise NotImplementedError - def dup(self: SocketType) -> SocketType: - ... + def dup(self: SocketType) -> SocketType: + raise NotImplementedError - def close(self) -> None: - ... + def close(self) -> None: + raise NotImplementedError - async def bind(self, address: Address) -> None: - ... + async def bind(self, address: Address) -> None: + raise NotImplementedError - def shutdown(self, flag: int) -> None: - ... + def shutdown(self, flag: int) -> None: + raise NotImplementedError - def is_readable(self) -> bool: - ... + def is_readable(self) -> bool: + raise NotImplementedError - async def wait_writable(self) -> None: - ... + async def wait_writable(self) -> None: + raise NotImplementedError - async def accept(self: SocketType) -> tuple[SocketType, object]: - ... + async def accept(self: SocketType) -> tuple[SocketType, object]: + raise NotImplementedError - async def connect(self, address: Address) -> None: - ... + async def connect(self, address: Address) -> None: + raise NotImplementedError - def recv( - __self: SocketType, __buflen: int, __flags: int = 0 - ) -> Awaitable[bytes]: - ... + def recv(__self: SocketType, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: + raise NotImplementedError - def recv_into( - __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0 - ) -> Awaitable[int]: - ... + def recv_into( + __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: + raise NotImplementedError - # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] - def recvfrom( - __self: SocketType, __bufsize: int, __flags: int = 0 - ) -> Awaitable[tuple[bytes, Address]]: - ... + # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] + def recvfrom( + __self: SocketType, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, Address]]: + raise NotImplementedError - # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] - def recvfrom_into( - __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0 - ) -> Awaitable[tuple[int, Address]]: - ... + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[tuple[int, Address]]: + raise NotImplementedError - if hasattr(_stdlib_socket.socket, "recvmsg"): + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg") + ): - def recvmsg( - __self: SocketType, - __bufsize: int, - __ancbufsize: int = 0, - __flags: int = 0, - ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: - ... + def recvmsg( + __self: SocketType, + __bufsize: int, + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + raise NotImplementedError - if hasattr(_stdlib_socket.socket, "recvmsg_into"): + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into") + ): - def recvmsg_into( - __self: SocketType, - __buffers: Iterable[Buffer], - __ancbufsize: int = 0, - __flags: int = 0, - ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: - ... + def recvmsg_into( + __self: SocketType, + __buffers: Iterable[Buffer], + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + raise NotImplementedError - def send( - __self: SocketType, __bytes: Buffer, __flags: int = 0 - ) -> Awaitable[int]: - ... + def send(__self: SocketType, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: + raise NotImplementedError - @overload - async def sendto( - self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer - ) -> int: - ... + @overload + async def sendto( + self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... - @overload - async def sendto( - self, - __data: Buffer, - __flags: int, - __address: tuple[Any, ...] | str | Buffer, - ) -> int: - ... + @overload + async def sendto( + self, + __data: Buffer, + __flags: int, + __address: tuple[Any, ...] | str | Buffer, + ) -> int: + ... - async def sendto(self, *args: Any) -> int: - ... + async def sendto(self, *args: Any) -> int: + raise NotImplementedError - if sys.platform != "win32": + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg") + ): - @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) - async def sendmsg( - self, - __buffers: Iterable[Buffer], - __ancdata: Iterable[tuple[int, int, Buffer]] = (), - __flags: int = 0, - __address: AddressWithNoneHost | None = None, - ) -> int: - ... + @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) + async def sendmsg( + self, + __buffers: Iterable[Buffer], + __ancdata: Iterable[tuple[int, int, Buffer]] = (), + __flags: int = 0, + __address: AddressWithNoneHost | None = None, + ) -> int: + raise NotImplementedError class _SocketType(SocketType): From b213602bc05271a1d9c28443a3792ec1cb5e8f62 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 2 Sep 2023 13:18:55 +0200 Subject: [PATCH 07/20] remove docs reference --- docs/source/reference-io.rst | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 61bbef78c2..e234d255c1 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -504,14 +504,6 @@ Socket objects * :meth:`~socket.socket.set_inheritable` * :meth:`~socket.socket.get_inheritable` -The internal SocketType -~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. - .. autoclass:: _SocketType -.. - TODO: adding `:members:` here gives error due to overload+_wraps on `sendto` - TODO: rewrite ... all of the above when fixing _SocketType vs SocketType - .. currentmodule:: trio From 7eb099d7e5d7c63c5e3cfff511c668d30eb10216 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 2 Sep 2023 14:08:15 +0200 Subject: [PATCH 08/20] experimentally changing SocketType to be a generic --- trio/_abc.py | 6 +- trio/_dtls.py | 26 ++-- trio/_highlevel_open_tcp_stream.py | 36 ++--- trio/_highlevel_socket.py | 6 +- trio/_socket.py | 183 ++++++++++++-------------- trio/_tests/verify_types_darwin.json | 2 +- trio/_tests/verify_types_linux.json | 2 +- trio/_tests/verify_types_windows.json | 6 +- trio/socket.py | 1 - 9 files changed, 115 insertions(+), 153 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index dbc7966cee..724df92855 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -2,7 +2,7 @@ import socket from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar import trio @@ -209,12 +209,12 @@ class SocketFactory(metaclass=ABCMeta): """ @abstractmethod - def socket( + def socket( # type: ignore[misc] self, family: socket.AddressFamily | int = socket.AF_INET, type: socket.SocketKind | int = socket.SOCK_STREAM, proto: int = 0, - ) -> SocketType: + ) -> SocketType[Any]: """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, diff --git a/trio/_dtls.py b/trio/_dtls.py index ff6a61ba88..63ca83db1f 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -43,26 +43,26 @@ from OpenSSL.SSL import Context from typing_extensions import Self, TypeAlias - from trio.socket import Address, SocketType + from trio.socket import SocketType MAX_UDP_PACKET_SIZE = 65527 -def packet_header_overhead(sock: SocketType) -> int: +def packet_header_overhead(sock: SocketType[Any]) -> int: if sock.family == trio.socket.AF_INET: return 28 else: return 48 -def worst_case_mtu(sock: SocketType) -> int: +def worst_case_mtu(sock: SocketType[Any]) -> int: if sock.family == trio.socket.AF_INET: return 576 - packet_header_overhead(sock) else: return 1280 - packet_header_overhead(sock) -def best_guess_mtu(sock: SocketType) -> int: +def best_guess_mtu(sock: SocketType[Any]) -> int: return 1500 - packet_header_overhead(sock) @@ -563,7 +563,7 @@ def _signable(*fields: bytes) -> bytes: def _make_cookie( - key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes + key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes ) -> bytes: assert len(salt) == SALT_BYTES assert len(key) == KEY_BYTES @@ -581,7 +581,7 @@ def _make_cookie( def valid_cookie( - key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes + key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes ) -> bool: if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] @@ -603,7 +603,7 @@ def valid_cookie( def challenge_for( - key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes + key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes ) -> bytes: salt = os.urandom(SALT_BYTES) tick = _current_cookie_tick() @@ -664,7 +664,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: async def handle_client_hello_untrusted( - endpoint: DTLSEndpoint, address: Address, packet: bytes + endpoint: DTLSEndpoint, address: Any, packet: bytes ) -> None: if endpoint._listening_context is None: return @@ -739,7 +739,7 @@ async def handle_client_hello_untrusted( async def dtls_receive_loop( - endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType + endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType[Any] ) -> None: try: while True: @@ -828,7 +828,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context): + def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -1178,7 +1178,7 @@ class DTLSEndpoint(metaclass=Final): """ - def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): + def __init__(self, socket: SocketType[Any], *, incoming_packets_buffer: int = 10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. global SSL @@ -1189,7 +1189,7 @@ def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") self._initialized = True - self.socket: SocketType = socket + self.socket: SocketType[Any] = socket self.incoming_packets_buffer = incoming_packets_buffer self._token = trio.lowlevel.current_trio_token() @@ -1198,7 +1198,7 @@ def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary() + self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary() self._listening_context: Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 322ae4006e..358d9f32f0 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -4,11 +4,11 @@ from collections.abc import Generator from contextlib import contextmanager from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import trio from trio._core._multierror import MultiError -from trio.socket import SOCK_STREAM, Address, SocketType, getaddrinfo, socket +from trio.socket import SOCK_STREAM, SocketType, getaddrinfo, socket if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -114,8 +114,8 @@ @contextmanager -def close_all() -> Generator[set[SocketType], None, None]: - sockets_to_close: set[SocketType] = set() +def close_all() -> Generator[set[SocketType[Any]], None, None]: # type: ignore[misc] + sockets_to_close: set[SocketType[Any]] = set() try: yield sockets_to_close finally: @@ -131,7 +131,6 @@ def close_all() -> Generator[set[SocketType], None, None]: raise MultiError(errs) -# workaround for list being invariant def reorder_for_rfc_6555_section_5_4( targets: list[ tuple[ @@ -139,23 +138,7 @@ def reorder_for_rfc_6555_section_5_4( SocketKind, int, str, - tuple[str, int] | tuple[str, int, int, int], - ] - ] | list[ - tuple[ - AddressFamily, - SocketKind, - int, - str, - tuple[str, int] - ] - ] | list[ - tuple[ - AddressFamily, - SocketKind, - int, - str, - tuple[str, int, int, int], + Any, ] ] ) -> None: @@ -172,12 +155,11 @@ def reorder_for_rfc_6555_section_5_4( # Found the first entry with a different address family; move it # so that it becomes the second item on the list. if i != 1: - # invariant workaround in arguments leads to type issues here - targets.insert(1, targets.pop(i)) # type: ignore[arg-type] + targets.insert(1, targets.pop(i)) break -def format_host_port(host: str | bytes, port: int|str) -> str: +def format_host_port(host: str | bytes, port: int | str) -> str: host = host.decode("ascii") if isinstance(host, bytes) else host if ":" in host: return f"[{host}]:{port}" @@ -310,7 +292,7 @@ async def open_tcp_stream( # Keeps track of the socket that we're going to complete with, # need to make sure this isn't automatically closed - winning_socket: SocketType | None = None + winning_socket: SocketType[Any] | None = None # Try connecting to the specified address. Possible outcomes: # - success: record connected socket in winning_socket and cancel @@ -321,7 +303,7 @@ async def open_tcp_stream( # face of crash or cancellation async def attempt_connect( socket_args: tuple[AddressFamily, SocketKind, int], - sockaddr: Address, + sockaddr: Any, attempt_failed: trio.Event, ) -> None: nonlocal winning_socket diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index d733537752..edb07f2e82 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -4,7 +4,7 @@ import errno from collections.abc import Generator from contextlib import contextmanager -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload import trio @@ -66,7 +66,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final): """ - def __init__(self, socket: SocketType): + def __init__(self, socket: SocketType[Any]): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -371,7 +371,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final): """ - def __init__(self, socket: SocketType): + def __init__(self, socket: SocketType[Any]): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: diff --git a/trio/_socket.py b/trio/_socket.py index 58d1d83cba..ca97a0ef37 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -12,10 +12,9 @@ Any, Awaitable, Callable, + Generic, Literal, - Optional, SupportsIndex, - Tuple, TypeVar, Union, overload, @@ -39,18 +38,7 @@ T = TypeVar("T") - -# must use old-style typing because it's evaluated at runtime -Address: TypeAlias = Union[ - str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int] -] -AddressWithNoneHost: TypeAlias = Union[ - str, - bytes, - Tuple[Optional[str], int], - Tuple[Optional[str], int, int], - Tuple[Optional[str], int, int, int], -] +AddressFormat = TypeVar("AddressFormat") # Usage: @@ -180,11 +168,7 @@ async def getaddrinfo( flags: int = 0, ) -> list[ tuple[ - AddressFamily, - SocketKind, - int, - str, - tuple[str, int] | tuple[str, int, int, int], + AddressFamily, SocketKind, int, str, tuple[str, int] | tuple[str, int, int, int] ] ]: """Look up a numeric address given a name. @@ -289,7 +273,7 @@ async def getprotobyname(name: str) -> int: ################################################################ -def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType: +def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType[Any]: """Convert a standard library :class:`socket.socket` object into a Trio socket object. @@ -298,12 +282,12 @@ def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType: @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) -def fromfd( +def fromfd( # type: ignore[misc] # use of Any in decorated function fd: SupportsIndex, family: AddressFamily | int = _stdlib_socket.AF_INET, type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, -) -> SocketType: +) -> SocketType[Any]: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -314,7 +298,7 @@ def fromfd( ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(info: bytes) -> SocketType: + def fromshare(info: bytes) -> SocketType[Any]: return from_stdlib_socket(_stdlib_socket.fromshare(info)) @@ -329,11 +313,11 @@ def fromshare(info: bytes) -> SocketType: @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) -def socketpair( +def socketpair( # type: ignore[misc] family: FamilyT = FamilyDefault, type: TypeT = SocketKind.SOCK_STREAM, proto: int = 0, -) -> tuple[SocketType, SocketType]: +) -> tuple[SocketType[Any], SocketType[Any]]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. @@ -343,12 +327,12 @@ def socketpair( @_wraps(_stdlib_socket.socket, assigned=(), updated=()) -def socket( +def socket( # type: ignore[misc] family: AddressFamily | int = _stdlib_socket.AF_INET, type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, fileno: int | None = None, -) -> SocketType: +) -> SocketType[Any]: """Create a new Trio socket, like :class:`socket.socket`. This function's behavior can be customized using @@ -416,9 +400,9 @@ def _make_simple_sock_method_wrapper( fn: Callable[Concatenate[_stdlib_socket.socket, P], T], wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], maybe_avail: bool = False, -) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]: - @_wraps(fn, assigned=("__name__",), updated=()) - async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T: +) -> Callable[Concatenate[_SocketType[Any], P], Awaitable[T]]: + @_wraps(fn, assigned=("__name__",), updated=()) # type: ignore + async def wrapper(self: _SocketType[Any], *args: P.args, **kwargs: P.kwargs) -> T: return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs) wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async. @@ -457,9 +441,9 @@ async def _resolve_address_nocp( proto: int, *, ipv6_v6only: bool | int, - address: AddressWithNoneHost, + address: AddressFormat, local: bool, -) -> Address: +) -> Any: # Do some pre-checking (or exit early for non-IP sockets) if family == _stdlib_socket.AF_INET: if not isinstance(address, tuple) or not len(address) == 2: @@ -474,8 +458,7 @@ async def _resolve_address_nocp( assert isinstance(address, (str, bytes)) return os.fspath(address) else: - # TODO: check for host is None? - return address # type: ignore[return-value] + return address # -- From here on we know we have IPv4 or IPV6 -- host: str | None @@ -489,7 +472,7 @@ async def _resolve_address_nocp( except (OSError, TypeError): pass else: - return address # type: ignore[return-value] + return address # Special cases to match the stdlib, see gh-277 if host == "": host = None @@ -518,17 +501,17 @@ async def _resolve_address_nocp( if family == _stdlib_socket.AF_INET6: list_normed = list(normed) assert len(normed) == 4 - # typechecking certainly doesn't like this logic, but given just how broad - # Address is, it's quite cumbersome to write the below without type: ignore if len(address) >= 3: - list_normed[2] = address[2] # type: ignore + list_normed[2] = address[2] if len(address) >= 4: - list_normed[3] = address[3] # type: ignore - return tuple(list_normed) # type: ignore + list_normed[3] = address[3] + return tuple(list_normed) return normed -class SocketType: +# _stdlib_socket.socket supports 13 different socket families: https://docs.python.org/3/library/socket.html?highlight=sendmsg#socket-families +# and the return type of several methods will depend on those. typeshed has ended up typing those return types as `Any` in most cases, but for users that know which family/families they're working in and wants complete type coverage they can specify the AddressFormat. +class SocketType(Generic[AddressFormat]): def __init__(self) -> None: if type(self) == SocketType: raise TypeError( @@ -536,24 +519,24 @@ def __init__(self) -> None: "want to construct a socket object" ) - def detach(self: SocketType) -> int: + def detach(self) -> int: raise NotImplementedError - def fileno(self: SocketType) -> int: + def fileno(self) -> int: raise NotImplementedError - def getpeername(self: SocketType) -> Any: + def getpeername(self) -> AddressFormat: raise NotImplementedError - def getsockname(self: SocketType) -> Any: + def getsockname(self) -> AddressFormat: raise NotImplementedError @overload - def getsockopt(self: SocketType, /, level: int, optname: int) -> int: + def getsockopt(self, /, level: int, optname: int) -> int: ... @overload - def getsockopt(self: SocketType, /, level: int, optname: int, buflen: int) -> bytes: + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... def getsockopt( @@ -562,15 +545,11 @@ def getsockopt( raise NotImplementedError @overload - def setsockopt( - self: SocketType, /, level: int, optname: int, value: int | Buffer - ) -> None: + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ... @overload - def setsockopt( - self: SocketType, /, level: int, optname: int, value: None, optlen: int - ) -> None: + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: ... def setsockopt( @@ -583,29 +562,27 @@ def setsockopt( ) -> None: raise NotImplementedError - def listen( - self: SocketType, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128) - ) -> None: + def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: raise NotImplementedError - def get_inheritable(self: SocketType) -> bool: + def get_inheritable(self) -> bool: raise NotImplementedError - def set_inheritable(self: SocketType, inheritable: bool) -> None: + def set_inheritable(self, inheritable: bool) -> None: raise NotImplementedError if sys.platform == "win32" or ( not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") ): - def share(self: SocketType, /, process_id: int) -> bytes: + def share(self, /, process_id: int) -> bytes: raise NotImplementedError def __enter__(self) -> Self: raise NotImplementedError def __exit__( - self: SocketType, + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, @@ -613,15 +590,15 @@ def __exit__( raise NotImplementedError @property - def family(self: SocketType) -> AddressFamily: + def family(self) -> AddressFamily: raise NotImplementedError @property - def type(self: SocketType) -> SocketKind: + def type(self) -> SocketKind: raise NotImplementedError @property - def proto(self: SocketType) -> int: + def proto(self) -> int: raise NotImplementedError @property @@ -631,13 +608,13 @@ def did_shutdown_SHUT_WR(self) -> bool: def __repr__(self) -> str: raise NotImplementedError - def dup(self: SocketType) -> SocketType: + def dup(self) -> SocketType[AddressFormat]: raise NotImplementedError def close(self) -> None: raise NotImplementedError - async def bind(self, address: Address) -> None: + async def bind(self, address: AddressFormat) -> None: raise NotImplementedError def shutdown(self, flag: int) -> None: @@ -649,57 +626,57 @@ def is_readable(self) -> bool: async def wait_writable(self) -> None: raise NotImplementedError - async def accept(self: SocketType) -> tuple[SocketType, object]: + async def accept(self) -> tuple[SocketType[AddressFormat], AddressFormat]: raise NotImplementedError - async def connect(self, address: Address) -> None: + async def connect(self, address: AddressFormat) -> None: raise NotImplementedError - def recv(__self: SocketType, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: + def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: raise NotImplementedError def recv_into( - __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0 + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 ) -> Awaitable[int]: raise NotImplementedError # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] def recvfrom( - __self: SocketType, __bufsize: int, __flags: int = 0 - ) -> Awaitable[tuple[bytes, Address]]: + __self, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, AddressFormat]]: raise NotImplementedError # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] def recvfrom_into( - __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0 - ) -> Awaitable[tuple[int, Address]]: + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[tuple[int, AddressFormat]]: raise NotImplementedError - if sys.platform == "win32" or ( + if sys.platform != "win32" or ( not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg") ): def recvmsg( - __self: SocketType, + __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0, ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: raise NotImplementedError - if sys.platform == "win32" or ( + if sys.platform != "win32" or ( not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into") ): def recvmsg_into( - __self: SocketType, + __self, __buffers: Iterable[Buffer], __ancbufsize: int = 0, __flags: int = 0, ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: raise NotImplementedError - def send(__self: SocketType, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: + def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: raise NotImplementedError @overload @@ -720,7 +697,7 @@ async def sendto( async def sendto(self, *args: Any) -> int: raise NotImplementedError - if sys.platform == "win32" or ( + if sys.platform != "win32" or ( not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg") ): @@ -730,12 +707,12 @@ async def sendmsg( __buffers: Iterable[Buffer], __ancdata: Iterable[tuple[int, int, Buffer]] = (), __flags: int = 0, - __address: AddressWithNoneHost | None = None, + __address: AddressFormat | None = None, ) -> int: raise NotImplementedError -class _SocketType(SocketType): +class _SocketType(SocketType[AddressFormat]): def __init__(self, sock: _stdlib_socket.socket): if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we @@ -758,11 +735,11 @@ def detach(self) -> int: def fileno(self) -> int: return self._sock.fileno() - def getpeername(self) -> Any: - return self._sock.getpeername() + def getpeername(self) -> AddressFormat: + return self._sock.getpeername() # type: ignore[no-any-return] - def getsockname(self) -> Any: - return self._sock.getsockname() + def getsockname(self) -> AddressFormat: + return self._sock.getsockname() # type: ignore[no-any-return] @overload def getsockopt(self, /, level: int, optname: int) -> int: @@ -856,7 +833,7 @@ def did_shutdown_SHUT_WR(self) -> bool: def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") - def dup(self) -> _SocketType: + def dup(self) -> SocketType[AddressFormat]: """Same as :meth:`socket.socket.dup`.""" return _SocketType(self._sock.dup()) @@ -865,12 +842,12 @@ def close(self) -> None: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address: AddressWithNoneHost) -> None: + async def bind(self, address: AddressFormat) -> None: address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") and self.family == _stdlib_socket.AF_UNIX - and address[0] + and address[0] # type: ignore[index] ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) @@ -881,7 +858,7 @@ async def bind(self, address: AddressWithNoneHost) -> None: # there aren't yet any real systems that do this, so we'll worry # about it when it happens. await trio.lowlevel.checkpoint() - return self._sock.bind(address) + return self._sock.bind(address) # type: ignore[arg-type] def shutdown(self, flag: int) -> None: # no need to worry about return value b/c always returns None: @@ -904,17 +881,17 @@ async def wait_writable(self) -> None: async def _resolve_address_nocp( self, - address: AddressWithNoneHost, + address: AddressFormat, *, local: bool, - ) -> Address: + ) -> AddressFormat: if self.family == _stdlib_socket.AF_INET6: ipv6_v6only = self._sock.getsockopt( _stdlib_socket.IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY ) else: ipv6_v6only = False - return await _resolve_address_nocp( + return await _resolve_address_nocp( # type: ignore[no-any-return] self.type, self.family, self.proto, @@ -975,7 +952,7 @@ async def _nonblocking_helper( _stdlib_socket.socket.accept, _core.wait_readable ) - async def accept(self) -> tuple[SocketType, object]: + async def accept(self) -> tuple[SocketType[AddressFormat], AddressFormat]: """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr @@ -984,7 +961,7 @@ async def accept(self) -> tuple[SocketType, object]: # connect ################################################################ - async def connect(self, address: AddressWithNoneHost) -> None: + async def connect(self, address: AddressFormat) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the @@ -1039,7 +1016,7 @@ async def connect(self, address: AddressWithNoneHost) -> None: # happens, someone will hopefully tell us, and then hopefully we # can investigate their system to figure out what its semantics # are. - return self._sock.connect(address) + return self._sock.connect(address) # type: ignore[arg-type] # It raised BlockingIOError, meaning that it's started the # connection attempt. We wait for it to complete: await _core.wait_writable(self._sock) @@ -1097,7 +1074,7 @@ def recv_into( # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] def recvfrom( __self, __bufsize: int, __flags: int = 0 - ) -> Awaitable[tuple[bytes, Address]]: + ) -> Awaitable[tuple[bytes, AddressFormat]]: ... recvfrom = _make_simple_sock_method_wrapper( # noqa: F811 @@ -1112,7 +1089,7 @@ def recvfrom( # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] def recvfrom_into( __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 - ) -> Awaitable[tuple[int, Address]]: + ) -> Awaitable[tuple[int, AddressFormat]]: ... recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811 @@ -1123,7 +1100,9 @@ def recvfrom_into( # recvmsg ################################################################ - if hasattr(_stdlib_socket.socket, "recvmsg"): + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg") + ): if TYPE_CHECKING: def recvmsg( @@ -1139,7 +1118,9 @@ def recvmsg( # recvmsg_into ################################################################ - if hasattr(_stdlib_socket.socket, "recvmsg_into"): + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into") + ): if TYPE_CHECKING: def recvmsg_into( @@ -1208,7 +1189,7 @@ async def sendmsg( __buffers: Iterable[Buffer], __ancdata: Iterable[tuple[int, int, Buffer]] = (), __flags: int = 0, - __address: AddressWithNoneHost | None = None, + __address: AddressFormat | None = None, ) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. @@ -1224,7 +1205,7 @@ async def sendmsg( __buffers, __ancdata, __flags, - __address, + __address, # type: ignore[arg-type] ) ################################################################ diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json index 2b89d28d8e..1f7e81c182 100644 --- a/trio/_tests/verify_types_darwin.json +++ b/trio/_tests/verify_types_darwin.json @@ -40,7 +40,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 631, + "withKnownType": 630, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json index ea5af77abc..e128b89775 100644 --- a/trio/_tests/verify_types_linux.json +++ b/trio/_tests/verify_types_linux.json @@ -28,7 +28,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 628, + "withKnownType": 627, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json index 5d3e29a5dc..c14fb3bee9 100644 --- a/trio/_tests/verify_types_windows.json +++ b/trio/_tests/verify_types_windows.json @@ -7,7 +7,7 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9857369255150554, + "completenessScore": 0.9857142857142858, "diagnostics": [ { "message": "Return type annotation is missing", @@ -144,7 +144,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 622, + "withKnownType": 621, "withUnknownType": 9 }, "ignoreUnknownTypesFromImports": true, @@ -180,7 +180,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 671, + "withKnownType": 667, "withUnknownType": 0 }, "packageName": "trio" diff --git a/trio/socket.py b/trio/socket.py index ddefc72649..a9e276c782 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -34,7 +34,6 @@ # import the overwrites from ._socket import ( - Address as Address, SocketType as SocketType, from_stdlib_socket as from_stdlib_socket, fromfd as fromfd, From ab66937342b4adbde63f4bc8bf1ab2e47a7fba42 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 5 Sep 2023 11:48:24 +0200 Subject: [PATCH 09/20] fix url --- trio/_socket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_socket.py b/trio/_socket.py index ca97a0ef37..8dee44a279 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -509,7 +509,7 @@ async def _resolve_address_nocp( return normed -# _stdlib_socket.socket supports 13 different socket families: https://docs.python.org/3/library/socket.html?highlight=sendmsg#socket-families +# _stdlib_socket.socket supports 13 different socket families: https://docs.python.org/3/library/socket.html#socket-families # and the return type of several methods will depend on those. typeshed has ended up typing those return types as `Any` in most cases, but for users that know which family/families they're working in and wants complete type coverage they can specify the AddressFormat. class SocketType(Generic[AddressFormat]): def __init__(self) -> None: From 304fd32b022e01be788a02e5862c42d5e04ec80d Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 5 Sep 2023 13:40:33 +0200 Subject: [PATCH 10/20] fix old imports --- trio/_tests/test_highlevel_open_tcp_listeners.py | 1 - trio/_tests/test_highlevel_open_tcp_stream.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py index fe7c55483d..776920dade 100644 --- a/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -17,7 +17,6 @@ from .. import socket as tsocket from .._core._tests.tutil import binds_ipv6 -from .._socket import AddressWithNoneHost if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py index 775ab1f644..0ea305b5fb 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -15,7 +15,7 @@ open_tcp_stream, reorder_for_rfc_6555_section_5_4, ) -from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, Address, SocketType +from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType if TYPE_CHECKING: from trio.testing import MockClock From 920d6630c930151d0a202e82bbc49621db437048 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 16 Sep 2023 12:07:09 +0200 Subject: [PATCH 11/20] fix verify_types --- trio/_tests/verify_types_darwin.json | 2 +- trio/_tests/verify_types_linux.json | 2 +- trio/_tests/verify_types_windows.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json index d98fbacaf0..507af26b3d 100644 --- a/trio/_tests/verify_types_darwin.json +++ b/trio/_tests/verify_types_darwin.json @@ -76,7 +76,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 686, + "withKnownType": 687, "withUnknownType": 0 }, "packageName": "trio" diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json index d62d04db82..61ce9afcc5 100644 --- a/trio/_tests/verify_types_linux.json +++ b/trio/_tests/verify_types_linux.json @@ -64,7 +64,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 686, + "withKnownType": 687, "withUnknownType": 0 }, "packageName": "trio" diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json index 555003ae92..e825282bf0 100644 --- a/trio/_tests/verify_types_windows.json +++ b/trio/_tests/verify_types_windows.json @@ -180,7 +180,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 677, + "withKnownType": 674, "withUnknownType": 0 }, "packageName": "trio" From a3169cc11adbd4fb1cd0d4fa9defa9a84451cba4 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 16 Sep 2023 14:26:40 +0200 Subject: [PATCH 12/20] reverse making SocketType generic, CI fixes --- trio/_abc.py | 6 +- trio/_dtls.py | 12 ++-- trio/_highlevel_open_tcp_stream.py | 6 +- trio/_highlevel_socket.py | 6 +- trio/_socket.py | 64 +++++++++++-------- .../test_highlevel_open_tcp_listeners.py | 8 +-- trio/_tests/test_highlevel_open_tcp_stream.py | 8 +-- trio/_tests/test_highlevel_socket.py | 10 +-- trio/_tests/test_socket.py | 10 +-- trio/testing/_fake_net.py | 12 +++- 10 files changed, 77 insertions(+), 65 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index 724df92855..dbc7966cee 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -2,7 +2,7 @@ import socket from abc import ABCMeta, abstractmethod -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar import trio @@ -209,12 +209,12 @@ class SocketFactory(metaclass=ABCMeta): """ @abstractmethod - def socket( # type: ignore[misc] + def socket( self, family: socket.AddressFamily | int = socket.AF_INET, type: socket.SocketKind | int = socket.SOCK_STREAM, proto: int = 0, - ) -> SocketType[Any]: + ) -> SocketType: """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, diff --git a/trio/_dtls.py b/trio/_dtls.py index 63ca83db1f..9a98dd8f5f 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -48,21 +48,21 @@ MAX_UDP_PACKET_SIZE = 65527 -def packet_header_overhead(sock: SocketType[Any]) -> int: +def packet_header_overhead(sock: SocketType) -> int: if sock.family == trio.socket.AF_INET: return 28 else: return 48 -def worst_case_mtu(sock: SocketType[Any]) -> int: +def worst_case_mtu(sock: SocketType) -> int: if sock.family == trio.socket.AF_INET: return 576 - packet_header_overhead(sock) else: return 1280 - packet_header_overhead(sock) -def best_guess_mtu(sock: SocketType[Any]) -> int: +def best_guess_mtu(sock: SocketType) -> int: return 1500 - packet_header_overhead(sock) @@ -739,7 +739,7 @@ async def handle_client_hello_untrusted( async def dtls_receive_loop( - endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType[Any] + endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType ) -> None: try: while True: @@ -1178,7 +1178,7 @@ class DTLSEndpoint(metaclass=Final): """ - def __init__(self, socket: SocketType[Any], *, incoming_packets_buffer: int = 10): + def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. global SSL @@ -1189,7 +1189,7 @@ def __init__(self, socket: SocketType[Any], *, incoming_packets_buffer: int = 10 if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") self._initialized = True - self.socket: SocketType[Any] = socket + self.socket: SocketType = socket self.incoming_packets_buffer = incoming_packets_buffer self._token = trio.lowlevel.current_trio_token() diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 358d9f32f0..c3dc157827 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -114,8 +114,8 @@ @contextmanager -def close_all() -> Generator[set[SocketType[Any]], None, None]: # type: ignore[misc] - sockets_to_close: set[SocketType[Any]] = set() +def close_all() -> Generator[set[SocketType], None, None]: + sockets_to_close: set[SocketType] = set() try: yield sockets_to_close finally: @@ -292,7 +292,7 @@ async def open_tcp_stream( # Keeps track of the socket that we're going to complete with, # need to make sure this isn't automatically closed - winning_socket: SocketType[Any] | None = None + winning_socket: SocketType | None = None # Try connecting to the specified address. Possible outcomes: # - success: record connected socket in winning_socket and cancel diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index edb07f2e82..d733537752 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -4,7 +4,7 @@ import errno from collections.abc import Generator from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, overload import trio @@ -66,7 +66,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final): """ - def __init__(self, socket: SocketType[Any]): + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -371,7 +371,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final): """ - def __init__(self, socket: SocketType[Any]): + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: diff --git a/trio/_socket.py b/trio/_socket.py index 8dee44a279..84059f5ba3 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -12,7 +12,6 @@ Any, Awaitable, Callable, - Generic, Literal, SupportsIndex, TypeVar, @@ -38,7 +37,18 @@ T = TypeVar("T") -AddressFormat = TypeVar("AddressFormat") + +# _stdlib_socket.socket supports 13 different socket families, see +# https://docs.python.org/3/library/socket.html#socket-families +# and the return type of several methods in SocketType will depend on those. Typeshed +# has ended up typing those return types as `Any` in most cases, but for users that +# know which family/families they're working in we could make SocketType a generic type, +# where you specify the return values you expect from those methods depending on the +# protocol the socket will be handling. +# But without the ability to default the value to `Any` it will be overly cumbersome for +# most users, so currently we just specify it as `Any`. +# AddressFormat = TypeVar("AddressFormat") +AddressFormat: TypeAlias = Any # Usage: @@ -273,7 +283,7 @@ async def getprotobyname(name: str) -> int: ################################################################ -def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType[Any]: +def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType: """Convert a standard library :class:`socket.socket` object into a Trio socket object. @@ -282,12 +292,12 @@ def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType[Any]: @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) -def fromfd( # type: ignore[misc] # use of Any in decorated function +def fromfd( fd: SupportsIndex, family: AddressFamily | int = _stdlib_socket.AF_INET, type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, -) -> SocketType[Any]: +) -> SocketType: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -298,7 +308,7 @@ def fromfd( # type: ignore[misc] # use of Any in decorated function ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(info: bytes) -> SocketType[Any]: + def fromshare(info: bytes) -> SocketType: return from_stdlib_socket(_stdlib_socket.fromshare(info)) @@ -313,11 +323,11 @@ def fromshare(info: bytes) -> SocketType[Any]: @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) -def socketpair( # type: ignore[misc] +def socketpair( family: FamilyT = FamilyDefault, type: TypeT = SocketKind.SOCK_STREAM, proto: int = 0, -) -> tuple[SocketType[Any], SocketType[Any]]: +) -> tuple[SocketType, SocketType]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. @@ -327,12 +337,12 @@ def socketpair( # type: ignore[misc] @_wraps(_stdlib_socket.socket, assigned=(), updated=()) -def socket( # type: ignore[misc] +def socket( family: AddressFamily | int = _stdlib_socket.AF_INET, type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, fileno: int | None = None, -) -> SocketType[Any]: +) -> SocketType: """Create a new Trio socket, like :class:`socket.socket`. This function's behavior can be customized using @@ -400,9 +410,9 @@ def _make_simple_sock_method_wrapper( fn: Callable[Concatenate[_stdlib_socket.socket, P], T], wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], maybe_avail: bool = False, -) -> Callable[Concatenate[_SocketType[Any], P], Awaitable[T]]: - @_wraps(fn, assigned=("__name__",), updated=()) # type: ignore - async def wrapper(self: _SocketType[Any], *args: P.args, **kwargs: P.kwargs) -> T: +) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]: + @_wraps(fn, assigned=("__name__",), updated=()) + async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T: return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs) wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async. @@ -509,9 +519,7 @@ async def _resolve_address_nocp( return normed -# _stdlib_socket.socket supports 13 different socket families: https://docs.python.org/3/library/socket.html#socket-families -# and the return type of several methods will depend on those. typeshed has ended up typing those return types as `Any` in most cases, but for users that know which family/families they're working in and wants complete type coverage they can specify the AddressFormat. -class SocketType(Generic[AddressFormat]): +class SocketType: def __init__(self) -> None: if type(self) == SocketType: raise TypeError( @@ -608,7 +616,7 @@ def did_shutdown_SHUT_WR(self) -> bool: def __repr__(self) -> str: raise NotImplementedError - def dup(self) -> SocketType[AddressFormat]: + def dup(self) -> SocketType: raise NotImplementedError def close(self) -> None: @@ -626,7 +634,7 @@ def is_readable(self) -> bool: async def wait_writable(self) -> None: raise NotImplementedError - async def accept(self) -> tuple[SocketType[AddressFormat], AddressFormat]: + async def accept(self) -> tuple[SocketType, AddressFormat]: raise NotImplementedError async def connect(self, address: AddressFormat) -> None: @@ -712,7 +720,7 @@ async def sendmsg( raise NotImplementedError -class _SocketType(SocketType[AddressFormat]): +class _SocketType(SocketType): def __init__(self, sock: _stdlib_socket.socket): if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we @@ -736,10 +744,10 @@ def fileno(self) -> int: return self._sock.fileno() def getpeername(self) -> AddressFormat: - return self._sock.getpeername() # type: ignore[no-any-return] + return self._sock.getpeername() def getsockname(self) -> AddressFormat: - return self._sock.getsockname() # type: ignore[no-any-return] + return self._sock.getsockname() @overload def getsockopt(self, /, level: int, optname: int) -> int: @@ -833,7 +841,7 @@ def did_shutdown_SHUT_WR(self) -> bool: def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") - def dup(self) -> SocketType[AddressFormat]: + def dup(self) -> SocketType: """Same as :meth:`socket.socket.dup`.""" return _SocketType(self._sock.dup()) @@ -847,7 +855,7 @@ async def bind(self, address: AddressFormat) -> None: if ( hasattr(_stdlib_socket, "AF_UNIX") and self.family == _stdlib_socket.AF_UNIX - and address[0] # type: ignore[index] + and address[0] ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) @@ -858,7 +866,7 @@ async def bind(self, address: AddressFormat) -> None: # there aren't yet any real systems that do this, so we'll worry # about it when it happens. await trio.lowlevel.checkpoint() - return self._sock.bind(address) # type: ignore[arg-type] + return self._sock.bind(address) def shutdown(self, flag: int) -> None: # no need to worry about return value b/c always returns None: @@ -891,7 +899,7 @@ async def _resolve_address_nocp( ) else: ipv6_v6only = False - return await _resolve_address_nocp( # type: ignore[no-any-return] + return await _resolve_address_nocp( self.type, self.family, self.proto, @@ -952,7 +960,7 @@ async def _nonblocking_helper( _stdlib_socket.socket.accept, _core.wait_readable ) - async def accept(self) -> tuple[SocketType[AddressFormat], AddressFormat]: + async def accept(self) -> tuple[SocketType, AddressFormat]: """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr @@ -1016,7 +1024,7 @@ async def connect(self, address: AddressFormat) -> None: # happens, someone will hopefully tell us, and then hopefully we # can investigate their system to figure out what its semantics # are. - return self._sock.connect(address) # type: ignore[arg-type] + return self._sock.connect(address) # It raised BlockingIOError, meaning that it's started the # connection attempt. We wait for it to complete: await _core.wait_writable(self._sock) @@ -1205,7 +1213,7 @@ async def sendmsg( __buffers, __ancdata, __flags, - __address, # type: ignore[arg-type] + __address, ) ################################################################ diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py index 776920dade..b6cfb55732 100644 --- a/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -5,7 +5,7 @@ import sys from math import inf from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING, Sequence, overload +from typing import TYPE_CHECKING, Any, Sequence, overload import attr import pytest @@ -140,7 +140,7 @@ def family(self) -> AddressFamily: return self._family @property - def proto(self) -> int: + def proto(self) -> int: # pragma: no cover return self._proto @overload @@ -176,7 +176,7 @@ def setsockopt( ) -> None: pass - async def bind(self, address: AddressWithNoneHost) -> None: + async def bind(self, address: Any) -> None: pass def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None: @@ -289,7 +289,7 @@ async def handler(stream: SendStream) -> None: async with trio.open_nursery() as nursery: # nursery.start is incorrectly typed, awaiting #2773 - listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) # type: ignore[arg-type] + listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) stream = await open_stream_to_socket_listener(listeners[0]) async with stream: await stream.receive_some(1) == b"x" diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py index 0ea305b5fb..3992c83630 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -169,7 +169,7 @@ async def test_local_address_real() -> None: await client_stream.aclose() server_sock.close() # accept returns tuple[SocketType, object], due to typeshed returning `Any` - assert remote_addr[0] == local_address # type: ignore[index] + assert remote_addr[0] == local_address # Trying to connect to an ipv4 address with the ipv6 wildcard # local_address should fail @@ -206,14 +206,14 @@ def type(self) -> SocketKind: return self._type @property - def family(self) -> AddressFamily: + def family(self) -> AddressFamily: # pragma: no cover return self._family @property - def proto(self) -> int: + def proto(self) -> int: # pragma: no cover return self._proto - async def connect(self, sockaddr: Address) -> None: + async def connect(self, sockaddr: tuple[str | int, str | int | None]) -> None: self.ip = sockaddr[0] self.port = sockaddr[1] assert self.ip not in self.scenario.sockets diff --git a/trio/_tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py index c5d46b6c6a..830a153c00 100644 --- a/trio/_tests/test_highlevel_socket.py +++ b/trio/_tests/test_highlevel_socket.py @@ -222,10 +222,12 @@ def getsockopt(self, /, level: int, optname: int) -> int: ... @overload - def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + def getsockopt( # noqa: F811 + self, /, level: int, optname: int, buflen: int + ) -> bytes: ... - def getsockopt( + def getsockopt( # noqa: F811 self, /, level: int, optname: int, buflen: int | None = None ) -> int | bytes: return True @@ -235,12 +237,12 @@ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ... @overload - def setsockopt( + def setsockopt( # noqa: F811 self, /, level: int, optname: int, value: None, optlen: int ) -> None: ... - def setsockopt( + def setsockopt( # noqa: F811 self, /, level: int, diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py index b405e97e72..40ffefb8cd 100644 --- a/trio/_tests/test_socket.py +++ b/trio/_tests/test_socket.py @@ -289,7 +289,7 @@ async def child(sock: SocketType) -> None: @pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only") async def test_fromshare() -> None: - if TYPE_CHECKING and sys.platform != "win32": + if TYPE_CHECKING and sys.platform != "win32": # pragma: no cover return a, b = tsocket.socketpair() with a, b: @@ -585,12 +585,8 @@ async def res( | tuple[str, str] | tuple[str, str, int] | tuple[str, str, int, int] - ) -> tuple[str, int] | tuple[str, int, int, int]: - # we're only passing IP sockets, so we ignore the str/bytes return type - # But what about when port/family is a string? Should that be part of the public API? - res = await sock._resolve_address_nocp(args, local=local) # type: ignore[arg-type] - # no str/bytes - return res # type: ignore[return-value] + ) -> Any: + return await sock._resolve_address_nocp(args, local=local) assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) if v6: diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index 2a358119f3..f0ab3f53e4 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -11,7 +11,7 @@ import errno import ipaddress import os -from typing import TYPE_CHECKING, Optional, Union, Type +from typing import TYPE_CHECKING, Optional, Type, Union import attr @@ -21,9 +21,10 @@ if TYPE_CHECKING: from socket import AddressFamily, SocketKind from types import TracebackType + from typing_extensions import TypeAlias -IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] +IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] def _family_for(ip: IPAddress) -> int: @@ -175,7 +176,9 @@ def deliver_packet(self, packet) -> None: class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor): - def __init__(self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, proto: int): + def __init__( + self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, proto: int + ): self._fake_net = fake_net if not family: @@ -200,12 +203,15 @@ def __init__(self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, p # This is the source-of-truth for what port etc. this socket is bound to self._binding: Optional[UDPBinding] = None + @property def type(self) -> SocketKind: return self._type + @property def family(self) -> AddressFamily: return self._family + @property def proto(self) -> int: return self._proto From b9bbc06868795255562e97e07988b7066ccb9aef Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 16 Sep 2023 14:35:07 +0200 Subject: [PATCH 13/20] fix verifytypes --- trio/_tests/verify_types_darwin.json | 2 +- trio/_tests/verify_types_linux.json | 2 +- trio/_tests/verify_types_windows.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json index 507af26b3d..d98fbacaf0 100644 --- a/trio/_tests/verify_types_darwin.json +++ b/trio/_tests/verify_types_darwin.json @@ -76,7 +76,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 687, + "withKnownType": 686, "withUnknownType": 0 }, "packageName": "trio" diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json index 61ce9afcc5..d62d04db82 100644 --- a/trio/_tests/verify_types_linux.json +++ b/trio/_tests/verify_types_linux.json @@ -64,7 +64,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 687, + "withKnownType": 686, "withUnknownType": 0 }, "packageName": "trio" diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json index e825282bf0..0c33864b3a 100644 --- a/trio/_tests/verify_types_windows.json +++ b/trio/_tests/verify_types_windows.json @@ -180,7 +180,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 674, + "withKnownType": 673, "withUnknownType": 0 }, "packageName": "trio" From 76ab99427f893e98d7b3f2037f6ec9a05646bb8b Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 20 Sep 2023 15:57:24 +0200 Subject: [PATCH 14/20] fix suggestions from teamspen --- trio/_tests/test_highlevel_open_tcp_stream.py | 2 +- trio/testing/_fake_net.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py index 3992c83630..eb7929c2db 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -377,7 +377,7 @@ async def run_scenario( return (exc, scenario) -async def test_one_host_quick_success(autojump_clock: trio.testing.MockClock) -> None: +async def test_one_host_quick_success(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")]) assert isinstance(sock, FakeSocket) assert sock.ip == "1.2.3.4" diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index f0ab3f53e4..fe886ba131 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -8,10 +8,11 @@ from __future__ import annotations +import builtins import errno import ipaddress import os -from typing import TYPE_CHECKING, Optional, Type, Union +from typing import TYPE_CHECKING, Optional, Union import attr @@ -378,8 +379,7 @@ def __enter__(self): def __exit__( self, - # builtin `type` is shadowed by the property - exc_type: Type[BaseException] | None, + exc_type: builtins.type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: From 0c294ea1729300424d3e084dcb303ab78c07c858 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Fri, 22 Sep 2023 19:16:46 -0500 Subject: [PATCH 15/20] Update `verify_types` --- trio/_tests/verify_types_darwin.json | 2 +- trio/_tests/verify_types_linux.json | 2 +- trio/_tests/verify_types_windows.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json index e83a324714..acb4f58c07 100644 --- a/trio/_tests/verify_types_darwin.json +++ b/trio/_tests/verify_types_darwin.json @@ -76,7 +76,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 685, + "withKnownType": 684, "withUnknownType": 0 }, "packageName": "trio" diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json index 7c9d745dba..c060b2a51e 100644 --- a/trio/_tests/verify_types_linux.json +++ b/trio/_tests/verify_types_linux.json @@ -64,7 +64,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 685, + "withKnownType": 684, "withUnknownType": 0 }, "packageName": "trio" diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json index 3da2740491..249bb179ae 100644 --- a/trio/_tests/verify_types_windows.json +++ b/trio/_tests/verify_types_windows.json @@ -180,7 +180,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 672, + "withKnownType": 671, "withUnknownType": 0 }, "packageName": "trio" From dc09e7eeb7ed5f6969d31920b7eeb1f80ad581fd Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 26 Sep 2023 09:44:51 +0200 Subject: [PATCH 16/20] clarifying comments after a5rocks review --- trio/_socket.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/trio/_socket.py b/trio/_socket.py index 84059f5ba3..77731708ba 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -46,8 +46,9 @@ # where you specify the return values you expect from those methods depending on the # protocol the socket will be handling. # But without the ability to default the value to `Any` it will be overly cumbersome for -# most users, so currently we just specify it as `Any`. -# AddressFormat = TypeVar("AddressFormat") +# most users, so currently we just specify it as `Any`. Otherwise we would write: +# `AddressFormat = TypeVar("AddressFormat")` +# but instead we simply do: AddressFormat: TypeAlias = Any @@ -521,6 +522,9 @@ async def _resolve_address_nocp( class SocketType: def __init__(self) -> None: + # make sure this __init__ works with multiple inheritance + super().__init__() + # and only raises error if it's directly constructed if type(self) == SocketType: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " @@ -640,6 +644,7 @@ async def accept(self) -> tuple[SocketType, AddressFormat]: async def connect(self, address: AddressFormat) -> None: raise NotImplementedError + # argument names with __ used because of typeshed, see comment for recv in _SocketType def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: raise NotImplementedError From d934724af349736ba3e0a15cd17e7601c77c75f7 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 26 Sep 2023 11:47:02 +0200 Subject: [PATCH 17/20] fix verify_types ............ --- trio/_tests/verify_types_windows.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json index bab6797c02..232b18757c 100644 --- a/trio/_tests/verify_types_windows.json +++ b/trio/_tests/verify_types_windows.json @@ -60,7 +60,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 631, + "withKnownType": 630, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, @@ -96,7 +96,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 680, + "withKnownType": 676, "withUnknownType": 0 }, "packageName": "trio" From 73351845510d14008fb436f1099e90c242461208 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 26 Sep 2023 13:46:04 +0200 Subject: [PATCH 18/20] add @_public to exclude_lines in coveragerc --- .coveragerc | 1 + 1 file changed, 1 insertion(+) diff --git a/.coveragerc b/.coveragerc index 4911012653..5f6bcb3e11 100644 --- a/.coveragerc +++ b/.coveragerc @@ -26,6 +26,7 @@ exclude_lines = @overload class .*\bProtocol\b.*\): raise NotImplementedError + @_public partial_branches = pragma: no branch From 618fdbf932ed86a796c0e1da80b55cc3ad291b73 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 26 Sep 2023 13:57:49 +0200 Subject: [PATCH 19/20] Revert "add @_public to exclude_lines in coveragerc" This reverts commit 73351845510d14008fb436f1099e90c242461208. --- .coveragerc | 1 - 1 file changed, 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index 5f6bcb3e11..4911012653 100644 --- a/.coveragerc +++ b/.coveragerc @@ -26,7 +26,6 @@ exclude_lines = @overload class .*\bProtocol\b.*\): raise NotImplementedError - @_public partial_branches = pragma: no branch From a20a909f29c390c09081e7246649854a345ef36e Mon Sep 17 00:00:00 2001 From: jakkdl Date: Thu, 28 Sep 2023 14:05:26 +0200 Subject: [PATCH 20/20] 'pragma: no cover' an added type-converting line --- trio/_tests/test_highlevel_open_tcp_listeners.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py index b6cfb55732..6f39446c7e 100644 --- a/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -204,8 +204,8 @@ def socket( ) -> tsocket.SocketType: assert family is not None assert type is not None - if isinstance(family, int): - family = AddressFamily(family) + if isinstance(family, int) and not isinstance(family, AddressFamily): + family = AddressFamily(family) # pragma: no cover if family in self.raise_on_family: raise OSError(self.raise_on_family[family], "nope") sock = FakeSocket(family, type, proto)