Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 145 additions & 6 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import asyncio
import concurrent.futures
import functools
import ssl
import sys
import threading
import time
import typing
from types import TracebackType

from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from ..utils import get_logger
from .base import (
BaseBackgroundManager,
BaseEvent,
Expand All @@ -19,6 +23,8 @@

SSL_MONKEY_PATCH_APPLIED = False

logger = get_logger(__name__)


def ssl_monkey_patch() -> None:
"""
Expand Down Expand Up @@ -207,6 +213,8 @@ def release(self) -> None:


class AsyncioBackend(ConcurrencyBackend):
worker_thread_id: typing.Optional[int] = None

def __init__(self) -> None:
global SSL_MONKEY_PATCH_APPLIED

Expand All @@ -220,9 +228,51 @@ def loop(self) -> asyncio.AbstractEventLoop:
try:
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._loop = loop
self._loop_thread_id = threading.get_ident()

return self._loop

@property
def worker_executor(self) -> concurrent.futures.ThreadPoolExecutor:
"""
A thread executor where coroutines should be run if
the event loop of the main thread is already running.

It is important to have only one worker thread, because sync HTTP calls
will create tasks that manipulate I/O resources (namely iterate request body,
get response, close response), and those tasks MUST be running in the
same thread.
"""
# NOTE: currently, the worker executor is never shutdown. It may block
# the interpreter from exiting if the worker thread is pending or blocked.
if not hasattr(self, "_worker_executor"):
self._worker_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
return self._worker_executor

@property
def worker_loop(self) -> asyncio.AbstractEventLoop:
"""The event loop attached to the worker thread.

It is important to keep a reference to the worker event loop, because sync
HTTP calls will create tasks that manipulate I/O resources (namely iterate
request body, get response, close response), and those tasks MUST be bound to
the same event loop.
"""
if threading.get_ident() == threading.main_thread().ident:
raise RuntimeError("Cannot only access worker loop from a sub-thread")

if not hasattr(self, "_worker_loop"):
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
self._worker_loop = loop

return self._worker_loop

async def open_tcp_stream(
self,
hostname: str,
Expand Down Expand Up @@ -253,13 +303,78 @@ async def run_in_threadpool(
def run(
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
loop = self.loop
if loop.is_running():
self._loop = asyncio.new_event_loop()
initial_loop = self.loop

target: typing.Callable

if (
initial_loop.is_running()
and threading.main_thread().ident == threading.get_ident()
):
# The event loop is already running in this thread.
# We must run 'coroutine' on the worker loop in the worker thread.

def target(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
def _target() -> typing.Any:
if self.worker_thread_id is None:
self.worker_thread_id = threading.get_ident()

assert self.worker_thread_id is not None
assert threading.get_ident() == self.worker_thread_id

self._loop = self.worker_loop
return self.loop.run_until_complete(coroutine(*args, **kwargs))

future = self.worker_executor.submit(_target)
while not future.done():
time.sleep(1e-2)
return future.result()

elif initial_loop.is_running():
if self.worker_thread_id is None:
self.worker_thread_id = threading.get_ident()

assert self.worker_thread_id is not None
assert threading.get_ident() == self.worker_thread_id
# The loop is running in a different thread (i.e. the main thread)
# Run 'coroutine' on the existing loop, but from this thread (i.e.
# the worker thread).

def target(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
# 1) Create a future to hold the result.
call_result: concurrent.futures.Future[typing.Any] = (
concurrent.futures.Future()
)
# 2) Schedule the coroutine on the event loop of the main thread,
# so that it sets its result or exception onto 'call_result'.
self.loop.call_soon_threadsafe(
self.loop.create_task,
proxy_coroutine_to_future(
coroutine,
*args,
future=call_result,
exc_info=sys.exc_info(),
**kwargs,
),
)
# Wait for the coroutine to terminate, i.e. for the future to have
# a result set.
while not call_result.done():
# This blocks the worker thread, which is fine because the
# event loop is still being driven in the main thread.
time.sleep(1e-2)
return call_result.result()

else:
# Event loop is not running: we can run the coroutine on it directly.

def target(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return initial_loop.run_until_complete(coroutine(*args, **kwargs))

try:
return self.loop.run_until_complete(coroutine(*args, **kwargs))
return target(*args, **kwargs)
finally:
self._loop = loop
self._loop = initial_loop

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)
Expand Down Expand Up @@ -295,3 +410,27 @@ async def __aexit__(
await self.task
if exc_type is None:
self.task.result()


async def proxy_coroutine_to_future(
coroutine: typing.Callable,
*args: typing.Any,
future: concurrent.futures.Future,
exc_info: tuple,
**kwargs: typing.Any,
) -> None:
"""Run a coroutine, and set its return value or exception on 'future'."""
try:
# If we have an exception, run the function inside the except
# block after raising it so that 'exc_info' is correctly populated.
if exc_info[1]:
try:
raise exc_info[1]
except BaseException:
result = await coroutine(*args, **kwargs)
else:
result = await coroutine(*args, **kwargs)
except Exception as e:
future.set_exception(e)
else:
future.set_result(result)
36 changes: 36 additions & 0 deletions tests/client/test_client_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Test the behavior of Client in async environments."""

import threading

import pytest

import httpx

# NOTE: using Client in an async environment is only supported for asyncio, for now.


@pytest.mark.asyncio
async def test_sync_request_in_async_environment(server):
with httpx.Client() as client:
response = client.get(server.url)
assert response.status_code == 200
assert response.text == "Hello, world!"


@pytest.mark.asyncio
async def test_sync_request_in_async_environment_with_exception(server):
outer_thread = threading.current_thread()

class FailingDispatcher(httpx.AsyncDispatcher):
async def send(self, *args, **kwargs):
assert threading.current_thread() != outer_thread
# This shouldn't make the shared sub-thread hang.
raise ValueError("Failed")

with pytest.raises(ValueError) as ctx:
with httpx.Client(dispatch=FailingDispatcher()) as client:
client.get(server.url)

exc = ctx.value
assert isinstance(exc, ValueError)
assert str(exc) == "Failed"