diff --git a/httpx/client.py b/httpx/client.py index b2c2e5eed0..5b02c9727e 100644 --- a/httpx/client.py +++ b/httpx/client.py @@ -451,7 +451,7 @@ async def send_handling_redirects( raise RedirectLoop() response = await self.send_handling_auth( - request, auth=auth, timeout=timeout, + request, history, auth=auth, timeout=timeout, ) response.history = list(history) @@ -566,7 +566,11 @@ def redirect_stream( return request.stream async def send_handling_auth( - self, request: Request, auth: Auth, timeout: Timeout, + self, + request: Request, + history: typing.List[Response], + auth: Auth, + timeout: Timeout, ) -> Response: auth_flow = auth(request) request = next(auth_flow) @@ -580,8 +584,10 @@ async def send_handling_auth( await response.aclose() raise exc from None else: + response.history = list(history) + await response.aread() request = next_request - await response.aclose() + history.append(response) async def send_single_request( self, request: Request, timeout: Timeout, diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index d4dd76a905..ea6ff8acac 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -6,6 +6,7 @@ import pytest from httpx import URL, AsyncClient, DigestAuth, ProtocolError, Request, Response +from httpx.auth import Auth, AuthFlow from httpx.config import CertTypes, TimeoutTypes, VerifyTypes from httpx.dispatch.base import Dispatcher @@ -218,6 +219,7 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() -> assert response.status_code == 200 assert response.json() == {"auth": None} + assert len(response.history) == 0 @pytest.mark.asyncio @@ -233,6 +235,7 @@ async def test_digest_auth_200_response_including_digest_auth_header() -> None: assert response.status_code == 200 assert response.json() == {"auth": None} + assert len(response.history) == 0 @pytest.mark.asyncio @@ -245,6 +248,7 @@ async def test_digest_auth_401_response_without_digest_auth_header() -> None: assert response.status_code == 401 assert response.json() == {"auth": None} + assert len(response.history) == 0 @pytest.mark.parametrize( @@ -271,6 +275,8 @@ async def test_digest_auth( response = await client.get(url, auth=auth) assert response.status_code == 200 + assert len(response.history) == 1 + authorization = typing.cast(dict, response.json())["auth"] scheme, _, fields = authorization.partition(" ") assert scheme == "Digest" @@ -299,6 +305,8 @@ async def test_digest_auth_no_specified_qop() -> None: response = await client.get(url, auth=auth) assert response.status_code == 200 + assert len(response.history) == 1 + authorization = typing.cast(dict, response.json())["auth"] scheme, _, fields = authorization.partition(" ") assert scheme == "Digest" @@ -325,7 +333,10 @@ async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str) auth = DigestAuth(username="tomchristie", password="password123") client = AsyncClient(dispatch=MockDigestAuthDispatch(qop=qop)) - await client.get(url, auth=auth) + response = await client.get(url, auth=auth) + + assert response.status_code == 200 + assert len(response.history) == 1 @pytest.mark.asyncio @@ -357,6 +368,7 @@ async def test_digest_auth_incorrect_credentials() -> None: response = await client.get(url, auth=auth) assert response.status_code == 401 + assert len(response.history) == 1 @pytest.mark.parametrize( @@ -381,3 +393,52 @@ async def test_digest_auth_raises_protocol_error_on_malformed_header( with pytest.raises(ProtocolError): await client.get(url, auth=auth) + + +@pytest.mark.asyncio +async def test_auth_history() -> None: + """ + Test that intermediate requests sent as part of an authentication flow + are recorded in the response history. + """ + + class RepeatAuth(Auth): + """ + A mock authentication scheme that requires clients to send + the request a fixed number of times, and then send a last request containing + an aggregation of nonces that the server sent in 'WWW-Authenticate' headers + of intermediate responses. + """ + + def __init__(self, repeat: int): + self.repeat = repeat + + def __call__(self, request: Request) -> AuthFlow: + nonces = [] + + for index in range(self.repeat): + request.headers["Authorization"] = f"Repeat {index}" + response = yield request + nonces.append(response.headers["www-authenticate"]) + + key = ".".join(nonces) + request.headers["Authorization"] = f"Repeat {key}" + yield request + + url = "https://example.org/" + auth = RepeatAuth(repeat=2) + client = AsyncClient(dispatch=MockDispatch(auth_header="abc")) + + response = await client.get(url, auth=auth) + assert response.status_code == 200 + assert response.json() == {"auth": "Repeat abc.abc"} + + assert len(response.history) == 2 + resp1, resp2 = response.history + assert resp1.json() == {"auth": "Repeat 0"} + assert resp2.json() == {"auth": "Repeat 1"} + + assert len(resp2.history) == 1 + assert resp2.history == [resp1] + + assert len(resp1.history) == 0