From 7e220dd6564576bad1bf3f7d8a4a4a49305a7be3 Mon Sep 17 00:00:00 2001 From: Jamie Hewland Date: Sat, 19 Oct 2019 11:38:30 +0200 Subject: [PATCH 1/2] Move start_tls to stream & return a new stream --- httpx/concurrency/asyncio.py | 69 ++++++++++++++++-------------------- httpx/concurrency/base.py | 14 +++----- httpx/concurrency/trio.py | 44 +++++++++++------------ httpx/dispatch/proxy_http.py | 7 ++-- tests/dispatch/utils.py | 16 ++++----- tests/test_concurrency.py | 2 +- 6 files changed, 65 insertions(+), 87 deletions(-) diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index dfee425456..0a7d378b70 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -51,6 +51,37 @@ def __init__( self.stream_writer = stream_writer self.timeout = timeout + async def start_tls( + self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig + ) -> BaseTCPStream: + loop = asyncio.get_event_loop() + if not hasattr(loop, "start_tls"): # pragma: no cover + raise NotImplementedError( + "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+" + ) + + stream_reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(stream_reader) + transport = self.stream_writer.transport + + loop_start_tls = loop.start_tls # type: ignore + transport = await asyncio.wait_for( + loop_start_tls( + transport=transport, + protocol=protocol, + sslcontext=ssl_context, + server_hostname=hostname, + ), + timeout=timeout.connect_timeout, + ) + + stream_reader.set_transport(transport) + stream_writer = asyncio.StreamWriter( + transport=transport, protocol=protocol, reader=stream_reader, loop=loop + ) + + return TCPStream(stream_reader, stream_writer, self.timeout) + def get_http_version(self) -> str: ssl_object = self.stream_writer.get_extra_info("ssl_object") @@ -201,44 +232,6 @@ async def open_tcp_stream( stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout ) - async def start_tls( - self, - stream: BaseTCPStream, - hostname: str, - ssl_context: ssl.SSLContext, - timeout: TimeoutConfig, - ) -> BaseTCPStream: - - loop = self.loop - if not hasattr(loop, "start_tls"): # pragma: no cover - raise NotImplementedError( - "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+" - ) - - assert isinstance(stream, TCPStream) - - stream_reader = asyncio.StreamReader() - protocol = asyncio.StreamReaderProtocol(stream_reader) - transport = stream.stream_writer.transport - - loop_start_tls = loop.start_tls # type: ignore - transport = await asyncio.wait_for( - loop_start_tls( - transport=transport, - protocol=protocol, - sslcontext=ssl_context, - server_hostname=hostname, - ), - timeout=timeout.connect_timeout, - ) - - stream_reader.set_transport(transport) - stream.stream_reader = stream_reader - stream.stream_writer = asyncio.StreamWriter( - transport=transport, protocol=protocol, reader=stream_reader, loop=loop - ) - return stream - async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index fc784b30f8..a23d89bd30 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -47,6 +47,11 @@ class BaseTCPStream: def get_http_version(self) -> str: raise NotImplementedError() # pragma: no cover + async def start_tls( + self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig + ) -> "BaseTCPStream": + raise NotImplementedError() # pragma: no cover + async def read( self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None ) -> bytes: @@ -119,15 +124,6 @@ async def open_tcp_stream( ) -> BaseTCPStream: raise NotImplementedError() # pragma: no cover - async def start_tls( - self, - stream: BaseTCPStream, - hostname: str, - ssl_context: ssl.SSLContext, - timeout: TimeoutConfig, - ) -> BaseTCPStream: - raise NotImplementedError() # pragma: no cover - def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 3de3d14078..da8e38a0ef 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -34,6 +34,26 @@ def __init__( self.write_buffer = b"" self.write_lock = trio.Lock() + async def start_tls( + self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig + ) -> BaseTCPStream: + # Check that the write buffer is empty. We should never start a TLS stream + # while there is still pending data to write. + assert self.write_buffer == b"" + + connect_timeout = _or_inf(timeout.connect_timeout) + ssl_stream = trio.SSLStream( + self.stream, ssl_context=ssl_context, server_hostname=hostname + ) + + with trio.move_on_after(connect_timeout) as cancel_scope: + await ssl_stream.do_handshake() + + if cancel_scope.cancelled_caught: + raise ConnectTimeout() + + return TCPStream(ssl_stream, self.timeout) + def get_http_version(self) -> str: if not isinstance(self.stream, trio.SSLStream): return "HTTP/1.1" @@ -171,30 +191,6 @@ async def open_tcp_stream( return TCPStream(stream=stream, timeout=timeout) - async def start_tls( - self, - stream: BaseTCPStream, - hostname: str, - ssl_context: ssl.SSLContext, - timeout: TimeoutConfig, - ) -> BaseTCPStream: - assert isinstance(stream, TCPStream) - - connect_timeout = _or_inf(timeout.connect_timeout) - ssl_stream = trio.SSLStream( - stream.stream, ssl_context=ssl_context, server_hostname=hostname - ) - - with trio.move_on_after(connect_timeout) as cancel_scope: - await ssl_stream.do_handshake() - - if cancel_scope.cancelled_caught: - raise ConnectTimeout() - - stream.stream = ssl_stream - - return stream - async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/httpx/dispatch/proxy_http.py b/httpx/dispatch/proxy_http.py index be2e289ffc..e54a5b3073 100644 --- a/httpx/dispatch/proxy_http.py +++ b/httpx/dispatch/proxy_http.py @@ -192,11 +192,8 @@ async def tunnel_start_tls( f"proxy_url={self.proxy_url!r} " f"origin={origin!r}" ) - stream = await self.backend.start_tls( - stream=stream, - hostname=origin.host, - ssl_context=ssl_context, - timeout=timeout, + stream = await stream.start_tls( + hostname=origin.host, ssl_context=ssl_context, timeout=timeout ) http_version = stream.get_http_version() logger.debug( diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index 0798b31ff2..3b0d534000 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -184,16 +184,6 @@ async def open_tcp_stream( ) return self.stream - async def start_tls( - self, - stream: BaseTCPStream, - hostname: str, - ssl_context: ssl.SSLContext, - timeout: TimeoutConfig, - ) -> BaseTCPStream: - self.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode()) - return self.stream - # Defer all other attributes and methods to the underlying backend. def __getattr__(self, name: str) -> typing.Any: return getattr(self.backend, name) @@ -203,6 +193,12 @@ class MockRawSocketStream(BaseTCPStream): def __init__(self, backend: MockRawSocketBackend): self.backend = backend + async def start_tls( + self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig + ) -> BaseTCPStream: + self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode()) + return MockRawSocketStream(self.backend) + def get_http_version(self) -> str: return "HTTP/1.1" diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 3f9e826235..27bbeaf280 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -45,7 +45,7 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher): assert stream.is_connection_dropped() is False assert get_cipher(stream) is None - stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout) + stream = await stream.start_tls(https_server.url.host, ctx, timeout) assert stream.is_connection_dropped() is False assert get_cipher(stream) is not None From 6384d5a24dacf317f9e775b53f2a1215fa589b4a Mon Sep 17 00:00:00 2001 From: Jamie Hewland Date: Sat, 19 Oct 2019 11:47:07 +0200 Subject: [PATCH 2/2] asyncio: Keep a reference to the inner stream when upgrading to TLS --- httpx/concurrency/asyncio.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 0a7d378b70..7083345426 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -51,6 +51,8 @@ def __init__( self.stream_writer = stream_writer self.timeout = timeout + self._inner: typing.Optional[TCPStream] = None + async def start_tls( self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig ) -> BaseTCPStream: @@ -80,7 +82,12 @@ async def start_tls( transport=transport, protocol=protocol, reader=stream_reader, loop=loop ) - return TCPStream(stream_reader, stream_writer, self.timeout) + ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout) + # When we return a new TCPStream with new StreamReader/StreamWriter instances, + # we need to keep references to the old StreamReader/StreamWriter so that they + # are not garbage collected and closed while we're still using them. + ssl_stream._inner = self + return ssl_stream def get_http_version(self) -> str: ssl_object = self.stream_writer.get_extra_info("ssl_object")