From 829a6a6c24c6293ca31066dc0f6d005ea76385a9 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Fri, 6 Dec 2019 10:04:39 +0100 Subject: [PATCH 1/2] Drop BackgroundManager in favor of fork(func1, func2) --- httpx/__init__.py | 8 +----- httpx/concurrency/asyncio.py | 49 ++++++++++++++++-------------------- httpx/concurrency/auto.py | 6 ----- httpx/concurrency/base.py | 36 +++++++++++--------------- httpx/concurrency/trio.py | 48 ++++++++++++++--------------------- httpx/dispatch/http2.py | 18 +++++++++++-- tests/test_concurrency.py | 48 ++++++++++++++++++++++++++++++++++- 7 files changed, 118 insertions(+), 95 deletions(-) diff --git a/httpx/__init__.py b/httpx/__init__.py index 686359a7d3..b6cd6df2b7 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -3,12 +3,7 @@ from .auth import BasicAuth, DigestAuth from .client import Client from .concurrency.asyncio import AsyncioBackend -from .concurrency.base import ( - BaseBackgroundManager, - BasePoolSemaphore, - BaseSocketStream, - ConcurrencyBackend, -) +from .concurrency.base import BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend from .config import ( USER_AGENT, CertTypes, @@ -89,7 +84,6 @@ "VerifyTypes", "HTTPConnection", "BasePoolSemaphore", - "BaseBackgroundManager", "ConnectionPool", "HTTPProxy", "HTTPProxyMode", diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index a75971620b..41019d15b2 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -3,12 +3,10 @@ import ssl import sys import typing -from types import TracebackType from ..config import PoolLimits, Timeout from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .base import ( - BaseBackgroundManager, BaseEvent, BasePoolSemaphore, BaseSocketStream, @@ -317,34 +315,29 @@ def run( finally: self._loop = loop + async def fork( + self, + coroutine1: typing.Callable, + args1: typing.Sequence, + coroutine2: typing.Callable, + args2: typing.Sequence, + ) -> None: + task1 = self.loop.create_task(coroutine1(*args1)) + task2 = self.loop.create_task(coroutine2(*args2)) + + try: + await asyncio.gather(task1, task2) + finally: + _, pending = await asyncio.wait({task1, task2}, timeout=0) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) def create_event(self) -> BaseEvent: return typing.cast(BaseEvent, asyncio.Event()) - - def background_manager( - self, coroutine: typing.Callable, *args: typing.Any - ) -> "BackgroundManager": - return BackgroundManager(coroutine, args) - - -class BackgroundManager(BaseBackgroundManager): - def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None: - self.coroutine = coroutine - self.args = args - - async def __aenter__(self) -> "BackgroundManager": - loop = asyncio.get_event_loop() - self.task = loop.create_task(self.coroutine(*self.args)) - return self - - async def __aexit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - await self.task - if exc_type is None: - self.task.result() diff --git a/httpx/concurrency/auto.py b/httpx/concurrency/auto.py index 3dd31a8bbf..3b57e5674d 100644 --- a/httpx/concurrency/auto.py +++ b/httpx/concurrency/auto.py @@ -5,7 +5,6 @@ from ..config import PoolLimits, Timeout from .base import ( - BaseBackgroundManager, BaseEvent, BasePoolSemaphore, BaseSocketStream, @@ -52,8 +51,3 @@ async def run_in_threadpool( def create_event(self) -> BaseEvent: return self.backend.create_event() - - def background_manager( - self, coroutine: typing.Callable, *args: typing.Any - ) -> BaseBackgroundManager: - return self.backend.background_manager(coroutine, *args) diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index f32501a3a0..ff5f72f30d 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -1,6 +1,5 @@ import ssl import typing -from types import TracebackType from ..config import PoolLimits, Timeout @@ -154,27 +153,20 @@ def run( def create_event(self) -> BaseEvent: raise NotImplementedError() # pragma: no cover - def background_manager( - self, coroutine: typing.Callable, *args: typing.Any - ) -> "BaseBackgroundManager": - raise NotImplementedError() # pragma: no cover - - -class BaseBackgroundManager: - async def __aenter__(self) -> "BaseBackgroundManager": - raise NotImplementedError() # pragma: no cover - - async def __aexit__( + async def fork( self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, + coroutine1: typing.Callable, + args1: typing.Sequence, + coroutine2: typing.Callable, + args2: typing.Sequence, ) -> None: - raise NotImplementedError() # pragma: no cover + """ + Run two coroutines concurrently. + + This should start 'coroutine1' with '*args1' and 'coroutine2' with '*args2', + and wait for them to finish. - async def close(self, exception: BaseException = None) -> None: - if exception is None: - await self.__aexit__(None, None, None) - else: - traceback = exception.__traceback__ # type: ignore - await self.__aexit__(type(exception), exception, traceback) + In case one of the coroutines raises an exception, cancel the other one then + raise. If the other coroutine had also raised an exception, ignore it (for now). + """ + raise NotImplementedError() # pragma: no cover diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 4af4242e19..f1fc7c4286 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -1,14 +1,12 @@ import functools import ssl import typing -from types import TracebackType import trio from ..config import PoolLimits, Timeout from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .base import ( - BaseBackgroundManager, BaseEvent, BasePoolSemaphore, BaseSocketStream, @@ -204,17 +202,31 @@ def run( functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args ) + async def fork( + self, + coroutine1: typing.Callable, + args1: typing.Sequence, + coroutine2: typing.Callable, + args2: typing.Sequence, + ) -> None: + try: + async with trio.open_nursery() as nursery: + nursery.start_soon(coroutine1, *args1) + nursery.start_soon(coroutine2, *args2) + except trio.MultiError as exc: + # NOTE: asyncio doesn't handle multi-errors yet, so we must align on its + # behavior here, and need to arbitrarily decide which exception to raise. + # We may want to add an 'httpx.MultiError', manually add support + # for this situation in the asyncio backend, and re-raise + # an 'httpx.MultiError' from trio's here. + raise exc.exceptions[0] + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) def create_event(self) -> BaseEvent: return Event() - def background_manager( - self, coroutine: typing.Callable, *args: typing.Any - ) -> "BackgroundManager": - return BackgroundManager(coroutine, *args) - class Event(BaseEvent): def __init__(self) -> None: @@ -233,25 +245,3 @@ def clear(self) -> None: # trio.Event.clear() was deprecated in Trio 0.12. # https://github.com/python-trio/trio/issues/637 self._event = trio.Event() - - -class BackgroundManager(BaseBackgroundManager): - def __init__(self, coroutine: typing.Callable, *args: typing.Any) -> None: - self.coroutine = coroutine - self.args = args - self.nursery_manager = trio.open_nursery() - self.nursery: typing.Optional[trio.Nursery] = None - - async def __aenter__(self) -> "BackgroundManager": - self.nursery = await self.nursery_manager.__aenter__() - self.nursery.start_soon(self.coroutine, *self.args) - return self - - async def __aexit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - assert self.nursery is not None - await self.nursery_manager.__aexit__(exc_type, exc_value, traceback) diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 09368c884e..471ba9b7e0 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -65,9 +65,23 @@ async def send(self, request: Request, timeout: Timeout = None) -> Response: self.timeout_flags[stream_id] = TimeoutFlag() self.window_update_received[stream_id] = self.backend.create_event() - task, args = self.send_request_data, [stream_id, request.stream(), timeout] - async with self.backend.background_manager(task, *args): + status_code: typing.Optional[int] = None + headers: typing.Optional[list] = None + + async def receive_response(stream_id: int, timeout: Timeout) -> None: + nonlocal status_code, headers status_code, headers = await self.receive_response(stream_id, timeout) + + await self.backend.fork( + self.send_request_data, + [stream_id, request.stream(), timeout], + receive_response, + [stream_id, timeout], + ) + + assert status_code is not None + assert headers is not None + content = self.body_iter(stream_id, timeout) on_close = functools.partial(self.response_closed, stream_id=stream_id) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 4fecaa6dcb..0898dd4bb7 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -3,7 +3,7 @@ from httpx import AsyncioBackend, SSLConfig, Timeout from httpx.concurrency.trio import TrioBackend -from tests.concurrency import run_concurrently +from tests.concurrency import run_concurrently, sleep def get_asyncio_cipher(stream): @@ -110,3 +110,49 @@ async def test_concurrent_read(server, backend): ) finally: await stream.close() + + +async def test_fork(backend): + ok_counter = 0 + + async def ok(delay: int) -> None: + nonlocal ok_counter + await sleep(backend, delay) + ok_counter += 1 + + async def fail(message: str, delay: int) -> None: + await sleep(backend, delay) + raise RuntimeError(message) + + await backend.fork(ok, [0], ok, [0]) + assert ok_counter == 2 + + with pytest.raises(RuntimeError, match="Oops"): + await backend.fork(ok, [0], fail, ["Oops", 0.01]) + + assert ok_counter == 3 + + with pytest.raises(RuntimeError, match="Oops"): + await backend.fork(ok, [0.01], fail, ["Oops", 0]) + + assert ok_counter == 3 + + with pytest.raises(RuntimeError, match="Oops"): + await backend.fork(fail, ["Oops", 0.01], ok, [0]) + + assert ok_counter == 4 + + with pytest.raises(RuntimeError, match="Oops"): + await backend.fork(fail, ["Oops", 0], ok, [0.01]) + + assert ok_counter == 4 + + with pytest.raises(RuntimeError, match="My bad"): + await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0.01]) + + with pytest.raises(RuntimeError, match="Oops"): + await backend.fork(fail, ["My bad", 0.01], fail, ["Oops", 0]) + + # No 'match', since we can't know which will be raised first. + with pytest.raises(RuntimeError): + await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0]) From 57b6a23f60539801e2b2011827287decb3c442b2 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Fri, 6 Dec 2019 10:21:57 +0100 Subject: [PATCH 2/2] Please mypy --- httpx/concurrency/asyncio.py | 1 + 1 file changed, 1 insertion(+) diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 41019d15b2..eee4ec32d5 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -328,6 +328,7 @@ async def fork( try: await asyncio.gather(task1, task2) finally: + pending: typing.Set[asyncio.Future[typing.Any]] # Please mypy. _, pending = await asyncio.wait({task1, task2}, timeout=0) for task in pending: task.cancel()