From fbae13dd98f64a369385c922b5367fff0156ad41 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 23 Aug 2023 18:53:22 +0200
Subject: [PATCH 01/20] wip
---
docs/source/conf.py | 2 -
docs/source/reference-io.rst | 3 +-
pyproject.toml | 5 +-
trio/_abc.py | 4 +-
trio/_dtls.py | 14 +-
trio/_highlevel_open_tcp_stream.py | 32 ++-
trio/_highlevel_socket.py | 2 +-
trio/_socket.py | 122 +++++++++-
.../test_highlevel_open_tcp_listeners.py | 19 +-
trio/_tests/test_highlevel_open_tcp_stream.py | 214 ++++++++++++------
trio/_tests/test_socket.py | 135 ++++++-----
trio/socket.py | 1 -
trio/testing/_fake_net.py | 23 +-
13 files changed, 407 insertions(+), 169 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 66aa8dea05..ee23ce587f 100755
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -54,8 +54,6 @@
("py:class", "sync function"),
# why aren't these found in stdlib?
("py:class", "types.FrameType"),
- # TODO: temporary type
- ("py:class", "_SocketType"),
# these are not defined in https://docs.python.org/3/objects.inv
("py:class", "socket.AddressFamily"),
("py:class", "socket.SocketKind"),
diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst
index e270033b46..61bbef78c2 100644
--- a/docs/source/reference-io.rst
+++ b/docs/source/reference-io.rst
@@ -506,7 +506,8 @@ Socket objects
The internal SocketType
~~~~~~~~~~~~~~~~~~~~~~~~~~
-.. autoclass:: _SocketType
+..
+ .. autoclass:: _SocketType
..
TODO: adding `:members:` here gives error due to overload+_wraps on `sendto`
TODO: rewrite ... all of the above when fixing _SocketType vs SocketType
diff --git a/pyproject.toml b/pyproject.toml
index 24be2d07bf..64e15a4c71 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -87,6 +87,7 @@ module = [
"trio/_core/_tests/test_tutil",
"trio/_core/_tests/test_unbounded_queue",
"trio/_core/_tests/tutil",
+"trio/_tests/check_type_completeness",
"trio/_tests/pytest_plugin",
"trio/_tests/test_abc",
"trio/_tests/test_channel",
@@ -95,16 +96,14 @@ module = [
"trio/_tests/test_exports",
"trio/_tests/test_file_io",
"trio/_tests/test_highlevel_generic",
+"trio/_tests/test_highlevel_socket",
"trio/_tests/test_highlevel_open_tcp_listeners",
-"trio/_tests/test_highlevel_open_tcp_stream",
"trio/_tests/test_highlevel_open_unix_stream",
"trio/_tests/test_highlevel_serve_listeners",
-"trio/_tests/test_highlevel_socket",
"trio/_tests/test_highlevel_ssl_helpers",
"trio/_tests/test_path",
"trio/_tests/test_scheduler_determinism",
"trio/_tests/test_signals",
-"trio/_tests/test_socket",
"trio/_tests/test_ssl",
"trio/_tests/test_subprocess",
"trio/_tests/test_sync",
diff --git a/trio/_abc.py b/trio/_abc.py
index 746360c8f8..6a99ea0842 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -12,7 +12,7 @@
from typing_extensions import Self
# both of these introduce circular imports if outside a TYPE_CHECKING guard
- from ._socket import _SocketType
+ from ._socket import SocketType
from .lowlevel import Task
@@ -214,7 +214,7 @@ def socket(
family: socket.AddressFamily | int | None = None,
type: socket.SocketKind | int | None = None,
proto: int | None = None,
- ) -> _SocketType:
+ ) -> SocketType:
"""Create and return a socket object.
Your socket object must inherit from :class:`trio.socket.SocketType`,
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 08b7672a2f..cc940cf206 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -42,26 +42,26 @@
from OpenSSL.SSL import Context
from typing_extensions import Self, TypeAlias
- from trio.socket import Address, _SocketType
+ from trio.socket import Address, SocketType
MAX_UDP_PACKET_SIZE = 65527
-def packet_header_overhead(sock: _SocketType) -> int:
+def packet_header_overhead(sock: SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 28
else:
return 48
-def worst_case_mtu(sock: _SocketType) -> int:
+def worst_case_mtu(sock: SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 576 - packet_header_overhead(sock)
else:
return 1280 - packet_header_overhead(sock)
-def best_guess_mtu(sock: _SocketType) -> int:
+def best_guess_mtu(sock: SocketType) -> int:
return 1500 - packet_header_overhead(sock)
@@ -738,7 +738,7 @@ async def handle_client_hello_untrusted(
async def dtls_receive_loop(
- endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType
+ endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType
) -> None:
try:
while True:
@@ -1177,7 +1177,7 @@ class DTLSEndpoint(metaclass=Final):
"""
- def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
+ def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10):
# We do this lazily on first construction, so only people who actually use DTLS
# have to install PyOpenSSL.
global SSL
@@ -1188,7 +1188,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
if socket.type != trio.socket.SOCK_DGRAM:
raise ValueError("DTLS requires a SOCK_DGRAM socket")
self._initialized = True
- self.socket: _SocketType = socket
+ self.socket: SocketType = socket
self.incoming_packets_buffer = incoming_packets_buffer
self._token = trio.lowlevel.current_trio_token()
diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py
index 0c4e8a4a8d..322ae4006e 100644
--- a/trio/_highlevel_open_tcp_stream.py
+++ b/trio/_highlevel_open_tcp_stream.py
@@ -8,7 +8,7 @@
import trio
from trio._core._multierror import MultiError
-from trio.socket import SOCK_STREAM, Address, _SocketType, getaddrinfo, socket
+from trio.socket import SOCK_STREAM, Address, SocketType, getaddrinfo, socket
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
@@ -114,8 +114,8 @@
@contextmanager
-def close_all() -> Generator[set[_SocketType], None, None]:
- sockets_to_close: set[_SocketType] = set()
+def close_all() -> Generator[set[SocketType], None, None]:
+ sockets_to_close: set[SocketType] = set()
try:
yield sockets_to_close
finally:
@@ -131,6 +131,7 @@ def close_all() -> Generator[set[_SocketType], None, None]:
raise MultiError(errs)
+# workaround for list being invariant
def reorder_for_rfc_6555_section_5_4(
targets: list[
tuple[
@@ -140,6 +141,22 @@ def reorder_for_rfc_6555_section_5_4(
str,
tuple[str, int] | tuple[str, int, int, int],
]
+ ] | list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ tuple[str, int]
+ ]
+ ] | list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ tuple[str, int, int, int],
+ ]
]
) -> None:
# RFC 6555 section 5.4 says that if getaddrinfo returns multiple address
@@ -155,11 +172,12 @@ def reorder_for_rfc_6555_section_5_4(
# Found the first entry with a different address family; move it
# so that it becomes the second item on the list.
if i != 1:
- targets.insert(1, targets.pop(i))
+ # invariant workaround in arguments leads to type issues here
+ targets.insert(1, targets.pop(i)) # type: ignore[arg-type]
break
-def format_host_port(host: str | bytes, port: int) -> str:
+def format_host_port(host: str | bytes, port: int|str) -> str:
host = host.decode("ascii") if isinstance(host, bytes) else host
if ":" in host:
return f"[{host}]:{port}"
@@ -193,7 +211,7 @@ async def open_tcp_stream(
*,
happy_eyeballs_delay: float | None = DEFAULT_DELAY,
local_address: str | None = None,
-) -> trio.abc.Stream:
+) -> trio.SocketStream:
"""Connect to the given host and port over TCP.
If the given ``host`` has multiple IP addresses associated with it, then
@@ -292,7 +310,7 @@ async def open_tcp_stream(
# Keeps track of the socket that we're going to complete with,
# need to make sure this isn't automatically closed
- winning_socket: _SocketType | None = None
+ winning_socket: SocketType | None = None
# Try connecting to the specified address. Possible outcomes:
# - success: record connected socket in winning_socket and cancel
diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py
index f8d01cd755..d733537752 100644
--- a/trio/_highlevel_socket.py
+++ b/trio/_highlevel_socket.py
@@ -15,7 +15,7 @@
if TYPE_CHECKING:
from typing_extensions import Buffer
- from ._socket import _SocketType as SocketType
+ from ._socket import SocketType
# XX TODO: this number was picked arbitrarily. We should do experiments to
# tune it. (Or make it dynamic -- one idea is to start small and increase it
diff --git a/trio/_socket.py b/trio/_socket.py
index 2834a5b055..0b64e5adec 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -282,7 +282,7 @@ async def getprotobyname(name: str) -> int:
################################################################
-def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType:
+def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType:
"""Convert a standard library :class:`socket.socket` object into a Trio
socket object.
@@ -296,7 +296,7 @@ def fromfd(
family: AddressFamily | int = _stdlib_socket.AF_INET,
type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
-) -> _SocketType:
+) -> SocketType:
"""Like :func:`socket.fromfd`, but returns a Trio socket object."""
family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd))
return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto))
@@ -307,7 +307,7 @@ def fromfd(
):
@_wraps(_stdlib_socket.fromshare, assigned=(), updated=())
- def fromshare(info: bytes) -> _SocketType:
+ def fromshare(info: bytes) -> SocketType:
return from_stdlib_socket(_stdlib_socket.fromshare(info))
@@ -326,7 +326,7 @@ def socketpair(
family: FamilyT = FamilyDefault,
type: TypeT = SocketKind.SOCK_STREAM,
proto: int = 0,
-) -> tuple[_SocketType, _SocketType]:
+) -> tuple[SocketType, SocketType]:
"""Like :func:`socket.socketpair`, but returns a pair of Trio socket
objects.
@@ -341,7 +341,7 @@ def socket(
type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
fileno: int | None = None,
-) -> _SocketType:
+) -> SocketType:
"""Create a new Trio socket, like :class:`socket.socket`.
This function's behavior can be customized using
@@ -530,7 +530,117 @@ def __init__(self) -> NoReturn:
"SocketType is an abstract class; use trio.socket.socket if you "
"want to construct a socket object"
)
+ def __enter__(self) -> Self:
+ raise NotImplementedError()
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> None:
+ raise NotImplementedError()
+ @property
+ def type(self) -> SocketKind:
+ raise NotImplementedError()
+ @property
+ def family(self) -> AddressFamily:
+ raise NotImplementedError()
+ @property
+ def proto(self) -> int:
+ raise NotImplementedError()
+ @property
+ def did_shutdown_SHUT_WR(self) -> bool:
+ raise NotImplementedError()
+ def is_readable(self) -> bool:
+ raise NotImplementedError()
+ def fileno(self) -> int:
+ raise NotImplementedError()
+ async def wait_writable(self) -> None:
+ raise NotImplementedError()
+ def shutdown(self, flag: int) -> None:
+ raise NotImplementedError()
+ async def connect(self, address: Address) -> None:
+ raise NotImplementedError()
+ def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
+ raise NotImplementedError()
+ async def bind(self, address: Address) -> None:
+ raise NotImplementedError()
+ def close(self) -> None:
+ raise NotImplementedError()
+ def getsockname(self) -> Any:
+ raise NotImplementedError()
+ async def accept(self) -> tuple[SocketType, object]:
+ raise NotImplementedError()
+
+ @overload
+ async def sendto(
+ self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer
+ ) -> int:
+ ...
+
+ @overload
+ async def sendto(
+ self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer
+ ) -> int:
+ ...
+
+ @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc]
+ async def sendto(self, *args: Any) -> int:
+ """Similar to :meth:`socket.socket.sendto`, but async."""
+ raise NotImplementedError()
+ @overload
+ def getsockopt(self, /, level: int, optname: int) -> int:
+ ...
+
+ @overload
+ def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
+ ...
+
+ def getsockopt(
+ self, /, level: int, optname: int, buflen: int | None = None
+ ) -> int | bytes:
+ raise NotImplementedError()
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
+ ...
+
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None:
+ ...
+
+ def setsockopt(
+ self,
+ /,
+ level: int,
+ optname: int,
+ value: int | Buffer | None,
+ optlen: int | None = None,
+ ) -> None:
+ raise NotImplementedError()
+
+ def recvfrom(
+ __self, __bufsize: int, __flags: int = 0
+ ) -> Awaitable[tuple[bytes, Address]]:
+ raise NotImplementedError()
+ def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]:
+ raise NotImplementedError()
+ def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
+ raise NotImplementedError()
+ if sys.platform == "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share")
+ ):
+
+ def share(self, /, process_id: int) -> bytes:
+ raise NotImplementedError()
+ def detach(self) -> int:
+ raise NotImplementedError()
+ def get_inheritable(self) -> bool:
+ raise NotImplementedError()
+
+ def set_inheritable(self, inheritable: bool) -> None:
+ raise NotImplementedError()
class _SocketType(SocketType):
def __init__(self, sock: _stdlib_socket.socket):
@@ -772,7 +882,7 @@ async def _nonblocking_helper(
_stdlib_socket.socket.accept, _core.wait_readable
)
- async def accept(self) -> tuple[_SocketType, object]:
+ async def accept(self) -> tuple[SocketType, object]:
"""Like :meth:`socket.socket.accept`, but async."""
sock, addr = await self._accept()
return from_stdlib_socket(sock), addr
diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py
index 6eca844f0c..71b4c72a98 100644
--- a/trio/_tests/test_highlevel_open_tcp_listeners.py
+++ b/trio/_tests/test_highlevel_open_tcp_listeners.py
@@ -2,6 +2,7 @@
import socket as stdlib_socket
import sys
from math import inf
+from socket import AddressFamily, SocketKind
import attr
import pytest
@@ -115,14 +116,26 @@ class FakeOSError(OSError):
@attr.s
class FakeSocket(tsocket.SocketType):
- family = attr.ib()
- type = attr.ib()
- proto = attr.ib()
+ _family: SocketKind = attr.ib()
+ _type: AddressFamily = attr.ib()
+ _proto: int = attr.ib()
closed = attr.ib(default=False)
poison_listen = attr.ib(default=False)
backlog = attr.ib(default=None)
+ @property
+ def type(self) -> SocketKind:
+ return self._type
+
+ @property
+ def family(self) -> AddressFamily:
+ return self._family
+
+ @property
+ def proto(self) -> int:
+ return self._proto
+
def getsockopt(self, level, option):
if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
return True
diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py
index 24f82bddd5..a5f16ad777 100644
--- a/trio/_tests/test_highlevel_open_tcp_stream.py
+++ b/trio/_tests/test_highlevel_open_tcp_stream.py
@@ -1,9 +1,11 @@
+from __future__ import annotations
import socket
import sys
import attr
import pytest
+from typing import Any, Sequence, TYPE_CHECKING
import trio
from trio._highlevel_open_tcp_stream import (
close_all,
@@ -11,24 +13,32 @@
open_tcp_stream,
reorder_for_rfc_6555_section_5_4,
)
-from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM
+from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType, Address
+from socket import AddressFamily, SocketKind
+
+if TYPE_CHECKING:
+ from trio.testing import MockClock
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
-def test_close_all():
- class CloseMe:
+def test_close_all() -> None:
+ class CloseMe(SocketType):
+ def __init__(self) -> None:
+ ...
closed = False
- def close(self):
+ def close(self) -> None:
self.closed = True
- class CloseKiller:
- def close(self):
+ class CloseKiller(SocketType):
+ def __init__(self) -> None:
+ ...
+ def close(self) -> None:
raise OSError
- c = CloseMe()
+ c: CloseMe = CloseMe()
with close_all() as to_close:
to_close.add(c)
assert c.closed
@@ -48,8 +58,10 @@ def close(self):
assert c.closed
-def test_reorder_for_rfc_6555_section_5_4():
- def fake4(i):
+def test_reorder_for_rfc_6555_section_5_4() -> None:
+ def fake4(
+ i: int,
+ ) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
return (
AF_INET,
SOCK_STREAM,
@@ -58,7 +70,9 @@ def fake4(i):
(f"10.0.0.{i}", 80),
)
- def fake6(i):
+ def fake6(
+ i: int,
+ ) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", (f"::{i}", 80))
for fake in fake4, fake6:
@@ -85,7 +99,7 @@ def fake6(i):
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
-def test_format_host_port():
+def test_format_host_port() -> None:
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port("example.com", 443) == "example.com:443"
@@ -95,7 +109,7 @@ def test_format_host_port():
# Make sure we can connect to localhost using real kernel sockets
-async def test_open_tcp_stream_real_socket_smoketest():
+async def test_open_tcp_stream_real_socket_smoketest() -> None:
listen_sock = trio.socket.socket()
await listen_sock.bind(("127.0.0.1", 0))
_, listen_port = listen_sock.getsockname()
@@ -110,23 +124,23 @@ async def test_open_tcp_stream_real_socket_smoketest():
listen_sock.close()
-async def test_open_tcp_stream_input_validation():
+async def test_open_tcp_stream_input_validation() -> None:
with pytest.raises(ValueError):
- await open_tcp_stream(None, 80)
+ await open_tcp_stream(None, 80) # type: ignore[arg-type]
with pytest.raises(TypeError):
- await open_tcp_stream("127.0.0.1", b"80")
+ await open_tcp_stream("127.0.0.1", b"80") # type: ignore[arg-type]
-def can_bind_127_0_0_2():
+def can_bind_127_0_0_2() -> bool:
with socket.socket() as s:
try:
s.bind(("127.0.0.2", 0))
except OSError:
return False
- return s.getsockname()[0] == "127.0.0.2"
+ return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return]
-async def test_local_address_real():
+async def test_local_address_real() -> None:
with trio.socket.socket() as listener:
await listener.bind(("127.0.0.1", 0))
listener.listen()
@@ -153,11 +167,11 @@ async def test_local_address_real():
assert client_stream.socket.getsockopt(
trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT
)
-
server_sock, remote_addr = await listener.accept()
await client_stream.aclose()
server_sock.close()
- assert remote_addr[0] == local_address
+ # accept returns tuple[SocketType, object]
+ assert remote_addr[0] == local_address # type: ignore[index]
# Trying to connect to an ipv4 address with the ipv6 wildcard
# local_address should fail
@@ -178,18 +192,30 @@ async def test_local_address_real():
@attr.s(eq=False)
class FakeSocket(trio.socket.SocketType):
- scenario = attr.ib()
- family = attr.ib()
- type = attr.ib()
- proto = attr.ib()
-
- ip = attr.ib(default=None)
- port = attr.ib(default=None)
- succeeded = attr.ib(default=False)
- closed = attr.ib(default=False)
- failing = attr.ib(default=False)
-
- async def connect(self, sockaddr):
+ scenario: Scenario = attr.ib()
+ _family: AddressFamily = attr.ib(alias="_family")
+ _type: SocketKind = attr.ib(alias="_type")
+ _proto: int = attr.ib(alias="_proto")
+
+ ip: str | int | None = attr.ib(default=None)
+ port: str | int | None = attr.ib(default=None)
+ succeeded: bool = attr.ib(default=False)
+ closed: bool = attr.ib(default=False)
+ failing: bool = attr.ib(default=False)
+
+ @property
+ def type(self) -> SocketKind:
+ return self._type
+
+ @property
+ def family(self) -> AddressFamily:
+ return self._family
+
+ @property
+ def proto(self) -> int:
+ return self._proto
+
+ async def connect(self, sockaddr: Address) -> None:
self.ip = sockaddr[0]
self.port = sockaddr[1]
assert self.ip not in self.scenario.sockets
@@ -203,11 +229,11 @@ async def connect(self, sockaddr):
self.failing = True
self.succeeded = True
- def close(self):
+ def close(self) -> None:
self.closed = True
# called when SocketStream is constructed
- def setsockopt(self, *args, **kwargs):
+ def setsockopt(self, *args: object, **kwargs: object) -> None:
if self.failing:
# raise something that isn't OSError as SocketStream
# ignores those
@@ -215,11 +241,16 @@ def setsockopt(self, *args, **kwargs):
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
- def __init__(self, port, ip_list, supported_families):
+ def __init__(
+ self,
+ port: int,
+ ip_list: Sequence[tuple[str, float, str]],
+ supported_families: set[AddressFamily],
+ ):
# ip_list have to be unique
ip_order = [ip for (ip, _, _) in ip_list]
assert len(set(ip_order)) == len(ip_list)
- ip_dict = {}
+ ip_dict: dict[str | int, tuple[float, str]] = {}
for ip, delay, result in ip_list:
assert 0 <= delay
assert result in ["error", "success", "postconnect_fail"]
@@ -230,16 +261,32 @@ def __init__(self, port, ip_list, supported_families):
self.ip_dict = ip_dict
self.supported_families = supported_families
self.socket_count = 0
- self.sockets = {}
- self.connect_times = {}
-
- def socket(self, family, type, proto):
+ self.sockets: dict[str | int, FakeSocket] = {}
+ self.connect_times: dict[str | int, float] = {}
+
+ def socket(
+ self,
+ family: AddressFamily | int | None = None,
+ type: SocketKind | int | None = None,
+ proto: int | None = None,
+ ) -> SocketType:
+ assert isinstance(family, AddressFamily)
+ assert isinstance(type, SocketKind)
if family not in self.supported_families:
raise OSError("pretending not to support this family")
self.socket_count += 1
return FakeSocket(self, family, type, proto)
- def _ip_to_gai_entry(self, ip):
+ def _ip_to_gai_entry(
+ self, ip: str
+ ) -> tuple[
+ AddressFamily,
+ SocketKind,
+ int | None,
+ str,
+ tuple[int | str, int, int, int] | tuple[int | str, int],
+ ]:
+ sockaddr: tuple[int | str, int] | tuple[int | str, int, int, int]
if ":" in ip:
family = trio.socket.AF_INET6
sockaddr = (ip, self.port, 0, 0)
@@ -248,7 +295,25 @@ def _ip_to_gai_entry(self, ip):
sockaddr = (ip, self.port)
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
- async def getaddrinfo(self, host, port, family, type, proto, flags):
+ # should hostnameresolver use AddressFamily and SocketKind, instead of int&int?
+ # the return type in supertype is ... wildly incompatible with what this returns
+ async def getaddrinfo( # type: ignore[override]
+ self,
+ host: str | bytes | None,
+ port: bytes | str | int | None,
+ family: int = -1,
+ type: int = -1,
+ proto: int = -1,
+ flags: int = -1,
+ ) -> list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ int | None,
+ str,
+ tuple[int | str, int, int, int] | tuple[int | str, int],
+ ]
+ ]:
assert host == b"test.example.com"
assert port == self.port
assert family == trio.socket.AF_UNSPEC
@@ -257,10 +322,12 @@ async def getaddrinfo(self, host, port, family, type, proto, flags):
assert flags == 0
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
- async def getnameinfo(self, sockaddr, flags): # pragma: no cover
+ async def getnameinfo( # pragma: no cover
+ self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int
+ ) -> tuple[str, str]:
raise NotImplementedError
- def check(self, succeeded):
+ def check(self, succeeded: SocketType | None) -> None:
# sockets only go into self.sockets when connect is called; make sure
# all the sockets that were created did in fact go in there.
assert self.socket_count == len(self.sockets)
@@ -274,24 +341,24 @@ def check(self, succeeded):
async def run_scenario(
# The port to connect to
- port,
+ port: int,
# A list of
# (ip, delay, result)
# tuples, where delay is in seconds and result is "success" or "error"
# The ip's will be returned from getaddrinfo in this order, and then
# connect() calls to them will have the given result.
- ip_list,
+ ip_list: Sequence[tuple[str, float, str]],
*,
# If False, AF_INET4/6 sockets error out on creation, before connect is
# even called.
- ipv4_supported=True,
- ipv6_supported=True,
+ ipv4_supported: bool = True,
+ ipv6_supported: bool = True,
# Normally, we return (winning_sock, scenario object)
# If this is True, we require there to be an exception, and return
# (exception, scenario object)
- expect_error=(),
- **kwargs,
-):
+ expect_error: tuple[type[BaseException], ...] | type[BaseException] = (),
+ **kwargs: Any,
+) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]:
supported_families = set()
if ipv4_supported:
supported_families.add(trio.socket.AF_INET)
@@ -313,19 +380,21 @@ async def run_scenario(
return (exc, scenario)
-async def test_one_host_quick_success(autojump_clock):
+async def test_one_host_quick_success(autojump_clock: trio.testing.MockClock) -> None:
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
+ assert isinstance(sock, FakeSocket)
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 0.123
-async def test_one_host_slow_success(autojump_clock):
+async def test_one_host_slow_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
+ assert isinstance(sock, FakeSocket)
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 100
-async def test_one_host_quick_fail(autojump_clock):
+async def test_one_host_quick_fail(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
82, [("1.2.3.4", 0.123, "error")], expect_error=OSError
)
@@ -333,7 +402,7 @@ async def test_one_host_quick_fail(autojump_clock):
assert trio.current_time() == 0.123
-async def test_one_host_slow_fail(autojump_clock):
+async def test_one_host_slow_fail(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
83, [("1.2.3.4", 100, "error")], expect_error=OSError
)
@@ -341,7 +410,7 @@ async def test_one_host_slow_fail(autojump_clock):
assert trio.current_time() == 100
-async def test_one_host_failed_after_connect(autojump_clock):
+async def test_one_host_failed_after_connect(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt
)
@@ -349,7 +418,7 @@ async def test_one_host_failed_after_connect(autojump_clock):
# With the default 0.250 second delay, the third attempt will win
-async def test_basic_fallthrough(autojump_clock):
+async def test_basic_fallthrough(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
@@ -358,6 +427,7 @@ async def test_basic_fallthrough(autojump_clock):
("3.3.3.3", 0.2, "success"),
],
)
+ assert isinstance(sock, FakeSocket)
assert sock.ip == "3.3.3.3"
# current time is default time + default time + connection time
assert trio.current_time() == (0.250 + 0.250 + 0.2)
@@ -368,7 +438,7 @@ async def test_basic_fallthrough(autojump_clock):
}
-async def test_early_success(autojump_clock):
+async def test_early_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
@@ -377,6 +447,7 @@ async def test_early_success(autojump_clock):
("3.3.3.3", 0.2, "success"),
],
)
+ assert isinstance(sock, FakeSocket)
assert sock.ip == "2.2.2.2"
assert trio.current_time() == (0.250 + 0.1)
assert scenario.connect_times == {
@@ -387,7 +458,7 @@ async def test_early_success(autojump_clock):
# With a 0.450 second delay, the first attempt will win
-async def test_custom_delay(autojump_clock):
+async def test_custom_delay(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
@@ -397,6 +468,7 @@ async def test_custom_delay(autojump_clock):
],
happy_eyeballs_delay=0.450,
)
+ assert isinstance(sock, FakeSocket)
assert sock.ip == "1.1.1.1"
assert trio.current_time() == 1
assert scenario.connect_times == {
@@ -406,7 +478,7 @@ async def test_custom_delay(autojump_clock):
}
-async def test_custom_errors_expedite(autojump_clock):
+async def test_custom_errors_expedite(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
@@ -417,6 +489,7 @@ async def test_custom_errors_expedite(autojump_clock):
("4.4.4.4", 0.25, "success"),
],
)
+ assert isinstance(sock, FakeSocket)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
assert scenario.connect_times == {
@@ -427,7 +500,7 @@ async def test_custom_errors_expedite(autojump_clock):
}
-async def test_all_fail(autojump_clock):
+async def test_all_fail(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
80,
[
@@ -450,7 +523,7 @@ async def test_all_fail(autojump_clock):
}
-async def test_multi_success(autojump_clock):
+async def test_multi_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
@@ -469,6 +542,7 @@ async def test_multi_success(autojump_clock):
or scenario.sockets["4.4.4.4"].succeeded
)
assert not scenario.sockets["5.5.5.5"].succeeded
+ assert isinstance(sock, FakeSocket)
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
assert trio.current_time() == (0.5 + 10)
assert scenario.connect_times == {
@@ -480,7 +554,7 @@ async def test_multi_success(autojump_clock):
}
-async def test_does_reorder(autojump_clock):
+async def test_does_reorder(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
@@ -492,6 +566,7 @@ async def test_does_reorder(autojump_clock):
],
happy_eyeballs_delay=1,
)
+ assert isinstance(sock, FakeSocket)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.5
assert scenario.connect_times == {
@@ -500,7 +575,7 @@ async def test_does_reorder(autojump_clock):
}
-async def test_handles_no_ipv4(autojump_clock):
+async def test_handles_no_ipv4(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
@@ -514,6 +589,7 @@ async def test_handles_no_ipv4(autojump_clock):
happy_eyeballs_delay=1,
ipv4_supported=False,
)
+ assert isinstance(sock, FakeSocket)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
@@ -522,7 +598,7 @@ async def test_handles_no_ipv4(autojump_clock):
}
-async def test_handles_no_ipv6(autojump_clock):
+async def test_handles_no_ipv6(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
@@ -536,6 +612,7 @@ async def test_handles_no_ipv6(autojump_clock):
happy_eyeballs_delay=1,
ipv6_supported=False,
)
+ assert isinstance(sock, FakeSocket)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
@@ -544,12 +621,12 @@ async def test_handles_no_ipv6(autojump_clock):
}
-async def test_no_hosts(autojump_clock):
+async def test_no_hosts(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(80, [], expect_error=OSError)
assert "no results found" in str(exc)
-async def test_cancel(autojump_clock):
+async def test_cancel(autojump_clock: MockClock) -> None:
with trio.move_on_after(5) as cancel_scope:
exc, scenario = await run_scenario(
80,
@@ -561,6 +638,7 @@ async def test_cancel(autojump_clock):
],
expect_error=BaseExceptionGroup,
)
+ assert isinstance(exc, BaseException)
# What comes out should be 1 or more Cancelled errors that all belong
# to this cancel_scope; this is the easiest way to check that
raise exc
@@ -571,4 +649,4 @@ async def test_cancel(autojump_clock):
# This should have been called already, but just to make sure, since the
# exception-handling logic in run_scenario is a bit complicated and the
# main thing we care about here is that all the sockets were cleaned up.
- scenario.check(succeeded=False)
+ scenario.check(succeeded=None)
diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py
index 036098b8e5..8536c66f89 100644
--- a/trio/_tests/test_socket.py
+++ b/trio/_tests/test_socket.py
@@ -1,31 +1,41 @@
+from __future__ import annotations
import errno
import inspect
import os
import socket as stdlib_socket
import sys
import tempfile
+from typing import Callable, Any, TYPE_CHECKING, Tuple, List, Union
+from socket import SocketKind, AddressFamily
import attr
import pytest
from .. import _core, socket as tsocket
from .._core._tests.tutil import binds_ipv6, creates_ipv6
-from .._socket import _NUMERIC_ONLY, _try_sync
+from .._socket import _NUMERIC_ONLY, _try_sync, SocketType
from ..testing import assert_checkpoints, wait_all_tasks_blocked
+if TYPE_CHECKING:
+ from typing_extensions import TypeAlias
+ GaiTuple: TypeAlias = Tuple[AddressFamily, SocketKind, int, str, Union[Tuple[str, int],Tuple[str, int, int, int]]]
+ getaddrinfoResponse: TypeAlias = List[GaiTuple]
+else:
+ GaiTuple: object
+ getaddrinfoResponse = object
+
################################################################
# utils
################################################################
-
class MonkeypatchedGAI:
- def __init__(self, orig_getaddrinfo):
+ def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]):
self._orig_getaddrinfo = orig_getaddrinfo
- self._responses = {}
- self.record = []
+ self._responses: dict[tuple[Any, ...], getaddrinfoResponse|str] = {}
+ self.record: list[tuple[Any, ...]] = []
# get a normalized getaddrinfo argument tuple
- def _frozenbind(self, *args, **kwargs):
+ def _frozenbind(self, *args: Any, **kwargs: Any) -> tuple[Any, ...]:
sig = inspect.signature(self._orig_getaddrinfo)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
@@ -33,10 +43,10 @@ def _frozenbind(self, *args, **kwargs):
assert not bound.kwargs
return frozenbound
- def set(self, response, *args, **kwargs):
+ def set(self, response: getaddrinfoResponse|str, *args: Any, **kwargs: Any) -> None:
self._responses[self._frozenbind(*args, **kwargs)] = response
- def getaddrinfo(self, *args, **kwargs):
+ def getaddrinfo(self, *args: Any, **kwargs: Any) -> getaddrinfoResponse|str:
bound = self._frozenbind(*args, **kwargs)
self.record.append(bound)
if bound in self._responses:
@@ -48,13 +58,13 @@ def getaddrinfo(self, *args, **kwargs):
@pytest.fixture
-def monkeygai(monkeypatch):
+def monkeygai(monkeypatch: pytest.MonkeyPatch) -> MonkeypatchedGAI:
controller = MonkeypatchedGAI(stdlib_socket.getaddrinfo)
monkeypatch.setattr(stdlib_socket, "getaddrinfo", controller.getaddrinfo)
return controller
-async def test__try_sync():
+async def test__try_sync() -> None:
with assert_checkpoints():
async with _try_sync():
pass
@@ -67,7 +77,7 @@ async def test__try_sync():
async with _try_sync():
raise BlockingIOError
- def _is_ValueError(exc):
+ def _is_ValueError(exc: BaseException) -> bool:
return isinstance(exc, ValueError)
async with _try_sync(_is_ValueError):
@@ -84,7 +94,7 @@ def _is_ValueError(exc):
################################################################
-def test_socket_has_some_reexports():
+def test_socket_has_some_reexports() -> None:
assert tsocket.SOL_SOCKET == stdlib_socket.SOL_SOCKET
assert tsocket.TCP_NODELAY == stdlib_socket.TCP_NODELAY
assert tsocket.gaierror == stdlib_socket.gaierror
@@ -96,19 +106,19 @@ def test_socket_has_some_reexports():
################################################################
-async def test_getaddrinfo(monkeygai):
- def check(got, expected):
+async def test_getaddrinfo(monkeygai: MonkeypatchedGAI) -> None:
+ def check(got: getaddrinfoResponse, expected: getaddrinfoResponse) -> None:
# win32 returns 0 for the proto field
# musl and glibc have inconsistent handling of the canonical name
# field (https://github.com/python-trio/trio/issues/1499)
# Neither field gets used much and there isn't much opportunity for us
# to mess them up, so we don't bother checking them here
- def interesting_fields(gai_tup):
+ def interesting_fields(gai_tup: GaiTuple) -> tuple[AddressFamily, SocketKind, tuple[str, int]|tuple[str, int, int]|tuple[str, int, int, int]]:
# (family, type, proto, canonname, sockaddr)
family, type, proto, canonname, sockaddr = gai_tup
return (family, type, sockaddr)
- def filtered(gai_list):
+ def filtered(gai_list: getaddrinfoResponse) -> list[tuple[AddressFamily, SocketKind, tuple[str, int]|tuple[str, int, int]|tuple[str, int, int, int]]]:
return [interesting_fields(gai_tup) for gai_tup in gai_list]
assert filtered(got) == filtered(expected)
@@ -172,7 +182,7 @@ def filtered(gai_list):
await tsocket.getaddrinfo("asdf", "12345")
-async def test_getnameinfo():
+async def test_getnameinfo() -> None:
# Trivial test:
ni_numeric = stdlib_socket.NI_NUMERICHOST | stdlib_socket.NI_NUMERICSERV
with assert_checkpoints():
@@ -207,7 +217,7 @@ async def test_getnameinfo():
################################################################
-async def test_from_stdlib_socket():
+async def test_from_stdlib_socket() -> None:
sa, sb = stdlib_socket.socketpair()
assert not isinstance(sa, tsocket.SocketType)
with sa, sb:
@@ -219,7 +229,7 @@ async def test_from_stdlib_socket():
# rejects other types
with pytest.raises(TypeError):
- tsocket.from_stdlib_socket(1)
+ tsocket.from_stdlib_socket(1) # type: ignore[arg-type]
class MySocket(stdlib_socket.socket):
pass
@@ -229,7 +239,7 @@ class MySocket(stdlib_socket.socket):
tsocket.from_stdlib_socket(mysock)
-async def test_from_fd():
+async def test_from_fd() -> None:
sa, sb = stdlib_socket.socketpair()
ta = tsocket.fromfd(sa.fileno(), sa.family, sa.type, sa.proto)
with sa, sb, ta:
@@ -238,8 +248,8 @@ async def test_from_fd():
assert sb.recv(3) == b"x"
-async def test_socketpair_simple():
- async def child(sock):
+async def test_socketpair_simple() -> None:
+ async def child(sock: SocketType) -> None:
print("sending hello")
await sock.send(b"h")
assert await sock.recv(1) == b"h"
@@ -252,7 +262,8 @@ async def child(sock):
@pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only")
-async def test_fromshare():
+async def test_fromshare() -> None:
+ assert not TYPE_CHECKING or sys.platform == "win32"
a, b = tsocket.socketpair()
with a, b:
# share with ourselves
@@ -264,21 +275,21 @@ async def test_fromshare():
assert await b.recv(1) == b"x"
-async def test_socket():
+async def test_socket() -> None:
with tsocket.socket() as s:
assert isinstance(s, tsocket.SocketType)
assert s.family == tsocket.AF_INET
@creates_ipv6
-async def test_socket_v6():
+async def test_socket_v6() -> None:
with tsocket.socket(tsocket.AF_INET6, tsocket.SOCK_DGRAM) as s:
assert isinstance(s, tsocket.SocketType)
assert s.family == tsocket.AF_INET6
@pytest.mark.skipif(not sys.platform == "linux", reason="linux only")
-async def test_sniff_sockopts():
+async def test_sniff_sockopts() -> None:
from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM
# generate the combinations of families/types we're testing:
@@ -309,7 +320,7 @@ async def test_sniff_sockopts():
################################################################
-async def test_SocketType_basics():
+async def test_SocketType_basics() -> None:
sock = tsocket.socket()
with sock as cm_enter_value:
assert cm_enter_value is sock
@@ -349,7 +360,7 @@ async def test_SocketType_basics():
# our __getattr__ handles unknown names
with pytest.raises(AttributeError):
- sock.asdf
+ sock.asdf # type: ignore[attr-defined]
# type family proto
stdlib_sock = stdlib_socket.socket()
@@ -366,7 +377,7 @@ async def test_SocketType_setsockopt() -> None:
setsockopt_tests(sock)
-def setsockopt_tests(sock):
+def setsockopt_tests(sock: SocketType) -> None:
"""Extract these out, to be reused for SocketStream also."""
# specifying optlen. Not supported on pypy, and I couldn't find
# valid calls on darwin or win32.
@@ -378,14 +389,14 @@ def setsockopt_tests(sock):
# specifying both
with pytest.raises(TypeError, match="invalid value for argument 'value'"):
- sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5)
+ sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) # type: ignore[call-overload]
# specifying neither
with pytest.raises(TypeError, match="invalid value for argument 'value'"):
- sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None)
+ sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) # type: ignore[call-overload]
-async def test_SocketType_dup():
+async def test_SocketType_dup() -> None:
a, b = tsocket.socketpair()
with a, b:
a2 = a.dup()
@@ -397,7 +408,7 @@ async def test_SocketType_dup():
assert await b.recv(1) == b"x"
-async def test_SocketType_shutdown():
+async def test_SocketType_shutdown() -> None:
a, b = tsocket.socketpair()
with a, b:
await a.send(b"x")
@@ -431,7 +442,7 @@ async def test_SocketType_shutdown():
pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6),
],
)
-async def test_SocketType_simple_server(address, socket_type):
+async def test_SocketType_simple_server(address, socket_type) -> None:
# listen, bind, accept, connect, getpeername, getsockname
listener = tsocket.socket(socket_type)
client = tsocket.socket(socket_type)
@@ -448,7 +459,7 @@ async def test_SocketType_simple_server(address, socket_type):
assert await client.recv(1) == b"x"
-async def test_SocketType_is_readable():
+async def test_SocketType_is_readable() -> None:
a, b = tsocket.socketpair()
with a, b:
assert not a.is_readable()
@@ -462,7 +473,7 @@ async def test_SocketType_is_readable():
# On some macOS systems, getaddrinfo likes to return V4-mapped addresses even
# when we *don't* pass AI_V4MAPPED.
# https://github.com/python-trio/trio/issues/580
-def gai_without_v4mapped_is_buggy(): # pragma: no cover
+def gai_without_v4mapped_is_buggy() -> bool: # pragma: no cover
try:
stdlib_socket.getaddrinfo("1.2.3.4", 0, family=stdlib_socket.AF_INET6)
except stdlib_socket.gaierror:
@@ -504,16 +515,16 @@ class Addresses:
),
],
)
-async def test_SocketType_resolve(socket_type, addrs):
+async def test_SocketType_resolve(socket_type, addrs) -> None:
v6 = socket_type == tsocket.AF_INET6
- def pad(addr):
+ def pad(addr) -> tuple[str, int]|tuple[str, int,int,int]:
if v6:
while len(addr) < 4:
addr += (0,)
return addr
- def assert_eq(actual, expected):
+ def assert_eq(actual, expected) -> None:
assert pad(expected) == pad(actual)
with tsocket.socket(family=socket_type) as sock:
@@ -599,7 +610,7 @@ async def res(*args):
await res(("1.2.3.4", 80, 0, 0))
-async def test_SocketType_unresolved_names():
+async def test_SocketType_unresolved_names() -> None:
with tsocket.socket() as sock:
await sock.bind(("localhost", 0))
assert sock.getsockname()[0] == "127.0.0.1"
@@ -618,7 +629,7 @@ async def test_SocketType_unresolved_names():
# This tests all the complicated paths through _nonblocking_helper, using recv
# as a stand-in for all the methods that use _nonblocking_helper.
-async def test_SocketType_non_blocking_paths():
+async def test_SocketType_non_blocking_paths() -> None:
a, b = stdlib_socket.socketpair()
with a, b:
ta = tsocket.from_stdlib_socket(a)
@@ -641,7 +652,7 @@ async def test_SocketType_non_blocking_paths():
await ta.recv("haha")
# block then succeed
- async def do_successful_blocking_recv():
+ async def do_successful_blocking_recv() -> None:
with assert_checkpoints():
assert await ta.recv(10) == b"2"
@@ -651,7 +662,7 @@ async def do_successful_blocking_recv():
b.send(b"2")
# block then cancelled
- async def do_cancelled_blocking_recv():
+ async def do_cancelled_blocking_recv() -> None:
with assert_checkpoints():
with pytest.raises(_core.Cancelled):
await ta.recv(10)
@@ -669,13 +680,13 @@ async def do_cancelled_blocking_recv():
# other:
tb = tsocket.from_stdlib_socket(b)
- async def t1():
+ async def t1() -> None:
with assert_checkpoints():
assert await ta.recv(1) == b"a"
with assert_checkpoints():
assert await tb.recv(1) == b"b"
- async def t2():
+ async def t2() -> None:
with assert_checkpoints():
assert await tb.recv(1) == b"b"
with assert_checkpoints():
@@ -693,7 +704,7 @@ async def t2():
# This tests the complicated paths through connect
-async def test_SocketType_connect_paths():
+async def test_SocketType_connect_paths() -> None:
with tsocket.socket() as sock:
with pytest.raises(ValueError):
# Should be a tuple
@@ -716,7 +727,7 @@ async def test_SocketType_connect_paths():
# nose -- and then swap it back out again before we hit
# wait_socket_writable, which insists on a real socket.
class CancelSocket(stdlib_socket.socket):
- def connect(self, *args, **kwargs):
+ def connect(self, *args, **kwargs) -> None:
cancel_scope.cancel()
sock._sock = stdlib_socket.fromfd(
self.detach(), self.family, self.type
@@ -747,7 +758,7 @@ def connect(self, *args, **kwargs):
# Fix issue #1810
-async def test_address_in_socket_error():
+async def test_address_in_socket_error() -> None:
address = "127.0.0.1"
with tsocket.socket() as sock:
try:
@@ -756,12 +767,12 @@ async def test_address_in_socket_error():
assert any(address in str(arg) for arg in e.args)
-async def test_resolve_address_exception_in_connect_closes_socket():
+async def test_resolve_address_exception_in_connect_closes_socket() -> None:
# Here we are testing issue 247, any cancellation will leave the socket closed
with _core.CancelScope() as cancel_scope:
with tsocket.socket() as sock:
- async def _resolve_address_nocp(self, *args, **kwargs):
+ async def _resolve_address_nocp(self, *args, **kwargs) -> None:
cancel_scope.cancel()
await _core.checkpoint()
@@ -772,7 +783,7 @@ async def _resolve_address_nocp(self, *args, **kwargs):
assert sock.fileno() == -1
-async def test_send_recv_variants():
+async def test_send_recv_variants() -> None:
a, b = tsocket.socketpair()
with a, b:
# recv, including with flags
@@ -868,7 +879,7 @@ async def test_send_recv_variants():
assert await b.recv(10) == b"yyy"
-async def test_idna(monkeygai):
+async def test_idna(monkeygai) -> None:
# This is the encoding for "faß.de", which uses one of the characters that
# IDNA 2003 handles incorrectly:
monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80)
@@ -886,14 +897,14 @@ async def test_idna(monkeygai):
assert "ok faß.de" == await tsocket.getaddrinfo(b"xn--fa-hia.de", 80)
-async def test_getprotobyname():
+async def test_getprotobyname() -> None:
# These are the constants used in IP header fields, so the numeric values
# had *better* be stable across systems...
assert await tsocket.getprotobyname("udp") == 17
assert await tsocket.getprotobyname("tcp") == 6
-async def test_custom_hostname_resolver(monkeygai):
+async def test_custom_hostname_resolver(monkeygai) -> None:
class CustomResolver:
async def getaddrinfo(self, host, port, family, type, proto, flags):
return ("custom_gai", host, port, family, type, proto, flags)
@@ -937,7 +948,7 @@ async def getnameinfo(self, sockaddr, flags):
assert await tsocket.getaddrinfo("host", "port") == "x"
-async def test_custom_socket_factory():
+async def test_custom_socket_factory() -> None:
class CustomSocketFactory:
def socket(self, family, type, proto):
return ("hi", family, type, proto)
@@ -964,17 +975,17 @@ def socket(self, family, type, proto):
assert tsocket.set_custom_socket_factory(None) is csf
-async def test_SocketType_is_abstract():
+async def test_SocketType_is_abstract() -> None:
with pytest.raises(TypeError):
tsocket.SocketType()
@pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets")
-async def test_unix_domain_socket():
+async def test_unix_domain_socket() -> None:
# Bind has a special branch to use a thread, since it has to do filesystem
# traversal. Maybe connect should too? Not sure.
- async def check_AF_UNIX(path):
+ async def check_AF_UNIX(path) -> None:
with tsocket.socket(family=tsocket.AF_UNIX) as lsock:
await lsock.bind(path)
lsock.listen(10)
@@ -999,7 +1010,7 @@ async def check_AF_UNIX(path):
pass
-async def test_interrupted_by_close():
+async def test_interrupted_by_close() -> None:
a_stdlib, b_stdlib = stdlib_socket.socketpair()
with a_stdlib, b_stdlib:
a_stdlib.setblocking(False)
@@ -1014,11 +1025,11 @@ async def test_interrupted_by_close():
a = tsocket.from_stdlib_socket(a_stdlib)
- async def sender():
+ async def sender() -> None:
with pytest.raises(_core.ClosedResourceError):
await a.send(data)
- async def receiver():
+ async def receiver() -> None:
with pytest.raises(_core.ClosedResourceError):
await a.recv(1)
@@ -1029,7 +1040,7 @@ async def receiver():
a.close()
-async def test_many_sockets():
+async def test_many_sockets() -> None:
total = 5000 # Must be more than MAX_AFD_GROUP_SIZE
sockets = []
for x in range(total // 2):
diff --git a/trio/socket.py b/trio/socket.py
index f8d0bc3fc2..ddefc72649 100644
--- a/trio/socket.py
+++ b/trio/socket.py
@@ -36,7 +36,6 @@
from ._socket import (
Address as Address,
SocketType as SocketType,
- _SocketType as _SocketType,
from_stdlib_socket as from_stdlib_socket,
fromfd as fromfd,
getaddrinfo as getaddrinfo,
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index ddf46174f3..2a358119f3 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -11,7 +11,7 @@
import errno
import ipaddress
import os
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING, Optional, Union, Type
import attr
@@ -21,6 +21,7 @@
if TYPE_CHECKING:
from socket import AddressFamily, SocketKind
from types import TracebackType
+ from typing_extensions import TypeAlias
IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@@ -174,7 +175,7 @@ def deliver_packet(self, packet) -> None:
class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor):
- def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int):
+ def __init__(self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, proto: int):
self._fake_net = fake_net
if not family:
@@ -187,9 +188,9 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int):
if type != trio.socket.SOCK_DGRAM:
raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}")
- self.family = family
- self.type = type
- self.proto = proto
+ self._family = family
+ self._type = type
+ self._proto = proto
self._closed = False
@@ -199,6 +200,15 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int):
# This is the source-of-truth for what port etc. this socket is bound to
self._binding: Optional[UDPBinding] = None
+ @property
+ def type(self) -> SocketKind:
+ return self._type
+ @property
+ def family(self) -> AddressFamily:
+ return self._family
+ @property
+ def proto(self) -> int:
+ return self._proto
def _check_closed(self):
if self._closed:
@@ -362,7 +372,8 @@ def __enter__(self):
def __exit__(
self,
- exc_type: type[BaseException] | None,
+ # builtin `type` is shadowed by the property
+ exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
From 1a38b7e4eda4e2702401e1bb651fd386a0eda1ca Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 25 Aug 2023 17:40:54 +0200
Subject: [PATCH 02/20] wip2
---
pyproject.toml | 4 +-
trio/_core/_generated_io_epoll.py | 9 +-
trio/_core/_io_epoll.py | 9 +-
trio/_socket.py | 213 ++++++++++++------
.../test_highlevel_open_tcp_listeners.py | 74 ++++--
trio/_tests/test_highlevel_open_tcp_stream.py | 10 +-
trio/_tests/test_highlevel_socket.py | 87 ++++---
trio/_tests/test_socket.py | 143 +++++++++---
trio/_tools/gen_exports.py | 4 +
9 files changed, 388 insertions(+), 165 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 64e15a4c71..aeb35977a7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -96,8 +96,8 @@ module = [
"trio/_tests/test_exports",
"trio/_tests/test_file_io",
"trio/_tests/test_highlevel_generic",
-"trio/_tests/test_highlevel_socket",
-"trio/_tests/test_highlevel_open_tcp_listeners",
+#"trio/_tests/test_highlevel_socket", #
+#"trio/_tests/test_highlevel_open_tcp_listeners", #
"trio/_tests/test_highlevel_open_unix_stream",
"trio/_tests/test_highlevel_serve_listeners",
"trio/_tests/test_highlevel_ssl_helpers",
diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py
index abe49ed3ff..54957952be 100644
--- a/trio/_core/_generated_io_epoll.py
+++ b/trio/_core/_generated_io_epoll.py
@@ -10,12 +10,15 @@
from ._run import GLOBAL_RUN_CONTEXT
from socket import socket
from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from .._file_io import _HasFileNo
import sys
assert not TYPE_CHECKING or sys.platform=="linux"
-async def wait_readable(fd: (int | socket)) ->None:
+async def wait_readable(fd: (int | _HasFileNo)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
@@ -23,7 +26,7 @@ async def wait_readable(fd: (int | socket)) ->None:
raise RuntimeError("must be called from async context")
-async def wait_writable(fd: (int | socket)) ->None:
+async def wait_writable(fd: (int | _HasFileNo)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
@@ -31,7 +34,7 @@ async def wait_writable(fd: (int | socket)) ->None:
raise RuntimeError("must be called from async context")
-def notify_closing(fd: (int | socket)) ->None:
+def notify_closing(fd: (int | _HasFileNo)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index 0d247cae64..b39ad547f1 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -18,6 +18,7 @@
from typing_extensions import TypeAlias
from .._core import Abort, RaiseCancelT
+ from .._file_io import _HasFileNo
@attr.s(slots=True, eq=False)
@@ -290,7 +291,7 @@ def _update_registrations(self, fd: int) -> None:
if not wanted_flags:
del self._registered[fd]
- async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None:
+ async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
@@ -309,15 +310,15 @@ def abort(_: RaiseCancelT) -> Abort:
await _core.wait_task_rescheduled(abort)
@_public
- async def wait_readable(self, fd: int | socket) -> None:
+ async def wait_readable(self, fd: int | _HasFileNo) -> None:
await self._epoll_wait(fd, "read_task")
@_public
- async def wait_writable(self, fd: int | socket) -> None:
+ async def wait_writable(self, fd: int | _HasFileNo) -> None:
await self._epoll_wait(fd, "write_task")
@_public
- def notify_closing(self, fd: int | socket) -> None:
+ def notify_closing(self, fd: int | _HasFileNo) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
diff --git a/trio/_socket.py b/trio/_socket.py
index 0b64e5adec..ace35ff253 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -14,6 +14,7 @@
Callable,
Literal,
NoReturn,
+ Optional,
SupportsIndex,
Tuple,
TypeVar,
@@ -44,6 +45,13 @@
Address: TypeAlias = Union[
str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int]
]
+AddressWithNoneHost: TypeAlias = Union[
+ str,
+ bytes,
+ Tuple[Optional[str], int],
+ Tuple[Optional[str], int, int],
+ Tuple[Optional[str], int, int, int],
+]
# Usage:
@@ -450,7 +458,7 @@ async def _resolve_address_nocp(
proto: int,
*,
ipv6_v6only: bool | int,
- address: Address,
+ address: AddressWithNoneHost,
local: bool,
) -> Address:
# Do some pre-checking (or exit early for non-IP sockets)
@@ -467,7 +475,8 @@ async def _resolve_address_nocp(
assert isinstance(address, (str, bytes))
return os.fspath(address)
else:
- return address
+ # TODO: check for host is None?
+ return address # type: ignore[return-value]
# -- From here on we know we have IPv4 or IPV6 --
host: str | None
@@ -475,13 +484,13 @@ async def _resolve_address_nocp(
# Fast path for the simple case: already-resolved IP address,
# already-resolved port. This is particularly important for UDP, since
# every sendto call goes through here.
- if isinstance(port, int):
+ if isinstance(port, int) and host is not None:
try:
- _stdlib_socket.inet_pton(family, address[0])
+ _stdlib_socket.inet_pton(family, host)
except (OSError, TypeError):
pass
else:
- return address
+ return address # type: ignore[return-value]
# Special cases to match the stdlib, see gh-277
if host == "":
host = None
@@ -530,6 +539,69 @@ def __init__(self) -> NoReturn:
"SocketType is an abstract class; use trio.socket.socket if you "
"want to construct a socket object"
)
+
+ ################################################################
+ # Simple + portable methods and attributes
+ ################################################################
+ def detach(self) -> int:
+ raise NotImplementedError()
+
+ def fileno(self) -> int:
+ raise NotImplementedError()
+
+ def getpeername(self) -> Any:
+ raise NotImplementedError()
+
+ def getsockname(self) -> Any:
+ raise NotImplementedError()
+
+ @overload
+ def getsockopt(self, /, level: int, optname: int) -> int:
+ ...
+
+ @overload
+ def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
+ ...
+
+ def getsockopt(
+ self, /, level: int, optname: int, buflen: int | None = None
+ ) -> int | bytes:
+ raise NotImplementedError()
+
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
+ ...
+
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None:
+ ...
+
+ def setsockopt(
+ self,
+ /,
+ level: int,
+ optname: int,
+ value: int | Buffer | None,
+ optlen: int | None = None,
+ ) -> None:
+ raise NotImplementedError()
+
+ def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
+ raise NotImplementedError()
+
+ def get_inheritable(self) -> bool:
+ raise NotImplementedError()
+
+ def set_inheritable(self, inheritable: bool) -> None:
+ raise NotImplementedError()
+
+ if sys.platform == "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share")
+ ):
+
+ def share(self, /, process_id: int) -> bytes:
+ raise NotImplementedError()
+
def __enter__(self) -> Self:
raise NotImplementedError()
@@ -540,38 +612,85 @@ def __exit__(
traceback: TracebackType | None,
) -> None:
raise NotImplementedError()
+
@property
- def type(self) -> SocketKind:
+ def family(self) -> AddressFamily:
raise NotImplementedError()
+
@property
- def family(self) -> AddressFamily:
+ def type(self) -> SocketKind:
raise NotImplementedError()
+
@property
def proto(self) -> int:
raise NotImplementedError()
+
@property
def did_shutdown_SHUT_WR(self) -> bool:
raise NotImplementedError()
- def is_readable(self) -> bool:
+
+ # def __repr__(self) -> str:
+ # raise NotImplementedError()
+ def dup(self) -> SocketType:
raise NotImplementedError()
- def fileno(self) -> int:
+
+ def close(self) -> None:
raise NotImplementedError()
- async def wait_writable(self) -> None:
+
+ async def bind(self, address: Address) -> None:
raise NotImplementedError()
+
def shutdown(self, flag: int) -> None:
raise NotImplementedError()
+ def is_readable(self) -> bool:
+ raise NotImplementedError()
+
+ async def wait_writable(self) -> None:
+ raise NotImplementedError()
+
+ async def accept(self) -> tuple[SocketType, object]:
+ raise NotImplementedError()
+
async def connect(self, address: Address) -> None:
raise NotImplementedError()
- def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
+
+ def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
raise NotImplementedError()
- async def bind(self, address: Address) -> None:
+
+ def recv_into(
+ __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[int]:
raise NotImplementedError()
- def close(self) -> None:
+
+ ###
+ def recvfrom(
+ __self, __bufsize: int, __flags: int = 0
+ ) -> Awaitable[tuple[bytes, Address]]:
raise NotImplementedError()
- def getsockname(self) -> Any:
+
+ # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any]
+ def recvfrom_into(
+ __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[tuple[int, Address]]:
raise NotImplementedError()
- async def accept(self) -> tuple[SocketType, object]:
+
+ # if hasattr(_stdlib_socket.socket, "recvmsg"):
+ # def recvmsg(
+ # __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0
+ # ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]:
+ # raise NotImplementedError()
+
+ # if hasattr(_stdlib_socket.socket, "recvmsg_into"):
+ # def recvmsg_into(
+ # __self,
+ # __buffers: Iterable[Buffer],
+ # __ancbufsize: int = 0,
+ # __flags: int = 0,
+ # ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]:
+ # raise NotImplementedError()
+
+ def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]:
raise NotImplementedError()
@overload
@@ -590,57 +709,21 @@ async def sendto(
async def sendto(self, *args: Any) -> int:
"""Similar to :meth:`socket.socket.sendto`, but async."""
raise NotImplementedError()
- @overload
- def getsockopt(self, /, level: int, optname: int) -> int:
- ...
-
- @overload
- def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
- ...
-
- def getsockopt(
- self, /, level: int, optname: int, buflen: int | None = None
- ) -> int | bytes:
- raise NotImplementedError()
- @overload
- def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
- ...
-
- @overload
- def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None:
- ...
-
- def setsockopt(
- self,
- /,
- level: int,
- optname: int,
- value: int | Buffer | None,
- optlen: int | None = None,
- ) -> None:
- raise NotImplementedError()
- def recvfrom(
- __self, __bufsize: int, __flags: int = 0
- ) -> Awaitable[tuple[bytes, Address]]:
- raise NotImplementedError()
- def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]:
- raise NotImplementedError()
- def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
- raise NotImplementedError()
- if sys.platform == "win32" or (
- not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share")
+ if sys.platform != "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg")
):
- def share(self, /, process_id: int) -> bytes:
+ @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=())
+ async def sendmsg(
+ self,
+ __buffers: Iterable[Buffer],
+ __ancdata: Iterable[tuple[int, int, Buffer]] = (),
+ __flags: int = 0,
+ __address: Address | None = None,
+ ) -> int:
raise NotImplementedError()
- def detach(self) -> int:
- raise NotImplementedError()
- def get_inheritable(self) -> bool:
- raise NotImplementedError()
- def set_inheritable(self, inheritable: bool) -> None:
- raise NotImplementedError()
class _SocketType(SocketType):
def __init__(self, sock: _stdlib_socket.socket):
@@ -772,7 +855,7 @@ def close(self) -> None:
trio.lowlevel.notify_closing(self._sock)
self._sock.close()
- async def bind(self, address: Address) -> None:
+ async def bind(self, address: AddressWithNoneHost) -> None:
address = await self._resolve_address_nocp(address, local=True)
if (
hasattr(_stdlib_socket, "AF_UNIX")
@@ -811,7 +894,7 @@ async def wait_writable(self) -> None:
async def _resolve_address_nocp(
self,
- address: Address,
+ address: AddressWithNoneHost,
*,
local: bool,
) -> Address:
@@ -891,7 +974,7 @@ async def accept(self) -> tuple[SocketType, object]:
# connect
################################################################
- async def connect(self, address: Address) -> None:
+ async def connect(self, address: AddressWithNoneHost) -> None:
# nonblocking connect is weird -- you call it to start things
# off, then the socket becomes writable as a completion
# notification. This means it isn't really cancellable... we close the
@@ -1115,7 +1198,7 @@ async def sendmsg(
__buffers: Iterable[Buffer],
__ancdata: Iterable[tuple[int, int, Buffer]] = (),
__flags: int = 0,
- __address: Address | None = None,
+ __address: AddressWithNoneHost | None = None,
) -> int:
"""Similar to :meth:`socket.socket.sendmsg`, but async.
diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py
index 71b4c72a98..0d9e9316fa 100644
--- a/trio/_tests/test_highlevel_open_tcp_listeners.py
+++ b/trio/_tests/test_highlevel_open_tcp_listeners.py
@@ -1,8 +1,11 @@
+from __future__ import annotations
+
import errno
import socket as stdlib_socket
import sys
from math import inf
from socket import AddressFamily, SocketKind
+from typing import overload
import attr
import pytest
@@ -18,7 +21,7 @@
from exceptiongroup import BaseExceptionGroup
-async def test_open_tcp_listeners_basic():
+async def test_open_tcp_listeners_basic() -> None:
listeners = await open_tcp_listeners(0)
assert isinstance(listeners, list)
for obj in listeners:
@@ -46,7 +49,7 @@ async def test_open_tcp_listeners_basic():
await resource.aclose()
-async def test_open_tcp_listeners_specific_port_specific_host():
+async def test_open_tcp_listeners_specific_port_specific_host() -> None:
# Pick a port
sock = tsocket.socket()
await sock.bind(("127.0.0.1", 0))
@@ -59,7 +62,7 @@ async def test_open_tcp_listeners_specific_port_specific_host():
@binds_ipv6
-async def test_open_tcp_listeners_ipv6_v6only():
+async def test_open_tcp_listeners_ipv6_v6only() -> None:
# Check IPV6_V6ONLY is working properly
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
async with ipv6_listener:
@@ -69,7 +72,7 @@ async def test_open_tcp_listeners_ipv6_v6only():
await open_tcp_stream("127.0.0.1", port)
-async def test_open_tcp_listeners_rebind():
+async def test_open_tcp_listeners_rebind() -> None:
(l1,) = await open_tcp_listeners(0, host="127.0.0.1")
sockaddr1 = l1.socket.getsockname()
@@ -116,13 +119,13 @@ class FakeOSError(OSError):
@attr.s
class FakeSocket(tsocket.SocketType):
- _family: SocketKind = attr.ib()
- _type: AddressFamily = attr.ib()
+ _family: AddressFamily = attr.ib()
+ _type: SocketKind = attr.ib()
_proto: int = attr.ib()
- closed = attr.ib(default=False)
- poison_listen = attr.ib(default=False)
- backlog = attr.ib(default=None)
+ closed: bool = attr.ib(default=False)
+ poison_listen: bool = attr.ib(default=False)
+ backlog: int | None = attr.ib(default=None)
@property
def type(self) -> SocketKind:
@@ -136,25 +139,50 @@ def family(self) -> AddressFamily:
def proto(self) -> int:
return self._proto
- def getsockopt(self, level, option):
- if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
+ @overload
+ def getsockopt(self, /, level: int, optname: int) -> int:
+ ...
+
+ @overload
+ def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
+ ...
+
+ def getsockopt(
+ self, /, level: int, optname: int, buflen: int | None = None
+ ) -> int | bytes:
+ if (level, optname) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
return True
assert False # pragma: no cover
- def setsockopt(self, level, option, value):
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
+ ...
+
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None:
+ ...
+
+ def setsockopt(
+ self,
+ /,
+ level: int,
+ optname: int,
+ value: int | Buffer | None,
+ optlen: int | None = None,
+ ) -> None:
pass
- async def bind(self, sockaddr):
+ async def bind(self, address: AddressWithNoneHost) -> None:
pass
- def listen(self, backlog):
+ def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
assert self.backlog is None
assert backlog is not None
self.backlog = backlog
if self.poison_listen:
raise FakeOSError("whoops")
- def close(self):
+ def close(self) -> None:
self.closed = True
@@ -186,7 +214,7 @@ async def getaddrinfo(self, host, port, family, type, proto, flags):
]
-async def test_open_tcp_listeners_multiple_host_cleanup_on_error():
+async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None:
# If we were trying to bind to multiple hosts and one of them failed, they
# call get cleaned up before returning
fsf = FakeSocketFactory(3)
@@ -209,7 +237,7 @@ async def test_open_tcp_listeners_multiple_host_cleanup_on_error():
assert sock.closed
-async def test_open_tcp_listeners_port_checking():
+async def test_open_tcp_listeners_port_checking() -> None:
for host in ["127.0.0.1", None]:
with pytest.raises(TypeError):
await open_tcp_listeners(None, host=host)
@@ -219,8 +247,8 @@ async def test_open_tcp_listeners_port_checking():
await open_tcp_listeners("http", host=host)
-async def test_serve_tcp():
- async def handler(stream):
+async def test_serve_tcp() -> None:
+ async def handler(stream) -> None:
await stream.send_all(b"x")
async with trio.open_nursery() as nursery:
@@ -241,7 +269,7 @@ async def handler(stream):
)
async def test_open_tcp_listeners_some_address_families_unavailable(
try_families, fail_families
-):
+) -> None:
fsf = FakeSocketFactory(
10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families}
)
@@ -270,7 +298,7 @@ async def test_open_tcp_listeners_some_address_families_unavailable(
assert not should_succeed
-async def test_open_tcp_listeners_socket_fails_not_afnosupport():
+async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None:
fsf = FakeSocketFactory(
10,
raise_on_family={
@@ -298,7 +326,7 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport():
# effectively is no backlog), sometimes the host might not be enough resources
# to give us the full requested backlog... it was a mess. So now we just check
# that the backlog argument is passed through correctly.
-async def test_open_tcp_listeners_backlog():
+async def test_open_tcp_listeners_backlog() -> None:
fsf = FakeSocketFactory(99)
tsocket.set_custom_socket_factory(fsf)
for given, expected in [
@@ -314,7 +342,7 @@ async def test_open_tcp_listeners_backlog():
assert listener.socket.backlog == expected
-async def test_open_tcp_listeners_backlog_float_error():
+async def test_open_tcp_listeners_backlog_float_error() -> None:
fsf = FakeSocketFactory(99)
tsocket.set_custom_socket_factory(fsf)
for should_fail in (0.0, 2.18, 3.14, 9.75):
diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py
index a5f16ad777..fc1bf4c006 100644
--- a/trio/_tests/test_highlevel_open_tcp_stream.py
+++ b/trio/_tests/test_highlevel_open_tcp_stream.py
@@ -1,11 +1,13 @@
from __future__ import annotations
+
import socket
import sys
+from socket import AddressFamily, SocketKind
+from typing import TYPE_CHECKING, Any, Sequence
import attr
import pytest
-from typing import Any, Sequence, TYPE_CHECKING
import trio
from trio._highlevel_open_tcp_stream import (
close_all,
@@ -13,8 +15,7 @@
open_tcp_stream,
reorder_for_rfc_6555_section_5_4,
)
-from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType, Address
-from socket import AddressFamily, SocketKind
+from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, Address, SocketType
if TYPE_CHECKING:
from trio.testing import MockClock
@@ -27,6 +28,7 @@ def test_close_all() -> None:
class CloseMe(SocketType):
def __init__(self) -> None:
...
+
closed = False
def close(self) -> None:
@@ -35,6 +37,7 @@ def close(self) -> None:
class CloseKiller(SocketType):
def __init__(self) -> None:
...
+
def close(self) -> None:
raise OSError
@@ -272,6 +275,7 @@ def socket(
) -> SocketType:
assert isinstance(family, AddressFamily)
assert isinstance(type, SocketKind)
+ assert proto is not None
if family not in self.supported_families:
raise OSError("pretending not to support this family")
self.socket_count += 1
diff --git a/trio/_tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py
index 1a987df3f3..c5d46b6c6a 100644
--- a/trio/_tests/test_highlevel_socket.py
+++ b/trio/_tests/test_highlevel_socket.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
import errno
import socket as stdlib_socket
import sys
+from typing import Sequence
import pytest
@@ -14,16 +17,17 @@
from .test_socket import setsockopt_tests
-async def test_SocketStream_basics():
+async def test_SocketStream_basics() -> None:
# stdlib socket bad (even if connected)
- a, b = stdlib_socket.socketpair()
- with a, b:
+ stdlib_a, stdlib_b = stdlib_socket.socketpair()
+ with stdlib_a, stdlib_b:
with pytest.raises(TypeError):
- SocketStream(a)
+ SocketStream(stdlib_a) # type: ignore[arg-type]
# DGRAM socket bad
with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
with pytest.raises(ValueError):
+ # TODO: does not raise an error?
SocketStream(sock)
a, b = tsocket.socketpair()
@@ -48,13 +52,13 @@ async def test_SocketStream_basics():
s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
- b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
- assert isinstance(b, bytes)
+ res = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
+ assert isinstance(res, bytes)
setsockopt_tests(s)
-async def test_SocketStream_send_all():
+async def test_SocketStream_send_all() -> None:
BIG = 10000000
a_sock, b_sock = tsocket.socketpair()
@@ -65,7 +69,7 @@ async def test_SocketStream_send_all():
# Check a send_all that has to be split into multiple parts (on most
# platforms... on Windows every send() either succeeds or fails as a
# whole)
- async def sender():
+ async def sender() -> None:
data = bytearray(BIG)
await a.send_all(data)
# send_all uses memoryviews internally, which temporarily "lock"
@@ -89,7 +93,7 @@ async def sender():
# and we break our implementation of send_all, then we'll get some
# early warning...)
- async def receiver():
+ async def receiver() -> None:
# Make sure the sender fills up the kernel buffers and blocks
await wait_all_tasks_blocked()
nbytes = 0
@@ -109,12 +113,12 @@ async def receiver():
assert await b.receive_some(10) == b""
-async def fill_stream(s):
- async def sender():
+async def fill_stream(s: SocketStream) -> None:
+ async def sender() -> None:
while True:
await s.send_all(b"x" * 10000)
- async def waiter(nursery):
+ async def waiter(nursery: _core.Nursery) -> None:
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
@@ -123,12 +127,12 @@ async def waiter(nursery):
nursery.start_soon(waiter, nursery)
-async def test_SocketStream_generic():
- async def stream_maker():
+async def test_SocketStream_generic() -> None:
+ async def stream_maker() -> tuple[SocketStream, SocketStream]:
left, right = tsocket.socketpair()
return SocketStream(left), SocketStream(right)
- async def clogged_stream_maker():
+ async def clogged_stream_maker() -> tuple[SocketStream, SocketStream]:
left, right = await stream_maker()
await fill_stream(left)
await fill_stream(right)
@@ -137,13 +141,13 @@ async def clogged_stream_maker():
await check_half_closeable_stream(stream_maker, clogged_stream_maker)
-async def test_SocketListener():
+async def test_SocketListener() -> None:
# Not a Trio socket
with stdlib_socket.socket() as s:
s.bind(("127.0.0.1", 0))
s.listen(10)
with pytest.raises(TypeError):
- SocketListener(s)
+ SocketListener(s) # type: ignore[arg-type]
# Not a SOCK_STREAM
with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
@@ -190,7 +194,7 @@ async def test_SocketListener():
await server_stream.aclose()
-async def test_SocketListener_socket_closed_underfoot():
+async def test_SocketListener_socket_closed_underfoot() -> None:
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(10)
@@ -205,21 +209,48 @@ async def test_SocketListener_socket_closed_underfoot():
await listener.accept()
-async def test_SocketListener_accept_errors():
+async def test_SocketListener_accept_errors() -> None:
class FakeSocket(tsocket.SocketType):
- def __init__(self, events):
+ def __init__(self, events: Sequence[SocketType | BaseException]):
self._events = iter(events)
type = tsocket.SOCK_STREAM
# Fool the check for SO_ACCEPTCONN in SocketListener.__init__
- def getsockopt(self, level, opt):
+ @overload
+ def getsockopt(self, /, level: int, optname: int) -> int:
+ ...
+
+ @overload
+ def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
+ ...
+
+ def getsockopt(
+ self, /, level: int, optname: int, buflen: int | None = None
+ ) -> int | bytes:
return True
- def setsockopt(self, level, opt, value):
+ @overload
+ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
+ ...
+
+ @overload
+ def setsockopt(
+ self, /, level: int, optname: int, value: None, optlen: int
+ ) -> None:
+ ...
+
+ def setsockopt(
+ self,
+ /,
+ level: int,
+ optname: int,
+ value: int | Buffer | None,
+ optlen: int | None = None,
+ ) -> None:
pass
- async def accept(self):
+ async def accept(self) -> tuple[SocketType, object]:
await _core.checkpoint()
event = next(self._events)
if isinstance(event, BaseException):
@@ -242,24 +273,24 @@ async def accept(self):
]
)
- l = SocketListener(fake_listen_sock)
+ listener = SocketListener(fake_listen_sock)
with assert_checkpoints():
- s = await l.accept()
+ s = await listener.accept()
assert s.socket is fake_server_sock
for code in [errno.EMFILE, errno.EFAULT, errno.ENOBUFS]:
with assert_checkpoints():
with pytest.raises(OSError) as excinfo:
- await l.accept()
+ await listener.accept()
assert excinfo.value.errno == code
with assert_checkpoints():
- s = await l.accept()
+ s = await listener.accept()
assert s.socket is fake_server_sock
-async def test_socket_stream_works_when_peer_has_already_closed():
+async def test_socket_stream_works_when_peer_has_already_closed() -> None:
sock_a, sock_b = tsocket.socketpair()
with sock_a, sock_b:
await sock_b.send(b"x")
diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py
index 8536c66f89..b405e97e72 100644
--- a/trio/_tests/test_socket.py
+++ b/trio/_tests/test_socket.py
@@ -1,24 +1,33 @@
from __future__ import annotations
+
import errno
import inspect
import os
import socket as stdlib_socket
import sys
import tempfile
-from typing import Callable, Any, TYPE_CHECKING, Tuple, List, Union
-from socket import SocketKind, AddressFamily
+from socket import AddressFamily, SocketKind
+from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union
import attr
import pytest
from .. import _core, socket as tsocket
from .._core._tests.tutil import binds_ipv6, creates_ipv6
-from .._socket import _NUMERIC_ONLY, _try_sync, SocketType
+from .._highlevel_socket import SocketStream
+from .._socket import _NUMERIC_ONLY, SocketType, _SocketType, _try_sync
from ..testing import assert_checkpoints, wait_all_tasks_blocked
if TYPE_CHECKING:
from typing_extensions import TypeAlias
- GaiTuple: TypeAlias = Tuple[AddressFamily, SocketKind, int, str, Union[Tuple[str, int],Tuple[str, int, int, int]]]
+
+ GaiTuple: TypeAlias = Tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ Union[Tuple[str, int], Tuple[str, int, int, int]],
+ ]
getaddrinfoResponse: TypeAlias = List[GaiTuple]
else:
GaiTuple: object
@@ -28,10 +37,11 @@
# utils
################################################################
+
class MonkeypatchedGAI:
def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]):
self._orig_getaddrinfo = orig_getaddrinfo
- self._responses: dict[tuple[Any, ...], getaddrinfoResponse|str] = {}
+ self._responses: dict[tuple[Any, ...], getaddrinfoResponse | str] = {}
self.record: list[tuple[Any, ...]] = []
# get a normalized getaddrinfo argument tuple
@@ -43,10 +53,12 @@ def _frozenbind(self, *args: Any, **kwargs: Any) -> tuple[Any, ...]:
assert not bound.kwargs
return frozenbound
- def set(self, response: getaddrinfoResponse|str, *args: Any, **kwargs: Any) -> None:
+ def set(
+ self, response: getaddrinfoResponse | str, *args: Any, **kwargs: Any
+ ) -> None:
self._responses[self._frozenbind(*args, **kwargs)] = response
- def getaddrinfo(self, *args: Any, **kwargs: Any) -> getaddrinfoResponse|str:
+ def getaddrinfo(self, *args: Any, **kwargs: Any) -> getaddrinfoResponse | str:
bound = self._frozenbind(*args, **kwargs)
self.record.append(bound)
if bound in self._responses:
@@ -113,12 +125,26 @@ def check(got: getaddrinfoResponse, expected: getaddrinfoResponse) -> None:
# field (https://github.com/python-trio/trio/issues/1499)
# Neither field gets used much and there isn't much opportunity for us
# to mess them up, so we don't bother checking them here
- def interesting_fields(gai_tup: GaiTuple) -> tuple[AddressFamily, SocketKind, tuple[str, int]|tuple[str, int, int]|tuple[str, int, int, int]]:
+ def interesting_fields(
+ gai_tup: GaiTuple,
+ ) -> tuple[
+ AddressFamily,
+ SocketKind,
+ tuple[str, int] | tuple[str, int, int] | tuple[str, int, int, int],
+ ]:
# (family, type, proto, canonname, sockaddr)
family, type, proto, canonname, sockaddr = gai_tup
return (family, type, sockaddr)
- def filtered(gai_list: getaddrinfoResponse) -> list[tuple[AddressFamily, SocketKind, tuple[str, int]|tuple[str, int, int]|tuple[str, int, int, int]]]:
+ def filtered(
+ gai_list: getaddrinfoResponse,
+ ) -> list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ tuple[str, int] | tuple[str, int, int] | tuple[str, int, int, int],
+ ]
+ ]:
return [interesting_fields(gai_tup) for gai_tup in gai_list]
assert filtered(got) == filtered(expected)
@@ -263,7 +289,8 @@ async def child(sock: SocketType) -> None:
@pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only")
async def test_fromshare() -> None:
- assert not TYPE_CHECKING or sys.platform == "win32"
+ if TYPE_CHECKING and sys.platform != "win32":
+ return
a, b = tsocket.socketpair()
with a, b:
# share with ourselves
@@ -377,7 +404,7 @@ async def test_SocketType_setsockopt() -> None:
setsockopt_tests(sock)
-def setsockopt_tests(sock: SocketType) -> None:
+def setsockopt_tests(sock: SocketType | SocketStream) -> None:
"""Extract these out, to be reused for SocketStream also."""
# specifying optlen. Not supported on pypy, and I couldn't find
# valid calls on darwin or win32.
@@ -442,7 +469,9 @@ async def test_SocketType_shutdown() -> None:
pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6),
],
)
-async def test_SocketType_simple_server(address, socket_type) -> None:
+async def test_SocketType_simple_server(
+ address: str, socket_type: AddressFamily
+) -> None:
# listen, bind, accept, connect, getpeername, getsockname
listener = tsocket.socket(socket_type)
client = tsocket.socket(socket_type)
@@ -484,10 +513,10 @@ def gai_without_v4mapped_is_buggy() -> bool: # pragma: no cover
@attr.s
class Addresses:
- bind_all = attr.ib()
- localhost = attr.ib()
- arbitrary = attr.ib()
- broadcast = attr.ib()
+ bind_all: str = attr.ib()
+ localhost: str = attr.ib()
+ arbitrary: str = attr.ib()
+ broadcast: str = attr.ib()
# Direct thorough tests of the implicit resolver helpers
@@ -515,34 +544,53 @@ class Addresses:
),
],
)
-async def test_SocketType_resolve(socket_type, addrs) -> None:
+async def test_SocketType_resolve(socket_type: AddressFamily, addrs: Addresses) -> None:
v6 = socket_type == tsocket.AF_INET6
- def pad(addr) -> tuple[str, int]|tuple[str, int,int,int]:
+ def pad(addr: tuple[str | int, ...]) -> tuple[str | int, ...]:
if v6:
while len(addr) < 4:
addr += (0,)
return addr
- def assert_eq(actual, expected) -> None:
+ def assert_eq(
+ actual: tuple[str | int, ...], expected: tuple[str | int, ...]
+ ) -> None:
assert pad(expected) == pad(actual)
with tsocket.socket(family=socket_type) as sock:
+ # testing internal functionality, so we check it against the internal type
+ assert isinstance(sock, _SocketType)
+
# For some reason the stdlib special-cases "" to pass NULL to
# getaddrinfo. They also error out on None, but whatever, None is much
# more consistent, so we accept it too.
+ # TODO: this implies that we can send host=None, but what does that imply for the return value, and other stuff?
for null in [None, ""]:
got = await sock._resolve_address_nocp((null, 80), local=True)
+ assert not isinstance(got, (str, bytes))
assert_eq(got, (addrs.bind_all, 80))
got = await sock._resolve_address_nocp((null, 80), local=False)
+ assert not isinstance(got, (str, bytes))
assert_eq(got, (addrs.localhost, 80))
# AI_PASSIVE only affects the wildcard address, so for everything else
# local=True/local=False should work the same:
for local in [False, True]:
- async def res(*args):
- return await sock._resolve_address_nocp(*args, local=local)
+ async def res(
+ args: tuple[str, int]
+ | tuple[str, int, int]
+ | tuple[str, int, int, int]
+ | tuple[str, str]
+ | tuple[str, str, int]
+ | tuple[str, str, int, int]
+ ) -> tuple[str, int] | tuple[str, int, int, int]:
+ # we're only passing IP sockets, so we ignore the str/bytes return type
+ # But what about when port/family is a string? Should that be part of the public API?
+ res = await sock._resolve_address_nocp(args, local=local) # type: ignore[arg-type]
+ # no str/bytes
+ return res # type: ignore[return-value]
assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80))
if v6:
@@ -593,6 +641,7 @@ async def res(*args):
except (AttributeError, OSError):
pass
else:
+ assert isinstance(netlink_sock, _SocketType)
assert (
await netlink_sock._resolve_address_nocp("asdf", local=local)
== "asdf"
@@ -600,13 +649,14 @@ async def res(*args):
netlink_sock.close()
with pytest.raises(ValueError):
- await res("1.2.3.4")
+ await res("1.2.3.4") # type: ignore[arg-type]
with pytest.raises(ValueError):
- await res(("1.2.3.4",))
+ await res(("1.2.3.4",)) # type: ignore[arg-type]
with pytest.raises(ValueError):
if v6:
- await res(("1.2.3.4", 80, 0, 0, 0))
+ await res(("1.2.3.4", 80, 0, 0, 0)) # type: ignore[arg-type]
else:
+ # I guess in theory there could be enough overloads that this could error?
await res(("1.2.3.4", 80, 0, 0))
@@ -649,7 +699,7 @@ async def test_SocketType_non_blocking_paths() -> None:
# immediate failure
with assert_checkpoints():
with pytest.raises(TypeError):
- await ta.recv("haha")
+ await ta.recv("haha") # type: ignore[arg-type]
# block then succeed
async def do_successful_blocking_recv() -> None:
@@ -727,7 +777,10 @@ async def test_SocketType_connect_paths() -> None:
# nose -- and then swap it back out again before we hit
# wait_socket_writable, which insists on a real socket.
class CancelSocket(stdlib_socket.socket):
- def connect(self, *args, **kwargs) -> None:
+ def connect(self, *args: Any, **kwargs: Any) -> None:
+ # accessing private method only available in _SocketType
+ assert isinstance(sock, _SocketType)
+
cancel_scope.cancel()
sock._sock = stdlib_socket.fromfd(
self.detach(), self.family, self.type
@@ -736,6 +789,8 @@ def connect(self, *args, **kwargs) -> None:
# If connect *doesn't* raise, then pretend it did
raise BlockingIOError # pragma: no cover
+ # accessing private method only available in _SocketType
+ assert isinstance(sock, _SocketType)
sock._sock.close()
sock._sock = CancelSocket()
@@ -772,11 +827,14 @@ async def test_resolve_address_exception_in_connect_closes_socket() -> None:
with _core.CancelScope() as cancel_scope:
with tsocket.socket() as sock:
- async def _resolve_address_nocp(self, *args, **kwargs) -> None:
+ async def _resolve_address_nocp(
+ self: Any, *args: Any, **kwargs: Any
+ ) -> None:
cancel_scope.cancel()
await _core.checkpoint()
- sock._resolve_address_nocp = _resolve_address_nocp
+ assert isinstance(sock, _SocketType)
+ sock._resolve_address_nocp = _resolve_address_nocp # type: ignore[method-assign, assignment]
with assert_checkpoints():
with pytest.raises(_core.Cancelled):
await sock.connect("")
@@ -879,7 +937,7 @@ async def test_send_recv_variants() -> None:
assert await b.recv(10) == b"yyy"
-async def test_idna(monkeygai) -> None:
+async def test_idna(monkeygai: MonkeypatchedGAI) -> None:
# This is the encoding for "faß.de", which uses one of the characters that
# IDNA 2003 handles incorrectly:
monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80)
@@ -904,17 +962,22 @@ async def test_getprotobyname() -> None:
assert await tsocket.getprotobyname("tcp") == 6
-async def test_custom_hostname_resolver(monkeygai) -> None:
+async def test_custom_hostname_resolver(monkeygai: MonkeypatchedGAI) -> None:
+ # This intentionally breaks the signatures used in HostnameResolver
class CustomResolver:
- async def getaddrinfo(self, host, port, family, type, proto, flags):
+ async def getaddrinfo(
+ self, host: str, port: str, family: int, type: int, proto: int, flags: int
+ ) -> tuple[str, str, str, int, int, int, int]:
return ("custom_gai", host, port, family, type, proto, flags)
- async def getnameinfo(self, sockaddr, flags):
+ async def getnameinfo(
+ self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int
+ ) -> tuple[str, tuple[str, int] | tuple[str, int, int, int], int]:
return ("custom_gni", sockaddr, flags)
cr = CustomResolver()
- assert tsocket.set_custom_hostname_resolver(cr) is None
+ assert tsocket.set_custom_hostname_resolver(cr) is None # type: ignore[arg-type]
# Check that the arguments are all getting passed through.
# We have to use valid calls to avoid making the underlying system
@@ -937,7 +1000,11 @@ async def getnameinfo(self, sockaddr, flags):
expected = ("custom_gai", b"xn--f-1gaa", "foo", 0, 0, 0, 0)
assert got == expected
- assert await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0)
+ assert await tsocket.getnameinfo("a", 0) == ( # type: ignore[arg-type]
+ "custom_gni",
+ "a",
+ 0,
+ )
# We can set it back to None
assert tsocket.set_custom_hostname_resolver(None) is cr
@@ -950,12 +1017,14 @@ async def getnameinfo(self, sockaddr, flags):
async def test_custom_socket_factory() -> None:
class CustomSocketFactory:
- def socket(self, family, type, proto):
+ def socket(
+ self, family: AddressFamily, type: SocketKind, proto: int
+ ) -> tuple[str, AddressFamily, SocketKind, int]:
return ("hi", family, type, proto)
csf = CustomSocketFactory()
- assert tsocket.set_custom_socket_factory(csf) is None
+ assert tsocket.set_custom_socket_factory(csf) is None # type: ignore[arg-type]
assert tsocket.socket() == ("hi", tsocket.AF_INET, tsocket.SOCK_STREAM, 0)
assert tsocket.socket(1, 2, 3) == ("hi", 1, 2, 3)
@@ -985,7 +1054,7 @@ async def test_unix_domain_socket() -> None:
# Bind has a special branch to use a thread, since it has to do filesystem
# traversal. Maybe connect should too? Not sure.
- async def check_AF_UNIX(path) -> None:
+ async def check_AF_UNIX(path: str | bytes) -> None:
with tsocket.socket(family=tsocket.AF_UNIX) as lsock:
await lsock.bind(path)
lsock.listen(10)
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 3c598e8eae..3e8c300d53 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -263,6 +263,10 @@ def main() -> None: # pragma: no cover
IMPORTS_EPOLL = """\
from socket import socket
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from .._file_io import _HasFileNo
"""
IMPORTS_KQUEUE = """\
From 201da1a26b07a4eeb82c77d5fc843e8596936596 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 28 Aug 2023 17:04:01 +0200
Subject: [PATCH 03/20] hide methods in SocketType behind a TYPE_CHECKING
guard, update _io_kqueue to use _HasFileNo
---
.coveragerc | 1 +
pyproject.toml | 2 -
trio/_abc.py | 6 +-
trio/_core/_generated_io_kqueue.py | 8 +-
trio/_core/_io_epoll.py | 2 -
trio/_core/_io_kqueue.py | 11 +-
trio/_highlevel_open_tcp_listeners.py | 5 +-
trio/_socket.py | 330 +++++++++---------
.../test_highlevel_open_tcp_listeners.py | 79 +++--
trio/_tests/test_highlevel_open_tcp_stream.py | 27 +-
trio/_tools/gen_exports.py | 2 +-
11 files changed, 257 insertions(+), 216 deletions(-)
diff --git a/.coveragerc b/.coveragerc
index 5272237caf..07709cd482 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -24,6 +24,7 @@ exclude_lines =
if t.TYPE_CHECKING:
@overload
class .*\bProtocol\b.*\):
+ raise NotImplementedError
partial_branches =
pragma: no branch
diff --git a/pyproject.toml b/pyproject.toml
index aeb35977a7..ec2ef1ffe1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -96,8 +96,6 @@ module = [
"trio/_tests/test_exports",
"trio/_tests/test_file_io",
"trio/_tests/test_highlevel_generic",
-#"trio/_tests/test_highlevel_socket", #
-#"trio/_tests/test_highlevel_open_tcp_listeners", #
"trio/_tests/test_highlevel_open_unix_stream",
"trio/_tests/test_highlevel_serve_listeners",
"trio/_tests/test_highlevel_ssl_helpers",
diff --git a/trio/_abc.py b/trio/_abc.py
index 6a99ea0842..a839bd380a 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -211,9 +211,9 @@ class SocketFactory(metaclass=ABCMeta):
@abstractmethod
def socket(
self,
- family: socket.AddressFamily | int | None = None,
- type: socket.SocketKind | int | None = None,
- proto: int | None = None,
+ family: socket.AddressFamily | int = ...,
+ type: socket.SocketKind | int = ...,
+ proto: int = ...,
) -> SocketType:
"""Create and return a socket object.
diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py
index cfcf6354c7..b8b58c2edc 100644
--- a/trio/_core/_generated_io_kqueue.py
+++ b/trio/_core/_generated_io_kqueue.py
@@ -12,11 +12,11 @@
if TYPE_CHECKING:
import select
- from socket import socket
from ._traps import Abort, RaiseCancelT
from .. import _core
+ from .._file_io import _HasFileNo
import sys
@@ -49,7 +49,7 @@ async def wait_kevent(ident: int, filter: int, abort_func: Callable[[
raise RuntimeError("must be called from async context")
-async def wait_readable(fd: (int | socket)) ->None:
+async def wait_readable(fd: (int | _HasFileNo)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
@@ -57,7 +57,7 @@ async def wait_readable(fd: (int | socket)) ->None:
raise RuntimeError("must be called from async context")
-async def wait_writable(fd: (int | socket)) ->None:
+async def wait_writable(fd: (int | _HasFileNo)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
@@ -65,7 +65,7 @@ async def wait_writable(fd: (int | socket)) ->None:
raise RuntimeError("must be called from async context")
-def notify_closing(fd: (int | socket)) ->None:
+def notify_closing(fd: (int | _HasFileNo)) ->None:
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index b39ad547f1..a0373fb8fa 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -13,8 +13,6 @@
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
- from socket import socket
-
from typing_extensions import TypeAlias
from .._core import Abort, RaiseCancelT
diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py
index 56a6559091..4faa382eca 100644
--- a/trio/_core/_io_kqueue.py
+++ b/trio/_core/_io_kqueue.py
@@ -14,11 +14,10 @@
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
- from socket import socket
-
from typing_extensions import TypeAlias
from .._core import Abort, RaiseCancelT, Task, UnboundedQueue
+ from .._file_io import _HasFileNo
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
@@ -149,7 +148,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
# wait_task_rescheduled does not have its return type typed
return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return]
- async def _wait_common(self, fd: int | socket, filter: int) -> None:
+ async def _wait_common(self, fd: int | _HasFileNo, filter: int) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
@@ -181,15 +180,15 @@ def abort(_: RaiseCancelT) -> Abort:
await self.wait_kevent(fd, filter, abort)
@_public
- async def wait_readable(self, fd: int | socket) -> None:
+ async def wait_readable(self, fd: int | _HasFileNo) -> None:
await self._wait_common(fd, select.KQ_FILTER_READ)
@_public
- async def wait_writable(self, fd: int | socket) -> None:
+ async def wait_writable(self, fd: int | _HasFileNo) -> None:
await self._wait_common(fd, select.KQ_FILTER_WRITE)
@_public
- def notify_closing(self, fd: int | socket) -> None:
+ def notify_closing(self, fd: int | _HasFileNo) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
diff --git a/trio/_highlevel_open_tcp_listeners.py b/trio/_highlevel_open_tcp_listeners.py
index e6840eae97..019505c827 100644
--- a/trio/_highlevel_open_tcp_listeners.py
+++ b/trio/_highlevel_open_tcp_listeners.py
@@ -192,11 +192,12 @@ async def serve_tcp(
connect to it to check that it's working properly, you can use something
like::
+ from trio import SocketListener, SocketStream
from trio.testing import open_stream_to_socket_listener
async with trio.open_nursery() as nursery:
- listeners = await nursery.start(serve_tcp, handler, 0)
- client_stream = await open_stream_to_socket_listener(listeners[0])
+ listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0)
+ client_stream: SocketStream = await open_stream_to_socket_listener(listeners[0])
# Then send and receive data on 'client_stream', for example:
await client_stream.send_all(b"GET / HTTP/1.0\\r\\n\\r\\n")
diff --git a/trio/_socket.py b/trio/_socket.py
index ace35ff253..8321de1d63 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -13,7 +13,6 @@
Awaitable,
Callable,
Literal,
- NoReturn,
Optional,
SupportsIndex,
Tuple,
@@ -397,7 +396,7 @@ def _sniff_sockopts_for_fileno(
################################################################
-# _SocketType
+# SocketType
################################################################
# sock.type gets weird stuff set in it, in particular on Linux:
@@ -529,200 +528,211 @@ async def _resolve_address_nocp(
return normed
-# TODO: stopping users from initializing this type should be done in a different way,
-# so SocketType can be used as a type. Note that this is *far* from trivial without
-# breaking subclasses of SocketType. Can maybe add abstract methods to SocketType,
-# or rename _SocketType.
class SocketType:
- def __init__(self) -> NoReturn:
- raise TypeError(
- "SocketType is an abstract class; use trio.socket.socket if you "
- "want to construct a socket object"
- )
+ def __init__(self) -> None:
+ if type(self) == SocketType:
+ raise TypeError(
+ "SocketType is an abstract class; use trio.socket.socket if you "
+ "want to construct a socket object"
+ )
- ################################################################
- # Simple + portable methods and attributes
- ################################################################
- def detach(self) -> int:
- raise NotImplementedError()
+ if TYPE_CHECKING:
- def fileno(self) -> int:
- raise NotImplementedError()
+ def detach(self: SocketType) -> int:
+ ...
- def getpeername(self) -> Any:
- raise NotImplementedError()
+ def fileno(self: SocketType) -> int:
+ ...
- def getsockname(self) -> Any:
- raise NotImplementedError()
+ def getpeername(self: SocketType) -> Any:
+ ...
- @overload
- def getsockopt(self, /, level: int, optname: int) -> int:
- ...
+ def getsockname(self: SocketType) -> Any:
+ ...
- @overload
- def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
- ...
+ @overload
+ def getsockopt(self: SocketType, /, level: int, optname: int) -> int:
+ ...
- def getsockopt(
- self, /, level: int, optname: int, buflen: int | None = None
- ) -> int | bytes:
- raise NotImplementedError()
+ @overload
+ def getsockopt(
+ self: SocketType, /, level: int, optname: int, buflen: int
+ ) -> bytes:
+ ...
- @overload
- def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
- ...
+ def getsockopt(
+ self, /, level: int, optname: int, buflen: int | None = None
+ ) -> int | bytes:
+ ...
- @overload
- def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None:
- ...
+ @overload
+ def setsockopt(
+ self: SocketType, /, level: int, optname: int, value: int | Buffer
+ ) -> None:
+ ...
- def setsockopt(
- self,
- /,
- level: int,
- optname: int,
- value: int | Buffer | None,
- optlen: int | None = None,
- ) -> None:
- raise NotImplementedError()
+ @overload
+ def setsockopt(
+ self: SocketType, /, level: int, optname: int, value: None, optlen: int
+ ) -> None:
+ ...
- def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
- raise NotImplementedError()
+ def setsockopt(
+ self,
+ /,
+ level: int,
+ optname: int,
+ value: int | Buffer | None,
+ optlen: int | None = None,
+ ) -> None:
+ ...
- def get_inheritable(self) -> bool:
- raise NotImplementedError()
+ def listen(
+ self: SocketType, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)
+ ) -> None:
+ ...
- def set_inheritable(self, inheritable: bool) -> None:
- raise NotImplementedError()
+ def get_inheritable(self: SocketType) -> bool:
+ ...
- if sys.platform == "win32" or (
- not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share")
- ):
+ def set_inheritable(self: SocketType, inheritable: bool) -> None:
+ ...
- def share(self, /, process_id: int) -> bytes:
- raise NotImplementedError()
+ if sys.platform == "win32":
- def __enter__(self) -> Self:
- raise NotImplementedError()
+ def share(self: SocketType, /, process_id: int) -> bytes:
+ ...
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc_value: BaseException | None,
- traceback: TracebackType | None,
- ) -> None:
- raise NotImplementedError()
+ def __enter__(self) -> Self:
+ ...
- @property
- def family(self) -> AddressFamily:
- raise NotImplementedError()
+ def __exit__(
+ self: SocketType,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> None:
+ ...
- @property
- def type(self) -> SocketKind:
- raise NotImplementedError()
+ @property
+ def family(self: SocketType) -> AddressFamily:
+ ...
- @property
- def proto(self) -> int:
- raise NotImplementedError()
+ @property
+ def type(self: SocketType) -> SocketKind:
+ ...
- @property
- def did_shutdown_SHUT_WR(self) -> bool:
- raise NotImplementedError()
+ @property
+ def proto(self: SocketType) -> int:
+ ...
- # def __repr__(self) -> str:
- # raise NotImplementedError()
- def dup(self) -> SocketType:
- raise NotImplementedError()
+ @property
+ def did_shutdown_SHUT_WR(self) -> bool:
+ ...
- def close(self) -> None:
- raise NotImplementedError()
+ def __repr__(self) -> str:
+ ...
- async def bind(self, address: Address) -> None:
- raise NotImplementedError()
+ def dup(self: SocketType) -> SocketType:
+ ...
- def shutdown(self, flag: int) -> None:
- raise NotImplementedError()
+ def close(self) -> None:
+ ...
- def is_readable(self) -> bool:
- raise NotImplementedError()
+ async def bind(self, address: Address) -> None:
+ ...
- async def wait_writable(self) -> None:
- raise NotImplementedError()
+ def shutdown(self, flag: int) -> None:
+ ...
- async def accept(self) -> tuple[SocketType, object]:
- raise NotImplementedError()
-
- async def connect(self, address: Address) -> None:
- raise NotImplementedError()
-
- def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
- raise NotImplementedError()
-
- def recv_into(
- __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
- ) -> Awaitable[int]:
- raise NotImplementedError()
-
- ###
- def recvfrom(
- __self, __bufsize: int, __flags: int = 0
- ) -> Awaitable[tuple[bytes, Address]]:
- raise NotImplementedError()
-
- # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any]
- def recvfrom_into(
- __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
- ) -> Awaitable[tuple[int, Address]]:
- raise NotImplementedError()
-
- # if hasattr(_stdlib_socket.socket, "recvmsg"):
- # def recvmsg(
- # __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0
- # ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]:
- # raise NotImplementedError()
-
- # if hasattr(_stdlib_socket.socket, "recvmsg_into"):
- # def recvmsg_into(
- # __self,
- # __buffers: Iterable[Buffer],
- # __ancbufsize: int = 0,
- # __flags: int = 0,
- # ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]:
- # raise NotImplementedError()
-
- def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]:
- raise NotImplementedError()
+ def is_readable(self) -> bool:
+ ...
- @overload
- async def sendto(
- self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer
- ) -> int:
- ...
+ async def wait_writable(self) -> None:
+ ...
- @overload
- async def sendto(
- self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer
- ) -> int:
- ...
+ async def accept(self: SocketType) -> tuple[SocketType, object]:
+ ...
- @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc]
- async def sendto(self, *args: Any) -> int:
- """Similar to :meth:`socket.socket.sendto`, but async."""
- raise NotImplementedError()
+ async def connect(self, address: Address) -> None:
+ ...
- if sys.platform != "win32" or (
- not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg")
- ):
+ def recv(
+ __self: SocketType, __buflen: int, __flags: int = 0
+ ) -> Awaitable[bytes]:
+ ...
- @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=())
- async def sendmsg(
+ def recv_into(
+ __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[int]:
+ ...
+
+ # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any]
+ def recvfrom(
+ __self: SocketType, __bufsize: int, __flags: int = 0
+ ) -> Awaitable[tuple[bytes, Address]]:
+ ...
+
+ # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any]
+ def recvfrom_into(
+ __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[tuple[int, Address]]:
+ ...
+
+ if hasattr(_stdlib_socket.socket, "recvmsg"):
+
+ def recvmsg(
+ __self: SocketType,
+ __bufsize: int,
+ __ancbufsize: int = 0,
+ __flags: int = 0,
+ ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]:
+ ...
+
+ if hasattr(_stdlib_socket.socket, "recvmsg_into"):
+
+ def recvmsg_into(
+ __self: SocketType,
+ __buffers: Iterable[Buffer],
+ __ancbufsize: int = 0,
+ __flags: int = 0,
+ ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]:
+ ...
+
+ def send(
+ __self: SocketType, __bytes: Buffer, __flags: int = 0
+ ) -> Awaitable[int]:
+ ...
+
+ @overload
+ async def sendto(
+ self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer
+ ) -> int:
+ ...
+
+ @overload
+ async def sendto(
self,
- __buffers: Iterable[Buffer],
- __ancdata: Iterable[tuple[int, int, Buffer]] = (),
- __flags: int = 0,
- __address: Address | None = None,
+ __data: Buffer,
+ __flags: int,
+ __address: tuple[Any, ...] | str | Buffer,
) -> int:
- raise NotImplementedError()
+ ...
+
+ async def sendto(self, *args: Any) -> int:
+ ...
+
+ if sys.platform != "win32":
+
+ @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=())
+ async def sendmsg(
+ self,
+ __buffers: Iterable[Buffer],
+ __ancdata: Iterable[tuple[int, int, Buffer]] = (),
+ __flags: int = 0,
+ __address: AddressWithNoneHost | None = None,
+ ) -> int:
+ ...
class _SocketType(SocketType):
diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py
index 0d9e9316fa..fe7c55483d 100644
--- a/trio/_tests/test_highlevel_open_tcp_listeners.py
+++ b/trio/_tests/test_highlevel_open_tcp_listeners.py
@@ -5,21 +5,26 @@
import sys
from math import inf
from socket import AddressFamily, SocketKind
-from typing import overload
+from typing import TYPE_CHECKING, Sequence, overload
import attr
import pytest
import trio
from trio import SocketListener, open_tcp_listeners, open_tcp_stream, serve_tcp
+from trio.abc import HostnameResolver, SendStream, SocketFactory
from trio.testing import open_stream_to_socket_listener
from .. import socket as tsocket
from .._core._tests.tutil import binds_ipv6
+from .._socket import AddressWithNoneHost
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
+if TYPE_CHECKING:
+ from typing_extensions import Buffer
+
async def test_open_tcp_listeners_basic() -> None:
listeners = await open_tcp_listeners(0)
@@ -119,8 +124,8 @@ class FakeOSError(OSError):
@attr.s
class FakeSocket(tsocket.SocketType):
- _family: AddressFamily = attr.ib()
- _type: SocketKind = attr.ib()
+ _family: AddressFamily = attr.ib(converter=AddressFamily)
+ _type: SocketKind = attr.ib(converter=SocketKind)
_proto: int = attr.ib()
closed: bool = attr.ib(default=False)
@@ -175,7 +180,7 @@ def setsockopt(
async def bind(self, address: AddressWithNoneHost) -> None:
pass
- def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
+ def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None:
assert self.backlog is None
assert backlog is not None
self.backlog = backlog
@@ -187,12 +192,21 @@ def close(self) -> None:
@attr.s
-class FakeSocketFactory:
- poison_after = attr.ib()
- sockets = attr.ib(factory=list)
- raise_on_family = attr.ib(factory=dict) # family => errno
+class FakeSocketFactory(SocketFactory):
+ poison_after: int = attr.ib()
+ sockets: list[tsocket.SocketType] = attr.ib(factory=list)
+ raise_on_family: dict[AddressFamily, int] = attr.ib(factory=dict) # family => errno
- def socket(self, family, type, proto):
+ def socket(
+ self,
+ family: AddressFamily | int | None = None,
+ type: SocketKind | int | None = None,
+ proto: int = 0,
+ ) -> tsocket.SocketType:
+ assert family is not None
+ assert type is not None
+ if isinstance(family, int):
+ family = AddressFamily(family)
if family in self.raise_on_family:
raise OSError(self.raise_on_family[family], "nope")
sock = FakeSocket(family, type, proto)
@@ -204,15 +218,37 @@ def socket(self, family, type, proto):
@attr.s
-class FakeHostnameResolver:
- family_addr_pairs = attr.ib()
+class FakeHostnameResolver(HostnameResolver):
+ family_addr_pairs: Sequence[tuple[AddressFamily, str]] = attr.ib()
- async def getaddrinfo(self, host, port, family, type, proto, flags):
+ async def getaddrinfo(
+ self,
+ host: bytes | str | None,
+ port: bytes | str | int | None,
+ family: int = 0,
+ type: int = 0,
+ proto: int = 0,
+ flags: int = 0,
+ ) -> list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ tuple[str, int] | tuple[str, int, int, int],
+ ]
+ ]:
+ assert isinstance(port, int)
return [
(family, tsocket.SOCK_STREAM, 0, "", (addr, port))
for family, addr in self.family_addr_pairs
]
+ async def getnameinfo(
+ self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int
+ ) -> tuple[str, str]:
+ raise NotImplementedError()
+
async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None:
# If we were trying to bind to multiple hosts and one of them failed, they
@@ -234,25 +270,27 @@ async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None:
assert len(fsf.sockets) == 3
for sock in fsf.sockets:
- assert sock.closed
+ # property only exists on FakeSocket
+ assert sock.closed # type: ignore[attr-defined]
async def test_open_tcp_listeners_port_checking() -> None:
for host in ["127.0.0.1", None]:
with pytest.raises(TypeError):
- await open_tcp_listeners(None, host=host)
+ await open_tcp_listeners(None, host=host) # type: ignore[arg-type]
with pytest.raises(TypeError):
- await open_tcp_listeners(b"80", host=host)
+ await open_tcp_listeners(b"80", host=host) # type: ignore[arg-type]
with pytest.raises(TypeError):
- await open_tcp_listeners("http", host=host)
+ await open_tcp_listeners("http", host=host) # type: ignore[arg-type]
async def test_serve_tcp() -> None:
- async def handler(stream) -> None:
+ async def handler(stream: SendStream) -> None:
await stream.send_all(b"x")
async with trio.open_nursery() as nursery:
- listeners = await nursery.start(serve_tcp, handler, 0)
+ # nursery.start is incorrectly typed, awaiting #2773
+ listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) # type: ignore[arg-type]
stream = await open_stream_to_socket_listener(listeners[0])
async with stream:
await stream.receive_some(1) == b"x"
@@ -268,7 +306,7 @@ async def handler(stream) -> None:
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
)
async def test_open_tcp_listeners_some_address_families_unavailable(
- try_families, fail_families
+ try_families: set[AddressFamily], fail_families: set[AddressFamily]
) -> None:
fsf = FakeSocketFactory(
10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families}
@@ -339,7 +377,8 @@ async def test_open_tcp_listeners_backlog() -> None:
listeners = await open_tcp_listeners(0, backlog=given)
assert listeners
for listener in listeners:
- assert listener.socket.backlog == expected
+ # `backlog` only exists on FakeSocket
+ assert listener.socket.backlog == expected # type: ignore[attr-defined]
async def test_open_tcp_listeners_backlog_float_error() -> None:
diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py
index fc1bf4c006..0fa142ba4b 100644
--- a/trio/_tests/test_highlevel_open_tcp_stream.py
+++ b/trio/_tests/test_highlevel_open_tcp_stream.py
@@ -26,18 +26,12 @@
def test_close_all() -> None:
class CloseMe(SocketType):
- def __init__(self) -> None:
- ...
-
closed = False
def close(self) -> None:
self.closed = True
class CloseKiller(SocketType):
- def __init__(self) -> None:
- ...
-
def close(self) -> None:
raise OSError
@@ -140,6 +134,7 @@ def can_bind_127_0_0_2() -> bool:
s.bind(("127.0.0.2", 0))
except OSError:
return False
+ # s.getsockname() is typed as returning Any
return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return]
@@ -173,7 +168,7 @@ async def test_local_address_real() -> None:
server_sock, remote_addr = await listener.accept()
await client_stream.aclose()
server_sock.close()
- # accept returns tuple[SocketType, object]
+ # accept returns tuple[SocketType, object], due to typeshed returning `Any`
assert remote_addr[0] == local_address # type: ignore[index]
# Trying to connect to an ipv4 address with the ipv6 wildcard
@@ -196,9 +191,9 @@ async def test_local_address_real() -> None:
@attr.s(eq=False)
class FakeSocket(trio.socket.SocketType):
scenario: Scenario = attr.ib()
- _family: AddressFamily = attr.ib(alias="_family")
- _type: SocketKind = attr.ib(alias="_type")
- _proto: int = attr.ib(alias="_proto")
+ _family: AddressFamily = attr.ib()
+ _type: SocketKind = attr.ib()
+ _proto: int = attr.ib()
ip: str | int | None = attr.ib(default=None)
port: str | int | None = attr.ib(default=None)
@@ -286,11 +281,11 @@ def _ip_to_gai_entry(
) -> tuple[
AddressFamily,
SocketKind,
- int | None,
+ int,
str,
- tuple[int | str, int, int, int] | tuple[int | str, int],
+ tuple[str, int, int, int] | tuple[str, int],
]:
- sockaddr: tuple[int | str, int] | tuple[int | str, int, int, int]
+ sockaddr: tuple[str, int] | tuple[str, int, int, int]
if ":" in ip:
family = trio.socket.AF_INET6
sockaddr = (ip, self.port, 0, 0)
@@ -301,7 +296,7 @@ def _ip_to_gai_entry(
# should hostnameresolver use AddressFamily and SocketKind, instead of int&int?
# the return type in supertype is ... wildly incompatible with what this returns
- async def getaddrinfo( # type: ignore[override]
+ async def getaddrinfo(
self,
host: str | bytes | None,
port: bytes | str | int | None,
@@ -313,9 +308,9 @@ async def getaddrinfo( # type: ignore[override]
tuple[
AddressFamily,
SocketKind,
- int | None,
+ int,
str,
- tuple[int | str, int, int, int] | tuple[int | str, int],
+ tuple[str, int, int, int] | tuple[str, int],
]
]:
assert host == b"test.example.com"
diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py
index 3e8c300d53..63fd38d5d2 100755
--- a/trio/_tools/gen_exports.py
+++ b/trio/_tools/gen_exports.py
@@ -274,11 +274,11 @@ def main() -> None: # pragma: no cover
if TYPE_CHECKING:
import select
- from socket import socket
from ._traps import Abort, RaiseCancelT
from .. import _core
+ from .._file_io import _HasFileNo
"""
From c8fc73ce7af51aaea0d63e1b58bef25352cc6e44 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 29 Aug 2023 15:24:32 +0200
Subject: [PATCH 04/20] remove unneeded pragma: no cover
---
trio/_tests/test_highlevel_open_tcp_stream.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py
index 0fa142ba4b..775ab1f644 100644
--- a/trio/_tests/test_highlevel_open_tcp_stream.py
+++ b/trio/_tests/test_highlevel_open_tcp_stream.py
@@ -294,8 +294,6 @@ def _ip_to_gai_entry(
sockaddr = (ip, self.port)
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
- # should hostnameresolver use AddressFamily and SocketKind, instead of int&int?
- # the return type in supertype is ... wildly incompatible with what this returns
async def getaddrinfo(
self,
host: str | bytes | None,
@@ -321,7 +319,7 @@ async def getaddrinfo(
assert flags == 0
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
- async def getnameinfo( # pragma: no cover
+ async def getnameinfo(
self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int
) -> tuple[str, str]:
raise NotImplementedError
From e3fc6ae5e2f81bcc7e2d6d89edec775f497a64a2 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 29 Aug 2023 16:17:07 +0200
Subject: [PATCH 05/20] fix default parameters
---
trio/_abc.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index a839bd380a..dbc7966cee 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -211,9 +211,9 @@ class SocketFactory(metaclass=ABCMeta):
@abstractmethod
def socket(
self,
- family: socket.AddressFamily | int = ...,
- type: socket.SocketKind | int = ...,
- proto: int = ...,
+ family: socket.AddressFamily | int = socket.AF_INET,
+ type: socket.SocketKind | int = socket.SOCK_STREAM,
+ proto: int = 0,
) -> SocketType:
"""Create and return a socket object.
From d60d1a141360f9ea0d29c1a6ca0cb620ce1baf6c Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sat, 2 Sep 2023 12:54:01 +0200
Subject: [PATCH 06/20] raise NotImplementedError instead of methods not being
defined at runtime
---
trio/_socket.py | 308 ++++++++++++++++++++++++------------------------
1 file changed, 154 insertions(+), 154 deletions(-)
diff --git a/trio/_socket.py b/trio/_socket.py
index 8321de1d63..58d1d83cba 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -536,203 +536,203 @@ def __init__(self) -> None:
"want to construct a socket object"
)
- if TYPE_CHECKING:
+ def detach(self: SocketType) -> int:
+ raise NotImplementedError
- def detach(self: SocketType) -> int:
- ...
+ def fileno(self: SocketType) -> int:
+ raise NotImplementedError
- def fileno(self: SocketType) -> int:
- ...
+ def getpeername(self: SocketType) -> Any:
+ raise NotImplementedError
- def getpeername(self: SocketType) -> Any:
- ...
+ def getsockname(self: SocketType) -> Any:
+ raise NotImplementedError
- def getsockname(self: SocketType) -> Any:
- ...
-
- @overload
- def getsockopt(self: SocketType, /, level: int, optname: int) -> int:
- ...
+ @overload
+ def getsockopt(self: SocketType, /, level: int, optname: int) -> int:
+ ...
- @overload
- def getsockopt(
- self: SocketType, /, level: int, optname: int, buflen: int
- ) -> bytes:
- ...
+ @overload
+ def getsockopt(self: SocketType, /, level: int, optname: int, buflen: int) -> bytes:
+ ...
- def getsockopt(
- self, /, level: int, optname: int, buflen: int | None = None
- ) -> int | bytes:
- ...
+ def getsockopt(
+ self, /, level: int, optname: int, buflen: int | None = None
+ ) -> int | bytes:
+ raise NotImplementedError
- @overload
- def setsockopt(
- self: SocketType, /, level: int, optname: int, value: int | Buffer
- ) -> None:
- ...
+ @overload
+ def setsockopt(
+ self: SocketType, /, level: int, optname: int, value: int | Buffer
+ ) -> None:
+ ...
- @overload
- def setsockopt(
- self: SocketType, /, level: int, optname: int, value: None, optlen: int
- ) -> None:
- ...
+ @overload
+ def setsockopt(
+ self: SocketType, /, level: int, optname: int, value: None, optlen: int
+ ) -> None:
+ ...
- def setsockopt(
- self,
- /,
- level: int,
- optname: int,
- value: int | Buffer | None,
- optlen: int | None = None,
- ) -> None:
- ...
+ def setsockopt(
+ self,
+ /,
+ level: int,
+ optname: int,
+ value: int | Buffer | None,
+ optlen: int | None = None,
+ ) -> None:
+ raise NotImplementedError
- def listen(
- self: SocketType, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)
- ) -> None:
- ...
+ def listen(
+ self: SocketType, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)
+ ) -> None:
+ raise NotImplementedError
- def get_inheritable(self: SocketType) -> bool:
- ...
+ def get_inheritable(self: SocketType) -> bool:
+ raise NotImplementedError
- def set_inheritable(self: SocketType, inheritable: bool) -> None:
- ...
+ def set_inheritable(self: SocketType, inheritable: bool) -> None:
+ raise NotImplementedError
- if sys.platform == "win32":
+ if sys.platform == "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share")
+ ):
- def share(self: SocketType, /, process_id: int) -> bytes:
- ...
+ def share(self: SocketType, /, process_id: int) -> bytes:
+ raise NotImplementedError
- def __enter__(self) -> Self:
- ...
+ def __enter__(self) -> Self:
+ raise NotImplementedError
- def __exit__(
- self: SocketType,
- exc_type: type[BaseException] | None,
- exc_value: BaseException | None,
- traceback: TracebackType | None,
- ) -> None:
- ...
+ def __exit__(
+ self: SocketType,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> None:
+ raise NotImplementedError
- @property
- def family(self: SocketType) -> AddressFamily:
- ...
+ @property
+ def family(self: SocketType) -> AddressFamily:
+ raise NotImplementedError
- @property
- def type(self: SocketType) -> SocketKind:
- ...
+ @property
+ def type(self: SocketType) -> SocketKind:
+ raise NotImplementedError
- @property
- def proto(self: SocketType) -> int:
- ...
+ @property
+ def proto(self: SocketType) -> int:
+ raise NotImplementedError
- @property
- def did_shutdown_SHUT_WR(self) -> bool:
- ...
+ @property
+ def did_shutdown_SHUT_WR(self) -> bool:
+ raise NotImplementedError
- def __repr__(self) -> str:
- ...
+ def __repr__(self) -> str:
+ raise NotImplementedError
- def dup(self: SocketType) -> SocketType:
- ...
+ def dup(self: SocketType) -> SocketType:
+ raise NotImplementedError
- def close(self) -> None:
- ...
+ def close(self) -> None:
+ raise NotImplementedError
- async def bind(self, address: Address) -> None:
- ...
+ async def bind(self, address: Address) -> None:
+ raise NotImplementedError
- def shutdown(self, flag: int) -> None:
- ...
+ def shutdown(self, flag: int) -> None:
+ raise NotImplementedError
- def is_readable(self) -> bool:
- ...
+ def is_readable(self) -> bool:
+ raise NotImplementedError
- async def wait_writable(self) -> None:
- ...
+ async def wait_writable(self) -> None:
+ raise NotImplementedError
- async def accept(self: SocketType) -> tuple[SocketType, object]:
- ...
+ async def accept(self: SocketType) -> tuple[SocketType, object]:
+ raise NotImplementedError
- async def connect(self, address: Address) -> None:
- ...
+ async def connect(self, address: Address) -> None:
+ raise NotImplementedError
- def recv(
- __self: SocketType, __buflen: int, __flags: int = 0
- ) -> Awaitable[bytes]:
- ...
+ def recv(__self: SocketType, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
+ raise NotImplementedError
- def recv_into(
- __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0
- ) -> Awaitable[int]:
- ...
+ def recv_into(
+ __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[int]:
+ raise NotImplementedError
- # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any]
- def recvfrom(
- __self: SocketType, __bufsize: int, __flags: int = 0
- ) -> Awaitable[tuple[bytes, Address]]:
- ...
+ # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any]
+ def recvfrom(
+ __self: SocketType, __bufsize: int, __flags: int = 0
+ ) -> Awaitable[tuple[bytes, Address]]:
+ raise NotImplementedError
- # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any]
- def recvfrom_into(
- __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0
- ) -> Awaitable[tuple[int, Address]]:
- ...
+ # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any]
+ def recvfrom_into(
+ __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[tuple[int, Address]]:
+ raise NotImplementedError
- if hasattr(_stdlib_socket.socket, "recvmsg"):
+ if sys.platform == "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg")
+ ):
- def recvmsg(
- __self: SocketType,
- __bufsize: int,
- __ancbufsize: int = 0,
- __flags: int = 0,
- ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]:
- ...
+ def recvmsg(
+ __self: SocketType,
+ __bufsize: int,
+ __ancbufsize: int = 0,
+ __flags: int = 0,
+ ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]:
+ raise NotImplementedError
- if hasattr(_stdlib_socket.socket, "recvmsg_into"):
+ if sys.platform == "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into")
+ ):
- def recvmsg_into(
- __self: SocketType,
- __buffers: Iterable[Buffer],
- __ancbufsize: int = 0,
- __flags: int = 0,
- ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]:
- ...
+ def recvmsg_into(
+ __self: SocketType,
+ __buffers: Iterable[Buffer],
+ __ancbufsize: int = 0,
+ __flags: int = 0,
+ ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]:
+ raise NotImplementedError
- def send(
- __self: SocketType, __bytes: Buffer, __flags: int = 0
- ) -> Awaitable[int]:
- ...
+ def send(__self: SocketType, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]:
+ raise NotImplementedError
- @overload
- async def sendto(
- self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer
- ) -> int:
- ...
+ @overload
+ async def sendto(
+ self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer
+ ) -> int:
+ ...
- @overload
- async def sendto(
- self,
- __data: Buffer,
- __flags: int,
- __address: tuple[Any, ...] | str | Buffer,
- ) -> int:
- ...
+ @overload
+ async def sendto(
+ self,
+ __data: Buffer,
+ __flags: int,
+ __address: tuple[Any, ...] | str | Buffer,
+ ) -> int:
+ ...
- async def sendto(self, *args: Any) -> int:
- ...
+ async def sendto(self, *args: Any) -> int:
+ raise NotImplementedError
- if sys.platform != "win32":
+ if sys.platform == "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg")
+ ):
- @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=())
- async def sendmsg(
- self,
- __buffers: Iterable[Buffer],
- __ancdata: Iterable[tuple[int, int, Buffer]] = (),
- __flags: int = 0,
- __address: AddressWithNoneHost | None = None,
- ) -> int:
- ...
+ @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=())
+ async def sendmsg(
+ self,
+ __buffers: Iterable[Buffer],
+ __ancdata: Iterable[tuple[int, int, Buffer]] = (),
+ __flags: int = 0,
+ __address: AddressWithNoneHost | None = None,
+ ) -> int:
+ raise NotImplementedError
class _SocketType(SocketType):
From b213602bc05271a1d9c28443a3792ec1cb5e8f62 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sat, 2 Sep 2023 13:18:55 +0200
Subject: [PATCH 07/20] remove docs reference
---
docs/source/reference-io.rst | 8 --------
1 file changed, 8 deletions(-)
diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst
index 61bbef78c2..e234d255c1 100644
--- a/docs/source/reference-io.rst
+++ b/docs/source/reference-io.rst
@@ -504,14 +504,6 @@ Socket objects
* :meth:`~socket.socket.set_inheritable`
* :meth:`~socket.socket.get_inheritable`
-The internal SocketType
-~~~~~~~~~~~~~~~~~~~~~~~~~~
-..
- .. autoclass:: _SocketType
-..
- TODO: adding `:members:` here gives error due to overload+_wraps on `sendto`
- TODO: rewrite ... all of the above when fixing _SocketType vs SocketType
-
.. currentmodule:: trio
From 7eb099d7e5d7c63c5e3cfff511c668d30eb10216 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sat, 2 Sep 2023 14:08:15 +0200
Subject: [PATCH 08/20] experimentally changing SocketType to be a generic
---
trio/_abc.py | 6 +-
trio/_dtls.py | 26 ++--
trio/_highlevel_open_tcp_stream.py | 36 ++---
trio/_highlevel_socket.py | 6 +-
trio/_socket.py | 183 ++++++++++++--------------
trio/_tests/verify_types_darwin.json | 2 +-
trio/_tests/verify_types_linux.json | 2 +-
trio/_tests/verify_types_windows.json | 6 +-
trio/socket.py | 1 -
9 files changed, 115 insertions(+), 153 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index dbc7966cee..724df92855 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -2,7 +2,7 @@
import socket
from abc import ABCMeta, abstractmethod
-from typing import TYPE_CHECKING, Generic, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, TypeVar
import trio
@@ -209,12 +209,12 @@ class SocketFactory(metaclass=ABCMeta):
"""
@abstractmethod
- def socket(
+ def socket( # type: ignore[misc]
self,
family: socket.AddressFamily | int = socket.AF_INET,
type: socket.SocketKind | int = socket.SOCK_STREAM,
proto: int = 0,
- ) -> SocketType:
+ ) -> SocketType[Any]:
"""Create and return a socket object.
Your socket object must inherit from :class:`trio.socket.SocketType`,
diff --git a/trio/_dtls.py b/trio/_dtls.py
index ff6a61ba88..63ca83db1f 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -43,26 +43,26 @@
from OpenSSL.SSL import Context
from typing_extensions import Self, TypeAlias
- from trio.socket import Address, SocketType
+ from trio.socket import SocketType
MAX_UDP_PACKET_SIZE = 65527
-def packet_header_overhead(sock: SocketType) -> int:
+def packet_header_overhead(sock: SocketType[Any]) -> int:
if sock.family == trio.socket.AF_INET:
return 28
else:
return 48
-def worst_case_mtu(sock: SocketType) -> int:
+def worst_case_mtu(sock: SocketType[Any]) -> int:
if sock.family == trio.socket.AF_INET:
return 576 - packet_header_overhead(sock)
else:
return 1280 - packet_header_overhead(sock)
-def best_guess_mtu(sock: SocketType) -> int:
+def best_guess_mtu(sock: SocketType[Any]) -> int:
return 1500 - packet_header_overhead(sock)
@@ -563,7 +563,7 @@ def _signable(*fields: bytes) -> bytes:
def _make_cookie(
- key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes
+ key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes
) -> bytes:
assert len(salt) == SALT_BYTES
assert len(key) == KEY_BYTES
@@ -581,7 +581,7 @@ def _make_cookie(
def valid_cookie(
- key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes
+ key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes
) -> bool:
if len(cookie) > SALT_BYTES:
salt = cookie[:SALT_BYTES]
@@ -603,7 +603,7 @@ def valid_cookie(
def challenge_for(
- key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes
+ key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes
) -> bytes:
salt = os.urandom(SALT_BYTES)
tick = _current_cookie_tick()
@@ -664,7 +664,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
async def handle_client_hello_untrusted(
- endpoint: DTLSEndpoint, address: Address, packet: bytes
+ endpoint: DTLSEndpoint, address: Any, packet: bytes
) -> None:
if endpoint._listening_context is None:
return
@@ -739,7 +739,7 @@ async def handle_client_hello_untrusted(
async def dtls_receive_loop(
- endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType
+ endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType[Any]
) -> None:
try:
while True:
@@ -828,7 +828,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
"""
- def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context):
+ def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context):
self.endpoint = endpoint
self.peer_address = peer_address
self._packets_dropped_in_trio = 0
@@ -1178,7 +1178,7 @@ class DTLSEndpoint(metaclass=Final):
"""
- def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10):
+ def __init__(self, socket: SocketType[Any], *, incoming_packets_buffer: int = 10):
# We do this lazily on first construction, so only people who actually use DTLS
# have to install PyOpenSSL.
global SSL
@@ -1189,7 +1189,7 @@ def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10):
if socket.type != trio.socket.SOCK_DGRAM:
raise ValueError("DTLS requires a SOCK_DGRAM socket")
self._initialized = True
- self.socket: SocketType = socket
+ self.socket: SocketType[Any] = socket
self.incoming_packets_buffer = incoming_packets_buffer
self._token = trio.lowlevel.current_trio_token()
@@ -1198,7 +1198,7 @@ def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10):
# as a peer provides a valid cookie, we can immediately tear down the
# old connection.
# {remote address: DTLSChannel}
- self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary()
+ self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary()
self._listening_context: Context | None = None
self._listening_key: bytes | None = None
self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py
index 322ae4006e..358d9f32f0 100644
--- a/trio/_highlevel_open_tcp_stream.py
+++ b/trio/_highlevel_open_tcp_stream.py
@@ -4,11 +4,11 @@
from collections.abc import Generator
from contextlib import contextmanager
from socket import AddressFamily, SocketKind
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
import trio
from trio._core._multierror import MultiError
-from trio.socket import SOCK_STREAM, Address, SocketType, getaddrinfo, socket
+from trio.socket import SOCK_STREAM, SocketType, getaddrinfo, socket
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
@@ -114,8 +114,8 @@
@contextmanager
-def close_all() -> Generator[set[SocketType], None, None]:
- sockets_to_close: set[SocketType] = set()
+def close_all() -> Generator[set[SocketType[Any]], None, None]: # type: ignore[misc]
+ sockets_to_close: set[SocketType[Any]] = set()
try:
yield sockets_to_close
finally:
@@ -131,7 +131,6 @@ def close_all() -> Generator[set[SocketType], None, None]:
raise MultiError(errs)
-# workaround for list being invariant
def reorder_for_rfc_6555_section_5_4(
targets: list[
tuple[
@@ -139,23 +138,7 @@ def reorder_for_rfc_6555_section_5_4(
SocketKind,
int,
str,
- tuple[str, int] | tuple[str, int, int, int],
- ]
- ] | list[
- tuple[
- AddressFamily,
- SocketKind,
- int,
- str,
- tuple[str, int]
- ]
- ] | list[
- tuple[
- AddressFamily,
- SocketKind,
- int,
- str,
- tuple[str, int, int, int],
+ Any,
]
]
) -> None:
@@ -172,12 +155,11 @@ def reorder_for_rfc_6555_section_5_4(
# Found the first entry with a different address family; move it
# so that it becomes the second item on the list.
if i != 1:
- # invariant workaround in arguments leads to type issues here
- targets.insert(1, targets.pop(i)) # type: ignore[arg-type]
+ targets.insert(1, targets.pop(i))
break
-def format_host_port(host: str | bytes, port: int|str) -> str:
+def format_host_port(host: str | bytes, port: int | str) -> str:
host = host.decode("ascii") if isinstance(host, bytes) else host
if ":" in host:
return f"[{host}]:{port}"
@@ -310,7 +292,7 @@ async def open_tcp_stream(
# Keeps track of the socket that we're going to complete with,
# need to make sure this isn't automatically closed
- winning_socket: SocketType | None = None
+ winning_socket: SocketType[Any] | None = None
# Try connecting to the specified address. Possible outcomes:
# - success: record connected socket in winning_socket and cancel
@@ -321,7 +303,7 @@ async def open_tcp_stream(
# face of crash or cancellation
async def attempt_connect(
socket_args: tuple[AddressFamily, SocketKind, int],
- sockaddr: Address,
+ sockaddr: Any,
attempt_failed: trio.Event,
) -> None:
nonlocal winning_socket
diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py
index d733537752..edb07f2e82 100644
--- a/trio/_highlevel_socket.py
+++ b/trio/_highlevel_socket.py
@@ -4,7 +4,7 @@
import errno
from collections.abc import Generator
from contextlib import contextmanager
-from typing import TYPE_CHECKING, overload
+from typing import TYPE_CHECKING, Any, overload
import trio
@@ -66,7 +66,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final):
"""
- def __init__(self, socket: SocketType):
+ def __init__(self, socket: SocketType[Any]):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketStream requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
@@ -371,7 +371,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final):
"""
- def __init__(self, socket: SocketType):
+ def __init__(self, socket: SocketType[Any]):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketListener requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
diff --git a/trio/_socket.py b/trio/_socket.py
index 58d1d83cba..ca97a0ef37 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -12,10 +12,9 @@
Any,
Awaitable,
Callable,
+ Generic,
Literal,
- Optional,
SupportsIndex,
- Tuple,
TypeVar,
Union,
overload,
@@ -39,18 +38,7 @@
T = TypeVar("T")
-
-# must use old-style typing because it's evaluated at runtime
-Address: TypeAlias = Union[
- str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int]
-]
-AddressWithNoneHost: TypeAlias = Union[
- str,
- bytes,
- Tuple[Optional[str], int],
- Tuple[Optional[str], int, int],
- Tuple[Optional[str], int, int, int],
-]
+AddressFormat = TypeVar("AddressFormat")
# Usage:
@@ -180,11 +168,7 @@ async def getaddrinfo(
flags: int = 0,
) -> list[
tuple[
- AddressFamily,
- SocketKind,
- int,
- str,
- tuple[str, int] | tuple[str, int, int, int],
+ AddressFamily, SocketKind, int, str, tuple[str, int] | tuple[str, int, int, int]
]
]:
"""Look up a numeric address given a name.
@@ -289,7 +273,7 @@ async def getprotobyname(name: str) -> int:
################################################################
-def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType:
+def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType[Any]:
"""Convert a standard library :class:`socket.socket` object into a Trio
socket object.
@@ -298,12 +282,12 @@ def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType:
@_wraps(_stdlib_socket.fromfd, assigned=(), updated=())
-def fromfd(
+def fromfd( # type: ignore[misc] # use of Any in decorated function
fd: SupportsIndex,
family: AddressFamily | int = _stdlib_socket.AF_INET,
type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
-) -> SocketType:
+) -> SocketType[Any]:
"""Like :func:`socket.fromfd`, but returns a Trio socket object."""
family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd))
return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto))
@@ -314,7 +298,7 @@ def fromfd(
):
@_wraps(_stdlib_socket.fromshare, assigned=(), updated=())
- def fromshare(info: bytes) -> SocketType:
+ def fromshare(info: bytes) -> SocketType[Any]:
return from_stdlib_socket(_stdlib_socket.fromshare(info))
@@ -329,11 +313,11 @@ def fromshare(info: bytes) -> SocketType:
@_wraps(_stdlib_socket.socketpair, assigned=(), updated=())
-def socketpair(
+def socketpair( # type: ignore[misc]
family: FamilyT = FamilyDefault,
type: TypeT = SocketKind.SOCK_STREAM,
proto: int = 0,
-) -> tuple[SocketType, SocketType]:
+) -> tuple[SocketType[Any], SocketType[Any]]:
"""Like :func:`socket.socketpair`, but returns a pair of Trio socket
objects.
@@ -343,12 +327,12 @@ def socketpair(
@_wraps(_stdlib_socket.socket, assigned=(), updated=())
-def socket(
+def socket( # type: ignore[misc]
family: AddressFamily | int = _stdlib_socket.AF_INET,
type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
fileno: int | None = None,
-) -> SocketType:
+) -> SocketType[Any]:
"""Create a new Trio socket, like :class:`socket.socket`.
This function's behavior can be customized using
@@ -416,9 +400,9 @@ def _make_simple_sock_method_wrapper(
fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]],
maybe_avail: bool = False,
-) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]:
- @_wraps(fn, assigned=("__name__",), updated=())
- async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T:
+) -> Callable[Concatenate[_SocketType[Any], P], Awaitable[T]]:
+ @_wraps(fn, assigned=("__name__",), updated=()) # type: ignore
+ async def wrapper(self: _SocketType[Any], *args: P.args, **kwargs: P.kwargs) -> T:
return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs)
wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async.
@@ -457,9 +441,9 @@ async def _resolve_address_nocp(
proto: int,
*,
ipv6_v6only: bool | int,
- address: AddressWithNoneHost,
+ address: AddressFormat,
local: bool,
-) -> Address:
+) -> Any:
# Do some pre-checking (or exit early for non-IP sockets)
if family == _stdlib_socket.AF_INET:
if not isinstance(address, tuple) or not len(address) == 2:
@@ -474,8 +458,7 @@ async def _resolve_address_nocp(
assert isinstance(address, (str, bytes))
return os.fspath(address)
else:
- # TODO: check for host is None?
- return address # type: ignore[return-value]
+ return address
# -- From here on we know we have IPv4 or IPV6 --
host: str | None
@@ -489,7 +472,7 @@ async def _resolve_address_nocp(
except (OSError, TypeError):
pass
else:
- return address # type: ignore[return-value]
+ return address
# Special cases to match the stdlib, see gh-277
if host == "":
host = None
@@ -518,17 +501,17 @@ async def _resolve_address_nocp(
if family == _stdlib_socket.AF_INET6:
list_normed = list(normed)
assert len(normed) == 4
- # typechecking certainly doesn't like this logic, but given just how broad
- # Address is, it's quite cumbersome to write the below without type: ignore
if len(address) >= 3:
- list_normed[2] = address[2] # type: ignore
+ list_normed[2] = address[2]
if len(address) >= 4:
- list_normed[3] = address[3] # type: ignore
- return tuple(list_normed) # type: ignore
+ list_normed[3] = address[3]
+ return tuple(list_normed)
return normed
-class SocketType:
+# _stdlib_socket.socket supports 13 different socket families: https://docs.python.org/3/library/socket.html?highlight=sendmsg#socket-families
+# and the return type of several methods will depend on those. typeshed has ended up typing those return types as `Any` in most cases, but for users that know which family/families they're working in and wants complete type coverage they can specify the AddressFormat.
+class SocketType(Generic[AddressFormat]):
def __init__(self) -> None:
if type(self) == SocketType:
raise TypeError(
@@ -536,24 +519,24 @@ def __init__(self) -> None:
"want to construct a socket object"
)
- def detach(self: SocketType) -> int:
+ def detach(self) -> int:
raise NotImplementedError
- def fileno(self: SocketType) -> int:
+ def fileno(self) -> int:
raise NotImplementedError
- def getpeername(self: SocketType) -> Any:
+ def getpeername(self) -> AddressFormat:
raise NotImplementedError
- def getsockname(self: SocketType) -> Any:
+ def getsockname(self) -> AddressFormat:
raise NotImplementedError
@overload
- def getsockopt(self: SocketType, /, level: int, optname: int) -> int:
+ def getsockopt(self, /, level: int, optname: int) -> int:
...
@overload
- def getsockopt(self: SocketType, /, level: int, optname: int, buflen: int) -> bytes:
+ def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
...
def getsockopt(
@@ -562,15 +545,11 @@ def getsockopt(
raise NotImplementedError
@overload
- def setsockopt(
- self: SocketType, /, level: int, optname: int, value: int | Buffer
- ) -> None:
+ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
...
@overload
- def setsockopt(
- self: SocketType, /, level: int, optname: int, value: None, optlen: int
- ) -> None:
+ def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None:
...
def setsockopt(
@@ -583,29 +562,27 @@ def setsockopt(
) -> None:
raise NotImplementedError
- def listen(
- self: SocketType, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)
- ) -> None:
+ def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
raise NotImplementedError
- def get_inheritable(self: SocketType) -> bool:
+ def get_inheritable(self) -> bool:
raise NotImplementedError
- def set_inheritable(self: SocketType, inheritable: bool) -> None:
+ def set_inheritable(self, inheritable: bool) -> None:
raise NotImplementedError
if sys.platform == "win32" or (
not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share")
):
- def share(self: SocketType, /, process_id: int) -> bytes:
+ def share(self, /, process_id: int) -> bytes:
raise NotImplementedError
def __enter__(self) -> Self:
raise NotImplementedError
def __exit__(
- self: SocketType,
+ self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
@@ -613,15 +590,15 @@ def __exit__(
raise NotImplementedError
@property
- def family(self: SocketType) -> AddressFamily:
+ def family(self) -> AddressFamily:
raise NotImplementedError
@property
- def type(self: SocketType) -> SocketKind:
+ def type(self) -> SocketKind:
raise NotImplementedError
@property
- def proto(self: SocketType) -> int:
+ def proto(self) -> int:
raise NotImplementedError
@property
@@ -631,13 +608,13 @@ def did_shutdown_SHUT_WR(self) -> bool:
def __repr__(self) -> str:
raise NotImplementedError
- def dup(self: SocketType) -> SocketType:
+ def dup(self) -> SocketType[AddressFormat]:
raise NotImplementedError
def close(self) -> None:
raise NotImplementedError
- async def bind(self, address: Address) -> None:
+ async def bind(self, address: AddressFormat) -> None:
raise NotImplementedError
def shutdown(self, flag: int) -> None:
@@ -649,57 +626,57 @@ def is_readable(self) -> bool:
async def wait_writable(self) -> None:
raise NotImplementedError
- async def accept(self: SocketType) -> tuple[SocketType, object]:
+ async def accept(self) -> tuple[SocketType[AddressFormat], AddressFormat]:
raise NotImplementedError
- async def connect(self, address: Address) -> None:
+ async def connect(self, address: AddressFormat) -> None:
raise NotImplementedError
- def recv(__self: SocketType, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
+ def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
raise NotImplementedError
def recv_into(
- __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
) -> Awaitable[int]:
raise NotImplementedError
# return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any]
def recvfrom(
- __self: SocketType, __bufsize: int, __flags: int = 0
- ) -> Awaitable[tuple[bytes, Address]]:
+ __self, __bufsize: int, __flags: int = 0
+ ) -> Awaitable[tuple[bytes, AddressFormat]]:
raise NotImplementedError
# return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any]
def recvfrom_into(
- __self: SocketType, buffer: Buffer, nbytes: int = 0, flags: int = 0
- ) -> Awaitable[tuple[int, Address]]:
+ __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[tuple[int, AddressFormat]]:
raise NotImplementedError
- if sys.platform == "win32" or (
+ if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg")
):
def recvmsg(
- __self: SocketType,
+ __self,
__bufsize: int,
__ancbufsize: int = 0,
__flags: int = 0,
) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]:
raise NotImplementedError
- if sys.platform == "win32" or (
+ if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into")
):
def recvmsg_into(
- __self: SocketType,
+ __self,
__buffers: Iterable[Buffer],
__ancbufsize: int = 0,
__flags: int = 0,
) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]:
raise NotImplementedError
- def send(__self: SocketType, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]:
+ def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]:
raise NotImplementedError
@overload
@@ -720,7 +697,7 @@ async def sendto(
async def sendto(self, *args: Any) -> int:
raise NotImplementedError
- if sys.platform == "win32" or (
+ if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg")
):
@@ -730,12 +707,12 @@ async def sendmsg(
__buffers: Iterable[Buffer],
__ancdata: Iterable[tuple[int, int, Buffer]] = (),
__flags: int = 0,
- __address: AddressWithNoneHost | None = None,
+ __address: AddressFormat | None = None,
) -> int:
raise NotImplementedError
-class _SocketType(SocketType):
+class _SocketType(SocketType[AddressFormat]):
def __init__(self, sock: _stdlib_socket.socket):
if type(sock) is not _stdlib_socket.socket:
# For example, ssl.SSLSocket subclasses socket.socket, but we
@@ -758,11 +735,11 @@ def detach(self) -> int:
def fileno(self) -> int:
return self._sock.fileno()
- def getpeername(self) -> Any:
- return self._sock.getpeername()
+ def getpeername(self) -> AddressFormat:
+ return self._sock.getpeername() # type: ignore[no-any-return]
- def getsockname(self) -> Any:
- return self._sock.getsockname()
+ def getsockname(self) -> AddressFormat:
+ return self._sock.getsockname() # type: ignore[no-any-return]
@overload
def getsockopt(self, /, level: int, optname: int) -> int:
@@ -856,7 +833,7 @@ def did_shutdown_SHUT_WR(self) -> bool:
def __repr__(self) -> str:
return repr(self._sock).replace("socket.socket", "trio.socket.socket")
- def dup(self) -> _SocketType:
+ def dup(self) -> SocketType[AddressFormat]:
"""Same as :meth:`socket.socket.dup`."""
return _SocketType(self._sock.dup())
@@ -865,12 +842,12 @@ def close(self) -> None:
trio.lowlevel.notify_closing(self._sock)
self._sock.close()
- async def bind(self, address: AddressWithNoneHost) -> None:
+ async def bind(self, address: AddressFormat) -> None:
address = await self._resolve_address_nocp(address, local=True)
if (
hasattr(_stdlib_socket, "AF_UNIX")
and self.family == _stdlib_socket.AF_UNIX
- and address[0]
+ and address[0] # type: ignore[index]
):
# Use a thread for the filesystem traversal (unless it's an
# abstract domain socket)
@@ -881,7 +858,7 @@ async def bind(self, address: AddressWithNoneHost) -> None:
# there aren't yet any real systems that do this, so we'll worry
# about it when it happens.
await trio.lowlevel.checkpoint()
- return self._sock.bind(address)
+ return self._sock.bind(address) # type: ignore[arg-type]
def shutdown(self, flag: int) -> None:
# no need to worry about return value b/c always returns None:
@@ -904,17 +881,17 @@ async def wait_writable(self) -> None:
async def _resolve_address_nocp(
self,
- address: AddressWithNoneHost,
+ address: AddressFormat,
*,
local: bool,
- ) -> Address:
+ ) -> AddressFormat:
if self.family == _stdlib_socket.AF_INET6:
ipv6_v6only = self._sock.getsockopt(
_stdlib_socket.IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY
)
else:
ipv6_v6only = False
- return await _resolve_address_nocp(
+ return await _resolve_address_nocp( # type: ignore[no-any-return]
self.type,
self.family,
self.proto,
@@ -975,7 +952,7 @@ async def _nonblocking_helper(
_stdlib_socket.socket.accept, _core.wait_readable
)
- async def accept(self) -> tuple[SocketType, object]:
+ async def accept(self) -> tuple[SocketType[AddressFormat], AddressFormat]:
"""Like :meth:`socket.socket.accept`, but async."""
sock, addr = await self._accept()
return from_stdlib_socket(sock), addr
@@ -984,7 +961,7 @@ async def accept(self) -> tuple[SocketType, object]:
# connect
################################################################
- async def connect(self, address: AddressWithNoneHost) -> None:
+ async def connect(self, address: AddressFormat) -> None:
# nonblocking connect is weird -- you call it to start things
# off, then the socket becomes writable as a completion
# notification. This means it isn't really cancellable... we close the
@@ -1039,7 +1016,7 @@ async def connect(self, address: AddressWithNoneHost) -> None:
# happens, someone will hopefully tell us, and then hopefully we
# can investigate their system to figure out what its semantics
# are.
- return self._sock.connect(address)
+ return self._sock.connect(address) # type: ignore[arg-type]
# It raised BlockingIOError, meaning that it's started the
# connection attempt. We wait for it to complete:
await _core.wait_writable(self._sock)
@@ -1097,7 +1074,7 @@ def recv_into(
# return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any]
def recvfrom(
__self, __bufsize: int, __flags: int = 0
- ) -> Awaitable[tuple[bytes, Address]]:
+ ) -> Awaitable[tuple[bytes, AddressFormat]]:
...
recvfrom = _make_simple_sock_method_wrapper( # noqa: F811
@@ -1112,7 +1089,7 @@ def recvfrom(
# return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any]
def recvfrom_into(
__self, buffer: Buffer, nbytes: int = 0, flags: int = 0
- ) -> Awaitable[tuple[int, Address]]:
+ ) -> Awaitable[tuple[int, AddressFormat]]:
...
recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811
@@ -1123,7 +1100,9 @@ def recvfrom_into(
# recvmsg
################################################################
- if hasattr(_stdlib_socket.socket, "recvmsg"):
+ if sys.platform != "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg")
+ ):
if TYPE_CHECKING:
def recvmsg(
@@ -1139,7 +1118,9 @@ def recvmsg(
# recvmsg_into
################################################################
- if hasattr(_stdlib_socket.socket, "recvmsg_into"):
+ if sys.platform != "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into")
+ ):
if TYPE_CHECKING:
def recvmsg_into(
@@ -1208,7 +1189,7 @@ async def sendmsg(
__buffers: Iterable[Buffer],
__ancdata: Iterable[tuple[int, int, Buffer]] = (),
__flags: int = 0,
- __address: AddressWithNoneHost | None = None,
+ __address: AddressFormat | None = None,
) -> int:
"""Similar to :meth:`socket.socket.sendmsg`, but async.
@@ -1224,7 +1205,7 @@ async def sendmsg(
__buffers,
__ancdata,
__flags,
- __address,
+ __address, # type: ignore[arg-type]
)
################################################################
diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json
index 2b89d28d8e..1f7e81c182 100644
--- a/trio/_tests/verify_types_darwin.json
+++ b/trio/_tests/verify_types_darwin.json
@@ -40,7 +40,7 @@
],
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 631,
+ "withKnownType": 630,
"withUnknownType": 0
},
"ignoreUnknownTypesFromImports": true,
diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json
index ea5af77abc..e128b89775 100644
--- a/trio/_tests/verify_types_linux.json
+++ b/trio/_tests/verify_types_linux.json
@@ -28,7 +28,7 @@
],
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 628,
+ "withKnownType": 627,
"withUnknownType": 0
},
"ignoreUnknownTypesFromImports": true,
diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json
index 5d3e29a5dc..c14fb3bee9 100644
--- a/trio/_tests/verify_types_windows.json
+++ b/trio/_tests/verify_types_windows.json
@@ -7,7 +7,7 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.9857369255150554,
+ "completenessScore": 0.9857142857142858,
"diagnostics": [
{
"message": "Return type annotation is missing",
@@ -144,7 +144,7 @@
],
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 622,
+ "withKnownType": 621,
"withUnknownType": 9
},
"ignoreUnknownTypesFromImports": true,
@@ -180,7 +180,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 671,
+ "withKnownType": 667,
"withUnknownType": 0
},
"packageName": "trio"
diff --git a/trio/socket.py b/trio/socket.py
index ddefc72649..a9e276c782 100644
--- a/trio/socket.py
+++ b/trio/socket.py
@@ -34,7 +34,6 @@
# import the overwrites
from ._socket import (
- Address as Address,
SocketType as SocketType,
from_stdlib_socket as from_stdlib_socket,
fromfd as fromfd,
From ab66937342b4adbde63f4bc8bf1ab2e47a7fba42 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 5 Sep 2023 11:48:24 +0200
Subject: [PATCH 09/20] fix url
---
trio/_socket.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trio/_socket.py b/trio/_socket.py
index ca97a0ef37..8dee44a279 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -509,7 +509,7 @@ async def _resolve_address_nocp(
return normed
-# _stdlib_socket.socket supports 13 different socket families: https://docs.python.org/3/library/socket.html?highlight=sendmsg#socket-families
+# _stdlib_socket.socket supports 13 different socket families: https://docs.python.org/3/library/socket.html#socket-families
# and the return type of several methods will depend on those. typeshed has ended up typing those return types as `Any` in most cases, but for users that know which family/families they're working in and wants complete type coverage they can specify the AddressFormat.
class SocketType(Generic[AddressFormat]):
def __init__(self) -> None:
From 304fd32b022e01be788a02e5862c42d5e04ec80d Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 5 Sep 2023 13:40:33 +0200
Subject: [PATCH 10/20] fix old imports
---
trio/_tests/test_highlevel_open_tcp_listeners.py | 1 -
trio/_tests/test_highlevel_open_tcp_stream.py | 2 +-
2 files changed, 1 insertion(+), 2 deletions(-)
diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py
index fe7c55483d..776920dade 100644
--- a/trio/_tests/test_highlevel_open_tcp_listeners.py
+++ b/trio/_tests/test_highlevel_open_tcp_listeners.py
@@ -17,7 +17,6 @@
from .. import socket as tsocket
from .._core._tests.tutil import binds_ipv6
-from .._socket import AddressWithNoneHost
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py
index 775ab1f644..0ea305b5fb 100644
--- a/trio/_tests/test_highlevel_open_tcp_stream.py
+++ b/trio/_tests/test_highlevel_open_tcp_stream.py
@@ -15,7 +15,7 @@
open_tcp_stream,
reorder_for_rfc_6555_section_5_4,
)
-from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, Address, SocketType
+from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType
if TYPE_CHECKING:
from trio.testing import MockClock
From 920d6630c930151d0a202e82bbc49621db437048 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sat, 16 Sep 2023 12:07:09 +0200
Subject: [PATCH 11/20] fix verify_types
---
trio/_tests/verify_types_darwin.json | 2 +-
trio/_tests/verify_types_linux.json | 2 +-
trio/_tests/verify_types_windows.json | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json
index d98fbacaf0..507af26b3d 100644
--- a/trio/_tests/verify_types_darwin.json
+++ b/trio/_tests/verify_types_darwin.json
@@ -76,7 +76,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 686,
+ "withKnownType": 687,
"withUnknownType": 0
},
"packageName": "trio"
diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json
index d62d04db82..61ce9afcc5 100644
--- a/trio/_tests/verify_types_linux.json
+++ b/trio/_tests/verify_types_linux.json
@@ -64,7 +64,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 686,
+ "withKnownType": 687,
"withUnknownType": 0
},
"packageName": "trio"
diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json
index 555003ae92..e825282bf0 100644
--- a/trio/_tests/verify_types_windows.json
+++ b/trio/_tests/verify_types_windows.json
@@ -180,7 +180,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 677,
+ "withKnownType": 674,
"withUnknownType": 0
},
"packageName": "trio"
From a3169cc11adbd4fb1cd0d4fa9defa9a84451cba4 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sat, 16 Sep 2023 14:26:40 +0200
Subject: [PATCH 12/20] reverse making SocketType generic, CI fixes
---
trio/_abc.py | 6 +-
trio/_dtls.py | 12 ++--
trio/_highlevel_open_tcp_stream.py | 6 +-
trio/_highlevel_socket.py | 6 +-
trio/_socket.py | 64 +++++++++++--------
.../test_highlevel_open_tcp_listeners.py | 8 +--
trio/_tests/test_highlevel_open_tcp_stream.py | 8 +--
trio/_tests/test_highlevel_socket.py | 10 +--
trio/_tests/test_socket.py | 10 +--
trio/testing/_fake_net.py | 12 +++-
10 files changed, 77 insertions(+), 65 deletions(-)
diff --git a/trio/_abc.py b/trio/_abc.py
index 724df92855..dbc7966cee 100644
--- a/trio/_abc.py
+++ b/trio/_abc.py
@@ -2,7 +2,7 @@
import socket
from abc import ABCMeta, abstractmethod
-from typing import TYPE_CHECKING, Any, Generic, TypeVar
+from typing import TYPE_CHECKING, Generic, TypeVar
import trio
@@ -209,12 +209,12 @@ class SocketFactory(metaclass=ABCMeta):
"""
@abstractmethod
- def socket( # type: ignore[misc]
+ def socket(
self,
family: socket.AddressFamily | int = socket.AF_INET,
type: socket.SocketKind | int = socket.SOCK_STREAM,
proto: int = 0,
- ) -> SocketType[Any]:
+ ) -> SocketType:
"""Create and return a socket object.
Your socket object must inherit from :class:`trio.socket.SocketType`,
diff --git a/trio/_dtls.py b/trio/_dtls.py
index 63ca83db1f..9a98dd8f5f 100644
--- a/trio/_dtls.py
+++ b/trio/_dtls.py
@@ -48,21 +48,21 @@
MAX_UDP_PACKET_SIZE = 65527
-def packet_header_overhead(sock: SocketType[Any]) -> int:
+def packet_header_overhead(sock: SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 28
else:
return 48
-def worst_case_mtu(sock: SocketType[Any]) -> int:
+def worst_case_mtu(sock: SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 576 - packet_header_overhead(sock)
else:
return 1280 - packet_header_overhead(sock)
-def best_guess_mtu(sock: SocketType[Any]) -> int:
+def best_guess_mtu(sock: SocketType) -> int:
return 1500 - packet_header_overhead(sock)
@@ -739,7 +739,7 @@ async def handle_client_hello_untrusted(
async def dtls_receive_loop(
- endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType[Any]
+ endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType
) -> None:
try:
while True:
@@ -1178,7 +1178,7 @@ class DTLSEndpoint(metaclass=Final):
"""
- def __init__(self, socket: SocketType[Any], *, incoming_packets_buffer: int = 10):
+ def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10):
# We do this lazily on first construction, so only people who actually use DTLS
# have to install PyOpenSSL.
global SSL
@@ -1189,7 +1189,7 @@ def __init__(self, socket: SocketType[Any], *, incoming_packets_buffer: int = 10
if socket.type != trio.socket.SOCK_DGRAM:
raise ValueError("DTLS requires a SOCK_DGRAM socket")
self._initialized = True
- self.socket: SocketType[Any] = socket
+ self.socket: SocketType = socket
self.incoming_packets_buffer = incoming_packets_buffer
self._token = trio.lowlevel.current_trio_token()
diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py
index 358d9f32f0..c3dc157827 100644
--- a/trio/_highlevel_open_tcp_stream.py
+++ b/trio/_highlevel_open_tcp_stream.py
@@ -114,8 +114,8 @@
@contextmanager
-def close_all() -> Generator[set[SocketType[Any]], None, None]: # type: ignore[misc]
- sockets_to_close: set[SocketType[Any]] = set()
+def close_all() -> Generator[set[SocketType], None, None]:
+ sockets_to_close: set[SocketType] = set()
try:
yield sockets_to_close
finally:
@@ -292,7 +292,7 @@ async def open_tcp_stream(
# Keeps track of the socket that we're going to complete with,
# need to make sure this isn't automatically closed
- winning_socket: SocketType[Any] | None = None
+ winning_socket: SocketType | None = None
# Try connecting to the specified address. Possible outcomes:
# - success: record connected socket in winning_socket and cancel
diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py
index edb07f2e82..d733537752 100644
--- a/trio/_highlevel_socket.py
+++ b/trio/_highlevel_socket.py
@@ -4,7 +4,7 @@
import errno
from collections.abc import Generator
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Any, overload
+from typing import TYPE_CHECKING, overload
import trio
@@ -66,7 +66,7 @@ class SocketStream(HalfCloseableStream, metaclass=Final):
"""
- def __init__(self, socket: SocketType[Any]):
+ def __init__(self, socket: SocketType):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketStream requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
@@ -371,7 +371,7 @@ class SocketListener(Listener[SocketStream], metaclass=Final):
"""
- def __init__(self, socket: SocketType[Any]):
+ def __init__(self, socket: SocketType):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketListener requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
diff --git a/trio/_socket.py b/trio/_socket.py
index 8dee44a279..84059f5ba3 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -12,7 +12,6 @@
Any,
Awaitable,
Callable,
- Generic,
Literal,
SupportsIndex,
TypeVar,
@@ -38,7 +37,18 @@
T = TypeVar("T")
-AddressFormat = TypeVar("AddressFormat")
+
+# _stdlib_socket.socket supports 13 different socket families, see
+# https://docs.python.org/3/library/socket.html#socket-families
+# and the return type of several methods in SocketType will depend on those. Typeshed
+# has ended up typing those return types as `Any` in most cases, but for users that
+# know which family/families they're working in we could make SocketType a generic type,
+# where you specify the return values you expect from those methods depending on the
+# protocol the socket will be handling.
+# But without the ability to default the value to `Any` it will be overly cumbersome for
+# most users, so currently we just specify it as `Any`.
+# AddressFormat = TypeVar("AddressFormat")
+AddressFormat: TypeAlias = Any
# Usage:
@@ -273,7 +283,7 @@ async def getprotobyname(name: str) -> int:
################################################################
-def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType[Any]:
+def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType:
"""Convert a standard library :class:`socket.socket` object into a Trio
socket object.
@@ -282,12 +292,12 @@ def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType[Any]:
@_wraps(_stdlib_socket.fromfd, assigned=(), updated=())
-def fromfd( # type: ignore[misc] # use of Any in decorated function
+def fromfd(
fd: SupportsIndex,
family: AddressFamily | int = _stdlib_socket.AF_INET,
type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
-) -> SocketType[Any]:
+) -> SocketType:
"""Like :func:`socket.fromfd`, but returns a Trio socket object."""
family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd))
return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto))
@@ -298,7 +308,7 @@ def fromfd( # type: ignore[misc] # use of Any in decorated function
):
@_wraps(_stdlib_socket.fromshare, assigned=(), updated=())
- def fromshare(info: bytes) -> SocketType[Any]:
+ def fromshare(info: bytes) -> SocketType:
return from_stdlib_socket(_stdlib_socket.fromshare(info))
@@ -313,11 +323,11 @@ def fromshare(info: bytes) -> SocketType[Any]:
@_wraps(_stdlib_socket.socketpair, assigned=(), updated=())
-def socketpair( # type: ignore[misc]
+def socketpair(
family: FamilyT = FamilyDefault,
type: TypeT = SocketKind.SOCK_STREAM,
proto: int = 0,
-) -> tuple[SocketType[Any], SocketType[Any]]:
+) -> tuple[SocketType, SocketType]:
"""Like :func:`socket.socketpair`, but returns a pair of Trio socket
objects.
@@ -327,12 +337,12 @@ def socketpair( # type: ignore[misc]
@_wraps(_stdlib_socket.socket, assigned=(), updated=())
-def socket( # type: ignore[misc]
+def socket(
family: AddressFamily | int = _stdlib_socket.AF_INET,
type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
fileno: int | None = None,
-) -> SocketType[Any]:
+) -> SocketType:
"""Create a new Trio socket, like :class:`socket.socket`.
This function's behavior can be customized using
@@ -400,9 +410,9 @@ def _make_simple_sock_method_wrapper(
fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]],
maybe_avail: bool = False,
-) -> Callable[Concatenate[_SocketType[Any], P], Awaitable[T]]:
- @_wraps(fn, assigned=("__name__",), updated=()) # type: ignore
- async def wrapper(self: _SocketType[Any], *args: P.args, **kwargs: P.kwargs) -> T:
+) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]:
+ @_wraps(fn, assigned=("__name__",), updated=())
+ async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T:
return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs)
wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async.
@@ -509,9 +519,7 @@ async def _resolve_address_nocp(
return normed
-# _stdlib_socket.socket supports 13 different socket families: https://docs.python.org/3/library/socket.html#socket-families
-# and the return type of several methods will depend on those. typeshed has ended up typing those return types as `Any` in most cases, but for users that know which family/families they're working in and wants complete type coverage they can specify the AddressFormat.
-class SocketType(Generic[AddressFormat]):
+class SocketType:
def __init__(self) -> None:
if type(self) == SocketType:
raise TypeError(
@@ -608,7 +616,7 @@ def did_shutdown_SHUT_WR(self) -> bool:
def __repr__(self) -> str:
raise NotImplementedError
- def dup(self) -> SocketType[AddressFormat]:
+ def dup(self) -> SocketType:
raise NotImplementedError
def close(self) -> None:
@@ -626,7 +634,7 @@ def is_readable(self) -> bool:
async def wait_writable(self) -> None:
raise NotImplementedError
- async def accept(self) -> tuple[SocketType[AddressFormat], AddressFormat]:
+ async def accept(self) -> tuple[SocketType, AddressFormat]:
raise NotImplementedError
async def connect(self, address: AddressFormat) -> None:
@@ -712,7 +720,7 @@ async def sendmsg(
raise NotImplementedError
-class _SocketType(SocketType[AddressFormat]):
+class _SocketType(SocketType):
def __init__(self, sock: _stdlib_socket.socket):
if type(sock) is not _stdlib_socket.socket:
# For example, ssl.SSLSocket subclasses socket.socket, but we
@@ -736,10 +744,10 @@ def fileno(self) -> int:
return self._sock.fileno()
def getpeername(self) -> AddressFormat:
- return self._sock.getpeername() # type: ignore[no-any-return]
+ return self._sock.getpeername()
def getsockname(self) -> AddressFormat:
- return self._sock.getsockname() # type: ignore[no-any-return]
+ return self._sock.getsockname()
@overload
def getsockopt(self, /, level: int, optname: int) -> int:
@@ -833,7 +841,7 @@ def did_shutdown_SHUT_WR(self) -> bool:
def __repr__(self) -> str:
return repr(self._sock).replace("socket.socket", "trio.socket.socket")
- def dup(self) -> SocketType[AddressFormat]:
+ def dup(self) -> SocketType:
"""Same as :meth:`socket.socket.dup`."""
return _SocketType(self._sock.dup())
@@ -847,7 +855,7 @@ async def bind(self, address: AddressFormat) -> None:
if (
hasattr(_stdlib_socket, "AF_UNIX")
and self.family == _stdlib_socket.AF_UNIX
- and address[0] # type: ignore[index]
+ and address[0]
):
# Use a thread for the filesystem traversal (unless it's an
# abstract domain socket)
@@ -858,7 +866,7 @@ async def bind(self, address: AddressFormat) -> None:
# there aren't yet any real systems that do this, so we'll worry
# about it when it happens.
await trio.lowlevel.checkpoint()
- return self._sock.bind(address) # type: ignore[arg-type]
+ return self._sock.bind(address)
def shutdown(self, flag: int) -> None:
# no need to worry about return value b/c always returns None:
@@ -891,7 +899,7 @@ async def _resolve_address_nocp(
)
else:
ipv6_v6only = False
- return await _resolve_address_nocp( # type: ignore[no-any-return]
+ return await _resolve_address_nocp(
self.type,
self.family,
self.proto,
@@ -952,7 +960,7 @@ async def _nonblocking_helper(
_stdlib_socket.socket.accept, _core.wait_readable
)
- async def accept(self) -> tuple[SocketType[AddressFormat], AddressFormat]:
+ async def accept(self) -> tuple[SocketType, AddressFormat]:
"""Like :meth:`socket.socket.accept`, but async."""
sock, addr = await self._accept()
return from_stdlib_socket(sock), addr
@@ -1016,7 +1024,7 @@ async def connect(self, address: AddressFormat) -> None:
# happens, someone will hopefully tell us, and then hopefully we
# can investigate their system to figure out what its semantics
# are.
- return self._sock.connect(address) # type: ignore[arg-type]
+ return self._sock.connect(address)
# It raised BlockingIOError, meaning that it's started the
# connection attempt. We wait for it to complete:
await _core.wait_writable(self._sock)
@@ -1205,7 +1213,7 @@ async def sendmsg(
__buffers,
__ancdata,
__flags,
- __address, # type: ignore[arg-type]
+ __address,
)
################################################################
diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py
index 776920dade..b6cfb55732 100644
--- a/trio/_tests/test_highlevel_open_tcp_listeners.py
+++ b/trio/_tests/test_highlevel_open_tcp_listeners.py
@@ -5,7 +5,7 @@
import sys
from math import inf
from socket import AddressFamily, SocketKind
-from typing import TYPE_CHECKING, Sequence, overload
+from typing import TYPE_CHECKING, Any, Sequence, overload
import attr
import pytest
@@ -140,7 +140,7 @@ def family(self) -> AddressFamily:
return self._family
@property
- def proto(self) -> int:
+ def proto(self) -> int: # pragma: no cover
return self._proto
@overload
@@ -176,7 +176,7 @@ def setsockopt(
) -> None:
pass
- async def bind(self, address: AddressWithNoneHost) -> None:
+ async def bind(self, address: Any) -> None:
pass
def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None:
@@ -289,7 +289,7 @@ async def handler(stream: SendStream) -> None:
async with trio.open_nursery() as nursery:
# nursery.start is incorrectly typed, awaiting #2773
- listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) # type: ignore[arg-type]
+ listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0)
stream = await open_stream_to_socket_listener(listeners[0])
async with stream:
await stream.receive_some(1) == b"x"
diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py
index 0ea305b5fb..3992c83630 100644
--- a/trio/_tests/test_highlevel_open_tcp_stream.py
+++ b/trio/_tests/test_highlevel_open_tcp_stream.py
@@ -169,7 +169,7 @@ async def test_local_address_real() -> None:
await client_stream.aclose()
server_sock.close()
# accept returns tuple[SocketType, object], due to typeshed returning `Any`
- assert remote_addr[0] == local_address # type: ignore[index]
+ assert remote_addr[0] == local_address
# Trying to connect to an ipv4 address with the ipv6 wildcard
# local_address should fail
@@ -206,14 +206,14 @@ def type(self) -> SocketKind:
return self._type
@property
- def family(self) -> AddressFamily:
+ def family(self) -> AddressFamily: # pragma: no cover
return self._family
@property
- def proto(self) -> int:
+ def proto(self) -> int: # pragma: no cover
return self._proto
- async def connect(self, sockaddr: Address) -> None:
+ async def connect(self, sockaddr: tuple[str | int, str | int | None]) -> None:
self.ip = sockaddr[0]
self.port = sockaddr[1]
assert self.ip not in self.scenario.sockets
diff --git a/trio/_tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py
index c5d46b6c6a..830a153c00 100644
--- a/trio/_tests/test_highlevel_socket.py
+++ b/trio/_tests/test_highlevel_socket.py
@@ -222,10 +222,12 @@ def getsockopt(self, /, level: int, optname: int) -> int:
...
@overload
- def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
+ def getsockopt( # noqa: F811
+ self, /, level: int, optname: int, buflen: int
+ ) -> bytes:
...
- def getsockopt(
+ def getsockopt( # noqa: F811
self, /, level: int, optname: int, buflen: int | None = None
) -> int | bytes:
return True
@@ -235,12 +237,12 @@ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
...
@overload
- def setsockopt(
+ def setsockopt( # noqa: F811
self, /, level: int, optname: int, value: None, optlen: int
) -> None:
...
- def setsockopt(
+ def setsockopt( # noqa: F811
self,
/,
level: int,
diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py
index b405e97e72..40ffefb8cd 100644
--- a/trio/_tests/test_socket.py
+++ b/trio/_tests/test_socket.py
@@ -289,7 +289,7 @@ async def child(sock: SocketType) -> None:
@pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only")
async def test_fromshare() -> None:
- if TYPE_CHECKING and sys.platform != "win32":
+ if TYPE_CHECKING and sys.platform != "win32": # pragma: no cover
return
a, b = tsocket.socketpair()
with a, b:
@@ -585,12 +585,8 @@ async def res(
| tuple[str, str]
| tuple[str, str, int]
| tuple[str, str, int, int]
- ) -> tuple[str, int] | tuple[str, int, int, int]:
- # we're only passing IP sockets, so we ignore the str/bytes return type
- # But what about when port/family is a string? Should that be part of the public API?
- res = await sock._resolve_address_nocp(args, local=local) # type: ignore[arg-type]
- # no str/bytes
- return res # type: ignore[return-value]
+ ) -> Any:
+ return await sock._resolve_address_nocp(args, local=local)
assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80))
if v6:
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index 2a358119f3..f0ab3f53e4 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -11,7 +11,7 @@
import errno
import ipaddress
import os
-from typing import TYPE_CHECKING, Optional, Union, Type
+from typing import TYPE_CHECKING, Optional, Type, Union
import attr
@@ -21,9 +21,10 @@
if TYPE_CHECKING:
from socket import AddressFamily, SocketKind
from types import TracebackType
+
from typing_extensions import TypeAlias
-IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
+IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
def _family_for(ip: IPAddress) -> int:
@@ -175,7 +176,9 @@ def deliver_packet(self, packet) -> None:
class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor):
- def __init__(self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, proto: int):
+ def __init__(
+ self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, proto: int
+ ):
self._fake_net = fake_net
if not family:
@@ -200,12 +203,15 @@ def __init__(self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, p
# This is the source-of-truth for what port etc. this socket is bound to
self._binding: Optional[UDPBinding] = None
+
@property
def type(self) -> SocketKind:
return self._type
+
@property
def family(self) -> AddressFamily:
return self._family
+
@property
def proto(self) -> int:
return self._proto
From b9bbc06868795255562e97e07988b7066ccb9aef Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sat, 16 Sep 2023 14:35:07 +0200
Subject: [PATCH 13/20] fix verifytypes
---
trio/_tests/verify_types_darwin.json | 2 +-
trio/_tests/verify_types_linux.json | 2 +-
trio/_tests/verify_types_windows.json | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json
index 507af26b3d..d98fbacaf0 100644
--- a/trio/_tests/verify_types_darwin.json
+++ b/trio/_tests/verify_types_darwin.json
@@ -76,7 +76,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 687,
+ "withKnownType": 686,
"withUnknownType": 0
},
"packageName": "trio"
diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json
index 61ce9afcc5..d62d04db82 100644
--- a/trio/_tests/verify_types_linux.json
+++ b/trio/_tests/verify_types_linux.json
@@ -64,7 +64,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 687,
+ "withKnownType": 686,
"withUnknownType": 0
},
"packageName": "trio"
diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json
index e825282bf0..0c33864b3a 100644
--- a/trio/_tests/verify_types_windows.json
+++ b/trio/_tests/verify_types_windows.json
@@ -180,7 +180,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 674,
+ "withKnownType": 673,
"withUnknownType": 0
},
"packageName": "trio"
From 76ab99427f893e98d7b3f2037f6ec9a05646bb8b Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Wed, 20 Sep 2023 15:57:24 +0200
Subject: [PATCH 14/20] fix suggestions from teamspen
---
trio/_tests/test_highlevel_open_tcp_stream.py | 2 +-
trio/testing/_fake_net.py | 6 +++---
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py
index 3992c83630..eb7929c2db 100644
--- a/trio/_tests/test_highlevel_open_tcp_stream.py
+++ b/trio/_tests/test_highlevel_open_tcp_stream.py
@@ -377,7 +377,7 @@ async def run_scenario(
return (exc, scenario)
-async def test_one_host_quick_success(autojump_clock: trio.testing.MockClock) -> None:
+async def test_one_host_quick_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
assert isinstance(sock, FakeSocket)
assert sock.ip == "1.2.3.4"
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
index f0ab3f53e4..fe886ba131 100644
--- a/trio/testing/_fake_net.py
+++ b/trio/testing/_fake_net.py
@@ -8,10 +8,11 @@
from __future__ import annotations
+import builtins
import errno
import ipaddress
import os
-from typing import TYPE_CHECKING, Optional, Type, Union
+from typing import TYPE_CHECKING, Optional, Union
import attr
@@ -378,8 +379,7 @@ def __enter__(self):
def __exit__(
self,
- # builtin `type` is shadowed by the property
- exc_type: Type[BaseException] | None,
+ exc_type: builtins.type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
From 0c294ea1729300424d3e084dcb303ab78c07c858 Mon Sep 17 00:00:00 2001
From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com>
Date: Fri, 22 Sep 2023 19:16:46 -0500
Subject: [PATCH 15/20] Update `verify_types`
---
trio/_tests/verify_types_darwin.json | 2 +-
trio/_tests/verify_types_linux.json | 2 +-
trio/_tests/verify_types_windows.json | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json
index e83a324714..acb4f58c07 100644
--- a/trio/_tests/verify_types_darwin.json
+++ b/trio/_tests/verify_types_darwin.json
@@ -76,7 +76,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 685,
+ "withKnownType": 684,
"withUnknownType": 0
},
"packageName": "trio"
diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json
index 7c9d745dba..c060b2a51e 100644
--- a/trio/_tests/verify_types_linux.json
+++ b/trio/_tests/verify_types_linux.json
@@ -64,7 +64,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 685,
+ "withKnownType": 684,
"withUnknownType": 0
},
"packageName": "trio"
diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json
index 3da2740491..249bb179ae 100644
--- a/trio/_tests/verify_types_windows.json
+++ b/trio/_tests/verify_types_windows.json
@@ -180,7 +180,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 672,
+ "withKnownType": 671,
"withUnknownType": 0
},
"packageName": "trio"
From dc09e7eeb7ed5f6969d31920b7eeb1f80ad581fd Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 26 Sep 2023 09:44:51 +0200
Subject: [PATCH 16/20] clarifying comments after a5rocks review
---
trio/_socket.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/trio/_socket.py b/trio/_socket.py
index 84059f5ba3..77731708ba 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -46,8 +46,9 @@
# where you specify the return values you expect from those methods depending on the
# protocol the socket will be handling.
# But without the ability to default the value to `Any` it will be overly cumbersome for
-# most users, so currently we just specify it as `Any`.
-# AddressFormat = TypeVar("AddressFormat")
+# most users, so currently we just specify it as `Any`. Otherwise we would write:
+# `AddressFormat = TypeVar("AddressFormat")`
+# but instead we simply do:
AddressFormat: TypeAlias = Any
@@ -521,6 +522,9 @@ async def _resolve_address_nocp(
class SocketType:
def __init__(self) -> None:
+ # make sure this __init__ works with multiple inheritance
+ super().__init__()
+ # and only raises error if it's directly constructed
if type(self) == SocketType:
raise TypeError(
"SocketType is an abstract class; use trio.socket.socket if you "
@@ -640,6 +644,7 @@ async def accept(self) -> tuple[SocketType, AddressFormat]:
async def connect(self, address: AddressFormat) -> None:
raise NotImplementedError
+ # argument names with __ used because of typeshed, see comment for recv in _SocketType
def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
raise NotImplementedError
From d934724af349736ba3e0a15cd17e7601c77c75f7 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 26 Sep 2023 11:47:02 +0200
Subject: [PATCH 17/20] fix verify_types ............
---
trio/_tests/verify_types_windows.json | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json
index bab6797c02..232b18757c 100644
--- a/trio/_tests/verify_types_windows.json
+++ b/trio/_tests/verify_types_windows.json
@@ -60,7 +60,7 @@
],
"exportedSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 631,
+ "withKnownType": 630,
"withUnknownType": 0
},
"ignoreUnknownTypesFromImports": true,
@@ -96,7 +96,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 0,
- "withKnownType": 680,
+ "withKnownType": 676,
"withUnknownType": 0
},
"packageName": "trio"
From 73351845510d14008fb436f1099e90c242461208 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 26 Sep 2023 13:46:04 +0200
Subject: [PATCH 18/20] add @_public to exclude_lines in coveragerc
---
.coveragerc | 1 +
1 file changed, 1 insertion(+)
diff --git a/.coveragerc b/.coveragerc
index 4911012653..5f6bcb3e11 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -26,6 +26,7 @@ exclude_lines =
@overload
class .*\bProtocol\b.*\):
raise NotImplementedError
+ @_public
partial_branches =
pragma: no branch
From 618fdbf932ed86a796c0e1da80b55cc3ad291b73 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 26 Sep 2023 13:57:49 +0200
Subject: [PATCH 19/20] Revert "add @_public to exclude_lines in coveragerc"
This reverts commit 73351845510d14008fb436f1099e90c242461208.
---
.coveragerc | 1 -
1 file changed, 1 deletion(-)
diff --git a/.coveragerc b/.coveragerc
index 5f6bcb3e11..4911012653 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -26,7 +26,6 @@ exclude_lines =
@overload
class .*\bProtocol\b.*\):
raise NotImplementedError
- @_public
partial_branches =
pragma: no branch
From a20a909f29c390c09081e7246649854a345ef36e Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 28 Sep 2023 14:05:26 +0200
Subject: [PATCH 20/20] 'pragma: no cover' an added type-converting line
---
trio/_tests/test_highlevel_open_tcp_listeners.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py
index b6cfb55732..6f39446c7e 100644
--- a/trio/_tests/test_highlevel_open_tcp_listeners.py
+++ b/trio/_tests/test_highlevel_open_tcp_listeners.py
@@ -204,8 +204,8 @@ def socket(
) -> tsocket.SocketType:
assert family is not None
assert type is not None
- if isinstance(family, int):
- family = AddressFamily(family)
+ if isinstance(family, int) and not isinstance(family, AddressFamily):
+ family = AddressFamily(family) # pragma: no cover
if family in self.raise_on_family:
raise OSError(self.raise_on_family[family], "nope")
sock = FakeSocket(family, type, proto)