diff --git a/pyproject.toml b/pyproject.toml index ac3e1a3ea5..90bc98f64f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ module = [ "trio._deprecate", "trio._dtls", "trio._file_io", + "trio._highlevel_open_tcp_stream.py", "trio._ki", "trio._socket", "trio._sync", diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index a2477104d9..0c4e8a4a8d 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -1,9 +1,14 @@ +from __future__ import annotations + import sys +from collections.abc import Generator from contextlib import contextmanager +from socket import AddressFamily, SocketKind +from typing import TYPE_CHECKING import trio from trio._core._multierror import MultiError -from trio.socket import SOCK_STREAM, getaddrinfo, socket +from trio.socket import SOCK_STREAM, Address, _SocketType, getaddrinfo, socket if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -109,8 +114,8 @@ @contextmanager -def close_all(): - sockets_to_close = set() +def close_all() -> Generator[set[_SocketType], None, None]: + sockets_to_close: set[_SocketType] = set() try: yield sockets_to_close finally: @@ -126,7 +131,17 @@ def close_all(): raise MultiError(errs) -def reorder_for_rfc_6555_section_5_4(targets): +def reorder_for_rfc_6555_section_5_4( + targets: list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ] +) -> None: # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address # families (e.g. IPv4 and IPv6), then you should make sure that your first # and second attempts use different families: @@ -144,7 +159,7 @@ def reorder_for_rfc_6555_section_5_4(targets): break -def format_host_port(host, port): +def format_host_port(host: str | bytes, port: int) -> str: host = host.decode("ascii") if isinstance(host, bytes) else host if ":" in host: return f"[{host}]:{port}" @@ -173,8 +188,12 @@ def format_host_port(host, port): # AF_INET6: "..."} # this might be simpler after async def open_tcp_stream( - host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None -): + host: str | bytes, + port: int, + *, + happy_eyeballs_delay: float | None = DEFAULT_DELAY, + local_address: str | None = None, +) -> trio.abc.Stream: """Connect to the given host and port over TCP. If the given ``host`` has multiple IP addresses associated with it, then @@ -212,9 +231,9 @@ async def open_tcp_stream( port (int): The port to connect to. - happy_eyeballs_delay (float): How many seconds to wait for each + happy_eyeballs_delay (float or None): How many seconds to wait for each connection attempt to succeed or fail before getting impatient and - starting another one in parallel. Set to `math.inf` if you want + starting another one in parallel. Set to `None` if you want to limit to only one connection attempt at a time (like :func:`socket.create_connection`). Default: 0.25 (250 ms). @@ -247,9 +266,8 @@ async def open_tcp_stream( # To keep our public API surface smaller, rule out some cases that # getaddrinfo will accept in some circumstances, but that act weird or # have non-portable behavior or are just plain not useful. - # No type check on host though b/c we want to allow bytes-likes. - if host is None: - raise ValueError("host cannot be None") + if not isinstance(host, (str, bytes)): + raise ValueError(f"host must be str or bytes, not {host!r}") if not isinstance(port, int): raise TypeError(f"port must be int, not {port!r}") @@ -274,7 +292,7 @@ async def open_tcp_stream( # Keeps track of the socket that we're going to complete with, # need to make sure this isn't automatically closed - winning_socket = None + winning_socket: _SocketType | None = None # Try connecting to the specified address. Possible outcomes: # - success: record connected socket in winning_socket and cancel @@ -283,7 +301,11 @@ async def open_tcp_stream( # the next connection attempt to start early # code needs to ensure sockets can be closed appropriately in the # face of crash or cancellation - async def attempt_connect(socket_args, sockaddr, attempt_failed): + async def attempt_connect( + socket_args: tuple[AddressFamily, SocketKind, int], + sockaddr: Address, + attempt_failed: trio.Event, + ) -> None: nonlocal winning_socket try: @@ -334,7 +356,7 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed): except OSError: raise OSError( f"local_address={local_address!r} is incompatible " - f"with remote address {sockaddr}" + f"with remote address {sockaddr!r}" ) await sock.connect(sockaddr) @@ -355,12 +377,23 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed): # nursery spawns a task for each connection attempt, will be # cancelled by the task that gets a successful connection async with trio.open_nursery() as nursery: - for *sa, _, addr in targets: + for address_family, socket_type, proto, _, addr in targets: # create an event to indicate connection failure, # allowing the next target to be tried early attempt_failed = trio.Event() - nursery.start_soon(attempt_connect, sa, addr, attempt_failed) + # workaround to check types until typing of nursery.start_soon improved + if TYPE_CHECKING: + await attempt_connect( + (address_family, socket_type, proto), addr, attempt_failed + ) + + nursery.start_soon( + attempt_connect, + (address_family, socket_type, proto), + addr, + attempt_failed, + ) # give this attempt at most this time before moving on with trio.move_on_after(happy_eyeballs_delay): diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index 4dbd256dcf..ac2cfbd197 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9137380191693291, + "completenessScore": 0.9154704944178629, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 572, - "withUnknownType": 54 + "withKnownType": 574, + "withUnknownType": 53 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -109,7 +109,6 @@ "trio.open_ssl_over_tcp_listeners", "trio.open_ssl_over_tcp_stream", "trio.open_tcp_listeners", - "trio.open_tcp_stream", "trio.open_unix_socket", "trio.run", "trio.run_process", diff --git a/trio/socket.py b/trio/socket.py index f6aebb6a6e..f8d0bc3fc2 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -34,6 +34,7 @@ # import the overwrites from ._socket import ( + Address as Address, SocketType as SocketType, _SocketType as _SocketType, from_stdlib_socket as from_stdlib_socket,