-
-
Notifications
You must be signed in to change notification settings - Fork 392
Add type annotations for _ssl.py
#2745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
09c5761
d9a7e24
abcb7cd
aaaea9b
81d8d85
067527a
985acca
11842a9
edbf981
4a92a2d
564bb92
ce776ac
c896899
93333bc
1be18a3
49c2c1b
cbee73e
b8f18ee
acfcead
31b0ccf
40e9bf2
3406eaf
38c9186
a583326
bf88b26
beaf82d
b6e7303
cf7a777
e791683
3020001
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,18 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import operator as _operator | ||
| import ssl as _stdlib_ssl | ||
| from collections.abc import Awaitable, Callable | ||
| from enum import Enum as _Enum | ||
| from typing import Any, Final as TFinal, TypeVar | ||
|
|
||
| import trio | ||
|
|
||
| from . import _sync | ||
| from ._highlevel_generic import aclose_forcefully | ||
| from ._util import ConflictDetector, Final | ||
| from .abc import Listener, Stream | ||
|
|
||
| # General theory of operation: | ||
| # | ||
| # We implement an API that closely mirrors the stdlib ssl module's blocking | ||
|
|
@@ -149,16 +164,8 @@ | |
| # docs will need to make very clear that this is different from all the other | ||
| # cancellations in core Trio | ||
|
|
||
| import operator as _operator | ||
| import ssl as _stdlib_ssl | ||
| from enum import Enum as _Enum | ||
|
|
||
| import trio | ||
|
|
||
| from . import _sync | ||
| from ._highlevel_generic import aclose_forcefully | ||
| from ._util import ConflictDetector, Final | ||
| from .abc import Listener, Stream | ||
| T = TypeVar("T") | ||
|
|
||
| ################################################################ | ||
| # SSLStream | ||
|
|
@@ -187,16 +194,16 @@ | |
| # MTU and an initial window of 10 (see RFC 6928), then the initial burst of | ||
| # data will be limited to ~15000 bytes (or a bit less due to IP-level framing | ||
| # overhead), so this is chosen to be larger than that. | ||
| STARTING_RECEIVE_SIZE = 16384 | ||
| STARTING_RECEIVE_SIZE: TFinal = 16384 | ||
|
|
||
|
|
||
| def _is_eof(exc): | ||
| def _is_eof(exc: BaseException | None) -> bool: | ||
| # There appears to be a bug on Python 3.10, where SSLErrors | ||
| # aren't properly translated into SSLEOFErrors. | ||
| # This stringly-typed error check is borrowed from the AnyIO | ||
| # project. | ||
| return isinstance(exc, _stdlib_ssl.SSLEOFError) or ( | ||
| hasattr(exc, "strerror") and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror | ||
| "UNEXPECTED_EOF_WHILE_READING" in getattr(exc, "strerror", ()) | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -209,13 +216,13 @@ class NeedHandshakeError(Exception): | |
|
|
||
|
|
||
| class _Once: | ||
| def __init__(self, afn, *args): | ||
| def __init__(self, afn: Callable[..., Awaitable[object]], *args: object) -> None: | ||
| self._afn = afn | ||
| self._args = args | ||
| self.started = False | ||
| self._done = _sync.Event() | ||
|
|
||
| async def ensure(self, *, checkpoint): | ||
| async def ensure(self, *, checkpoint: bool) -> None: | ||
| if not self.started: | ||
| self.started = True | ||
| await self._afn(*self._args) | ||
|
|
@@ -226,8 +233,8 @@ async def ensure(self, *, checkpoint): | |
| await self._done.wait() | ||
|
|
||
| @property | ||
| def done(self): | ||
| return self._done.is_set() | ||
| def done(self) -> bool: | ||
| return bool(self._done.is_set()) | ||
|
|
||
|
|
||
| _State = _Enum("_State", ["OK", "BROKEN", "CLOSED"]) | ||
|
|
@@ -257,8 +264,8 @@ class SSLStream(Stream, metaclass=Final): | |
| this connection. Required. Usually created by calling | ||
| :func:`ssl.create_default_context`. | ||
|
|
||
| server_hostname (str or None): The name of the server being connected | ||
| to. Used for `SNI | ||
| server_hostname (str, bytes, or None): The name of the server being | ||
| connected to. Used for `SNI | ||
| <https://en.wikipedia.org/wiki/Server_Name_Indication>`__ and for | ||
| validating the server's certificate (if hostname checking is | ||
| enabled). This is effectively mandatory for clients, and actually | ||
|
|
@@ -331,24 +338,24 @@ class SSLStream(Stream, metaclass=Final): | |
| # SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers. | ||
| def __init__( | ||
| self, | ||
| transport_stream, | ||
| ssl_context, | ||
| transport_stream: Stream, | ||
| ssl_context: _stdlib_ssl.SSLContext, | ||
| *, | ||
| server_hostname=None, | ||
| server_side=False, | ||
| https_compatible=False, | ||
| ): | ||
| self.transport_stream = transport_stream | ||
| server_hostname: str | bytes | None = None, | ||
| server_side: bool = False, | ||
| https_compatible: bool = False, | ||
| ) -> None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: inconsequential
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not inconsequential, I am fairly certain that
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you referring to #2740 (comment) ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just bumping on this! |
||
| self.transport_stream: Stream = transport_stream | ||
| self._state = _State.OK | ||
| self._https_compatible = https_compatible | ||
| self._outgoing = _stdlib_ssl.MemoryBIO() | ||
| self._delayed_outgoing = None | ||
| self._delayed_outgoing: bytes | None = None | ||
| self._incoming = _stdlib_ssl.MemoryBIO() | ||
| self._ssl_object = ssl_context.wrap_bio( | ||
| self._incoming, | ||
| self._outgoing, | ||
| server_side=server_side, | ||
| server_hostname=server_hostname, | ||
| server_hostname=server_hostname, # type: ignore[arg-type] # Typeshed bug, does accept bytes as well (typeshed#10590) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like that PR was merged, meaning next mypy release should support this! Nice! |
||
| ) | ||
| # Tracks whether we've already done the initial handshake | ||
| self._handshook = _Once(self._do_handshake) | ||
|
|
@@ -399,7 +406,7 @@ def __init__( | |
| "version", | ||
| } | ||
|
|
||
| def __getattr__(self, name): | ||
| def __getattr__(self, name: str) -> Any: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose there's a case for returning |
||
| if name in self._forwarded: | ||
| if name in self._after_handshake and not self._handshook.done: | ||
| raise NeedHandshakeError(f"call do_handshake() before calling {name!r}") | ||
|
|
@@ -408,16 +415,16 @@ def __getattr__(self, name): | |
| else: | ||
| raise AttributeError(name) | ||
|
|
||
| def __setattr__(self, name, value): | ||
| def __setattr__(self, name: str, value: object) -> None: | ||
| if name in self._forwarded: | ||
| setattr(self._ssl_object, name, value) | ||
| else: | ||
| super().__setattr__(name, value) | ||
|
|
||
| def __dir__(self): | ||
| return super().__dir__() + list(self._forwarded) | ||
| def __dir__(self) -> list[str]: | ||
| return list(super().__dir__()) + list(self._forwarded) | ||
|
|
||
| def _check_status(self): | ||
| def _check_status(self) -> None: | ||
| if self._state is _State.OK: | ||
| return | ||
| elif self._state is _State.BROKEN: | ||
|
|
@@ -431,7 +438,13 @@ def _check_status(self): | |
| # comments, though, just make sure to think carefully if you ever have to | ||
| # touch it. The big comment at the top of this file will help explain | ||
| # too. | ||
| async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): | ||
| async def _retry( | ||
| self, | ||
| fn: Callable[..., T], | ||
| *args: object, | ||
| ignore_want_read: bool = False, | ||
| is_handshake: bool = False, | ||
| ) -> T | None: | ||
| await trio.lowlevel.checkpoint_if_cancelled() | ||
| yielded = False | ||
| finished = False | ||
|
|
@@ -603,14 +616,14 @@ async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): | |
| await trio.lowlevel.cancel_shielded_checkpoint() | ||
| return ret | ||
|
|
||
| async def _do_handshake(self): | ||
| async def _do_handshake(self) -> None: | ||
| try: | ||
| await self._retry(self._ssl_object.do_handshake, is_handshake=True) | ||
| except: | ||
| self._state = _State.BROKEN | ||
| raise | ||
|
|
||
| async def do_handshake(self): | ||
| async def do_handshake(self) -> None: | ||
| """Ensure that the initial handshake has completed. | ||
|
|
||
| The SSL protocol requires an initial handshake to exchange | ||
|
|
@@ -645,7 +658,7 @@ async def do_handshake(self): | |
| # https://bugs.python.org/issue30141 | ||
| # So we *definitely* have to make sure that do_handshake is called | ||
| # before doing anything else. | ||
| async def receive_some(self, max_bytes=None): | ||
| async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: | ||
| """Read some data from the underlying transport, decrypt it, and | ||
| return it. | ||
|
|
||
|
|
@@ -684,7 +697,9 @@ async def receive_some(self, max_bytes=None): | |
| if max_bytes < 1: | ||
| raise ValueError("max_bytes must be >= 1") | ||
| try: | ||
| return await self._retry(self._ssl_object.read, max_bytes) | ||
| received = await self._retry(self._ssl_object.read, max_bytes) | ||
| assert received is not None | ||
| return received | ||
| except trio.BrokenResourceError as exc: | ||
| # This isn't quite equivalent to just returning b"" in the | ||
| # first place, because we still end up with self._state set to | ||
|
|
@@ -698,7 +713,7 @@ async def receive_some(self, max_bytes=None): | |
| else: | ||
| raise | ||
|
|
||
| async def send_all(self, data): | ||
| async def send_all(self, data: bytes | bytearray | memoryview) -> None: | ||
| """Encrypt some data and then send it on the underlying transport. | ||
|
|
||
| See :meth:`trio.abc.SendStream.send_all` for details. | ||
|
|
@@ -719,7 +734,7 @@ async def send_all(self, data): | |
| return | ||
| await self._retry(self._ssl_object.write, data) | ||
|
|
||
| async def unwrap(self): | ||
| async def unwrap(self) -> tuple[Stream, bytes | bytearray]: | ||
| """Cleanly close down the SSL/TLS encryption layer, allowing the | ||
| underlying stream to be used for unencrypted communication. | ||
|
|
||
|
|
@@ -741,11 +756,11 @@ async def unwrap(self): | |
| await self._handshook.ensure(checkpoint=False) | ||
| await self._retry(self._ssl_object.unwrap) | ||
| transport_stream = self.transport_stream | ||
| self.transport_stream = None | ||
| self._state = _State.CLOSED | ||
| self.transport_stream = None # type: ignore[assignment] # State is CLOSED now, nothing should use | ||
| return (transport_stream, self._incoming.read()) | ||
|
|
||
| async def aclose(self): | ||
| async def aclose(self) -> None: | ||
| """Gracefully shut down this connection, and close the underlying | ||
| transport. | ||
|
|
||
|
|
@@ -832,7 +847,7 @@ async def aclose(self): | |
| finally: | ||
| self._state = _State.CLOSED | ||
|
|
||
| async def wait_send_all_might_not_block(self): | ||
| async def wait_send_all_might_not_block(self) -> None: | ||
| """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`.""" | ||
| # This method's implementation is deceptively simple. | ||
| # | ||
|
|
@@ -897,16 +912,16 @@ class SSLListener(Listener[SSLStream], metaclass=Final): | |
|
|
||
| def __init__( | ||
| self, | ||
| transport_listener, | ||
| ssl_context, | ||
| transport_listener: Listener[Stream], | ||
| ssl_context: _stdlib_ssl.SSLContext, | ||
| *, | ||
| https_compatible=False, | ||
| ): | ||
| https_compatible: bool = False, | ||
| ) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW this
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I haven't seen any reason why it would change anything. |
||
| self.transport_listener = transport_listener | ||
| self._ssl_context = ssl_context | ||
| self._https_compatible = https_compatible | ||
|
|
||
| async def accept(self): | ||
| async def accept(self) -> SSLStream: | ||
| """Accept the next connection and wrap it in an :class:`SSLStream`. | ||
|
|
||
| See :meth:`trio.abc.Listener.accept` for details. | ||
|
|
@@ -920,6 +935,6 @@ async def accept(self): | |
| https_compatible=self._https_compatible, | ||
| ) | ||
|
|
||
| async def aclose(self): | ||
| async def aclose(self) -> None: | ||
| """Close the transport listener.""" | ||
| await self.transport_listener.aclose() | ||
Uh oh!
There was an error while loading. Please reload this page.