diff --git a/pyproject.toml b/pyproject.toml index 34f2f069b3..184b46056d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,8 +60,6 @@ disallow_untyped_calls = false # files not yet fully typed [[tool.mypy.overrides]] module = [ -# 2745 -"trio/_ssl", # 2761 "trio/_core/_generated_io_windows", "trio/_core/_io_windows", diff --git a/trio/_abc.py b/trio/_abc.py index 59454b794c..746360c8f8 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -565,7 +565,7 @@ class Listener(AsyncResource, Generic[T_resource]): __slots__ = () @abstractmethod - async def accept(self) -> AsyncResource: + async def accept(self) -> T_resource: """Wait until an incoming connection arrives, and then return it. Returns: diff --git a/trio/_ssl.py b/trio/_ssl.py index bd8b3b06b6..f0f01f7583 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -1,3 +1,18 @@ +from __future__ import annotations + +import operator as _operator +import ssl as _stdlib_ssl +from collections.abc import Awaitable, Callable +from enum import Enum as _Enum +from typing import Any, Final as TFinal, TypeVar + +import trio + +from . import _sync +from ._highlevel_generic import aclose_forcefully +from ._util import ConflictDetector, Final +from .abc import Listener, Stream + # General theory of operation: # # We implement an API that closely mirrors the stdlib ssl module's blocking @@ -149,16 +164,8 @@ # docs will need to make very clear that this is different from all the other # cancellations in core Trio -import operator as _operator -import ssl as _stdlib_ssl -from enum import Enum as _Enum -import trio - -from . import _sync -from ._highlevel_generic import aclose_forcefully -from ._util import ConflictDetector, Final -from .abc import Listener, Stream +T = TypeVar("T") ################################################################ # SSLStream @@ -187,16 +194,16 @@ # MTU and an initial window of 10 (see RFC 6928), then the initial burst of # data will be limited to ~15000 bytes (or a bit less due to IP-level framing # overhead), so this is chosen to be larger than that. -STARTING_RECEIVE_SIZE = 16384 +STARTING_RECEIVE_SIZE: TFinal = 16384 -def _is_eof(exc): +def _is_eof(exc: BaseException | None) -> bool: # There appears to be a bug on Python 3.10, where SSLErrors # aren't properly translated into SSLEOFErrors. # This stringly-typed error check is borrowed from the AnyIO # project. return isinstance(exc, _stdlib_ssl.SSLEOFError) or ( - hasattr(exc, "strerror") and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror + "UNEXPECTED_EOF_WHILE_READING" in getattr(exc, "strerror", ()) ) @@ -209,13 +216,13 @@ class NeedHandshakeError(Exception): class _Once: - def __init__(self, afn, *args): + def __init__(self, afn: Callable[..., Awaitable[object]], *args: object) -> None: self._afn = afn self._args = args self.started = False self._done = _sync.Event() - async def ensure(self, *, checkpoint): + async def ensure(self, *, checkpoint: bool) -> None: if not self.started: self.started = True await self._afn(*self._args) @@ -226,8 +233,8 @@ async def ensure(self, *, checkpoint): await self._done.wait() @property - def done(self): - return self._done.is_set() + def done(self) -> bool: + return bool(self._done.is_set()) _State = _Enum("_State", ["OK", "BROKEN", "CLOSED"]) @@ -257,8 +264,8 @@ class SSLStream(Stream, metaclass=Final): this connection. Required. Usually created by calling :func:`ssl.create_default_context`. - server_hostname (str or None): The name of the server being connected - to. Used for `SNI + server_hostname (str, bytes, or None): The name of the server being + connected to. Used for `SNI `__ and for validating the server's certificate (if hostname checking is enabled). This is effectively mandatory for clients, and actually @@ -331,24 +338,24 @@ class SSLStream(Stream, metaclass=Final): # SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers. def __init__( self, - transport_stream, - ssl_context, + transport_stream: Stream, + ssl_context: _stdlib_ssl.SSLContext, *, - server_hostname=None, - server_side=False, - https_compatible=False, - ): - self.transport_stream = transport_stream + server_hostname: str | bytes | None = None, + server_side: bool = False, + https_compatible: bool = False, + ) -> None: + self.transport_stream: Stream = transport_stream self._state = _State.OK self._https_compatible = https_compatible self._outgoing = _stdlib_ssl.MemoryBIO() - self._delayed_outgoing = None + self._delayed_outgoing: bytes | None = None self._incoming = _stdlib_ssl.MemoryBIO() self._ssl_object = ssl_context.wrap_bio( self._incoming, self._outgoing, server_side=server_side, - server_hostname=server_hostname, + server_hostname=server_hostname, # type: ignore[arg-type] # Typeshed bug, does accept bytes as well (typeshed#10590) ) # Tracks whether we've already done the initial handshake self._handshook = _Once(self._do_handshake) @@ -399,7 +406,7 @@ def __init__( "version", } - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name in self._forwarded: if name in self._after_handshake and not self._handshook.done: raise NeedHandshakeError(f"call do_handshake() before calling {name!r}") @@ -408,16 +415,16 @@ def __getattr__(self, name): else: raise AttributeError(name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: object) -> None: if name in self._forwarded: setattr(self._ssl_object, name, value) else: super().__setattr__(name, value) - def __dir__(self): - return super().__dir__() + list(self._forwarded) + def __dir__(self) -> list[str]: + return list(super().__dir__()) + list(self._forwarded) - def _check_status(self): + def _check_status(self) -> None: if self._state is _State.OK: return elif self._state is _State.BROKEN: @@ -431,7 +438,13 @@ def _check_status(self): # comments, though, just make sure to think carefully if you ever have to # touch it. The big comment at the top of this file will help explain # too. - async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): + async def _retry( + self, + fn: Callable[..., T], + *args: object, + ignore_want_read: bool = False, + is_handshake: bool = False, + ) -> T | None: await trio.lowlevel.checkpoint_if_cancelled() yielded = False finished = False @@ -603,14 +616,14 @@ async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): await trio.lowlevel.cancel_shielded_checkpoint() return ret - async def _do_handshake(self): + async def _do_handshake(self) -> None: try: await self._retry(self._ssl_object.do_handshake, is_handshake=True) except: self._state = _State.BROKEN raise - async def do_handshake(self): + async def do_handshake(self) -> None: """Ensure that the initial handshake has completed. The SSL protocol requires an initial handshake to exchange @@ -645,7 +658,7 @@ async def do_handshake(self): # https://bugs.python.org/issue30141 # So we *definitely* have to make sure that do_handshake is called # before doing anything else. - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: """Read some data from the underlying transport, decrypt it, and return it. @@ -684,7 +697,9 @@ async def receive_some(self, max_bytes=None): if max_bytes < 1: raise ValueError("max_bytes must be >= 1") try: - return await self._retry(self._ssl_object.read, max_bytes) + received = await self._retry(self._ssl_object.read, max_bytes) + assert received is not None + return received except trio.BrokenResourceError as exc: # This isn't quite equivalent to just returning b"" in the # first place, because we still end up with self._state set to @@ -698,7 +713,7 @@ async def receive_some(self, max_bytes=None): else: raise - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Encrypt some data and then send it on the underlying transport. See :meth:`trio.abc.SendStream.send_all` for details. @@ -719,7 +734,7 @@ async def send_all(self, data): return await self._retry(self._ssl_object.write, data) - async def unwrap(self): + async def unwrap(self) -> tuple[Stream, bytes | bytearray]: """Cleanly close down the SSL/TLS encryption layer, allowing the underlying stream to be used for unencrypted communication. @@ -741,11 +756,11 @@ async def unwrap(self): await self._handshook.ensure(checkpoint=False) await self._retry(self._ssl_object.unwrap) transport_stream = self.transport_stream - self.transport_stream = None self._state = _State.CLOSED + self.transport_stream = None # type: ignore[assignment] # State is CLOSED now, nothing should use return (transport_stream, self._incoming.read()) - async def aclose(self): + async def aclose(self) -> None: """Gracefully shut down this connection, and close the underlying transport. @@ -832,7 +847,7 @@ async def aclose(self): finally: self._state = _State.CLOSED - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`.""" # This method's implementation is deceptively simple. # @@ -897,16 +912,16 @@ class SSLListener(Listener[SSLStream], metaclass=Final): def __init__( self, - transport_listener, - ssl_context, + transport_listener: Listener[Stream], + ssl_context: _stdlib_ssl.SSLContext, *, - https_compatible=False, - ): + https_compatible: bool = False, + ) -> None: self.transport_listener = transport_listener self._ssl_context = ssl_context self._https_compatible = https_compatible - async def accept(self): + async def accept(self) -> SSLStream: """Accept the next connection and wrap it in an :class:`SSLStream`. See :meth:`trio.abc.Listener.accept` for details. @@ -920,6 +935,6 @@ async def accept(self): https_compatible=self._https_compatible, ) - async def aclose(self): + async def aclose(self) -> None: """Close the transport listener.""" await self.transport_listener.aclose() diff --git a/trio/_tests/check_type_completeness.py b/trio/_tests/check_type_completeness.py index abaabcf785..6a8761b88c 100755 --- a/trio/_tests/check_type_completeness.py +++ b/trio/_tests/check_type_completeness.py @@ -128,9 +128,10 @@ def main(args: argparse.Namespace) -> int: invert=invert, ) - assert ( - res.returncode != 0 - ), "Fully type complete! Delete this script and instead directly run `pyright --verifytypes=trio` (consider `--ignoreexternal`) in CI and checking exit code." + # handle in separate PR + # assert ( + # res.returncode != 0 + # ), "Fully type complete! Delete this script and instead directly run `pyright --verifytypes=trio` (consider `--ignoreexternal`) in CI and checking exit code." if args.overwrite_file: print("Overwriting file") diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index e8c405d2eb..2d93f0fb3f 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.9968152866242038, + "completenessScore": 1, "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 626, - "withUnknownType": 2 + "withKnownType": 628, + "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -45,26 +45,12 @@ } ], "otherSymbolCounts": { - "withAmbiguousType": 1, - "withKnownType": 666, - "withUnknownType": 15 + "withAmbiguousType": 0, + "withKnownType": 682, + "withUnknownType": 0 }, "packageName": "trio", "symbols": [ - "trio._ssl.SSLListener.__init__", - "trio._ssl.SSLListener.accept", - "trio._ssl.SSLListener.aclose", - "trio._ssl.SSLStream.__dir__", - "trio._ssl.SSLStream.__getattr__", - "trio._ssl.SSLStream.__init__", - "trio._ssl.SSLStream.__setattr__", - "trio._ssl.SSLStream.aclose", - "trio._ssl.SSLStream.do_handshake", - "trio._ssl.SSLStream.receive_some", - "trio._ssl.SSLStream.send_all", - "trio._ssl.SSLStream.transport_stream", - "trio._ssl.SSLStream.unwrap", - "trio._ssl.SSLStream.wait_send_all_might_not_block", "trio.lowlevel.notify_closing", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable",