From e382d49874edc2e157bc299dbbe3b83c39f6679d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 24 May 2019 15:17:05 +0100 Subject: [PATCH 01/19] Support thread-pooled dispatch --- httpcore/__init__.py | 9 +++- httpcore/client.py | 21 ++++++--- httpcore/concurrency.py | 13 ++++++ httpcore/dispatch/connection.py | 4 +- httpcore/dispatch/connection_pool.py | 4 +- httpcore/dispatch/http11.py | 2 +- httpcore/dispatch/http2.py | 2 +- httpcore/dispatch/threaded.py | 31 +++++++++++++ httpcore/interfaces.py | 66 +++++++++++++++++++++++++--- tests/client/test_auth.py | 4 +- tests/client/test_cookies.py | 4 +- tests/client/test_redirects.py | 4 +- tests/dispatch/test_threaded.py | 51 +++++++++++++++++++++ 13 files changed, 191 insertions(+), 24 deletions(-) create mode 100644 httpcore/dispatch/threaded.py create mode 100644 tests/dispatch/test_threaded.py diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 6d073a8de3..45508d3b4a 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -28,7 +28,14 @@ TooManyRedirects, WriteTimeout, ) -from .interfaces import BaseReader, BaseWriter, ConcurrencyBackend, Dispatcher, Protocol +from .interfaces import ( + AsyncDispatcher, + BaseReader, + BaseWriter, + ConcurrencyBackend, + Dispatcher, + Protocol, +) from .models import URL, Cookies, Headers, Origin, QueryParams, Request, Response from .status_codes import StatusCode, codes diff --git a/httpcore/client.py b/httpcore/client.py index 2946a753fd..4f965cbe5b 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -3,6 +3,7 @@ from types import TracebackType from .auth import HTTPBasicAuth +from .concurrency import AsyncioBackend from .config import ( DEFAULT_MAX_REDIRECTS, DEFAULT_POOL_LIMITS, @@ -13,8 +14,9 @@ VerifyTypes, ) from .dispatch.connection_pool import ConnectionPool +from .dispatch.threaded import ThreadedDispatcher from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects -from .interfaces import ConcurrencyBackend, Dispatcher +from .interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher from .models import ( URL, AuthTypes, @@ -42,22 +44,29 @@ def __init__( timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, max_redirects: int = DEFAULT_MAX_REDIRECTS, - dispatch: Dispatcher = None, + dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None, backend: ConcurrencyBackend = None, ): + if backend is None: + backend = AsyncioBackend() + if dispatch is None: - dispatch = ConnectionPool( + async_dispatch = ConnectionPool( verify=verify, cert=cert, timeout=timeout, pool_limits=pool_limits, backend=backend, - ) + ) # type: AsyncDispatcher + elif isinstance(dispatch, Dispatcher): + async_dispatch = ThreadedDispatcher(dispatch, backend) + else: + async_dispatch = dispatch self.auth = auth self.cookies = Cookies(cookies) self.max_redirects = max_redirects - self.dispatch = dispatch + self.dispatch = async_dispatch async def get( self, @@ -500,7 +509,7 @@ def __init__( timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, max_redirects: int = DEFAULT_MAX_REDIRECTS, - dispatch: Dispatcher = None, + dispatch: typing.Union[Dispatcher, AsyncDispatcher] = None, backend: ConcurrencyBackend = None, ) -> None: self._client = AsyncClient( diff --git a/httpcore/concurrency.py b/httpcore/concurrency.py index 0c1d3409eb..9ec10b2879 100644 --- a/httpcore/concurrency.py +++ b/httpcore/concurrency.py @@ -9,6 +9,7 @@ based, and less strictly `asyncio`-specific. """ import asyncio +import functools import ssl import typing @@ -133,6 +134,12 @@ def __init__(self) -> None: ssl_monkey_patch() SSL_MONKEY_PATCH_APPLIED = True + @property + def loop(self) -> asyncio.BaseEventLoop: + if not hasattr(self, '_loop'): + self._loop = asyncio.get_event_loop() + return self._loop + async def connect( self, hostname: str, @@ -162,5 +169,11 @@ async def connect( return (reader, writer, protocol) + async def run_in_threadpool(self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: + if kwargs: + # loop.run_in_executor doesn't accept 'kwargs', so bind them in here + func = functools.partial(func, **kwargs) + return await self.loop.run_in_executor(None, func, *args) + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py index 60214333fe..3c6cc81cb8 100644 --- a/httpcore/dispatch/connection.py +++ b/httpcore/dispatch/connection.py @@ -15,7 +15,7 @@ VerifyTypes, ) from ..exceptions import ConnectTimeout -from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol +from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Protocol from ..models import Origin, Request, Response from .http2 import HTTP2Connection from .http11 import HTTP11Connection @@ -24,7 +24,7 @@ ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]] -class HTTPConnection(Dispatcher): +class HTTPConnection(AsyncDispatcher): def __init__( self, origin: typing.Union[str, Origin], diff --git a/httpcore/dispatch/connection_pool.py b/httpcore/dispatch/connection_pool.py index e7cefbd7e4..713e56fb24 100644 --- a/httpcore/dispatch/connection_pool.py +++ b/httpcore/dispatch/connection_pool.py @@ -12,7 +12,7 @@ ) from ..decoders import ACCEPT_ENCODING from ..exceptions import PoolTimeout -from ..interfaces import ConcurrencyBackend, Dispatcher +from ..interfaces import AsyncDispatcher, ConcurrencyBackend from ..models import Origin, Request, Response from .connection import HTTPConnection @@ -77,7 +77,7 @@ def __len__(self) -> int: return len(self.all) -class ConnectionPool(Dispatcher): +class ConnectionPool(AsyncDispatcher): def __init__( self, *, diff --git a/httpcore/dispatch/http11.py b/httpcore/dispatch/http11.py index 4308f64a3a..03da58a9d9 100644 --- a/httpcore/dispatch/http11.py +++ b/httpcore/dispatch/http11.py @@ -4,7 +4,7 @@ from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout -from ..interfaces import BaseReader, BaseWriter, Dispatcher +from ..interfaces import BaseReader, BaseWriter from ..models import Request, Response H11Event = typing.Union[ diff --git a/httpcore/dispatch/http2.py b/httpcore/dispatch/http2.py index bb1857f307..013fc0a4a6 100644 --- a/httpcore/dispatch/http2.py +++ b/httpcore/dispatch/http2.py @@ -6,7 +6,7 @@ from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout -from ..interfaces import BaseReader, BaseWriter, Dispatcher +from ..interfaces import BaseReader, BaseWriter from ..models import Request, Response diff --git a/httpcore/dispatch/threaded.py b/httpcore/dispatch/threaded.py new file mode 100644 index 0000000000..441abbb8cc --- /dev/null +++ b/httpcore/dispatch/threaded.py @@ -0,0 +1,31 @@ +from ..config import CertTypes, TimeoutTypes, VerifyTypes +from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher +from ..models import Request, Response + + +class ThreadedDispatcher(AsyncDispatcher): + def __init__(self, dispatch: Dispatcher, backend: ConcurrencyBackend) -> None: + self.sync_dispatcher = dispatch + self.backend = backend + + async def send( + self, + request: Request, + stream: bool = False, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> Response: + func = self.sync_dispatcher.send + kwargs = { + "request": request, + "stream": stream, + "verify": verify, + "cert": cert, + "timeout": timeout, + } + return await self.backend.run_in_threadpool(func, **kwargs) + + async def close(self) -> None: + func = self.sync_dispatcher.close + await self.backend.run_in_threadpool(func) diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index 42ffd157ae..a6a511f00e 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -21,9 +21,9 @@ class Protocol(str, enum.Enum): HTTP_2 = "HTTP/2" -class Dispatcher: +class AsyncDispatcher: """ - Base class for dispatcher classes, that handle sending the request. + Base class for async dispatcher classes, that handle sending the request. Stubs out the interface, as well as providing a `.request()` convienence implementation, to make it easy to use or test stand-alone dispatchers, @@ -44,10 +44,9 @@ async def request( timeout: TimeoutTypes = None ) -> Response: request = Request(method, url, data=data, params=params, headers=headers) - response = await self.send( + return await self.send( request, stream=stream, verify=verify, cert=cert, timeout=timeout ) - return response async def send( self, @@ -62,7 +61,7 @@ async def send( async def close(self) -> None: pass # pragma: nocover - async def __aenter__(self) -> "Dispatcher": + async def __aenter__(self) -> "AsyncDispatcher": return self async def __aexit__( @@ -74,6 +73,58 @@ async def __aexit__( await self.close() +class Dispatcher: + """ + Base class for syncronous dispatcher classes, that handle sending the request. + + Stubs out the interface, as well as providing a `.request()` convienence + implementation, to make it easy to use or test stand-alone dispatchers, + without requiring a complete `Client` instance. + """ + + def request( + self, + method: str, + url: URLTypes, + *, + data: RequestData = b"", + params: QueryParamTypes = None, + headers: HeaderTypes = None, + stream: bool = False, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None + ) -> Response: + request = Request(method, url, data=data, params=params, headers=headers) + return self.send( + request, stream=stream, verify=verify, cert=cert, timeout=timeout + ) + + def send( + self, + request: Request, + stream: bool = False, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> Response: + raise NotImplementedError() # pragma: nocover + + def close(self) -> None: + pass # pragma: nocover + + def __enter__(self) -> "Dispatcher": + return self + + def __exit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self.close() + + class BaseReader: """ A stream reader. Abstracts away any asyncio-specfic interfaces @@ -128,3 +179,8 @@ async def connect( def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover + + def run_in_threadpool( + self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + raise NotImplementedError() # pragma: no cover diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 1d2b97239c..631a058561 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -6,7 +6,7 @@ URL, CertTypes, Client, - Dispatcher, + AsyncDispatcher, Request, Response, TimeoutTypes, @@ -14,7 +14,7 @@ ) -class MockDispatch(Dispatcher): +class MockDispatch(AsyncDispatcher): async def send( self, request: Request, diff --git a/tests/client/test_cookies.py b/tests/client/test_cookies.py index a21f5c134f..e96c386bc1 100644 --- a/tests/client/test_cookies.py +++ b/tests/client/test_cookies.py @@ -8,7 +8,7 @@ CertTypes, Client, Cookies, - Dispatcher, + AsyncDispatcher, Request, Response, TimeoutTypes, @@ -16,7 +16,7 @@ ) -class MockDispatch(Dispatcher): +class MockDispatch(AsyncDispatcher): async def send( self, request: Request, diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index c3b384dc95..24acc3aff8 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -7,7 +7,7 @@ URL, AsyncClient, CertTypes, - Dispatcher, + AsyncDispatcher, RedirectBodyUnavailable, RedirectLoop, Request, @@ -19,7 +19,7 @@ ) -class MockDispatch(Dispatcher): +class MockDispatch(AsyncDispatcher): async def send( self, request: Request, diff --git a/tests/dispatch/test_threaded.py b/tests/dispatch/test_threaded.py new file mode 100644 index 0000000000..b2d99698e9 --- /dev/null +++ b/tests/dispatch/test_threaded.py @@ -0,0 +1,51 @@ +import json + +import pytest + +from httpcore import ( + CertTypes, + Client, + Dispatcher, + Request, + Response, + TimeoutTypes, + VerifyTypes, +) + + +class MockDispatch(Dispatcher): + def send( + self, + request: Request, + stream: bool = False, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> Response: + body = json.dumps({"hello": "world"}).encode() + return Response(200, content=body, request=request) + + +def test_threaded_dispatch(): + """ + Use a syncronous 'Dispatcher' class with the client. + Calls to the dispatcher will end up running within a thread pool. + """ + url = "https://example.org/" + with Client(dispatch=MockDispatch()) as client: + response = client.get(url) + + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + +def test_dispatch_class(): + """ + Use a syncronous 'Dispatcher' class directly. + """ + url = "https://example.org/" + with MockDispatch() as dispatcher: + response = dispatcher.request("GET", url) + + assert response.status_code == 200 + assert response.json() == {"hello": "world"} From 51fedbfe438ddedd447f39e41e55aac49f7355a1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 24 May 2019 16:09:28 +0100 Subject: [PATCH 02/19] Add ConcurrencyBackend.run --- httpcore/client.py | 32 ++++++++++++++++++-------------- httpcore/concurrency.py | 18 ++++++++++++++---- httpcore/interfaces.py | 5 +++++ httpcore/models.py | 17 ++++++++++------- tests/client/test_auth.py | 2 +- tests/client/test_cookies.py | 2 +- tests/client/test_redirects.py | 2 +- 7 files changed, 50 insertions(+), 28 deletions(-) diff --git a/httpcore/client.py b/httpcore/client.py index 4f965cbe5b..e0d00d8664 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -1,4 +1,3 @@ -import asyncio import typing from types import TracebackType @@ -67,6 +66,7 @@ def __init__( self.cookies = Cookies(cookies) self.max_redirects = max_redirects self.dispatch = async_dispatch + self.concurrency_backend = backend async def get( self, @@ -522,12 +522,15 @@ def __init__( dispatch=dispatch, backend=backend, ) - self._loop = asyncio.new_event_loop() @property def cookies(self) -> Cookies: return self._client.cookies + @property + def concurrency_backend(self) -> ConcurrencyBackend: + return self._client.concurrency_backend + def request( self, method: str, @@ -781,21 +784,22 @@ def send( cert: CertTypes = None, timeout: TimeoutTypes = None, ) -> SyncResponse: - response = self._loop.run_until_complete( - self._client.send( - request, - stream=stream, - auth=auth, - allow_redirects=allow_redirects, - verify=verify, - cert=cert, - timeout=timeout, - ) + coroutine = self._client.send + args = [request] + kwargs = dict( + stream=stream, + auth=auth, + allow_redirects=allow_redirects, + verify=verify, + cert=cert, + timeout=timeout, ) - return SyncResponse(response, self._loop) + response = self.concurrency_backend.run(coroutine, *args, **kwargs) + return SyncResponse(response, self.concurrency_backend) def close(self) -> None: - self._loop.run_until_complete(self._client.close()) + coroutine = self._client.close + self.concurrency_backend.run(coroutine) def __enter__(self) -> "Client": return self diff --git a/httpcore/concurrency.py b/httpcore/concurrency.py index 9ec10b2879..d45953904f 100644 --- a/httpcore/concurrency.py +++ b/httpcore/concurrency.py @@ -135,9 +135,12 @@ def __init__(self) -> None: SSL_MONKEY_PATCH_APPLIED = True @property - def loop(self) -> asyncio.BaseEventLoop: - if not hasattr(self, '_loop'): - self._loop = asyncio.get_event_loop() + def loop(self) -> asyncio.AbstractEventLoop: + if not hasattr(self, "_loop"): + try: + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() return self._loop async def connect( @@ -169,11 +172,18 @@ async def connect( return (reader, writer, protocol) - async def run_in_threadpool(self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: + async def run_in_threadpool( + self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: if kwargs: # loop.run_in_executor doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) return await self.loop.run_in_executor(None, func, *args) + def run( + self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + return self.loop.run_until_complete(coroutine(*args, **kwargs)) + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index a6a511f00e..56bf5e4d0e 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -184,3 +184,8 @@ def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: raise NotImplementedError() # pragma: no cover + + def run( + self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + raise NotImplementedError() # pragma: no cover diff --git a/httpcore/models.py b/httpcore/models.py index f8c1084716..6bcd42608d 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -1,4 +1,3 @@ -import asyncio import cgi import email.message import json as jsonlib @@ -776,9 +775,9 @@ class SyncResponse: instance, providing standard synchronous interfaces where required. """ - def __init__(self, response: Response, loop: asyncio.AbstractEventLoop): + def __init__(self, response: Response, backend: "ConcurrencyBackend"): self._response = response - self._loop = loop + self._backend = backend @property def status_code(self) -> int: @@ -827,13 +826,13 @@ def json(self) -> typing.Any: return self._response.json() def read(self) -> bytes: - return self._loop.run_until_complete(self._response.read()) + return self._backend.run(self._response.read) def stream(self) -> typing.Iterator[bytes]: inner = self._response.stream() while True: try: - yield self._loop.run_until_complete(inner.__anext__()) + yield self._backend.run(inner.__anext__) except StopAsyncIteration: break @@ -841,12 +840,12 @@ def raw(self) -> typing.Iterator[bytes]: inner = self._response.raw() while True: try: - yield self._loop.run_until_complete(inner.__anext__()) + yield self._backend.run(inner.__anext__) except StopAsyncIteration: break def close(self) -> None: - return self._loop.run_until_complete(self._response.close()) + return self._backend.run(self._response.close) @property def cookies(self) -> "Cookies": @@ -1029,3 +1028,7 @@ def info(self) -> email.message.Message: for key, value in self.response.headers.items(): info[key] = value return info + + +if True: + from .interfaces import ConcurrencyBackend diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 631a058561..7601ffc3ef 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -4,9 +4,9 @@ from httpcore import ( URL, + AsyncDispatcher, CertTypes, Client, - AsyncDispatcher, Request, Response, TimeoutTypes, diff --git a/tests/client/test_cookies.py b/tests/client/test_cookies.py index e96c386bc1..eb951726f8 100644 --- a/tests/client/test_cookies.py +++ b/tests/client/test_cookies.py @@ -5,10 +5,10 @@ from httpcore import ( URL, + AsyncDispatcher, CertTypes, Client, Cookies, - AsyncDispatcher, Request, Response, TimeoutTypes, diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 24acc3aff8..b0cfbb9bfc 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -6,8 +6,8 @@ from httpcore import ( URL, AsyncClient, - CertTypes, AsyncDispatcher, + CertTypes, RedirectBodyUnavailable, RedirectLoop, Request, From d8465dfbef34bbededf566ed13052e7b9b96bb4e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 24 May 2019 17:18:28 +0100 Subject: [PATCH 03/19] Initial work towards support byte-iterators on sync request data --- httpcore/client.py | 35 +++++++++++++++++++++++++++++------ httpcore/interfaces.py | 5 +++-- httpcore/models.py | 7 +++++-- tests/test_api.py | 8 ++++++++ 4 files changed, 45 insertions(+), 10 deletions(-) diff --git a/httpcore/client.py b/httpcore/client.py index e0d00d8664..865b6e296f 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -18,6 +18,7 @@ from .interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher from .models import ( URL, + AsyncRequestData, AuthTypes, Cookies, CookieTypes, @@ -156,7 +157,7 @@ async def post( self, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -188,7 +189,7 @@ async def put( self, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -220,7 +221,7 @@ async def patch( self, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -252,7 +253,7 @@ async def delete( self, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -285,7 +286,7 @@ async def request( method: str, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -531,6 +532,28 @@ def cookies(self) -> Cookies: def concurrency_backend(self) -> ConcurrencyBackend: return self._client.concurrency_backend + def _async_request_data(self, data: RequestData) -> AsyncRequestData: + """ + If the request data is an bytes iterator then return an async bytes + iterator onto the request data. + """ + if isinstance(data, (bytes, dict)): + return data + + assert hasattr(data, "__iter__") + + async def async_iterator(backend, data): # type: ignore + while True: + print(123) + try: + yield await self.concurrency_backend.run_in_threadpool( + data.__next__ + ) + except StopIteration: + raise StopAsyncIteration() + + return async_iterator(self.concurrency_backend, data) + def request( self, method: str, @@ -551,7 +574,7 @@ def request( request = Request( method, url, - data=data, + data=self._async_request_data(data), json=json, params=params, headers=headers, diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index 56bf5e4d0e..be0731c9f6 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -6,6 +6,7 @@ from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes from .models import ( URL, + AsyncRequestData, Headers, HeaderTypes, QueryParamTypes, @@ -35,7 +36,7 @@ async def request( method: str, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", params: QueryParamTypes = None, headers: HeaderTypes = None, stream: bool = False, @@ -180,7 +181,7 @@ async def connect( def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover - def run_in_threadpool( + async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: raise NotImplementedError() # pragma: no cover diff --git a/httpcore/models.py b/httpcore/models.py index 6bcd42608d..ca08b034d0 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -50,7 +50,9 @@ typing.Callable[["Request"], "Request"], ] -RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]] +AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]] + +RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]] ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]] @@ -474,7 +476,7 @@ def __init__( method: str, url: typing.Union[str, URL], *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -499,6 +501,7 @@ def __init__( self.content = urlencode(data, doseq=True).encode("utf-8") self.headers["Content-Type"] = "application/x-www-form-urlencoded" else: + assert hasattr(data, "__aiter__") self.is_streaming = True self.content_aiter = data diff --git a/tests/test_api.py b/tests/test_api.py index 6a62359c16..06e59694db 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -38,6 +38,14 @@ def test_post(server): assert response.reason_phrase == "OK" +# @threadpool +# def test_post_byte_iterator(server): +# data = (i for i in [b"Hello", b", ", b"world!"]) +# response = httpcore.post("http://127.0.0.1:8000/", data=data) +# assert response.status_code == 200 +# assert response.reason_phrase == "OK" + + @threadpool def test_options(server): response = httpcore.options("http://127.0.0.1:8000/") From 0a6cc1fe6308fc4b48c208ff5fde8164e958e9c3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 24 May 2019 17:33:38 +0100 Subject: [PATCH 04/19] Test case for byte iterator content --- httpcore/client.py | 4 ++-- tests/test_api.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/httpcore/client.py b/httpcore/client.py index 865b6e296f..e16615e748 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -544,13 +544,13 @@ def _async_request_data(self, data: RequestData) -> AsyncRequestData: async def async_iterator(backend, data): # type: ignore while True: - print(123) try: - yield await self.concurrency_backend.run_in_threadpool( + chunk = await self.concurrency_backend.run_in_threadpool( data.__next__ ) except StopIteration: raise StopAsyncIteration() + yield chunk return async_iterator(self.concurrency_backend, data) diff --git a/tests/test_api.py b/tests/test_api.py index 06e59694db..53e5e044d3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -40,8 +40,12 @@ def test_post(server): # @threadpool # def test_post_byte_iterator(server): -# data = (i for i in [b"Hello", b", ", b"world!"]) -# response = httpcore.post("http://127.0.0.1:8000/", data=data) +# def data(): +# yield b"Hello" +# yield b", " +# yield b"world!" +# +# response = httpcore.post("http://127.0.0.1:8000/", data=data()) # assert response.status_code == 200 # assert response.reason_phrase == "OK" From 9c9db45f2da91ba779ce5f1967e2d5ec46d3a115 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 24 May 2019 22:17:44 +0100 Subject: [PATCH 05/19] byte iterator support for RequestData --- httpcore/client.py | 13 +------------ httpcore/interfaces.py | 16 ++++++++++++++++ tests/test_api.py | 20 ++++++++++---------- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/httpcore/client.py b/httpcore/client.py index e16615e748..6aafae88c9 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -541,18 +541,7 @@ def _async_request_data(self, data: RequestData) -> AsyncRequestData: return data assert hasattr(data, "__iter__") - - async def async_iterator(backend, data): # type: ignore - while True: - try: - chunk = await self.concurrency_backend.run_in_threadpool( - data.__next__ - ) - except StopIteration: - raise StopAsyncIteration() - yield chunk - - return async_iterator(self.concurrency_backend, data) + return self.concurrency_backend.iterate_in_threadpool(data) def request( self, diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index be0731c9f6..27c9180876 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -186,6 +186,22 @@ async def run_in_threadpool( ) -> typing.Any: raise NotImplementedError() # pragma: no cover + async def iterate_in_threadpool(self, iterator): # type: ignore + class IterationComplete(Exception): + pass + + def next_wrapper(iterator): # type: ignore + try: + return next(iterator) + except StopIteration: + raise IterationComplete() + + while True: + try: + yield await self.run_in_threadpool(next_wrapper, iterator) + except IterationComplete: + break + def run( self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: diff --git a/tests/test_api.py b/tests/test_api.py index 53e5e044d3..1247a41602 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -38,16 +38,16 @@ def test_post(server): assert response.reason_phrase == "OK" -# @threadpool -# def test_post_byte_iterator(server): -# def data(): -# yield b"Hello" -# yield b", " -# yield b"world!" -# -# response = httpcore.post("http://127.0.0.1:8000/", data=data()) -# assert response.status_code == 200 -# assert response.reason_phrase == "OK" +@threadpool +def test_post_byte_iterator(server): + def data(): + yield b"Hello" + yield b", " + yield b"world!" + + response = httpcore.post("http://127.0.0.1:8000/", data=data()) + assert response.status_code == 200 + assert response.reason_phrase == "OK" @threadpool From cda91e955c36ad6c0e13400306dc6a22fa887b5d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 28 May 2019 17:00:39 +0100 Subject: [PATCH 06/19] Add BaseResponse --- httpcore/dispatch/http11.py | 2 +- httpcore/models.py | 119 ++++++++++++++++++++-------------- tests/models/test_requests.py | 19 ++---- 3 files changed, 77 insertions(+), 63 deletions(-) diff --git a/httpcore/dispatch/http11.py b/httpcore/dispatch/http11.py index 1c0306b6de..669d80726f 100644 --- a/httpcore/dispatch/http11.py +++ b/httpcore/dispatch/http11.py @@ -46,7 +46,7 @@ async def send( method = request.method.encode("ascii") target = request.url.full_path.encode("ascii") headers = request.headers.raw - if 'Host' not in request.headers: + if "Host" not in request.headers: host = request.url.authority.encode("ascii") headers = [(b"host", host)] + headers event = h11.Request(method=method, target=target, headers=headers) diff --git a/httpcore/models.py b/httpcore/models.py index 336db88132..f3c9a6cd40 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -554,7 +554,7 @@ def __repr__(self) -> str: return f"<{class_name}({self.method!r}, {url!r})>" -class Response: +class BaseResponse: def __init__( self, status_code: int, @@ -572,15 +572,6 @@ def __init__( self.protocol = protocol self.headers = Headers(headers) - if isinstance(content, bytes): - self.is_closed = True - self.is_stream_consumed = True - self._raw_content = content - else: - self.is_closed = False - self.is_stream_consumed = False - self._raw_stream = content - self.on_close = on_close self.request = request self.history = [] if history is None else list(history) @@ -684,6 +675,77 @@ def decoder(self) -> Decoder: return self._decoder + + @property + def is_redirect(self) -> bool: + return StatusCode.is_redirect(self.status_code) and "location" in self.headers + + def raise_for_status(self) -> None: + """ + Raise the `HttpError` if one occurred. + """ + message = ( + "{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n" + "For more information check: https://httpstatuses.com/{0.status_code}" + ) + + if StatusCode.is_client_error(self.status_code): + message = message.format(self, error_type="Client Error") + elif StatusCode.is_server_error(self.status_code): + message = message.format(self, error_type="Server Error") + else: + message = "" + + if message: + raise HttpError(message) + + def json(self) -> typing.Any: + return jsonlib.loads(self.content.decode("utf-8")) + + @property + def cookies(self) -> "Cookies": + if not hasattr(self, "_cookies"): + assert self.request is not None + self._cookies = Cookies() + self._cookies.extract_cookies(self) + return self._cookies + + def __repr__(self) -> str: + return f"" + + +class Response(BaseResponse): + def __init__( + self, + status_code: int, + *, + reason_phrase: str = None, + protocol: str = None, + headers: HeaderTypes = None, + content: ResponseContent = b"", + on_close: typing.Callable = None, + request: Request = None, + history: typing.List["Response"] = None, + ): + super().__init__( + status_code=status_code, + reason_phrase=reason_phrase, + protocol=protocol, + headers=headers, + on_close=on_close, + request=request, + history=history, + ) + + if isinstance(content, bytes): + self.is_closed = True + self.is_stream_consumed = True + self._raw_content = content + else: + self.is_closed = False + self.is_stream_consumed = False + self._raw_stream = content + async def read(self) -> bytes: """ Read and return the response content. @@ -731,43 +793,6 @@ async def close(self) -> None: if self.on_close is not None: await self.on_close() - @property - def is_redirect(self) -> bool: - return StatusCode.is_redirect(self.status_code) and "location" in self.headers - - def raise_for_status(self) -> None: - """ - Raise the `HttpError` if one occurred. - """ - message = ( - "{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n" - "For more information check: https://httpstatuses.com/{0.status_code}" - ) - - if StatusCode.is_client_error(self.status_code): - message = message.format(self, error_type="Client Error") - elif StatusCode.is_server_error(self.status_code): - message = message.format(self, error_type="Server Error") - else: - message = "" - - if message: - raise HttpError(message) - - def json(self) -> typing.Any: - return jsonlib.loads(self.content.decode("utf-8")) - - @property - def cookies(self) -> "Cookies": - if not hasattr(self, "_cookies"): - assert self.request is not None - self._cookies = Cookies() - self._cookies.extract_cookies(self) - return self._cookies - - def __repr__(self) -> str: - return f"" - class SyncResponse: """ diff --git a/tests/models/test_requests.py b/tests/models/test_requests.py index d0d521a468..5d3a27d3ff 100644 --- a/tests/models/test_requests.py +++ b/tests/models/test_requests.py @@ -20,10 +20,7 @@ def test_content_length_header(): request = httpcore.Request("POST", "http://example.org", data=b"test 123") request.prepare() assert request.headers == httpcore.Headers( - [ - (b"content-length", b"8"), - (b"accept-encoding", b"deflate, gzip, br"), - ] + [(b"content-length", b"8"), (b"accept-encoding", b"deflate, gzip, br")] ) @@ -49,10 +46,7 @@ async def streaming_body(data): request = httpcore.Request("POST", "http://example.org", data=data) request.prepare() assert request.headers == httpcore.Headers( - [ - (b"transfer-encoding", b"chunked"), - (b"accept-encoding", b"deflate, gzip, br"), - ] + [(b"transfer-encoding", b"chunked"), (b"accept-encoding", b"deflate, gzip, br")] ) @@ -71,9 +65,7 @@ def test_override_accept_encoding_header(): request = httpcore.Request("GET", "http://example.org", headers=headers) request.prepare() - assert request.headers == httpcore.Headers( - [(b"accept-encoding", b"identity")] - ) + assert request.headers == httpcore.Headers([(b"accept-encoding", b"identity")]) def test_override_content_length_header(): @@ -86,10 +78,7 @@ async def streaming_body(data): request = httpcore.Request("POST", "http://example.org", data=data, headers=headers) request.prepare() assert request.headers == httpcore.Headers( - [ - (b"accept-encoding", b"deflate, gzip, br"), - (b"content-length", b"8"), - ] + [(b"accept-encoding", b"deflate, gzip, br"), (b"content-length", b"8")] ) From 2b839b55a011256be186133f3ed37036fff632e0 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 30 May 2019 13:20:30 +0100 Subject: [PATCH 07/19] Bridge sync/async data in SyncResponse --- httpcore/client.py | 41 +++++++++++- httpcore/interfaces.py | 7 ++ httpcore/models.py | 148 ++++++++++++++++++++--------------------- 3 files changed, 116 insertions(+), 80 deletions(-) diff --git a/httpcore/client.py b/httpcore/client.py index 6aafae88c9..7ef2a38f76 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -540,9 +540,21 @@ def _async_request_data(self, data: RequestData) -> AsyncRequestData: if isinstance(data, (bytes, dict)): return data + # Coerce an iterator into an async iterator, with each item in the + # iteration running as a thread-pooled operation. assert hasattr(data, "__iter__") return self.concurrency_backend.iterate_in_threadpool(data) + def _sync_data(self, data): + if isinstance(data, (bytes, dict)): + return data + + # Coerce an async iterator into an iterator, with each item in the + # iteration run within the event loop. + assert hasattr(data, "__aiter__") + return self.concurrency_backend.iterate(data) + + def request( self, method: str, @@ -796,18 +808,41 @@ def send( cert: CertTypes = None, timeout: TimeoutTypes = None, ) -> SyncResponse: + concurrency_backend = self.concurrency_backend + coroutine = self._client.send args = [request] kwargs = dict( - stream=stream, + stream=True, auth=auth, allow_redirects=allow_redirects, verify=verify, cert=cert, timeout=timeout, ) - response = self.concurrency_backend.run(coroutine, *args, **kwargs) - return SyncResponse(response, self.concurrency_backend) + response = concurrency_backend.run(coroutine, *args, **kwargs) + + content = getattr(response, '_raw_content', getattr(response, '_raw_stream', None)) + + sync_content = self._sync_data(content) + + def sync_on_close(): + nonlocal concurrency_backend, response + return concurrency_backend.run(response.on_close) + + sync_response = SyncResponse( + status_code=response.status_code, + reason_phrase=response.reason_phrase, + protocol=response.protocol, + headers=response.headers, + content=sync_content, + on_close=sync_on_close, + request=response.request, + history=response.history, + ) + if not stream: + sync_response.read() + return sync_response def close(self) -> None: coroutine = self._client.close diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index 27c9180876..63096c7fb9 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -206,3 +206,10 @@ def run( self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: raise NotImplementedError() # pragma: no cover + + def iterate(self, async_iterator): # type: ignore + while True: + try: + yield self.run(async_iterator.__anext__) + except StopAsyncIteration: + break diff --git a/httpcore/models.py b/httpcore/models.py index f3c9a6cd40..a68061a271 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -675,7 +675,6 @@ def decoder(self) -> Decoder: return self._decoder - @property def is_redirect(self) -> bool: return StatusCode.is_redirect(self.status_code) and "location" in self.headers @@ -794,90 +793,89 @@ async def close(self) -> None: await self.on_close() -class SyncResponse: +class SyncResponse(BaseResponse): """ A thread-synchronous response. This class proxies onto a `Response` instance, providing standard synchronous interfaces where required. """ - def __init__(self, response: Response, backend: "ConcurrencyBackend"): - self._response = response - self._backend = backend - - @property - def status_code(self) -> int: - return self._response.status_code - - @property - def reason_phrase(self) -> str: - return self._response.reason_phrase - - @property - def protocol(self) -> typing.Optional[str]: - return self._response.protocol - - @property - def url(self) -> typing.Optional[URL]: - return self._response.url - - @property - def request(self) -> typing.Optional[Request]: - return self._response.request - - @property - def headers(self) -> Headers: - return self._response.headers - - @property - def content(self) -> bytes: - return self._response.content - - @property - def text(self) -> str: - return self._response.text - - @property - def encoding(self) -> str: - return self._response.encoding - - @property - def is_redirect(self) -> bool: - return self._response.is_redirect - - def raise_for_status(self) -> None: - return self._response.raise_for_status() + def __init__( + self, + status_code: int, + *, + reason_phrase: str = None, + protocol: str = None, + headers: HeaderTypes = None, + content: ResponseContent = b"", + on_close: typing.Callable = None, + request: Request = None, + history: typing.List["Response"] = None, + ): + super().__init__( + status_code=status_code, + reason_phrase=reason_phrase, + protocol=protocol, + headers=headers, + on_close=on_close, + request=request, + history=history, + ) - def json(self) -> typing.Any: - return self._response.json() + if isinstance(content, bytes): + self.is_closed = True + self.is_stream_consumed = True + self._raw_content = content + else: + self.is_closed = False + self.is_stream_consumed = False + self._raw_stream = content def read(self) -> bytes: - return self._backend.run(self._response.read) + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join([part for part in self.stream()]) + return self._content def stream(self) -> typing.Iterator[bytes]: - inner = self._response.stream() - while True: - try: - yield self._backend.run(inner.__anext__) - except StopAsyncIteration: - break + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "_content"): + yield self._content + else: + for chunk in self.raw(): + yield self.decoder.decode(chunk) + yield self.decoder.flush() def raw(self) -> typing.Iterator[bytes]: - inner = self._response.raw() - while True: - try: - yield self._backend.run(inner.__anext__) - except StopAsyncIteration: - break - - def close(self) -> None: - return self._backend.run(self._response.close) + """ + A byte-iterator over the raw response content. + """ + if hasattr(self, "_raw_content"): + yield self._raw_content + else: + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise ResponseClosed() - @property - def cookies(self) -> "Cookies": - return self._response.cookies + self.is_stream_consumed = True + for part in self._raw_stream: + yield part + self.close() - def __repr__(self) -> str: - return f"" + def close(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not self.is_closed: + self.is_closed = True + if self.on_close is not None: + self.on_close() class Cookies(MutableMapping): @@ -898,7 +896,7 @@ def __init__(self, cookies: CookieTypes = None) -> None: else: self.jar = cookies - def extract_cookies(self, response: Response) -> None: + def extract_cookies(self, response: BaseResponse) -> None: """ Loads any cookies based on the response `Set-Cookie` headers. """ @@ -1045,7 +1043,7 @@ class _CookieCompatResponse: for use with `CookieJar` operations. """ - def __init__(self, response: Response): + def __init__(self, response: BaseResponse): self.response = response def info(self) -> email.message.Message: @@ -1053,7 +1051,3 @@ def info(self) -> email.message.Message: for key, value in self.response.headers.items(): info[key] = value return info - - -if True: - from .interfaces import ConcurrencyBackend From 7fe21560edf0f110244f5dee0fe43fd93646e8a6 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 30 May 2019 15:18:57 +0100 Subject: [PATCH 08/19] Add BaseClient --- httpcore/client.py | 459 ++++++++++++++++++++------------------------- 1 file changed, 205 insertions(+), 254 deletions(-) diff --git a/httpcore/client.py b/httpcore/client.py index 7ef2a38f76..9cc287a984 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -34,7 +34,7 @@ from .status_codes import codes -class AsyncClient: +class BaseClient: def __init__( self, auth: AuthTypes = None, @@ -69,6 +69,176 @@ def __init__( self.dispatch = async_dispatch self.concurrency_backend = backend + def merge_cookies( + self, cookies: CookieTypes = None + ) -> typing.Optional[CookieTypes]: + if cookies or self.cookies: + merged_cookies = Cookies(self.cookies) + merged_cookies.update(cookies) + return merged_cookies + return cookies + + async def send( + self, + request: Request, + *, + stream: bool = False, + auth: AuthTypes = None, + allow_redirects: bool = True, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> Response: + if auth is None: + auth = self.auth + + url = request.url + if auth is None and (url.username or url.password): + auth = HTTPBasicAuth(username=url.username, password=url.password) + + if auth is not None: + if isinstance(auth, tuple): + auth = HTTPBasicAuth(username=auth[0], password=auth[1]) + request = auth(request) + + response = await self.send_handling_redirects( + request, + stream=stream, + verify=verify, + cert=cert, + timeout=timeout, + allow_redirects=allow_redirects, + ) + return response + + async def send_handling_redirects( + self, + request: Request, + *, + stream: bool = False, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, + allow_redirects: bool = True, + history: typing.List[Response] = None, + ) -> Response: + if history is None: + history = [] + + while True: + # We perform these checks here, so that calls to `response.next()` + # will raise redirect errors if appropriate. + if len(history) > self.max_redirects: + raise TooManyRedirects() + if request.url in [response.url for response in history]: + raise RedirectLoop() + + response = await self.dispatch.send( + request, stream=stream, verify=verify, cert=cert, timeout=timeout + ) + response.history = list(history) + self.cookies.extract_cookies(response) + history = [response] + history + if not response.is_redirect: + break + + if allow_redirects: + request = self.build_redirect_request(request, response) + else: + + async def send_next() -> Response: + nonlocal request, response, verify, cert, allow_redirects, timeout, history + request = self.build_redirect_request(request, response) + response = await self.send_handling_redirects( + request, + stream=stream, + allow_redirects=allow_redirects, + verify=verify, + cert=cert, + timeout=timeout, + history=history, + ) + return response + + response.next = send_next # type: ignore + break + + return response + + def build_redirect_request(self, request: Request, response: Response) -> Request: + method = self.redirect_method(request, response) + url = self.redirect_url(request, response) + headers = self.redirect_headers(request, url) + content = self.redirect_content(request, method) + cookies = self.merge_cookies(request.cookies) + return Request( + method=method, url=url, headers=headers, data=content, cookies=cookies + ) + + def redirect_method(self, request: Request, response: Response) -> str: + """ + When being redirected we may want to change the method of the request + based on certain specs or browser behavior. + """ + method = request.method + + # https://tools.ietf.org/html/rfc7231#section-6.4.4 + if response.status_code == codes.SEE_OTHER and method != "HEAD": + method = "GET" + + # Do what the browsers do, despite standards... + # Turn 302s into GETs. + if response.status_code == codes.FOUND and method != "HEAD": + method = "GET" + + # If a POST is responded to with a 301, turn it into a GET. + # This bizarre behaviour is explained in 'requests' issue 1704. + if response.status_code == codes.MOVED_PERMANENTLY and method == "POST": + method = "GET" + + return method + + def redirect_url(self, request: Request, response: Response) -> URL: + """ + Return the URL for the redirect to follow. + """ + location = response.headers["Location"] + + url = URL(location, allow_relative=True) + + # Facilitate relative 'Location' headers, as allowed by RFC 7231. + # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + if url.is_relative_url: + url = url.resolve_with(request.url) + + # Attach previous fragment if needed (RFC 7231 7.1.2) + if request.url.fragment and not url.fragment: + url = url.copy_with(fragment=request.url.fragment) + + return url + + def redirect_headers(self, request: Request, url: URL) -> Headers: + """ + Strip Authorization headers when responses are redirected away from + the origin. + """ + headers = Headers(request.headers) + if url.origin != request.url.origin: + del headers["Authorization"] + return headers + + def redirect_content(self, request: Request, method: str) -> bytes: + """ + Return the body that should be used for the redirect request. + """ + if method != request.method and method == "GET": + return b"" + if request.is_streaming: + raise RedirectBodyUnavailable() + return request.content + + +class AsyncClient(BaseClient): async def get( self, url: URLTypes, @@ -318,174 +488,6 @@ async def request( ) return response - def merge_cookies( - self, cookies: CookieTypes = None - ) -> typing.Optional[CookieTypes]: - if cookies or self.cookies: - merged_cookies = Cookies(self.cookies) - merged_cookies.update(cookies) - return merged_cookies - return cookies - - async def send( - self, - request: Request, - *, - stream: bool = False, - auth: AuthTypes = None, - allow_redirects: bool = True, - verify: VerifyTypes = None, - cert: CertTypes = None, - timeout: TimeoutTypes = None, - ) -> Response: - if auth is None: - auth = self.auth - - url = request.url - if auth is None and (url.username or url.password): - auth = HTTPBasicAuth(username=url.username, password=url.password) - - if auth is not None: - if isinstance(auth, tuple): - auth = HTTPBasicAuth(username=auth[0], password=auth[1]) - request = auth(request) - - response = await self.send_handling_redirects( - request, - stream=stream, - verify=verify, - cert=cert, - timeout=timeout, - allow_redirects=allow_redirects, - ) - return response - - async def send_handling_redirects( - self, - request: Request, - *, - stream: bool = False, - cert: CertTypes = None, - verify: VerifyTypes = None, - timeout: TimeoutTypes = None, - allow_redirects: bool = True, - history: typing.List[Response] = None, - ) -> Response: - if history is None: - history = [] - - while True: - # We perform these checks here, so that calls to `response.next()` - # will raise redirect errors if appropriate. - if len(history) > self.max_redirects: - raise TooManyRedirects() - if request.url in [response.url for response in history]: - raise RedirectLoop() - - response = await self.dispatch.send( - request, stream=stream, verify=verify, cert=cert, timeout=timeout - ) - response.history = list(history) - self.cookies.extract_cookies(response) - history = [response] + history - if not response.is_redirect: - break - - if allow_redirects: - request = self.build_redirect_request(request, response) - else: - - async def send_next() -> Response: - nonlocal request, response, verify, cert, allow_redirects, timeout, history - request = self.build_redirect_request(request, response) - response = await self.send_handling_redirects( - request, - stream=stream, - allow_redirects=allow_redirects, - verify=verify, - cert=cert, - timeout=timeout, - history=history, - ) - return response - - response.next = send_next # type: ignore - break - - return response - - def build_redirect_request(self, request: Request, response: Response) -> Request: - method = self.redirect_method(request, response) - url = self.redirect_url(request, response) - headers = self.redirect_headers(request, url) - content = self.redirect_content(request, method) - cookies = self.merge_cookies(request.cookies) - return Request( - method=method, url=url, headers=headers, data=content, cookies=cookies - ) - - def redirect_method(self, request: Request, response: Response) -> str: - """ - When being redirected we may want to change the method of the request - based on certain specs or browser behavior. - """ - method = request.method - - # https://tools.ietf.org/html/rfc7231#section-6.4.4 - if response.status_code == codes.SEE_OTHER and method != "HEAD": - method = "GET" - - # Do what the browsers do, despite standards... - # Turn 302s into GETs. - if response.status_code == codes.FOUND and method != "HEAD": - method = "GET" - - # If a POST is responded to with a 301, turn it into a GET. - # This bizarre behaviour is explained in 'requests' issue 1704. - if response.status_code == codes.MOVED_PERMANENTLY and method == "POST": - method = "GET" - - return method - - def redirect_url(self, request: Request, response: Response) -> URL: - """ - Return the URL for the redirect to follow. - """ - location = response.headers["Location"] - - url = URL(location, allow_relative=True) - - # Facilitate relative 'Location' headers, as allowed by RFC 7231. - # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') - if url.is_relative_url: - url = url.resolve_with(request.url) - - # Attach previous fragment if needed (RFC 7231 7.1.2) - if request.url.fragment and not url.fragment: - url = url.copy_with(fragment=request.url.fragment) - - return url - - def redirect_headers(self, request: Request, url: URL) -> Headers: - """ - Strip Authorization headers when responses are redirected away from - the origin. - """ - headers = Headers(request.headers) - if url.origin != request.url.origin: - del headers["Authorization"] - return headers - - def redirect_content(self, request: Request, method: str) -> bytes: - """ - Return the body that should be used for the redirect request. - """ - if method != request.method and method == "GET": - return b"" - if request.is_streaming: - raise RedirectBodyUnavailable() - return request.content - async def close(self) -> None: await self.dispatch.close() @@ -501,37 +503,7 @@ async def __aexit__( await self.close() -class Client: - def __init__( - self, - auth: AuthTypes = None, - cert: CertTypes = None, - verify: VerifyTypes = True, - timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, - pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, - max_redirects: int = DEFAULT_MAX_REDIRECTS, - dispatch: typing.Union[Dispatcher, AsyncDispatcher] = None, - backend: ConcurrencyBackend = None, - ) -> None: - self._client = AsyncClient( - auth=auth, - verify=verify, - cert=cert, - timeout=timeout, - pool_limits=pool_limits, - max_redirects=max_redirects, - dispatch=dispatch, - backend=backend, - ) - - @property - def cookies(self) -> Cookies: - return self._client.cookies - - @property - def concurrency_backend(self) -> ConcurrencyBackend: - return self._client.concurrency_backend - +class Client(BaseClient): def _async_request_data(self, data: RequestData) -> AsyncRequestData: """ If the request data is an bytes iterator then return an async bytes @@ -554,7 +526,6 @@ def _sync_data(self, data): assert hasattr(data, "__aiter__") return self.concurrency_backend.iterate(data) - def request( self, method: str, @@ -579,18 +550,45 @@ def request( json=json, params=params, headers=headers, - cookies=self._client.merge_cookies(cookies), + cookies=self.merge_cookies(cookies), ) - response = self.send( - request, - stream=stream, + concurrency_backend = self.concurrency_backend + + coroutine = self.send + args = [request] + kwargs = dict( + stream=True, auth=auth, allow_redirects=allow_redirects, verify=verify, cert=cert, timeout=timeout, ) - return response + response = concurrency_backend.run(coroutine, *args, **kwargs) + + content = getattr( + response, "_raw_content", getattr(response, "_raw_stream", None) + ) + + sync_content = self._sync_data(content) + + def sync_on_close(): + nonlocal concurrency_backend, response + return concurrency_backend.run(response.on_close) + + sync_response = SyncResponse( + status_code=response.status_code, + reason_phrase=response.reason_phrase, + protocol=response.protocol, + headers=response.headers, + content=sync_content, + on_close=sync_on_close, + request=response.request, + history=response.history, + ) + if not stream: + sync_response.read() + return sync_response def get( self, @@ -797,55 +795,8 @@ def delete( timeout=timeout, ) - def send( - self, - request: Request, - *, - stream: bool = False, - auth: AuthTypes = None, - allow_redirects: bool = True, - verify: VerifyTypes = None, - cert: CertTypes = None, - timeout: TimeoutTypes = None, - ) -> SyncResponse: - concurrency_backend = self.concurrency_backend - - coroutine = self._client.send - args = [request] - kwargs = dict( - stream=True, - auth=auth, - allow_redirects=allow_redirects, - verify=verify, - cert=cert, - timeout=timeout, - ) - response = concurrency_backend.run(coroutine, *args, **kwargs) - - content = getattr(response, '_raw_content', getattr(response, '_raw_stream', None)) - - sync_content = self._sync_data(content) - - def sync_on_close(): - nonlocal concurrency_backend, response - return concurrency_backend.run(response.on_close) - - sync_response = SyncResponse( - status_code=response.status_code, - reason_phrase=response.reason_phrase, - protocol=response.protocol, - headers=response.headers, - content=sync_content, - on_close=sync_on_close, - request=response.request, - history=response.history, - ) - if not stream: - sync_response.read() - return sync_response - def close(self) -> None: - coroutine = self._client.close + coroutine = self.dispatch.close self.concurrency_backend.run(coroutine) def __enter__(self) -> "Client": From 1ae52096e544412cd5da484353a67fd4d7170014 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 30 May 2019 15:58:39 +0100 Subject: [PATCH 09/19] SyncResponse -> Response --- httpcore/api.py | 18 +++++------ httpcore/client.py | 48 ++++++++++++++-------------- httpcore/dispatch/connection.py | 4 +-- httpcore/dispatch/connection_pool.py | 4 +-- httpcore/dispatch/http11.py | 6 ++-- httpcore/dispatch/http2.py | 6 ++-- httpcore/dispatch/threaded.py | 4 +-- httpcore/interfaces.py | 5 +-- httpcore/models.py | 9 ++---- tests/models/test_responses.py | 43 +++++++++++-------------- tests/test_decoders.py | 7 ++-- 11 files changed, 71 insertions(+), 83 deletions(-) diff --git a/httpcore/api.py b/httpcore/api.py index 33d68c5e77..7e2682567f 100644 --- a/httpcore/api.py +++ b/httpcore/api.py @@ -8,7 +8,7 @@ HeaderTypes, QueryParamTypes, RequestData, - SyncResponse, + Response, URLTypes, ) @@ -30,7 +30,7 @@ def request( cert: CertTypes = None, verify: VerifyTypes = True, stream: bool = False, -) -> SyncResponse: +) -> Response: with Client() as client: return client.request( method=method, @@ -61,7 +61,7 @@ def get( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "GET", url, @@ -88,7 +88,7 @@ def options( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "OPTIONS", url, @@ -115,7 +115,7 @@ def head( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "HEAD", url, @@ -144,7 +144,7 @@ def post( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "POST", url, @@ -175,7 +175,7 @@ def put( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "PUT", url, @@ -206,7 +206,7 @@ def patch( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "PATCH", url, @@ -237,7 +237,7 @@ def delete( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "DELETE", url, diff --git a/httpcore/client.py b/httpcore/client.py index 9cc287a984..b37d7b0fb9 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -19,6 +19,7 @@ from .models import ( URL, AsyncRequestData, + AsyncResponse, AuthTypes, Cookies, CookieTypes, @@ -28,7 +29,6 @@ Request, RequestData, Response, - SyncResponse, URLTypes, ) from .status_codes import codes @@ -88,7 +88,7 @@ async def send( verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: if auth is None: auth = self.auth @@ -542,7 +542,7 @@ def request( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: request = Request( method, url, @@ -564,31 +564,31 @@ def request( cert=cert, timeout=timeout, ) - response = concurrency_backend.run(coroutine, *args, **kwargs) + async_response = concurrency_backend.run(coroutine, *args, **kwargs) content = getattr( - response, "_raw_content", getattr(response, "_raw_stream", None) + async_response, "_raw_content", getattr(async_response, "_raw_stream", None) ) sync_content = self._sync_data(content) def sync_on_close(): - nonlocal concurrency_backend, response - return concurrency_backend.run(response.on_close) - - sync_response = SyncResponse( - status_code=response.status_code, - reason_phrase=response.reason_phrase, - protocol=response.protocol, - headers=response.headers, + nonlocal concurrency_backend, async_response + return concurrency_backend.run(async_response.on_close) + + response = Response( + status_code=async_response.status_code, + reason_phrase=async_response.reason_phrase, + protocol=async_response.protocol, + headers=async_response.headers, content=sync_content, on_close=sync_on_close, - request=response.request, - history=response.history, + request=async_response.request, + history=async_response.history, ) if not stream: - sync_response.read() - return sync_response + response.read() + return response def get( self, @@ -603,7 +603,7 @@ def get( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "GET", url, @@ -630,7 +630,7 @@ def options( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "OPTIONS", url, @@ -657,7 +657,7 @@ def head( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "HEAD", url, @@ -686,7 +686,7 @@ def post( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "POST", url, @@ -717,7 +717,7 @@ def put( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "PUT", url, @@ -748,7 +748,7 @@ def patch( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "PATCH", url, @@ -779,7 +779,7 @@ def delete( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "DELETE", url, diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py index 3c6cc81cb8..c0ac499d6e 100644 --- a/httpcore/dispatch/connection.py +++ b/httpcore/dispatch/connection.py @@ -16,7 +16,7 @@ ) from ..exceptions import ConnectTimeout from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Protocol -from ..models import Origin, Request, Response +from ..models import AsyncResponse, Origin, Request from .http2 import HTTP2Connection from .http11 import HTTP11Connection @@ -49,7 +49,7 @@ async def send( verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: if self.h11_connection is None and self.h2_connection is None: await self.connect(verify=verify, cert=cert, timeout=timeout) diff --git a/httpcore/dispatch/connection_pool.py b/httpcore/dispatch/connection_pool.py index 713e56fb24..cf3a481947 100644 --- a/httpcore/dispatch/connection_pool.py +++ b/httpcore/dispatch/connection_pool.py @@ -13,7 +13,7 @@ from ..decoders import ACCEPT_ENCODING from ..exceptions import PoolTimeout from ..interfaces import AsyncDispatcher, ConcurrencyBackend -from ..models import Origin, Request, Response +from ..models import AsyncResponse, Origin, Request from .connection import HTTPConnection CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]] @@ -110,7 +110,7 @@ async def send( verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: connection = await self.acquire_connection(request.url.origin) try: response = await connection.send( diff --git a/httpcore/dispatch/http11.py b/httpcore/dispatch/http11.py index 669d80726f..549a82059a 100644 --- a/httpcore/dispatch/http11.py +++ b/httpcore/dispatch/http11.py @@ -5,7 +5,7 @@ from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout from ..interfaces import BaseReader, BaseWriter -from ..models import Request, Response +from ..models import AsyncResponse, Request H11Event = typing.Union[ h11.Request, @@ -39,7 +39,7 @@ def __init__( async def send( self, request: Request, stream: bool = False, timeout: TimeoutTypes = None - ) -> Response: + ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) #  Start sending the request. @@ -72,7 +72,7 @@ async def send( headers = event.headers content = self._body_iter(timeout) - response = Response( + response = AsyncResponse( status_code=status_code, reason_phrase=reason_phrase, protocol="HTTP/1.1", diff --git a/httpcore/dispatch/http2.py b/httpcore/dispatch/http2.py index 4621a9cd20..5b6b26b4c8 100644 --- a/httpcore/dispatch/http2.py +++ b/httpcore/dispatch/http2.py @@ -7,7 +7,7 @@ from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout from ..interfaces import BaseReader, BaseWriter -from ..models import Request, Response +from ..models import AsyncResponse, Request class HTTP2Connection: @@ -25,7 +25,7 @@ def __init__( async def send( self, request: Request, stream: bool = False, timeout: TimeoutTypes = None - ) -> Response: + ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) #  Start sending the request. @@ -59,7 +59,7 @@ async def send( content = self.body_iter(stream_id, timeout) on_close = functools.partial(self.response_closed, stream_id=stream_id) - response = Response( + response = AsyncResponse( status_code=status_code, protocol="HTTP/2", headers=headers, diff --git a/httpcore/dispatch/threaded.py b/httpcore/dispatch/threaded.py index 441abbb8cc..3869235c94 100644 --- a/httpcore/dispatch/threaded.py +++ b/httpcore/dispatch/threaded.py @@ -1,6 +1,6 @@ from ..config import CertTypes, TimeoutTypes, VerifyTypes from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher -from ..models import Request, Response +from ..models import AsyncResponse, Request, Response class ThreadedDispatcher(AsyncDispatcher): @@ -15,7 +15,7 @@ async def send( verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: func = self.sync_dispatcher.send kwargs = { "request": request, diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index 63096c7fb9..7c45d972b7 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -7,6 +7,7 @@ from .models import ( URL, AsyncRequestData, + AsyncResponse, Headers, HeaderTypes, QueryParamTypes, @@ -43,7 +44,7 @@ async def request( verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None - ) -> Response: + ) -> AsyncResponse: request = Request(method, url, data=data, params=params, headers=headers) return await self.send( request, stream=stream, verify=verify, cert=cert, timeout=timeout @@ -56,7 +57,7 @@ async def send( verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: raise NotImplementedError() # pragma: nocover async def close(self) -> None: diff --git a/httpcore/models.py b/httpcore/models.py index a68061a271..ae49bb6620 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -713,7 +713,7 @@ def __repr__(self) -> str: return f"" -class Response(BaseResponse): +class AsyncResponse(BaseResponse): def __init__( self, status_code: int, @@ -793,12 +793,7 @@ async def close(self) -> None: await self.on_close() -class SyncResponse(BaseResponse): - """ - A thread-synchronous response. This class proxies onto a `Response` - instance, providing standard synchronous interfaces where required. - """ - +class Response(BaseResponse): def __init__( self, status_code: int, diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 8ecd37ab5c..1fe1e0535d 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -3,7 +3,7 @@ import httpcore -async def streaming_body(): +def streaming_body(): yield b"Hello, " yield b"world!" @@ -105,8 +105,7 @@ def test_response_force_encoding(): assert response.encoding == "iso-8859-1" -@pytest.mark.asyncio -async def test_read_response(): +def test_read_response(): response = httpcore.Response(200, content=b"Hello, world!") assert response.status_code == 200 @@ -114,79 +113,73 @@ async def test_read_response(): assert response.encoding == "ascii" assert response.is_closed - content = await response.read() + content = response.read() assert content == b"Hello, world!" assert response.content == b"Hello, world!" assert response.is_closed -@pytest.mark.asyncio -async def test_raw_interface(): +def test_raw_interface(): response = httpcore.Response(200, content=b"Hello, world!") raw = b"" - async for part in response.raw(): + for part in response.raw(): raw += part assert raw == b"Hello, world!" -@pytest.mark.asyncio -async def test_stream_interface(): +def test_stream_interface(): response = httpcore.Response(200, content=b"Hello, world!") content = b"" - async for part in response.stream(): + for part in response.stream(): content += part assert content == b"Hello, world!" -@pytest.mark.asyncio -async def test_stream_interface_after_read(): +def test_stream_interface_after_read(): response = httpcore.Response(200, content=b"Hello, world!") - await response.read() + response.read() content = b"" - async for part in response.stream(): + for part in response.stream(): content += part assert content == b"Hello, world!" -@pytest.mark.asyncio -async def test_streaming_response(): +def test_streaming_response(): response = httpcore.Response(200, content=streaming_body()) assert response.status_code == 200 assert not response.is_closed - content = await response.read() + content = response.read() assert content == b"Hello, world!" assert response.content == b"Hello, world!" assert response.is_closed -@pytest.mark.asyncio -async def test_cannot_read_after_stream_consumed(): +def test_cannot_read_after_stream_consumed(): response = httpcore.Response(200, content=streaming_body()) content = b"" - async for part in response.stream(): + for part in response.stream(): content += part with pytest.raises(httpcore.StreamConsumed): - await response.read() + response.read() -@pytest.mark.asyncio -async def test_cannot_read_after_response_closed(): +def test_cannot_read_after_response_closed(): response = httpcore.Response(200, content=streaming_body()) - await response.close() + response.close() with pytest.raises(httpcore.ResponseClosed): - await response.read() + response.read() def test_unknown_status_code(): diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 20273eec26..ac795ca91e 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -64,19 +64,18 @@ def test_multi_with_identity(): assert response.content == body -@pytest.mark.asyncio -async def test_streaming(): +def test_streaming(): body = b"test 123" compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) - async def compress(body): + def compress(body): yield compressor.compress(body) yield compressor.flush() headers = [(b"Content-Encoding", b"gzip")] response = httpcore.Response(200, headers=headers, content=compress(body)) assert not hasattr(response, "body") - assert await response.read() == body + assert response.read() == body @pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br")) From bb75f0acbb033448d2d32961839bd0569acc1071 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 4 Jun 2019 16:38:03 +0100 Subject: [PATCH 10/19] Tweaking type annotation --- httpcore/client.py | 40 ++++++++++++++++++++++------------------ httpcore/models.py | 20 +++++++++++--------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/httpcore/client.py b/httpcore/client.py index b37d7b0fb9..01825a952f 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -20,6 +20,7 @@ URL, AsyncRequestData, AsyncResponse, + AsyncResponseContent, AuthTypes, Cookies, CookieTypes, @@ -29,6 +30,7 @@ Request, RequestData, Response, + ResponseContent, URLTypes, ) from .status_codes import codes @@ -120,8 +122,8 @@ async def send_handling_redirects( verify: VerifyTypes = None, timeout: TimeoutTypes = None, allow_redirects: bool = True, - history: typing.List[Response] = None, - ) -> Response: + history: typing.List[AsyncResponse] = None, + ) -> AsyncResponse: if history is None: history = [] @@ -146,7 +148,7 @@ async def send_handling_redirects( request = self.build_redirect_request(request, response) else: - async def send_next() -> Response: + async def send_next() -> AsyncResponse: nonlocal request, response, verify, cert, allow_redirects, timeout, history request = self.build_redirect_request(request, response) response = await self.send_handling_redirects( @@ -165,7 +167,9 @@ async def send_next() -> Response: return response - def build_redirect_request(self, request: Request, response: Response) -> Request: + def build_redirect_request( + self, request: Request, response: AsyncResponse + ) -> Request: method = self.redirect_method(request, response) url = self.redirect_url(request, response) headers = self.redirect_headers(request, url) @@ -175,7 +179,7 @@ def build_redirect_request(self, request: Request, response: Response) -> Reques method=method, url=url, headers=headers, data=content, cookies=cookies ) - def redirect_method(self, request: Request, response: Response) -> str: + def redirect_method(self, request: Request, response: AsyncResponse) -> str: """ When being redirected we may want to change the method of the request based on certain specs or browser behavior. @@ -198,7 +202,7 @@ def redirect_method(self, request: Request, response: Response) -> str: return method - def redirect_url(self, request: Request, response: Response) -> URL: + def redirect_url(self, request: Request, response: AsyncResponse) -> URL: """ Return the URL for the redirect to follow. """ @@ -252,7 +256,7 @@ async def get( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "GET", url, @@ -280,7 +284,7 @@ async def options( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "OPTIONS", url, @@ -308,7 +312,7 @@ async def head( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "HEAD", url, @@ -338,7 +342,7 @@ async def post( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "POST", url, @@ -370,7 +374,7 @@ async def put( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "PUT", url, @@ -402,7 +406,7 @@ async def patch( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "PATCH", url, @@ -434,7 +438,7 @@ async def delete( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "DELETE", url, @@ -467,7 +471,7 @@ async def request( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: request = Request( method, url, @@ -517,8 +521,8 @@ def _async_request_data(self, data: RequestData) -> AsyncRequestData: assert hasattr(data, "__iter__") return self.concurrency_backend.iterate_in_threadpool(data) - def _sync_data(self, data): - if isinstance(data, (bytes, dict)): + def _sync_data(self, data: AsyncResponseContent) -> ResponseContent: + if isinstance(data, bytes): return data # Coerce an async iterator into an iterator, with each item in the @@ -572,9 +576,9 @@ def request( sync_content = self._sync_data(content) - def sync_on_close(): + def sync_on_close() -> None: nonlocal concurrency_backend, async_response - return concurrency_backend.run(async_response.on_close) + concurrency_backend.run(async_response.on_close) response = Response( status_code=async_response.status_code, diff --git a/httpcore/models.py b/httpcore/models.py index ae49bb6620..d7e8f772ef 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -54,7 +54,9 @@ RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]] -ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]] +AsyncResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]] + +ResponseContent = typing.Union[bytes, typing.Iterator[bytes]] class URL: @@ -562,10 +564,8 @@ def __init__( reason_phrase: str = None, protocol: str = None, headers: HeaderTypes = None, - content: ResponseContent = b"", on_close: typing.Callable = None, request: Request = None, - history: typing.List["Response"] = None, ): self.status_code = StatusCode.enum_or_int(status_code) self.reason_phrase = StatusCode.get_reason_phrase(status_code) @@ -574,7 +574,6 @@ def __init__( self.on_close = on_close self.request = request - self.history = [] if history is None else list(history) self.next = None # typing.Optional[typing.Callable] @property @@ -590,7 +589,8 @@ def url(self) -> typing.Optional[URL]: def content(self) -> bytes: if not hasattr(self, "_content"): if hasattr(self, "_raw_content"): - content = self.decoder.decode(self._raw_content) + raw_content = getattr(self, "_raw_content") # type: bytes + content = self.decoder.decode(raw_content) content += self.decoder.flush() self._content = content else: @@ -721,10 +721,10 @@ def __init__( reason_phrase: str = None, protocol: str = None, headers: HeaderTypes = None, - content: ResponseContent = b"", + content: AsyncResponseContent = b"", on_close: typing.Callable = None, request: Request = None, - history: typing.List["Response"] = None, + history: typing.List["AsyncResponse"] = None, ): super().__init__( status_code=status_code, @@ -733,9 +733,10 @@ def __init__( headers=headers, on_close=on_close, request=request, - history=history, ) + self.history = [] if history is None else list(history) + if isinstance(content, bytes): self.is_closed = True self.is_stream_consumed = True @@ -813,9 +814,10 @@ def __init__( headers=headers, on_close=on_close, request=request, - history=history, ) + self.history = [] if history is None else list(history) + if isinstance(content, bytes): self.is_closed = True self.is_stream_consumed = True From 9f185e18bf6395909ea4d8eebd3c6093dfd0b0c3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 5 Jun 2019 11:47:27 +0100 Subject: [PATCH 11/19] Distinct classes for Request, AsyncRequest --- httpcore/auth.py | 6 +- httpcore/client.py | 23 +++-- httpcore/dispatch/connection.py | 4 +- httpcore/dispatch/connection_pool.py | 4 +- httpcore/dispatch/http11.py | 4 +- httpcore/dispatch/http2.py | 6 +- httpcore/dispatch/threaded.py | 4 +- httpcore/interfaces.py | 5 +- httpcore/models.py | 149 ++++++++++++++++++++------- tests/models/test_requests.py | 4 +- 10 files changed, 141 insertions(+), 68 deletions(-) diff --git a/httpcore/auth.py b/httpcore/auth.py index 49ff998b43..6a39c1b2c5 100644 --- a/httpcore/auth.py +++ b/httpcore/auth.py @@ -1,7 +1,7 @@ import typing from base64 import b64encode -from .models import Request +from .models import AsyncRequest class AuthBase: @@ -9,7 +9,7 @@ class AuthBase: Base class that all auth implementations derive from. """ - def __call__(self, request: Request) -> Request: + def __call__(self, request: AsyncRequest) -> AsyncRequest: raise NotImplementedError("Auth hooks must be callable.") # pragma: nocover @@ -20,7 +20,7 @@ def __init__( self.username = username self.password = password - def __call__(self, request: Request) -> Request: + def __call__(self, request: AsyncRequest) -> AsyncRequest: request.headers["Authorization"] = self.build_auth_header() return request diff --git a/httpcore/client.py b/httpcore/client.py index 01825a952f..3edba29240 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -18,6 +18,7 @@ from .interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher from .models import ( URL, + AsyncRequest, AsyncRequestData, AsyncResponse, AsyncResponseContent, @@ -82,7 +83,7 @@ def merge_cookies( async def send( self, - request: Request, + request: AsyncRequest, *, stream: bool = False, auth: AuthTypes = None, @@ -115,7 +116,7 @@ async def send( async def send_handling_redirects( self, - request: Request, + request: AsyncRequest, *, stream: bool = False, cert: CertTypes = None, @@ -168,18 +169,18 @@ async def send_next() -> AsyncResponse: return response def build_redirect_request( - self, request: Request, response: AsyncResponse - ) -> Request: + self, request: AsyncRequest, response: AsyncResponse + ) -> AsyncRequest: method = self.redirect_method(request, response) url = self.redirect_url(request, response) headers = self.redirect_headers(request, url) content = self.redirect_content(request, method) cookies = self.merge_cookies(request.cookies) - return Request( + return AsyncRequest( method=method, url=url, headers=headers, data=content, cookies=cookies ) - def redirect_method(self, request: Request, response: AsyncResponse) -> str: + def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str: """ When being redirected we may want to change the method of the request based on certain specs or browser behavior. @@ -202,7 +203,7 @@ def redirect_method(self, request: Request, response: AsyncResponse) -> str: return method - def redirect_url(self, request: Request, response: AsyncResponse) -> URL: + def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL: """ Return the URL for the redirect to follow. """ @@ -221,7 +222,7 @@ def redirect_url(self, request: Request, response: AsyncResponse) -> URL: return url - def redirect_headers(self, request: Request, url: URL) -> Headers: + def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers: """ Strip Authorization headers when responses are redirected away from the origin. @@ -231,7 +232,7 @@ def redirect_headers(self, request: Request, url: URL) -> Headers: del headers["Authorization"] return headers - def redirect_content(self, request: Request, method: str) -> bytes: + def redirect_content(self, request: AsyncRequest, method: str) -> bytes: """ Return the body that should be used for the redirect request. """ @@ -472,7 +473,7 @@ async def request( verify: VerifyTypes = None, timeout: TimeoutTypes = None, ) -> AsyncResponse: - request = Request( + request = AsyncRequest( method, url, data=data, @@ -547,7 +548,7 @@ def request( verify: VerifyTypes = None, timeout: TimeoutTypes = None, ) -> Response: - request = Request( + request = AsyncRequest( method, url, data=self._async_request_data(data), diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py index c0ac499d6e..55592432a6 100644 --- a/httpcore/dispatch/connection.py +++ b/httpcore/dispatch/connection.py @@ -16,7 +16,7 @@ ) from ..exceptions import ConnectTimeout from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Protocol -from ..models import AsyncResponse, Origin, Request +from ..models import AsyncRequest, AsyncResponse, Origin from .http2 import HTTP2Connection from .http11 import HTTP11Connection @@ -44,7 +44,7 @@ def __init__( async def send( self, - request: Request, + request: AsyncRequest, stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, diff --git a/httpcore/dispatch/connection_pool.py b/httpcore/dispatch/connection_pool.py index cf3a481947..2777c04713 100644 --- a/httpcore/dispatch/connection_pool.py +++ b/httpcore/dispatch/connection_pool.py @@ -13,7 +13,7 @@ from ..decoders import ACCEPT_ENCODING from ..exceptions import PoolTimeout from ..interfaces import AsyncDispatcher, ConcurrencyBackend -from ..models import AsyncResponse, Origin, Request +from ..models import AsyncRequest, AsyncResponse, Origin from .connection import HTTPConnection CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]] @@ -105,7 +105,7 @@ def num_connections(self) -> int: async def send( self, - request: Request, + request: AsyncRequest, stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, diff --git a/httpcore/dispatch/http11.py b/httpcore/dispatch/http11.py index 549a82059a..3aa81ca215 100644 --- a/httpcore/dispatch/http11.py +++ b/httpcore/dispatch/http11.py @@ -5,7 +5,7 @@ from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout from ..interfaces import BaseReader, BaseWriter -from ..models import AsyncResponse, Request +from ..models import AsyncRequest, AsyncResponse H11Event = typing.Union[ h11.Request, @@ -38,7 +38,7 @@ def __init__( self.h11_state = h11.Connection(our_role=h11.CLIENT) async def send( - self, request: Request, stream: bool = False, timeout: TimeoutTypes = None + self, request: AsyncRequest, stream: bool = False, timeout: TimeoutTypes = None ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) diff --git a/httpcore/dispatch/http2.py b/httpcore/dispatch/http2.py index 5b6b26b4c8..94a1967a78 100644 --- a/httpcore/dispatch/http2.py +++ b/httpcore/dispatch/http2.py @@ -7,7 +7,7 @@ from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout from ..interfaces import BaseReader, BaseWriter -from ..models import AsyncResponse, Request +from ..models import AsyncRequest, AsyncResponse class HTTP2Connection: @@ -24,7 +24,7 @@ def __init__( self.initialized = False async def send( - self, request: Request, stream: bool = False, timeout: TimeoutTypes = None + self, request: AsyncRequest, stream: bool = False, timeout: TimeoutTypes = None ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) @@ -86,7 +86,7 @@ def initiate_connection(self) -> None: self.initialized = True async def send_headers( - self, request: Request, timeout: TimeoutConfig = None + self, request: AsyncRequest, timeout: TimeoutConfig = None ) -> int: stream_id = self.h2_state.get_next_available_stream_id() headers = [ diff --git a/httpcore/dispatch/threaded.py b/httpcore/dispatch/threaded.py index 3869235c94..3366714145 100644 --- a/httpcore/dispatch/threaded.py +++ b/httpcore/dispatch/threaded.py @@ -1,6 +1,6 @@ from ..config import CertTypes, TimeoutTypes, VerifyTypes from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher -from ..models import AsyncResponse, Request, Response +from ..models import AsyncRequest, AsyncResponse, Request, Response class ThreadedDispatcher(AsyncDispatcher): @@ -10,7 +10,7 @@ def __init__(self, dispatch: Dispatcher, backend: ConcurrencyBackend) -> None: async def send( self, - request: Request, + request: AsyncRequest, stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index 7c45d972b7..74f778c127 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -6,6 +6,7 @@ from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes from .models import ( URL, + AsyncRequest, AsyncRequestData, AsyncResponse, Headers, @@ -45,14 +46,14 @@ async def request( cert: CertTypes = None, timeout: TimeoutTypes = None ) -> AsyncResponse: - request = Request(method, url, data=data, params=params, headers=headers) + request = AsyncRequest(method, url, data=data, params=params, headers=headers) return await self.send( request, stream=stream, verify=verify, cert=cert, timeout=timeout ) async def send( self, - request: Request, + request: AsyncRequest, stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, diff --git a/httpcore/models.py b/httpcore/models.py index d7e8f772ef..3ee0f18ee5 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -47,7 +47,7 @@ AuthTypes = typing.Union[ typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]], - typing.Callable[["Request"], "Request"], + typing.Callable[["AsyncRequest"], "AsyncRequest"], ] AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]] @@ -472,14 +472,12 @@ def __repr__(self) -> str: return f"{class_name}({as_list!r}{encoding_str})" -class Request: +class BaseRequest: def __init__( self, method: str, url: typing.Union[str, URL], *, - data: AsyncRequestData = b"", - json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, cookies: CookieTypes = None, @@ -491,16 +489,73 @@ def __init__( self._cookies = Cookies(cookies) self._cookies.set_cookie_header(self) + self.content = b"" + self.is_streaming = False + + def encode_json(self, json: typing.Any) -> bytes: + return jsonlib.dumps(json).encode("utf-8") + + def urlencode_data(self, data: dict) -> bytes: + return urlencode(data, doseq=True).encode("utf-8") + + def prepare(self) -> None: + auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + + has_content_length = ( + "content-length" in self.headers or "transfer-encoding" in self.headers + ) + has_accept_encoding = "accept-encoding" in self.headers + + if not has_content_length: + if self.is_streaming: + auto_headers.append((b"transfer-encoding", b"chunked")) + elif self.content: + content_length = str(len(self.content)).encode() + auto_headers.append((b"content-length", content_length)) + if not has_accept_encoding: + auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode())) + + for item in reversed(auto_headers): + self.headers.raw.insert(0, item) + + @property + def cookies(self) -> "Cookies": + if not hasattr(self, "_cookies"): + self._cookies = Cookies() + return self._cookies + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + url = str(self.url) + return f"<{class_name}({self.method!r}, {url!r})>" + + +class AsyncRequest(BaseRequest): + def __init__( + self, + method: str, + url: typing.Union[str, URL], + *, + params: QueryParamTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + data: AsyncRequestData = b"", + json: typing.Any = None, + ): + super().__init__( + method=method, url=url, params=params, headers=headers, cookies=cookies + ) + if json is not None: - data = jsonlib.dumps(json).encode("utf-8") + self.is_streaming = False + self.content = self.encode_json(json) self.headers["Content-Type"] = "application/json" - - if isinstance(data, bytes): + elif isinstance(data, bytes): self.is_streaming = False self.content = data elif isinstance(data, dict): self.is_streaming = False - self.content = urlencode(data, doseq=True).encode("utf-8") + self.content = self.urlencode_data(data) self.headers["Content-Type"] = "application/x-www-form-urlencoded" else: assert hasattr(data, "__aiter__") @@ -524,36 +579,52 @@ async def stream(self) -> typing.AsyncIterator[bytes]: elif self.content: yield self.content - def prepare(self) -> None: - auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] - has_content_length = ( - "content-length" in self.headers or "transfer-encoding" in self.headers +class Request(BaseRequest): + def __init__( + self, + method: str, + url: typing.Union[str, URL], + *, + params: QueryParamTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + data: RequestData = b"", + json: typing.Any = None, + ): + super().__init__( + method=method, url=url, params=params, headers=headers, cookies=cookies ) - has_accept_encoding = "accept-encoding" in self.headers - if not has_content_length: - if self.is_streaming: - auto_headers.append((b"transfer-encoding", b"chunked")) - elif self.content: - content_length = str(len(self.content)).encode() - auto_headers.append((b"content-length", content_length)) - if not has_accept_encoding: - auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode())) + if json is not None: + self.is_streaming = False + self.content = self.encode_json(json) + self.headers["Content-Type"] = "application/json" + elif isinstance(data, bytes): + self.is_streaming = False + self.content = data + elif isinstance(data, dict): + self.is_streaming = False + self.content = self.urlencode_data(data) + self.headers["Content-Type"] = "application/x-www-form-urlencoded" + else: + assert hasattr(data, "__iter__") + self.is_streaming = True + self.content_iter = data - for item in reversed(auto_headers): - self.headers.raw.insert(0, item) + self.prepare() - @property - def cookies(self) -> "Cookies": - if not hasattr(self, "_cookies"): - self._cookies = Cookies() - return self._cookies + def read(self) -> bytes: + if not hasattr(self, "content"): + self.content = b"".join([part for part in self.stream()]) + return self.content - def __repr__(self) -> str: - class_name = self.__class__.__name__ - url = str(self.url) - return f"<{class_name}({self.method!r}, {url!r})>" + def stream(self) -> typing.Iterator[bytes]: + if self.is_streaming: + for part in self.content_iter: + yield part + elif self.content: + yield self.content class BaseResponse: @@ -564,16 +635,16 @@ def __init__( reason_phrase: str = None, protocol: str = None, headers: HeaderTypes = None, + request: BaseRequest = None, on_close: typing.Callable = None, - request: Request = None, ): self.status_code = StatusCode.enum_or_int(status_code) self.reason_phrase = StatusCode.get_reason_phrase(status_code) self.protocol = protocol self.headers = Headers(headers) - self.on_close = on_close self.request = request + self.on_close = on_close self.next = None # typing.Optional[typing.Callable] @property @@ -723,7 +794,7 @@ def __init__( headers: HeaderTypes = None, content: AsyncResponseContent = b"", on_close: typing.Callable = None, - request: Request = None, + request: AsyncRequest = None, history: typing.List["AsyncResponse"] = None, ): super().__init__( @@ -731,8 +802,8 @@ def __init__( reason_phrase=reason_phrase, protocol=protocol, headers=headers, - on_close=on_close, request=request, + on_close=on_close, ) self.history = [] if history is None else list(history) @@ -812,8 +883,8 @@ def __init__( reason_phrase=reason_phrase, protocol=protocol, headers=headers, - on_close=on_close, request=request, + on_close=on_close, ) self.history = [] if history is None else list(history) @@ -903,7 +974,7 @@ def extract_cookies(self, response: BaseResponse) -> None: self.jar.extract_cookies(urlib_response, urllib_request) # type: ignore - def set_cookie_header(self, request: Request) -> None: + def set_cookie_header(self, request: BaseRequest) -> None: """ Sets an appropriate 'Cookie:' HTTP header on the `Request`. """ @@ -1022,7 +1093,7 @@ class _CookieCompatRequest(urllib.request.Request): for use with `CookieJar` operations. """ - def __init__(self, request: Request) -> None: + def __init__(self, request: BaseRequest) -> None: super().__init__( url=str(request.url), headers=dict(request.headers), diff --git a/tests/models/test_requests.py b/tests/models/test_requests.py index 5d3a27d3ff..183c9a7caa 100644 --- a/tests/models/test_requests.py +++ b/tests/models/test_requests.py @@ -38,7 +38,7 @@ def test_url_encoded_data(): def test_transfer_encoding_header(): - async def streaming_body(data): + def streaming_body(data): yield data # pragma: nocover data = streaming_body(b"test 123") @@ -69,7 +69,7 @@ def test_override_accept_encoding_header(): def test_override_content_length_header(): - async def streaming_body(data): + def streaming_body(data): yield data # pragma: nocover data = streaming_body(b"test 123") From 20e929ad387648a06dc1bc9ab347554f900b4687 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 5 Jun 2019 14:11:46 +0100 Subject: [PATCH 12/19] Tweak is_streaming, content in BaseRequest --- httpcore/models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/httpcore/models.py b/httpcore/models.py index 3ee0f18ee5..904c41ec02 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -489,9 +489,6 @@ def __init__( self._cookies = Cookies(cookies) self._cookies.set_cookie_header(self) - self.content = b"" - self.is_streaming = False - def encode_json(self, json: typing.Any) -> bytes: return jsonlib.dumps(json).encode("utf-8") @@ -499,6 +496,9 @@ def urlencode_data(self, data: dict) -> bytes: return urlencode(data, doseq=True).encode("utf-8") def prepare(self) -> None: + content = getattr(self, "content", None) # type: bytes + is_streaming = getattr(self, "is_streaming", False) + auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] has_content_length = ( @@ -507,10 +507,10 @@ def prepare(self) -> None: has_accept_encoding = "accept-encoding" in self.headers if not has_content_length: - if self.is_streaming: + if is_streaming: auto_headers.append((b"transfer-encoding", b"chunked")) - elif self.content: - content_length = str(len(self.content)).encode() + elif content: + content_length = str(len(content)).encode() auto_headers.append((b"content-length", content_length)) if not has_accept_encoding: auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode())) From b89d1fc71ef3938ff575ceeada33228e6c7b80c3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 5 Jun 2019 15:27:11 +0100 Subject: [PATCH 13/19] Stream handling moves to client --- httpcore/__init__.py | 12 ++++++++- httpcore/client.py | 13 +++++++--- httpcore/dispatch/connection.py | 9 ++----- httpcore/dispatch/connection_pool.py | 3 +-- httpcore/dispatch/http11.py | 12 ++------- httpcore/dispatch/http2.py | 12 ++------- httpcore/dispatch/threaded.py | 2 -- httpcore/interfaces.py | 12 ++------- tests/client/test_auth.py | 11 ++++----- tests/client/test_cookies.py | 13 +++++----- tests/client/test_redirects.py | 33 ++++++++++++++----------- tests/dispatch/test_connection_pools.py | 17 +++++++++---- tests/dispatch/test_connections.py | 3 +++ tests/dispatch/test_threaded.py | 1 - 14 files changed, 73 insertions(+), 80 deletions(-) diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 45508d3b4a..b49691bf87 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -36,7 +36,17 @@ Dispatcher, Protocol, ) -from .models import URL, Cookies, Headers, Origin, QueryParams, Request, Response +from .models import ( + URL, + AsyncRequest, + AsyncResponse, + Cookies, + Headers, + Origin, + QueryParams, + Request, + Response, +) from .status_codes import StatusCode, codes __version__ = "0.3.0" diff --git a/httpcore/client.py b/httpcore/client.py index 3edba29240..08fe6efb02 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -106,19 +106,24 @@ async def send( response = await self.send_handling_redirects( request, - stream=stream, verify=verify, cert=cert, timeout=timeout, allow_redirects=allow_redirects, ) + + if not stream: + try: + await response.read() + finally: + await response.close() + return response async def send_handling_redirects( self, request: AsyncRequest, *, - stream: bool = False, cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, @@ -137,8 +142,9 @@ async def send_handling_redirects( raise RedirectLoop() response = await self.dispatch.send( - request, stream=stream, verify=verify, cert=cert, timeout=timeout + request, verify=verify, cert=cert, timeout=timeout ) + assert isinstance(response, AsyncResponse) response.history = list(history) self.cookies.extract_cookies(response) history = [response] + history @@ -154,7 +160,6 @@ async def send_next() -> AsyncResponse: request = self.build_redirect_request(request, response) response = await self.send_handling_redirects( request, - stream=stream, allow_redirects=allow_redirects, verify=verify, cert=cert, diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py index 55592432a6..d644bcba73 100644 --- a/httpcore/dispatch/connection.py +++ b/httpcore/dispatch/connection.py @@ -45,7 +45,6 @@ def __init__( async def send( self, request: AsyncRequest, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, @@ -54,14 +53,10 @@ async def send( await self.connect(verify=verify, cert=cert, timeout=timeout) if self.h2_connection is not None: - response = await self.h2_connection.send( - request, stream=stream, timeout=timeout - ) + response = await self.h2_connection.send(request, timeout=timeout) else: assert self.h11_connection is not None - response = await self.h11_connection.send( - request, stream=stream, timeout=timeout - ) + response = await self.h11_connection.send(request, timeout=timeout) return response diff --git a/httpcore/dispatch/connection_pool.py b/httpcore/dispatch/connection_pool.py index 2777c04713..c84117ca45 100644 --- a/httpcore/dispatch/connection_pool.py +++ b/httpcore/dispatch/connection_pool.py @@ -106,7 +106,6 @@ def num_connections(self) -> int: async def send( self, request: AsyncRequest, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, @@ -114,7 +113,7 @@ async def send( connection = await self.acquire_connection(request.url.origin) try: response = await connection.send( - request, stream=stream, verify=verify, cert=cert, timeout=timeout + request, verify=verify, cert=cert, timeout=timeout ) except BaseException as exc: self.active_connections.remove(connection) diff --git a/httpcore/dispatch/http11.py b/httpcore/dispatch/http11.py index 3aa81ca215..f19b3d3dc1 100644 --- a/httpcore/dispatch/http11.py +++ b/httpcore/dispatch/http11.py @@ -38,7 +38,7 @@ def __init__( self.h11_state = h11.Connection(our_role=h11.CLIENT) async def send( - self, request: AsyncRequest, stream: bool = False, timeout: TimeoutTypes = None + self, request: AsyncRequest, timeout: TimeoutTypes = None ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) @@ -72,7 +72,7 @@ async def send( headers = event.headers content = self._body_iter(timeout) - response = AsyncResponse( + return AsyncResponse( status_code=status_code, reason_phrase=reason_phrase, protocol="HTTP/1.1", @@ -82,14 +82,6 @@ async def send( request=request, ) - if not stream: - try: - await response.read() - finally: - await response.close() - - return response - async def close(self) -> None: event = h11.ConnectionClosed() self.h11_state.send(event) diff --git a/httpcore/dispatch/http2.py b/httpcore/dispatch/http2.py index 94a1967a78..f7814ec3c1 100644 --- a/httpcore/dispatch/http2.py +++ b/httpcore/dispatch/http2.py @@ -24,7 +24,7 @@ def __init__( self.initialized = False async def send( - self, request: AsyncRequest, stream: bool = False, timeout: TimeoutTypes = None + self, request: AsyncRequest, timeout: TimeoutTypes = None ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) @@ -59,7 +59,7 @@ async def send( content = self.body_iter(stream_id, timeout) on_close = functools.partial(self.response_closed, stream_id=stream_id) - response = AsyncResponse( + return AsyncResponse( status_code=status_code, protocol="HTTP/2", headers=headers, @@ -68,14 +68,6 @@ async def send( request=request, ) - if not stream: - try: - await response.read() - finally: - await response.close() - - return response - async def close(self) -> None: await self.writer.close() diff --git a/httpcore/dispatch/threaded.py b/httpcore/dispatch/threaded.py index 3366714145..96b93d6117 100644 --- a/httpcore/dispatch/threaded.py +++ b/httpcore/dispatch/threaded.py @@ -11,7 +11,6 @@ def __init__(self, dispatch: Dispatcher, backend: ConcurrencyBackend) -> None: async def send( self, request: AsyncRequest, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, @@ -19,7 +18,6 @@ async def send( func = self.sync_dispatcher.send kwargs = { "request": request, - "stream": stream, "verify": verify, "cert": cert, "timeout": timeout, diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index 74f778c127..13d118cfcd 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -41,20 +41,16 @@ async def request( data: AsyncRequestData = b"", params: QueryParamTypes = None, headers: HeaderTypes = None, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None ) -> AsyncResponse: request = AsyncRequest(method, url, data=data, params=params, headers=headers) - return await self.send( - request, stream=stream, verify=verify, cert=cert, timeout=timeout - ) + return await self.send(request, verify=verify, cert=cert, timeout=timeout) async def send( self, request: AsyncRequest, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, @@ -93,20 +89,16 @@ def request( data: RequestData = b"", params: QueryParamTypes = None, headers: HeaderTypes = None, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None ) -> Response: request = Request(method, url, data=data, params=params, headers=headers) - return self.send( - request, stream=stream, verify=verify, cert=cert, timeout=timeout - ) + return self.send(request, verify=verify, cert=cert, timeout=timeout) def send( self, request: Request, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 7601ffc3ef..17993383a9 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -5,10 +5,10 @@ from httpcore import ( URL, AsyncDispatcher, + AsyncRequest, + AsyncResponse, CertTypes, Client, - Request, - Response, TimeoutTypes, VerifyTypes, ) @@ -17,14 +17,13 @@ class MockDispatch(AsyncDispatcher): async def send( self, - request: Request, - stream: bool = False, + request: AsyncRequest, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: body = json.dumps({"auth": request.headers.get("Authorization")}).encode() - return Response(200, content=body, request=request) + return AsyncResponse(200, content=body, request=request) def test_basic_auth(): diff --git a/tests/client/test_cookies.py b/tests/client/test_cookies.py index eb951726f8..5cbb380921 100644 --- a/tests/client/test_cookies.py +++ b/tests/client/test_cookies.py @@ -6,11 +6,11 @@ from httpcore import ( URL, AsyncDispatcher, + AsyncRequest, + AsyncResponse, CertTypes, Client, Cookies, - Request, - Response, TimeoutTypes, VerifyTypes, ) @@ -19,18 +19,17 @@ class MockDispatch(AsyncDispatcher): async def send( self, - request: Request, - stream: bool = False, + request: AsyncRequest, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: if request.url.path.startswith("/echo_cookies"): body = json.dumps({"cookies": request.headers.get("Cookie")}).encode() - return Response(200, content=body, request=request) + return AsyncResponse(200, content=body, request=request) elif request.url.path.startswith("/set_cookie"): headers = {"set-cookie": "example-name=example-value"} - return Response(200, headers=headers, request=request) + return AsyncResponse(200, headers=headers, request=request) def test_set_cookie(): diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index b0cfbb9bfc..3f5168974a 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -7,6 +7,8 @@ URL, AsyncClient, AsyncDispatcher, + AsyncRequest, + AsyncResponse, CertTypes, RedirectBodyUnavailable, RedirectLoop, @@ -22,34 +24,33 @@ class MockDispatch(AsyncDispatcher): async def send( self, - request: Request, - stream: bool = False, + request: AsyncRequest, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: if request.url.path == "/redirect_301": status_code = codes.MOVED_PERMANENTLY headers = {"location": "https://example.org/"} - return Response(status_code, headers=headers, request=request) + return AsyncResponse(status_code, headers=headers, request=request) elif request.url.path == "/redirect_302": status_code = codes.FOUND headers = {"location": "https://example.org/"} - return Response(status_code, headers=headers, request=request) + return AsyncResponse(status_code, headers=headers, request=request) elif request.url.path == "/redirect_303": status_code = codes.SEE_OTHER headers = {"location": "https://example.org/"} - return Response(status_code, headers=headers, request=request) + return AsyncResponse(status_code, headers=headers, request=request) elif request.url.path == "/relative_redirect": headers = {"location": "/"} - return Response(codes.SEE_OTHER, headers=headers, request=request) + return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request) elif request.url.path == "/no_scheme_redirect": headers = {"location": "//example.org/"} - return Response(codes.SEE_OTHER, headers=headers, request=request) + return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request) elif request.url.path == "/multiple_redirects": params = parse_qs(request.url.query) @@ -60,32 +61,34 @@ async def send( if redirect_count: location += "?count=" + str(redirect_count) headers = {"location": location} if count else {} - return Response(code, headers=headers, request=request) + return AsyncResponse(code, headers=headers, request=request) if request.url.path == "/redirect_loop": headers = {"location": "/redirect_loop"} - return Response(codes.SEE_OTHER, headers=headers, request=request) + return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request) elif request.url.path == "/cross_domain": headers = {"location": "https://example.org/cross_domain_target"} - return Response(codes.SEE_OTHER, headers=headers, request=request) + return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request) elif request.url.path == "/cross_domain_target": headers = dict(request.headers.items()) content = json.dumps({"headers": headers}).encode() - return Response(codes.OK, content=content, request=request) + return AsyncResponse(codes.OK, content=content, request=request) elif request.url.path == "/redirect_body": await request.read() headers = {"location": "/redirect_body_target"} - return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request) + return AsyncResponse( + codes.PERMANENT_REDIRECT, headers=headers, request=request + ) elif request.url.path == "/redirect_body_target": content = await request.read() body = json.dumps({"body": content.decode()}).encode() - return Response(codes.OK, content=body, request=request) + return AsyncResponse(codes.OK, content=body, request=request) - return Response(codes.OK, content=b"Hello, world!", request=request) + return AsyncResponse(codes.OK, content=b"Hello, world!", request=request) @pytest.mark.asyncio diff --git a/tests/dispatch/test_connection_pools.py b/tests/dispatch/test_connection_pools.py index bbe200ba94..b8049c70b8 100644 --- a/tests/dispatch/test_connection_pools.py +++ b/tests/dispatch/test_connection_pools.py @@ -10,10 +10,12 @@ async def test_keepalive_connections(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 @@ -25,10 +27,12 @@ async def test_differing_connection_keys(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://localhost:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 2 @@ -42,10 +46,12 @@ async def test_soft_limit(server): async with httpcore.ConnectionPool(pool_limits=pool_limits) as http: response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://localhost:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 @@ -56,7 +62,7 @@ async def test_streaming_response_holds_connection(server): A streaming request should hold the connection open until the response is read. """ async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response = await http.request("GET", "http://127.0.0.1:8000/") assert len(http.active_connections) == 1 assert len(http.keepalive_connections) == 0 @@ -72,11 +78,11 @@ async def test_multiple_concurrent_connections(server): Multiple conncurrent requests should open multiple conncurrent connections. """ async with httpcore.ConnectionPool() as http: - response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response_a = await http.request("GET", "http://127.0.0.1:8000/") assert len(http.active_connections) == 1 assert len(http.keepalive_connections) == 0 - response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response_b = await http.request("GET", "http://127.0.0.1:8000/") assert len(http.active_connections) == 2 assert len(http.keepalive_connections) == 0 @@ -97,6 +103,7 @@ async def test_close_connections(server): headers = [(b"connection", b"close")] async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers) + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 0 @@ -107,7 +114,7 @@ async def test_standard_response_close(server): A standard close should keep the connection open. """ async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() await response.close() assert len(http.active_connections) == 0 @@ -120,7 +127,7 @@ async def test_premature_response_close(server): A premature close should close the connection. """ async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response = await http.request("GET", "http://127.0.0.1:8000/") await response.close() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 0 diff --git a/tests/dispatch/test_connections.py b/tests/dispatch/test_connections.py index ca401c7813..4b267f4fd9 100644 --- a/tests/dispatch/test_connections.py +++ b/tests/dispatch/test_connections.py @@ -7,6 +7,7 @@ async def test_get(server): conn = HTTPConnection(origin="http://127.0.0.1:8000/") response = await conn.request("GET", "http://127.0.0.1:8000/") + await response.read() assert response.status_code == 200 assert response.content == b"Hello, world!" @@ -27,6 +28,7 @@ async def test_https_get_with_ssl_defaults(https_server): """ conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False) response = await conn.request("GET", "https://127.0.0.1:8001/") + await response.read() assert response.status_code == 200 assert response.content == b"Hello, world!" @@ -38,5 +40,6 @@ async def test_https_get_with_sll_overrides(https_server): """ conn = HTTPConnection(origin="https://127.0.0.1:8001/") response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False) + await response.read() assert response.status_code == 200 assert response.content == b"Hello, world!" diff --git a/tests/dispatch/test_threaded.py b/tests/dispatch/test_threaded.py index b2d99698e9..2e93ff8f84 100644 --- a/tests/dispatch/test_threaded.py +++ b/tests/dispatch/test_threaded.py @@ -17,7 +17,6 @@ class MockDispatch(Dispatcher): def send( self, request: Request, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, From 37021dea32432eb317717755feaaefc7235cb2e7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 5 Jun 2019 16:36:06 +0100 Subject: [PATCH 14/19] Handle mediating to AsyncResponse from a standard sync Dispatcher class --- httpcore/client.py | 5 +++- httpcore/dispatch/threaded.py | 54 +++++++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/httpcore/client.py b/httpcore/client.py index 08fe6efb02..0fc60a1a3b 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -597,7 +597,10 @@ def sync_on_close() -> None: history=async_response.history, ) if not stream: - response.read() + try: + response.read() + finally: + response.close() return response def get( diff --git a/httpcore/dispatch/threaded.py b/httpcore/dispatch/threaded.py index 96b93d6117..5a2280839c 100644 --- a/httpcore/dispatch/threaded.py +++ b/httpcore/dispatch/threaded.py @@ -1,9 +1,22 @@ from ..config import CertTypes, TimeoutTypes, VerifyTypes from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher -from ..models import AsyncRequest, AsyncResponse, Request, Response +from ..models import ( + AsyncRequest, + AsyncResponse, + AsyncResponseContent, + Request, + Response, + ResponseContent, +) class ThreadedDispatcher(AsyncDispatcher): + """ + The ThreadedDispatcher class is used to mediate between the Client + (which always uses async under the hood), and a synchronous `Dispatch` + class. + """ + def __init__(self, dispatch: Dispatcher, backend: ConcurrencyBackend) -> None: self.sync_dispatcher = dispatch self.backend = backend @@ -15,6 +28,8 @@ async def send( cert: CertTypes = None, timeout: TimeoutTypes = None, ) -> AsyncResponse: + concurrency_backend = self.backend + func = self.sync_dispatcher.send kwargs = { "request": request, @@ -22,8 +37,43 @@ async def send( "cert": cert, "timeout": timeout, } - return await self.backend.run_in_threadpool(func, **kwargs) + sync_response = await self.backend.run_in_threadpool(func, **kwargs) + assert isinstance(sync_response, Response) + + content = getattr( + sync_response, "_raw_content", getattr(sync_response, "_raw_stream", None) + ) + + async_content = self._async_data(content) + + def async_on_close() -> None: + nonlocal concurrency_backend, sync_response + concurrency_backend.run_in_threadpool(sync_response.on_close) + + return AsyncResponse( + status_code=sync_response.status_code, + reason_phrase=sync_response.reason_phrase, + protocol=sync_response.protocol, + headers=sync_response.headers, + content=async_content, + on_close=async_on_close, + request=request, + history=sync_response.history, + ) async def close(self) -> None: + """ + The `.close()` method runs the `Dispatcher.close()` within a threadpool, + so as not to block the async event loop. + """ func = self.sync_dispatcher.close await self.backend.run_in_threadpool(func) + + def _async_data(self, data: ResponseContent) -> AsyncResponseContent: + if isinstance(data, bytes): + return data + + # Coerce an async iterator into an iterator, with each item in the + # iteration run within the event loop. + assert hasattr(data, "__iter__") + return self.backend.iterate_in_threadpool(data) From 6f1e93374ef6cae34644a4b0a9c8a9f42db0784e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 6 Jun 2019 14:44:03 +0100 Subject: [PATCH 15/19] Working on thread-pooled dispatcher --- httpcore/dispatch/threaded.py | 4 ++-- tests/dispatch/test_threaded.py | 21 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/httpcore/dispatch/threaded.py b/httpcore/dispatch/threaded.py index 5a2280839c..626c60183e 100644 --- a/httpcore/dispatch/threaded.py +++ b/httpcore/dispatch/threaded.py @@ -46,9 +46,9 @@ async def send( async_content = self._async_data(content) - def async_on_close() -> None: + async def async_on_close() -> None: nonlocal concurrency_backend, sync_response - concurrency_backend.run_in_threadpool(sync_response.on_close) + await concurrency_backend.run_in_threadpool(sync_response.close) return AsyncResponse( status_code=sync_response.status_code, diff --git a/tests/dispatch/test_threaded.py b/tests/dispatch/test_threaded.py index 2e93ff8f84..af64ab3fa4 100644 --- a/tests/dispatch/test_threaded.py +++ b/tests/dispatch/test_threaded.py @@ -13,6 +13,11 @@ ) +def streaming_body(): + for part in [b"Hello", b", ", b"world!"]: + yield part + + class MockDispatch(Dispatcher): def send( self, @@ -21,8 +26,11 @@ def send( cert: CertTypes = None, timeout: TimeoutTypes = None, ) -> Response: - body = json.dumps({"hello": "world"}).encode() - return Response(200, content=body, request=request) + if request.url.path == "/streaming_response": + return Response(200, content=streaming_body(), request=request) + else: + body = json.dumps({"hello": "world"}).encode() + return Response(200, content=body, request=request) def test_threaded_dispatch(): @@ -38,6 +46,15 @@ def test_threaded_dispatch(): assert response.json() == {"hello": "world"} +def test_threaded_streaming_response(): + url = "https://example.org/streaming_response" + with Client(dispatch=MockDispatch()) as client: + response = client.get(url) + + assert response.status_code == 200 + assert response.text == "Hello, world!" + + def test_dispatch_class(): """ Use a syncronous 'Dispatcher' class directly. From 11917f01c24cd382335557a686459afe67236e5c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 7 Jun 2019 14:42:01 +0100 Subject: [PATCH 16/19] Support threaded dispatch, inc. streaming requests/responses --- httpcore/concurrency.py | 8 +++++++- httpcore/dispatch/threaded.py | 32 +++++++++++++++++++++++++------- tests/dispatch/test_threaded.py | 12 ++++++++++++ 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/httpcore/concurrency.py b/httpcore/concurrency.py index d45953904f..664cb29448 100644 --- a/httpcore/concurrency.py +++ b/httpcore/concurrency.py @@ -183,7 +183,13 @@ async def run_in_threadpool( def run( self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: - return self.loop.run_until_complete(coroutine(*args, **kwargs)) + loop = self.loop + if loop.is_running(): + self._loop = asyncio.new_event_loop() + try: + return self.loop.run_until_complete(coroutine(*args, **kwargs)) + finally: + self._loop = loop def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) diff --git a/httpcore/dispatch/threaded.py b/httpcore/dispatch/threaded.py index 626c60183e..dbcd4dab19 100644 --- a/httpcore/dispatch/threaded.py +++ b/httpcore/dispatch/threaded.py @@ -2,9 +2,11 @@ from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher from ..models import ( AsyncRequest, + AsyncRequestData, AsyncResponse, AsyncResponseContent, Request, + RequestData, Response, ResponseContent, ) @@ -30,9 +32,19 @@ async def send( ) -> AsyncResponse: concurrency_backend = self.backend + data = getattr(request, "content", getattr(request, "content_aiter", None)) + sync_data = self._sync_request_data(data) + + sync_request = Request( + method=request.method, + url=request.url, + headers=request.headers, + data=sync_data, + ) + func = self.sync_dispatcher.send kwargs = { - "request": request, + "request": sync_request, "verify": verify, "cert": cert, "timeout": timeout, @@ -44,7 +56,7 @@ async def send( sync_response, "_raw_content", getattr(sync_response, "_raw_stream", None) ) - async_content = self._async_data(content) + async_content = self._async_response_content(content) async def async_on_close() -> None: nonlocal concurrency_backend, sync_response @@ -69,11 +81,17 @@ async def close(self) -> None: func = self.sync_dispatcher.close await self.backend.run_in_threadpool(func) - def _async_data(self, data: ResponseContent) -> AsyncResponseContent: - if isinstance(data, bytes): - return data + def _async_response_content(self, content: ResponseContent) -> AsyncResponseContent: + if isinstance(content, bytes): + return content # Coerce an async iterator into an iterator, with each item in the # iteration run within the event loop. - assert hasattr(data, "__iter__") - return self.backend.iterate_in_threadpool(data) + assert hasattr(content, "__iter__") + return self.backend.iterate_in_threadpool(content) + + def _sync_request_data(self, data: AsyncRequestData) -> RequestData: + if isinstance(data, bytes): + return data + + return self.backend.iterate(data) diff --git a/tests/dispatch/test_threaded.py b/tests/dispatch/test_threaded.py index af64ab3fa4..7eff1a0ab3 100644 --- a/tests/dispatch/test_threaded.py +++ b/tests/dispatch/test_threaded.py @@ -28,6 +28,9 @@ def send( ) -> Response: if request.url.path == "/streaming_response": return Response(200, content=streaming_body(), request=request) + if request.url.path == "/streaming_request": + content = request.read() + return Response(200, content=content, request=request) else: body = json.dumps({"hello": "world"}).encode() return Response(200, content=body, request=request) @@ -55,6 +58,15 @@ def test_threaded_streaming_response(): assert response.text == "Hello, world!" +def test_threaded_streaming_request(): + url = "https://example.org/streaming_request" + with Client(dispatch=MockDispatch()) as client: + response = client.post(url, data=streaming_body()) + + assert response.status_code == 200 + assert response.text == "Hello, world!" + + def test_dispatch_class(): """ Use a syncronous 'Dispatcher' class directly. From 28d1dc422e28636f097db310a63d593472ed7af8 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 7 Jun 2019 14:54:57 +0100 Subject: [PATCH 17/19] Increase test coverage --- tests/dispatch/test_threaded.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/dispatch/test_threaded.py b/tests/dispatch/test_threaded.py index 7eff1a0ab3..d177dbba96 100644 --- a/tests/dispatch/test_threaded.py +++ b/tests/dispatch/test_threaded.py @@ -28,9 +28,12 @@ def send( ) -> Response: if request.url.path == "/streaming_response": return Response(200, content=streaming_body(), request=request) - if request.url.path == "/streaming_request": + elif request.url.path == "/echo_request_body": content = request.read() return Response(200, content=content, request=request) + elif request.url.path == "/echo_request_body_streaming": + content = b"".join([part for part in request.stream()]) + return Response(200, content=content, request=request) else: body = json.dumps({"hello": "world"}).encode() return Response(200, content=body, request=request) @@ -59,7 +62,7 @@ def test_threaded_streaming_response(): def test_threaded_streaming_request(): - url = "https://example.org/streaming_request" + url = "https://example.org/echo_request_body" with Client(dispatch=MockDispatch()) as client: response = client.post(url, data=streaming_body()) @@ -67,6 +70,24 @@ def test_threaded_streaming_request(): assert response.text == "Hello, world!" +def test_threaded_request_body(): + url = "https://example.org/echo_request_body" + with Client(dispatch=MockDispatch()) as client: + response = client.post(url, data=b"Hello, world!") + + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +def test_threaded_request_body_streaming(): + url = "https://example.org/echo_request_body_streaming" + with Client(dispatch=MockDispatch()) as client: + response = client.post(url, data=b"Hello, world!") + + assert response.status_code == 200 + assert response.text == "Hello, world!" + + def test_dispatch_class(): """ Use a syncronous 'Dispatcher' class directly. From 5be5598386250677d25d3495226de4db83f384f2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 10 Jun 2019 11:34:44 +0100 Subject: [PATCH 18/19] Coverage and tweaks --- httpcore/models.py | 4 +-- tests/models/test_requests.py | 27 ++++++--------- tests/models/test_responses.py | 63 ++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 18 deletions(-) diff --git a/httpcore/models.py b/httpcore/models.py index 904c41ec02..e4ec202945 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -795,7 +795,7 @@ def __init__( content: AsyncResponseContent = b"", on_close: typing.Callable = None, request: AsyncRequest = None, - history: typing.List["AsyncResponse"] = None, + history: typing.List["BaseResponse"] = None, ): super().__init__( status_code=status_code, @@ -876,7 +876,7 @@ def __init__( content: ResponseContent = b"", on_close: typing.Callable = None, request: Request = None, - history: typing.List["Response"] = None, + history: typing.List["BaseResponse"] = None, ): super().__init__( status_code=status_code, diff --git a/tests/models/test_requests.py b/tests/models/test_requests.py index 183c9a7caa..88f2a80323 100644 --- a/tests/models/test_requests.py +++ b/tests/models/test_requests.py @@ -10,7 +10,6 @@ def test_request_repr(): def test_no_content(): request = httpcore.Request("GET", "http://example.org") - request.prepare() assert request.headers == httpcore.Headers( [(b"accept-encoding", b"deflate, gzip, br")] ) @@ -18,23 +17,23 @@ def test_no_content(): def test_content_length_header(): request = httpcore.Request("POST", "http://example.org", data=b"test 123") - request.prepare() assert request.headers == httpcore.Headers( [(b"content-length", b"8"), (b"accept-encoding", b"deflate, gzip, br")] ) def test_url_encoded_data(): - request = httpcore.Request("POST", "http://example.org", data={"test": "123"}) - request.prepare() - assert request.headers == httpcore.Headers( - [ - (b"content-length", b"8"), - (b"accept-encoding", b"deflate, gzip, br"), - (b"content-type", b"application/x-www-form-urlencoded"), - ] - ) - assert request.content == b"test=123" + for RequestClass in (httpcore.Request, httpcore.AsyncRequest): + request = RequestClass("POST", "http://example.org", data={"test": "123"}) + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert request.content == b"test=123" + + +def test_json_encoded_data(): + for RequestClass in (httpcore.Request, httpcore.AsyncRequest): + request = RequestClass("POST", "http://example.org", json={"test": 123}) + assert request.headers["Content-Type"] == "application/json" + assert request.content == b'{"test": 123}' def test_transfer_encoding_header(): @@ -44,7 +43,6 @@ def streaming_body(data): data = streaming_body(b"test 123") request = httpcore.Request("POST", "http://example.org", data=data) - request.prepare() assert request.headers == httpcore.Headers( [(b"transfer-encoding", b"chunked"), (b"accept-encoding", b"deflate, gzip, br")] ) @@ -54,7 +52,6 @@ def test_override_host_header(): headers = [(b"host", b"1.2.3.4:80")] request = httpcore.Request("GET", "http://example.org", headers=headers) - request.prepare() assert request.headers == httpcore.Headers( [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")] ) @@ -64,7 +61,6 @@ def test_override_accept_encoding_header(): headers = [(b"accept-encoding", b"identity")] request = httpcore.Request("GET", "http://example.org", headers=headers) - request.prepare() assert request.headers == httpcore.Headers([(b"accept-encoding", b"identity")]) @@ -76,7 +72,6 @@ def streaming_body(data): headers = [(b"content-length", b"8")] request = httpcore.Request("POST", "http://example.org", data=data, headers=headers) - request.prepare() assert request.headers == httpcore.Headers( [(b"accept-encoding", b"deflate, gzip, br"), (b"content-length", b"8")] ) diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 1fe1e0535d..f2d080ffc8 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -8,6 +8,11 @@ def streaming_body(): yield b"world!" +async def async_streaming_body(): + yield b"Hello, " + yield b"world!" + + def test_response(): response = httpcore.Response(200, content=b"Hello, world!") assert response.status_code == 200 @@ -138,6 +143,16 @@ def test_stream_interface(): assert content == b"Hello, world!" +@pytest.mark.asyncio +async def test_async_stream_interface(): + response = httpcore.AsyncResponse(200, content=b"Hello, world!") + + content = b"" + async for part in response.stream(): + content += part + assert content == b"Hello, world!" + + def test_stream_interface_after_read(): response = httpcore.Response(200, content=b"Hello, world!") @@ -149,6 +164,18 @@ def test_stream_interface_after_read(): assert content == b"Hello, world!" +@pytest.mark.asyncio +async def test_async_stream_interface_after_read(): + response = httpcore.AsyncResponse(200, content=b"Hello, world!") + + await response.read() + + content = b"" + async for part in response.stream(): + content += part + assert content == b"Hello, world!" + + def test_streaming_response(): response = httpcore.Response(200, content=streaming_body()) @@ -162,6 +189,20 @@ def test_streaming_response(): assert response.is_closed +@pytest.mark.asyncio +async def test_async_streaming_response(): + response = httpcore.AsyncResponse(200, content=async_streaming_body()) + + assert response.status_code == 200 + assert not response.is_closed + + content = await response.read() + + assert content == b"Hello, world!" + assert response.content == b"Hello, world!" + assert response.is_closed + + def test_cannot_read_after_stream_consumed(): response = httpcore.Response(200, content=streaming_body()) @@ -173,6 +214,18 @@ def test_cannot_read_after_stream_consumed(): response.read() +@pytest.mark.asyncio +async def test_async_cannot_read_after_stream_consumed(): + response = httpcore.AsyncResponse(200, content=async_streaming_body()) + + content = b"" + async for part in response.stream(): + content += part + + with pytest.raises(httpcore.StreamConsumed): + await response.read() + + def test_cannot_read_after_response_closed(): response = httpcore.Response(200, content=streaming_body()) @@ -182,6 +235,16 @@ def test_cannot_read_after_response_closed(): response.read() +@pytest.mark.asyncio +async def test_async_cannot_read_after_response_closed(): + response = httpcore.AsyncResponse(200, content=async_streaming_body()) + + await response.close() + + with pytest.raises(httpcore.ResponseClosed): + await response.read() + + def test_unknown_status_code(): response = httpcore.Response(600) assert response.status_code == 600 From 12e1b29058926e5723ed92417aa4a0df1879d20c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 10 Jun 2019 12:04:07 +0100 Subject: [PATCH 19/19] Include Accept and User-Agent headers by default --- httpcore/models.py | 6 ++++++ tests/models/test_requests.py | 29 ++++++++++------------------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/httpcore/models.py b/httpcore/models.py index e4ec202945..eb610801bc 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -501,11 +501,17 @@ def prepare(self) -> None: auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + has_user_agent = "user-agent" in self.headers + has_accept = "accept" in self.headers has_content_length = ( "content-length" in self.headers or "transfer-encoding" in self.headers ) has_accept_encoding = "accept-encoding" in self.headers + if not has_user_agent: + auto_headers.append((b"user-agent", b"httpcore")) + if not has_accept: + auto_headers.append((b"accept", b"*/*")) if not has_content_length: if is_streaming: auto_headers.append((b"transfer-encoding", b"chunked")) diff --git a/tests/models/test_requests.py b/tests/models/test_requests.py index 88f2a80323..79cbba36e8 100644 --- a/tests/models/test_requests.py +++ b/tests/models/test_requests.py @@ -10,16 +10,12 @@ def test_request_repr(): def test_no_content(): request = httpcore.Request("GET", "http://example.org") - assert request.headers == httpcore.Headers( - [(b"accept-encoding", b"deflate, gzip, br")] - ) + assert "Content-Length" not in request.headers def test_content_length_header(): request = httpcore.Request("POST", "http://example.org", data=b"test 123") - assert request.headers == httpcore.Headers( - [(b"content-length", b"8"), (b"accept-encoding", b"deflate, gzip, br")] - ) + assert request.headers["Content-Length"] == "8" def test_url_encoded_data(): @@ -43,25 +39,22 @@ def streaming_body(data): data = streaming_body(b"test 123") request = httpcore.Request("POST", "http://example.org", data=data) - assert request.headers == httpcore.Headers( - [(b"transfer-encoding", b"chunked"), (b"accept-encoding", b"deflate, gzip, br")] - ) + assert "Content-Length" not in request.headers + assert request.headers["Transfer-Encoding"] == "chunked" def test_override_host_header(): - headers = [(b"host", b"1.2.3.4:80")] + headers = {"host": "1.2.3.4:80"} request = httpcore.Request("GET", "http://example.org", headers=headers) - assert request.headers == httpcore.Headers( - [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")] - ) + assert request.headers["Host"] == "1.2.3.4:80" def test_override_accept_encoding_header(): - headers = [(b"accept-encoding", b"identity")] + headers = {"Accept-Encoding": "identity"} request = httpcore.Request("GET", "http://example.org", headers=headers) - assert request.headers == httpcore.Headers([(b"accept-encoding", b"identity")]) + assert request.headers["Accept-Encoding"] == "identity" def test_override_content_length_header(): @@ -69,12 +62,10 @@ def streaming_body(data): yield data # pragma: nocover data = streaming_body(b"test 123") - headers = [(b"content-length", b"8")] + headers = {"Content-Length": "8"} request = httpcore.Request("POST", "http://example.org", data=data, headers=headers) - assert request.headers == httpcore.Headers( - [(b"accept-encoding", b"deflate, gzip, br"), (b"content-length", b"8")] - ) + assert request.headers["Content-Length"] == "8" def test_url():