Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions httpx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,7 @@
from .auth import BasicAuth, DigestAuth
from .client import Client
from .concurrency.asyncio import AsyncioBackend
from .concurrency.base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseSocketStream,
ConcurrencyBackend,
)
from .concurrency.base import BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
from .config import (
USER_AGENT,
CertTypes,
Expand Down Expand Up @@ -89,7 +84,6 @@
"VerifyTypes",
"HTTPConnection",
"BasePoolSemaphore",
"BaseBackgroundManager",
"ConnectionPool",
"HTTPProxy",
"HTTPProxyMode",
Expand Down
50 changes: 22 additions & 28 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
import ssl
import sys
import typing
from types import TracebackType

from ..config import PoolLimits, Timeout
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
BaseSocketStream,
Expand Down Expand Up @@ -317,34 +315,30 @@ def run(
finally:
self._loop = loop

async def fork(
self,
coroutine1: typing.Callable,
args1: typing.Sequence,
coroutine2: typing.Callable,
args2: typing.Sequence,
) -> None:
task1 = self.loop.create_task(coroutine1(*args1))
task2 = self.loop.create_task(coroutine2(*args2))

try:
await asyncio.gather(task1, task2)
finally:
pending: typing.Set[asyncio.Future[typing.Any]] # Please mypy.
_, pending = await asyncio.wait({task1, task2}, timeout=0)
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This'll look a bunch nicer with anyio. (Tho we should leave Trio's implementation alone, since they already have a sensible nursery primitive.)


def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)

def create_event(self) -> BaseEvent:
return typing.cast(BaseEvent, asyncio.Event())

def background_manager(
self, coroutine: typing.Callable, *args: typing.Any
) -> "BackgroundManager":
return BackgroundManager(coroutine, args)


class BackgroundManager(BaseBackgroundManager):
def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None:
self.coroutine = coroutine
self.args = args

async def __aenter__(self) -> "BackgroundManager":
loop = asyncio.get_event_loop()
self.task = loop.create_task(self.coroutine(*self.args))
return self

