Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
09c5761
Somehow this did not get in to edbf981
CoolCat467 Aug 6, 2023
d9a7e24
Add type annotations to `_ssl.py`
CoolCat467 Aug 7, 2023
abcb7cd
Fix for pre-3.10
CoolCat467 Aug 7, 2023
aaaea9b
Merge branch 'python-trio:master' into typing-ssl
CoolCat467 Aug 7, 2023
81d8d85
Fix CI issues
CoolCat467 Aug 7, 2023
067527a
Update `verify_types.json`
CoolCat467 Aug 7, 2023
985acca
Remove unused imports
CoolCat467 Aug 7, 2023
11842a9
Merge branch 'master' into typing-ssl
CoolCat467 Aug 8, 2023
edbf981
Change lots of `Any` to `object`
CoolCat467 Aug 6, 2023
4a92a2d
Update `verify_types.json`
CoolCat467 Aug 8, 2023
564bb92
Add `_ssl` to stricker check block and sort modules
CoolCat467 Aug 6, 2023
ce776ac
Add `trio._abc.T_resource` to doc ignore list
CoolCat467 Aug 8, 2023
c896899
Attempt to fix pyright issues
CoolCat467 Aug 10, 2023
93333bc
Update `verify_types.json`
CoolCat467 Aug 10, 2023
1be18a3
Merge branch 'master' into typing-ssl
CoolCat467 Aug 14, 2023
49c2c1b
Remove `__slots__`
CoolCat467 Aug 14, 2023
cbee73e
Update `verify_types.json`
CoolCat467 Aug 14, 2023
b8f18ee
Merge branch 'typing-ssl' of https://github.com/CoolCat467/trio into …
CoolCat467 Aug 16, 2023
acfcead
Change to `object` (@A5rocks suggestion)
CoolCat467 Aug 16, 2023
31b0ccf
Fix `server_hostname`
CoolCat467 Aug 16, 2023
40e9bf2
Update `verify_types.json`
CoolCat467 Aug 16, 2023
3406eaf
Add `_highlevel_ssl_helpers` to stricter checks block
CoolCat467 Aug 16, 2023
38c9186
Revert "Add `_highlevel_ssl_helpers` to stricter checks block"
CoolCat467 Aug 16, 2023
a583326
Revert changes to `_highlevel_ssl_helpers`
CoolCat467 Aug 16, 2023
bf88b26
Revert "Update `verify_types.json`"
CoolCat467 Aug 16, 2023
beaf82d
Merge branch 'master' into typing-ssl
CoolCat467 Aug 20, 2023
b6e7303
Update `verify_types.json` and fix line endings
CoolCat467 Aug 20, 2023
cf7a777
Remove slots from exception
CoolCat467 Aug 20, 2023
e791683
Remove `_ssl` from not fully typed block
CoolCat467 Aug 20, 2023
3020001
Merge remote-tracking branch 'origin/master' into typing-ssl
jakkdl Aug 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ disallow_untyped_calls = false
# files not yet fully typed
[[tool.mypy.overrides]]
module = [
# 2745
"trio/_ssl",
# 2761
"trio/_core/_generated_io_windows",
"trio/_core/_io_windows",
Expand Down
2 changes: 1 addition & 1 deletion trio/_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ class Listener(AsyncResource, Generic[T_resource]):
__slots__ = ()

@abstractmethod
async def accept(self) -> AsyncResource:
async def accept(self) -> T_resource:
"""Wait until an incoming connection arrives, and then return it.

Returns:
Expand Down
111 changes: 63 additions & 48 deletions trio/_ssl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
from __future__ import annotations

import operator as _operator
import ssl as _stdlib_ssl
from collections.abc import Awaitable, Callable
from enum import Enum as _Enum
from typing import Any, Final as TFinal, TypeVar

import trio

from . import _sync
from ._highlevel_generic import aclose_forcefully
from ._util import ConflictDetector, Final
from .abc import Listener, Stream

# General theory of operation:
#
# We implement an API that closely mirrors the stdlib ssl module's blocking
Expand Down Expand Up @@ -149,16 +164,8 @@
# docs will need to make very clear that this is different from all the other
# cancellations in core Trio

import operator as _operator
import ssl as _stdlib_ssl
from enum import Enum as _Enum

import trio

from . import _sync
from ._highlevel_generic import aclose_forcefully
from ._util import ConflictDetector, Final
from .abc import Listener, Stream
T = TypeVar("T")

################################################################
# SSLStream
Expand Down Expand Up @@ -187,16 +194,16 @@
# MTU and an initial window of 10 (see RFC 6928), then the initial burst of
# data will be limited to ~15000 bytes (or a bit less due to IP-level framing
# overhead), so this is chosen to be larger than that.
STARTING_RECEIVE_SIZE = 16384
STARTING_RECEIVE_SIZE: TFinal = 16384


def _is_eof(exc):
def _is_eof(exc: BaseException | None) -> bool:
# There appears to be a bug on Python 3.10, where SSLErrors
# aren't properly translated into SSLEOFErrors.
# This stringly-typed error check is borrowed from the AnyIO
# project.
return isinstance(exc, _stdlib_ssl.SSLEOFError) or (
hasattr(exc, "strerror") and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
"UNEXPECTED_EOF_WHILE_READING" in getattr(exc, "strerror", ())
)


Expand All @@ -209,13 +216,13 @@ class NeedHandshakeError(Exception):


class _Once:
def __init__(self, afn, *args):
def __init__(self, afn: Callable[..., Awaitable[object]], *args: object) -> None:
self._afn = afn
self._args = args
self.started = False
self._done = _sync.Event()

async def ensure(self, *, checkpoint):
async def ensure(self, *, checkpoint: bool) -> None:
if not self.started:
self.started = True
await self._afn(*self._args)
Expand All @@ -226,8 +233,8 @@ async def ensure(self, *, checkpoint):
await self._done.wait()

@property
def done(self):
return self._done.is_set()
def done(self) -> bool:
return bool(self._done.is_set())


_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])
Expand Down Expand Up @@ -257,8 +264,8 @@ class SSLStream(Stream, metaclass=Final):
this connection. Required. Usually created by calling
:func:`ssl.create_default_context`.

server_hostname (str or None): The name of the server being connected
to. Used for `SNI
server_hostname (str, bytes, or None): The name of the server being
connected to. Used for `SNI
<https://en.wikipedia.org/wiki/Server_Name_Indication>`__ and for
validating the server's certificate (if hostname checking is
enabled). This is effectively mandatory for clients, and actually
Expand Down Expand Up @@ -331,24 +338,24 @@ class SSLStream(Stream, metaclass=Final):
# SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers.
def __init__(
self,
transport_stream,
ssl_context,
transport_stream: Stream,
ssl_context: _stdlib_ssl.SSLContext,
*,
server_hostname=None,
server_side=False,
https_compatible=False,
):
self.transport_stream = transport_stream
server_hostname: str | bytes | None = None,
server_side: bool = False,
https_compatible: bool = False,
) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: inconsequential -> None. We ought to create a style guide somewhere for typing, and/or add a check to #2744 that removes/adds it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not inconsequential, I am fairly certain that verify_types.json changes for the better when it's there

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to #2740 (comment) ?
I removed the -> None without any impact on verify_types.json

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just bumping on this!

self.transport_stream: Stream = transport_stream
self._state = _State.OK
self._https_compatible = https_compatible
self._outgoing = _stdlib_ssl.MemoryBIO()
self._delayed_outgoing = None
self._delayed_outgoing: bytes | None = None
self._incoming = _stdlib_ssl.MemoryBIO()
self._ssl_object = ssl_context.wrap_bio(
self._incoming,
self._outgoing,
server_side=server_side,
server_hostname=server_hostname,
server_hostname=server_hostname, # type: ignore[arg-type] # Typeshed bug, does accept bytes as well (typeshed#10590)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like that PR was merged, meaning next mypy release should support this! Nice!

)
# Tracks whether we've already done the initial handshake
self._handshook = _Once(self._do_handshake)
Expand Down Expand Up @@ -399,7 +406,7 @@ def __init__(
"version",
}

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose there's a case for returning object here to tell the user they need to do isinstance. But they can also enable disallow_any_expression I suppose.
I tried finding an authoritative best-practice of what to do in this case, but no dice

if name in self._forwarded:
if name in self._after_handshake and not self._handshook.done:
raise NeedHandshakeError(f"call do_handshake() before calling {name!r}")
Expand All @@ -408,16 +415,16 @@ def __getattr__(self, name):
else:
raise AttributeError(name)

def __setattr__(self, name, value):
def __setattr__(self, name: str, value: object) -> None:
if name in self._forwarded:
setattr(self._ssl_object, name, value)
else:
super().__setattr__(name, value)

def __dir__(self):
return super().__dir__() + list(self._forwarded)
def __dir__(self) -> list[str]:
return list(super().__dir__()) + list(self._forwarded)

def _check_status(self):
def _check_status(self) -> None:
if self._state is _State.OK:
return
elif self._state is _State.BROKEN:
Expand All @@ -431,7 +438,13 @@ def _check_status(self):
# comments, though, just make sure to think carefully if you ever have to
# touch it. The big comment at the top of this file will help explain
# too.
async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False):
async def _retry(
self,
fn: Callable[..., T],
*args: object,
ignore_want_read: bool = False,
is_handshake: bool = False,
) -> T | None:
await trio.lowlevel.checkpoint_if_cancelled()
yielded = False
finished = False
Expand Down Expand Up @@ -603,14 +616,14 @@ async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False):
await trio.lowlevel.cancel_shielded_checkpoint()
return ret

async def _do_handshake(self):
async def _do_handshake(self) -> None:
try:
await self._retry(self._ssl_object.do_handshake, is_handshake=True)
except:
self._state = _State.BROKEN
raise

async def do_handshake(self):
async def do_handshake(self) -> None:
"""Ensure that the initial handshake has completed.

The SSL protocol requires an initial handshake to exchange
Expand Down Expand Up @@ -645,7 +658,7 @@ async def do_handshake(self):
# https://bugs.python.org/issue30141
# So we *definitely* have to make sure that do_handshake is called
# before doing anything else.
async def receive_some(self, max_bytes=None):
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
"""Read some data from the underlying transport, decrypt it, and
return it.

Expand Down Expand Up @@ -684,7 +697,9 @@ async def receive_some(self, max_bytes=None):
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
try:
return await self._retry(self._ssl_object.read, max_bytes)
received = await self._retry(self._ssl_object.read, max_bytes)
assert received is not None
return received
except trio.BrokenResourceError as exc:
# This isn't quite equivalent to just returning b"" in the
# first place, because we still end up with self._state set to
Expand All @@ -698,7 +713,7 @@ async def receive_some(self, max_bytes=None):
else:
raise

async def send_all(self, data):
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
"""Encrypt some data and then send it on the underlying transport.

See :meth:`trio.abc.SendStream.send_all` for details.
Expand All @@ -719,7 +734,7 @@ async def send_all(self, data):
return
await self._retry(self._ssl_object.write, data)

async def unwrap(self):
async def unwrap(self) -> tuple[Stream, bytes | bytearray]:
"""Cleanly close down the SSL/TLS encryption layer, allowing the
underlying stream to be used for unencrypted communication.

Expand All @@ -741,11 +756,11 @@ async def unwrap(self):
await self._handshook.ensure(checkpoint=False)
await self._retry(self._ssl_object.unwrap)
transport_stream = self.transport_stream
self.transport_stream = None
self._state = _State.CLOSED
self.transport_stream = None # type: ignore[assignment] # State is CLOSED now, nothing should use
return (transport_stream, self._incoming.read())

async def aclose(self):
async def aclose(self) -> None:
"""Gracefully shut down this connection, and close the underlying
transport.

Expand Down Expand Up @@ -832,7 +847,7 @@ async def aclose(self):
finally:
self._state = _State.CLOSED

async def wait_send_all_might_not_block(self):
async def wait_send_all_might_not_block(self) -> None:
"""See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`."""
# This method's implementation is deceptively simple.
#
Expand Down Expand Up @@ -897,16 +912,16 @@ class SSLListener(Listener[SSLStream], metaclass=Final):

def __init__(
self,
transport_listener,
ssl_context,
transport_listener: Listener[Stream],
ssl_context: _stdlib_ssl.SSLContext,
*,
https_compatible=False,
):
https_compatible: bool = False,
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW this -> None is probably also unnecessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I haven't seen any reason why it would change anything.

self.transport_listener = transport_listener
self._ssl_context = ssl_context
self._https_compatible = https_compatible

async def accept(self):
async def accept(self) -> SSLStream:
"""Accept the next connection and wrap it in an :class:`SSLStream`.

See :meth:`trio.abc.Listener.accept` for details.
Expand All @@ -920,6 +935,6 @@ async def accept(self):
https_compatible=self._https_compatible,
)

async def aclose(self):
async def aclose(self) -> None:
"""Close the transport listener."""
await self.transport_listener.aclose()
7 changes: 4 additions & 3 deletions trio/_tests/check_type_completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,10 @@ def main(args: argparse.Namespace) -> int:
invert=invert,
)

assert (
res.returncode != 0
), "Fully type complete! Delete this script and instead directly run `pyright --verifytypes=trio` (consider `--ignoreexternal`) in CI and checking exit code."
# handle in separate PR
# assert (
# res.returncode != 0
# ), "Fully type complete! Delete this script and instead directly run `pyright --verifytypes=trio` (consider `--ignoreexternal`) in CI and checking exit code."

if args.overwrite_file:
print("Overwriting file")
Expand Down
26 changes: 6 additions & 20 deletions trio/_tests/verify_types.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
"completenessScore": 0.9968152866242038,
"completenessScore": 1,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
"withKnownType": 626,
"withUnknownType": 2
"withKnownType": 628,
"withUnknownType": 0
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
Expand Down Expand Up @@ -45,26 +45,12 @@
}
],
"otherSymbolCounts": {
"withAmbiguousType": 1,
"withKnownType": 666,
"withUnknownType": 15
"withAmbiguousType": 0,
"withKnownType": 682,
"withUnknownType": 0
},
"packageName": "trio",
"symbols": [
"trio._ssl.SSLListener.__init__",
"trio._ssl.SSLListener.accept",
"trio._ssl.SSLListener.aclose",
"trio._ssl.SSLStream.__dir__",
"trio._ssl.SSLStream.__getattr__",
"trio._ssl.SSLStream.__init__",
"trio._ssl.SSLStream.__setattr__",
"trio._ssl.SSLStream.aclose",
"trio._ssl.SSLStream.do_handshake",
"trio._ssl.SSLStream.receive_some",
"trio._ssl.SSLStream.send_all",
"trio._ssl.SSLStream.transport_stream",
"trio._ssl.SSLStream.unwrap",
"trio._ssl.SSLStream.wait_send_all_might_not_block",
"trio.lowlevel.notify_closing",
"trio.lowlevel.wait_readable",
"trio.lowlevel.wait_writable",
Expand Down