diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 6d073a8de3..b49691bf87 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -28,8 +28,25 @@ TooManyRedirects, WriteTimeout, ) -from .interfaces import BaseReader, BaseWriter, ConcurrencyBackend, Dispatcher, Protocol -from .models import URL, Cookies, Headers, Origin, QueryParams, Request, Response +from .interfaces import ( + AsyncDispatcher, + BaseReader, + BaseWriter, + ConcurrencyBackend, + Dispatcher, + Protocol, +) +from .models import ( + URL, + AsyncRequest, + AsyncResponse, + Cookies, + Headers, + Origin, + QueryParams, + Request, + Response, +) from .status_codes import StatusCode, codes __version__ = "0.3.0" diff --git a/httpcore/api.py b/httpcore/api.py index 33d68c5e77..7e2682567f 100644 --- a/httpcore/api.py +++ b/httpcore/api.py @@ -8,7 +8,7 @@ HeaderTypes, QueryParamTypes, RequestData, - SyncResponse, + Response, URLTypes, ) @@ -30,7 +30,7 @@ def request( cert: CertTypes = None, verify: VerifyTypes = True, stream: bool = False, -) -> SyncResponse: +) -> Response: with Client() as client: return client.request( method=method, @@ -61,7 +61,7 @@ def get( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "GET", url, @@ -88,7 +88,7 @@ def options( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "OPTIONS", url, @@ -115,7 +115,7 @@ def head( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "HEAD", url, @@ -144,7 +144,7 @@ def post( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "POST", url, @@ -175,7 +175,7 @@ def put( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "PUT", url, @@ -206,7 +206,7 @@ def patch( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "PATCH", url, @@ -237,7 +237,7 @@ def delete( cert: CertTypes = None, verify: VerifyTypes = True, timeout: TimeoutTypes = None, -) -> SyncResponse: +) -> Response: return request( "DELETE", url, diff --git a/httpcore/auth.py b/httpcore/auth.py index 49ff998b43..6a39c1b2c5 100644 --- a/httpcore/auth.py +++ b/httpcore/auth.py @@ -1,7 +1,7 @@ import typing from base64 import b64encode -from .models import Request +from .models import AsyncRequest class AuthBase: @@ -9,7 +9,7 @@ class AuthBase: Base class that all auth implementations derive from. """ - def __call__(self, request: Request) -> Request: + def __call__(self, request: AsyncRequest) -> AsyncRequest: raise NotImplementedError("Auth hooks must be callable.") # pragma: nocover @@ -20,7 +20,7 @@ def __init__( self.username = username self.password = password - def __call__(self, request: Request) -> Request: + def __call__(self, request: AsyncRequest) -> AsyncRequest: request.headers["Authorization"] = self.build_auth_header() return request diff --git a/httpcore/client.py b/httpcore/client.py index 2946a753fd..0fc60a1a3b 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -1,8 +1,8 @@ -import asyncio import typing from types import TracebackType from .auth import HTTPBasicAuth +from .concurrency import AsyncioBackend from .config import ( DEFAULT_MAX_REDIRECTS, DEFAULT_POOL_LIMITS, @@ -13,10 +13,15 @@ VerifyTypes, ) from .dispatch.connection_pool import ConnectionPool +from .dispatch.threaded import ThreadedDispatcher from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects -from .interfaces import ConcurrencyBackend, Dispatcher +from .interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher from .models import ( URL, + AsyncRequest, + AsyncRequestData, + AsyncResponse, + AsyncResponseContent, AuthTypes, Cookies, CookieTypes, @@ -26,13 +31,13 @@ Request, RequestData, Response, - SyncResponse, + ResponseContent, URLTypes, ) from .status_codes import codes -class AsyncClient: +class BaseClient: def __init__( self, auth: AuthTypes = None, @@ -42,23 +47,208 @@ def __init__( timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, max_redirects: int = DEFAULT_MAX_REDIRECTS, - dispatch: Dispatcher = None, + dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None, backend: ConcurrencyBackend = None, ): + if backend is None: + backend = AsyncioBackend() + if dispatch is None: - dispatch = ConnectionPool( + async_dispatch = ConnectionPool( verify=verify, cert=cert, timeout=timeout, pool_limits=pool_limits, backend=backend, - ) + ) # type: AsyncDispatcher + elif isinstance(dispatch, Dispatcher): + async_dispatch = ThreadedDispatcher(dispatch, backend) + else: + async_dispatch = dispatch self.auth = auth self.cookies = Cookies(cookies) self.max_redirects = max_redirects - self.dispatch = dispatch + self.dispatch = async_dispatch + self.concurrency_backend = backend + + def merge_cookies( + self, cookies: CookieTypes = None + ) -> typing.Optional[CookieTypes]: + if cookies or self.cookies: + merged_cookies = Cookies(self.cookies) + merged_cookies.update(cookies) + return merged_cookies + return cookies + + async def send( + self, + request: AsyncRequest, + *, + stream: bool = False, + auth: AuthTypes = None, + allow_redirects: bool = True, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> AsyncResponse: + if auth is None: + auth = self.auth + + url = request.url + if auth is None and (url.username or url.password): + auth = HTTPBasicAuth(username=url.username, password=url.password) + + if auth is not None: + if isinstance(auth, tuple): + auth = HTTPBasicAuth(username=auth[0], password=auth[1]) + request = auth(request) + + response = await self.send_handling_redirects( + request, + verify=verify, + cert=cert, + timeout=timeout, + allow_redirects=allow_redirects, + ) + + if not stream: + try: + await response.read() + finally: + await response.close() + + return response + + async def send_handling_redirects( + self, + request: AsyncRequest, + *, + cert: CertTypes = None, + verify: VerifyTypes = None, + timeout: TimeoutTypes = None, + allow_redirects: bool = True, + history: typing.List[AsyncResponse] = None, + ) -> AsyncResponse: + if history is None: + history = [] + + while True: + # We perform these checks here, so that calls to `response.next()` + # will raise redirect errors if appropriate. + if len(history) > self.max_redirects: + raise TooManyRedirects() + if request.url in [response.url for response in history]: + raise RedirectLoop() + + response = await self.dispatch.send( + request, verify=verify, cert=cert, timeout=timeout + ) + assert isinstance(response, AsyncResponse) + response.history = list(history) + self.cookies.extract_cookies(response) + history = [response] + history + if not response.is_redirect: + break + + if allow_redirects: + request = self.build_redirect_request(request, response) + else: + + async def send_next() -> AsyncResponse: + nonlocal request, response, verify, cert, allow_redirects, timeout, history + request = self.build_redirect_request(request, response) + response = await self.send_handling_redirects( + request, + allow_redirects=allow_redirects, + verify=verify, + cert=cert, + timeout=timeout, + history=history, + ) + return response + + response.next = send_next # type: ignore + break + + return response + + def build_redirect_request( + self, request: AsyncRequest, response: AsyncResponse + ) -> AsyncRequest: + method = self.redirect_method(request, response) + url = self.redirect_url(request, response) + headers = self.redirect_headers(request, url) + content = self.redirect_content(request, method) + cookies = self.merge_cookies(request.cookies) + return AsyncRequest( + method=method, url=url, headers=headers, data=content, cookies=cookies + ) + + def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str: + """ + When being redirected we may want to change the method of the request + based on certain specs or browser behavior. + """ + method = request.method + + # https://tools.ietf.org/html/rfc7231#section-6.4.4 + if response.status_code == codes.SEE_OTHER and method != "HEAD": + method = "GET" + + # Do what the browsers do, despite standards... + # Turn 302s into GETs. + if response.status_code == codes.FOUND and method != "HEAD": + method = "GET" + + # If a POST is responded to with a 301, turn it into a GET. + # This bizarre behaviour is explained in 'requests' issue 1704. + if response.status_code == codes.MOVED_PERMANENTLY and method == "POST": + method = "GET" + + return method + + def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL: + """ + Return the URL for the redirect to follow. + """ + location = response.headers["Location"] + + url = URL(location, allow_relative=True) + + # Facilitate relative 'Location' headers, as allowed by RFC 7231. + # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + if url.is_relative_url: + url = url.resolve_with(request.url) + + # Attach previous fragment if needed (RFC 7231 7.1.2) + if request.url.fragment and not url.fragment: + url = url.copy_with(fragment=request.url.fragment) + + return url + + def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers: + """ + Strip Authorization headers when responses are redirected away from + the origin. + """ + headers = Headers(request.headers) + if url.origin != request.url.origin: + del headers["Authorization"] + return headers + + def redirect_content(self, request: AsyncRequest, method: str) -> bytes: + """ + Return the body that should be used for the redirect request. + """ + if method != request.method and method == "GET": + return b"" + if request.is_streaming: + raise RedirectBodyUnavailable() + return request.content + +class AsyncClient(BaseClient): async def get( self, url: URLTypes, @@ -72,7 +262,7 @@ async def get( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "GET", url, @@ -100,7 +290,7 @@ async def options( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "OPTIONS", url, @@ -128,7 +318,7 @@ async def head( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "HEAD", url, @@ -147,7 +337,7 @@ async def post( self, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -158,7 +348,7 @@ async def post( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "POST", url, @@ -179,7 +369,7 @@ async def put( self, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -190,7 +380,7 @@ async def put( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "PUT", url, @@ -211,7 +401,7 @@ async def patch( self, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -222,7 +412,7 @@ async def patch( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "PATCH", url, @@ -243,7 +433,7 @@ async def delete( self, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -254,7 +444,7 @@ async def delete( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: return await self.request( "DELETE", url, @@ -276,7 +466,7 @@ async def request( method: str, url: URLTypes, *, - data: RequestData = b"", + data: AsyncRequestData = b"", json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, @@ -287,8 +477,8 @@ async def request( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> Response: - request = Request( + ) -> AsyncResponse: + request = AsyncRequest( method, url, data=data, @@ -308,174 +498,6 @@ async def request( ) return response - def merge_cookies( - self, cookies: CookieTypes = None - ) -> typing.Optional[CookieTypes]: - if cookies or self.cookies: - merged_cookies = Cookies(self.cookies) - merged_cookies.update(cookies) - return merged_cookies - return cookies - - async def send( - self, - request: Request, - *, - stream: bool = False, - auth: AuthTypes = None, - allow_redirects: bool = True, - verify: VerifyTypes = None, - cert: CertTypes = None, - timeout: TimeoutTypes = None, - ) -> Response: - if auth is None: - auth = self.auth - - url = request.url - if auth is None and (url.username or url.password): - auth = HTTPBasicAuth(username=url.username, password=url.password) - - if auth is not None: - if isinstance(auth, tuple): - auth = HTTPBasicAuth(username=auth[0], password=auth[1]) - request = auth(request) - - response = await self.send_handling_redirects( - request, - stream=stream, - verify=verify, - cert=cert, - timeout=timeout, - allow_redirects=allow_redirects, - ) - return response - - async def send_handling_redirects( - self, - request: Request, - *, - stream: bool = False, - cert: CertTypes = None, - verify: VerifyTypes = None, - timeout: TimeoutTypes = None, - allow_redirects: bool = True, - history: typing.List[Response] = None, - ) -> Response: - if history is None: - history = [] - - while True: - # We perform these checks here, so that calls to `response.next()` - # will raise redirect errors if appropriate. - if len(history) > self.max_redirects: - raise TooManyRedirects() - if request.url in [response.url for response in history]: - raise RedirectLoop() - - response = await self.dispatch.send( - request, stream=stream, verify=verify, cert=cert, timeout=timeout - ) - response.history = list(history) - self.cookies.extract_cookies(response) - history = [response] + history - if not response.is_redirect: - break - - if allow_redirects: - request = self.build_redirect_request(request, response) - else: - - async def send_next() -> Response: - nonlocal request, response, verify, cert, allow_redirects, timeout, history - request = self.build_redirect_request(request, response) - response = await self.send_handling_redirects( - request, - stream=stream, - allow_redirects=allow_redirects, - verify=verify, - cert=cert, - timeout=timeout, - history=history, - ) - return response - - response.next = send_next # type: ignore - break - - return response - - def build_redirect_request(self, request: Request, response: Response) -> Request: - method = self.redirect_method(request, response) - url = self.redirect_url(request, response) - headers = self.redirect_headers(request, url) - content = self.redirect_content(request, method) - cookies = self.merge_cookies(request.cookies) - return Request( - method=method, url=url, headers=headers, data=content, cookies=cookies - ) - - def redirect_method(self, request: Request, response: Response) -> str: - """ - When being redirected we may want to change the method of the request - based on certain specs or browser behavior. - """ - method = request.method - - # https://tools.ietf.org/html/rfc7231#section-6.4.4 - if response.status_code == codes.SEE_OTHER and method != "HEAD": - method = "GET" - - # Do what the browsers do, despite standards... - # Turn 302s into GETs. - if response.status_code == codes.FOUND and method != "HEAD": - method = "GET" - - # If a POST is responded to with a 301, turn it into a GET. - # This bizarre behaviour is explained in 'requests' issue 1704. - if response.status_code == codes.MOVED_PERMANENTLY and method == "POST": - method = "GET" - - return method - - def redirect_url(self, request: Request, response: Response) -> URL: - """ - Return the URL for the redirect to follow. - """ - location = response.headers["Location"] - - url = URL(location, allow_relative=True) - - # Facilitate relative 'Location' headers, as allowed by RFC 7231. - # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') - if url.is_relative_url: - url = url.resolve_with(request.url) - - # Attach previous fragment if needed (RFC 7231 7.1.2) - if request.url.fragment and not url.fragment: - url = url.copy_with(fragment=request.url.fragment) - - return url - - def redirect_headers(self, request: Request, url: URL) -> Headers: - """ - Strip Authorization headers when responses are redirected away from - the origin. - """ - headers = Headers(request.headers) - if url.origin != request.url.origin: - del headers["Authorization"] - return headers - - def redirect_content(self, request: Request, method: str) -> bytes: - """ - Return the body that should be used for the redirect request. - """ - if method != request.method and method == "GET": - return b"" - if request.is_streaming: - raise RedirectBodyUnavailable() - return request.content - async def close(self) -> None: await self.dispatch.close() @@ -491,33 +513,28 @@ async def __aexit__( await self.close() -class Client: - def __init__( - self, - auth: AuthTypes = None, - cert: CertTypes = None, - verify: VerifyTypes = True, - timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, - pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, - max_redirects: int = DEFAULT_MAX_REDIRECTS, - dispatch: Dispatcher = None, - backend: ConcurrencyBackend = None, - ) -> None: - self._client = AsyncClient( - auth=auth, - verify=verify, - cert=cert, - timeout=timeout, - pool_limits=pool_limits, - max_redirects=max_redirects, - dispatch=dispatch, - backend=backend, - ) - self._loop = asyncio.new_event_loop() +class Client(BaseClient): + def _async_request_data(self, data: RequestData) -> AsyncRequestData: + """ + If the request data is an bytes iterator then return an async bytes + iterator onto the request data. + """ + if isinstance(data, (bytes, dict)): + return data + + # Coerce an iterator into an async iterator, with each item in the + # iteration running as a thread-pooled operation. + assert hasattr(data, "__iter__") + return self.concurrency_backend.iterate_in_threadpool(data) - @property - def cookies(self) -> Cookies: - return self._client.cookies + def _sync_data(self, data: AsyncResponseContent) -> ResponseContent: + if isinstance(data, bytes): + return data + + # Coerce an async iterator into an iterator, with each item in the + # iteration run within the event loop. + assert hasattr(data, "__aiter__") + return self.concurrency_backend.iterate(data) def request( self, @@ -535,25 +552,55 @@ def request( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: - request = Request( + ) -> Response: + request = AsyncRequest( method, url, - data=data, + data=self._async_request_data(data), json=json, params=params, headers=headers, - cookies=self._client.merge_cookies(cookies), + cookies=self.merge_cookies(cookies), ) - response = self.send( - request, - stream=stream, + concurrency_backend = self.concurrency_backend + + coroutine = self.send + args = [request] + kwargs = dict( + stream=True, auth=auth, allow_redirects=allow_redirects, verify=verify, cert=cert, timeout=timeout, ) + async_response = concurrency_backend.run(coroutine, *args, **kwargs) + + content = getattr( + async_response, "_raw_content", getattr(async_response, "_raw_stream", None) + ) + + sync_content = self._sync_data(content) + + def sync_on_close() -> None: + nonlocal concurrency_backend, async_response + concurrency_backend.run(async_response.on_close) + + response = Response( + status_code=async_response.status_code, + reason_phrase=async_response.reason_phrase, + protocol=async_response.protocol, + headers=async_response.headers, + content=sync_content, + on_close=sync_on_close, + request=async_response.request, + history=async_response.history, + ) + if not stream: + try: + response.read() + finally: + response.close() return response def get( @@ -569,7 +616,7 @@ def get( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "GET", url, @@ -596,7 +643,7 @@ def options( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "OPTIONS", url, @@ -623,7 +670,7 @@ def head( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "HEAD", url, @@ -652,7 +699,7 @@ def post( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "POST", url, @@ -683,7 +730,7 @@ def put( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "PUT", url, @@ -714,7 +761,7 @@ def patch( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "PATCH", url, @@ -745,7 +792,7 @@ def delete( cert: CertTypes = None, verify: VerifyTypes = None, timeout: TimeoutTypes = None, - ) -> SyncResponse: + ) -> Response: return self.request( "DELETE", url, @@ -761,32 +808,9 @@ def delete( timeout=timeout, ) - def send( - self, - request: Request, - *, - stream: bool = False, - auth: AuthTypes = None, - allow_redirects: bool = True, - verify: VerifyTypes = None, - cert: CertTypes = None, - timeout: TimeoutTypes = None, - ) -> SyncResponse: - response = self._loop.run_until_complete( - self._client.send( - request, - stream=stream, - auth=auth, - allow_redirects=allow_redirects, - verify=verify, - cert=cert, - timeout=timeout, - ) - ) - return SyncResponse(response, self._loop) - def close(self) -> None: - self._loop.run_until_complete(self._client.close()) + coroutine = self.dispatch.close + self.concurrency_backend.run(coroutine) def __enter__(self) -> "Client": return self diff --git a/httpcore/concurrency.py b/httpcore/concurrency.py index 0c1d3409eb..664cb29448 100644 --- a/httpcore/concurrency.py +++ b/httpcore/concurrency.py @@ -9,6 +9,7 @@ based, and less strictly `asyncio`-specific. """ import asyncio +import functools import ssl import typing @@ -133,6 +134,15 @@ def __init__(self) -> None: ssl_monkey_patch() SSL_MONKEY_PATCH_APPLIED = True + @property + def loop(self) -> asyncio.AbstractEventLoop: + if not hasattr(self, "_loop"): + try: + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + return self._loop + async def connect( self, hostname: str, @@ -162,5 +172,24 @@ async def connect( return (reader, writer, protocol) + async def run_in_threadpool( + self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + if kwargs: + # loop.run_in_executor doesn't accept 'kwargs', so bind them in here + func = functools.partial(func, **kwargs) + return await self.loop.run_in_executor(None, func, *args) + + def run( + self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + loop = self.loop + if loop.is_running(): + self._loop = asyncio.new_event_loop() + try: + return self.loop.run_until_complete(coroutine(*args, **kwargs)) + finally: + self._loop = loop + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py index 60214333fe..d644bcba73 100644 --- a/httpcore/dispatch/connection.py +++ b/httpcore/dispatch/connection.py @@ -15,8 +15,8 @@ VerifyTypes, ) from ..exceptions import ConnectTimeout -from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol -from ..models import Origin, Request, Response +from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Protocol +from ..models import AsyncRequest, AsyncResponse, Origin from .http2 import HTTP2Connection from .http11 import HTTP11Connection @@ -24,7 +24,7 @@ ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]] -class HTTPConnection(Dispatcher): +class HTTPConnection(AsyncDispatcher): def __init__( self, origin: typing.Union[str, Origin], @@ -44,24 +44,19 @@ def __init__( async def send( self, - request: Request, - stream: bool = False, + request: AsyncRequest, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: if self.h11_connection is None and self.h2_connection is None: await self.connect(verify=verify, cert=cert, timeout=timeout) if self.h2_connection is not None: - response = await self.h2_connection.send( - request, stream=stream, timeout=timeout - ) + response = await self.h2_connection.send(request, timeout=timeout) else: assert self.h11_connection is not None - response = await self.h11_connection.send( - request, stream=stream, timeout=timeout - ) + response = await self.h11_connection.send(request, timeout=timeout) return response diff --git a/httpcore/dispatch/connection_pool.py b/httpcore/dispatch/connection_pool.py index e7cefbd7e4..c84117ca45 100644 --- a/httpcore/dispatch/connection_pool.py +++ b/httpcore/dispatch/connection_pool.py @@ -12,8 +12,8 @@ ) from ..decoders import ACCEPT_ENCODING from ..exceptions import PoolTimeout -from ..interfaces import ConcurrencyBackend, Dispatcher -from ..models import Origin, Request, Response +from ..interfaces import AsyncDispatcher, ConcurrencyBackend +from ..models import AsyncRequest, AsyncResponse, Origin from .connection import HTTPConnection CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]] @@ -77,7 +77,7 @@ def __len__(self) -> int: return len(self.all) -class ConnectionPool(Dispatcher): +class ConnectionPool(AsyncDispatcher): def __init__( self, *, @@ -105,16 +105,15 @@ def num_connections(self) -> int: async def send( self, - request: Request, - stream: bool = False, + request: AsyncRequest, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: connection = await self.acquire_connection(request.url.origin) try: response = await connection.send( - request, stream=stream, verify=verify, cert=cert, timeout=timeout + request, verify=verify, cert=cert, timeout=timeout ) except BaseException as exc: self.active_connections.remove(connection) diff --git a/httpcore/dispatch/http11.py b/httpcore/dispatch/http11.py index 6f1548563b..f19b3d3dc1 100644 --- a/httpcore/dispatch/http11.py +++ b/httpcore/dispatch/http11.py @@ -4,8 +4,8 @@ from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout -from ..interfaces import BaseReader, BaseWriter, Dispatcher -from ..models import Request, Response +from ..interfaces import BaseReader, BaseWriter +from ..models import AsyncRequest, AsyncResponse H11Event = typing.Union[ h11.Request, @@ -38,15 +38,15 @@ def __init__( self.h11_state = h11.Connection(our_role=h11.CLIENT) async def send( - self, request: Request, stream: bool = False, timeout: TimeoutTypes = None - ) -> Response: + self, request: AsyncRequest, timeout: TimeoutTypes = None + ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) #  Start sending the request. method = request.method.encode("ascii") target = request.url.full_path.encode("ascii") headers = request.headers.raw - if 'Host' not in request.headers: + if "Host" not in request.headers: host = request.url.authority.encode("ascii") headers = [(b"host", host)] + headers event = h11.Request(method=method, target=target, headers=headers) @@ -72,7 +72,7 @@ async def send( headers = event.headers content = self._body_iter(timeout) - response = Response( + return AsyncResponse( status_code=status_code, reason_phrase=reason_phrase, protocol="HTTP/1.1", @@ -82,14 +82,6 @@ async def send( request=request, ) - if not stream: - try: - await response.read() - finally: - await response.close() - - return response - async def close(self) -> None: event = h11.ConnectionClosed() self.h11_state.send(event) diff --git a/httpcore/dispatch/http2.py b/httpcore/dispatch/http2.py index 402f3b651c..f7814ec3c1 100644 --- a/httpcore/dispatch/http2.py +++ b/httpcore/dispatch/http2.py @@ -6,8 +6,8 @@ from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes from ..exceptions import ConnectTimeout, ReadTimeout -from ..interfaces import BaseReader, BaseWriter, Dispatcher -from ..models import Request, Response +from ..interfaces import BaseReader, BaseWriter +from ..models import AsyncRequest, AsyncResponse class HTTP2Connection: @@ -24,8 +24,8 @@ def __init__( self.initialized = False async def send( - self, request: Request, stream: bool = False, timeout: TimeoutTypes = None - ) -> Response: + self, request: AsyncRequest, timeout: TimeoutTypes = None + ) -> AsyncResponse: timeout = None if timeout is None else TimeoutConfig(timeout) #  Start sending the request. @@ -59,7 +59,7 @@ async def send( content = self.body_iter(stream_id, timeout) on_close = functools.partial(self.response_closed, stream_id=stream_id) - response = Response( + return AsyncResponse( status_code=status_code, protocol="HTTP/2", headers=headers, @@ -68,14 +68,6 @@ async def send( request=request, ) - if not stream: - try: - await response.read() - finally: - await response.close() - - return response - async def close(self) -> None: await self.writer.close() @@ -86,7 +78,7 @@ def initiate_connection(self) -> None: self.initialized = True async def send_headers( - self, request: Request, timeout: TimeoutConfig = None + self, request: AsyncRequest, timeout: TimeoutConfig = None ) -> int: stream_id = self.h2_state.get_next_available_stream_id() headers = [ diff --git a/httpcore/dispatch/threaded.py b/httpcore/dispatch/threaded.py new file mode 100644 index 0000000000..dbcd4dab19 --- /dev/null +++ b/httpcore/dispatch/threaded.py @@ -0,0 +1,97 @@ +from ..config import CertTypes, TimeoutTypes, VerifyTypes +from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher +from ..models import ( + AsyncRequest, + AsyncRequestData, + AsyncResponse, + AsyncResponseContent, + Request, + RequestData, + Response, + ResponseContent, +) + + +class ThreadedDispatcher(AsyncDispatcher): + """ + The ThreadedDispatcher class is used to mediate between the Client + (which always uses async under the hood), and a synchronous `Dispatch` + class. + """ + + def __init__(self, dispatch: Dispatcher, backend: ConcurrencyBackend) -> None: + self.sync_dispatcher = dispatch + self.backend = backend + + async def send( + self, + request: AsyncRequest, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> AsyncResponse: + concurrency_backend = self.backend + + data = getattr(request, "content", getattr(request, "content_aiter", None)) + sync_data = self._sync_request_data(data) + + sync_request = Request( + method=request.method, + url=request.url, + headers=request.headers, + data=sync_data, + ) + + func = self.sync_dispatcher.send + kwargs = { + "request": sync_request, + "verify": verify, + "cert": cert, + "timeout": timeout, + } + sync_response = await self.backend.run_in_threadpool(func, **kwargs) + assert isinstance(sync_response, Response) + + content = getattr( + sync_response, "_raw_content", getattr(sync_response, "_raw_stream", None) + ) + + async_content = self._async_response_content(content) + + async def async_on_close() -> None: + nonlocal concurrency_backend, sync_response + await concurrency_backend.run_in_threadpool(sync_response.close) + + return AsyncResponse( + status_code=sync_response.status_code, + reason_phrase=sync_response.reason_phrase, + protocol=sync_response.protocol, + headers=sync_response.headers, + content=async_content, + on_close=async_on_close, + request=request, + history=sync_response.history, + ) + + async def close(self) -> None: + """ + The `.close()` method runs the `Dispatcher.close()` within a threadpool, + so as not to block the async event loop. + """ + func = self.sync_dispatcher.close + await self.backend.run_in_threadpool(func) + + def _async_response_content(self, content: ResponseContent) -> AsyncResponseContent: + if isinstance(content, bytes): + return content + + # Coerce an async iterator into an iterator, with each item in the + # iteration run within the event loop. + assert hasattr(content, "__iter__") + return self.backend.iterate_in_threadpool(content) + + def _sync_request_data(self, data: AsyncRequestData) -> RequestData: + if isinstance(data, bytes): + return data + + return self.backend.iterate(data) diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index 42ffd157ae..13d118cfcd 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -6,6 +6,9 @@ from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes from .models import ( URL, + AsyncRequest, + AsyncRequestData, + AsyncResponse, Headers, HeaderTypes, QueryParamTypes, @@ -21,9 +24,9 @@ class Protocol(str, enum.Enum): HTTP_2 = "HTTP/2" -class Dispatcher: +class AsyncDispatcher: """ - Base class for dispatcher classes, that handle sending the request. + Base class for async dispatcher classes, that handle sending the request. Stubs out the interface, as well as providing a `.request()` convienence implementation, to make it easy to use or test stand-alone dispatchers, @@ -31,6 +34,54 @@ class Dispatcher: """ async def request( + self, + method: str, + url: URLTypes, + *, + data: AsyncRequestData = b"", + params: QueryParamTypes = None, + headers: HeaderTypes = None, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None + ) -> AsyncResponse: + request = AsyncRequest(method, url, data=data, params=params, headers=headers) + return await self.send(request, verify=verify, cert=cert, timeout=timeout) + + async def send( + self, + request: AsyncRequest, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> AsyncResponse: + raise NotImplementedError() # pragma: nocover + + async def close(self) -> None: + pass # pragma: nocover + + async def __aenter__(self) -> "AsyncDispatcher": + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.close() + + +class Dispatcher: + """ + Base class for syncronous dispatcher classes, that handle sending the request. + + Stubs out the interface, as well as providing a `.request()` convienence + implementation, to make it easy to use or test stand-alone dispatchers, + without requiring a complete `Client` instance. + """ + + def request( self, method: str, url: URLTypes, @@ -38,40 +89,35 @@ async def request( data: RequestData = b"", params: QueryParamTypes = None, headers: HeaderTypes = None, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None ) -> Response: request = Request(method, url, data=data, params=params, headers=headers) - response = await self.send( - request, stream=stream, verify=verify, cert=cert, timeout=timeout - ) - return response + return self.send(request, verify=verify, cert=cert, timeout=timeout) - async def send( + def send( self, request: Request, - stream: bool = False, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, ) -> Response: raise NotImplementedError() # pragma: nocover - async def close(self) -> None: + def close(self) -> None: pass # pragma: nocover - async def __aenter__(self) -> "Dispatcher": + def __enter__(self) -> "Dispatcher": return self - async def __aexit__( + def __exit__( self, exc_type: typing.Type[BaseException] = None, exc_value: BaseException = None, traceback: TracebackType = None, ) -> None: - await self.close() + self.close() class BaseReader: @@ -128,3 +174,36 @@ async def connect( def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover + + async def run_in_threadpool( + self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + raise NotImplementedError() # pragma: no cover + + async def iterate_in_threadpool(self, iterator): # type: ignore + class IterationComplete(Exception): + pass + + def next_wrapper(iterator): # type: ignore + try: + return next(iterator) + except StopIteration: + raise IterationComplete() + + while True: + try: + yield await self.run_in_threadpool(next_wrapper, iterator) + except IterationComplete: + break + + def run( + self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + raise NotImplementedError() # pragma: no cover + + def iterate(self, async_iterator): # type: ignore + while True: + try: + yield self.run(async_iterator.__anext__) + except StopAsyncIteration: + break diff --git a/httpcore/models.py b/httpcore/models.py index cb28675d2d..eb610801bc 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -1,4 +1,3 @@ -import asyncio import cgi import email.message import json as jsonlib @@ -48,12 +47,16 @@ AuthTypes = typing.Union[ typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]], - typing.Callable[["Request"], "Request"], + typing.Callable[["AsyncRequest"], "AsyncRequest"], ] -RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]] +AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]] -ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]] +RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]] + +AsyncResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]] + +ResponseContent = typing.Union[bytes, typing.Iterator[bytes]] class URL: @@ -469,14 +472,12 @@ def __repr__(self) -> str: return f"{class_name}({as_list!r}{encoding_str})" -class Request: +class BaseRequest: def __init__( self, method: str, url: typing.Union[str, URL], *, - data: RequestData = b"", - json: typing.Any = None, params: QueryParamTypes = None, headers: HeaderTypes = None, cookies: CookieTypes = None, @@ -488,18 +489,82 @@ def __init__( self._cookies = Cookies(cookies) self._cookies.set_cookie_header(self) + def encode_json(self, json: typing.Any) -> bytes: + return jsonlib.dumps(json).encode("utf-8") + + def urlencode_data(self, data: dict) -> bytes: + return urlencode(data, doseq=True).encode("utf-8") + + def prepare(self) -> None: + content = getattr(self, "content", None) # type: bytes + is_streaming = getattr(self, "is_streaming", False) + + auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + + has_user_agent = "user-agent" in self.headers + has_accept = "accept" in self.headers + has_content_length = ( + "content-length" in self.headers or "transfer-encoding" in self.headers + ) + has_accept_encoding = "accept-encoding" in self.headers + + if not has_user_agent: + auto_headers.append((b"user-agent", b"httpcore")) + if not has_accept: + auto_headers.append((b"accept", b"*/*")) + if not has_content_length: + if is_streaming: + auto_headers.append((b"transfer-encoding", b"chunked")) + elif content: + content_length = str(len(content)).encode() + auto_headers.append((b"content-length", content_length)) + if not has_accept_encoding: + auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode())) + + for item in reversed(auto_headers): + self.headers.raw.insert(0, item) + + @property + def cookies(self) -> "Cookies": + if not hasattr(self, "_cookies"): + self._cookies = Cookies() + return self._cookies + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + url = str(self.url) + return f"<{class_name}({self.method!r}, {url!r})>" + + +class AsyncRequest(BaseRequest): + def __init__( + self, + method: str, + url: typing.Union[str, URL], + *, + params: QueryParamTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + data: AsyncRequestData = b"", + json: typing.Any = None, + ): + super().__init__( + method=method, url=url, params=params, headers=headers, cookies=cookies + ) + if json is not None: - data = jsonlib.dumps(json).encode("utf-8") + self.is_streaming = False + self.content = self.encode_json(json) self.headers["Content-Type"] = "application/json" - - if isinstance(data, bytes): + elif isinstance(data, bytes): self.is_streaming = False self.content = data elif isinstance(data, dict): self.is_streaming = False - self.content = urlencode(data, doseq=True).encode("utf-8") + self.content = self.urlencode_data(data) self.headers["Content-Type"] = "application/x-www-form-urlencoded" else: + assert hasattr(data, "__aiter__") self.is_streaming = True self.content_aiter = data @@ -520,39 +585,55 @@ async def stream(self) -> typing.AsyncIterator[bytes]: elif self.content: yield self.content - def prepare(self) -> None: - auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] - has_content_length = ( - "content-length" in self.headers or "transfer-encoding" in self.headers +class Request(BaseRequest): + def __init__( + self, + method: str, + url: typing.Union[str, URL], + *, + params: QueryParamTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + data: RequestData = b"", + json: typing.Any = None, + ): + super().__init__( + method=method, url=url, params=params, headers=headers, cookies=cookies ) - has_accept_encoding = "accept-encoding" in self.headers - if not has_content_length: - if self.is_streaming: - auto_headers.append((b"transfer-encoding", b"chunked")) - elif self.content: - content_length = str(len(self.content)).encode() - auto_headers.append((b"content-length", content_length)) - if not has_accept_encoding: - auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode())) + if json is not None: + self.is_streaming = False + self.content = self.encode_json(json) + self.headers["Content-Type"] = "application/json" + elif isinstance(data, bytes): + self.is_streaming = False + self.content = data + elif isinstance(data, dict): + self.is_streaming = False + self.content = self.urlencode_data(data) + self.headers["Content-Type"] = "application/x-www-form-urlencoded" + else: + assert hasattr(data, "__iter__") + self.is_streaming = True + self.content_iter = data - for item in reversed(auto_headers): - self.headers.raw.insert(0, item) + self.prepare() - @property - def cookies(self) -> "Cookies": - if not hasattr(self, "_cookies"): - self._cookies = Cookies() - return self._cookies + def read(self) -> bytes: + if not hasattr(self, "content"): + self.content = b"".join([part for part in self.stream()]) + return self.content - def __repr__(self) -> str: - class_name = self.__class__.__name__ - url = str(self.url) - return f"<{class_name}({self.method!r}, {url!r})>" + def stream(self) -> typing.Iterator[bytes]: + if self.is_streaming: + for part in self.content_iter: + yield part + elif self.content: + yield self.content -class Response: +class BaseResponse: def __init__( self, status_code: int, @@ -560,28 +641,16 @@ def __init__( reason_phrase: str = None, protocol: str = None, headers: HeaderTypes = None, - content: ResponseContent = b"", + request: BaseRequest = None, on_close: typing.Callable = None, - request: Request = None, - history: typing.List["Response"] = None, ): self.status_code = StatusCode.enum_or_int(status_code) self.reason_phrase = StatusCode.get_reason_phrase(status_code) self.protocol = protocol self.headers = Headers(headers) - if isinstance(content, bytes): - self.is_closed = True - self.is_stream_consumed = True - self._raw_content = content - else: - self.is_closed = False - self.is_stream_consumed = False - self._raw_stream = content - - self.on_close = on_close self.request = request - self.history = [] if history is None else list(history) + self.on_close = on_close self.next = None # typing.Optional[typing.Callable] @property @@ -597,7 +666,8 @@ def url(self) -> typing.Optional[URL]: def content(self) -> bytes: if not hasattr(self, "_content"): if hasattr(self, "_raw_content"): - content = self.decoder.decode(self._raw_content) + raw_content = getattr(self, "_raw_content") # type: bytes + content = self.decoder.decode(raw_content) content += self.decoder.flush() self._content = content else: @@ -682,6 +752,77 @@ def decoder(self) -> Decoder: return self._decoder + @property + def is_redirect(self) -> bool: + return StatusCode.is_redirect(self.status_code) and "location" in self.headers + + def raise_for_status(self) -> None: + """ + Raise the `HttpError` if one occurred. + """ + message = ( + "{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n" + "For more information check: https://httpstatuses.com/{0.status_code}" + ) + + if StatusCode.is_client_error(self.status_code): + message = message.format(self, error_type="Client Error") + elif StatusCode.is_server_error(self.status_code): + message = message.format(self, error_type="Server Error") + else: + message = "" + + if message: + raise HttpError(message) + + def json(self) -> typing.Any: + return jsonlib.loads(self.content.decode("utf-8")) + + @property + def cookies(self) -> "Cookies": + if not hasattr(self, "_cookies"): + assert self.request is not None + self._cookies = Cookies() + self._cookies.extract_cookies(self) + return self._cookies + + def __repr__(self) -> str: + return f"" + + +class AsyncResponse(BaseResponse): + def __init__( + self, + status_code: int, + *, + reason_phrase: str = None, + protocol: str = None, + headers: HeaderTypes = None, + content: AsyncResponseContent = b"", + on_close: typing.Callable = None, + request: AsyncRequest = None, + history: typing.List["BaseResponse"] = None, + ): + super().__init__( + status_code=status_code, + reason_phrase=reason_phrase, + protocol=protocol, + headers=headers, + request=request, + on_close=on_close, + ) + + self.history = [] if history is None else list(history) + + if isinstance(content, bytes): + self.is_closed = True + self.is_stream_consumed = True + self._raw_content = content + else: + self.is_closed = False + self.is_stream_consumed = False + self._raw_stream = content + async def read(self) -> bytes: """ Read and return the response content. @@ -729,128 +870,86 @@ async def close(self) -> None: if self.on_close is not None: await self.on_close() - @property - def is_redirect(self) -> bool: - return StatusCode.is_redirect(self.status_code) and "location" in self.headers - def raise_for_status(self) -> None: - """ - Raise the `HttpError` if one occurred. - """ - message = ( - "{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n" - "For more information check: https://httpstatuses.com/{0.status_code}" +class Response(BaseResponse): + def __init__( + self, + status_code: int, + *, + reason_phrase: str = None, + protocol: str = None, + headers: HeaderTypes = None, + content: ResponseContent = b"", + on_close: typing.Callable = None, + request: Request = None, + history: typing.List["BaseResponse"] = None, + ): + super().__init__( + status_code=status_code, + reason_phrase=reason_phrase, + protocol=protocol, + headers=headers, + request=request, + on_close=on_close, ) - if StatusCode.is_client_error(self.status_code): - message = message.format(self, error_type="Client Error") - elif StatusCode.is_server_error(self.status_code): - message = message.format(self, error_type="Server Error") - else: - message = "" - - if message: - raise HttpError(message) - - def json(self) -> typing.Any: - return jsonlib.loads(self.content.decode("utf-8")) - - @property - def cookies(self) -> "Cookies": - if not hasattr(self, "_cookies"): - assert self.request is not None - self._cookies = Cookies() - self._cookies.extract_cookies(self) - return self._cookies - - def __repr__(self) -> str: - return f"" - - -class SyncResponse: - """ - A thread-synchronous response. This class proxies onto a `Response` - instance, providing standard synchronous interfaces where required. - """ - - def __init__(self, response: Response, loop: asyncio.AbstractEventLoop): - self._response = response - self._loop = loop - - @property - def status_code(self) -> int: - return self._response.status_code - - @property - def reason_phrase(self) -> str: - return self._response.reason_phrase - - @property - def protocol(self) -> typing.Optional[str]: - return self._response.protocol - - @property - def url(self) -> typing.Optional[URL]: - return self._response.url - - @property - def request(self) -> typing.Optional[Request]: - return self._response.request - - @property - def headers(self) -> Headers: - return self._response.headers - - @property - def content(self) -> bytes: - return self._response.content - - @property - def text(self) -> str: - return self._response.text - - @property - def encoding(self) -> str: - return self._response.encoding - - @property - def is_redirect(self) -> bool: - return self._response.is_redirect - - def raise_for_status(self) -> None: - return self._response.raise_for_status() + self.history = [] if history is None else list(history) - def json(self) -> typing.Any: - return self._response.json() + if isinstance(content, bytes): + self.is_closed = True + self.is_stream_consumed = True + self._raw_content = content + else: + self.is_closed = False + self.is_stream_consumed = False + self._raw_stream = content def read(self) -> bytes: - return self._loop.run_until_complete(self._response.read()) + """ + Read and return the response content. + """ + if not hasattr(self, "_content"): + self._content = b"".join([part for part in self.stream()]) + return self._content def stream(self) -> typing.Iterator[bytes]: - inner = self._response.stream() - while True: - try: - yield self._loop.run_until_complete(inner.__anext__()) - except StopAsyncIteration: - break + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + if hasattr(self, "_content"): + yield self._content + else: + for chunk in self.raw(): + yield self.decoder.decode(chunk) + yield self.decoder.flush() def raw(self) -> typing.Iterator[bytes]: - inner = self._response.raw() - while True: - try: - yield self._loop.run_until_complete(inner.__anext__()) - except StopAsyncIteration: - break - - def close(self) -> None: - return self._loop.run_until_complete(self._response.close()) + """ + A byte-iterator over the raw response content. + """ + if hasattr(self, "_raw_content"): + yield self._raw_content + else: + if self.is_stream_consumed: + raise StreamConsumed() + if self.is_closed: + raise ResponseClosed() - @property - def cookies(self) -> "Cookies": - return self._response.cookies + self.is_stream_consumed = True + for part in self._raw_stream: + yield part + self.close() - def __repr__(self) -> str: - return f"" + def close(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + if not self.is_closed: + self.is_closed = True + if self.on_close is not None: + self.on_close() class Cookies(MutableMapping): @@ -871,7 +970,7 @@ def __init__(self, cookies: CookieTypes = None) -> None: else: self.jar = cookies - def extract_cookies(self, response: Response) -> None: + def extract_cookies(self, response: BaseResponse) -> None: """ Loads any cookies based on the response `Set-Cookie` headers. """ @@ -881,7 +980,7 @@ def extract_cookies(self, response: Response) -> None: self.jar.extract_cookies(urlib_response, urllib_request) # type: ignore - def set_cookie_header(self, request: Request) -> None: + def set_cookie_header(self, request: BaseRequest) -> None: """ Sets an appropriate 'Cookie:' HTTP header on the `Request`. """ @@ -1000,7 +1099,7 @@ class _CookieCompatRequest(urllib.request.Request): for use with `CookieJar` operations. """ - def __init__(self, request: Request) -> None: + def __init__(self, request: BaseRequest) -> None: super().__init__( url=str(request.url), headers=dict(request.headers), @@ -1018,7 +1117,7 @@ class _CookieCompatResponse: for use with `CookieJar` operations. """ - def __init__(self, response: Response): + def __init__(self, response: BaseResponse): self.response = response def info(self) -> email.message.Message: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 1d2b97239c..17993383a9 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -4,27 +4,26 @@ from httpcore import ( URL, + AsyncDispatcher, + AsyncRequest, + AsyncResponse, CertTypes, Client, - Dispatcher, - Request, - Response, TimeoutTypes, VerifyTypes, ) -class MockDispatch(Dispatcher): +class MockDispatch(AsyncDispatcher): async def send( self, - request: Request, - stream: bool = False, + request: AsyncRequest, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: body = json.dumps({"auth": request.headers.get("Authorization")}).encode() - return Response(200, content=body, request=request) + return AsyncResponse(200, content=body, request=request) def test_basic_auth(): diff --git a/tests/client/test_cookies.py b/tests/client/test_cookies.py index a21f5c134f..5cbb380921 100644 --- a/tests/client/test_cookies.py +++ b/tests/client/test_cookies.py @@ -5,32 +5,31 @@ from httpcore import ( URL, + AsyncDispatcher, + AsyncRequest, + AsyncResponse, CertTypes, Client, Cookies, - Dispatcher, - Request, - Response, TimeoutTypes, VerifyTypes, ) -class MockDispatch(Dispatcher): +class MockDispatch(AsyncDispatcher): async def send( self, - request: Request, - stream: bool = False, + request: AsyncRequest, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: if request.url.path.startswith("/echo_cookies"): body = json.dumps({"cookies": request.headers.get("Cookie")}).encode() - return Response(200, content=body, request=request) + return AsyncResponse(200, content=body, request=request) elif request.url.path.startswith("/set_cookie"): headers = {"set-cookie": "example-name=example-value"} - return Response(200, headers=headers, request=request) + return AsyncResponse(200, headers=headers, request=request) def test_set_cookie(): diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index c3b384dc95..3f5168974a 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -6,8 +6,10 @@ from httpcore import ( URL, AsyncClient, + AsyncDispatcher, + AsyncRequest, + AsyncResponse, CertTypes, - Dispatcher, RedirectBodyUnavailable, RedirectLoop, Request, @@ -19,37 +21,36 @@ ) -class MockDispatch(Dispatcher): +class MockDispatch(AsyncDispatcher): async def send( self, - request: Request, - stream: bool = False, + request: AsyncRequest, verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - ) -> Response: + ) -> AsyncResponse: if request.url.path == "/redirect_301": status_code = codes.MOVED_PERMANENTLY headers = {"location": "https://example.org/"} - return Response(status_code, headers=headers, request=request) + return AsyncResponse(status_code, headers=headers, request=request) elif request.url.path == "/redirect_302": status_code = codes.FOUND headers = {"location": "https://example.org/"} - return Response(status_code, headers=headers, request=request) + return AsyncResponse(status_code, headers=headers, request=request) elif request.url.path == "/redirect_303": status_code = codes.SEE_OTHER headers = {"location": "https://example.org/"} - return Response(status_code, headers=headers, request=request) + return AsyncResponse(status_code, headers=headers, request=request) elif request.url.path == "/relative_redirect": headers = {"location": "/"} - return Response(codes.SEE_OTHER, headers=headers, request=request) + return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request) elif request.url.path == "/no_scheme_redirect": headers = {"location": "//example.org/"} - return Response(codes.SEE_OTHER, headers=headers, request=request) + return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request) elif request.url.path == "/multiple_redirects": params = parse_qs(request.url.query) @@ -60,32 +61,34 @@ async def send( if redirect_count: location += "?count=" + str(redirect_count) headers = {"location": location} if count else {} - return Response(code, headers=headers, request=request) + return AsyncResponse(code, headers=headers, request=request) if request.url.path == "/redirect_loop": headers = {"location": "/redirect_loop"} - return Response(codes.SEE_OTHER, headers=headers, request=request) + return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request) elif request.url.path == "/cross_domain": headers = {"location": "https://example.org/cross_domain_target"} - return Response(codes.SEE_OTHER, headers=headers, request=request) + return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request) elif request.url.path == "/cross_domain_target": headers = dict(request.headers.items()) content = json.dumps({"headers": headers}).encode() - return Response(codes.OK, content=content, request=request) + return AsyncResponse(codes.OK, content=content, request=request) elif request.url.path == "/redirect_body": await request.read() headers = {"location": "/redirect_body_target"} - return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request) + return AsyncResponse( + codes.PERMANENT_REDIRECT, headers=headers, request=request + ) elif request.url.path == "/redirect_body_target": content = await request.read() body = json.dumps({"body": content.decode()}).encode() - return Response(codes.OK, content=body, request=request) + return AsyncResponse(codes.OK, content=body, request=request) - return Response(codes.OK, content=b"Hello, world!", request=request) + return AsyncResponse(codes.OK, content=b"Hello, world!", request=request) @pytest.mark.asyncio diff --git a/tests/dispatch/test_connection_pools.py b/tests/dispatch/test_connection_pools.py index bbe200ba94..b8049c70b8 100644 --- a/tests/dispatch/test_connection_pools.py +++ b/tests/dispatch/test_connection_pools.py @@ -10,10 +10,12 @@ async def test_keepalive_connections(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 @@ -25,10 +27,12 @@ async def test_differing_connection_keys(server): """ async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://localhost:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 2 @@ -42,10 +46,12 @@ async def test_soft_limit(server): async with httpcore.ConnectionPool(pool_limits=pool_limits) as http: response = await http.request("GET", "http://127.0.0.1:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 response = await http.request("GET", "http://localhost:8000/") + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 1 @@ -56,7 +62,7 @@ async def test_streaming_response_holds_connection(server): A streaming request should hold the connection open until the response is read. """ async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response = await http.request("GET", "http://127.0.0.1:8000/") assert len(http.active_connections) == 1 assert len(http.keepalive_connections) == 0 @@ -72,11 +78,11 @@ async def test_multiple_concurrent_connections(server): Multiple conncurrent requests should open multiple conncurrent connections. """ async with httpcore.ConnectionPool() as http: - response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response_a = await http.request("GET", "http://127.0.0.1:8000/") assert len(http.active_connections) == 1 assert len(http.keepalive_connections) == 0 - response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response_b = await http.request("GET", "http://127.0.0.1:8000/") assert len(http.active_connections) == 2 assert len(http.keepalive_connections) == 0 @@ -97,6 +103,7 @@ async def test_close_connections(server): headers = [(b"connection", b"close")] async with httpcore.ConnectionPool() as http: response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers) + await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 0 @@ -107,7 +114,7 @@ async def test_standard_response_close(server): A standard close should keep the connection open. """ async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() await response.close() assert len(http.active_connections) == 0 @@ -120,7 +127,7 @@ async def test_premature_response_close(server): A premature close should close the connection. """ async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + response = await http.request("GET", "http://127.0.0.1:8000/") await response.close() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 0 diff --git a/tests/dispatch/test_connections.py b/tests/dispatch/test_connections.py index ca401c7813..4b267f4fd9 100644 --- a/tests/dispatch/test_connections.py +++ b/tests/dispatch/test_connections.py @@ -7,6 +7,7 @@ async def test_get(server): conn = HTTPConnection(origin="http://127.0.0.1:8000/") response = await conn.request("GET", "http://127.0.0.1:8000/") + await response.read() assert response.status_code == 200 assert response.content == b"Hello, world!" @@ -27,6 +28,7 @@ async def test_https_get_with_ssl_defaults(https_server): """ conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False) response = await conn.request("GET", "https://127.0.0.1:8001/") + await response.read() assert response.status_code == 200 assert response.content == b"Hello, world!" @@ -38,5 +40,6 @@ async def test_https_get_with_sll_overrides(https_server): """ conn = HTTPConnection(origin="https://127.0.0.1:8001/") response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False) + await response.read() assert response.status_code == 200 assert response.content == b"Hello, world!" diff --git a/tests/dispatch/test_threaded.py b/tests/dispatch/test_threaded.py new file mode 100644 index 0000000000..d177dbba96 --- /dev/null +++ b/tests/dispatch/test_threaded.py @@ -0,0 +1,100 @@ +import json + +import pytest + +from httpcore import ( + CertTypes, + Client, + Dispatcher, + Request, + Response, + TimeoutTypes, + VerifyTypes, +) + + +def streaming_body(): + for part in [b"Hello", b", ", b"world!"]: + yield part + + +class MockDispatch(Dispatcher): + def send( + self, + request: Request, + verify: VerifyTypes = None, + cert: CertTypes = None, + timeout: TimeoutTypes = None, + ) -> Response: + if request.url.path == "/streaming_response": + return Response(200, content=streaming_body(), request=request) + elif request.url.path == "/echo_request_body": + content = request.read() + return Response(200, content=content, request=request) + elif request.url.path == "/echo_request_body_streaming": + content = b"".join([part for part in request.stream()]) + return Response(200, content=content, request=request) + else: + body = json.dumps({"hello": "world"}).encode() + return Response(200, content=body, request=request) + + +def test_threaded_dispatch(): + """ + Use a syncronous 'Dispatcher' class with the client. + Calls to the dispatcher will end up running within a thread pool. + """ + url = "https://example.org/" + with Client(dispatch=MockDispatch()) as client: + response = client.get(url) + + assert response.status_code == 200 + assert response.json() == {"hello": "world"} + + +def test_threaded_streaming_response(): + url = "https://example.org/streaming_response" + with Client(dispatch=MockDispatch()) as client: + response = client.get(url) + + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +def test_threaded_streaming_request(): + url = "https://example.org/echo_request_body" + with Client(dispatch=MockDispatch()) as client: + response = client.post(url, data=streaming_body()) + + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +def test_threaded_request_body(): + url = "https://example.org/echo_request_body" + with Client(dispatch=MockDispatch()) as client: + response = client.post(url, data=b"Hello, world!") + + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +def test_threaded_request_body_streaming(): + url = "https://example.org/echo_request_body_streaming" + with Client(dispatch=MockDispatch()) as client: + response = client.post(url, data=b"Hello, world!") + + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +def test_dispatch_class(): + """ + Use a syncronous 'Dispatcher' class directly. + """ + url = "https://example.org/" + with MockDispatch() as dispatcher: + response = dispatcher.request("GET", url) + + assert response.status_code == 200 + assert response.json() == {"hello": "world"} diff --git a/tests/models/test_requests.py b/tests/models/test_requests.py index d0d521a468..79cbba36e8 100644 --- a/tests/models/test_requests.py +++ b/tests/models/test_requests.py @@ -10,87 +10,62 @@ def test_request_repr(): def test_no_content(): request = httpcore.Request("GET", "http://example.org") - request.prepare() - assert request.headers == httpcore.Headers( - [(b"accept-encoding", b"deflate, gzip, br")] - ) + assert "Content-Length" not in request.headers def test_content_length_header(): request = httpcore.Request("POST", "http://example.org", data=b"test 123") - request.prepare() - assert request.headers == httpcore.Headers( - [ - (b"content-length", b"8"), - (b"accept-encoding", b"deflate, gzip, br"), - ] - ) + assert request.headers["Content-Length"] == "8" def test_url_encoded_data(): - request = httpcore.Request("POST", "http://example.org", data={"test": "123"}) - request.prepare() - assert request.headers == httpcore.Headers( - [ - (b"content-length", b"8"), - (b"accept-encoding", b"deflate, gzip, br"), - (b"content-type", b"application/x-www-form-urlencoded"), - ] - ) - assert request.content == b"test=123" + for RequestClass in (httpcore.Request, httpcore.AsyncRequest): + request = RequestClass("POST", "http://example.org", data={"test": "123"}) + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert request.content == b"test=123" + + +def test_json_encoded_data(): + for RequestClass in (httpcore.Request, httpcore.AsyncRequest): + request = RequestClass("POST", "http://example.org", json={"test": 123}) + assert request.headers["Content-Type"] == "application/json" + assert request.content == b'{"test": 123}' def test_transfer_encoding_header(): - async def streaming_body(data): + def streaming_body(data): yield data # pragma: nocover data = streaming_body(b"test 123") request = httpcore.Request("POST", "http://example.org", data=data) - request.prepare() - assert request.headers == httpcore.Headers( - [ - (b"transfer-encoding", b"chunked"), - (b"accept-encoding", b"deflate, gzip, br"), - ] - ) + assert "Content-Length" not in request.headers + assert request.headers["Transfer-Encoding"] == "chunked" def test_override_host_header(): - headers = [(b"host", b"1.2.3.4:80")] + headers = {"host": "1.2.3.4:80"} request = httpcore.Request("GET", "http://example.org", headers=headers) - request.prepare() - assert request.headers == httpcore.Headers( - [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")] - ) + assert request.headers["Host"] == "1.2.3.4:80" def test_override_accept_encoding_header(): - headers = [(b"accept-encoding", b"identity")] + headers = {"Accept-Encoding": "identity"} request = httpcore.Request("GET", "http://example.org", headers=headers) - request.prepare() - assert request.headers == httpcore.Headers( - [(b"accept-encoding", b"identity")] - ) + assert request.headers["Accept-Encoding"] == "identity" def test_override_content_length_header(): - async def streaming_body(data): + def streaming_body(data): yield data # pragma: nocover data = streaming_body(b"test 123") - headers = [(b"content-length", b"8")] + headers = {"Content-Length": "8"} request = httpcore.Request("POST", "http://example.org", data=data, headers=headers) - request.prepare() - assert request.headers == httpcore.Headers( - [ - (b"accept-encoding", b"deflate, gzip, br"), - (b"content-length", b"8"), - ] - ) + assert request.headers["Content-Length"] == "8" def test_url(): diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 8ecd37ab5c..f2d080ffc8 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -3,7 +3,12 @@ import httpcore -async def streaming_body(): +def streaming_body(): + yield b"Hello, " + yield b"world!" + + +async def async_streaming_body(): yield b"Hello, " yield b"world!" @@ -105,8 +110,7 @@ def test_response_force_encoding(): assert response.encoding == "iso-8859-1" -@pytest.mark.asyncio -async def test_read_response(): +def test_read_response(): response = httpcore.Response(200, content=b"Hello, world!") assert response.status_code == 200 @@ -114,37 +118,56 @@ async def test_read_response(): assert response.encoding == "ascii" assert response.is_closed - content = await response.read() + content = response.read() assert content == b"Hello, world!" assert response.content == b"Hello, world!" assert response.is_closed -@pytest.mark.asyncio -async def test_raw_interface(): +def test_raw_interface(): response = httpcore.Response(200, content=b"Hello, world!") raw = b"" - async for part in response.raw(): + for part in response.raw(): raw += part assert raw == b"Hello, world!" -@pytest.mark.asyncio -async def test_stream_interface(): +def test_stream_interface(): response = httpcore.Response(200, content=b"Hello, world!") content = b"" - async for part in response.stream(): + for part in response.stream(): content += part assert content == b"Hello, world!" @pytest.mark.asyncio -async def test_stream_interface_after_read(): +async def test_async_stream_interface(): + response = httpcore.AsyncResponse(200, content=b"Hello, world!") + + content = b"" + async for part in response.stream(): + content += part + assert content == b"Hello, world!" + + +def test_stream_interface_after_read(): response = httpcore.Response(200, content=b"Hello, world!") + response.read() + + content = b"" + for part in response.stream(): + content += part + assert content == b"Hello, world!" + + +@pytest.mark.asyncio +async def test_async_stream_interface_after_read(): + response = httpcore.AsyncResponse(200, content=b"Hello, world!") + await response.read() content = b"" @@ -153,14 +176,13 @@ async def test_stream_interface_after_read(): assert content == b"Hello, world!" -@pytest.mark.asyncio -async def test_streaming_response(): +def test_streaming_response(): response = httpcore.Response(200, content=streaming_body()) assert response.status_code == 200 assert not response.is_closed - content = await response.read() + content = response.read() assert content == b"Hello, world!" assert response.content == b"Hello, world!" @@ -168,9 +190,34 @@ async def test_streaming_response(): @pytest.mark.asyncio -async def test_cannot_read_after_stream_consumed(): +async def test_async_streaming_response(): + response = httpcore.AsyncResponse(200, content=async_streaming_body()) + + assert response.status_code == 200 + assert not response.is_closed + + content = await response.read() + + assert content == b"Hello, world!" + assert response.content == b"Hello, world!" + assert response.is_closed + + +def test_cannot_read_after_stream_consumed(): response = httpcore.Response(200, content=streaming_body()) + content = b"" + for part in response.stream(): + content += part + + with pytest.raises(httpcore.StreamConsumed): + response.read() + + +@pytest.mark.asyncio +async def test_async_cannot_read_after_stream_consumed(): + response = httpcore.AsyncResponse(200, content=async_streaming_body()) + content = b"" async for part in response.stream(): content += part @@ -179,10 +226,19 @@ async def test_cannot_read_after_stream_consumed(): await response.read() -@pytest.mark.asyncio -async def test_cannot_read_after_response_closed(): +def test_cannot_read_after_response_closed(): response = httpcore.Response(200, content=streaming_body()) + response.close() + + with pytest.raises(httpcore.ResponseClosed): + response.read() + + +@pytest.mark.asyncio +async def test_async_cannot_read_after_response_closed(): + response = httpcore.AsyncResponse(200, content=async_streaming_body()) + await response.close() with pytest.raises(httpcore.ResponseClosed): diff --git a/tests/test_api.py b/tests/test_api.py index 6a62359c16..1247a41602 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -38,6 +38,18 @@ def test_post(server): assert response.reason_phrase == "OK" +@threadpool +def test_post_byte_iterator(server): + def data(): + yield b"Hello" + yield b", " + yield b"world!" + + response = httpcore.post("http://127.0.0.1:8000/", data=data()) + assert response.status_code == 200 + assert response.reason_phrase == "OK" + + @threadpool def test_options(server): response = httpcore.options("http://127.0.0.1:8000/") diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 20273eec26..ac795ca91e 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -64,19 +64,18 @@ def test_multi_with_identity(): assert response.content == body -@pytest.mark.asyncio -async def test_streaming(): +def test_streaming(): body = b"test 123" compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) - async def compress(body): + def compress(body): yield compressor.compress(body) yield compressor.flush() headers = [(b"Content-Encoding", b"gzip")] response = httpcore.Response(200, headers=headers, content=compress(body)) assert not hasattr(response, "body") - assert await response.read() == body + assert response.read() == body @pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br"))