async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
await self.task
if exc_type is None:
self.task.result()
6 changes: 0 additions & 6 deletions httpx/concurrency/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from ..config import PoolLimits, Timeout
from .base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
BaseSocketStream,
Expand Down Expand Up @@ -52,8 +51,3 @@ async def run_in_threadpool(

def create_event(self) -> BaseEvent:
return self.backend.create_event()

def background_manager(
self, coroutine: typing.Callable, *args: typing.Any
) -> BaseBackgroundManager:
return self.backend.background_manager(coroutine, *args)
36 changes: 14 additions & 22 deletions httpx/concurrency/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ssl
import typing
from types import TracebackType

from ..config import PoolLimits, Timeout

Expand Down Expand Up @@ -154,27 +153,20 @@ def run(
def create_event(self) -> BaseEvent:
raise NotImplementedError() # pragma: no cover

def background_manager(
self, coroutine: typing.Callable, *args: typing.Any
) -> "BaseBackgroundManager":
raise NotImplementedError() # pragma: no cover


class BaseBackgroundManager:
async def __aenter__(self) -> "BaseBackgroundManager":
raise NotImplementedError() # pragma: no cover

async def __aexit__(
async def fork(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
coroutine1: typing.Callable,
args1: typing.Sequence,
coroutine2: typing.Callable,
args2: typing.Sequence,
) -> None:
raise NotImplementedError() # pragma: no cover
"""
Run two coroutines concurrently.

This should start 'coroutine1' with '*args1' and 'coroutine2' with '*args2',
and wait for them to finish.

async def close(self, exception: BaseException = None) -> None:
if exception is None:
await self.__aexit__(None, None, None)
else:
traceback = exception.__traceback__ # type: ignore
await self.__aexit__(type(exception), exception, traceback)
In case one of the coroutines raises an exception, cancel the other one then
raise. If the other coroutine had also raised an exception, ignore it (for now).
"""
raise NotImplementedError() # pragma: no cover
48 changes: 19 additions & 29 deletions httpx/concurrency/trio.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import functools
import ssl
import typing
from types import TracebackType

import trio

from ..config import PoolLimits, Timeout
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
BaseSocketStream,
Expand Down Expand Up @@ -204,17 +202,31 @@ def run(
functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
)

async def fork(
self,
coroutine1: typing.Callable,
args1: typing.Sequence,
coroutine2: typing.Callable,
args2: typing.Sequence,
) -> None:
try:
async with trio.open_nursery() as nursery:
nursery.start_soon(coroutine1, *args1)
nursery.start_soon(coroutine2, *args2)
except trio.MultiError as exc:
# NOTE: asyncio doesn't handle multi-errors yet, so we must align on its
# behavior here, and need to arbitrarily decide which exception to raise.
# We may want to add an 'httpx.MultiError', manually add support
# for this situation in the asyncio backend, and re-raise
# an 'httpx.MultiError' from trio's here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can be more direct on our wording here. If MultiError occurs then we semantically only want either one of those two exceptions raised up to the user.

raise exc.exceptions[0]

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)

def create_event(self) -> BaseEvent:
return Event()

def background_manager(
self, coroutine: typing.Callable, *args: typing.Any
) -> "BackgroundManager":
return BackgroundManager(coroutine, *args)


class Event(BaseEvent):
def __init__(self) -> None:
Expand All @@ -233,25 +245,3 @@ def clear(self) -> None:
# trio.Event.clear() was deprecated in Trio 0.12.
# https://github.com/python-trio/trio/issues/637
self._event = trio.Event()


class BackgroundManager(BaseBackgroundManager):
def __init__(self, coroutine: typing.Callable, *args: typing.Any) -> None:
self.coroutine = coroutine
self.args = args
self.nursery_manager = trio.open_nursery()
self.nursery: typing.Optional[trio.Nursery] = None

async def __aenter__(self) -> "BackgroundManager":
self.nursery = await self.nursery_manager.__aenter__()
self.nursery.start_soon(self.coroutine, *self.args)
return self

async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
assert self.nursery is not None
await self.nursery_manager.__aexit__(exc_type, exc_value, traceback)
18 changes: 16 additions & 2 deletions httpx/dispatch/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,23 @@ async def send(self, request: Request, timeout: Timeout = None) -> Response:
self.timeout_flags[stream_id] = TimeoutFlag()
self.window_update_received[stream_id] = self.backend.create_event()

task, args = self.send_request_data, [stream_id, request.stream(), timeout]
async with self.backend.background_manager(task, *args):
status_code: typing.Optional[int] = None
headers: typing.Optional[list] = None

async def receive_response(stream_id: int, timeout: Timeout) -> None:
nonlocal status_code, headers
status_code, headers = await self.receive_response(stream_id, timeout)

await self.backend.fork(
self.send_request_data,
[stream_id, request.stream(), timeout],
receive_response,
[stream_id, timeout],
)

assert status_code is not None
assert headers is not None

content = self.body_iter(stream_id, timeout)
on_close = functools.partial(self.response_closed, stream_id=stream_id)

Expand Down
48 changes: 47 additions & 1 deletion tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from httpx import AsyncioBackend, SSLConfig, Timeout
from httpx.concurrency.trio import TrioBackend
from tests.concurrency import run_concurrently
from tests.concurrency import run_concurrently, sleep


def get_asyncio_cipher(stream):
Expand Down Expand Up @@ -110,3 +110,49 @@ async def test_concurrent_read(server, backend):
)
finally:
await stream.close()


async def test_fork(backend):
ok_counter = 0

async def ok(delay: int) -> None:
nonlocal ok_counter
await sleep(backend, delay)
ok_counter += 1

async def fail(message: str, delay: int) -> None:
await sleep(backend, delay)
raise RuntimeError(message)

await backend.fork(ok, [0], ok, [0])
assert ok_counter == 2

with pytest.raises(RuntimeError, match="Oops"):
await backend.fork(ok, [0], fail, ["Oops", 0.01])

assert ok_counter == 3

with pytest.raises(RuntimeError, match="Oops"):
await backend.fork(ok, [0.01], fail, ["Oops", 0])

assert ok_counter == 3

with pytest.raises(RuntimeError, match="Oops"):
await backend.fork(fail, ["Oops", 0.01], ok, [0])

assert ok_counter == 4

with pytest.raises(RuntimeError, match="Oops"):
await backend.fork(fail, ["Oops", 0], ok, [0.01])

assert ok_counter == 4

with pytest.raises(RuntimeError, match="My bad"):
await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0.01])

with pytest.raises(RuntimeError, match="Oops"):
await backend.fork(fail, ["My bad", 0.01], fail, ["Oops", 0])

# No 'match', since we can't know which will be raised first.
with pytest.raises(RuntimeError):
await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0])