diff --git a/.coveragerc b/.coveragerc index de12642263..4911012653 100644 --- a/.coveragerc +++ b/.coveragerc @@ -25,6 +25,7 @@ exclude_lines = if t.TYPE_CHECKING: @overload class .*\bProtocol\b.*\): + raise NotImplementedError partial_branches = pragma: no branch diff --git a/docs/source/conf.py b/docs/source/conf.py index 66aa8dea05..ee23ce587f 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -54,8 +54,6 @@ ("py:class", "sync function"), # why aren't these found in stdlib? ("py:class", "types.FrameType"), - # TODO: temporary type - ("py:class", "_SocketType"), # these are not defined in https://docs.python.org/3/objects.inv ("py:class", "socket.AddressFamily"), ("py:class", "socket.SocketKind"), diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index e270033b46..e234d255c1 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -504,13 +504,6 @@ Socket objects * :meth:`~socket.socket.set_inheritable` * :meth:`~socket.socket.get_inheritable` -The internal SocketType -~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: _SocketType -.. - TODO: adding `:members:` here gives error due to overload+_wraps on `sendto` - TODO: rewrite ... all of the above when fixing _SocketType vs SocketType - .. currentmodule:: trio diff --git a/pyproject.toml b/pyproject.toml index d123d4f158..3fe1372e20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,15 +79,11 @@ module = [ "trio/_tests/test_exports", "trio/_tests/test_file_io", "trio/_tests/test_highlevel_generic", -"trio/_tests/test_highlevel_open_tcp_listeners", -"trio/_tests/test_highlevel_open_tcp_stream", "trio/_tests/test_highlevel_open_unix_stream", "trio/_tests/test_highlevel_serve_listeners", -"trio/_tests/test_highlevel_socket", "trio/_tests/test_highlevel_ssl_helpers", "trio/_tests/test_path", "trio/_tests/test_scheduler_determinism", -"trio/_tests/test_socket", "trio/_tests/test_ssl", "trio/_tests/test_subprocess", "trio/_tests/test_sync", diff --git a/trio/_abc.py b/trio/_abc.py index 746360c8f8..dbc7966cee 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -12,7 +12,7 @@ from typing_extensions import Self # both of these introduce circular imports if outside a TYPE_CHECKING guard - from ._socket import _SocketType + from ._socket import SocketType from .lowlevel import Task @@ -211,10 +211,10 @@ class SocketFactory(metaclass=ABCMeta): @abstractmethod def socket( self, - family: socket.AddressFamily | int | None = None, - type: socket.SocketKind | int | None = None, - proto: int | None = None, - ) -> _SocketType: + family: socket.AddressFamily | int = socket.AF_INET, + type: socket.SocketKind | int = socket.SOCK_STREAM, + proto: int = 0, + ) -> SocketType: """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index 4dc2b59c98..af73bd21cf 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -3,17 +3,20 @@ # ************************************************************* from __future__ import annotations -import sys -from socket import socket from typing import TYPE_CHECKING from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import GLOBAL_RUN_CONTEXT +if TYPE_CHECKING: + from .._file_io import _HasFileNo + +import sys + assert not TYPE_CHECKING or sys.platform == "linux" -async def wait_readable(fd: (int | socket)) -> None: +async def wait_readable(fd: (int | _HasFileNo)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -21,7 +24,7 @@ async def wait_readable(fd: (int | socket)) -> None: raise RuntimeError("must be called from async context") -async def wait_writable(fd: (int | socket)) -> None: +async def wait_writable(fd: (int | _HasFileNo)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -29,7 +32,7 @@ async def wait_writable(fd: (int | socket)) -> None: raise RuntimeError("must be called from async context") -def notify_closing(fd: (int | socket)) -> None: +def notify_closing(fd: (int | _HasFileNo)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 9c8ca26ef3..a0883d0179 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -10,11 +10,11 @@ if TYPE_CHECKING: import select - from socket import socket from ._traps import Abort, RaiseCancelT from .. import _core + from .._file_io import _HasFileNo import sys @@ -51,7 +51,7 @@ async def wait_kevent( raise RuntimeError("must be called from async context") -async def wait_readable(fd: (int | socket)) -> None: +async def wait_readable(fd: (int | _HasFileNo)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) @@ -59,7 +59,7 @@ async def wait_readable(fd: (int | socket)) -> None: raise RuntimeError("must be called from async context") -async def wait_writable(fd: (int | socket)) -> None: +async def wait_writable(fd: (int | _HasFileNo)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) @@ -67,7 +67,7 @@ async def wait_writable(fd: (int | socket)) -> None: raise RuntimeError("must be called from async context") -def notify_closing(fd: (int | socket)) -> None: +def notify_closing(fd: (int | _HasFileNo)) -> None: locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 0d247cae64..a0373fb8fa 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -13,11 +13,10 @@ from ._wakeup_socketpair import WakeupSocketpair if TYPE_CHECKING: - from socket import socket - from typing_extensions import TypeAlias from .._core import Abort, RaiseCancelT + from .._file_io import _HasFileNo @attr.s(slots=True, eq=False) @@ -290,7 +289,7 @@ def _update_registrations(self, fd: int) -> None: if not wanted_flags: del self._registered[fd] - async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None: + async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None: if not isinstance(fd, int): fd = fd.fileno() waiters = self._registered[fd] @@ -309,15 +308,15 @@ def abort(_: RaiseCancelT) -> Abort: await _core.wait_task_rescheduled(abort) @_public - async def wait_readable(self, fd: int | socket) -> None: + async def wait_readable(self, fd: int | _HasFileNo) -> None: await self._epoll_wait(fd, "read_task") @_public - async def wait_writable(self, fd: int | socket) -> None: + async def wait_writable(self, fd: int | _HasFileNo) -> None: await self._epoll_wait(fd, "write_task") @_public - def notify_closing(self, fd: int | socket) -> None: + def notify_closing(self, fd: int | _HasFileNo) -> None: if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 56a6559091..4faa382eca 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -14,11 +14,10 @@ from ._wakeup_socketpair import WakeupSocketpair if TYPE_CHECKING: - from socket import socket - from typing_extensions import TypeAlias from .._core import Abort, RaiseCancelT, Task, UnboundedQueue + from .._file_io import _HasFileNo assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") @@ -149,7 +148,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: # wait_task_rescheduled does not have its return type typed return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return] - async def _wait_common(self, fd: int | socket, filter: int) -> None: + async def _wait_common(self, fd: int | _HasFileNo, filter: int) -> None: if not isinstance(fd, int): fd = fd.fileno() flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT @@ -181,15 +180,15 @@ def abort(_: RaiseCancelT) -> Abort: await self.wait_kevent(fd, filter, abort) @_public - async def wait_readable(self, fd: int | socket) -> None: + async def wait_readable(self, fd: int | _HasFileNo) -> None: await self._wait_common(fd, select.KQ_FILTER_READ) @_public - async def wait_writable(self, fd: int | socket) -> None: + async def wait_writable(self, fd: int | _HasFileNo) -> None: await self._wait_common(fd, select.KQ_FILTER_WRITE) @_public - def notify_closing(self, fd: int | socket) -> None: + def notify_closing(self, fd: int | _HasFileNo) -> None: if not isinstance(fd, int): fd = fd.fileno() diff --git a/trio/_dtls.py b/trio/_dtls.py index 4a244ecca8..d6baebafbc 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -43,26 +43,26 @@ from OpenSSL.SSL import Context from typing_extensions import Self, TypeAlias - from trio.socket import Address, _SocketType + from trio.socket import SocketType MAX_UDP_PACKET_SIZE = 65527 -def packet_header_overhead(sock: _SocketType) -> int: +def packet_header_overhead(sock: SocketType) -> int: if sock.family == trio.socket.AF_INET: return 28 else: return 48 -def worst_case_mtu(sock: _SocketType) -> int: +def worst_case_mtu(sock: SocketType) -> int: if sock.family == trio.socket.AF_INET: return 576 - packet_header_overhead(sock) else: return 1280 - packet_header_overhead(sock) -def best_guess_mtu(sock: _SocketType) -> int: +def best_guess_mtu(sock: SocketType) -> int: return 1500 - packet_header_overhead(sock) @@ -563,7 +563,7 @@ def _signable(*fields: bytes) -> bytes: def _make_cookie( - key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes + key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes ) -> bytes: assert len(salt) == SALT_BYTES assert len(key) == KEY_BYTES @@ -581,7 +581,7 @@ def _make_cookie( def valid_cookie( - key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes + key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes ) -> bool: if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] @@ -603,7 +603,7 @@ def valid_cookie( def challenge_for( - key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes + key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes ) -> bytes: salt = os.urandom(SALT_BYTES) tick = _current_cookie_tick() @@ -664,7 +664,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: async def handle_client_hello_untrusted( - endpoint: DTLSEndpoint, address: Address, packet: bytes + endpoint: DTLSEndpoint, address: Any, packet: bytes ) -> None: if endpoint._listening_context is None: return @@ -739,7 +739,7 @@ async def handle_client_hello_untrusted( async def dtls_receive_loop( - endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType + endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType ) -> None: try: while True: @@ -829,7 +829,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context): + def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -1180,7 +1180,7 @@ class DTLSEndpoint: """ - def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): + def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. global SSL @@ -1191,7 +1191,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") self._initialized = True - self.socket: _SocketType = socket + self.socket: SocketType = socket self.incoming_packets_buffer = incoming_packets_buffer self._token = trio.lowlevel.current_trio_token() @@ -1200,7 +1200,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary() + self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary() self._listening_context: Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) diff --git a/trio/_highlevel_open_tcp_listeners.py b/trio/_highlevel_open_tcp_listeners.py index e6840eae97..019505c827 100644 --- a/trio/_highlevel_open_tcp_listeners.py +++ b/trio/_highlevel_open_tcp_listeners.py @@ -192,11 +192,12 @@ async def serve_tcp( connect to it to check that it's working properly, you can use something like:: + from trio import SocketListener, SocketStream from trio.testing import open_stream_to_socket_listener async with trio.open_nursery() as nursery: - listeners = await nursery.start(serve_tcp, handler, 0) - client_stream = await open_stream_to_socket_listener(listeners[0]) + listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) + client_stream: SocketStream = await open_stream_to_socket_listener(listeners[0]) # Then send and receive data on 'client_stream', for example: await client_stream.send_all(b"GET / HTTP/1.0\\r\\n\\r\\n") diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 0c4e8a4a8d..c3dc157827 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -4,11 +4,11 @@ from collections.abc import Generator from contextlib import contextmanager from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import trio from trio._core._multierror import MultiError -from trio.socket import SOCK_STREAM, Address, _SocketType, getaddrinfo, socket +from trio.socket import SOCK_STREAM, SocketType, getaddrinfo, socket if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -114,8 +114,8 @@ @contextmanager -def close_all() -> Generator[set[_SocketType], None, None]: - sockets_to_close: set[_SocketType] = set() +def close_all() -> Generator[set[SocketType], None, None]: + sockets_to_close: set[SocketType] = set() try: yield sockets_to_close finally: @@ -138,7 +138,7 @@ def reorder_for_rfc_6555_section_5_4( SocketKind, int, str, - tuple[str, int] | tuple[str, int, int, int], + Any, ] ] ) -> None: @@ -159,7 +159,7 @@ def reorder_for_rfc_6555_section_5_4( break -def format_host_port(host: str | bytes, port: int) -> str: +def format_host_port(host: str | bytes, port: int | str) -> str: host = host.decode("ascii") if isinstance(host, bytes) else host if ":" in host: return f"[{host}]:{port}" @@ -193,7 +193,7 @@ async def open_tcp_stream( *, happy_eyeballs_delay: float | None = DEFAULT_DELAY, local_address: str | None = None, -) -> trio.abc.Stream: +) -> trio.SocketStream: """Connect to the given host and port over TCP. If the given ``host`` has multiple IP addresses associated with it, then @@ -292,7 +292,7 @@ async def open_tcp_stream( # Keeps track of the socket that we're going to complete with, # need to make sure this isn't automatically closed - winning_socket: _SocketType | None = None + winning_socket: SocketType | None = None # Try connecting to the specified address. Possible outcomes: # - success: record connected socket in winning_socket and cancel @@ -303,7 +303,7 @@ async def open_tcp_stream( # face of crash or cancellation async def attempt_connect( socket_args: tuple[AddressFamily, SocketKind, int], - sockaddr: Address, + sockaddr: Any, attempt_failed: trio.Event, ) -> None: nonlocal winning_socket diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index af7ed7278d..9ee8aa2249 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from typing_extensions import Buffer - from ._socket import _SocketType as SocketType + from ._socket import SocketType # XX TODO: this number was picked arbitrarily. We should do experiments to # tune it. (Or make it dynamic -- one idea is to start small and increase it diff --git a/trio/_socket.py b/trio/_socket.py index 2834a5b055..77731708ba 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -13,9 +13,7 @@ Awaitable, Callable, Literal, - NoReturn, SupportsIndex, - Tuple, TypeVar, Union, overload, @@ -40,10 +38,18 @@ T = TypeVar("T") -# must use old-style typing because it's evaluated at runtime -Address: TypeAlias = Union[ - str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int] -] +# _stdlib_socket.socket supports 13 different socket families, see +# https://docs.python.org/3/library/socket.html#socket-families +# and the return type of several methods in SocketType will depend on those. Typeshed +# has ended up typing those return types as `Any` in most cases, but for users that +# know which family/families they're working in we could make SocketType a generic type, +# where you specify the return values you expect from those methods depending on the +# protocol the socket will be handling. +# But without the ability to default the value to `Any` it will be overly cumbersome for +# most users, so currently we just specify it as `Any`. Otherwise we would write: +# `AddressFormat = TypeVar("AddressFormat")` +# but instead we simply do: +AddressFormat: TypeAlias = Any # Usage: @@ -173,11 +179,7 @@ async def getaddrinfo( flags: int = 0, ) -> list[ tuple[ - AddressFamily, - SocketKind, - int, - str, - tuple[str, int] | tuple[str, int, int, int], + AddressFamily, SocketKind, int, str, tuple[str, int] | tuple[str, int, int, int] ] ]: """Look up a numeric address given a name. @@ -282,7 +284,7 @@ async def getprotobyname(name: str) -> int: ################################################################ -def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType: +def from_stdlib_socket(sock: _stdlib_socket.socket) -> SocketType: """Convert a standard library :class:`socket.socket` object into a Trio socket object. @@ -296,7 +298,7 @@ def fromfd( family: AddressFamily | int = _stdlib_socket.AF_INET, type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, -) -> _SocketType: +) -> SocketType: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -307,7 +309,7 @@ def fromfd( ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(info: bytes) -> _SocketType: + def fromshare(info: bytes) -> SocketType: return from_stdlib_socket(_stdlib_socket.fromshare(info)) @@ -326,7 +328,7 @@ def socketpair( family: FamilyT = FamilyDefault, type: TypeT = SocketKind.SOCK_STREAM, proto: int = 0, -) -> tuple[_SocketType, _SocketType]: +) -> tuple[SocketType, SocketType]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. @@ -341,7 +343,7 @@ def socket( type: SocketKind | int = _stdlib_socket.SOCK_STREAM, proto: int = 0, fileno: int | None = None, -) -> _SocketType: +) -> SocketType: """Create a new Trio socket, like :class:`socket.socket`. This function's behavior can be customized using @@ -389,7 +391,7 @@ def _sniff_sockopts_for_fileno( ################################################################ -# _SocketType +# SocketType ################################################################ # sock.type gets weird stuff set in it, in particular on Linux: @@ -450,9 +452,9 @@ async def _resolve_address_nocp( proto: int, *, ipv6_v6only: bool | int, - address: Address, + address: AddressFormat, local: bool, -) -> Address: +) -> Any: # Do some pre-checking (or exit early for non-IP sockets) if family == _stdlib_socket.AF_INET: if not isinstance(address, tuple) or not len(address) == 2: @@ -475,9 +477,9 @@ async def _resolve_address_nocp( # Fast path for the simple case: already-resolved IP address, # already-resolved port. This is particularly important for UDP, since # every sendto call goes through here. - if isinstance(port, int): + if isinstance(port, int) and host is not None: try: - _stdlib_socket.inet_pton(family, address[0]) + _stdlib_socket.inet_pton(family, host) except (OSError, TypeError): pass else: @@ -510,26 +512,217 @@ async def _resolve_address_nocp( if family == _stdlib_socket.AF_INET6: list_normed = list(normed) assert len(normed) == 4 - # typechecking certainly doesn't like this logic, but given just how broad - # Address is, it's quite cumbersome to write the below without type: ignore if len(address) >= 3: - list_normed[2] = address[2] # type: ignore + list_normed[2] = address[2] if len(address) >= 4: - list_normed[3] = address[3] # type: ignore - return tuple(list_normed) # type: ignore + list_normed[3] = address[3] + return tuple(list_normed) return normed -# TODO: stopping users from initializing this type should be done in a different way, -# so SocketType can be used as a type. Note that this is *far* from trivial without -# breaking subclasses of SocketType. Can maybe add abstract methods to SocketType, -# or rename _SocketType. class SocketType: - def __init__(self) -> NoReturn: - raise TypeError( - "SocketType is an abstract class; use trio.socket.socket if you " - "want to construct a socket object" - ) + def __init__(self) -> None: + # make sure this __init__ works with multiple inheritance + super().__init__() + # and only raises error if it's directly constructed + if type(self) == SocketType: + raise TypeError( + "SocketType is an abstract class; use trio.socket.socket if you " + "want to construct a socket object" + ) + + def detach(self) -> int: + raise NotImplementedError + + def fileno(self) -> int: + raise NotImplementedError + + def getpeername(self) -> AddressFormat: + raise NotImplementedError + + def getsockname(self) -> AddressFormat: + raise NotImplementedError + + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + raise NotImplementedError + + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + raise NotImplementedError + + def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + raise NotImplementedError + + def get_inheritable(self) -> bool: + raise NotImplementedError + + def set_inheritable(self, inheritable: bool) -> None: + raise NotImplementedError + + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): + + def share(self, /, process_id: int) -> bytes: + raise NotImplementedError + + def __enter__(self) -> Self: + raise NotImplementedError + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + raise NotImplementedError + + @property + def family(self) -> AddressFamily: + raise NotImplementedError + + @property + def type(self) -> SocketKind: + raise NotImplementedError + + @property + def proto(self) -> int: + raise NotImplementedError + + @property + def did_shutdown_SHUT_WR(self) -> bool: + raise NotImplementedError + + def __repr__(self) -> str: + raise NotImplementedError + + def dup(self) -> SocketType: + raise NotImplementedError + + def close(self) -> None: + raise NotImplementedError + + async def bind(self, address: AddressFormat) -> None: + raise NotImplementedError + + def shutdown(self, flag: int) -> None: + raise NotImplementedError + + def is_readable(self) -> bool: + raise NotImplementedError + + async def wait_writable(self) -> None: + raise NotImplementedError + + async def accept(self) -> tuple[SocketType, AddressFormat]: + raise NotImplementedError + + async def connect(self, address: AddressFormat) -> None: + raise NotImplementedError + + # argument names with __ used because of typeshed, see comment for recv in _SocketType + def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: + raise NotImplementedError + + def recv_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: + raise NotImplementedError + + # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] + def recvfrom( + __self, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, AddressFormat]]: + raise NotImplementedError + + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[tuple[int, AddressFormat]]: + raise NotImplementedError + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg") + ): + + def recvmsg( + __self, + __bufsize: int, + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + raise NotImplementedError + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into") + ): + + def recvmsg_into( + __self, + __buffers: Iterable[Buffer], + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + raise NotImplementedError + + def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: + raise NotImplementedError + + @overload + async def sendto( + self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @overload + async def sendto( + self, + __data: Buffer, + __flags: int, + __address: tuple[Any, ...] | str | Buffer, + ) -> int: + ... + + async def sendto(self, *args: Any) -> int: + raise NotImplementedError + + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg") + ): + + @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) + async def sendmsg( + self, + __buffers: Iterable[Buffer], + __ancdata: Iterable[tuple[int, int, Buffer]] = (), + __flags: int = 0, + __address: AddressFormat | None = None, + ) -> int: + raise NotImplementedError class _SocketType(SocketType): @@ -555,10 +748,10 @@ def detach(self) -> int: def fileno(self) -> int: return self._sock.fileno() - def getpeername(self) -> Any: + def getpeername(self) -> AddressFormat: return self._sock.getpeername() - def getsockname(self) -> Any: + def getsockname(self) -> AddressFormat: return self._sock.getsockname() @overload @@ -653,7 +846,7 @@ def did_shutdown_SHUT_WR(self) -> bool: def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") - def dup(self) -> _SocketType: + def dup(self) -> SocketType: """Same as :meth:`socket.socket.dup`.""" return _SocketType(self._sock.dup()) @@ -662,7 +855,7 @@ def close(self) -> None: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address: Address) -> None: + async def bind(self, address: AddressFormat) -> None: address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") @@ -701,10 +894,10 @@ async def wait_writable(self) -> None: async def _resolve_address_nocp( self, - address: Address, + address: AddressFormat, *, local: bool, - ) -> Address: + ) -> AddressFormat: if self.family == _stdlib_socket.AF_INET6: ipv6_v6only = self._sock.getsockopt( _stdlib_socket.IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY @@ -772,7 +965,7 @@ async def _nonblocking_helper( _stdlib_socket.socket.accept, _core.wait_readable ) - async def accept(self) -> tuple[_SocketType, object]: + async def accept(self) -> tuple[SocketType, AddressFormat]: """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr @@ -781,7 +974,7 @@ async def accept(self) -> tuple[_SocketType, object]: # connect ################################################################ - async def connect(self, address: Address) -> None: + async def connect(self, address: AddressFormat) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the @@ -894,7 +1087,7 @@ def recv_into( # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] def recvfrom( __self, __bufsize: int, __flags: int = 0 - ) -> Awaitable[tuple[bytes, Address]]: + ) -> Awaitable[tuple[bytes, AddressFormat]]: ... recvfrom = _make_simple_sock_method_wrapper( # noqa: F811 @@ -909,7 +1102,7 @@ def recvfrom( # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] def recvfrom_into( __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 - ) -> Awaitable[tuple[int, Address]]: + ) -> Awaitable[tuple[int, AddressFormat]]: ... recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811 @@ -920,7 +1113,9 @@ def recvfrom_into( # recvmsg ################################################################ - if hasattr(_stdlib_socket.socket, "recvmsg"): + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg") + ): if TYPE_CHECKING: def recvmsg( @@ -936,7 +1131,9 @@ def recvmsg( # recvmsg_into ################################################################ - if hasattr(_stdlib_socket.socket, "recvmsg_into"): + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "recvmsg_into") + ): if TYPE_CHECKING: def recvmsg_into( @@ -1005,7 +1202,7 @@ async def sendmsg( __buffers: Iterable[Buffer], __ancdata: Iterable[tuple[int, int, Buffer]] = (), __flags: int = 0, - __address: Address | None = None, + __address: AddressFormat | None = None, ) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py index 6eca844f0c..6f39446c7e 100644 --- a/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import errno import socket as stdlib_socket import sys from math import inf +from socket import AddressFamily, SocketKind +from typing import TYPE_CHECKING, Any, Sequence, overload import attr import pytest import trio from trio import SocketListener, open_tcp_listeners, open_tcp_stream, serve_tcp +from trio.abc import HostnameResolver, SendStream, SocketFactory from trio.testing import open_stream_to_socket_listener from .. import socket as tsocket @@ -16,8 +21,11 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup +if TYPE_CHECKING: + from typing_extensions import Buffer + -async def test_open_tcp_listeners_basic(): +async def test_open_tcp_listeners_basic() -> None: listeners = await open_tcp_listeners(0) assert isinstance(listeners, list) for obj in listeners: @@ -45,7 +53,7 @@ async def test_open_tcp_listeners_basic(): await resource.aclose() -async def test_open_tcp_listeners_specific_port_specific_host(): +async def test_open_tcp_listeners_specific_port_specific_host() -> None: # Pick a port sock = tsocket.socket() await sock.bind(("127.0.0.1", 0)) @@ -58,7 +66,7 @@ async def test_open_tcp_listeners_specific_port_specific_host(): @binds_ipv6 -async def test_open_tcp_listeners_ipv6_v6only(): +async def test_open_tcp_listeners_ipv6_v6only() -> None: # Check IPV6_V6ONLY is working properly (ipv6_listener,) = await open_tcp_listeners(0, host="::1") async with ipv6_listener: @@ -68,7 +76,7 @@ async def test_open_tcp_listeners_ipv6_v6only(): await open_tcp_stream("127.0.0.1", port) -async def test_open_tcp_listeners_rebind(): +async def test_open_tcp_listeners_rebind() -> None: (l1,) = await open_tcp_listeners(0, host="127.0.0.1") sockaddr1 = l1.socket.getsockname() @@ -115,43 +123,89 @@ class FakeOSError(OSError): @attr.s class FakeSocket(tsocket.SocketType): - family = attr.ib() - type = attr.ib() - proto = attr.ib() - - closed = attr.ib(default=False) - poison_listen = attr.ib(default=False) - backlog = attr.ib(default=None) - - def getsockopt(self, level, option): - if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN): + _family: AddressFamily = attr.ib(converter=AddressFamily) + _type: SocketKind = attr.ib(converter=SocketKind) + _proto: int = attr.ib() + + closed: bool = attr.ib(default=False) + poison_listen: bool = attr.ib(default=False) + backlog: int | None = attr.ib(default=None) + + @property + def type(self) -> SocketKind: + return self._type + + @property + def family(self) -> AddressFamily: + return self._family + + @property + def proto(self) -> int: # pragma: no cover + return self._proto + + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + if (level, optname) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN): return True assert False # pragma: no cover - def setsockopt(self, level, option, value): + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: pass - async def bind(self, sockaddr): + async def bind(self, address: Any) -> None: pass - def listen(self, backlog): + def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None: assert self.backlog is None assert backlog is not None self.backlog = backlog if self.poison_listen: raise FakeOSError("whoops") - def close(self): + def close(self) -> None: self.closed = True @attr.s -class FakeSocketFactory: - poison_after = attr.ib() - sockets = attr.ib(factory=list) - raise_on_family = attr.ib(factory=dict) # family => errno - - def socket(self, family, type, proto): +class FakeSocketFactory(SocketFactory): + poison_after: int = attr.ib() + sockets: list[tsocket.SocketType] = attr.ib(factory=list) + raise_on_family: dict[AddressFamily, int] = attr.ib(factory=dict) # family => errno + + def socket( + self, + family: AddressFamily | int | None = None, + type: SocketKind | int | None = None, + proto: int = 0, + ) -> tsocket.SocketType: + assert family is not None + assert type is not None + if isinstance(family, int) and not isinstance(family, AddressFamily): + family = AddressFamily(family) # pragma: no cover if family in self.raise_on_family: raise OSError(self.raise_on_family[family], "nope") sock = FakeSocket(family, type, proto) @@ -163,17 +217,39 @@ def socket(self, family, type, proto): @attr.s -class FakeHostnameResolver: - family_addr_pairs = attr.ib() - - async def getaddrinfo(self, host, port, family, type, proto, flags): +class FakeHostnameResolver(HostnameResolver): + family_addr_pairs: Sequence[tuple[AddressFamily, str]] = attr.ib() + + async def getaddrinfo( + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: + assert isinstance(port, int) return [ (family, tsocket.SOCK_STREAM, 0, "", (addr, port)) for family, addr in self.family_addr_pairs ] + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, str]: + raise NotImplementedError() + -async def test_open_tcp_listeners_multiple_host_cleanup_on_error(): +async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None: # If we were trying to bind to multiple hosts and one of them failed, they # call get cleaned up before returning fsf = FakeSocketFactory(3) @@ -193,25 +269,27 @@ async def test_open_tcp_listeners_multiple_host_cleanup_on_error(): assert len(fsf.sockets) == 3 for sock in fsf.sockets: - assert sock.closed + # property only exists on FakeSocket + assert sock.closed # type: ignore[attr-defined] -async def test_open_tcp_listeners_port_checking(): +async def test_open_tcp_listeners_port_checking() -> None: for host in ["127.0.0.1", None]: with pytest.raises(TypeError): - await open_tcp_listeners(None, host=host) + await open_tcp_listeners(None, host=host) # type: ignore[arg-type] with pytest.raises(TypeError): - await open_tcp_listeners(b"80", host=host) + await open_tcp_listeners(b"80", host=host) # type: ignore[arg-type] with pytest.raises(TypeError): - await open_tcp_listeners("http", host=host) + await open_tcp_listeners("http", host=host) # type: ignore[arg-type] -async def test_serve_tcp(): - async def handler(stream): +async def test_serve_tcp() -> None: + async def handler(stream: SendStream) -> None: await stream.send_all(b"x") async with trio.open_nursery() as nursery: - listeners = await nursery.start(serve_tcp, handler, 0) + # nursery.start is incorrectly typed, awaiting #2773 + listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) stream = await open_stream_to_socket_listener(listeners[0]) async with stream: await stream.receive_some(1) == b"x" @@ -227,8 +305,8 @@ async def handler(stream): [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}], ) async def test_open_tcp_listeners_some_address_families_unavailable( - try_families, fail_families -): + try_families: set[AddressFamily], fail_families: set[AddressFamily] +) -> None: fsf = FakeSocketFactory( 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families} ) @@ -257,7 +335,7 @@ async def test_open_tcp_listeners_some_address_families_unavailable( assert not should_succeed -async def test_open_tcp_listeners_socket_fails_not_afnosupport(): +async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None: fsf = FakeSocketFactory( 10, raise_on_family={ @@ -285,7 +363,7 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport(): # effectively is no backlog), sometimes the host might not be enough resources # to give us the full requested backlog... it was a mess. So now we just check # that the backlog argument is passed through correctly. -async def test_open_tcp_listeners_backlog(): +async def test_open_tcp_listeners_backlog() -> None: fsf = FakeSocketFactory(99) tsocket.set_custom_socket_factory(fsf) for given, expected in [ @@ -298,10 +376,11 @@ async def test_open_tcp_listeners_backlog(): listeners = await open_tcp_listeners(0, backlog=given) assert listeners for listener in listeners: - assert listener.socket.backlog == expected + # `backlog` only exists on FakeSocket + assert listener.socket.backlog == expected # type: ignore[attr-defined] -async def test_open_tcp_listeners_backlog_float_error(): +async def test_open_tcp_listeners_backlog_float_error() -> None: fsf = FakeSocketFactory(99) tsocket.set_custom_socket_factory(fsf) for should_fail in (0.0, 2.18, 3.14, 9.75): diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py index 24f82bddd5..eb7929c2db 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -1,5 +1,9 @@ +from __future__ import annotations + import socket import sys +from socket import AddressFamily, SocketKind +from typing import TYPE_CHECKING, Any, Sequence import attr import pytest @@ -11,24 +15,27 @@ open_tcp_stream, reorder_for_rfc_6555_section_5_4, ) -from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM +from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType + +if TYPE_CHECKING: + from trio.testing import MockClock if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup -def test_close_all(): - class CloseMe: +def test_close_all() -> None: + class CloseMe(SocketType): closed = False - def close(self): + def close(self) -> None: self.closed = True - class CloseKiller: - def close(self): + class CloseKiller(SocketType): + def close(self) -> None: raise OSError - c = CloseMe() + c: CloseMe = CloseMe() with close_all() as to_close: to_close.add(c) assert c.closed @@ -48,8 +55,10 @@ def close(self): assert c.closed -def test_reorder_for_rfc_6555_section_5_4(): - def fake4(i): +def test_reorder_for_rfc_6555_section_5_4() -> None: + def fake4( + i: int, + ) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]: return ( AF_INET, SOCK_STREAM, @@ -58,7 +67,9 @@ def fake4(i): (f"10.0.0.{i}", 80), ) - def fake6(i): + def fake6( + i: int, + ) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]: return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", (f"::{i}", 80)) for fake in fake4, fake6: @@ -85,7 +96,7 @@ def fake6(i): assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)] -def test_format_host_port(): +def test_format_host_port() -> None: assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80" assert format_host_port("example.com", 443) == "example.com:443" @@ -95,7 +106,7 @@ def test_format_host_port(): # Make sure we can connect to localhost using real kernel sockets -async def test_open_tcp_stream_real_socket_smoketest(): +async def test_open_tcp_stream_real_socket_smoketest() -> None: listen_sock = trio.socket.socket() await listen_sock.bind(("127.0.0.1", 0)) _, listen_port = listen_sock.getsockname() @@ -110,23 +121,24 @@ async def test_open_tcp_stream_real_socket_smoketest(): listen_sock.close() -async def test_open_tcp_stream_input_validation(): +async def test_open_tcp_stream_input_validation() -> None: with pytest.raises(ValueError): - await open_tcp_stream(None, 80) + await open_tcp_stream(None, 80) # type: ignore[arg-type] with pytest.raises(TypeError): - await open_tcp_stream("127.0.0.1", b"80") + await open_tcp_stream("127.0.0.1", b"80") # type: ignore[arg-type] -def can_bind_127_0_0_2(): +def can_bind_127_0_0_2() -> bool: with socket.socket() as s: try: s.bind(("127.0.0.2", 0)) except OSError: return False - return s.getsockname()[0] == "127.0.0.2" + # s.getsockname() is typed as returning Any + return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return] -async def test_local_address_real(): +async def test_local_address_real() -> None: with trio.socket.socket() as listener: await listener.bind(("127.0.0.1", 0)) listener.listen() @@ -153,10 +165,10 @@ async def test_local_address_real(): assert client_stream.socket.getsockopt( trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT ) - server_sock, remote_addr = await listener.accept() await client_stream.aclose() server_sock.close() + # accept returns tuple[SocketType, object], due to typeshed returning `Any` assert remote_addr[0] == local_address # Trying to connect to an ipv4 address with the ipv6 wildcard @@ -178,18 +190,30 @@ async def test_local_address_real(): @attr.s(eq=False) class FakeSocket(trio.socket.SocketType): - scenario = attr.ib() - family = attr.ib() - type = attr.ib() - proto = attr.ib() - - ip = attr.ib(default=None) - port = attr.ib(default=None) - succeeded = attr.ib(default=False) - closed = attr.ib(default=False) - failing = attr.ib(default=False) - - async def connect(self, sockaddr): + scenario: Scenario = attr.ib() + _family: AddressFamily = attr.ib() + _type: SocketKind = attr.ib() + _proto: int = attr.ib() + + ip: str | int | None = attr.ib(default=None) + port: str | int | None = attr.ib(default=None) + succeeded: bool = attr.ib(default=False) + closed: bool = attr.ib(default=False) + failing: bool = attr.ib(default=False) + + @property + def type(self) -> SocketKind: + return self._type + + @property + def family(self) -> AddressFamily: # pragma: no cover + return self._family + + @property + def proto(self) -> int: # pragma: no cover + return self._proto + + async def connect(self, sockaddr: tuple[str | int, str | int | None]) -> None: self.ip = sockaddr[0] self.port = sockaddr[1] assert self.ip not in self.scenario.sockets @@ -203,11 +227,11 @@ async def connect(self, sockaddr): self.failing = True self.succeeded = True - def close(self): + def close(self) -> None: self.closed = True # called when SocketStream is constructed - def setsockopt(self, *args, **kwargs): + def setsockopt(self, *args: object, **kwargs: object) -> None: if self.failing: # raise something that isn't OSError as SocketStream # ignores those @@ -215,11 +239,16 @@ def setsockopt(self, *args, **kwargs): class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver): - def __init__(self, port, ip_list, supported_families): + def __init__( + self, + port: int, + ip_list: Sequence[tuple[str, float, str]], + supported_families: set[AddressFamily], + ): # ip_list have to be unique ip_order = [ip for (ip, _, _) in ip_list] assert len(set(ip_order)) == len(ip_list) - ip_dict = {} + ip_dict: dict[str | int, tuple[float, str]] = {} for ip, delay, result in ip_list: assert 0 <= delay assert result in ["error", "success", "postconnect_fail"] @@ -230,16 +259,33 @@ def __init__(self, port, ip_list, supported_families): self.ip_dict = ip_dict self.supported_families = supported_families self.socket_count = 0 - self.sockets = {} - self.connect_times = {} - - def socket(self, family, type, proto): + self.sockets: dict[str | int, FakeSocket] = {} + self.connect_times: dict[str | int, float] = {} + + def socket( + self, + family: AddressFamily | int | None = None, + type: SocketKind | int | None = None, + proto: int | None = None, + ) -> SocketType: + assert isinstance(family, AddressFamily) + assert isinstance(type, SocketKind) + assert proto is not None if family not in self.supported_families: raise OSError("pretending not to support this family") self.socket_count += 1 return FakeSocket(self, family, type, proto) - def _ip_to_gai_entry(self, ip): + def _ip_to_gai_entry( + self, ip: str + ) -> tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int, int, int] | tuple[str, int], + ]: + sockaddr: tuple[str, int] | tuple[str, int, int, int] if ":" in ip: family = trio.socket.AF_INET6 sockaddr = (ip, self.port, 0, 0) @@ -248,7 +294,23 @@ def _ip_to_gai_entry(self, ip): sockaddr = (ip, self.port) return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr) - async def getaddrinfo(self, host, port, family, type, proto, flags): + async def getaddrinfo( + self, + host: str | bytes | None, + port: bytes | str | int | None, + family: int = -1, + type: int = -1, + proto: int = -1, + flags: int = -1, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int, int, int] | tuple[str, int], + ] + ]: assert host == b"test.example.com" assert port == self.port assert family == trio.socket.AF_UNSPEC @@ -257,10 +319,12 @@ async def getaddrinfo(self, host, port, family, type, proto, flags): assert flags == 0 return [self._ip_to_gai_entry(ip) for ip in self.ip_order] - async def getnameinfo(self, sockaddr, flags): # pragma: no cover + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, str]: raise NotImplementedError - def check(self, succeeded): + def check(self, succeeded: SocketType | None) -> None: # sockets only go into self.sockets when connect is called; make sure # all the sockets that were created did in fact go in there. assert self.socket_count == len(self.sockets) @@ -274,24 +338,24 @@ def check(self, succeeded): async def run_scenario( # The port to connect to - port, + port: int, # A list of # (ip, delay, result) # tuples, where delay is in seconds and result is "success" or "error" # The ip's will be returned from getaddrinfo in this order, and then # connect() calls to them will have the given result. - ip_list, + ip_list: Sequence[tuple[str, float, str]], *, # If False, AF_INET4/6 sockets error out on creation, before connect is # even called. - ipv4_supported=True, - ipv6_supported=True, + ipv4_supported: bool = True, + ipv6_supported: bool = True, # Normally, we return (winning_sock, scenario object) # If this is True, we require there to be an exception, and return # (exception, scenario object) - expect_error=(), - **kwargs, -): + expect_error: tuple[type[BaseException], ...] | type[BaseException] = (), + **kwargs: Any, +) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]: supported_families = set() if ipv4_supported: supported_families.add(trio.socket.AF_INET) @@ -313,19 +377,21 @@ async def run_scenario( return (exc, scenario) -async def test_one_host_quick_success(autojump_clock): +async def test_one_host_quick_success(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")]) + assert isinstance(sock, FakeSocket) assert sock.ip == "1.2.3.4" assert trio.current_time() == 0.123 -async def test_one_host_slow_success(autojump_clock): +async def test_one_host_slow_success(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")]) + assert isinstance(sock, FakeSocket) assert sock.ip == "1.2.3.4" assert trio.current_time() == 100 -async def test_one_host_quick_fail(autojump_clock): +async def test_one_host_quick_fail(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario( 82, [("1.2.3.4", 0.123, "error")], expect_error=OSError ) @@ -333,7 +399,7 @@ async def test_one_host_quick_fail(autojump_clock): assert trio.current_time() == 0.123 -async def test_one_host_slow_fail(autojump_clock): +async def test_one_host_slow_fail(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario( 83, [("1.2.3.4", 100, "error")], expect_error=OSError ) @@ -341,7 +407,7 @@ async def test_one_host_slow_fail(autojump_clock): assert trio.current_time() == 100 -async def test_one_host_failed_after_connect(autojump_clock): +async def test_one_host_failed_after_connect(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario( 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt ) @@ -349,7 +415,7 @@ async def test_one_host_failed_after_connect(autojump_clock): # With the default 0.250 second delay, the third attempt will win -async def test_basic_fallthrough(autojump_clock): +async def test_basic_fallthrough(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -358,6 +424,7 @@ async def test_basic_fallthrough(autojump_clock): ("3.3.3.3", 0.2, "success"), ], ) + assert isinstance(sock, FakeSocket) assert sock.ip == "3.3.3.3" # current time is default time + default time + connection time assert trio.current_time() == (0.250 + 0.250 + 0.2) @@ -368,7 +435,7 @@ async def test_basic_fallthrough(autojump_clock): } -async def test_early_success(autojump_clock): +async def test_early_success(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -377,6 +444,7 @@ async def test_early_success(autojump_clock): ("3.3.3.3", 0.2, "success"), ], ) + assert isinstance(sock, FakeSocket) assert sock.ip == "2.2.2.2" assert trio.current_time() == (0.250 + 0.1) assert scenario.connect_times == { @@ -387,7 +455,7 @@ async def test_early_success(autojump_clock): # With a 0.450 second delay, the first attempt will win -async def test_custom_delay(autojump_clock): +async def test_custom_delay(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -397,6 +465,7 @@ async def test_custom_delay(autojump_clock): ], happy_eyeballs_delay=0.450, ) + assert isinstance(sock, FakeSocket) assert sock.ip == "1.1.1.1" assert trio.current_time() == 1 assert scenario.connect_times == { @@ -406,7 +475,7 @@ async def test_custom_delay(autojump_clock): } -async def test_custom_errors_expedite(autojump_clock): +async def test_custom_errors_expedite(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -417,6 +486,7 @@ async def test_custom_errors_expedite(autojump_clock): ("4.4.4.4", 0.25, "success"), ], ) + assert isinstance(sock, FakeSocket) assert sock.ip == "4.4.4.4" assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25) assert scenario.connect_times == { @@ -427,7 +497,7 @@ async def test_custom_errors_expedite(autojump_clock): } -async def test_all_fail(autojump_clock): +async def test_all_fail(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario( 80, [ @@ -450,7 +520,7 @@ async def test_all_fail(autojump_clock): } -async def test_multi_success(autojump_clock): +async def test_multi_success(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -469,6 +539,7 @@ async def test_multi_success(autojump_clock): or scenario.sockets["4.4.4.4"].succeeded ) assert not scenario.sockets["5.5.5.5"].succeeded + assert isinstance(sock, FakeSocket) assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"] assert trio.current_time() == (0.5 + 10) assert scenario.connect_times == { @@ -480,7 +551,7 @@ async def test_multi_success(autojump_clock): } -async def test_does_reorder(autojump_clock): +async def test_does_reorder(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, [ @@ -492,6 +563,7 @@ async def test_does_reorder(autojump_clock): ], happy_eyeballs_delay=1, ) + assert isinstance(sock, FakeSocket) assert sock.ip == "::3" assert trio.current_time() == 1 + 0.5 assert scenario.connect_times == { @@ -500,7 +572,7 @@ async def test_does_reorder(autojump_clock): } -async def test_handles_no_ipv4(autojump_clock): +async def test_handles_no_ipv4(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect @@ -514,6 +586,7 @@ async def test_handles_no_ipv4(autojump_clock): happy_eyeballs_delay=1, ipv4_supported=False, ) + assert isinstance(sock, FakeSocket) assert sock.ip == "::3" assert trio.current_time() == 1 + 0.1 assert scenario.connect_times == { @@ -522,7 +595,7 @@ async def test_handles_no_ipv4(autojump_clock): } -async def test_handles_no_ipv6(autojump_clock): +async def test_handles_no_ipv6(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, # Here the ipv6 addresses fail at socket creation time, so the connect @@ -536,6 +609,7 @@ async def test_handles_no_ipv6(autojump_clock): happy_eyeballs_delay=1, ipv6_supported=False, ) + assert isinstance(sock, FakeSocket) assert sock.ip == "4.4.4.4" assert trio.current_time() == 1 + 0.1 assert scenario.connect_times == { @@ -544,12 +618,12 @@ async def test_handles_no_ipv6(autojump_clock): } -async def test_no_hosts(autojump_clock): +async def test_no_hosts(autojump_clock: MockClock) -> None: exc, scenario = await run_scenario(80, [], expect_error=OSError) assert "no results found" in str(exc) -async def test_cancel(autojump_clock): +async def test_cancel(autojump_clock: MockClock) -> None: with trio.move_on_after(5) as cancel_scope: exc, scenario = await run_scenario( 80, @@ -561,6 +635,7 @@ async def test_cancel(autojump_clock): ], expect_error=BaseExceptionGroup, ) + assert isinstance(exc, BaseException) # What comes out should be 1 or more Cancelled errors that all belong # to this cancel_scope; this is the easiest way to check that raise exc @@ -571,4 +646,4 @@ async def test_cancel(autojump_clock): # This should have been called already, but just to make sure, since the # exception-handling logic in run_scenario is a bit complicated and the # main thing we care about here is that all the sockets were cleaned up. - scenario.check(succeeded=False) + scenario.check(succeeded=None) diff --git a/trio/_tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py index 1a987df3f3..830a153c00 100644 --- a/trio/_tests/test_highlevel_socket.py +++ b/trio/_tests/test_highlevel_socket.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import errno import socket as stdlib_socket import sys +from typing import Sequence import pytest @@ -14,16 +17,17 @@ from .test_socket import setsockopt_tests -async def test_SocketStream_basics(): +async def test_SocketStream_basics() -> None: # stdlib socket bad (even if connected) - a, b = stdlib_socket.socketpair() - with a, b: + stdlib_a, stdlib_b = stdlib_socket.socketpair() + with stdlib_a, stdlib_b: with pytest.raises(TypeError): - SocketStream(a) + SocketStream(stdlib_a) # type: ignore[arg-type] # DGRAM socket bad with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock: with pytest.raises(ValueError): + # TODO: does not raise an error? SocketStream(sock) a, b = tsocket.socketpair() @@ -48,13 +52,13 @@ async def test_SocketStream_basics(): s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY) - b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1) - assert isinstance(b, bytes) + res = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1) + assert isinstance(res, bytes) setsockopt_tests(s) -async def test_SocketStream_send_all(): +async def test_SocketStream_send_all() -> None: BIG = 10000000 a_sock, b_sock = tsocket.socketpair() @@ -65,7 +69,7 @@ async def test_SocketStream_send_all(): # Check a send_all that has to be split into multiple parts (on most # platforms... on Windows every send() either succeeds or fails as a # whole) - async def sender(): + async def sender() -> None: data = bytearray(BIG) await a.send_all(data) # send_all uses memoryviews internally, which temporarily "lock" @@ -89,7 +93,7 @@ async def sender(): # and we break our implementation of send_all, then we'll get some # early warning...) - async def receiver(): + async def receiver() -> None: # Make sure the sender fills up the kernel buffers and blocks await wait_all_tasks_blocked() nbytes = 0 @@ -109,12 +113,12 @@ async def receiver(): assert await b.receive_some(10) == b"" -async def fill_stream(s): - async def sender(): +async def fill_stream(s: SocketStream) -> None: + async def sender() -> None: while True: await s.send_all(b"x" * 10000) - async def waiter(nursery): + async def waiter(nursery: _core.Nursery) -> None: await wait_all_tasks_blocked() nursery.cancel_scope.cancel() @@ -123,12 +127,12 @@ async def waiter(nursery): nursery.start_soon(waiter, nursery) -async def test_SocketStream_generic(): - async def stream_maker(): +async def test_SocketStream_generic() -> None: + async def stream_maker() -> tuple[SocketStream, SocketStream]: left, right = tsocket.socketpair() return SocketStream(left), SocketStream(right) - async def clogged_stream_maker(): + async def clogged_stream_maker() -> tuple[SocketStream, SocketStream]: left, right = await stream_maker() await fill_stream(left) await fill_stream(right) @@ -137,13 +141,13 @@ async def clogged_stream_maker(): await check_half_closeable_stream(stream_maker, clogged_stream_maker) -async def test_SocketListener(): +async def test_SocketListener() -> None: # Not a Trio socket with stdlib_socket.socket() as s: s.bind(("127.0.0.1", 0)) s.listen(10) with pytest.raises(TypeError): - SocketListener(s) + SocketListener(s) # type: ignore[arg-type] # Not a SOCK_STREAM with tsocket.socket(type=tsocket.SOCK_DGRAM) as s: @@ -190,7 +194,7 @@ async def test_SocketListener(): await server_stream.aclose() -async def test_SocketListener_socket_closed_underfoot(): +async def test_SocketListener_socket_closed_underfoot() -> None: listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(10) @@ -205,21 +209,50 @@ async def test_SocketListener_socket_closed_underfoot(): await listener.accept() -async def test_SocketListener_accept_errors(): +async def test_SocketListener_accept_errors() -> None: class FakeSocket(tsocket.SocketType): - def __init__(self, events): + def __init__(self, events: Sequence[SocketType | BaseException]): self._events = iter(events) type = tsocket.SOCK_STREAM # Fool the check for SO_ACCEPTCONN in SocketListener.__init__ - def getsockopt(self, level, opt): + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt( # noqa: F811 + self, /, level: int, optname: int, buflen: int + ) -> bytes: + ... + + def getsockopt( # noqa: F811 + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: return True - def setsockopt(self, level, opt, value): + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt( # noqa: F811 + self, /, level: int, optname: int, value: None, optlen: int + ) -> None: + ... + + def setsockopt( # noqa: F811 + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: pass - async def accept(self): + async def accept(self) -> tuple[SocketType, object]: await _core.checkpoint() event = next(self._events) if isinstance(event, BaseException): @@ -242,24 +275,24 @@ async def accept(self): ] ) - l = SocketListener(fake_listen_sock) + listener = SocketListener(fake_listen_sock) with assert_checkpoints(): - s = await l.accept() + s = await listener.accept() assert s.socket is fake_server_sock for code in [errno.EMFILE, errno.EFAULT, errno.ENOBUFS]: with assert_checkpoints(): with pytest.raises(OSError) as excinfo: - await l.accept() + await listener.accept() assert excinfo.value.errno == code with assert_checkpoints(): - s = await l.accept() + s = await listener.accept() assert s.socket is fake_server_sock -async def test_socket_stream_works_when_peer_has_already_closed(): +async def test_socket_stream_works_when_peer_has_already_closed() -> None: sock_a, sock_b = tsocket.socketpair() with sock_a, sock_b: await sock_b.send(b"x") diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py index 036098b8e5..40ffefb8cd 100644 --- a/trio/_tests/test_socket.py +++ b/trio/_tests/test_socket.py @@ -1,31 +1,51 @@ +from __future__ import annotations + import errno import inspect import os import socket as stdlib_socket import sys import tempfile +from socket import AddressFamily, SocketKind +from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union import attr import pytest from .. import _core, socket as tsocket from .._core._tests.tutil import binds_ipv6, creates_ipv6 -from .._socket import _NUMERIC_ONLY, _try_sync +from .._highlevel_socket import SocketStream +from .._socket import _NUMERIC_ONLY, SocketType, _SocketType, _try_sync from ..testing import assert_checkpoints, wait_all_tasks_blocked +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + GaiTuple: TypeAlias = Tuple[ + AddressFamily, + SocketKind, + int, + str, + Union[Tuple[str, int], Tuple[str, int, int, int]], + ] + getaddrinfoResponse: TypeAlias = List[GaiTuple] +else: + GaiTuple: object + getaddrinfoResponse = object + ################################################################ # utils ################################################################ class MonkeypatchedGAI: - def __init__(self, orig_getaddrinfo): + def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]): self._orig_getaddrinfo = orig_getaddrinfo - self._responses = {} - self.record = [] + self._responses: dict[tuple[Any, ...], getaddrinfoResponse | str] = {} + self.record: list[tuple[Any, ...]] = [] # get a normalized getaddrinfo argument tuple - def _frozenbind(self, *args, **kwargs): + def _frozenbind(self, *args: Any, **kwargs: Any) -> tuple[Any, ...]: sig = inspect.signature(self._orig_getaddrinfo) bound = sig.bind(*args, **kwargs) bound.apply_defaults() @@ -33,10 +53,12 @@ def _frozenbind(self, *args, **kwargs): assert not bound.kwargs return frozenbound - def set(self, response, *args, **kwargs): + def set( + self, response: getaddrinfoResponse | str, *args: Any, **kwargs: Any + ) -> None: self._responses[self._frozenbind(*args, **kwargs)] = response - def getaddrinfo(self, *args, **kwargs): + def getaddrinfo(self, *args: Any, **kwargs: Any) -> getaddrinfoResponse | str: bound = self._frozenbind(*args, **kwargs) self.record.append(bound) if bound in self._responses: @@ -48,13 +70,13 @@ def getaddrinfo(self, *args, **kwargs): @pytest.fixture -def monkeygai(monkeypatch): +def monkeygai(monkeypatch: pytest.MonkeyPatch) -> MonkeypatchedGAI: controller = MonkeypatchedGAI(stdlib_socket.getaddrinfo) monkeypatch.setattr(stdlib_socket, "getaddrinfo", controller.getaddrinfo) return controller -async def test__try_sync(): +async def test__try_sync() -> None: with assert_checkpoints(): async with _try_sync(): pass @@ -67,7 +89,7 @@ async def test__try_sync(): async with _try_sync(): raise BlockingIOError - def _is_ValueError(exc): + def _is_ValueError(exc: BaseException) -> bool: return isinstance(exc, ValueError) async with _try_sync(_is_ValueError): @@ -84,7 +106,7 @@ def _is_ValueError(exc): ################################################################ -def test_socket_has_some_reexports(): +def test_socket_has_some_reexports() -> None: assert tsocket.SOL_SOCKET == stdlib_socket.SOL_SOCKET assert tsocket.TCP_NODELAY == stdlib_socket.TCP_NODELAY assert tsocket.gaierror == stdlib_socket.gaierror @@ -96,19 +118,33 @@ def test_socket_has_some_reexports(): ################################################################ -async def test_getaddrinfo(monkeygai): - def check(got, expected): +async def test_getaddrinfo(monkeygai: MonkeypatchedGAI) -> None: + def check(got: getaddrinfoResponse, expected: getaddrinfoResponse) -> None: # win32 returns 0 for the proto field # musl and glibc have inconsistent handling of the canonical name # field (https://github.com/python-trio/trio/issues/1499) # Neither field gets used much and there isn't much opportunity for us # to mess them up, so we don't bother checking them here - def interesting_fields(gai_tup): + def interesting_fields( + gai_tup: GaiTuple, + ) -> tuple[ + AddressFamily, + SocketKind, + tuple[str, int] | tuple[str, int, int] | tuple[str, int, int, int], + ]: # (family, type, proto, canonname, sockaddr) family, type, proto, canonname, sockaddr = gai_tup return (family, type, sockaddr) - def filtered(gai_list): + def filtered( + gai_list: getaddrinfoResponse, + ) -> list[ + tuple[ + AddressFamily, + SocketKind, + tuple[str, int] | tuple[str, int, int] | tuple[str, int, int, int], + ] + ]: return [interesting_fields(gai_tup) for gai_tup in gai_list] assert filtered(got) == filtered(expected) @@ -172,7 +208,7 @@ def filtered(gai_list): await tsocket.getaddrinfo("asdf", "12345") -async def test_getnameinfo(): +async def test_getnameinfo() -> None: # Trivial test: ni_numeric = stdlib_socket.NI_NUMERICHOST | stdlib_socket.NI_NUMERICSERV with assert_checkpoints(): @@ -207,7 +243,7 @@ async def test_getnameinfo(): ################################################################ -async def test_from_stdlib_socket(): +async def test_from_stdlib_socket() -> None: sa, sb = stdlib_socket.socketpair() assert not isinstance(sa, tsocket.SocketType) with sa, sb: @@ -219,7 +255,7 @@ async def test_from_stdlib_socket(): # rejects other types with pytest.raises(TypeError): - tsocket.from_stdlib_socket(1) + tsocket.from_stdlib_socket(1) # type: ignore[arg-type] class MySocket(stdlib_socket.socket): pass @@ -229,7 +265,7 @@ class MySocket(stdlib_socket.socket): tsocket.from_stdlib_socket(mysock) -async def test_from_fd(): +async def test_from_fd() -> None: sa, sb = stdlib_socket.socketpair() ta = tsocket.fromfd(sa.fileno(), sa.family, sa.type, sa.proto) with sa, sb, ta: @@ -238,8 +274,8 @@ async def test_from_fd(): assert sb.recv(3) == b"x" -async def test_socketpair_simple(): - async def child(sock): +async def test_socketpair_simple() -> None: + async def child(sock: SocketType) -> None: print("sending hello") await sock.send(b"h") assert await sock.recv(1) == b"h" @@ -252,7 +288,9 @@ async def child(sock): @pytest.mark.skipif(not hasattr(tsocket, "fromshare"), reason="windows only") -async def test_fromshare(): +async def test_fromshare() -> None: + if TYPE_CHECKING and sys.platform != "win32": # pragma: no cover + return a, b = tsocket.socketpair() with a, b: # share with ourselves @@ -264,21 +302,21 @@ async def test_fromshare(): assert await b.recv(1) == b"x" -async def test_socket(): +async def test_socket() -> None: with tsocket.socket() as s: assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET @creates_ipv6 -async def test_socket_v6(): +async def test_socket_v6() -> None: with tsocket.socket(tsocket.AF_INET6, tsocket.SOCK_DGRAM) as s: assert isinstance(s, tsocket.SocketType) assert s.family == tsocket.AF_INET6 @pytest.mark.skipif(not sys.platform == "linux", reason="linux only") -async def test_sniff_sockopts(): +async def test_sniff_sockopts() -> None: from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM # generate the combinations of families/types we're testing: @@ -309,7 +347,7 @@ async def test_sniff_sockopts(): ################################################################ -async def test_SocketType_basics(): +async def test_SocketType_basics() -> None: sock = tsocket.socket() with sock as cm_enter_value: assert cm_enter_value is sock @@ -349,7 +387,7 @@ async def test_SocketType_basics(): # our __getattr__ handles unknown names with pytest.raises(AttributeError): - sock.asdf + sock.asdf # type: ignore[attr-defined] # type family proto stdlib_sock = stdlib_socket.socket() @@ -366,7 +404,7 @@ async def test_SocketType_setsockopt() -> None: setsockopt_tests(sock) -def setsockopt_tests(sock): +def setsockopt_tests(sock: SocketType | SocketStream) -> None: """Extract these out, to be reused for SocketStream also.""" # specifying optlen. Not supported on pypy, and I couldn't find # valid calls on darwin or win32. @@ -378,14 +416,14 @@ def setsockopt_tests(sock): # specifying both with pytest.raises(TypeError, match="invalid value for argument 'value'"): - sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) # type: ignore[call-overload] # specifying neither with pytest.raises(TypeError, match="invalid value for argument 'value'"): - sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) # type: ignore[call-overload] -async def test_SocketType_dup(): +async def test_SocketType_dup() -> None: a, b = tsocket.socketpair() with a, b: a2 = a.dup() @@ -397,7 +435,7 @@ async def test_SocketType_dup(): assert await b.recv(1) == b"x" -async def test_SocketType_shutdown(): +async def test_SocketType_shutdown() -> None: a, b = tsocket.socketpair() with a, b: await a.send(b"x") @@ -431,7 +469,9 @@ async def test_SocketType_shutdown(): pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6), ], ) -async def test_SocketType_simple_server(address, socket_type): +async def test_SocketType_simple_server( + address: str, socket_type: AddressFamily +) -> None: # listen, bind, accept, connect, getpeername, getsockname listener = tsocket.socket(socket_type) client = tsocket.socket(socket_type) @@ -448,7 +488,7 @@ async def test_SocketType_simple_server(address, socket_type): assert await client.recv(1) == b"x" -async def test_SocketType_is_readable(): +async def test_SocketType_is_readable() -> None: a, b = tsocket.socketpair() with a, b: assert not a.is_readable() @@ -462,7 +502,7 @@ async def test_SocketType_is_readable(): # On some macOS systems, getaddrinfo likes to return V4-mapped addresses even # when we *don't* pass AI_V4MAPPED. # https://github.com/python-trio/trio/issues/580 -def gai_without_v4mapped_is_buggy(): # pragma: no cover +def gai_without_v4mapped_is_buggy() -> bool: # pragma: no cover try: stdlib_socket.getaddrinfo("1.2.3.4", 0, family=stdlib_socket.AF_INET6) except stdlib_socket.gaierror: @@ -473,10 +513,10 @@ def gai_without_v4mapped_is_buggy(): # pragma: no cover @attr.s class Addresses: - bind_all = attr.ib() - localhost = attr.ib() - arbitrary = attr.ib() - broadcast = attr.ib() + bind_all: str = attr.ib() + localhost: str = attr.ib() + arbitrary: str = attr.ib() + broadcast: str = attr.ib() # Direct thorough tests of the implicit resolver helpers @@ -504,34 +544,49 @@ class Addresses: ), ], ) -async def test_SocketType_resolve(socket_type, addrs): +async def test_SocketType_resolve(socket_type: AddressFamily, addrs: Addresses) -> None: v6 = socket_type == tsocket.AF_INET6 - def pad(addr): + def pad(addr: tuple[str | int, ...]) -> tuple[str | int, ...]: if v6: while len(addr) < 4: addr += (0,) return addr - def assert_eq(actual, expected): + def assert_eq( + actual: tuple[str | int, ...], expected: tuple[str | int, ...] + ) -> None: assert pad(expected) == pad(actual) with tsocket.socket(family=socket_type) as sock: + # testing internal functionality, so we check it against the internal type + assert isinstance(sock, _SocketType) + # For some reason the stdlib special-cases "" to pass NULL to # getaddrinfo. They also error out on None, but whatever, None is much # more consistent, so we accept it too. + # TODO: this implies that we can send host=None, but what does that imply for the return value, and other stuff? for null in [None, ""]: got = await sock._resolve_address_nocp((null, 80), local=True) + assert not isinstance(got, (str, bytes)) assert_eq(got, (addrs.bind_all, 80)) got = await sock._resolve_address_nocp((null, 80), local=False) + assert not isinstance(got, (str, bytes)) assert_eq(got, (addrs.localhost, 80)) # AI_PASSIVE only affects the wildcard address, so for everything else # local=True/local=False should work the same: for local in [False, True]: - async def res(*args): - return await sock._resolve_address_nocp(*args, local=local) + async def res( + args: tuple[str, int] + | tuple[str, int, int] + | tuple[str, int, int, int] + | tuple[str, str] + | tuple[str, str, int] + | tuple[str, str, int, int] + ) -> Any: + return await sock._resolve_address_nocp(args, local=local) assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) if v6: @@ -582,6 +637,7 @@ async def res(*args): except (AttributeError, OSError): pass else: + assert isinstance(netlink_sock, _SocketType) assert ( await netlink_sock._resolve_address_nocp("asdf", local=local) == "asdf" @@ -589,17 +645,18 @@ async def res(*args): netlink_sock.close() with pytest.raises(ValueError): - await res("1.2.3.4") + await res("1.2.3.4") # type: ignore[arg-type] with pytest.raises(ValueError): - await res(("1.2.3.4",)) + await res(("1.2.3.4",)) # type: ignore[arg-type] with pytest.raises(ValueError): if v6: - await res(("1.2.3.4", 80, 0, 0, 0)) + await res(("1.2.3.4", 80, 0, 0, 0)) # type: ignore[arg-type] else: + # I guess in theory there could be enough overloads that this could error? await res(("1.2.3.4", 80, 0, 0)) -async def test_SocketType_unresolved_names(): +async def test_SocketType_unresolved_names() -> None: with tsocket.socket() as sock: await sock.bind(("localhost", 0)) assert sock.getsockname()[0] == "127.0.0.1" @@ -618,7 +675,7 @@ async def test_SocketType_unresolved_names(): # This tests all the complicated paths through _nonblocking_helper, using recv # as a stand-in for all the methods that use _nonblocking_helper. -async def test_SocketType_non_blocking_paths(): +async def test_SocketType_non_blocking_paths() -> None: a, b = stdlib_socket.socketpair() with a, b: ta = tsocket.from_stdlib_socket(a) @@ -638,10 +695,10 @@ async def test_SocketType_non_blocking_paths(): # immediate failure with assert_checkpoints(): with pytest.raises(TypeError): - await ta.recv("haha") + await ta.recv("haha") # type: ignore[arg-type] # block then succeed - async def do_successful_blocking_recv(): + async def do_successful_blocking_recv() -> None: with assert_checkpoints(): assert await ta.recv(10) == b"2" @@ -651,7 +708,7 @@ async def do_successful_blocking_recv(): b.send(b"2") # block then cancelled - async def do_cancelled_blocking_recv(): + async def do_cancelled_blocking_recv() -> None: with assert_checkpoints(): with pytest.raises(_core.Cancelled): await ta.recv(10) @@ -669,13 +726,13 @@ async def do_cancelled_blocking_recv(): # other: tb = tsocket.from_stdlib_socket(b) - async def t1(): + async def t1() -> None: with assert_checkpoints(): assert await ta.recv(1) == b"a" with assert_checkpoints(): assert await tb.recv(1) == b"b" - async def t2(): + async def t2() -> None: with assert_checkpoints(): assert await tb.recv(1) == b"b" with assert_checkpoints(): @@ -693,7 +750,7 @@ async def t2(): # This tests the complicated paths through connect -async def test_SocketType_connect_paths(): +async def test_SocketType_connect_paths() -> None: with tsocket.socket() as sock: with pytest.raises(ValueError): # Should be a tuple @@ -716,7 +773,10 @@ async def test_SocketType_connect_paths(): # nose -- and then swap it back out again before we hit # wait_socket_writable, which insists on a real socket. class CancelSocket(stdlib_socket.socket): - def connect(self, *args, **kwargs): + def connect(self, *args: Any, **kwargs: Any) -> None: + # accessing private method only available in _SocketType + assert isinstance(sock, _SocketType) + cancel_scope.cancel() sock._sock = stdlib_socket.fromfd( self.detach(), self.family, self.type @@ -725,6 +785,8 @@ def connect(self, *args, **kwargs): # If connect *doesn't* raise, then pretend it did raise BlockingIOError # pragma: no cover + # accessing private method only available in _SocketType + assert isinstance(sock, _SocketType) sock._sock.close() sock._sock = CancelSocket() @@ -747,7 +809,7 @@ def connect(self, *args, **kwargs): # Fix issue #1810 -async def test_address_in_socket_error(): +async def test_address_in_socket_error() -> None: address = "127.0.0.1" with tsocket.socket() as sock: try: @@ -756,23 +818,26 @@ async def test_address_in_socket_error(): assert any(address in str(arg) for arg in e.args) -async def test_resolve_address_exception_in_connect_closes_socket(): +async def test_resolve_address_exception_in_connect_closes_socket() -> None: # Here we are testing issue 247, any cancellation will leave the socket closed with _core.CancelScope() as cancel_scope: with tsocket.socket() as sock: - async def _resolve_address_nocp(self, *args, **kwargs): + async def _resolve_address_nocp( + self: Any, *args: Any, **kwargs: Any + ) -> None: cancel_scope.cancel() await _core.checkpoint() - sock._resolve_address_nocp = _resolve_address_nocp + assert isinstance(sock, _SocketType) + sock._resolve_address_nocp = _resolve_address_nocp # type: ignore[method-assign, assignment] with assert_checkpoints(): with pytest.raises(_core.Cancelled): await sock.connect("") assert sock.fileno() == -1 -async def test_send_recv_variants(): +async def test_send_recv_variants() -> None: a, b = tsocket.socketpair() with a, b: # recv, including with flags @@ -868,7 +933,7 @@ async def test_send_recv_variants(): assert await b.recv(10) == b"yyy" -async def test_idna(monkeygai): +async def test_idna(monkeygai: MonkeypatchedGAI) -> None: # This is the encoding for "faß.de", which uses one of the characters that # IDNA 2003 handles incorrectly: monkeygai.set("ok faß.de", b"xn--fa-hia.de", 80) @@ -886,24 +951,29 @@ async def test_idna(monkeygai): assert "ok faß.de" == await tsocket.getaddrinfo(b"xn--fa-hia.de", 80) -async def test_getprotobyname(): +async def test_getprotobyname() -> None: # These are the constants used in IP header fields, so the numeric values # had *better* be stable across systems... assert await tsocket.getprotobyname("udp") == 17 assert await tsocket.getprotobyname("tcp") == 6 -async def test_custom_hostname_resolver(monkeygai): +async def test_custom_hostname_resolver(monkeygai: MonkeypatchedGAI) -> None: + # This intentionally breaks the signatures used in HostnameResolver class CustomResolver: - async def getaddrinfo(self, host, port, family, type, proto, flags): + async def getaddrinfo( + self, host: str, port: str, family: int, type: int, proto: int, flags: int + ) -> tuple[str, str, str, int, int, int, int]: return ("custom_gai", host, port, family, type, proto, flags) - async def getnameinfo(self, sockaddr, flags): + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, tuple[str, int] | tuple[str, int, int, int], int]: return ("custom_gni", sockaddr, flags) cr = CustomResolver() - assert tsocket.set_custom_hostname_resolver(cr) is None + assert tsocket.set_custom_hostname_resolver(cr) is None # type: ignore[arg-type] # Check that the arguments are all getting passed through. # We have to use valid calls to avoid making the underlying system @@ -926,7 +996,11 @@ async def getnameinfo(self, sockaddr, flags): expected = ("custom_gai", b"xn--f-1gaa", "foo", 0, 0, 0, 0) assert got == expected - assert await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0) + assert await tsocket.getnameinfo("a", 0) == ( # type: ignore[arg-type] + "custom_gni", + "a", + 0, + ) # We can set it back to None assert tsocket.set_custom_hostname_resolver(None) is cr @@ -937,14 +1011,16 @@ async def getnameinfo(self, sockaddr, flags): assert await tsocket.getaddrinfo("host", "port") == "x" -async def test_custom_socket_factory(): +async def test_custom_socket_factory() -> None: class CustomSocketFactory: - def socket(self, family, type, proto): + def socket( + self, family: AddressFamily, type: SocketKind, proto: int + ) -> tuple[str, AddressFamily, SocketKind, int]: return ("hi", family, type, proto) csf = CustomSocketFactory() - assert tsocket.set_custom_socket_factory(csf) is None + assert tsocket.set_custom_socket_factory(csf) is None # type: ignore[arg-type] assert tsocket.socket() == ("hi", tsocket.AF_INET, tsocket.SOCK_STREAM, 0) assert tsocket.socket(1, 2, 3) == ("hi", 1, 2, 3) @@ -964,17 +1040,17 @@ def socket(self, family, type, proto): assert tsocket.set_custom_socket_factory(None) is csf -async def test_SocketType_is_abstract(): +async def test_SocketType_is_abstract() -> None: with pytest.raises(TypeError): tsocket.SocketType() @pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets") -async def test_unix_domain_socket(): +async def test_unix_domain_socket() -> None: # Bind has a special branch to use a thread, since it has to do filesystem # traversal. Maybe connect should too? Not sure. - async def check_AF_UNIX(path): + async def check_AF_UNIX(path: str | bytes) -> None: with tsocket.socket(family=tsocket.AF_UNIX) as lsock: await lsock.bind(path) lsock.listen(10) @@ -999,7 +1075,7 @@ async def check_AF_UNIX(path): pass -async def test_interrupted_by_close(): +async def test_interrupted_by_close() -> None: a_stdlib, b_stdlib = stdlib_socket.socketpair() with a_stdlib, b_stdlib: a_stdlib.setblocking(False) @@ -1014,11 +1090,11 @@ async def test_interrupted_by_close(): a = tsocket.from_stdlib_socket(a_stdlib) - async def sender(): + async def sender() -> None: with pytest.raises(_core.ClosedResourceError): await a.send(data) - async def receiver(): + async def receiver() -> None: with pytest.raises(_core.ClosedResourceError): await a.recv(1) @@ -1029,7 +1105,7 @@ async def receiver(): a.close() -async def test_many_sockets(): +async def test_many_sockets() -> None: total = 5000 # Must be more than MAX_AFD_GROUP_SIZE sockets = [] for x in range(total // 2): diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json index d6ba2e6c2e..acb4f58c07 100644 --- a/trio/_tests/verify_types_darwin.json +++ b/trio/_tests/verify_types_darwin.json @@ -40,7 +40,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 631, + "withKnownType": 630, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json index 5b1233d2b2..c060b2a51e 100644 --- a/trio/_tests/verify_types_linux.json +++ b/trio/_tests/verify_types_linux.json @@ -28,7 +28,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 628, + "withKnownType": 627, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json index bab6797c02..232b18757c 100644 --- a/trio/_tests/verify_types_windows.json +++ b/trio/_tests/verify_types_windows.json @@ -60,7 +60,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 631, + "withKnownType": 630, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, @@ -96,7 +96,7 @@ ], "otherSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 680, + "withKnownType": 676, "withUnknownType": 0 }, "packageName": "trio" diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index b5612c59c3..40ac4c0602 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -306,7 +306,10 @@ def main() -> None: # pragma: no cover """ IMPORTS_EPOLL = """\ -from socket import socket +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .._file_io import _HasFileNo """ IMPORTS_KQUEUE = """\ @@ -314,11 +317,11 @@ def main() -> None: # pragma: no cover if TYPE_CHECKING: import select - from socket import socket from ._traps import Abort, RaiseCancelT from .. import _core + from .._file_io import _HasFileNo """ diff --git a/trio/socket.py b/trio/socket.py index f8d0bc3fc2..a9e276c782 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -34,9 +34,7 @@ # import the overwrites from ._socket import ( - Address as Address, SocketType as SocketType, - _SocketType as _SocketType, from_stdlib_socket as from_stdlib_socket, fromfd as fromfd, getaddrinfo as getaddrinfo, diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index 7cd0c25225..74ce32d37f 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -8,6 +8,7 @@ from __future__ import annotations +import builtins import errno import ipaddress import os @@ -22,7 +23,9 @@ from socket import AddressFamily, SocketKind from types import TracebackType -IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + from typing_extensions import TypeAlias + +IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] def _family_for(ip: IPAddress) -> int: @@ -176,7 +179,9 @@ def deliver_packet(self, packet) -> None: @final class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor): - def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): + def __init__( + self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, proto: int + ): self._fake_net = fake_net if not family: @@ -189,9 +194,9 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): if type != trio.socket.SOCK_DGRAM: raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}") - self.family = family - self.type = type - self.proto = proto + self._family = family + self._type = type + self._proto = proto self._closed = False @@ -202,6 +207,18 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): # This is the source-of-truth for what port etc. this socket is bound to self._binding: Optional[UDPBinding] = None + @property + def type(self) -> SocketKind: + return self._type + + @property + def family(self) -> AddressFamily: + return self._family + + @property + def proto(self) -> int: + return self._proto + def _check_closed(self): if self._closed: _fake_err(errno.EBADF) @@ -364,7 +381,7 @@ def __enter__(self): def __exit__( self, - exc_type: type[BaseException] | None, + exc_type: builtins.type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: