From b6513801a72f01b15ce5ed86fb8700dfbd34f191 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Fri, 8 Nov 2019 21:05:19 +0100 Subject: [PATCH] Allow using Client in an async environment --- httpx/concurrency/asyncio.py | 151 ++++++++++++++++++++++++++++-- tests/client/test_client_async.py | 36 +++++++ 2 files changed, 181 insertions(+), 6 deletions(-) create mode 100644 tests/client/test_client_async.py diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 4aeb7ca53d..23d904cfd9 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -1,12 +1,16 @@ import asyncio +import concurrent.futures import functools import ssl import sys +import threading +import time import typing from types import TracebackType from ..config import PoolLimits, TimeoutConfig from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout +from ..utils import get_logger from .base import ( BaseBackgroundManager, BaseEvent, @@ -19,6 +23,8 @@ SSL_MONKEY_PATCH_APPLIED = False +logger = get_logger(__name__) + def ssl_monkey_patch() -> None: """ @@ -207,6 +213,8 @@ def release(self) -> None: class AsyncioBackend(ConcurrencyBackend): + worker_thread_id: typing.Optional[int] = None + def __init__(self) -> None: global SSL_MONKEY_PATCH_APPLIED @@ -220,9 +228,51 @@ def loop(self) -> asyncio.AbstractEventLoop: try: self._loop = asyncio.get_event_loop() except RuntimeError: - self._loop = asyncio.new_event_loop() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._loop = loop + self._loop_thread_id = threading.get_ident() + return self._loop + @property + def worker_executor(self) -> concurrent.futures.ThreadPoolExecutor: + """ + A thread executor where coroutines should be run if + the event loop of the main thread is already running. + + It is important to have only one worker thread, because sync HTTP calls + will create tasks that manipulate I/O resources (namely iterate request body, + get response, close response), and those tasks MUST be running in the + same thread. + """ + # NOTE: currently, the worker executor is never shutdown. It may block + # the interpreter from exiting if the worker thread is pending or blocked. + if not hasattr(self, "_worker_executor"): + self._worker_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + return self._worker_executor + + @property + def worker_loop(self) -> asyncio.AbstractEventLoop: + """The event loop attached to the worker thread. + + It is important to keep a reference to the worker event loop, because sync + HTTP calls will create tasks that manipulate I/O resources (namely iterate + request body, get response, close response), and those tasks MUST be bound to + the same event loop. + """ + if threading.get_ident() == threading.main_thread().ident: + raise RuntimeError("Cannot only access worker loop from a sub-thread") + + if not hasattr(self, "_worker_loop"): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + self._worker_loop = loop + + return self._worker_loop + async def open_tcp_stream( self, hostname: str, @@ -253,13 +303,78 @@ async def run_in_threadpool( def run( self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: - loop = self.loop - if loop.is_running(): - self._loop = asyncio.new_event_loop() + initial_loop = self.loop + + target: typing.Callable + + if ( + initial_loop.is_running() + and threading.main_thread().ident == threading.get_ident() + ): + # The event loop is already running in this thread. + # We must run 'coroutine' on the worker loop in the worker thread. + + def target(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + def _target() -> typing.Any: + if self.worker_thread_id is None: + self.worker_thread_id = threading.get_ident() + + assert self.worker_thread_id is not None + assert threading.get_ident() == self.worker_thread_id + + self._loop = self.worker_loop + return self.loop.run_until_complete(coroutine(*args, **kwargs)) + + future = self.worker_executor.submit(_target) + while not future.done(): + time.sleep(1e-2) + return future.result() + + elif initial_loop.is_running(): + if self.worker_thread_id is None: + self.worker_thread_id = threading.get_ident() + + assert self.worker_thread_id is not None + assert threading.get_ident() == self.worker_thread_id + # The loop is running in a different thread (i.e. the main thread) + # Run 'coroutine' on the existing loop, but from this thread (i.e. + # the worker thread). + + def target(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + # 1) Create a future to hold the result. + call_result: concurrent.futures.Future[typing.Any] = ( + concurrent.futures.Future() + ) + # 2) Schedule the coroutine on the event loop of the main thread, + # so that it sets its result or exception onto 'call_result'. + self.loop.call_soon_threadsafe( + self.loop.create_task, + proxy_coroutine_to_future( + coroutine, + *args, + future=call_result, + exc_info=sys.exc_info(), + **kwargs, + ), + ) + # Wait for the coroutine to terminate, i.e. for the future to have + # a result set. + while not call_result.done(): + # This blocks the worker thread, which is fine because the + # event loop is still being driven in the main thread. + time.sleep(1e-2) + return call_result.result() + + else: + # Event loop is not running: we can run the coroutine on it directly. + + def target(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + return initial_loop.run_until_complete(coroutine(*args, **kwargs)) + try: - return self.loop.run_until_complete(coroutine(*args, **kwargs)) + return target(*args, **kwargs) finally: - self._loop = loop + self._loop = initial_loop def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) @@ -295,3 +410,27 @@ async def __aexit__( await self.task if exc_type is None: self.task.result() + + +async def proxy_coroutine_to_future( + coroutine: typing.Callable, + *args: typing.Any, + future: concurrent.futures.Future, + exc_info: tuple, + **kwargs: typing.Any, +) -> None: + """Run a coroutine, and set its return value or exception on 'future'.""" + try: + # If we have an exception, run the function inside the except + # block after raising it so that 'exc_info' is correctly populated. + if exc_info[1]: + try: + raise exc_info[1] + except BaseException: + result = await coroutine(*args, **kwargs) + else: + result = await coroutine(*args, **kwargs) + except Exception as e: + future.set_exception(e) + else: + future.set_result(result) diff --git a/tests/client/test_client_async.py b/tests/client/test_client_async.py new file mode 100644 index 0000000000..5d21ff57b6 --- /dev/null +++ b/tests/client/test_client_async.py @@ -0,0 +1,36 @@ +"""Test the behavior of Client in async environments.""" + +import threading + +import pytest + +import httpx + +# NOTE: using Client in an async environment is only supported for asyncio, for now. + + +@pytest.mark.asyncio +async def test_sync_request_in_async_environment(server): + with httpx.Client() as client: + response = client.get(server.url) + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +@pytest.mark.asyncio +async def test_sync_request_in_async_environment_with_exception(server): + outer_thread = threading.current_thread() + + class FailingDispatcher(httpx.AsyncDispatcher): + async def send(self, *args, **kwargs): + assert threading.current_thread() != outer_thread + # This shouldn't make the shared sub-thread hang. + raise ValueError("Failed") + + with pytest.raises(ValueError) as ctx: + with httpx.Client(dispatch=FailingDispatcher()) as client: + client.get(server.url) + + exc = ctx.value + assert isinstance(exc, ValueError) + assert str(exc) == "Failed"