From 83f22ec485c37ef6dc2a8021909911df9a6a6fd4 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 14 Jul 2023 17:53:03 +0200
Subject: [PATCH 01/16] adding typehints to _socket and friends
---
pyproject.toml | 7 +
trio/_core/_local.py | 58 ++++--
trio/_socket.py | 356 ++++++++++++++++++++++++++--------
trio/_sync.py | 16 +-
trio/_tests/verify_types.json | 34 +---
trio/_threads.py | 24 ++-
6 files changed, 352 insertions(+), 143 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index cfb4060ee7..1f6b15e45e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,6 +43,13 @@ disallow_untyped_defs = false
# DO NOT use `ignore_errors`; it doesn't apply
# downstream and users have to deal with them.
+[[tool.mypy.overrides]]
+module = [
+ "trio._socket",
+ "trio._core._local",
+ "trio._sync",
+]
+disallow_untyped_defs = true
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index a54f424fdf..fe509ca7ad 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -1,25 +1,36 @@
+from __future__ import annotations
+
+from typing import Generic, TypeVar, overload
+
# Runvar implementations
import attr
from .._util import Final
from . import _run
+T = TypeVar("T")
+C = TypeVar("C", bound="_RunVarToken")
+
+
+class NoValue(object):
+ ...
+
@attr.s(eq=False, hash=False, slots=True)
-class _RunVarToken:
- _no_value = object()
+class _RunVarToken(Generic[T]):
+ _no_value = NoValue()
- _var = attr.ib()
- previous_value = attr.ib(default=_no_value)
- redeemed = attr.ib(default=False, init=False)
+ _var: RunVar[T] = attr.ib()
+ previous_value: T | NoValue = attr.ib(default=_no_value)
+ redeemed: bool = attr.ib(default=False, init=False)
@classmethod
- def empty(cls, var):
+ def empty(cls: type[C], var: RunVar[T]) -> C:
return cls(var)
@attr.s(eq=False, hash=False, slots=True)
-class RunVar(metaclass=Final):
+class RunVar(Generic[T], metaclass=Final):
"""The run-local variant of a context variable.
:class:`RunVar` objects are similar to context variable objects,
@@ -28,14 +39,23 @@ class RunVar(metaclass=Final):
"""
- _NO_DEFAULT = object()
- _name = attr.ib()
- _default = attr.ib(default=_NO_DEFAULT)
+ _NO_DEFAULT = NoValue()
+ _name: str = attr.ib()
+ _default: T | NoValue = attr.ib(default=_NO_DEFAULT)
+
+ @overload
+ def get(self, default: T) -> T:
+ ...
+
+ @overload
+ def get(self, default: NoValue = _NO_DEFAULT) -> T | NoValue:
+ ...
- def get(self, default=_NO_DEFAULT):
+ def get(self, default: T | NoValue = _NO_DEFAULT) -> T | NoValue:
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
- return _run.GLOBAL_RUN_CONTEXT.runner._locals[self]
+ # not typed yet
+ return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value, index]
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
@@ -48,7 +68,7 @@ def get(self, default=_NO_DEFAULT):
raise LookupError(self) from None
- def set(self, value):
+ def set(self, value: T) -> _RunVarToken[T]:
"""Sets the value of this :class:`RunVar` for this current run
call.
@@ -56,16 +76,16 @@ def set(self, value):
try:
old_value = self.get()
except LookupError:
- token = _RunVarToken.empty(self)
+ token: _RunVarToken[T] = _RunVarToken.empty(self)
else:
token = _RunVarToken(self, old_value)
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index]
return token
- def reset(self, token):
+ def reset(self, token: _RunVarToken[T]) -> None:
"""Resets the value of this :class:`RunVar` to what it was
previously specified by the token.
@@ -82,13 +102,13 @@ def reset(self, token):
previous = token.previous_value
try:
if previous is _RunVarToken._no_value:
- _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
+ _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type]
else:
- _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
+ _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment]
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context")
token.redeemed = True
- def __repr__(self):
+ def __repr__(self) -> str:
return f""
diff --git a/trio/_socket.py b/trio/_socket.py
index eaf0e04d15..dc06e5c755 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -5,9 +5,22 @@
import socket as _stdlib_socket
import sys
from functools import wraps as _wraps
-from typing import TYPE_CHECKING
+from socket import AddressFamily, SocketKind
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ NoReturn,
+ SupportsIndex,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
import idna as _idna
+from typing_extensions import Concatenate, ParamSpec
import trio
@@ -17,7 +30,18 @@
from collections.abc import Iterable
from types import TracebackType
- from typing_extensions import Self
+ from typing_extensions import Buffer, Self, TypeAlias
+
+ from ._abc import HostnameResolver, SocketFactory
+
+
+T = TypeVar("T")
+P = ParamSpec("P")
+
+# must use old-style typing for TypeAlias
+Address: TypeAlias = Union[
+ str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int]
+]
# Usage:
@@ -29,16 +53,18 @@
# return await do_it_properly_with_a_check_point()
#
class _try_sync:
- def __init__(self, blocking_exc_override=None):
+ def __init__(
+ self, blocking_exc_override: Callable[[BaseException], bool] | None = None
+ ):
self._blocking_exc_override = blocking_exc_override
- def _is_blocking_io_error(self, exc):
+ def _is_blocking_io_error(self, exc: BaseException) -> bool:
if self._blocking_exc_override is None:
return isinstance(exc, BlockingIOError)
else:
return self._blocking_exc_override(exc)
- async def __aenter__(self):
+ async def __aenter__(self) -> None:
await trio.lowlevel.checkpoint_if_cancelled()
async def __aexit__(
@@ -73,11 +99,13 @@ async def __aexit__(
# Overrides
################################################################
-_resolver = _core.RunVar("hostname_resolver")
-_socket_factory = _core.RunVar("socket_factory")
+_resolver: _core.RunVar[HostnameResolver | None] = _core.RunVar("hostname_resolver")
+_socket_factory: _core.RunVar[SocketFactory | None] = _core.RunVar("socket_factory")
-def set_custom_hostname_resolver(hostname_resolver):
+def set_custom_hostname_resolver(
+ hostname_resolver: HostnameResolver | None,
+) -> HostnameResolver | None:
"""Set a custom hostname resolver.
By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions
@@ -109,7 +137,9 @@ def set_custom_hostname_resolver(hostname_resolver):
return old
-def set_custom_socket_factory(socket_factory):
+def set_custom_socket_factory(
+ socket_factory: SocketFactory | None,
+) -> SocketFactory | None:
"""Set a custom socket object factory.
This function allows you to replace Trio's normal socket class with a
@@ -143,7 +173,23 @@ def set_custom_socket_factory(socket_factory):
_NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV
-async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
+# It would be possible to @overload the return value depending on Literal[AddressFamily.INET/6], but should probably be added in typeshed first
+async def getaddrinfo(
+ host: bytes | str | None,
+ port: bytes | str | int | None,
+ family: int = 0,
+ type: int = 0,
+ proto: int = 0,
+ flags: int = 0,
+) -> list[
+ tuple[
+ AddressFamily,
+ SocketKind,
+ int,
+ str,
+ tuple[str, int] | tuple[str, int, int, int],
+ ]
+]:
"""Look up a numeric address given a name.
Arguments and return values are identical to :func:`socket.getaddrinfo`,
@@ -164,7 +210,7 @@ async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
# skip the whole thread thing, which seems worthwhile. So we try first
# with the _NUMERIC_ONLY flags set, and then only spawn a thread if that
# fails with EAI_NONAME:
- def numeric_only_failure(exc):
+ def numeric_only_failure(exc: BaseException) -> bool:
return (
isinstance(exc, _stdlib_socket.gaierror)
and exc.errno == _stdlib_socket.EAI_NONAME
@@ -190,9 +236,10 @@ def numeric_only_failure(exc):
# idna.encode will error out if the hostname has Capital Letters
# in it; with uts46=True it will lowercase them instead.
host = _idna.encode(host, uts46=True)
- hr = _resolver.get(None)
+ hr: HostnameResolver | None = _resolver.get(None)
+ # waiting on ._abc to get typed
if hr is not None:
- return await hr.getaddrinfo(host, port, family, type, proto, flags)
+ return await hr.getaddrinfo(host, port, family, type, proto, flags) # type: ignore
else:
return await trio.to_thread.run_sync(
_stdlib_socket.getaddrinfo,
@@ -206,7 +253,9 @@ def numeric_only_failure(exc):
)
-async def getnameinfo(sockaddr, flags):
+async def getnameinfo(
+ sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int
+) -> tuple[str, str]:
"""Look up a name given a numeric address.
Arguments and return values are identical to :func:`socket.getnameinfo`,
@@ -218,14 +267,15 @@ async def getnameinfo(sockaddr, flags):
"""
hr = _resolver.get(None)
if hr is not None:
- return await hr.getnameinfo(sockaddr, flags)
+ # waiting on ._abc to get typed
+ return await hr.getnameinfo(sockaddr, flags) # type: ignore
else:
return await trio.to_thread.run_sync(
_stdlib_socket.getnameinfo, sockaddr, flags, cancellable=True
)
-async def getprotobyname(name):
+async def getprotobyname(name: str) -> int:
"""Look up a protocol number by name. (Rarely used.)
Like :func:`socket.getprotobyname`, but async.
@@ -244,7 +294,7 @@ async def getprotobyname(name):
################################################################
-def from_stdlib_socket(sock):
+def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType:
"""Convert a standard library :class:`socket.socket` object into a Trio
socket object.
@@ -253,9 +303,14 @@ def from_stdlib_socket(sock):
@_wraps(_stdlib_socket.fromfd, assigned=(), updated=())
-def fromfd(fd, family, type, proto=0):
+def fromfd(
+ fd: SupportsIndex,
+ family: AddressFamily | int = _stdlib_socket.AF_INET,
+ type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
+ proto: int = 0,
+) -> _SocketType:
"""Like :func:`socket.fromfd`, but returns a Trio socket object."""
- family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd)
+ family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, int(fd))
return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto))
@@ -264,27 +319,41 @@ def fromfd(fd, family, type, proto=0):
):
@_wraps(_stdlib_socket.fromshare, assigned=(), updated=())
- def fromshare(*args, **kwargs):
- return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs))
+ def fromshare(info: bytes) -> _SocketType:
+ return from_stdlib_socket(_stdlib_socket.fromshare(info))
+
+
+if sys.platform == "win32":
+ FamilyT = int
+ TypeT = int
+ FamilyDefault = _stdlib_socket.AF_INET
+else:
+ FamilyDefault = None
+ FamilyT = Union[int, AddressFamily, None]
+ TypeT = Union[_stdlib_socket.socket, int]
@_wraps(_stdlib_socket.socketpair, assigned=(), updated=())
-def socketpair(*args, **kwargs):
+def socketpair(
+ family: FamilyT = FamilyDefault,
+ type: TypeT = SocketKind.SOCK_STREAM,
+ proto: int = 0,
+) -> tuple[_SocketType, _SocketType]:
"""Like :func:`socket.socketpair`, but returns a pair of Trio socket
objects.
"""
- left, right = _stdlib_socket.socketpair(*args, **kwargs)
+ left, right = _stdlib_socket.socketpair(family, type, proto)
return (from_stdlib_socket(left), from_stdlib_socket(right))
@_wraps(_stdlib_socket.socket, assigned=(), updated=())
def socket(
- family=_stdlib_socket.AF_INET,
- type=_stdlib_socket.SOCK_STREAM,
- proto=0,
- fileno=None,
-):
+ family: AddressFamily | int = _stdlib_socket.AF_INET,
+ type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
+ proto: int = 0,
+ fileno: int | None = None,
+) -> _SocketType:
"""Create a new Trio socket, like :class:`socket.socket`.
This function's behavior can be customized using
@@ -294,21 +363,32 @@ def socket(
if fileno is None:
sf = _socket_factory.get(None)
if sf is not None:
- return sf.socket(family, type, proto)
+ # waiting on ._abc to get typed
+ return sf.socket(family, type, proto) # type: ignore
else:
family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fileno)
stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno)
return from_stdlib_socket(stdlib_socket)
-def _sniff_sockopts_for_fileno(family, type, proto, fileno):
+def _sniff_sockopts_for_fileno(
+ family: AddressFamily | int,
+ type: SocketKind | int,
+ proto: int,
+ fileno: int | None,
+) -> tuple[AddressFamily | int, SocketKind | int, int]:
"""Correct SOCKOPTS for given fileno, falling back to provided values."""
# Wrap the raw fileno into a Python socket object
# This object might have the wrong metadata, but it lets us easily call getsockopt
# and then we'll throw it away and construct a new one with the correct metadata.
if sys.platform != "linux":
return family, type, proto
- from socket import SO_DOMAIN, SO_PROTOCOL, SO_TYPE, SOL_SOCKET
+ from socket import ( # type: ignore[attr-defined]
+ SO_DOMAIN,
+ SO_PROTOCOL,
+ SO_TYPE,
+ SOL_SOCKET,
+ )
sockobj = _stdlib_socket.socket(family, type, proto, fileno=fileno)
try:
@@ -338,19 +418,21 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno):
)
-def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False):
- fn = getattr(_stdlib_socket.socket, methname)
-
+def _make_simple_sock_method_wrapper(
+ fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
+ wait_fn: Callable,
+ maybe_avail: bool = False,
+) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]:
@_wraps(fn, assigned=("__name__",), updated=())
- async def wrapper(self, *args, **kwargs):
- return await self._nonblocking_helper(fn, args, kwargs, wait_fn)
+ async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T:
+ return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs)
- wrapper.__doc__ = f"""Like :meth:`socket.socket.{methname}`, but async.
+ wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async.
"""
if maybe_avail:
wrapper.__doc__ += (
- f"Only available on platforms where :meth:`socket.socket.{methname}` is "
+ f"Only available on platforms where :meth:`socket.socket.{fn.__name__}` is "
"available."
)
return wrapper
@@ -369,8 +451,21 @@ async def wrapper(self, *args, **kwargs):
# local=False means that the address is being used with connect() or sendto() or
# similar.
#
+
+
+# Using a TypeVar to indicate we return the same type of address appears to give errors
+# when passed a union of address types.
+# @overload likely works, but is extremely verbose.
# NOTE: this function does not always checkpoint
-async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, local):
+async def _resolve_address_nocp(
+ type: int,
+ family: AddressFamily,
+ proto: int,
+ *,
+ ipv6_v6only: bool | int,
+ address: Address,
+ local: bool,
+) -> Address:
# Do some pre-checking (or exit early for non-IP sockets)
if family == _stdlib_socket.AF_INET:
if not isinstance(address, tuple) or not len(address) == 2:
@@ -380,13 +475,15 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo
raise ValueError(
"address should be a (host, port, [flowinfo, [scopeid]]) tuple"
)
- elif family == _stdlib_socket.AF_UNIX:
+ elif family == getattr(_stdlib_socket, "AF_UNIX"):
# unwrap path-likes
+ assert isinstance(address, (str, bytes))
return os.fspath(address)
else:
return address
# -- From here on we know we have IPv4 or IPV6 --
+ host: str | None
host, port, *_ = address
# Fast path for the simple case: already-resolved IP address,
# already-resolved port. This is particularly important for UDP, since
@@ -424,18 +521,24 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo
# The above ignored any flowid and scopeid in the passed-in address,
# so restore them if present:
if family == _stdlib_socket.AF_INET6:
- normed = list(normed)
+ list_normed = list(normed)
assert len(normed) == 4
+ # typechecking certainly doesn't like this logic, but given just how broad
+ # Address is, it's kind of impossible to write the below without type: ignore
if len(address) >= 3:
- normed[2] = address[2]
+ list_normed[2] = address[2] # type: ignore
if len(address) >= 4:
- normed[3] = address[3]
- normed = tuple(normed)
+ list_normed[3] = address[3] # type: ignore
+ return tuple(list_normed) # type: ignore
return normed
+# TODO: stopping users from initializing this type should be done in a different way,
+# so SocketType can be used as a type. Note that this is *far* from trivial without
+# breaking subclasses of SocketType. Should maybe just add abstract methods to SocketType,
+# or rename _SocketType.
class SocketType:
- def __init__(self):
+ def __init__(self) -> NoReturn:
raise TypeError(
"SocketType is an abstract class; use trio.socket.socket if you "
"want to construct a socket object"
@@ -481,7 +584,7 @@ def __init__(self, sock: _stdlib_socket.socket):
"share",
}
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
if name in self._forward:
return getattr(self._sock, name)
raise AttributeError(name)
@@ -501,11 +604,11 @@ def __exit__(
return self._sock.__exit__(exc_type, exc_value, traceback)
@property
- def family(self) -> _stdlib_socket.AddressFamily:
+ def family(self) -> AddressFamily:
return self._sock.family
@property
- def type(self) -> _stdlib_socket.SocketKind:
+ def type(self) -> SocketKind:
return self._sock.type
@property
@@ -528,7 +631,7 @@ def close(self) -> None:
trio.lowlevel.notify_closing(self._sock)
self._sock.close()
- async def bind(self, address: tuple[object, ...] | str | bytes) -> None:
+ async def bind(self, address: Address) -> None:
address = await self._resolve_address_nocp(address, local=True)
if (
hasattr(_stdlib_socket, "AF_UNIX")
@@ -537,8 +640,7 @@ async def bind(self, address: tuple[object, ...] | str | bytes) -> None:
):
# Use a thread for the filesystem traversal (unless it's an
# abstract domain socket)
- # remove the `type: ignore` when run.sync is typed.
- return await trio.to_thread.run_sync(self._sock.bind, address) # type: ignore[no-any-return]
+ return await trio.to_thread.run_sync(self._sock.bind, address)
else:
# POSIX actually says that bind can return EWOULDBLOCK and
# complete asynchronously, like connect. But in practice AFAICT
@@ -566,7 +668,12 @@ def is_readable(self) -> bool:
async def wait_writable(self) -> None:
await _core.wait_writable(self._sock)
- async def _resolve_address_nocp(self, address, *, local):
+ async def _resolve_address_nocp(
+ self,
+ address: Address,
+ *,
+ local: bool,
+ ) -> Address:
if self.family == _stdlib_socket.AF_INET6:
ipv6_v6only = self._sock.getsockopt(
IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY
@@ -582,7 +689,19 @@ async def _resolve_address_nocp(self, address, *, local):
local=local,
)
- async def _nonblocking_helper(self, fn, args, kwargs, wait_fn):
+ # args and kwargs must be starred, otherwise pyright complains:
+ # '"args" member of ParamSpec is valid only when used with *args parameter'
+ # '"kwargs" member of ParamSpec is valid only when used with **kwargs parameter'
+ # wait_fn and fn must also be first in the signature
+ # 'Keyword parameter cannot appear in signature after ParamSpec args parameter'
+
+ async def _nonblocking_helper(
+ self,
+ wait_fn: Callable[[_stdlib_socket.socket], Awaitable],
+ fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> T:
# We have to reconcile two conflicting goals:
# - We want to make it look like we always blocked in doing these
# operations. The obvious way is to always do an IO wait before
@@ -618,9 +737,11 @@ async def _nonblocking_helper(self, fn, args, kwargs, wait_fn):
# accept
################################################################
- _accept = _make_simple_sock_method_wrapper("accept", _core.wait_readable)
+ _accept = _make_simple_sock_method_wrapper(
+ _stdlib_socket.socket.accept, _core.wait_readable
+ )
- async def accept(self):
+ async def accept(self) -> tuple[_SocketType, object]:
"""Like :meth:`socket.socket.accept`, but async."""
sock, addr = await self._accept()
return from_stdlib_socket(sock), addr
@@ -629,7 +750,7 @@ async def accept(self):
# connect
################################################################
- async def connect(self, address):
+ async def connect(self, address: Address) -> None:
# nonblocking connect is weird -- you call it to start things
# off, then the socket becomes writable as a completion
# notification. This means it isn't really cancellable... we close the
@@ -697,38 +818,69 @@ async def connect(self, address):
# Okay, the connect finished, but it might have failed:
err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR)
if err != 0:
- raise OSError(err, f"Error connecting to {address}: {os.strerror(err)}")
+ raise OSError(err, f"Error connecting to {address!r}: {os.strerror(err)}")
################################################################
# recv
################################################################
+ # Not possible to typecheck with a Callable (due to DefaultArg), nor with a
+ # callback Protocol (https://github.com/python/typing/discussions/1040)
+ # but this seems to work. If not explicitly defined then pyright --verifytypes will
+ # complain about AmbiguousType
if TYPE_CHECKING:
- async def recv(self, buffersize: int, flags: int = 0) -> bytes:
+ def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
...
- else:
- recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable)
+ # _make_simple_sock_method_wrapper is typed, so this check that the above is correct
+ recv = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recv, _core.wait_readable
+ )
################################################################
# recv_into
################################################################
- recv_into = _make_simple_sock_method_wrapper("recv_into", _core.wait_readable)
+ if TYPE_CHECKING:
+
+ def recv_into(
+ __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[int]:
+ ...
+
+ recv_into = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recv_into, _core.wait_readable
+ )
################################################################
# recvfrom
################################################################
- recvfrom = _make_simple_sock_method_wrapper("recvfrom", _core.wait_readable)
+ if TYPE_CHECKING:
+ # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any]
+ def recvfrom(
+ __self, __bufsize: int, __flags: int = 0
+ ) -> Awaitable[tuple[bytes, Address]]:
+ ...
+
+ recvfrom = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recvfrom, _core.wait_readable
+ )
################################################################
# recvfrom_into
################################################################
- recvfrom_into = _make_simple_sock_method_wrapper(
- "recvfrom_into", _core.wait_readable
+ if TYPE_CHECKING:
+ # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any]
+ def recvfrom_into(
+ __self, buffer: Buffer, nbytes: int = 0, flags: int = 0
+ ) -> Awaitable[tuple[int, Address]]:
+ ...
+
+ recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recvfrom_into, _core.wait_readable
)
################################################################
@@ -736,8 +888,15 @@ async def recv(self, buffersize: int, flags: int = 0) -> bytes:
################################################################
if hasattr(_stdlib_socket.socket, "recvmsg"):
- recvmsg = _make_simple_sock_method_wrapper(
- "recvmsg", _core.wait_readable, maybe_avail=True
+ if TYPE_CHECKING:
+
+ def recvmsg(
+ __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0
+ ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]:
+ ...
+
+ recvmsg = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recvmsg, _core.wait_readable, maybe_avail=True
)
################################################################
@@ -745,29 +904,58 @@ async def recv(self, buffersize: int, flags: int = 0) -> bytes:
################################################################
if hasattr(_stdlib_socket.socket, "recvmsg_into"):
- recvmsg_into = _make_simple_sock_method_wrapper(
- "recvmsg_into", _core.wait_readable, maybe_avail=True
+ if TYPE_CHECKING:
+
+ def recvmsg_into(
+ __self,
+ __buffers: Iterable[Buffer],
+ __ancbufsize: int = 0,
+ __flags: int = 0,
+ ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]:
+ ...
+
+ recvmsg_into = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.recvmsg_into, _core.wait_readable, maybe_avail=True
)
################################################################
# send
################################################################
- send = _make_simple_sock_method_wrapper("send", _core.wait_writable)
+ if TYPE_CHECKING:
+
+ def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]:
+ ...
+
+ send = _make_simple_sock_method_wrapper( # noqa: F811
+ _stdlib_socket.socket.send, _core.wait_writable
+ )
################################################################
# sendto
################################################################
+ @overload
+ async def sendto(
+ self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer
+ ) -> int:
+ ...
+
+ @overload
+ async def sendto(
+ self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer
+ ) -> int:
+ ...
+
@_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=())
- async def sendto(self, *args):
+ async def sendto(self, *args: Any) -> int:
"""Similar to :meth:`socket.socket.sendto`, but async."""
# args is: data[, flags], address)
# and kwargs are not accepted
- args = list(args)
- args[-1] = await self._resolve_address_nocp(args[-1], local=False)
+ args_list = list(args)
+ args_list[-1] = await self._resolve_address_nocp(args[-1], local=False)
return await self._nonblocking_helper(
- _stdlib_socket.socket.sendto, args, {}, _core.wait_writable
+ _core.wait_writable, _stdlib_socket.socket.sendto, *args_list
)
################################################################
@@ -779,20 +967,28 @@ async def sendto(self, *args):
):
@_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=())
- async def sendmsg(self, *args):
+ async def sendmsg(
+ self,
+ __buffers: Iterable[Buffer],
+ __ancdata: Iterable[tuple[int, int, Buffer]] = (),
+ __flags: int = 0,
+ __address: Address | None = None,
+ ) -> int:
"""Similar to :meth:`socket.socket.sendmsg`, but async.
Only available on platforms where :meth:`socket.socket.sendmsg` is
available.
"""
- # args is: buffers[, ancdata[, flags[, address]]]
- # and kwargs are not accepted
- if len(args) == 4 and args[-1] is not None:
- args = list(args)
- args[-1] = await self._resolve_address_nocp(args[-1], local=False)
+ if __address is not None:
+ __address = await self._resolve_address_nocp(__address, local=False)
return await self._nonblocking_helper(
- _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable
+ _core.wait_writable,
+ _stdlib_socket.socket.sendmsg,
+ __buffers,
+ __ancdata,
+ __flags,
+ __address,
)
################################################################
diff --git a/trio/_sync.py b/trio/_sync.py
index 5a7f240d5e..96c6025e99 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -8,7 +8,7 @@
import trio
from . import _core
-from ._core import ParkingLot, enable_ki_protection
+from ._core import Abort, ParkingLot, RaiseCancelT, enable_ki_protection
from ._util import Final
if TYPE_CHECKING:
@@ -87,7 +87,7 @@ async def wait(self) -> None:
task = _core.current_task()
self._tasks.add(task)
- def abort_fn(_):
+ def abort_fn(_: RaiseCancelT) -> Abort:
self._tasks.remove(task)
return _core.Abort.SUCCEEDED
@@ -143,7 +143,7 @@ class CapacityLimiterStatistics:
borrowed_tokens: int = attr.ib()
total_tokens: int | float = attr.ib()
- borrowers: list[Task] = attr.ib()
+ borrowers: list[object] = attr.ib()
tasks_waiting: int = attr.ib()
@@ -204,9 +204,9 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final):
# total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
def __init__(self, total_tokens: int | float):
self._lot = ParkingLot()
- self._borrowers: set[Task] = set()
+ self._borrowers: set[object] = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
- self._pending_borrowers: dict[Task, Task] = {}
+ self._pending_borrowers: dict[Task, object] = {}
# invoke the property setter for validation
self.total_tokens: int | float = total_tokens
assert self._total_tokens == total_tokens
@@ -268,7 +268,7 @@ def acquire_nowait(self) -> None:
self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task())
@enable_ki_protection
- def acquire_on_behalf_of_nowait(self, borrower: Task) -> None:
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, without
blocking.
@@ -307,7 +307,7 @@ async def acquire(self) -> None:
await self.acquire_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- async def acquire_on_behalf_of(self, borrower: Task) -> None:
+ async def acquire_on_behalf_of(self, borrower: object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, blocking if
necessary.
@@ -347,7 +347,7 @@ def release(self) -> None:
self.release_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- def release_on_behalf_of(self, borrower: Task) -> None:
+ def release_on_behalf_of(self, borrower: object) -> None:
"""Put a token back into the sack on behalf of ``borrower``.
Raises:
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 9d7d7aa912..18b3e0e51c 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
- "completenessScore": 0.8764044943820225,
+ "completenessScore": 0.8892455858747994,
"exportedSymbolCounts": {
"withAmbiguousType": 1,
- "withKnownType": 546,
- "withUnknownType": 76
+ "withKnownType": 554,
+ "withUnknownType": 68
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
@@ -46,8 +46,8 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 433,
- "withUnknownType": 135
+ "withKnownType": 467,
+ "withUnknownType": 114
},
"packageName": "trio",
"symbols": [
@@ -66,10 +66,6 @@
"trio._abc.Listener.accept",
"trio._abc.SocketFactory.socket",
"trio._core._entry_queue.TrioToken.run_sync_soon",
- "trio._core._local.RunVar.__repr__",
- "trio._core._local.RunVar.get",
- "trio._core._local.RunVar.reset",
- "trio._core._local.RunVar.set",
"trio._core._mock_clock.MockClock.jump",
"trio._core._run.Nursery.start",
"trio._core._run.Nursery.start_soon",
@@ -98,8 +94,6 @@
"trio._dtls.DTLSEndpoint.serve",
"trio._dtls.DTLSEndpoint.socket",
"trio._highlevel_socket.SocketListener",
- "trio._highlevel_socket.SocketListener.__init__",
- "trio._highlevel_socket.SocketStream.__init__",
"trio._highlevel_socket.SocketStream.getsockopt",
"trio._highlevel_socket.SocketStream.send_all",
"trio._highlevel_socket.SocketStream.setsockopt",
@@ -117,17 +111,6 @@
"trio._path.Path.__rtruediv__",
"trio._path.Path.__truediv__",
"trio._path.Path.open",
- "trio._socket._SocketType.__getattr__",
- "trio._socket._SocketType.accept",
- "trio._socket._SocketType.connect",
- "trio._socket._SocketType.recv_into",
- "trio._socket._SocketType.recvfrom",
- "trio._socket._SocketType.recvfrom_into",
- "trio._socket._SocketType.recvmsg",
- "trio._socket._SocketType.recvmsg_into",
- "trio._socket._SocketType.send",
- "trio._socket._SocketType.sendmsg",
- "trio._socket._SocketType.sendto",
"trio._ssl.SSLListener",
"trio._ssl.SSLListener.__init__",
"trio._ssl.SSLListener.accept",
@@ -190,15 +173,8 @@
"trio.serve_listeners",
"trio.serve_ssl_over_tcp",
"trio.serve_tcp",
- "trio.socket.from_stdlib_socket",
- "trio.socket.fromfd",
- "trio.socket.getaddrinfo",
- "trio.socket.getnameinfo",
- "trio.socket.getprotobyname",
"trio.socket.set_custom_hostname_resolver",
"trio.socket.set_custom_socket_factory",
- "trio.socket.socket",
- "trio.socket.socketpair",
"trio.testing._memory_streams.MemoryReceiveStream.__init__",
"trio.testing._memory_streams.MemoryReceiveStream.aclose",
"trio.testing._memory_streams.MemoryReceiveStream.close",
diff --git a/trio/_threads.py b/trio/_threads.py
index 807212e0f9..3fbab05750 100644
--- a/trio/_threads.py
+++ b/trio/_threads.py
@@ -1,16 +1,19 @@
+from __future__ import annotations
+
import contextvars
import functools
import inspect
import queue as stdlib_queue
import threading
from itertools import count
-from typing import Optional
+from typing import Any, Callable, Optional, TypeVar
import attr
import outcome
from sniffio import current_async_library_cvar
import trio
+from trio._core._traps import RaiseCancelT
from ._core import (
RunVar,
@@ -22,10 +25,12 @@
from ._sync import CapacityLimiter
from ._util import coroutine_or_error
+T = TypeVar("T")
+
# Global due to Threading API, thread local storage for trio token
TOKEN_LOCAL = threading.local()
-_limiter_local = RunVar("limiter")
+_limiter_local: RunVar[CapacityLimiter] = RunVar("limiter")
# I pulled this number out of the air; it isn't based on anything. Probably we
# should make some kind of measurements to pick a good value.
DEFAULT_LIMIT = 40
@@ -59,8 +64,12 @@ class ThreadPlaceholder:
@enable_ki_protection
async def to_thread_run_sync(
- sync_fn, *args, thread_name: Optional[str] = None, cancellable=False, limiter=None
-):
+ sync_fn: Callable[..., T],
+ *args: Any,
+ thread_name: Optional[str] = None,
+ cancellable: bool = False,
+ limiter: CapacityLimiter | None = None,
+) -> T:
"""Convert a blocking operation into an async operation using a thread.
These two lines are equivalent::
@@ -152,7 +161,7 @@ async def to_thread_run_sync(
# Holds a reference to the task that's blocked in this function waiting
# for the result – or None if this function was cancelled and we should
# discard the result.
- task_register = [trio.lowlevel.current_task()]
+ task_register: list[trio.lowlevel.Task | None] = [trio.lowlevel.current_task()]
name = f"trio.to_thread.run_sync-{next(_thread_counter)}"
placeholder = ThreadPlaceholder(name)
@@ -217,14 +226,15 @@ def deliver_worker_fn_result(result):
limiter.release_on_behalf_of(placeholder)
raise
- def abort(_):
+ def abort(_: RaiseCancelT) -> trio.lowlevel.Abort:
if cancellable:
task_register[0] = None
return trio.lowlevel.Abort.SUCCEEDED
else:
return trio.lowlevel.Abort.FAILED
- return await trio.lowlevel.wait_task_rescheduled(abort)
+ # wait_task_rescheduled return value cannot be typed
+ return await trio.lowlevel.wait_task_rescheduled(abort) # type: ignore[no-any-return]
def _run_fn_as_system_task(cb, fn, *args, context, trio_token=None):
From ba5a6ee9bc82f693f4be7974ee90f92a5e61bf0c Mon Sep 17 00:00:00 2001
From: Spencer Brown
Date: Sat, 15 Jul 2023 10:38:05 +1000
Subject: [PATCH 02/16] Change fromfd() to use the correct type for file
descriptors
---
trio/_socket.py | 15 ++++++++++-----
1 file changed, 10 insertions(+), 5 deletions(-)
diff --git a/trio/_socket.py b/trio/_socket.py
index dc06e5c755..c7f4a62e9e 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -12,7 +12,7 @@
Awaitable,
Callable,
NoReturn,
- SupportsIndex,
+ SupportsInt,
Tuple,
TypeVar,
Union,
@@ -20,7 +20,6 @@
)
import idna as _idna
-from typing_extensions import Concatenate, ParamSpec
import trio
@@ -30,10 +29,16 @@
from collections.abc import Iterable
from types import TracebackType
- from typing_extensions import Buffer, Self, TypeAlias
+ from typing_extensions import Buffer, Self, TypeAlias, Concatenate, ParamSpec, SupportsIndex
from ._abc import HostnameResolver, SocketFactory
+ # Duplicated from _socket type stubs
+ if sys.version_info >= (3, 8):
+ FileDescriptor: TypeAlias = SupportsIndex
+ else:
+ FileDescriptor: TypeAlias = SupportsInt
+
T = TypeVar("T")
P = ParamSpec("P")
@@ -88,7 +93,7 @@ async def __aexit__(
################################################################
try:
- from socket import IPPROTO_IPV6
+ from socket import IPPROTO_IPV6 # type: ignore
except ImportError:
# Before Python 3.8, Windows is missing IPPROTO_IPV6
# https://bugs.python.org/issue29515
@@ -304,7 +309,7 @@ def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType:
@_wraps(_stdlib_socket.fromfd, assigned=(), updated=())
def fromfd(
- fd: SupportsIndex,
+ fd: FileDescriptor,
family: AddressFamily | int = _stdlib_socket.AF_INET,
type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
From aa32d87dd620511d2f225fe2f7a2f607511636d5 Mon Sep 17 00:00:00 2001
From: Spencer Brown
Date: Sat, 15 Jul 2023 14:04:14 +1000
Subject: [PATCH 03/16] This also needs to be guarded under TYPE_CHECKING
---
trio/_socket.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/trio/_socket.py b/trio/_socket.py
index c7f4a62e9e..56e4a29c2f 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -39,9 +39,10 @@
else:
FileDescriptor: TypeAlias = SupportsInt
+ P = ParamSpec("P")
+
T = TypeVar("T")
-P = ParamSpec("P")
# must use old-style typing for TypeAlias
Address: TypeAlias = Union[
From 30e272bee5a1b0dec5be598c551a2e29e0f9e676 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 17 Jul 2023 10:01:24 +0200
Subject: [PATCH 04/16] revert changes for python <3.8
---
trio/_socket.py | 20 ++++----------------
1 file changed, 4 insertions(+), 16 deletions(-)
diff --git a/trio/_socket.py b/trio/_socket.py
index 56e4a29c2f..3e4c198077 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -12,7 +12,7 @@
Awaitable,
Callable,
NoReturn,
- SupportsInt,
+ SupportsIndex,
Tuple,
TypeVar,
Union,
@@ -29,16 +29,10 @@
from collections.abc import Iterable
from types import TracebackType
- from typing_extensions import Buffer, Self, TypeAlias, Concatenate, ParamSpec, SupportsIndex
+ from typing_extensions import Buffer, Concatenate, ParamSpec, Self, TypeAlias
from ._abc import HostnameResolver, SocketFactory
- # Duplicated from _socket type stubs
- if sys.version_info >= (3, 8):
- FileDescriptor: TypeAlias = SupportsIndex
- else:
- FileDescriptor: TypeAlias = SupportsInt
-
P = ParamSpec("P")
@@ -93,13 +87,7 @@ async def __aexit__(
# CONSTANTS
################################################################
-try:
- from socket import IPPROTO_IPV6 # type: ignore
-except ImportError:
- # Before Python 3.8, Windows is missing IPPROTO_IPV6
- # https://bugs.python.org/issue29515
- if sys.platform == "win32": # pragma: no branch
- IPPROTO_IPV6 = 41
+from socket import IPPROTO_IPV6
################################################################
# Overrides
@@ -310,7 +298,7 @@ def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType:
@_wraps(_stdlib_socket.fromfd, assigned=(), updated=())
def fromfd(
- fd: FileDescriptor,
+ fd: SupportsIndex,
family: AddressFamily | int = _stdlib_socket.AF_INET,
type: SocketKind | int = _stdlib_socket.SOCK_STREAM,
proto: int = 0,
From 399cce7e0492e4daf4ecf78de85f8e2f581c2dc3 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 17 Jul 2023 12:03:29 +0200
Subject: [PATCH 05/16] fixes after review from TeamSpen210
---
pyproject.toml | 1 +
trio/_core/_local.py | 53 +++++++---------
trio/_socket.py | 111 ++++++++++++++++++++++------------
trio/_sync.py | 12 ++--
trio/_tests/verify_types.json | 2 +-
5 files changed, 104 insertions(+), 75 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 1f6b15e45e..52751c2a11 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,6 +50,7 @@ module = [
"trio._sync",
]
disallow_untyped_defs = true
+disallow_any_generics = true
[tool.pytest.ini_options]
addopts = ["--strict-markers", "--strict-config"]
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index fe509ca7ad..dd3e9518be 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -1,32 +1,32 @@
from __future__ import annotations
-from typing import Generic, TypeVar, overload
+from typing import Generic, TypeVar, final
# Runvar implementations
import attr
-from .._util import Final
+from .._util import Final, NoPublicConstructor
from . import _run
+# `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released
+
T = TypeVar("T")
-C = TypeVar("C", bound="_RunVarToken")
-class NoValue(object):
+@final
+class _NoValue:
...
-@attr.s(eq=False, hash=False, slots=True)
-class _RunVarToken(Generic[T]):
- _no_value = NoValue()
-
+@attr.s(eq=False, hash=False, slots=False)
+class RunVarToken(Generic[T], metaclass=NoPublicConstructor):
_var: RunVar[T] = attr.ib()
- previous_value: T | NoValue = attr.ib(default=_no_value)
+ previous_value: T | type[_NoValue] = attr.ib(default=_NoValue)
redeemed: bool = attr.ib(default=False, init=False)
@classmethod
- def empty(cls: type[C], var: RunVar[T]) -> C:
- return cls(var)
+ def _empty(cls, var: RunVar[T]) -> RunVarToken[T]:
+ return cls._create(var)
@attr.s(eq=False, hash=False, slots=True)
@@ -39,19 +39,10 @@ class RunVar(Generic[T], metaclass=Final):
"""
- _NO_DEFAULT = NoValue()
_name: str = attr.ib()
- _default: T | NoValue = attr.ib(default=_NO_DEFAULT)
-
- @overload
- def get(self, default: T) -> T:
- ...
-
- @overload
- def get(self, default: NoValue = _NO_DEFAULT) -> T | NoValue:
- ...
+ _default: T | type[_NoValue] = attr.ib(default=_NoValue)
- def get(self, default: T | NoValue = _NO_DEFAULT) -> T | NoValue:
+ def get(self, default: T | type[_NoValue] = _NoValue) -> T:
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
# not typed yet
@@ -60,15 +51,15 @@ def get(self, default: T | NoValue = _NO_DEFAULT) -> T | NoValue:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
# contextvars consistency
- if default is not self._NO_DEFAULT:
- return default
+ if default is not _NoValue:
+ return default # type: ignore[return-value]
- if self._default is not self._NO_DEFAULT:
- return self._default
+ if self._default is not _NoValue:
+ return self._default # type: ignore[return-value]
raise LookupError(self) from None
- def set(self, value: T) -> _RunVarToken[T]:
+ def set(self, value: T) -> RunVarToken[T]:
"""Sets the value of this :class:`RunVar` for this current run
call.
@@ -76,16 +67,16 @@ def set(self, value: T) -> _RunVarToken[T]:
try:
old_value = self.get()
except LookupError:
- token: _RunVarToken[T] = _RunVarToken.empty(self)
+ token = RunVarToken[T]._empty(self)
else:
- token = _RunVarToken(self, old_value)
+ token = RunVarToken[T]._create(self, old_value)
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index]
return token
- def reset(self, token: _RunVarToken[T]) -> None:
+ def reset(self, token: RunVarToken[T]) -> None:
"""Resets the value of this :class:`RunVar` to what it was
previously specified by the token.
@@ -101,7 +92,7 @@ def reset(self, token: _RunVarToken[T]) -> None:
previous = token.previous_value
try:
- if previous is _RunVarToken._no_value:
+ if previous is _NoValue:
_run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type]
else:
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment]
diff --git a/trio/_socket.py b/trio/_socket.py
index 3e4c198077..5aa135905c 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -16,6 +16,7 @@
Tuple,
TypeVar,
Union,
+ cast,
overload,
)
@@ -318,13 +319,13 @@ def fromshare(info: bytes) -> _SocketType:
if sys.platform == "win32":
- FamilyT = int
- TypeT = int
+ FamilyT: TypeAlias = int
+ TypeT: TypeAlias = int
FamilyDefault = _stdlib_socket.AF_INET
else:
FamilyDefault = None
- FamilyT = Union[int, AddressFamily, None]
- TypeT = Union[_stdlib_socket.socket, int]
+ FamilyT: TypeAlias = Union[int, AddressFamily, None]
+ TypeT: TypeAlias = Union[_stdlib_socket.socket, int]
@_wraps(_stdlib_socket.socketpair, assigned=(), updated=())
@@ -414,7 +415,7 @@ def _sniff_sockopts_for_fileno(
def _make_simple_sock_method_wrapper(
fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
- wait_fn: Callable,
+ wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]],
maybe_avail: bool = False,
) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]:
@_wraps(fn, assigned=("__name__",), updated=())
@@ -518,7 +519,7 @@ async def _resolve_address_nocp(
list_normed = list(normed)
assert len(normed) == 4
# typechecking certainly doesn't like this logic, but given just how broad
- # Address is, it's kind of impossible to write the below without type: ignore
+ # Address is, it's quite cumbersome to write the below without type: ignore
if len(address) >= 3:
list_normed[2] = address[2] # type: ignore
if len(address) >= 4:
@@ -555,36 +556,72 @@ def __init__(self, sock: _stdlib_socket.socket):
# Simple + portable methods and attributes
################################################################
- # NB this doesn't work because for loops don't create a scope
- # for _name in [
- # ]:
- # _meth = getattr(_stdlib_socket.socket, _name)
- # @_wraps(_meth, assigned=("__name__", "__doc__"), updated=())
- # def _wrapped(self, *args, **kwargs):
- # return getattr(self._sock, _meth)(*args, **kwargs)
- # locals()[_meth] = _wrapped
- # del _name, _meth, _wrapped
-
- _forward = {
- "detach",
- "get_inheritable",
- "set_inheritable",
- "fileno",
- "getpeername",
- "getsockname",
- "getsockopt",
- "setsockopt",
- "listen",
- "share",
- }
-
- def __getattr__(self, name: str) -> Any:
- if name in self._forward:
- return getattr(self._sock, name)
- raise AttributeError(name)
-
- def __dir__(self) -> Iterable[str]:
- return [*super().__dir__(), *self._forward]
+ # forwarded methods
+ def detach(self) -> int:
+ return self._sock.detach()
+
+ def fileno(self) -> int:
+ return self._sock.fileno()
+
+ def getpeername(self) -> Any:
+ return self._sock.getpeername()
+
+ def getsockname(self) -> Any:
+ return self._sock.getsockname()
+
+ @overload
+ def getsockopt(self, __level: int, __optname: int) -> int:
+ ...
+
+ @overload
+ def getsockopt(self, __level: int, __optname: int, __buflen: int) -> bytes:
+ ...
+
+ def getsockopt(
+ self, __level: int, __optname: int, __buflen: int | None = None
+ ) -> int | bytes:
+ if __buflen is None:
+ return self._sock.getsockopt(__level, __optname)
+ return self._sock.getsockopt(__level, __optname, __buflen)
+
+ @overload
+ def setsockopt(self, __level: int, __optname: int, __value: int | Buffer) -> None:
+ ...
+
+ @overload
+ def setsockopt(
+ self, __level: int, __optname: int, __value: None, __optlen: int
+ ) -> None:
+ ...
+
+ def setsockopt(
+ self,
+ __level: int,
+ __optname: int,
+ __value: int | Buffer | None,
+ __optlen: int | None = None,
+ ) -> None:
+ if __optlen is None:
+ return self._sock.setsockopt(
+ __level, __optname, cast("int|Buffer", __value)
+ )
+ return self._sock.setsockopt(__level, __optname, cast(None, __value), __optlen)
+
+ def listen(self, __backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
+ return self._sock.listen(__backlog)
+
+ def get_inheritable(self) -> bool:
+ return self._sock.get_inheritable()
+
+ def set_inheritable(self, inheritable: bool) -> None:
+ return self._sock.set_inheritable(inheritable)
+
+ if sys.platform == "win32" or (
+ not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share")
+ ):
+
+ def share(self, __process_id: int) -> bytes:
+ return self._sock.share(__process_id)
def __enter__(self) -> Self:
return self
@@ -691,7 +728,7 @@ async def _resolve_address_nocp(
async def _nonblocking_helper(
self,
- wait_fn: Callable[[_stdlib_socket.socket], Awaitable],
+ wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]],
fn: Callable[Concatenate[_stdlib_socket.socket, P], T],
*args: P.args,
**kwargs: P.kwargs,
diff --git a/trio/_sync.py b/trio/_sync.py
index 96c6025e99..9764ddce2d 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -143,7 +143,7 @@ class CapacityLimiterStatistics:
borrowed_tokens: int = attr.ib()
total_tokens: int | float = attr.ib()
- borrowers: list[object] = attr.ib()
+ borrowers: list[Task | object] = attr.ib()
tasks_waiting: int = attr.ib()
@@ -204,9 +204,9 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final):
# total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
def __init__(self, total_tokens: int | float):
self._lot = ParkingLot()
- self._borrowers: set[object] = set()
+ self._borrowers: set[Task | object] = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
- self._pending_borrowers: dict[Task, object] = {}
+ self._pending_borrowers: dict[Task, Task | object] = {}
# invoke the property setter for validation
self.total_tokens: int | float = total_tokens
assert self._total_tokens == total_tokens
@@ -268,7 +268,7 @@ def acquire_nowait(self) -> None:
self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task())
@enable_ki_protection
- def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
+ def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, without
blocking.
@@ -307,7 +307,7 @@ async def acquire(self) -> None:
await self.acquire_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- async def acquire_on_behalf_of(self, borrower: object) -> None:
+ async def acquire_on_behalf_of(self, borrower: Task | object) -> None:
"""Borrow a token from the sack on behalf of ``borrower``, blocking if
necessary.
@@ -347,7 +347,7 @@ def release(self) -> None:
self.release_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
- def release_on_behalf_of(self, borrower: object) -> None:
+ def release_on_behalf_of(self, borrower: Task | object) -> None:
"""Put a token back into the sack on behalf of ``borrower``.
Raises:
diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json
index 18b3e0e51c..dd9e8018b3 100644
--- a/trio/_tests/verify_types.json
+++ b/trio/_tests/verify_types.json
@@ -46,7 +46,7 @@
],
"otherSymbolCounts": {
"withAmbiguousType": 8,
- "withKnownType": 467,
+ "withKnownType": 473,
"withUnknownType": 114
},
"packageName": "trio",
From dd828c2bdf7eb591b6eafd8b27283cfd23273cbf Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 17 Jul 2023 12:48:39 +0200
Subject: [PATCH 06/16] "fix" readthedocs build
---
docs/source/conf.py | 4 ++++
docs/source/reference-io.rst | 8 ++++++++
trio/socket.py | 1 +
3 files changed, 13 insertions(+)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 68a5a22a81..0992297014 100755
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -62,7 +62,11 @@
("py:obj", "trio._abc.SendType"),
("py:obj", "trio._abc.T"),
("py:obj", "trio._abc.T_resource"),
+ ("py:class", "trio._threads.T"),
+ # why aren't these found in stdlib?
("py:class", "types.FrameType"),
+ ("py:class", "socket.AddressFamily"),
+ ("py:class", "socket.SocketKind"),
]
autodoc_inherit_docstrings = False
default_role = "obj"
diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst
index a3291ef2ae..cc1ccb127a 100644
--- a/docs/source/reference-io.rst
+++ b/docs/source/reference-io.rst
@@ -501,6 +501,14 @@ Socket objects
* :meth:`~socket.socket.set_inheritable`
* :meth:`~socket.socket.get_inheritable`
+The internal SocketType
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: _SocketType
+..
+ TODO: adding `:members:` here gives error due to overload+_wraps on `sendto`
+ TODO: rewrite ... all of the above when fixing _SocketType vs SocketType
+
+
.. currentmodule:: trio
diff --git a/trio/socket.py b/trio/socket.py
index a9e276c782..f6aebb6a6e 100644
--- a/trio/socket.py
+++ b/trio/socket.py
@@ -35,6 +35,7 @@
# import the overwrites
from ._socket import (
SocketType as SocketType,
+ _SocketType as _SocketType,
from_stdlib_socket as from_stdlib_socket,
fromfd as fromfd,
getaddrinfo as getaddrinfo,
From 87a05bee45b70a02688b01be12086722c48d0ebe Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 21 Jul 2023 12:51:23 +0200
Subject: [PATCH 07/16] set merge strategy in .gitattributes
---
.gitattributes | 2 ++
1 file changed, 2 insertions(+)
diff --git a/.gitattributes b/.gitattributes
index 991065e069..7fbcb4fe2d 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -2,3 +2,5 @@
trio/_core/_generated* linguist-generated=true
# Treat generated files as binary in git diff
trio/_core/_generated* -diff
+# don't merge the generated json file, let the user (script) handle it
+trio/_tests/verify_types.json merge=binary
From 06e004d6d1def25f0bc82327b5a5da63dc5a4c29 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 21 Jul 2023 15:51:47 +0200
Subject: [PATCH 08/16] fixes after review by A5rocks
---
trio/_core/_local.py | 2 +-
trio/_socket.py | 16 +++++++---------
2 files changed, 8 insertions(+), 10 deletions(-)
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index dd3e9518be..b9dada64fe 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -67,7 +67,7 @@ def set(self, value: T) -> RunVarToken[T]:
try:
old_value = self.get()
except LookupError:
- token = RunVarToken[T]._empty(self)
+ token = RunVarToken._empty(self)
else:
token = RunVarToken[T]._create(self, old_value)
diff --git a/trio/_socket.py b/trio/_socket.py
index 57e4d32a71..f68338f966 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -5,6 +5,7 @@
import socket as _stdlib_socket
import sys
from functools import wraps as _wraps
+from operator import index
from socket import AddressFamily, SocketKind
from typing import (
TYPE_CHECKING,
@@ -39,7 +40,7 @@
T = TypeVar("T")
-# must use old-style typing for TypeAlias
+# must use old-style typing because it's evaluated at runtime
Address: TypeAlias = Union[
str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int]
]
@@ -225,10 +226,9 @@ def numeric_only_failure(exc: BaseException) -> bool:
# idna.encode will error out if the hostname has Capital Letters
# in it; with uts46=True it will lowercase them instead.
host = _idna.encode(host, uts46=True)
- hr: HostnameResolver | None = _resolver.get(None)
- # waiting on ._abc to get typed
+ hr = _resolver.get(None)
if hr is not None:
- return await hr.getaddrinfo(host, port, family, type, proto, flags) # type: ignore
+ return await hr.getaddrinfo(host, port, family, type, proto, flags)
else:
return await trio.to_thread.run_sync(
_stdlib_socket.getaddrinfo,
@@ -256,8 +256,7 @@ async def getnameinfo(
"""
hr = _resolver.get(None)
if hr is not None:
- # waiting on ._abc to get typed
- return await hr.getnameinfo(sockaddr, flags) # type: ignore
+ return await hr.getnameinfo(sockaddr, flags)
else:
return await trio.to_thread.run_sync(
_stdlib_socket.getnameinfo, sockaddr, flags, cancellable=True
@@ -299,7 +298,7 @@ def fromfd(
proto: int = 0,
) -> _SocketType:
"""Like :func:`socket.fromfd`, but returns a Trio socket object."""
- family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, int(fd))
+ family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd))
return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto))
@@ -352,8 +351,7 @@ def socket(
if fileno is None:
sf = _socket_factory.get(None)
if sf is not None:
- # waiting on ._abc to get typed
- return sf.socket(family, type, proto) # type: ignore
+ return sf.socket(family, type, proto)
else:
family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fileno)
stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno)
From 21422523ebebea23c455ede5d7f3c30a155ca0a9 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sun, 23 Jul 2023 23:56:24 +0200
Subject: [PATCH 09/16] use / for pos-only args
---
.coveragerc | 1 +
trio/_socket.py | 47 +++++++++++++++++++++++------------------------
2 files changed, 24 insertions(+), 24 deletions(-)
diff --git a/.coveragerc b/.coveragerc
index 98f923bd8e..d577aa8adf 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -21,6 +21,7 @@ exclude_lines =
abc.abstractmethod
if TYPE_CHECKING:
if _t.TYPE_CHECKING:
+ @overload
partial_branches =
pragma: no branch
diff --git a/trio/_socket.py b/trio/_socket.py
index f68338f966..e9fa8f3537 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -562,45 +562,42 @@ def getsockname(self) -> Any:
return self._sock.getsockname()
@overload
- def getsockopt(self, __level: int, __optname: int) -> int:
+ def getsockopt(self, /, level: int, optname: int) -> int:
...
@overload
- def getsockopt(self, __level: int, __optname: int, __buflen: int) -> bytes:
+ def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes:
...
def getsockopt(
- self, __level: int, __optname: int, __buflen: int | None = None
+ self, /, level: int, optname: int, buflen: int | None = None
) -> int | bytes:
- if __buflen is None:
- return self._sock.getsockopt(__level, __optname)
- return self._sock.getsockopt(__level, __optname, __buflen)
+ if buflen is None:
+ return self._sock.getsockopt(level, optname)
+ return self._sock.getsockopt(level, optname, buflen)
@overload
- def setsockopt(self, __level: int, __optname: int, __value: int | Buffer) -> None:
+ def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None:
...
@overload
- def setsockopt(
- self, __level: int, __optname: int, __value: None, __optlen: int
- ) -> None:
+ def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None:
...
def setsockopt(
self,
- __level: int,
- __optname: int,
- __value: int | Buffer | None,
- __optlen: int | None = None,
+ /,
+ level: int,
+ optname: int,
+ value: int | Buffer | None,
+ optlen: int | None = None,
) -> None:
- if __optlen is None:
- return self._sock.setsockopt(
- __level, __optname, cast("int|Buffer", __value)
- )
- return self._sock.setsockopt(__level, __optname, cast(None, __value), __optlen)
+ if optlen is None:
+ return self._sock.setsockopt(level, optname, cast("int|Buffer", value))
+ return self._sock.setsockopt(level, optname, cast(None, value), optlen)
- def listen(self, __backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
- return self._sock.listen(__backlog)
+ def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
+ return self._sock.listen(backlog)
def get_inheritable(self) -> bool:
return self._sock.get_inheritable()
@@ -612,8 +609,8 @@ def set_inheritable(self, inheritable: bool) -> None:
not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share")
):
- def share(self, __process_id: int) -> bytes:
- return self._sock.share(__process_id)
+ def share(self, /, process_id: int) -> bytes:
+ return self._sock.share(process_id)
def __enter__(self) -> Self:
return self
@@ -856,7 +853,9 @@ async def connect(self, address: Address) -> None:
def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]:
...
- # _make_simple_sock_method_wrapper is typed, so this check that the above is correct
+ # _make_simple_sock_method_wrapper is typed, so this checks that the above is correct
+ # this requires that we refrain from using `/` to specify pos-only
+ # args, or mypy thinks the signature differs from typeshed.
recv = _make_simple_sock_method_wrapper( # noqa: F811
_stdlib_socket.socket.recv, _core.wait_readable
)
From 831087fe7bb7ce80fe59a586afafbd7e655d0282 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 27 Jul 2023 14:27:11 +0200
Subject: [PATCH 10/16] add test to setsockopt, update comments
---
trio/_core/_local.py | 3 +--
trio/_socket.py | 15 +++++++++++----
trio/_sync.py | 5 +++++
trio/_tests/test_socket.py | 17 +++++++++++++++++
4 files changed, 34 insertions(+), 6 deletions(-)
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index b9dada64fe..c1e6fa01fe 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -8,8 +8,6 @@
from .._util import Final, NoPublicConstructor
from . import _run
-# `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released
-
T = TypeVar("T")
@@ -51,6 +49,7 @@ def get(self, default: T | type[_NoValue] = _NoValue) -> T:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
# contextvars consistency
+ # `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released
if default is not _NoValue:
return default # type: ignore[return-value]
diff --git a/trio/_socket.py b/trio/_socket.py
index e9fa8f3537..e17ae725a9 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -17,7 +17,6 @@
Tuple,
TypeVar,
Union,
- cast,
overload,
)
@@ -522,7 +521,7 @@ async def _resolve_address_nocp(
# TODO: stopping users from initializing this type should be done in a different way,
# so SocketType can be used as a type. Note that this is *far* from trivial without
-# breaking subclasses of SocketType. Should maybe just add abstract methods to SocketType,
+# breaking subclasses of SocketType. Can maybe add abstract methods to SocketType,
# or rename _SocketType.
class SocketType:
def __init__(self) -> NoReturn:
@@ -593,8 +592,16 @@ def setsockopt(
optlen: int | None = None,
) -> None:
if optlen is None:
- return self._sock.setsockopt(level, optname, cast("int|Buffer", value))
- return self._sock.setsockopt(level, optname, cast(None, value), optlen)
+ if value is None:
+ raise TypeError(
+ "invalid value for argument 'value', must not be None when specifying optlen"
+ )
+ return self._sock.setsockopt(level, optname, value)
+ if value is not None:
+ raise TypeError(
+ "invalid value for argument 'value': {value!r}, must be None when specifying optlen"
+ )
+ return self._sock.setsockopt(level, optname, value, optlen)
def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
return self._sock.listen(backlog)
diff --git a/trio/_sync.py b/trio/_sync.py
index 9764ddce2d..6e8ba13230 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -147,6 +147,11 @@ class CapacityLimiterStatistics:
tasks_waiting: int = attr.ib()
+# Can be a generic type with a default of Task if/when PEP 696 is released
+# and implemented in type checkers. Making it fully generic would currently
+# introduce a lot of unnecessary hassle.
+
+
class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final):
"""An object for controlling access to a resource with limited capacity.
diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py
index e559b98240..ec6567500c 100644
--- a/trio/_tests/test_socket.py
+++ b/trio/_tests/test_socket.py
@@ -360,6 +360,23 @@ async def test_SocketType_basics():
sock.close()
+async def test_SocketType_setsockopt():
+ sock = tsocket.socket()
+ with sock as _:
+ # specifying optlen
+ sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0)
+ # specifying value
+ sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
+
+ # specifying both
+ with pytest.raises(TypeError, match="invalid value for argument 'value'"):
+ sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) # type: ignore[call-overload]
+
+ # specifying neither
+ with pytest.raises(TypeError, match="invalid value for argument 'value'"):
+ sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) # type: ignore[call-overload]
+
+
async def test_SocketType_dup():
a, b = tsocket.socketpair()
with a, b:
From 37c7944a071a97d05bdfc101897c048d50939a79 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 27 Jul 2023 17:50:14 +0200
Subject: [PATCH 11/16] fix tests on non-linux
---
trio/_tests/test_socket.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py
index ec6567500c..40fabd73ac 100644
--- a/trio/_tests/test_socket.py
+++ b/trio/_tests/test_socket.py
@@ -363,8 +363,12 @@ async def test_SocketType_basics():
async def test_SocketType_setsockopt():
sock = tsocket.socket()
with sock as _:
- # specifying optlen
- sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0)
+ # no SO_BINDTODEVICE on other platforms. There's maybe other
+ # options that are if anybody wants to hunt through socket
+ # documentation on different platforms.
+ if sys.platform == "linux":
+ # specifying optlen
+ sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0)
# specifying value
sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
From 397252660732c6dba71280e861280b7dbab307da Mon Sep 17 00:00:00 2001
From: Spencer Brown
Date: Fri, 28 Jul 2023 13:27:07 +1000
Subject: [PATCH 12/16] Fix missing sys import in test_socket
---
trio/_tests/test_socket.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py
index 40fabd73ac..68c20c2ed5 100644
--- a/trio/_tests/test_socket.py
+++ b/trio/_tests/test_socket.py
@@ -2,7 +2,7 @@
import inspect
import os
import socket as stdlib_socket
-import sys as _sys
+import sys
import tempfile
import attr
@@ -277,7 +277,7 @@ async def test_socket_v6():
assert s.family == tsocket.AF_INET6
-@pytest.mark.skipif(not _sys.platform == "linux", reason="linux only")
+@pytest.mark.skipif(not sys.platform == "linux", reason="linux only")
async def test_sniff_sockopts():
from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM
From a06aee2bb1f485fb8204be6024e0b5b5e9263a62 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 28 Jul 2023 13:16:13 +0200
Subject: [PATCH 13/16] somewhat hackish cross-platform test
---
trio/_tests/test_socket.py | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py
index 68c20c2ed5..99288823ac 100644
--- a/trio/_tests/test_socket.py
+++ b/trio/_tests/test_socket.py
@@ -363,12 +363,15 @@ async def test_SocketType_basics():
async def test_SocketType_setsockopt():
sock = tsocket.socket()
with sock as _:
- # no SO_BINDTODEVICE on other platforms. There's maybe other
- # options that are if anybody wants to hunt through socket
- # documentation on different platforms.
- if sys.platform == "linux":
- # specifying optlen
+ # specifying optlen
+ if hasattr(tsocket, "SO_BINDTODEVICE"):
sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0)
+ # I couldn't find valid calls using optlen on systems other than
+ # linux CPython, so we instead check that we get an
+ # 'Invalid argument' error from the underlying socket.socket
+ with pytest.raises(OSError, match="Invalid argument"):
+ sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None, 0)
+
# specifying value
sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
From 5abc8ce65f07a1cf01c7ae8aab067545aff71455 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 28 Jul 2023 13:32:50 +0200
Subject: [PATCH 14/16] skip hacky test
---
trio/_socket.py | 3 +++
trio/_tests/test_socket.py | 8 ++------
2 files changed, 5 insertions(+), 6 deletions(-)
diff --git a/trio/_socket.py b/trio/_socket.py
index e17ae725a9..e75b1a9c2f 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -601,6 +601,9 @@ def setsockopt(
raise TypeError(
"invalid value for argument 'value': {value!r}, must be None when specifying optlen"
)
+
+ # Note: PyPy may crash here due to setsockopt only supporting
+ # four parameters.
return self._sock.setsockopt(level, optname, value, optlen)
def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None:
diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py
index 99288823ac..e9baff436a 100644
--- a/trio/_tests/test_socket.py
+++ b/trio/_tests/test_socket.py
@@ -363,14 +363,10 @@ async def test_SocketType_basics():
async def test_SocketType_setsockopt():
sock = tsocket.socket()
with sock as _:
- # specifying optlen
+ # specifying optlen. Not supported on pypy, and I couldn't find
+ # valid calls on darwin or win32.
if hasattr(tsocket, "SO_BINDTODEVICE"):
sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0)
- # I couldn't find valid calls using optlen on systems other than
- # linux CPython, so we instead check that we get an
- # 'Invalid argument' error from the underlying socket.socket
- with pytest.raises(OSError, match="Invalid argument"):
- sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None, 0)
# specifying value
sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
From 22fa8dffa916006c63d84b2216b7c56759cdca9e Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sat, 29 Jul 2023 12:20:46 +0200
Subject: [PATCH 15/16] remove extraneous lines
---
trio/_sync.py | 2 --
1 file changed, 2 deletions(-)
diff --git a/trio/_sync.py b/trio/_sync.py
index 6e8ba13230..bd2122858e 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -150,8 +150,6 @@ class CapacityLimiterStatistics:
# Can be a generic type with a default of Task if/when PEP 696 is released
# and implemented in type checkers. Making it fully generic would currently
# introduce a lot of unnecessary hassle.
-
-
class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final):
"""An object for controlling access to a resource with limited capacity.
From 975da8a91cfe20be6729aa945be49dc6b92196dc Mon Sep 17 00:00:00 2001
From: John Litborn <11260241+jakkdl@users.noreply.github.com>
Date: Sat, 29 Jul 2023 12:23:04 +0200
Subject: [PATCH 16/16] Update trio/_core/_local.py
make `_NoValue` `Final`
Co-authored-by: EXPLOSION
---
trio/_core/_local.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/trio/_core/_local.py b/trio/_core/_local.py
index c1e6fa01fe..7f2c632153 100644
--- a/trio/_core/_local.py
+++ b/trio/_core/_local.py
@@ -12,7 +12,7 @@
@final
-class _NoValue:
+class _NoValue(metaclass=Final):
...