From b11db9937be8ea19a6a22ea3fc479b866b6a1315 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 29 Oct 2024 10:49:03 +0000 Subject: [PATCH 1/5] Add proxy configuration to ConnectionPool --- docs/async.md | 15 ---------- docs/proxies.md | 47 +++++++++++++++++------------- docs/table-of-contents.md | 3 +- httpcore/__init__.py | 3 +- httpcore/_async/connection_pool.py | 44 ++++++++++++++++++++++++++-- httpcore/_models.py | 25 ++++++++++++++++ httpcore/_sync/connection_pool.py | 44 ++++++++++++++++++++++++++-- 7 files changed, 139 insertions(+), 42 deletions(-) diff --git a/docs/async.md b/docs/async.md index f80f8b7d6..e0f0a65c8 100644 --- a/docs/async.md +++ b/docs/async.md @@ -34,14 +34,6 @@ async with httpcore.AsyncConnectionPool() as http: ... ``` -Or if connecting via a proxy: - -```python -# The async variation of `httpcore.HTTPProxy` -async with httpcore.AsyncHTTPProxy() as proxy: - ... -``` - ### Sending requests Sending requests with the async version of `httpcore` requires the `await` keyword: @@ -221,10 +213,3 @@ anyio.run(main) handler: python rendering: show_source: False - -## `httpcore.AsyncHTTPProxy` - -::: httpcore.AsyncHTTPProxy - handler: python - rendering: - show_source: False diff --git a/docs/proxies.md b/docs/proxies.md index 72eaeb644..970d53d53 100644 --- a/docs/proxies.md +++ b/docs/proxies.md @@ -7,7 +7,8 @@ Sending requests via a proxy is very similar to sending requests using a standar ```python import httpcore -proxy = httpcore.HTTPProxy(proxy_url="http://127.0.0.1:8080/") +proxy = httpcore.Proxy("http://127.0.0.1:8080/") +pool = httpcore.ConnectionPool(proxy=proxy) r = proxy.request("GET", "https://www.example.com/") print(r) @@ -31,10 +32,11 @@ Proxy authentication can be included in the initial configuration: import httpcore # A `Proxy-Authorization` header will be included on the initial proxy connection. -proxy = httpcore.HTTPProxy( - proxy_url="http://127.0.0.1:8080/", - proxy_auth=("", "") +proxy = httpcore.Proxy( + url="http://127.0.0.1:8080/", + auth=("", "") ) +pool = httpcore.ConnectionPool(proxy=proxy) ``` Custom headers can also be included: @@ -45,10 +47,11 @@ import base64 # Construct and include a `Proxy-Authorization` header. auth = base64.b64encode(b":") -proxy = httpcore.HTTPProxy( - proxy_url="http://127.0.0.1:8080/", - proxy_headers={"Proxy-Authorization": b"Basic " + auth} +proxy = httpcore.Proxy( + url="http://127.0.0.1:8080/", + headers={"Proxy-Authorization": b"Basic " + auth} ) +pool = httpcore.ConnectionPool(proxy=proxy) ``` ## Proxy SSL @@ -58,10 +61,10 @@ The `httpcore` package also supports HTTPS proxies for http and https destinatio HTTPS proxies can be used in the same way that HTTP proxies are. ```python -proxy = httpcore.HTTPProxy(proxy_url="https://127.0.0.1:8080/") +proxy = httpcore.Proxy(url="https://127.0.0.1:8080/") ``` -Also, when using HTTPS proxies, you may need to configure the SSL context, which you can do with the `proxy_ssl_context` argument. +Also, when using HTTPS proxies, you may need to configure the SSL context, which you can do with the `ssl_context` argument. ```python import ssl @@ -70,11 +73,13 @@ import httpcore proxy_ssl_context = ssl.create_default_context() proxy_ssl_context.check_hostname = False -proxy = httpcore.HTTPProxy('https://127.0.0.1:8080/', proxy_ssl_context=proxy_ssl_context) +proxy = httpcore.Proxy( + url='https://127.0.0.1:8080/', + ssl_context=proxy_ssl_context +) +pool = httpcore.ConnectionPool(proxy=proxy) ``` -It is important to note that the `ssl_context` argument is always used for the remote connection, and the `proxy_ssl_context` argument is always used for the proxy connection. - ## HTTP Versions If you use proxies, keep in mind that the `httpcore` package only supports proxies to HTTP/1.1 servers. @@ -91,8 +96,9 @@ The `SOCKSProxy` class should be using instead of a standard connection pool: import httpcore # Note that the SOCKS port is 1080. -proxy = httpcore.SOCKSProxy(proxy_url="socks5://127.0.0.1:1080/") -r = proxy.request("GET", "https://www.example.com/") +proxy = httpcore.Proxy(url="socks5://127.0.0.1:1080/") +pool = httpcore.ConnectionPool(proxy=proxy) +r = pool.request("GET", "https://www.example.com/") ``` Authentication via SOCKS is also supported: @@ -100,20 +106,21 @@ Authentication via SOCKS is also supported: ```python import httpcore -proxy = httpcore.SOCKSProxy( - proxy_url="socks5://127.0.0.1:8080/", - proxy_auth=("", "") +proxy = httpcore.Proxy( + url="socks5://127.0.0.1:1080/", + auth=("", ""), ) -r = proxy.request("GET", "https://www.example.com/") +pool = httpcore.ConnectionPool(proxy=proxy) +r = pool.request("GET", "https://www.example.com/") ``` --- # Reference -## `httpcore.HTTPProxy` +## `httpcore.Proxy` -::: httpcore.HTTPProxy +::: httpcore.Proxy handler: python rendering: show_source: False diff --git a/docs/table-of-contents.md b/docs/table-of-contents.md index 3cf1f725e..5dc9a10b4 100644 --- a/docs/table-of-contents.md +++ b/docs/table-of-contents.md @@ -10,14 +10,13 @@ * Connection Pools * `httpcore.ConnectionPool` * Proxies - * `httpcore.HTTPProxy` + * `httpcore.Proxy` * Connections * `httpcore.HTTPConnection` * `httpcore.HTTP11Connection` * `httpcore.HTTP2Connection` * Async Support * `httpcore.AsyncConnectionPool` - * `httpcore.AsyncHTTPProxy` * `httpcore.AsyncHTTPConnection` * `httpcore.AsyncHTTP11Connection` * `httpcore.AsyncHTTP2Connection` diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 330745a5d..0d4946e70 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -34,7 +34,7 @@ WriteError, WriteTimeout, ) -from ._models import URL, Origin, Request, Response +from ._models import URL, Origin, Proxy, Request, Response from ._ssl import default_ssl_context from ._sync import ( ConnectionInterface, @@ -79,6 +79,7 @@ def __init__(self, *args, **kwargs): # type: ignore "URL", "Request", "Response", + "Proxy", # async "AsyncHTTPConnection", "AsyncConnectionPool", diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 0795b9ccb..96e973d0c 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -8,7 +8,7 @@ from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol -from .._models import Origin, Request, Response +from .._models import Origin, Proxy, Request, Response from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock from .connection import AsyncHTTPConnection from .interfaces import AsyncConnectionInterface, AsyncRequestInterface @@ -48,6 +48,7 @@ class AsyncConnectionPool(AsyncRequestInterface): def __init__( self, ssl_context: ssl.SSLContext | None = None, + proxy: Proxy | None = None, max_connections: int | None = 10, max_keepalive_connections: int | None = None, keepalive_expiry: float | None = None, @@ -89,7 +90,7 @@ def __init__( in the TCP socket when the connection was established. """ self._ssl_context = ssl_context - + self._proxy = proxy self._max_connections = ( sys.maxsize if max_connections is None else max_connections ) @@ -125,6 +126,45 @@ def __init__( self._optional_thread_lock = AsyncThreadLock() def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + if self._proxy is not None: + if self._proxy.url.scheme in (b"socks5", b"socks5h"): + from .socks_proxy import AsyncSocks5Connection + + return AsyncSocks5Connection( + proxy_origin=self._proxy.url.origin, + proxy_auth=self._proxy.auth, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + elif origin.scheme == b"http": + from .http_proxy import AsyncForwardHTTPConnection + + return AsyncForwardHTTPConnection( + proxy_origin=self._proxy.url.origin, + proxy_headers=self._proxy.headers, + proxy_ssl_context=self._proxy.ssl_context, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + ) + from .http_proxy import AsyncTunnelHTTPConnection + + return AsyncTunnelHTTPConnection( + proxy_origin=self._proxy.url.origin, + proxy_headers=self._proxy.headers, + proxy_ssl_context=self._proxy.ssl_context, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + return AsyncHTTPConnection( origin=origin, ssl_context=self._ssl_context, diff --git a/httpcore/_models.py b/httpcore/_models.py index c739a7fa6..8a65f1334 100644 --- a/httpcore/_models.py +++ b/httpcore/_models.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import ssl import typing import urllib.parse @@ -489,3 +491,26 @@ async def aclose(self) -> None: ) if hasattr(self.stream, "aclose"): await self.stream.aclose() + + +class Proxy: + def __init__( + self, + url: URL | bytes | str, + auth: tuple[bytes | str, bytes | str] | None = None, + headers: HeadersAsMapping | HeadersAsSequence | None = None, + ssl_context: ssl.SSLContext | None = None, + ): + self.url = enforce_url(url, name="url") + self.headers = enforce_headers(headers, name="headers") + self.ssl_context = ssl_context + + if auth is not None: + username = enforce_bytes(auth[0], name="auth") + password = enforce_bytes(auth[1], name="auth") + userpass = username + b":" + password + authorization = b"Basic " + base64.b64encode(userpass) + self.auth: tuple[bytes, bytes] | None = (username, password) + self.headers = [(b"Proxy-Authorization", authorization)] + self.headers + else: + self.auth = None diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 00c3983dd..9ccfa53e5 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -8,7 +8,7 @@ from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol -from .._models import Origin, Request, Response +from .._models import Origin, Proxy, Request, Response from .._synchronization import Event, ShieldCancellation, ThreadLock from .connection import HTTPConnection from .interfaces import ConnectionInterface, RequestInterface @@ -48,6 +48,7 @@ class ConnectionPool(RequestInterface): def __init__( self, ssl_context: ssl.SSLContext | None = None, + proxy: Proxy | None = None, max_connections: int | None = 10, max_keepalive_connections: int | None = None, keepalive_expiry: float | None = None, @@ -89,7 +90,7 @@ def __init__( in the TCP socket when the connection was established. """ self._ssl_context = ssl_context - + self._proxy = proxy self._max_connections = ( sys.maxsize if max_connections is None else max_connections ) @@ -125,6 +126,45 @@ def __init__( self._optional_thread_lock = ThreadLock() def create_connection(self, origin: Origin) -> ConnectionInterface: + if self._proxy is not None: + if self._proxy.url.scheme in (b"socks5", b"socks5h"): + from .socks_proxy import Socks5Connection + + return Socks5Connection( + proxy_origin=self._proxy.url.origin, + proxy_auth=self._proxy.auth, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + elif origin.scheme == b"http": + from .http_proxy import ForwardHTTPConnection + + return ForwardHTTPConnection( + proxy_origin=self._proxy.url.origin, + proxy_headers=self._proxy.headers, + proxy_ssl_context=self._proxy.ssl_context, + remote_origin=origin, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, + ) + from .http_proxy import TunnelHTTPConnection + + return TunnelHTTPConnection( + proxy_origin=self._proxy.url.origin, + proxy_headers=self._proxy.headers, + proxy_ssl_context=self._proxy.ssl_context, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + http1=self._http1, + http2=self._http2, + network_backend=self._network_backend, + ) + return HTTPConnection( origin=origin, ssl_context=self._ssl_context, From 0f9dfb041eb32e1af35797c1c177cf96a0ad934e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 29 Oct 2024 17:48:03 +0000 Subject: [PATCH 2/5] Update tests for new proxy API, and nocover old classes. --- httpcore/_async/http_proxy.py | 10 +++----- httpcore/_async/socks_proxy.py | 2 +- httpcore/_sync/http_proxy.py | 10 +++----- httpcore/_sync/socks_proxy.py | 2 +- tests/_async/test_http_proxy.py | 41 +++++++++++++++++++------------- tests/_async/test_socks_proxy.py | 28 ++++++++++++---------- tests/_sync/test_http_proxy.py | 41 +++++++++++++++++++------------- tests/_sync/test_socks_proxy.py | 28 ++++++++++++---------- 8 files changed, 88 insertions(+), 74 deletions(-) diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index 4cbfe5186..cc9d92066 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -51,12 +51,7 @@ def merge_headers( return default_headers + override_headers -def build_auth_header(username: bytes, password: bytes) -> bytes: - userpass = username + b":" + password - return b"Basic " + base64.b64encode(userpass) - - -class AsyncHTTPProxy(AsyncConnectionPool): +class AsyncHTTPProxy(AsyncConnectionPool): # pragma: nocover """ A connection pool that sends requests via an HTTP proxy. """ @@ -142,7 +137,8 @@ def __init__( if proxy_auth is not None: username = enforce_bytes(proxy_auth[0], name="proxy_auth") password = enforce_bytes(proxy_auth[1], name="proxy_auth") - authorization = build_auth_header(username, password) + userpass = username + b":" + password + authorization = b"Basic " + base64.b64encode(userpass) self._proxy_headers = [ (b"Proxy-Authorization", authorization) ] + self._proxy_headers diff --git a/httpcore/_async/socks_proxy.py b/httpcore/_async/socks_proxy.py index ef96a8c17..b363f55a0 100644 --- a/httpcore/_async/socks_proxy.py +++ b/httpcore/_async/socks_proxy.py @@ -102,7 +102,7 @@ async def _init_socks5_connection( raise ProxyError(f"Proxy Server could not connect: {reply_code}.") -class AsyncSOCKSProxy(AsyncConnectionPool): +class AsyncSOCKSProxy(AsyncConnectionPool): # pragma: nocover """ A connection pool that sends requests via an HTTP proxy. """ diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index e58693ecb..ecca88f7d 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -51,12 +51,7 @@ def merge_headers( return default_headers + override_headers -def build_auth_header(username: bytes, password: bytes) -> bytes: - userpass = username + b":" + password - return b"Basic " + base64.b64encode(userpass) - - -class HTTPProxy(ConnectionPool): +class HTTPProxy(ConnectionPool): # pragma: nocover """ A connection pool that sends requests via an HTTP proxy. """ @@ -142,7 +137,8 @@ def __init__( if proxy_auth is not None: username = enforce_bytes(proxy_auth[0], name="proxy_auth") password = enforce_bytes(proxy_auth[1], name="proxy_auth") - authorization = build_auth_header(username, password) + userpass = username + b":" + password + authorization = b"Basic " + base64.b64encode(userpass) self._proxy_headers = [ (b"Proxy-Authorization", authorization) ] + self._proxy_headers diff --git a/httpcore/_sync/socks_proxy.py b/httpcore/_sync/socks_proxy.py index 61dd7e380..0ca96ddfb 100644 --- a/httpcore/_sync/socks_proxy.py +++ b/httpcore/_sync/socks_proxy.py @@ -102,7 +102,7 @@ def _init_socks5_connection( raise ProxyError(f"Proxy Server could not connect: {reply_code}.") -class SOCKSProxy(ConnectionPool): +class SOCKSProxy(ConnectionPool): # pragma: nocover """ A connection pool that sends requests via an HTTP proxy. """ diff --git a/tests/_async/test_http_proxy.py b/tests/_async/test_http_proxy.py index b35fc2899..84a984b80 100644 --- a/tests/_async/test_http_proxy.py +++ b/tests/_async/test_http_proxy.py @@ -7,11 +7,12 @@ from httpcore import ( SOCKET_OPTION, - AsyncHTTPProxy, + AsyncConnectionPool, AsyncMockBackend, AsyncMockStream, AsyncNetworkStream, Origin, + Proxy, ProxyError, ) @@ -31,8 +32,8 @@ async def test_proxy_forwarding(): ] ) - async with AsyncHTTPProxy( - proxy_url="http://localhost:8080/", + async with AsyncConnectionPool( + proxy=Proxy("http://localhost:8080/"), max_connections=10, network_backend=network_backend, ) as proxy: @@ -87,8 +88,8 @@ async def test_proxy_tunneling(): ] ) - async with AsyncHTTPProxy( - proxy_url="http://localhost:8080/", + async with AsyncConnectionPool( + proxy=Proxy("http://localhost:8080/"), network_backend=network_backend, ) as proxy: # Sending an intial request, which once complete will return to the pool, IDLE. @@ -178,8 +179,8 @@ async def test_proxy_tunneling_http2(): ], ) - async with AsyncHTTPProxy( - proxy_url="http://localhost:8080/", + async with AsyncConnectionPool( + proxy=Proxy("http://localhost:8080/"), network_backend=network_backend, http2=True, ) as proxy: @@ -227,8 +228,8 @@ async def test_proxy_tunneling_with_403(): ] ) - async with AsyncHTTPProxy( - proxy_url="http://localhost:8080/", + async with AsyncConnectionPool( + proxy=Proxy("http://localhost:8080/"), network_backend=network_backend, ) as proxy: with pytest.raises(ProxyError) as exc_info: @@ -255,17 +256,23 @@ async def test_proxy_tunneling_with_auth(): ] ) - async with AsyncHTTPProxy( - proxy_url="http://localhost:8080/", - proxy_auth=("username", "password"), + async with AsyncConnectionPool( + proxy=Proxy( + url="http://localhost:8080/", + auth=("username", "password"), + ), network_backend=network_backend, ) as proxy: response = await proxy.request("GET", "https://example.com/") assert response.status == 200 assert response.content == b"Hello, world!" - # Dig into this private property as a cheap lazy way of - # checking that the proxy header is set correctly. - assert proxy._proxy_headers == [ # type: ignore - (b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=") - ] + +def test_proxy_headers(): + proxy = Proxy( + url="http://localhost:8080/", + auth=("username", "password"), + ) + assert proxy.headers == [ + (b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=") + ] diff --git a/tests/_async/test_socks_proxy.py b/tests/_async/test_socks_proxy.py index 3f5dd1cc0..907594a40 100644 --- a/tests/_async/test_socks_proxy.py +++ b/tests/_async/test_socks_proxy.py @@ -24,8 +24,8 @@ async def test_socks5_request(): ] ) - async with httpcore.AsyncSOCKSProxy( - proxy_url="socks5://localhost:8080/", + async with httpcore.AsyncConnectionPool( + proxy=httpcore.Proxy("socks5://localhost:8080/"), network_backend=network_backend, ) as proxy: # Sending an intial request, which once complete will return to the pool, IDLE. @@ -84,9 +84,11 @@ async def test_authenticated_socks5_request(): ] ) - async with httpcore.AsyncSOCKSProxy( - proxy_url="socks5://localhost:8080/", - proxy_auth=(b"username", b"password"), + async with httpcore.AsyncConnectionPool( + proxy=httpcore.Proxy( + url="socks5://localhost:8080/", + auth=(b"username", b"password"), + ), network_backend=network_backend, ) as proxy: # Sending an intial request, which once complete will return to the pool, IDLE. @@ -123,8 +125,8 @@ async def test_socks5_request_connect_failed(): ] ) - async with httpcore.AsyncSOCKSProxy( - proxy_url="socks5://localhost:8080/", + async with httpcore.AsyncConnectionPool( + proxy=httpcore.Proxy("socks5://localhost:8080/"), network_backend=network_backend, ) as proxy: # Sending a request, which the proxy rejects @@ -150,8 +152,8 @@ async def test_socks5_request_failed_to_provide_auth(): ] ) - async with httpcore.AsyncSOCKSProxy( - proxy_url="socks5://localhost:8080/", + async with httpcore.AsyncConnectionPool( + proxy=httpcore.Proxy("socks5://localhost:8080/"), network_backend=network_backend, ) as proxy: # Sending a request, which the proxy rejects @@ -180,9 +182,11 @@ async def test_socks5_request_incorrect_auth(): ] ) - async with httpcore.AsyncSOCKSProxy( - proxy_url="socks5://localhost:8080/", - proxy_auth=(b"invalid", b"invalid"), + async with httpcore.AsyncConnectionPool( + proxy=httpcore.Proxy( + url="socks5://localhost:8080/", + auth=(b"invalid", b"invalid"), + ), network_backend=network_backend, ) as proxy: # Sending a request, which the proxy rejects diff --git a/tests/_sync/test_http_proxy.py b/tests/_sync/test_http_proxy.py index 2d66578e2..966672dd2 100644 --- a/tests/_sync/test_http_proxy.py +++ b/tests/_sync/test_http_proxy.py @@ -7,11 +7,12 @@ from httpcore import ( SOCKET_OPTION, - HTTPProxy, + ConnectionPool, MockBackend, MockStream, NetworkStream, Origin, + Proxy, ProxyError, ) @@ -31,8 +32,8 @@ def test_proxy_forwarding(): ] ) - with HTTPProxy( - proxy_url="http://localhost:8080/", + with ConnectionPool( + proxy=Proxy("http://localhost:8080/"), max_connections=10, network_backend=network_backend, ) as proxy: @@ -87,8 +88,8 @@ def test_proxy_tunneling(): ] ) - with HTTPProxy( - proxy_url="http://localhost:8080/", + with ConnectionPool( + proxy=Proxy("http://localhost:8080/"), network_backend=network_backend, ) as proxy: # Sending an intial request, which once complete will return to the pool, IDLE. @@ -178,8 +179,8 @@ def test_proxy_tunneling_http2(): ], ) - with HTTPProxy( - proxy_url="http://localhost:8080/", + with ConnectionPool( + proxy=Proxy("http://localhost:8080/"), network_backend=network_backend, http2=True, ) as proxy: @@ -227,8 +228,8 @@ def test_proxy_tunneling_with_403(): ] ) - with HTTPProxy( - proxy_url="http://localhost:8080/", + with ConnectionPool( + proxy=Proxy("http://localhost:8080/"), network_backend=network_backend, ) as proxy: with pytest.raises(ProxyError) as exc_info: @@ -255,17 +256,23 @@ def test_proxy_tunneling_with_auth(): ] ) - with HTTPProxy( - proxy_url="http://localhost:8080/", - proxy_auth=("username", "password"), + with ConnectionPool( + proxy=Proxy( + url="http://localhost:8080/", + auth=("username", "password"), + ), network_backend=network_backend, ) as proxy: response = proxy.request("GET", "https://example.com/") assert response.status == 200 assert response.content == b"Hello, world!" - # Dig into this private property as a cheap lazy way of - # checking that the proxy header is set correctly. - assert proxy._proxy_headers == [ # type: ignore - (b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=") - ] + +def test_proxy_headers(): + proxy = Proxy( + url="http://localhost:8080/", + auth=("username", "password"), + ) + assert proxy.headers == [ + (b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=") + ] diff --git a/tests/_sync/test_socks_proxy.py b/tests/_sync/test_socks_proxy.py index 2d39bb97a..89ec9faee 100644 --- a/tests/_sync/test_socks_proxy.py +++ b/tests/_sync/test_socks_proxy.py @@ -24,8 +24,8 @@ def test_socks5_request(): ] ) - with httpcore.SOCKSProxy( - proxy_url="socks5://localhost:8080/", + with httpcore.ConnectionPool( + proxy=httpcore.Proxy("socks5://localhost:8080/"), network_backend=network_backend, ) as proxy: # Sending an intial request, which once complete will return to the pool, IDLE. @@ -84,9 +84,11 @@ def test_authenticated_socks5_request(): ] ) - with httpcore.SOCKSProxy( - proxy_url="socks5://localhost:8080/", - proxy_auth=(b"username", b"password"), + with httpcore.ConnectionPool( + proxy=httpcore.Proxy( + url="socks5://localhost:8080/", + auth=(b"username", b"password"), + ), network_backend=network_backend, ) as proxy: # Sending an intial request, which once complete will return to the pool, IDLE. @@ -123,8 +125,8 @@ def test_socks5_request_connect_failed(): ] ) - with httpcore.SOCKSProxy( - proxy_url="socks5://localhost:8080/", + with httpcore.ConnectionPool( + proxy=httpcore.Proxy("socks5://localhost:8080/"), network_backend=network_backend, ) as proxy: # Sending a request, which the proxy rejects @@ -150,8 +152,8 @@ def test_socks5_request_failed_to_provide_auth(): ] ) - with httpcore.SOCKSProxy( - proxy_url="socks5://localhost:8080/", + with httpcore.ConnectionPool( + proxy=httpcore.Proxy("socks5://localhost:8080/"), network_backend=network_backend, ) as proxy: # Sending a request, which the proxy rejects @@ -180,9 +182,11 @@ def test_socks5_request_incorrect_auth(): ] ) - with httpcore.SOCKSProxy( - proxy_url="socks5://localhost:8080/", - proxy_auth=(b"invalid", b"invalid"), + with httpcore.ConnectionPool( + proxy=httpcore.Proxy( + url="socks5://localhost:8080/", + auth=(b"invalid", b"invalid"), + ), network_backend=network_backend, ) as proxy: # Sending a request, which the proxy rejects From 228e7913ce116370946357100a128c13e7c5f66f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 4 Nov 2024 15:45:27 +0000 Subject: [PATCH 3/5] Update CHANGELOG --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f63110425..e33b10c40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [Unreleased] + +- Support `proxy=…` configuration on `ConnectionPool()`. + ## Version 1.0.6 (October 1st, 2024) - Relax `trio` dependency pinning. (#956) From ee9cfe10e7673179bfd802a0737e7143bae4a309 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 11 Nov 2024 17:11:53 +0000 Subject: [PATCH 4/5] Iterate refactor --- httpcore/_async/connection.py | 25 ++- httpcore/_async/connection_pool.py | 272 ++++++----------------------- httpcore/_async/http11.py | 174 ++++++++---------- httpcore/_async/http2.py | 173 +++++++++--------- httpcore/_async/http_proxy.py | 4 +- httpcore/_async/interfaces.py | 58 +++++- httpcore/_models.py | 3 + httpcore/_sync/connection.py | 26 ++- httpcore/_sync/connection_pool.py | 272 ++++++----------------------- httpcore/_sync/http11.py | 174 ++++++++---------- httpcore/_sync/http_proxy.py | 4 +- httpcore/_sync/interfaces.py | 57 +++++- httpcore/_synchronization.py | 18 +- scripts/unasync.py | 1 + tests/_async/test_connection.py | 46 ++--- tests/_sync/test_connection.py | 46 ++--- 16 files changed, 532 insertions(+), 821 deletions(-) diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index b42581dff..bb6d1709a 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -9,12 +9,12 @@ from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream from .._exceptions import ConnectError, ConnectTimeout -from .._models import Origin, Request, Response +from .._models import Origin, Request from .._ssl import default_ssl_context -from .._synchronization import AsyncLock +from .._synchronization import AsyncSemaphore from .._trace import Trace from .http11 import AsyncHTTP11Connection -from .interfaces import AsyncConnectionInterface +from .interfaces import AsyncConnectionInterface, StartResponse RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. @@ -63,10 +63,10 @@ def __init__( ) self._connection: AsyncConnectionInterface | None = None self._connect_failed: bool = False - self._request_lock = AsyncLock() + self._request_lock = AsyncSemaphore(bound=1) self._socket_options = socket_options - async def handle_async_request(self, request: Request) -> Response: + async def iterate_response(self, request: Request) -> typing.AsyncIterator[StartResponse | bytes]: if not self.can_handle_request(request.url.origin): raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection to {self._origin}" @@ -100,7 +100,11 @@ async def handle_async_request(self, request: Request) -> Response: self._connect_failed = True raise exc - return await self._connection.handle_async_request(request) + iterator = self._connection.iterate_response(request) + start_response = await anext(iterator) + yield start_response + async for body in iterator: + yield body async def _connect(self, request: Request) -> AsyncNetworkStream: timeouts = request.extensions.get("timeout", {}) @@ -174,14 +178,7 @@ async def aclose(self) -> None: def is_available(self) -> bool: if self._connection is None: - # If HTTP/2 support is enabled, and the resulting connection could - # end up as HTTP/2 then we should indicate the connection as being - # available to service multiple requests. - return ( - self._http2 - and (self._origin.scheme == b"https" or not self._http1) - and not self._connect_failed - ) + return False return self._connection.is_available() def has_expired(self) -> bool: diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 96e973d0c..23805e8c4 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -2,42 +2,15 @@ import ssl import sys -import types import typing from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend -from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol -from .._models import Origin, Proxy, Request, Response -from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock +from .._exceptions import UnsupportedProtocol +from .._models import Origin, Proxy, Request +from .._synchronization import AsyncSemaphore from .connection import AsyncHTTPConnection -from .interfaces import AsyncConnectionInterface, AsyncRequestInterface - - -class AsyncPoolRequest: - def __init__(self, request: Request) -> None: - self.request = request - self.connection: AsyncConnectionInterface | None = None - self._connection_acquired = AsyncEvent() - - def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None: - self.connection = connection - self._connection_acquired.set() - - def clear_connection(self) -> None: - self.connection = None - self._connection_acquired = AsyncEvent() - - async def wait_for_connection( - self, timeout: float | None = None - ) -> AsyncConnectionInterface: - if self.connection is None: - await self._connection_acquired.wait(timeout=timeout) - assert self.connection is not None - return self.connection - - def is_queued(self) -> bool: - return self.connection is None +from .interfaces import AsyncConnectionInterface, AsyncRequestInterface, StartResponse class AsyncConnectionPool(AsyncRequestInterface): @@ -49,6 +22,7 @@ def __init__( self, ssl_context: ssl.SSLContext | None = None, proxy: Proxy | None = None, + concurrency_limit: int = 100, max_connections: int | None = 10, max_keepalive_connections: int | None = None, keepalive_expiry: float | None = None, @@ -102,6 +76,7 @@ def __init__( self._max_keepalive_connections = min( self._max_connections, self._max_keepalive_connections ) + self._limits = AsyncSemaphore(bound=concurrency_limit) self._keepalive_expiry = keepalive_expiry self._http1 = http1 @@ -123,7 +98,7 @@ def __init__( # We only mutate the state of the connection pool within an 'optional_thread_lock' # context. This holds a threading lock unless we're running in async mode, # in which case it is a no-op. - self._optional_thread_lock = AsyncThreadLock() + # self._optional_thread_lock = AsyncThreadLock() def create_connection(self, origin: Origin) -> AsyncConnectionInterface: if self._proxy is not None: @@ -196,7 +171,7 @@ def connections(self) -> list[AsyncConnectionInterface]: """ return list(self._connections) - async def handle_async_request(self, request: Request) -> Response: + async def iterate_response(self, request: Request) -> typing.AsyncIterator[StartResponse | bytes]: """ Send an HTTP request, and return an HTTP response. @@ -212,145 +187,50 @@ async def handle_async_request(self, request: Request) -> Response: f"Request URL has an unsupported protocol '{scheme}://'." ) - timeouts = request.extensions.get("timeout", {}) - timeout = timeouts.get("pool", None) - - with self._optional_thread_lock: - # Add the incoming request to our request queue. - pool_request = AsyncPoolRequest(request) - self._requests.append(pool_request) - - try: - while True: - with self._optional_thread_lock: - # Assign incoming requests to available connections, - # closing or creating new connections as required. - closing = self._assign_requests_to_connections() - await self._close_connections(closing) - - # Wait until this request has an assigned connection. - connection = await pool_request.wait_for_connection(timeout=timeout) - - try: - # Send the request on the assigned connection. - response = await connection.handle_async_request( - pool_request.request - ) - except ConnectionNotAvailable: - # In some cases a connection may initially be available to - # handle a request, but then become unavailable. - # - # In this case we clear the connection and try again. - pool_request.clear_connection() - else: - break # pragma: nocover - - except BaseException as exc: - with self._optional_thread_lock: - # For any exception or cancellation we remove the request from - # the queue, and then re-assign requests to connections. - self._requests.remove(pool_request) - closing = self._assign_requests_to_connections() - - await self._close_connections(closing) - raise exc from None - - # Return the response. Note that in this case we still have to manage - # the point at which the response is closed. - assert isinstance(response.stream, typing.AsyncIterable) - return Response( - status=response.status, - headers=response.headers, - content=PoolByteStream( - stream=response.stream, pool_request=pool_request, pool=self - ), - extensions=response.extensions, - ) - - def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]: - """ - Manage the state of the connection pool, assigning incoming - requests to connections as available. - - Called whenever a new request is added or removed from the pool. - - Any closing connections are returned, allowing the I/O for closing - those connections to be handled seperately. - """ - closing_connections = [] - - # First we handle cleaning up any connections that are closed, - # have expired their keep-alive, or surplus idle connections. - for connection in list(self._connections): - if connection.is_closed(): - # log: "removing closed connection" - self._connections.remove(connection) - elif connection.has_expired(): - # log: "closing expired connection" - self._connections.remove(connection) - closing_connections.append(connection) - elif ( - connection.is_idle() - and len([connection.is_idle() for connection in self._connections]) - > self._max_keepalive_connections - ): - # log: "closing idle connection" - self._connections.remove(connection) - closing_connections.append(connection) - - # Assign queued requests to connections. - queued_requests = [request for request in self._requests if request.is_queued()] - for pool_request in queued_requests: - origin = pool_request.request.url.origin - available_connections = [ - connection - for connection in self._connections - if connection.can_handle_request(origin) and connection.is_available() - ] - idle_connections = [ - connection for connection in self._connections if connection.is_idle() - ] - - # There are three cases for how we may be able to handle the request: - # - # 1. There is an existing connection that can handle the request. - # 2. We can create a new connection to handle the request. - # 3. We can close an idle connection and then create a new connection - # to handle the request. - if available_connections: - # log: "reusing existing connection" - connection = available_connections[0] - pool_request.assign_to_connection(connection) - elif len(self._connections) < self._max_connections: - # log: "creating new connection" - connection = self.create_connection(origin) - self._connections.append(connection) - pool_request.assign_to_connection(connection) - elif idle_connections: - # log: "closing idle connection" - connection = idle_connections[0] - self._connections.remove(connection) - closing_connections.append(connection) - # log: "creating new connection" - connection = self.create_connection(origin) - self._connections.append(connection) - pool_request.assign_to_connection(connection) - - return closing_connections - - async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None: - # Close connections which have been removed from the pool. - with AsyncShieldCancellation(): - for connection in closing: - await connection.aclose() + # timeouts = request.extensions.get("timeout", {}) + # timeout = timeouts.get("pool", None) + + async with self._limits: + connection = self._get_connection(request) + iterator = connection.iterate_response(request) + try: + response_start = await anext(iterator) + # Return the response status and headers. + yield response_start + # Return the response. + async for event in iterator: + yield event + finally: + await iterator.aclose() + closing = self._close_connections() + for conn in closing: + await conn.aclose() + + def _get_connection(self, request): + origin = request.url.origin + for connection in self._connections: + if connection.can_handle_request(origin) and connection.is_available(): + return connection + + connection = self.create_connection(origin) + self._connections.append(connection) + return connection + + def _close_connections(self): + closing = [conn for conn in self._connections if conn.has_expired()] + self._connections = [ + conn for conn in self._connections + if not (conn.has_expired() or conn.is_closed()) + ] + return closing async def aclose(self) -> None: # Explicitly close the connection pool. # Clears all existing requests and connections. - with self._optional_thread_lock: - closing_connections = list(self._connections) - self._connections = [] - await self._close_connections(closing_connections) + closing = list(self._connections) + self._connections = [] + for conn in closing: + await conn.aclose() async def __aenter__(self) -> AsyncConnectionPool: return self @@ -365,56 +245,12 @@ async def __aexit__( def __repr__(self) -> str: class_name = self.__class__.__name__ - with self._optional_thread_lock: - request_is_queued = [request.is_queued() for request in self._requests] - connection_is_idle = [ - connection.is_idle() for connection in self._connections - ] - - num_active_requests = request_is_queued.count(False) - num_queued_requests = request_is_queued.count(True) - num_active_connections = connection_is_idle.count(False) - num_idle_connections = connection_is_idle.count(True) - - requests_info = ( - f"Requests: {num_active_requests} active, {num_queued_requests} queued" - ) + connection_is_idle = [ + connection.is_idle() for connection in self._connections + ] + num_active_connections = connection_is_idle.count(False) + num_idle_connections = connection_is_idle.count(True) connection_info = ( f"Connections: {num_active_connections} active, {num_idle_connections} idle" ) - - return f"<{class_name} [{requests_info} | {connection_info}]>" - - -class PoolByteStream: - def __init__( - self, - stream: typing.AsyncIterable[bytes], - pool_request: AsyncPoolRequest, - pool: AsyncConnectionPool, - ) -> None: - self._stream = stream - self._pool_request = pool_request - self._pool = pool - self._closed = False - - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - try: - async for part in self._stream: - yield part - except BaseException as exc: - await self.aclose() - raise exc from None - - async def aclose(self) -> None: - if not self._closed: - self._closed = True - with AsyncShieldCancellation(): - if hasattr(self._stream, "aclose"): - await self._stream.aclose() - - with self._pool._optional_thread_lock: - self._pool._requests.remove(self._pool_request) - closing = self._pool._assign_requests_to_connections() - - await self._pool._close_connections(closing) + return f"<{class_name} [{connection_info}]>" diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index e6d6d7098..bba95eedd 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -17,10 +17,10 @@ WriteError, map_exceptions, ) -from .._models import Origin, Request, Response -from .._synchronization import AsyncLock, AsyncShieldCancellation +from .._models import Origin, Request +from .._synchronization import AsyncSemaphore from .._trace import Trace -from .interfaces import AsyncConnectionInterface +from .interfaces import AsyncConnectionInterface, StartResponse logger = logging.getLogger("httpcore.http11") @@ -55,21 +55,23 @@ def __init__( self._keepalive_expiry: float | None = keepalive_expiry self._expire_at: float | None = None self._state = HTTPConnectionState.NEW - self._state_lock = AsyncLock() + self._request_lock = AsyncSemaphore(bound=1) self._request_count = 0 self._h11_state = h11.Connection( our_role=h11.CLIENT, max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, ) - async def handle_async_request(self, request: Request) -> Response: + async def iterate_response( + self, request: Request + ) -> typing.AsyncIterator[StartResponse | bytes]: if not self.can_handle_request(request.url.origin): raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection " f"to {self._origin}" ) - async with self._state_lock: + async with self._request_lock: if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): self._request_count += 1 self._state = HTTPConnectionState.ACTIVE @@ -77,63 +79,69 @@ async def handle_async_request(self, request: Request) -> Response: else: raise ConnectionNotAvailable() - try: - kwargs = {"request": request} try: + kwargs = {"request": request} + try: + async with Trace( + "send_request_headers", logger, request, kwargs + ) as trace: + await self._send_request_headers(**kwargs) + async with Trace( + "send_request_body", logger, request, kwargs + ) as trace: + await self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + async with Trace( - "send_request_headers", logger, request, kwargs + "receive_response_headers", logger, request, kwargs ) as trace: - await self._send_request_headers(**kwargs) - async with Trace("send_request_body", logger, request, kwargs) as trace: - await self._send_request_body(**kwargs) - except WriteError: - # If we get a write error while we're writing the request, - # then we supress this error and move on to attempting to - # read the response. Servers can sometimes close the request - # pre-emptively and then respond with a well formed HTTP - # error response. - pass - - async with Trace( - "receive_response_headers", logger, request, kwargs - ) as trace: - ( - http_version, - status, - reason_phrase, - headers, - trailing_data, - ) = await self._receive_response_headers(**kwargs) - trace.return_value = ( - http_version, - status, - reason_phrase, - headers, + ( + http_version, + status, + reason_phrase, + headers, + trailing_data, + ) = await self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + network_stream = self._network_stream + + # CONNECT or Upgrade request + if (status == 101) or ( + (request.method == b"CONNECT") and (200 <= status < 300) + ): + network_stream = AsyncHTTP11UpgradeStream( + network_stream, trailing_data + ) + + yield StartResponse( + status=status, + headers=headers, + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": network_stream, + }, ) - - network_stream = self._network_stream - - # CONNECT or Upgrade request - if (status == 101) or ( - (request.method == b"CONNECT") and (200 <= status < 300) - ): - network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data) - - return Response( - status=status, - headers=headers, - content=HTTP11ConnectionByteStream(self, request), - extensions={ - "http_version": http_version, - "reason_phrase": reason_phrase, - "network_stream": network_stream, - }, - ) - except BaseException as exc: - with AsyncShieldCancellation(): + async with Trace("receive_response_body", logger, request, kwargs): + async for chunk in self._receive_response_body(**kwargs): + yield chunk + finally: + await self._response_closed() async with Trace("response_closed", logger, request) as trace: - await self._response_closed() - raise exc + if self.is_closed(): + await self.aclose() # Sending the request... @@ -236,18 +244,17 @@ async def _receive_event( return event # type: ignore[return-value] async def _response_closed(self) -> None: - async with self._state_lock: - if ( - self._h11_state.our_state is h11.DONE - and self._h11_state.their_state is h11.DONE - ): - self._state = HTTPConnectionState.IDLE - self._h11_state.start_next_cycle() - if self._keepalive_expiry is not None: - now = time.monotonic() - self._expire_at = now + self._keepalive_expiry - else: - await self.aclose() + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + self._state = HTTPConnectionState.CLOSED # Once the connection is no longer required... @@ -321,33 +328,6 @@ async def __aexit__( await self.aclose() -class HTTP11ConnectionByteStream: - def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: - self._connection = connection - self._request = request - self._closed = False - - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - kwargs = {"request": self._request} - try: - async with Trace("receive_response_body", logger, self._request, kwargs): - async for chunk in self._connection._receive_response_body(**kwargs): - yield chunk - except BaseException as exc: - # If we get an exception while streaming the response, - # we want to close the response (and possibly the connection) - # before raising that exception. - with AsyncShieldCancellation(): - await self.aclose() - raise exc - - async def aclose(self) -> None: - if not self._closed: - self._closed = True - async with Trace("response_closed", logger, self._request): - await self._connection._response_closed() - - class AsyncHTTP11UpgradeStream(AsyncNetworkStream): def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None: self._stream = stream diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index c6434a049..3406da00b 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -21,7 +21,7 @@ from .._models import Origin, Request, Response from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation from .._trace import Trace -from .interfaces import AsyncConnectionInterface +from .interfaces import AsyncConnectionInterface, StartResponse logger = logging.getLogger("httpcore.http2") @@ -60,6 +60,7 @@ def __init__( self._state_lock = AsyncLock() self._read_lock = AsyncLock() self._write_lock = AsyncLock() + self._max_streams_semaphore = AsyncSemaphore(100) self._sent_connection_init = False self._used_all_stream_ids = False self._connection_error = False @@ -80,7 +81,9 @@ def __init__( self._read_exception: Exception | None = None self._write_exception: Exception | None = None - async def handle_async_request(self, request: Request) -> Response: + async def iterate_response( + self, request: Request + ) -> typing.AsyncIterator[StartResponse | bytes]: if not self.can_handle_request(request.url.origin): # This cannot occur in normal operation, since the connection pool # will only send requests on connections that handle them. @@ -112,76 +115,65 @@ async def handle_async_request(self, request: Request) -> Response: self._sent_connection_init = True - # Initially start with just 1 until the remote server provides - # its max_concurrent_streams value - self._max_streams = 1 - - local_settings_max_streams = ( - self._h2_state.local_settings.max_concurrent_streams - ) - self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams) - - for _ in range(local_settings_max_streams - self._max_streams): - await self._max_streams_semaphore.acquire() - - await self._max_streams_semaphore.acquire() - - try: - stream_id = self._h2_state.get_next_available_stream_id() - self._events[stream_id] = [] - except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover - self._used_all_stream_ids = True - self._request_count -= 1 - raise ConnectionNotAvailable() + async with self._max_streams_semaphore: + try: + stream_id = self._h2_state.get_next_available_stream_id() + self._events[stream_id] = [] + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + self._request_count -= 1 + raise ConnectionNotAvailable() - try: - kwargs = {"request": request, "stream_id": stream_id} - async with Trace("send_request_headers", logger, request, kwargs): - await self._send_request_headers(request=request, stream_id=stream_id) - async with Trace("send_request_body", logger, request, kwargs): - await self._send_request_body(request=request, stream_id=stream_id) - async with Trace( - "receive_response_headers", logger, request, kwargs - ) as trace: - status, headers = await self._receive_response( - request=request, stream_id=stream_id + try: + kwargs = {"request": request, "stream_id": stream_id} + async with Trace("send_request_headers", logger, request, kwargs): + await self._send_request_headers(request=request, stream_id=stream_id) + async with Trace("send_request_body", logger, request, kwargs): + await self._send_request_body(request=request, stream_id=stream_id) + async with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + status, headers = await self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + yield StartResponse( + status=status, + headers=headers, + extensions={ + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + }, ) - trace.return_value = (status, headers) - - return Response( - status=status, - headers=headers, - content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), - extensions={ - "http_version": b"HTTP/2", - "network_stream": self._network_stream, - "stream_id": stream_id, - }, - ) - except BaseException as exc: # noqa: PIE786 - with AsyncShieldCancellation(): + async with Trace("receive_response_body", logger, request, kwargs): + async for chunk in self._receive_response_body( + request=request, stream_id=stream_id + ): + yield chunk + except BaseException as exc: # noqa: PIE786 + if isinstance(exc, h2.exceptions.ProtocolError): + # One case where h2 can raise a protocol error is when a + # closed frame has been seen by the state machine. + # + # This happens when one stream is reading, and encounters + # a GOAWAY event. Other flows of control may then raise + # a protocol error at any point they interact with the 'h2_state'. + # + # In this case we'll have stored the event, and should raise + # it as a RemoteProtocolError. + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) + # If h2 raises a protocol error in some other state then we + # must somehow have made a protocol violation. + raise LocalProtocolError(exc) # pragma: nocover + raise exc + finally: kwargs = {"stream_id": stream_id} async with Trace("response_closed", logger, request, kwargs): await self._response_closed(stream_id=stream_id) - if isinstance(exc, h2.exceptions.ProtocolError): - # One case where h2 can raise a protocol error is when a - # closed frame has been seen by the state machine. - # - # This happens when one stream is reading, and encounters - # a GOAWAY event. Other flows of control may then raise - # a protocol error at any point they interact with the 'h2_state'. - # - # In this case we'll have stored the event, and should raise - # it as a RemoteProtocolError. - if self._connection_terminated: # pragma: nocover - raise RemoteProtocolError(self._connection_terminated) - # If h2 raises a protocol error in some other state then we - # must somehow have made a protocol violation. - raise LocalProtocolError(exc) # pragma: nocover - - raise exc - async def _send_connection_init(self, request: Request) -> None: """ The HTTP/2 connection requires some initial setup before we can start @@ -356,14 +348,14 @@ async def _receive_events( if stream_id is None or not self._events.get(stream_id): events = await self._read_incoming_data(request) for event in events: - if isinstance(event, h2.events.RemoteSettingsChanged): - async with Trace( - "receive_remote_settings", logger, request - ) as trace: - await self._receive_remote_settings_change(event) - trace.return_value = event - - elif isinstance( + # if isinstance(event, h2.events.RemoteSettingsChanged): + # async with Trace( + # "receive_remote_settings", logger, request + # ) as trace: + # await self._receive_remote_settings_change(event) + # trace.return_value = event + + if isinstance( event, ( h2.events.ResponseReceived, @@ -380,25 +372,24 @@ async def _receive_events( await self._write_outgoing_data(request) - async def _receive_remote_settings_change(self, event: h2.events.Event) -> None: - max_concurrent_streams = event.changed_settings.get( - h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS - ) - if max_concurrent_streams: - new_max_streams = min( - max_concurrent_streams.new_value, - self._h2_state.local_settings.max_concurrent_streams, - ) - if new_max_streams and new_max_streams != self._max_streams: - while new_max_streams > self._max_streams: - await self._max_streams_semaphore.release() - self._max_streams += 1 - while new_max_streams < self._max_streams: - await self._max_streams_semaphore.acquire() - self._max_streams -= 1 + # async def _receive_remote_settings_change(self, event: h2.events.Event) -> None: + # max_concurrent_streams = event.changed_settings.get( + # h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS + # ) + # if max_concurrent_streams: + # new_max_streams = min( + # max_concurrent_streams.new_value, + # self._h2_state.local_settings.max_concurrent_streams, + # ) + # if new_max_streams and new_max_streams != self._max_streams: + # while new_max_streams > self._max_streams: + # await self._max_streams_semaphore.release() + # self._max_streams += 1 + # while new_max_streams < self._max_streams: + # await self._max_streams_semaphore.acquire() + # self._max_streams -= 1 async def _response_closed(self, stream_id: int) -> None: - await self._max_streams_semaphore.release() del self._events[stream_id] async with self._state_lock: if self._connection_terminated and not self._events: diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index cc9d92066..ac2a8e015 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -17,7 +17,7 @@ enforce_url, ) from .._ssl import default_ssl_context -from .._synchronization import AsyncLock +from .._synchronization import AsyncSemaphore from .._trace import Trace from .connection import AsyncHTTPConnection from .connection_pool import AsyncConnectionPool @@ -259,7 +259,7 @@ def __init__( self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 - self._connect_lock = AsyncLock() + self._connect_lock = AsyncSemaphore(bound=1) self._connected = False async def handle_async_request(self, request: Request) -> Response: diff --git a/httpcore/_async/interfaces.py b/httpcore/_async/interfaces.py index 361583bed..9b9000162 100644 --- a/httpcore/_async/interfaces.py +++ b/httpcore/_async/interfaces.py @@ -17,6 +17,33 @@ ) +class StartResponse: + def __init__(self, status: int, headers: HeaderTypes, extensions: Extensions): + self.status = status + self.headers = headers + self.extensions = extensions + + +class ResponseContext: + def __init__(self, status: int, headers: HeaderTypes, iterator, extensions: Extensions): + self._status = status + self._headers = headers + self._iterator = iterator + self._extensions = extensions + + async def __aenter__(self): + self._response = Response( + status=self._status, + headers=self._headers, + content=self._iterator, + extensions=self._extensions + ) + return self._response + + async def __aexit__(self, *args, **kwargs): + await self._response.aclose() + + class AsyncRequestInterface: async def request( self, @@ -42,12 +69,15 @@ async def request( content=content, extensions=extensions, ) - response = await self.handle_async_request(request) - try: - await response.aread() - finally: - await response.aclose() - return response + iterator = self.iterate_response(request) + start_response = await anext(iterator) + content = b"".join([part async for part in iterator]) + return Response( + status=start_response.status, + headers=start_response.headers, + content=content, + extensions=start_response.extensions, + ) @contextlib.asynccontextmanager async def stream( @@ -58,7 +88,7 @@ async def stream( headers: HeaderTypes = None, content: bytes | typing.AsyncIterator[bytes] | None = None, extensions: Extensions | None = None, - ) -> typing.AsyncIterator[Response]: + ) -> ResponseContext: # Strict type checking on our parameters. method = enforce_bytes(method, name="method") url = enforce_url(url, name="url") @@ -74,14 +104,24 @@ async def stream( content=content, extensions=extensions, ) - response = await self.handle_async_request(request) + iterator = self.iterate_response(request) + start_response = await anext(iterator) + response = Response( + status=start_response.status, + headers=start_response.headers, + content=iterator, + extensions=start_response.extensions, + ) try: yield response finally: await response.aclose() - async def handle_async_request(self, request: Request) -> Response: + async def iterate_response( + self, request: Request + ) -> typing.AsyncIterator[StartResponse | bytes]: raise NotImplementedError() # pragma: nocover + yield b'' class AsyncConnectionInterface(AsyncRequestInterface): diff --git a/httpcore/_models.py b/httpcore/_models.py index 8a65f1334..1b1b02b7a 100644 --- a/httpcore/_models.py +++ b/httpcore/_models.py @@ -397,6 +397,9 @@ def __init__( ) self.extensions = {} if extensions is None else extensions + if isinstance(content, bytes): + self._content = content + self._stream_consumed = False @property diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 363f8be81..b877eaf0d 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -9,12 +9,12 @@ from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream from .._exceptions import ConnectError, ConnectTimeout -from .._models import Origin, Request, Response +from .._models import Origin, Request from .._ssl import default_ssl_context -from .._synchronization import Lock +from .._synchronization import Semaphore from .._trace import Trace from .http11 import HTTP11Connection -from .interfaces import ConnectionInterface +from .interfaces import ConnectionInterface, StartResponse RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. @@ -63,10 +63,10 @@ def __init__( ) self._connection: ConnectionInterface | None = None self._connect_failed: bool = False - self._request_lock = Lock() + self._request_lock = Semaphore(bound=1) self._socket_options = socket_options - def handle_request(self, request: Request) -> Response: + def iterate_response(self, request: Request) -> typing.Iterator[StartResponse | bytes]: if not self.can_handle_request(request.url.origin): raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection to {self._origin}" @@ -100,7 +100,12 @@ def handle_request(self, request: Request) -> Response: self._connect_failed = True raise exc - return self._connection.handle_request(request) + # iterator = self._connection.iterate_response(request) + iterator = self._connection.iterate_response(request) + start_response = next(iterator) + yield start_response + for body in iterator: + yield body def _connect(self, request: Request) -> NetworkStream: timeouts = request.extensions.get("timeout", {}) @@ -174,14 +179,7 @@ def close(self) -> None: def is_available(self) -> bool: if self._connection is None: - # If HTTP/2 support is enabled, and the resulting connection could - # end up as HTTP/2 then we should indicate the connection as being - # available to service multiple requests. - return ( - self._http2 - and (self._origin.scheme == b"https" or not self._http1) - and not self._connect_failed - ) + return False return self._connection.is_available() def has_expired(self) -> bool: diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 9ccfa53e5..63a9799d7 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -2,42 +2,15 @@ import ssl import sys -import types import typing from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend -from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol -from .._models import Origin, Proxy, Request, Response -from .._synchronization import Event, ShieldCancellation, ThreadLock +from .._exceptions import UnsupportedProtocol +from .._models import Origin, Proxy, Request +from .._synchronization import Semaphore from .connection import HTTPConnection -from .interfaces import ConnectionInterface, RequestInterface - - -class PoolRequest: - def __init__(self, request: Request) -> None: - self.request = request - self.connection: ConnectionInterface | None = None - self._connection_acquired = Event() - - def assign_to_connection(self, connection: ConnectionInterface | None) -> None: - self.connection = connection - self._connection_acquired.set() - - def clear_connection(self) -> None: - self.connection = None - self._connection_acquired = Event() - - def wait_for_connection( - self, timeout: float | None = None - ) -> ConnectionInterface: - if self.connection is None: - self._connection_acquired.wait(timeout=timeout) - assert self.connection is not None - return self.connection - - def is_queued(self) -> bool: - return self.connection is None +from .interfaces import ConnectionInterface, RequestInterface, StartResponse class ConnectionPool(RequestInterface): @@ -49,6 +22,7 @@ def __init__( self, ssl_context: ssl.SSLContext | None = None, proxy: Proxy | None = None, + concurrency_limit: int = 100, max_connections: int | None = 10, max_keepalive_connections: int | None = None, keepalive_expiry: float | None = None, @@ -102,6 +76,7 @@ def __init__( self._max_keepalive_connections = min( self._max_connections, self._max_keepalive_connections ) + self._limits = Semaphore(bound=concurrency_limit) self._keepalive_expiry = keepalive_expiry self._http1 = http1 @@ -123,7 +98,7 @@ def __init__( # We only mutate the state of the connection pool within an 'optional_thread_lock' # context. This holds a threading lock unless we're running in async mode, # in which case it is a no-op. - self._optional_thread_lock = ThreadLock() + # self._optional_thread_lock = ThreadLock() def create_connection(self, origin: Origin) -> ConnectionInterface: if self._proxy is not None: @@ -196,7 +171,7 @@ def connections(self) -> list[ConnectionInterface]: """ return list(self._connections) - def handle_request(self, request: Request) -> Response: + def iterate_response(self, request: Request) -> typing.Iterator[StartResponse | bytes]: """ Send an HTTP request, and return an HTTP response. @@ -212,145 +187,50 @@ def handle_request(self, request: Request) -> Response: f"Request URL has an unsupported protocol '{scheme}://'." ) - timeouts = request.extensions.get("timeout", {}) - timeout = timeouts.get("pool", None) - - with self._optional_thread_lock: - # Add the incoming request to our request queue. - pool_request = PoolRequest(request) - self._requests.append(pool_request) - - try: - while True: - with self._optional_thread_lock: - # Assign incoming requests to available connections, - # closing or creating new connections as required. - closing = self._assign_requests_to_connections() - self._close_connections(closing) - - # Wait until this request has an assigned connection. - connection = pool_request.wait_for_connection(timeout=timeout) - - try: - # Send the request on the assigned connection. - response = connection.handle_request( - pool_request.request - ) - except ConnectionNotAvailable: - # In some cases a connection may initially be available to - # handle a request, but then become unavailable. - # - # In this case we clear the connection and try again. - pool_request.clear_connection() - else: - break # pragma: nocover - - except BaseException as exc: - with self._optional_thread_lock: - # For any exception or cancellation we remove the request from - # the queue, and then re-assign requests to connections. - self._requests.remove(pool_request) - closing = self._assign_requests_to_connections() - - self._close_connections(closing) - raise exc from None - - # Return the response. Note that in this case we still have to manage - # the point at which the response is closed. - assert isinstance(response.stream, typing.Iterable) - return Response( - status=response.status, - headers=response.headers, - content=PoolByteStream( - stream=response.stream, pool_request=pool_request, pool=self - ), - extensions=response.extensions, - ) - - def _assign_requests_to_connections(self) -> list[ConnectionInterface]: - """ - Manage the state of the connection pool, assigning incoming - requests to connections as available. - - Called whenever a new request is added or removed from the pool. - - Any closing connections are returned, allowing the I/O for closing - those connections to be handled seperately. - """ - closing_connections = [] - - # First we handle cleaning up any connections that are closed, - # have expired their keep-alive, or surplus idle connections. - for connection in list(self._connections): - if connection.is_closed(): - # log: "removing closed connection" - self._connections.remove(connection) - elif connection.has_expired(): - # log: "closing expired connection" - self._connections.remove(connection) - closing_connections.append(connection) - elif ( - connection.is_idle() - and len([connection.is_idle() for connection in self._connections]) - > self._max_keepalive_connections - ): - # log: "closing idle connection" - self._connections.remove(connection) - closing_connections.append(connection) - - # Assign queued requests to connections. - queued_requests = [request for request in self._requests if request.is_queued()] - for pool_request in queued_requests: - origin = pool_request.request.url.origin - available_connections = [ - connection - for connection in self._connections - if connection.can_handle_request(origin) and connection.is_available() - ] - idle_connections = [ - connection for connection in self._connections if connection.is_idle() - ] - - # There are three cases for how we may be able to handle the request: - # - # 1. There is an existing connection that can handle the request. - # 2. We can create a new connection to handle the request. - # 3. We can close an idle connection and then create a new connection - # to handle the request. - if available_connections: - # log: "reusing existing connection" - connection = available_connections[0] - pool_request.assign_to_connection(connection) - elif len(self._connections) < self._max_connections: - # log: "creating new connection" - connection = self.create_connection(origin) - self._connections.append(connection) - pool_request.assign_to_connection(connection) - elif idle_connections: - # log: "closing idle connection" - connection = idle_connections[0] - self._connections.remove(connection) - closing_connections.append(connection) - # log: "creating new connection" - connection = self.create_connection(origin) - self._connections.append(connection) - pool_request.assign_to_connection(connection) - - return closing_connections - - def _close_connections(self, closing: list[ConnectionInterface]) -> None: - # Close connections which have been removed from the pool. - with ShieldCancellation(): - for connection in closing: - connection.close() + # timeouts = request.extensions.get("timeout", {}) + # timeout = timeouts.get("pool", None) + + with self._limits: + connection = self._get_connection(request) + iterator = connection.iterate_response(request) + try: + response_start = next(iterator) + # Return the response status and headers. + yield response_start + # Return the response. + for event in iterator: + yield event + finally: + iterator.close() + closing = self._close_connections() + for conn in closing: + conn.close() + + def _get_connection(self, request): + origin = request.url.origin + for connection in self._connections: + if connection.can_handle_request(origin) and connection.is_available(): + return connection + + connection = self.create_connection(origin) + self._connections.append(connection) + return connection + + def _close_connections(self): + closing = [conn for conn in self._connections if conn.has_expired()] + self._connections = [ + conn for conn in self._connections + if not (conn.has_expired() or conn.is_closed()) + ] + return closing def close(self) -> None: # Explicitly close the connection pool. # Clears all existing requests and connections. - with self._optional_thread_lock: - closing_connections = list(self._connections) - self._connections = [] - self._close_connections(closing_connections) + closing = list(self._connections) + self._connections = [] + for conn in closing: + conn.close() def __enter__(self) -> ConnectionPool: return self @@ -365,56 +245,12 @@ def __exit__( def __repr__(self) -> str: class_name = self.__class__.__name__ - with self._optional_thread_lock: - request_is_queued = [request.is_queued() for request in self._requests] - connection_is_idle = [ - connection.is_idle() for connection in self._connections - ] - - num_active_requests = request_is_queued.count(False) - num_queued_requests = request_is_queued.count(True) - num_active_connections = connection_is_idle.count(False) - num_idle_connections = connection_is_idle.count(True) - - requests_info = ( - f"Requests: {num_active_requests} active, {num_queued_requests} queued" - ) + connection_is_idle = [ + connection.is_idle() for connection in self._connections + ] + num_active_connections = connection_is_idle.count(False) + num_idle_connections = connection_is_idle.count(True) connection_info = ( f"Connections: {num_active_connections} active, {num_idle_connections} idle" ) - - return f"<{class_name} [{requests_info} | {connection_info}]>" - - -class PoolByteStream: - def __init__( - self, - stream: typing.Iterable[bytes], - pool_request: PoolRequest, - pool: ConnectionPool, - ) -> None: - self._stream = stream - self._pool_request = pool_request - self._pool = pool - self._closed = False - - def __iter__(self) -> typing.Iterator[bytes]: - try: - for part in self._stream: - yield part - except BaseException as exc: - self.close() - raise exc from None - - def close(self) -> None: - if not self._closed: - self._closed = True - with ShieldCancellation(): - if hasattr(self._stream, "close"): - self._stream.close() - - with self._pool._optional_thread_lock: - self._pool._requests.remove(self._pool_request) - closing = self._pool._assign_requests_to_connections() - - self._pool._close_connections(closing) + return f"<{class_name} [{connection_info}]>" diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index ebd3a9748..fdf2df2d1 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -17,10 +17,10 @@ WriteError, map_exceptions, ) -from .._models import Origin, Request, Response -from .._synchronization import Lock, ShieldCancellation +from .._models import Origin, Request +from .._synchronization import Semaphore from .._trace import Trace -from .interfaces import ConnectionInterface +from .interfaces import ConnectionInterface, StartResponse logger = logging.getLogger("httpcore.http11") @@ -55,21 +55,23 @@ def __init__( self._keepalive_expiry: float | None = keepalive_expiry self._expire_at: float | None = None self._state = HTTPConnectionState.NEW - self._state_lock = Lock() + self._request_lock = Semaphore(bound=1) self._request_count = 0 self._h11_state = h11.Connection( our_role=h11.CLIENT, max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, ) - def handle_request(self, request: Request) -> Response: + def iterate_response( + self, request: Request + ) -> typing.Iterator[StartResponse | bytes]: if not self.can_handle_request(request.url.origin): raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection " f"to {self._origin}" ) - with self._state_lock: + with self._request_lock: if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): self._request_count += 1 self._state = HTTPConnectionState.ACTIVE @@ -77,63 +79,69 @@ def handle_request(self, request: Request) -> Response: else: raise ConnectionNotAvailable() - try: - kwargs = {"request": request} try: + kwargs = {"request": request} + try: + with Trace( + "send_request_headers", logger, request, kwargs + ) as trace: + self._send_request_headers(**kwargs) + with Trace( + "send_request_body", logger, request, kwargs + ) as trace: + self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + with Trace( - "send_request_headers", logger, request, kwargs + "receive_response_headers", logger, request, kwargs ) as trace: - self._send_request_headers(**kwargs) - with Trace("send_request_body", logger, request, kwargs) as trace: - self._send_request_body(**kwargs) - except WriteError: - # If we get a write error while we're writing the request, - # then we supress this error and move on to attempting to - # read the response. Servers can sometimes close the request - # pre-emptively and then respond with a well formed HTTP - # error response. - pass - - with Trace( - "receive_response_headers", logger, request, kwargs - ) as trace: - ( - http_version, - status, - reason_phrase, - headers, - trailing_data, - ) = self._receive_response_headers(**kwargs) - trace.return_value = ( - http_version, - status, - reason_phrase, - headers, + ( + http_version, + status, + reason_phrase, + headers, + trailing_data, + ) = self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + network_stream = self._network_stream + + # CONNECT or Upgrade request + if (status == 101) or ( + (request.method == b"CONNECT") and (200 <= status < 300) + ): + network_stream = HTTP11UpgradeStream( + network_stream, trailing_data + ) + + yield StartResponse( + status=status, + headers=headers, + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": network_stream, + }, ) - - network_stream = self._network_stream - - # CONNECT or Upgrade request - if (status == 101) or ( - (request.method == b"CONNECT") and (200 <= status < 300) - ): - network_stream = HTTP11UpgradeStream(network_stream, trailing_data) - - return Response( - status=status, - headers=headers, - content=HTTP11ConnectionByteStream(self, request), - extensions={ - "http_version": http_version, - "reason_phrase": reason_phrase, - "network_stream": network_stream, - }, - ) - except BaseException as exc: - with ShieldCancellation(): + with Trace("receive_response_body", logger, request, kwargs): + for chunk in self._receive_response_body(**kwargs): + yield chunk + finally: + self._response_closed() with Trace("response_closed", logger, request) as trace: - self._response_closed() - raise exc + if self.is_closed(): + self.close() # Sending the request... @@ -236,18 +244,17 @@ def _receive_event( return event # type: ignore[return-value] def _response_closed(self) -> None: - with self._state_lock: - if ( - self._h11_state.our_state is h11.DONE - and self._h11_state.their_state is h11.DONE - ): - self._state = HTTPConnectionState.IDLE - self._h11_state.start_next_cycle() - if self._keepalive_expiry is not None: - now = time.monotonic() - self._expire_at = now + self._keepalive_expiry - else: - self.close() + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + self._state = HTTPConnectionState.CLOSED # Once the connection is no longer required... @@ -321,33 +328,6 @@ def __exit__( self.close() -class HTTP11ConnectionByteStream: - def __init__(self, connection: HTTP11Connection, request: Request) -> None: - self._connection = connection - self._request = request - self._closed = False - - def __iter__(self) -> typing.Iterator[bytes]: - kwargs = {"request": self._request} - try: - with Trace("receive_response_body", logger, self._request, kwargs): - for chunk in self._connection._receive_response_body(**kwargs): - yield chunk - except BaseException as exc: - # If we get an exception while streaming the response, - # we want to close the response (and possibly the connection) - # before raising that exception. - with ShieldCancellation(): - self.close() - raise exc - - def close(self) -> None: - if not self._closed: - self._closed = True - with Trace("response_closed", logger, self._request): - self._connection._response_closed() - - class HTTP11UpgradeStream(NetworkStream): def __init__(self, stream: NetworkStream, leading_data: bytes) -> None: self._stream = stream diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index ecca88f7d..dea1effe2 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -17,7 +17,7 @@ enforce_url, ) from .._ssl import default_ssl_context -from .._synchronization import Lock +from .._synchronization import Semaphore from .._trace import Trace from .connection import HTTPConnection from .connection_pool import ConnectionPool @@ -259,7 +259,7 @@ def __init__( self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 - self._connect_lock = Lock() + self._connect_lock = Semaphore(bound=1) self._connected = False def handle_request(self, request: Request) -> Response: diff --git a/httpcore/_sync/interfaces.py b/httpcore/_sync/interfaces.py index e673d4cc1..77860234b 100644 --- a/httpcore/_sync/interfaces.py +++ b/httpcore/_sync/interfaces.py @@ -17,6 +17,33 @@ ) +class StartResponse: + def __init__(self, status: int, headers: HeaderTypes, extensions: Extensions): + self.status = status + self.headers = headers + self.extensions = extensions + + +class ResponseContext: + def __init__(self, status: int, headers: HeaderTypes, iterator, extensions: Extensions): + self._status = status + self._headers = headers + self._iterator = iterator + self._extensions = extensions + + def __enter__(self): + self._response = Response( + status=self._status, + headers=self._headers, + content=self._iterator, + extensions=self._extensions + ) + return self._response + + def __exit__(self, *args, **kwargs): + self._response.close() + + class RequestInterface: def request( self, @@ -42,12 +69,15 @@ def request( content=content, extensions=extensions, ) - response = self.handle_request(request) - try: - response.read() - finally: - response.close() - return response + iterator = self.iterate_response(request) + start_response = next(iterator) + content = b"".join([part for part in iterator]) + return Response( + status=start_response.status, + headers=start_response.headers, + content=content, + extensions=start_response.extensions, + ) @contextlib.contextmanager def stream( @@ -58,7 +88,7 @@ def stream( headers: HeaderTypes = None, content: bytes | typing.Iterator[bytes] | None = None, extensions: Extensions | None = None, - ) -> typing.Iterator[Response]: + ) -> ResponseContext: # Strict type checking on our parameters. method = enforce_bytes(method, name="method") url = enforce_url(url, name="url") @@ -74,13 +104,22 @@ def stream( content=content, extensions=extensions, ) - response = self.handle_request(request) + iterator = self.iterate_response(request) + start_response = next(iterator) + response = Response( + status=start_response.status, + headers=start_response.headers, + content=iterator, + extensions=start_response.extensions, + ) try: yield response finally: response.close() - def handle_request(self, request: Request) -> Response: + def iterate_response( + self, request: Request + ) -> typing.Iterator[StartResponse | bytes]: raise NotImplementedError() # pragma: nocover diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index 2ecc9e9c3..892130638 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -171,7 +171,7 @@ def setup(self) -> None: initial_value=self._bound, max_value=self._bound ) - async def acquire(self) -> None: + async def __aenter__(self) -> None: if not self._backend: self.setup() @@ -180,7 +180,12 @@ async def acquire(self) -> None: elif self._backend == "asyncio": await self._anyio_semaphore.acquire() - async def release(self) -> None: + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: if self._backend == "trio": self._trio_semaphore.release() elif self._backend == "asyncio": @@ -295,10 +300,15 @@ class Semaphore: def __init__(self, bound: int) -> None: self._semaphore = threading.Semaphore(value=bound) - def acquire(self) -> None: + def __enter__(self) -> None: self._semaphore.acquire() - def release(self) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ) -> None: self._semaphore.release() diff --git a/scripts/unasync.py b/scripts/unasync.py index 5a5627d71..f724df30a 100644 --- a/scripts/unasync.py +++ b/scripts/unasync.py @@ -17,6 +17,7 @@ ('aclose', 'close'), ('aiter_stream', 'iter_stream'), ('aread', 'read'), + ('anext', 'next'), ('asynccontextmanager', 'contextmanager'), ('__aenter__', '__enter__'), ('__aexit__', '__exit__'), diff --git a/tests/_async/test_connection.py b/tests/_async/test_connection.py index b6ee0c7e3..a31b4f8da 100644 --- a/tests/_async/test_connection.py +++ b/tests/_async/test_connection.py @@ -61,29 +61,29 @@ async def test_http_connection(): ) -@pytest.mark.anyio -async def test_concurrent_requests_not_available_on_http11_connections(): - """ - Attempting to issue a request against an already active HTTP/1.1 connection - will raise a `ConnectionNotAvailable` exception. - """ - origin = Origin(b"https", b"example.com", 443) - network_backend = AsyncMockBackend( - [ - b"HTTP/1.1 200 OK\r\n", - b"Content-Type: plain/text\r\n", - b"Content-Length: 13\r\n", - b"\r\n", - b"Hello, world!", - ] - ) - - async with AsyncHTTPConnection( - origin=origin, network_backend=network_backend, keepalive_expiry=5.0 - ) as conn: - async with conn.stream("GET", "https://example.com/"): - with pytest.raises(ConnectionNotAvailable): - await conn.request("GET", "https://example.com/") +# @pytest.mark.anyio +# async def test_concurrent_requests_not_available_on_http11_connections(): +# """ +# Attempting to issue a request against an already active HTTP/1.1 connection +# will raise a `ConnectionNotAvailable` exception. +# """ +# origin = Origin(b"https", b"example.com", 443) +# network_backend = AsyncMockBackend( +# [ +# b"HTTP/1.1 200 OK\r\n", +# b"Content-Type: plain/text\r\n", +# b"Content-Length: 13\r\n", +# b"\r\n", +# b"Hello, world!", +# ] +# ) + +# async with AsyncHTTPConnection( +# origin=origin, network_backend=network_backend, keepalive_expiry=5.0 +# ) as conn: +# async with conn.stream("GET", "https://example.com/"): +# with pytest.raises(ConnectionNotAvailable): +# await conn.request("GET", "https://example.com/") @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") diff --git a/tests/_sync/test_connection.py b/tests/_sync/test_connection.py index 37c82e025..3dc848532 100644 --- a/tests/_sync/test_connection.py +++ b/tests/_sync/test_connection.py @@ -61,29 +61,29 @@ def test_http_connection(): ) - -def test_concurrent_requests_not_available_on_http11_connections(): - """ - Attempting to issue a request against an already active HTTP/1.1 connection - will raise a `ConnectionNotAvailable` exception. - """ - origin = Origin(b"https", b"example.com", 443) - network_backend = MockBackend( - [ - b"HTTP/1.1 200 OK\r\n", - b"Content-Type: plain/text\r\n", - b"Content-Length: 13\r\n", - b"\r\n", - b"Hello, world!", - ] - ) - - with HTTPConnection( - origin=origin, network_backend=network_backend, keepalive_expiry=5.0 - ) as conn: - with conn.stream("GET", "https://example.com/"): - with pytest.raises(ConnectionNotAvailable): - conn.request("GET", "https://example.com/") +# @pytest.mark.anyio +# def test_concurrent_requests_not_available_on_http11_connections(): +# """ +# Attempting to issue a request against an already active HTTP/1.1 connection +# will raise a `ConnectionNotAvailable` exception. +# """ +# origin = Origin(b"https", b"example.com", 443) +# network_backend = MockBackend( +# [ +# b"HTTP/1.1 200 OK\r\n", +# b"Content-Type: plain/text\r\n", +# b"Content-Length: 13\r\n", +# b"\r\n", +# b"Hello, world!", +# ] +# ) + +# with HTTPConnection( +# origin=origin, network_backend=network_backend, keepalive_expiry=5.0 +# ) as conn: +# with conn.stream("GET", "https://example.com/"): +# with pytest.raises(ConnectionNotAvailable): +# conn.request("GET", "https://example.com/") @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") From 98ee5dc01ffc94cad73fcc97461c674c64f75621 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 15 Nov 2024 10:46:30 +0000 Subject: [PATCH 5/5] Revert "Iterate refactor" This reverts commit ee9cfe10e7673179bfd802a0737e7143bae4a309. --- httpcore/_async/connection.py | 25 +-- httpcore/_async/connection_pool.py | 272 +++++++++++++++++++++++------ httpcore/_async/http11.py | 174 ++++++++++-------- httpcore/_async/http2.py | 173 +++++++++--------- httpcore/_async/http_proxy.py | 4 +- httpcore/_async/interfaces.py | 58 +----- httpcore/_models.py | 3 - httpcore/_sync/connection.py | 26 +-- httpcore/_sync/connection_pool.py | 272 +++++++++++++++++++++++------ httpcore/_sync/http11.py | 174 ++++++++++-------- httpcore/_sync/http_proxy.py | 4 +- httpcore/_sync/interfaces.py | 57 +----- httpcore/_synchronization.py | 18 +- scripts/unasync.py | 1 - tests/_async/test_connection.py | 46 ++--- tests/_sync/test_connection.py | 46 ++--- 16 files changed, 821 insertions(+), 532 deletions(-) diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index bb6d1709a..b42581dff 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -9,12 +9,12 @@ from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream from .._exceptions import ConnectError, ConnectTimeout -from .._models import Origin, Request +from .._models import Origin, Request, Response from .._ssl import default_ssl_context -from .._synchronization import AsyncSemaphore +from .._synchronization import AsyncLock from .._trace import Trace from .http11 import AsyncHTTP11Connection -from .interfaces import AsyncConnectionInterface, StartResponse +from .interfaces import AsyncConnectionInterface RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. @@ -63,10 +63,10 @@ def __init__( ) self._connection: AsyncConnectionInterface | None = None self._connect_failed: bool = False - self._request_lock = AsyncSemaphore(bound=1) + self._request_lock = AsyncLock() self._socket_options = socket_options - async def iterate_response(self, request: Request) -> typing.AsyncIterator[StartResponse | bytes]: + async def handle_async_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection to {self._origin}" @@ -100,11 +100,7 @@ async def iterate_response(self, request: Request) -> typing.AsyncIterator[Start self._connect_failed = True raise exc - iterator = self._connection.iterate_response(request) - start_response = await anext(iterator) - yield start_response - async for body in iterator: - yield body + return await self._connection.handle_async_request(request) async def _connect(self, request: Request) -> AsyncNetworkStream: timeouts = request.extensions.get("timeout", {}) @@ -178,7 +174,14 @@ async def aclose(self) -> None: def is_available(self) -> bool: if self._connection is None: - return False + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) return self._connection.is_available() def has_expired(self) -> bool: diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 23805e8c4..96e973d0c 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -2,15 +2,42 @@ import ssl import sys +import types import typing from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend -from .._exceptions import UnsupportedProtocol -from .._models import Origin, Proxy, Request -from .._synchronization import AsyncSemaphore +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Proxy, Request, Response +from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock from .connection import AsyncHTTPConnection -from .interfaces import AsyncConnectionInterface, AsyncRequestInterface, StartResponse +from .interfaces import AsyncConnectionInterface, AsyncRequestInterface + + +class AsyncPoolRequest: + def __init__(self, request: Request) -> None: + self.request = request + self.connection: AsyncConnectionInterface | None = None + self._connection_acquired = AsyncEvent() + + def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None: + self.connection = connection + self._connection_acquired.set() + + def clear_connection(self) -> None: + self.connection = None + self._connection_acquired = AsyncEvent() + + async def wait_for_connection( + self, timeout: float | None = None + ) -> AsyncConnectionInterface: + if self.connection is None: + await self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection + + def is_queued(self) -> bool: + return self.connection is None class AsyncConnectionPool(AsyncRequestInterface): @@ -22,7 +49,6 @@ def __init__( self, ssl_context: ssl.SSLContext | None = None, proxy: Proxy | None = None, - concurrency_limit: int = 100, max_connections: int | None = 10, max_keepalive_connections: int | None = None, keepalive_expiry: float | None = None, @@ -76,7 +102,6 @@ def __init__( self._max_keepalive_connections = min( self._max_connections, self._max_keepalive_connections ) - self._limits = AsyncSemaphore(bound=concurrency_limit) self._keepalive_expiry = keepalive_expiry self._http1 = http1 @@ -98,7 +123,7 @@ def __init__( # We only mutate the state of the connection pool within an 'optional_thread_lock' # context. This holds a threading lock unless we're running in async mode, # in which case it is a no-op. - # self._optional_thread_lock = AsyncThreadLock() + self._optional_thread_lock = AsyncThreadLock() def create_connection(self, origin: Origin) -> AsyncConnectionInterface: if self._proxy is not None: @@ -171,7 +196,7 @@ def connections(self) -> list[AsyncConnectionInterface]: """ return list(self._connections) - async def iterate_response(self, request: Request) -> typing.AsyncIterator[StartResponse | bytes]: + async def handle_async_request(self, request: Request) -> Response: """ Send an HTTP request, and return an HTTP response. @@ -187,50 +212,145 @@ async def iterate_response(self, request: Request) -> typing.AsyncIterator[Start f"Request URL has an unsupported protocol '{scheme}://'." ) - # timeouts = request.extensions.get("timeout", {}) - # timeout = timeouts.get("pool", None) - - async with self._limits: - connection = self._get_connection(request) - iterator = connection.iterate_response(request) - try: - response_start = await anext(iterator) - # Return the response status and headers. - yield response_start - # Return the response. - async for event in iterator: - yield event - finally: - await iterator.aclose() - closing = self._close_connections() - for conn in closing: - await conn.aclose() - - def _get_connection(self, request): - origin = request.url.origin - for connection in self._connections: - if connection.can_handle_request(origin) and connection.is_available(): - return connection - - connection = self.create_connection(origin) - self._connections.append(connection) - return connection - - def _close_connections(self): - closing = [conn for conn in self._connections if conn.has_expired()] - self._connections = [ - conn for conn in self._connections - if not (conn.has_expired() or conn.is_closed()) - ] - return closing + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + + with self._optional_thread_lock: + # Add the incoming request to our request queue. + pool_request = AsyncPoolRequest(request) + self._requests.append(pool_request) + + try: + while True: + with self._optional_thread_lock: + # Assign incoming requests to available connections, + # closing or creating new connections as required. + closing = self._assign_requests_to_connections() + await self._close_connections(closing) + + # Wait until this request has an assigned connection. + connection = await pool_request.wait_for_connection(timeout=timeout) + + try: + # Send the request on the assigned connection. + response = await connection.handle_async_request( + pool_request.request + ) + except ConnectionNotAvailable: + # In some cases a connection may initially be available to + # handle a request, but then become unavailable. + # + # In this case we clear the connection and try again. + pool_request.clear_connection() + else: + break # pragma: nocover + + except BaseException as exc: + with self._optional_thread_lock: + # For any exception or cancellation we remove the request from + # the queue, and then re-assign requests to connections. + self._requests.remove(pool_request) + closing = self._assign_requests_to_connections() + + await self._close_connections(closing) + raise exc from None + + # Return the response. Note that in this case we still have to manage + # the point at which the response is closed. + assert isinstance(response.stream, typing.AsyncIterable) + return Response( + status=response.status, + headers=response.headers, + content=PoolByteStream( + stream=response.stream, pool_request=pool_request, pool=self + ), + extensions=response.extensions, + ) + + def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]: + """ + Manage the state of the connection pool, assigning incoming + requests to connections as available. + + Called whenever a new request is added or removed from the pool. + + Any closing connections are returned, allowing the I/O for closing + those connections to be handled seperately. + """ + closing_connections = [] + + # First we handle cleaning up any connections that are closed, + # have expired their keep-alive, or surplus idle connections. + for connection in list(self._connections): + if connection.is_closed(): + # log: "removing closed connection" + self._connections.remove(connection) + elif connection.has_expired(): + # log: "closing expired connection" + self._connections.remove(connection) + closing_connections.append(connection) + elif ( + connection.is_idle() + and len([connection.is_idle() for connection in self._connections]) + > self._max_keepalive_connections + ): + # log: "closing idle connection" + self._connections.remove(connection) + closing_connections.append(connection) + + # Assign queued requests to connections. + queued_requests = [request for request in self._requests if request.is_queued()] + for pool_request in queued_requests: + origin = pool_request.request.url.origin + available_connections = [ + connection + for connection in self._connections + if connection.can_handle_request(origin) and connection.is_available() + ] + idle_connections = [ + connection for connection in self._connections if connection.is_idle() + ] + + # There are three cases for how we may be able to handle the request: + # + # 1. There is an existing connection that can handle the request. + # 2. We can create a new connection to handle the request. + # 3. We can close an idle connection and then create a new connection + # to handle the request. + if available_connections: + # log: "reusing existing connection" + connection = available_connections[0] + pool_request.assign_to_connection(connection) + elif len(self._connections) < self._max_connections: + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + elif idle_connections: + # log: "closing idle connection" + connection = idle_connections[0] + self._connections.remove(connection) + closing_connections.append(connection) + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + + return closing_connections + + async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None: + # Close connections which have been removed from the pool. + with AsyncShieldCancellation(): + for connection in closing: + await connection.aclose() async def aclose(self) -> None: # Explicitly close the connection pool. # Clears all existing requests and connections. - closing = list(self._connections) - self._connections = [] - for conn in closing: - await conn.aclose() + with self._optional_thread_lock: + closing_connections = list(self._connections) + self._connections = [] + await self._close_connections(closing_connections) async def __aenter__(self) -> AsyncConnectionPool: return self @@ -245,12 +365,56 @@ async def __aexit__( def __repr__(self) -> str: class_name = self.__class__.__name__ - connection_is_idle = [ - connection.is_idle() for connection in self._connections - ] - num_active_connections = connection_is_idle.count(False) - num_idle_connections = connection_is_idle.count(True) + with self._optional_thread_lock: + request_is_queued = [request.is_queued() for request in self._requests] + connection_is_idle = [ + connection.is_idle() for connection in self._connections + ] + + num_active_requests = request_is_queued.count(False) + num_queued_requests = request_is_queued.count(True) + num_active_connections = connection_is_idle.count(False) + num_idle_connections = connection_is_idle.count(True) + + requests_info = ( + f"Requests: {num_active_requests} active, {num_queued_requests} queued" + ) connection_info = ( f"Connections: {num_active_connections} active, {num_idle_connections} idle" ) - return f"<{class_name} [{connection_info}]>" + + return f"<{class_name} [{requests_info} | {connection_info}]>" + + +class PoolByteStream: + def __init__( + self, + stream: typing.AsyncIterable[bytes], + pool_request: AsyncPoolRequest, + pool: AsyncConnectionPool, + ) -> None: + self._stream = stream + self._pool_request = pool_request + self._pool = pool + self._closed = False + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + try: + async for part in self._stream: + yield part + except BaseException as exc: + await self.aclose() + raise exc from None + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + with AsyncShieldCancellation(): + if hasattr(self._stream, "aclose"): + await self._stream.aclose() + + with self._pool._optional_thread_lock: + self._pool._requests.remove(self._pool_request) + closing = self._pool._assign_requests_to_connections() + + await self._pool._close_connections(closing) diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index bba95eedd..e6d6d7098 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -17,10 +17,10 @@ WriteError, map_exceptions, ) -from .._models import Origin, Request -from .._synchronization import AsyncSemaphore +from .._models import Origin, Request, Response +from .._synchronization import AsyncLock, AsyncShieldCancellation from .._trace import Trace -from .interfaces import AsyncConnectionInterface, StartResponse +from .interfaces import AsyncConnectionInterface logger = logging.getLogger("httpcore.http11") @@ -55,23 +55,21 @@ def __init__( self._keepalive_expiry: float | None = keepalive_expiry self._expire_at: float | None = None self._state = HTTPConnectionState.NEW - self._request_lock = AsyncSemaphore(bound=1) + self._state_lock = AsyncLock() self._request_count = 0 self._h11_state = h11.Connection( our_role=h11.CLIENT, max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, ) - async def iterate_response( - self, request: Request - ) -> typing.AsyncIterator[StartResponse | bytes]: + async def handle_async_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection " f"to {self._origin}" ) - async with self._request_lock: + async with self._state_lock: if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): self._request_count += 1 self._state = HTTPConnectionState.ACTIVE @@ -79,69 +77,63 @@ async def iterate_response( else: raise ConnectionNotAvailable() + try: + kwargs = {"request": request} try: - kwargs = {"request": request} - try: - async with Trace( - "send_request_headers", logger, request, kwargs - ) as trace: - await self._send_request_headers(**kwargs) - async with Trace( - "send_request_body", logger, request, kwargs - ) as trace: - await self._send_request_body(**kwargs) - except WriteError: - # If we get a write error while we're writing the request, - # then we supress this error and move on to attempting to - # read the response. Servers can sometimes close the request - # pre-emptively and then respond with a well formed HTTP - # error response. - pass - async with Trace( - "receive_response_headers", logger, request, kwargs + "send_request_headers", logger, request, kwargs ) as trace: - ( - http_version, - status, - reason_phrase, - headers, - trailing_data, - ) = await self._receive_response_headers(**kwargs) - trace.return_value = ( - http_version, - status, - reason_phrase, - headers, - ) - - network_stream = self._network_stream - - # CONNECT or Upgrade request - if (status == 101) or ( - (request.method == b"CONNECT") and (200 <= status < 300) - ): - network_stream = AsyncHTTP11UpgradeStream( - network_stream, trailing_data - ) - - yield StartResponse( - status=status, - headers=headers, - extensions={ - "http_version": http_version, - "reason_phrase": reason_phrase, - "network_stream": network_stream, - }, + await self._send_request_headers(**kwargs) + async with Trace("send_request_body", logger, request, kwargs) as trace: + await self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + + async with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + trailing_data, + ) = await self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, ) - async with Trace("receive_response_body", logger, request, kwargs): - async for chunk in self._receive_response_body(**kwargs): - yield chunk - finally: - await self._response_closed() + + network_stream = self._network_stream + + # CONNECT or Upgrade request + if (status == 101) or ( + (request.method == b"CONNECT") and (200 <= status < 300) + ): + network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": network_stream, + }, + ) + except BaseException as exc: + with AsyncShieldCancellation(): async with Trace("response_closed", logger, request) as trace: - if self.is_closed(): - await self.aclose() + await self._response_closed() + raise exc # Sending the request... @@ -244,17 +236,18 @@ async def _receive_event( return event # type: ignore[return-value] async def _response_closed(self) -> None: - if ( - self._h11_state.our_state is h11.DONE - and self._h11_state.their_state is h11.DONE - ): - self._state = HTTPConnectionState.IDLE - self._h11_state.start_next_cycle() - if self._keepalive_expiry is not None: - now = time.monotonic() - self._expire_at = now + self._keepalive_expiry - else: - self._state = HTTPConnectionState.CLOSED + async with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + await self.aclose() # Once the connection is no longer required... @@ -328,6 +321,33 @@ async def __aexit__( await self.aclose() +class HTTP11ConnectionByteStream: + def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + self._closed = False + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + kwargs = {"request": self._request} + try: + async with Trace("receive_response_body", logger, self._request, kwargs): + async for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with AsyncShieldCancellation(): + await self.aclose() + raise exc + + async def aclose(self) -> None: + if not self._closed: + self._closed = True + async with Trace("response_closed", logger, self._request): + await self._connection._response_closed() + + class AsyncHTTP11UpgradeStream(AsyncNetworkStream): def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None: self._stream = stream diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 3406da00b..c6434a049 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -21,7 +21,7 @@ from .._models import Origin, Request, Response from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation from .._trace import Trace -from .interfaces import AsyncConnectionInterface, StartResponse +from .interfaces import AsyncConnectionInterface logger = logging.getLogger("httpcore.http2") @@ -60,7 +60,6 @@ def __init__( self._state_lock = AsyncLock() self._read_lock = AsyncLock() self._write_lock = AsyncLock() - self._max_streams_semaphore = AsyncSemaphore(100) self._sent_connection_init = False self._used_all_stream_ids = False self._connection_error = False @@ -81,9 +80,7 @@ def __init__( self._read_exception: Exception | None = None self._write_exception: Exception | None = None - async def iterate_response( - self, request: Request - ) -> typing.AsyncIterator[StartResponse | bytes]: + async def handle_async_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): # This cannot occur in normal operation, since the connection pool # will only send requests on connections that handle them. @@ -115,65 +112,76 @@ async def iterate_response( self._sent_connection_init = True - async with self._max_streams_semaphore: - try: - stream_id = self._h2_state.get_next_available_stream_id() - self._events[stream_id] = [] - except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover - self._used_all_stream_ids = True - self._request_count -= 1 - raise ConnectionNotAvailable() + # Initially start with just 1 until the remote server provides + # its max_concurrent_streams value + self._max_streams = 1 - try: - kwargs = {"request": request, "stream_id": stream_id} - async with Trace("send_request_headers", logger, request, kwargs): - await self._send_request_headers(request=request, stream_id=stream_id) - async with Trace("send_request_body", logger, request, kwargs): - await self._send_request_body(request=request, stream_id=stream_id) - async with Trace( - "receive_response_headers", logger, request, kwargs - ) as trace: - status, headers = await self._receive_response( - request=request, stream_id=stream_id - ) - trace.return_value = (status, headers) - - yield StartResponse( - status=status, - headers=headers, - extensions={ - "http_version": b"HTTP/2", - "network_stream": self._network_stream, - "stream_id": stream_id, - }, + local_settings_max_streams = ( + self._h2_state.local_settings.max_concurrent_streams ) - async with Trace("receive_response_body", logger, request, kwargs): - async for chunk in self._receive_response_body( - request=request, stream_id=stream_id - ): - yield chunk - except BaseException as exc: # noqa: PIE786 - if isinstance(exc, h2.exceptions.ProtocolError): - # One case where h2 can raise a protocol error is when a - # closed frame has been seen by the state machine. - # - # This happens when one stream is reading, and encounters - # a GOAWAY event. Other flows of control may then raise - # a protocol error at any point they interact with the 'h2_state'. - # - # In this case we'll have stored the event, and should raise - # it as a RemoteProtocolError. - if self._connection_terminated: # pragma: nocover - raise RemoteProtocolError(self._connection_terminated) - # If h2 raises a protocol error in some other state then we - # must somehow have made a protocol violation. - raise LocalProtocolError(exc) # pragma: nocover - raise exc - finally: + self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams) + + for _ in range(local_settings_max_streams - self._max_streams): + await self._max_streams_semaphore.acquire() + + await self._max_streams_semaphore.acquire() + + try: + stream_id = self._h2_state.get_next_available_stream_id() + self._events[stream_id] = [] + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + self._request_count -= 1 + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request, "stream_id": stream_id} + async with Trace("send_request_headers", logger, request, kwargs): + await self._send_request_headers(request=request, stream_id=stream_id) + async with Trace("send_request_body", logger, request, kwargs): + await self._send_request_body(request=request, stream_id=stream_id) + async with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + status, headers = await self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + return Response( + status=status, + headers=headers, + content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), + extensions={ + "http_version": b"HTTP/2", + "network_stream": self._network_stream, + "stream_id": stream_id, + }, + ) + except BaseException as exc: # noqa: PIE786 + with AsyncShieldCancellation(): kwargs = {"stream_id": stream_id} async with Trace("response_closed", logger, request, kwargs): await self._response_closed(stream_id=stream_id) + if isinstance(exc, h2.exceptions.ProtocolError): + # One case where h2 can raise a protocol error is when a + # closed frame has been seen by the state machine. + # + # This happens when one stream is reading, and encounters + # a GOAWAY event. Other flows of control may then raise + # a protocol error at any point they interact with the 'h2_state'. + # + # In this case we'll have stored the event, and should raise + # it as a RemoteProtocolError. + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) + # If h2 raises a protocol error in some other state then we + # must somehow have made a protocol violation. + raise LocalProtocolError(exc) # pragma: nocover + + raise exc + async def _send_connection_init(self, request: Request) -> None: """ The HTTP/2 connection requires some initial setup before we can start @@ -348,14 +356,14 @@ async def _receive_events( if stream_id is None or not self._events.get(stream_id): events = await self._read_incoming_data(request) for event in events: - # if isinstance(event, h2.events.RemoteSettingsChanged): - # async with Trace( - # "receive_remote_settings", logger, request - # ) as trace: - # await self._receive_remote_settings_change(event) - # trace.return_value = event - - if isinstance( + if isinstance(event, h2.events.RemoteSettingsChanged): + async with Trace( + "receive_remote_settings", logger, request + ) as trace: + await self._receive_remote_settings_change(event) + trace.return_value = event + + elif isinstance( event, ( h2.events.ResponseReceived, @@ -372,24 +380,25 @@ async def _receive_events( await self._write_outgoing_data(request) - # async def _receive_remote_settings_change(self, event: h2.events.Event) -> None: - # max_concurrent_streams = event.changed_settings.get( - # h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS - # ) - # if max_concurrent_streams: - # new_max_streams = min( - # max_concurrent_streams.new_value, - # self._h2_state.local_settings.max_concurrent_streams, - # ) - # if new_max_streams and new_max_streams != self._max_streams: - # while new_max_streams > self._max_streams: - # await self._max_streams_semaphore.release() - # self._max_streams += 1 - # while new_max_streams < self._max_streams: - # await self._max_streams_semaphore.acquire() - # self._max_streams -= 1 + async def _receive_remote_settings_change(self, event: h2.events.Event) -> None: + max_concurrent_streams = event.changed_settings.get( + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS + ) + if max_concurrent_streams: + new_max_streams = min( + max_concurrent_streams.new_value, + self._h2_state.local_settings.max_concurrent_streams, + ) + if new_max_streams and new_max_streams != self._max_streams: + while new_max_streams > self._max_streams: + await self._max_streams_semaphore.release() + self._max_streams += 1 + while new_max_streams < self._max_streams: + await self._max_streams_semaphore.acquire() + self._max_streams -= 1 async def _response_closed(self, stream_id: int) -> None: + await self._max_streams_semaphore.release() del self._events[stream_id] async with self._state_lock: if self._connection_terminated and not self._events: diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index ac2a8e015..cc9d92066 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -17,7 +17,7 @@ enforce_url, ) from .._ssl import default_ssl_context -from .._synchronization import AsyncSemaphore +from .._synchronization import AsyncLock from .._trace import Trace from .connection import AsyncHTTPConnection from .connection_pool import AsyncConnectionPool @@ -259,7 +259,7 @@ def __init__( self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 - self._connect_lock = AsyncSemaphore(bound=1) + self._connect_lock = AsyncLock() self._connected = False async def handle_async_request(self, request: Request) -> Response: diff --git a/httpcore/_async/interfaces.py b/httpcore/_async/interfaces.py index 9b9000162..361583bed 100644 --- a/httpcore/_async/interfaces.py +++ b/httpcore/_async/interfaces.py @@ -17,33 +17,6 @@ ) -class StartResponse: - def __init__(self, status: int, headers: HeaderTypes, extensions: Extensions): - self.status = status - self.headers = headers - self.extensions = extensions - - -class ResponseContext: - def __init__(self, status: int, headers: HeaderTypes, iterator, extensions: Extensions): - self._status = status - self._headers = headers - self._iterator = iterator - self._extensions = extensions - - async def __aenter__(self): - self._response = Response( - status=self._status, - headers=self._headers, - content=self._iterator, - extensions=self._extensions - ) - return self._response - - async def __aexit__(self, *args, **kwargs): - await self._response.aclose() - - class AsyncRequestInterface: async def request( self, @@ -69,15 +42,12 @@ async def request( content=content, extensions=extensions, ) - iterator = self.iterate_response(request) - start_response = await anext(iterator) - content = b"".join([part async for part in iterator]) - return Response( - status=start_response.status, - headers=start_response.headers, - content=content, - extensions=start_response.extensions, - ) + response = await self.handle_async_request(request) + try: + await response.aread() + finally: + await response.aclose() + return response @contextlib.asynccontextmanager async def stream( @@ -88,7 +58,7 @@ async def stream( headers: HeaderTypes = None, content: bytes | typing.AsyncIterator[bytes] | None = None, extensions: Extensions | None = None, - ) -> ResponseContext: + ) -> typing.AsyncIterator[Response]: # Strict type checking on our parameters. method = enforce_bytes(method, name="method") url = enforce_url(url, name="url") @@ -104,24 +74,14 @@ async def stream( content=content, extensions=extensions, ) - iterator = self.iterate_response(request) - start_response = await anext(iterator) - response = Response( - status=start_response.status, - headers=start_response.headers, - content=iterator, - extensions=start_response.extensions, - ) + response = await self.handle_async_request(request) try: yield response finally: await response.aclose() - async def iterate_response( - self, request: Request - ) -> typing.AsyncIterator[StartResponse | bytes]: + async def handle_async_request(self, request: Request) -> Response: raise NotImplementedError() # pragma: nocover - yield b'' class AsyncConnectionInterface(AsyncRequestInterface): diff --git a/httpcore/_models.py b/httpcore/_models.py index 1b1b02b7a..8a65f1334 100644 --- a/httpcore/_models.py +++ b/httpcore/_models.py @@ -397,9 +397,6 @@ def __init__( ) self.extensions = {} if extensions is None else extensions - if isinstance(content, bytes): - self._content = content - self._stream_consumed = False @property diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index b877eaf0d..363f8be81 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -9,12 +9,12 @@ from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream from .._exceptions import ConnectError, ConnectTimeout -from .._models import Origin, Request +from .._models import Origin, Request, Response from .._ssl import default_ssl_context -from .._synchronization import Semaphore +from .._synchronization import Lock from .._trace import Trace from .http11 import HTTP11Connection -from .interfaces import ConnectionInterface, StartResponse +from .interfaces import ConnectionInterface RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. @@ -63,10 +63,10 @@ def __init__( ) self._connection: ConnectionInterface | None = None self._connect_failed: bool = False - self._request_lock = Semaphore(bound=1) + self._request_lock = Lock() self._socket_options = socket_options - def iterate_response(self, request: Request) -> typing.Iterator[StartResponse | bytes]: + def handle_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection to {self._origin}" @@ -100,12 +100,7 @@ def iterate_response(self, request: Request) -> typing.Iterator[StartResponse | self._connect_failed = True raise exc - # iterator = self._connection.iterate_response(request) - iterator = self._connection.iterate_response(request) - start_response = next(iterator) - yield start_response - for body in iterator: - yield body + return self._connection.handle_request(request) def _connect(self, request: Request) -> NetworkStream: timeouts = request.extensions.get("timeout", {}) @@ -179,7 +174,14 @@ def close(self) -> None: def is_available(self) -> bool: if self._connection is None: - return False + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return ( + self._http2 + and (self._origin.scheme == b"https" or not self._http1) + and not self._connect_failed + ) return self._connection.is_available() def has_expired(self) -> bool: diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 63a9799d7..9ccfa53e5 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -2,15 +2,42 @@ import ssl import sys +import types import typing from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend -from .._exceptions import UnsupportedProtocol -from .._models import Origin, Proxy, Request -from .._synchronization import Semaphore +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Proxy, Request, Response +from .._synchronization import Event, ShieldCancellation, ThreadLock from .connection import HTTPConnection -from .interfaces import ConnectionInterface, RequestInterface, StartResponse +from .interfaces import ConnectionInterface, RequestInterface + + +class PoolRequest: + def __init__(self, request: Request) -> None: + self.request = request + self.connection: ConnectionInterface | None = None + self._connection_acquired = Event() + + def assign_to_connection(self, connection: ConnectionInterface | None) -> None: + self.connection = connection + self._connection_acquired.set() + + def clear_connection(self) -> None: + self.connection = None + self._connection_acquired = Event() + + def wait_for_connection( + self, timeout: float | None = None + ) -> ConnectionInterface: + if self.connection is None: + self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection + + def is_queued(self) -> bool: + return self.connection is None class ConnectionPool(RequestInterface): @@ -22,7 +49,6 @@ def __init__( self, ssl_context: ssl.SSLContext | None = None, proxy: Proxy | None = None, - concurrency_limit: int = 100, max_connections: int | None = 10, max_keepalive_connections: int | None = None, keepalive_expiry: float | None = None, @@ -76,7 +102,6 @@ def __init__( self._max_keepalive_connections = min( self._max_connections, self._max_keepalive_connections ) - self._limits = Semaphore(bound=concurrency_limit) self._keepalive_expiry = keepalive_expiry self._http1 = http1 @@ -98,7 +123,7 @@ def __init__( # We only mutate the state of the connection pool within an 'optional_thread_lock' # context. This holds a threading lock unless we're running in async mode, # in which case it is a no-op. - # self._optional_thread_lock = ThreadLock() + self._optional_thread_lock = ThreadLock() def create_connection(self, origin: Origin) -> ConnectionInterface: if self._proxy is not None: @@ -171,7 +196,7 @@ def connections(self) -> list[ConnectionInterface]: """ return list(self._connections) - def iterate_response(self, request: Request) -> typing.Iterator[StartResponse | bytes]: + def handle_request(self, request: Request) -> Response: """ Send an HTTP request, and return an HTTP response. @@ -187,50 +212,145 @@ def iterate_response(self, request: Request) -> typing.Iterator[StartResponse | f"Request URL has an unsupported protocol '{scheme}://'." ) - # timeouts = request.extensions.get("timeout", {}) - # timeout = timeouts.get("pool", None) - - with self._limits: - connection = self._get_connection(request) - iterator = connection.iterate_response(request) - try: - response_start = next(iterator) - # Return the response status and headers. - yield response_start - # Return the response. - for event in iterator: - yield event - finally: - iterator.close() - closing = self._close_connections() - for conn in closing: - conn.close() - - def _get_connection(self, request): - origin = request.url.origin - for connection in self._connections: - if connection.can_handle_request(origin) and connection.is_available(): - return connection - - connection = self.create_connection(origin) - self._connections.append(connection) - return connection - - def _close_connections(self): - closing = [conn for conn in self._connections if conn.has_expired()] - self._connections = [ - conn for conn in self._connections - if not (conn.has_expired() or conn.is_closed()) - ] - return closing + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + + with self._optional_thread_lock: + # Add the incoming request to our request queue. + pool_request = PoolRequest(request) + self._requests.append(pool_request) + + try: + while True: + with self._optional_thread_lock: + # Assign incoming requests to available connections, + # closing or creating new connections as required. + closing = self._assign_requests_to_connections() + self._close_connections(closing) + + # Wait until this request has an assigned connection. + connection = pool_request.wait_for_connection(timeout=timeout) + + try: + # Send the request on the assigned connection. + response = connection.handle_request( + pool_request.request + ) + except ConnectionNotAvailable: + # In some cases a connection may initially be available to + # handle a request, but then become unavailable. + # + # In this case we clear the connection and try again. + pool_request.clear_connection() + else: + break # pragma: nocover + + except BaseException as exc: + with self._optional_thread_lock: + # For any exception or cancellation we remove the request from + # the queue, and then re-assign requests to connections. + self._requests.remove(pool_request) + closing = self._assign_requests_to_connections() + + self._close_connections(closing) + raise exc from None + + # Return the response. Note that in this case we still have to manage + # the point at which the response is closed. + assert isinstance(response.stream, typing.Iterable) + return Response( + status=response.status, + headers=response.headers, + content=PoolByteStream( + stream=response.stream, pool_request=pool_request, pool=self + ), + extensions=response.extensions, + ) + + def _assign_requests_to_connections(self) -> list[ConnectionInterface]: + """ + Manage the state of the connection pool, assigning incoming + requests to connections as available. + + Called whenever a new request is added or removed from the pool. + + Any closing connections are returned, allowing the I/O for closing + those connections to be handled seperately. + """ + closing_connections = [] + + # First we handle cleaning up any connections that are closed, + # have expired their keep-alive, or surplus idle connections. + for connection in list(self._connections): + if connection.is_closed(): + # log: "removing closed connection" + self._connections.remove(connection) + elif connection.has_expired(): + # log: "closing expired connection" + self._connections.remove(connection) + closing_connections.append(connection) + elif ( + connection.is_idle() + and len([connection.is_idle() for connection in self._connections]) + > self._max_keepalive_connections + ): + # log: "closing idle connection" + self._connections.remove(connection) + closing_connections.append(connection) + + # Assign queued requests to connections. + queued_requests = [request for request in self._requests if request.is_queued()] + for pool_request in queued_requests: + origin = pool_request.request.url.origin + available_connections = [ + connection + for connection in self._connections + if connection.can_handle_request(origin) and connection.is_available() + ] + idle_connections = [ + connection for connection in self._connections if connection.is_idle() + ] + + # There are three cases for how we may be able to handle the request: + # + # 1. There is an existing connection that can handle the request. + # 2. We can create a new connection to handle the request. + # 3. We can close an idle connection and then create a new connection + # to handle the request. + if available_connections: + # log: "reusing existing connection" + connection = available_connections[0] + pool_request.assign_to_connection(connection) + elif len(self._connections) < self._max_connections: + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + elif idle_connections: + # log: "closing idle connection" + connection = idle_connections[0] + self._connections.remove(connection) + closing_connections.append(connection) + # log: "creating new connection" + connection = self.create_connection(origin) + self._connections.append(connection) + pool_request.assign_to_connection(connection) + + return closing_connections + + def _close_connections(self, closing: list[ConnectionInterface]) -> None: + # Close connections which have been removed from the pool. + with ShieldCancellation(): + for connection in closing: + connection.close() def close(self) -> None: # Explicitly close the connection pool. # Clears all existing requests and connections. - closing = list(self._connections) - self._connections = [] - for conn in closing: - conn.close() + with self._optional_thread_lock: + closing_connections = list(self._connections) + self._connections = [] + self._close_connections(closing_connections) def __enter__(self) -> ConnectionPool: return self @@ -245,12 +365,56 @@ def __exit__( def __repr__(self) -> str: class_name = self.__class__.__name__ - connection_is_idle = [ - connection.is_idle() for connection in self._connections - ] - num_active_connections = connection_is_idle.count(False) - num_idle_connections = connection_is_idle.count(True) + with self._optional_thread_lock: + request_is_queued = [request.is_queued() for request in self._requests] + connection_is_idle = [ + connection.is_idle() for connection in self._connections + ] + + num_active_requests = request_is_queued.count(False) + num_queued_requests = request_is_queued.count(True) + num_active_connections = connection_is_idle.count(False) + num_idle_connections = connection_is_idle.count(True) + + requests_info = ( + f"Requests: {num_active_requests} active, {num_queued_requests} queued" + ) connection_info = ( f"Connections: {num_active_connections} active, {num_idle_connections} idle" ) - return f"<{class_name} [{connection_info}]>" + + return f"<{class_name} [{requests_info} | {connection_info}]>" + + +class PoolByteStream: + def __init__( + self, + stream: typing.Iterable[bytes], + pool_request: PoolRequest, + pool: ConnectionPool, + ) -> None: + self._stream = stream + self._pool_request = pool_request + self._pool = pool + self._closed = False + + def __iter__(self) -> typing.Iterator[bytes]: + try: + for part in self._stream: + yield part + except BaseException as exc: + self.close() + raise exc from None + + def close(self) -> None: + if not self._closed: + self._closed = True + with ShieldCancellation(): + if hasattr(self._stream, "close"): + self._stream.close() + + with self._pool._optional_thread_lock: + self._pool._requests.remove(self._pool_request) + closing = self._pool._assign_requests_to_connections() + + self._pool._close_connections(closing) diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index fdf2df2d1..ebd3a9748 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -17,10 +17,10 @@ WriteError, map_exceptions, ) -from .._models import Origin, Request -from .._synchronization import Semaphore +from .._models import Origin, Request, Response +from .._synchronization import Lock, ShieldCancellation from .._trace import Trace -from .interfaces import ConnectionInterface, StartResponse +from .interfaces import ConnectionInterface logger = logging.getLogger("httpcore.http11") @@ -55,23 +55,21 @@ def __init__( self._keepalive_expiry: float | None = keepalive_expiry self._expire_at: float | None = None self._state = HTTPConnectionState.NEW - self._request_lock = Semaphore(bound=1) + self._state_lock = Lock() self._request_count = 0 self._h11_state = h11.Connection( our_role=h11.CLIENT, max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, ) - def iterate_response( - self, request: Request - ) -> typing.Iterator[StartResponse | bytes]: + def handle_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): raise RuntimeError( f"Attempted to send request to {request.url.origin} on connection " f"to {self._origin}" ) - with self._request_lock: + with self._state_lock: if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): self._request_count += 1 self._state = HTTPConnectionState.ACTIVE @@ -79,69 +77,63 @@ def iterate_response( else: raise ConnectionNotAvailable() + try: + kwargs = {"request": request} try: - kwargs = {"request": request} - try: - with Trace( - "send_request_headers", logger, request, kwargs - ) as trace: - self._send_request_headers(**kwargs) - with Trace( - "send_request_body", logger, request, kwargs - ) as trace: - self._send_request_body(**kwargs) - except WriteError: - # If we get a write error while we're writing the request, - # then we supress this error and move on to attempting to - # read the response. Servers can sometimes close the request - # pre-emptively and then respond with a well formed HTTP - # error response. - pass - with Trace( - "receive_response_headers", logger, request, kwargs + "send_request_headers", logger, request, kwargs ) as trace: - ( - http_version, - status, - reason_phrase, - headers, - trailing_data, - ) = self._receive_response_headers(**kwargs) - trace.return_value = ( - http_version, - status, - reason_phrase, - headers, - ) - - network_stream = self._network_stream - - # CONNECT or Upgrade request - if (status == 101) or ( - (request.method == b"CONNECT") and (200 <= status < 300) - ): - network_stream = HTTP11UpgradeStream( - network_stream, trailing_data - ) - - yield StartResponse( - status=status, - headers=headers, - extensions={ - "http_version": http_version, - "reason_phrase": reason_phrase, - "network_stream": network_stream, - }, + self._send_request_headers(**kwargs) + with Trace("send_request_body", logger, request, kwargs) as trace: + self._send_request_body(**kwargs) + except WriteError: + # If we get a write error while we're writing the request, + # then we supress this error and move on to attempting to + # read the response. Servers can sometimes close the request + # pre-emptively and then respond with a well formed HTTP + # error response. + pass + + with Trace( + "receive_response_headers", logger, request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + trailing_data, + ) = self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, ) - with Trace("receive_response_body", logger, request, kwargs): - for chunk in self._receive_response_body(**kwargs): - yield chunk - finally: - self._response_closed() + + network_stream = self._network_stream + + # CONNECT or Upgrade request + if (status == 101) or ( + (request.method == b"CONNECT") and (200 <= status < 300) + ): + network_stream = HTTP11UpgradeStream(network_stream, trailing_data) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": network_stream, + }, + ) + except BaseException as exc: + with ShieldCancellation(): with Trace("response_closed", logger, request) as trace: - if self.is_closed(): - self.close() + self._response_closed() + raise exc # Sending the request... @@ -244,17 +236,18 @@ def _receive_event( return event # type: ignore[return-value] def _response_closed(self) -> None: - if ( - self._h11_state.our_state is h11.DONE - and self._h11_state.their_state is h11.DONE - ): - self._state = HTTPConnectionState.IDLE - self._h11_state.start_next_cycle() - if self._keepalive_expiry is not None: - now = time.monotonic() - self._expire_at = now + self._keepalive_expiry - else: - self._state = HTTPConnectionState.CLOSED + with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + self.close() # Once the connection is no longer required... @@ -328,6 +321,33 @@ def __exit__( self.close() +class HTTP11ConnectionByteStream: + def __init__(self, connection: HTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + self._closed = False + + def __iter__(self) -> typing.Iterator[bytes]: + kwargs = {"request": self._request} + try: + with Trace("receive_response_body", logger, self._request, kwargs): + for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + except BaseException as exc: + # If we get an exception while streaming the response, + # we want to close the response (and possibly the connection) + # before raising that exception. + with ShieldCancellation(): + self.close() + raise exc + + def close(self) -> None: + if not self._closed: + self._closed = True + with Trace("response_closed", logger, self._request): + self._connection._response_closed() + + class HTTP11UpgradeStream(NetworkStream): def __init__(self, stream: NetworkStream, leading_data: bytes) -> None: self._stream = stream diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index dea1effe2..ecca88f7d 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -17,7 +17,7 @@ enforce_url, ) from .._ssl import default_ssl_context -from .._synchronization import Semaphore +from .._synchronization import Lock from .._trace import Trace from .connection import HTTPConnection from .connection_pool import ConnectionPool @@ -259,7 +259,7 @@ def __init__( self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 - self._connect_lock = Semaphore(bound=1) + self._connect_lock = Lock() self._connected = False def handle_request(self, request: Request) -> Response: diff --git a/httpcore/_sync/interfaces.py b/httpcore/_sync/interfaces.py index 77860234b..e673d4cc1 100644 --- a/httpcore/_sync/interfaces.py +++ b/httpcore/_sync/interfaces.py @@ -17,33 +17,6 @@ ) -class StartResponse: - def __init__(self, status: int, headers: HeaderTypes, extensions: Extensions): - self.status = status - self.headers = headers - self.extensions = extensions - - -class ResponseContext: - def __init__(self, status: int, headers: HeaderTypes, iterator, extensions: Extensions): - self._status = status - self._headers = headers - self._iterator = iterator - self._extensions = extensions - - def __enter__(self): - self._response = Response( - status=self._status, - headers=self._headers, - content=self._iterator, - extensions=self._extensions - ) - return self._response - - def __exit__(self, *args, **kwargs): - self._response.close() - - class RequestInterface: def request( self, @@ -69,15 +42,12 @@ def request( content=content, extensions=extensions, ) - iterator = self.iterate_response(request) - start_response = next(iterator) - content = b"".join([part for part in iterator]) - return Response( - status=start_response.status, - headers=start_response.headers, - content=content, - extensions=start_response.extensions, - ) + response = self.handle_request(request) + try: + response.read() + finally: + response.close() + return response @contextlib.contextmanager def stream( @@ -88,7 +58,7 @@ def stream( headers: HeaderTypes = None, content: bytes | typing.Iterator[bytes] | None = None, extensions: Extensions | None = None, - ) -> ResponseContext: + ) -> typing.Iterator[Response]: # Strict type checking on our parameters. method = enforce_bytes(method, name="method") url = enforce_url(url, name="url") @@ -104,22 +74,13 @@ def stream( content=content, extensions=extensions, ) - iterator = self.iterate_response(request) - start_response = next(iterator) - response = Response( - status=start_response.status, - headers=start_response.headers, - content=iterator, - extensions=start_response.extensions, - ) + response = self.handle_request(request) try: yield response finally: response.close() - def iterate_response( - self, request: Request - ) -> typing.Iterator[StartResponse | bytes]: + def handle_request(self, request: Request) -> Response: raise NotImplementedError() # pragma: nocover diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index 892130638..2ecc9e9c3 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -171,7 +171,7 @@ def setup(self) -> None: initial_value=self._bound, max_value=self._bound ) - async def __aenter__(self) -> None: + async def acquire(self) -> None: if not self._backend: self.setup() @@ -180,12 +180,7 @@ async def __aenter__(self) -> None: elif self._backend == "asyncio": await self._anyio_semaphore.acquire() - async def __aexit__( - self, - exc_type: type[BaseException] | None = None, - exc_value: BaseException | None = None, - traceback: types.TracebackType | None = None, - ) -> None: + async def release(self) -> None: if self._backend == "trio": self._trio_semaphore.release() elif self._backend == "asyncio": @@ -300,15 +295,10 @@ class Semaphore: def __init__(self, bound: int) -> None: self._semaphore = threading.Semaphore(value=bound) - def __enter__(self) -> None: + def acquire(self) -> None: self._semaphore.acquire() - def __exit__( - self, - exc_type: type[BaseException] | None = None, - exc_value: BaseException | None = None, - traceback: types.TracebackType | None = None, - ) -> None: + def release(self) -> None: self._semaphore.release() diff --git a/scripts/unasync.py b/scripts/unasync.py index f724df30a..5a5627d71 100644 --- a/scripts/unasync.py +++ b/scripts/unasync.py @@ -17,7 +17,6 @@ ('aclose', 'close'), ('aiter_stream', 'iter_stream'), ('aread', 'read'), - ('anext', 'next'), ('asynccontextmanager', 'contextmanager'), ('__aenter__', '__enter__'), ('__aexit__', '__exit__'), diff --git a/tests/_async/test_connection.py b/tests/_async/test_connection.py index a31b4f8da..b6ee0c7e3 100644 --- a/tests/_async/test_connection.py +++ b/tests/_async/test_connection.py @@ -61,29 +61,29 @@ async def test_http_connection(): ) -# @pytest.mark.anyio -# async def test_concurrent_requests_not_available_on_http11_connections(): -# """ -# Attempting to issue a request against an already active HTTP/1.1 connection -# will raise a `ConnectionNotAvailable` exception. -# """ -# origin = Origin(b"https", b"example.com", 443) -# network_backend = AsyncMockBackend( -# [ -# b"HTTP/1.1 200 OK\r\n", -# b"Content-Type: plain/text\r\n", -# b"Content-Length: 13\r\n", -# b"\r\n", -# b"Hello, world!", -# ] -# ) - -# async with AsyncHTTPConnection( -# origin=origin, network_backend=network_backend, keepalive_expiry=5.0 -# ) as conn: -# async with conn.stream("GET", "https://example.com/"): -# with pytest.raises(ConnectionNotAvailable): -# await conn.request("GET", "https://example.com/") +@pytest.mark.anyio +async def test_concurrent_requests_not_available_on_http11_connections(): + """ + Attempting to issue a request against an already active HTTP/1.1 connection + will raise a `ConnectionNotAvailable` exception. + """ + origin = Origin(b"https", b"example.com", 443) + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async with AsyncHTTPConnection( + origin=origin, network_backend=network_backend, keepalive_expiry=5.0 + ) as conn: + async with conn.stream("GET", "https://example.com/"): + with pytest.raises(ConnectionNotAvailable): + await conn.request("GET", "https://example.com/") @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") diff --git a/tests/_sync/test_connection.py b/tests/_sync/test_connection.py index 3dc848532..37c82e025 100644 --- a/tests/_sync/test_connection.py +++ b/tests/_sync/test_connection.py @@ -61,29 +61,29 @@ def test_http_connection(): ) -# @pytest.mark.anyio -# def test_concurrent_requests_not_available_on_http11_connections(): -# """ -# Attempting to issue a request against an already active HTTP/1.1 connection -# will raise a `ConnectionNotAvailable` exception. -# """ -# origin = Origin(b"https", b"example.com", 443) -# network_backend = MockBackend( -# [ -# b"HTTP/1.1 200 OK\r\n", -# b"Content-Type: plain/text\r\n", -# b"Content-Length: 13\r\n", -# b"\r\n", -# b"Hello, world!", -# ] -# ) - -# with HTTPConnection( -# origin=origin, network_backend=network_backend, keepalive_expiry=5.0 -# ) as conn: -# with conn.stream("GET", "https://example.com/"): -# with pytest.raises(ConnectionNotAvailable): -# conn.request("GET", "https://example.com/") + +def test_concurrent_requests_not_available_on_http11_connections(): + """ + Attempting to issue a request against an already active HTTP/1.1 connection + will raise a `ConnectionNotAvailable` exception. + """ + origin = Origin(b"https", b"example.com", 443) + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + with HTTPConnection( + origin=origin, network_backend=network_backend, keepalive_expiry=5.0 + ) as conn: + with conn.stream("GET", "https://example.com/"): + with pytest.raises(ConnectionNotAvailable): + conn.request("GET", "https://example.com/") @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")