diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index f08e9b526bbc..c62a5e98a241 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -7,6 +7,7 @@ #### Breaking Changes #### Bugs Fixed +* Fixed bug where HTTP 403 responses with sub-status 5300 (AAD_REQUEST_NOT_AUTHORIZED) did not trigger a token refresh and retry, causing AAD-authenticated requests to fail permanently after token expiry instead of recovering transparently. See [PR 46167](https://github.com/Azure/azure-sdk-for-python/pull/46167) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py index 83418e1f375d..f622720a69bc 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py @@ -12,7 +12,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import HttpResponseError -from .http_constants import HttpHeaders +from .http_constants import HttpHeaders, SubStatusCodes from ._constants import _Constants as Constants HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) @@ -67,6 +67,33 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: continue raise + def send(self, request: PipelineRequest[HTTPRequestType]): # type: ignore[override] + """Authorize request with a bearer token and send it to the next policy. + + If Cosmos DB returns HTTP 403 with sub-status AAD_REQUEST_NOT_AUTHORIZED (5300), the cached + token is cleared and a single retry is performed with a fresh token. This handles the case + where an AAD token has expired and Cosmos DB returns 403 instead of 401. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + response = super().send(request) + if ( + response.http_response.status_code == 403 + and int(response.http_response.headers.get(HttpHeaders.SubStatus, 0)) + == SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED + ): + self._token = None # cached token is invalid + self.on_request(request) + try: + response = self.next.send(request) + except Exception: + self.on_exception(request) + raise + return response + def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py index ea1a86b120a1..b0d6d55e0b18 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py @@ -13,7 +13,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import HttpResponseError -from ..http_constants import HttpHeaders +from ..http_constants import HttpHeaders, SubStatusCodes from .._constants import _Constants as Constants HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) @@ -68,6 +68,33 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: continue raise + async def send(self, request: PipelineRequest[HTTPRequestType]): # type: ignore[override] + """Authorize request with a bearer token and send it to the next policy. + + If Cosmos DB returns HTTP 403 with sub-status AAD_REQUEST_NOT_AUTHORIZED (5300), the cached + token is cleared and a single retry is performed with a fresh token. This handles the case + where an AAD token has expired and Cosmos DB returns 403 instead of 401. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + response = await super().send(request) + if ( + response.http_response.status_code == 403 + and int(response.http_response.headers.get(HttpHeaders.SubStatus, 0)) + == SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED + ): + self._token = None # cached token is invalid + await self.on_request(request) + try: + response = await self.next.send(request) + except Exception: + self.on_exception(request) + raise + return response + async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. diff --git a/sdk/cosmos/azure-cosmos/tests/test_auth_policy_unit.py b/sdk/cosmos/azure-cosmos/tests/test_auth_policy_unit.py new file mode 100644 index 000000000000..288a778975b1 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_auth_policy_unit.py @@ -0,0 +1,219 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Unit tests for CosmosBearerTokenCredentialPolicy 403/AAD token refresh behavior. + +Uses a realistic azure-core Pipeline with a mock transport that returns proper +requests.Response objects (including the x-ms-substatus header), and verifies +that the Authorization header is correctly set in the requests that reach the transport. +""" + +import time +import unittest +from unittest.mock import Mock + +from requests import Response + +from azure.core.credentials import AccessToken +from azure.core.pipeline import Pipeline +from azure.core.pipeline.transport import HttpTransport, HttpRequest + +from azure.cosmos._auth_policy import CosmosBearerTokenCredentialPolicy +from azure.cosmos.http_constants import HttpHeaders, SubStatusCodes + +COSMOS_ACCOUNT_URL = "https://example.cosmos.azure.com" +ACCOUNT_SCOPE = "https://cosmos.azure.com/.default" +AAD_AUTH_PREFIX = "type=aad&ver=1.0&sig=" + + +def _make_response(status_code, sub_status=None): + """Create a requests.Response with optional x-ms-substatus header.""" + response = Response() + response.status_code = status_code + if sub_status is not None: + response.headers[HttpHeaders.SubStatus] = str(sub_status) + return response + + +def _make_credential(token_str="fake-token"): + """Create a sync credential mock that returns an AccessToken via get_token.""" + credential = Mock(spec_set=["get_token"]) + credential.get_token.return_value = AccessToken(token_str, int(time.time()) + 3600) + return credential + + +class MockTransport(HttpTransport): + """Minimal sync HTTP transport that replays a sequence of canned responses and + records each outgoing request so tests can inspect its headers.""" + + def __init__(self, *responses): + self._responses = list(responses) + self.requests = [] + + def open(self): + pass + + def close(self): + pass + + def __exit__(self, *args): + pass + + def __enter__(self): + return self + + def send(self, request, **kwargs): + self.requests.append(request) + return self._responses.pop(0) + + +class TestCosmosBearerTokenPolicySend(unittest.TestCase): + + def _run(self, credential, *responses): + """Build a Pipeline with the Cosmos bearer policy and run a GET against it. + + Returns (pipeline_response, transport) so callers can inspect both the + final response and the recorded outgoing requests. + """ + transport = MockTransport(*responses) + policy = CosmosBearerTokenCredentialPolicy(credential, ACCOUNT_SCOPE) + pipeline = Pipeline(transport=transport, policies=[policy]) + http_response = pipeline.run(HttpRequest("GET", f"{COSMOS_ACCOUNT_URL}/dbs")) + return http_response, transport + + # ------------------------------------------------------------------ + # Pass-through cases — no retry expected + # ------------------------------------------------------------------ + + def test_200_response_passes_through(self): + """A 200 response is forwarded to the caller with no retry.""" + credential = _make_credential() + _, transport = self._run(credential, _make_response(200)) + + assert transport.requests[0].headers["Authorization"].startswith(AAD_AUTH_PREFIX) + assert len(transport.requests) == 1 + + def test_403_without_substatus_no_retry(self): + """A 403 with no sub-status is not an AAD expiry — no retry should occur.""" + credential = _make_credential() + result, transport = self._run(credential, _make_response(403)) + + assert result.http_response.status_code == 403 + assert len(transport.requests) == 1 + + def test_403_write_forbidden_no_retry(self): + """403/WRITE_FORBIDDEN is a different error — no AAD-triggered retry.""" + credential = _make_credential() + result, transport = self._run( + credential, _make_response(403, sub_status=SubStatusCodes.WRITE_FORBIDDEN) + ) + + assert result.http_response.status_code == 403 + assert len(transport.requests) == 1 + + # ------------------------------------------------------------------ + # 403 / AAD_REQUEST_NOT_AUTHORIZED — retry expected + # ------------------------------------------------------------------ + + def test_403_aad_expired_retries_and_succeeds(self): + """403/AAD_REQUEST_NOT_AUTHORIZED triggers a token refresh and one retry. + + The retry must succeed with the fresh token, and both the initial request + and the retry must carry a properly-formatted Cosmos AAD Authorization header. + """ + credential = _make_credential("fresh-token") + result, transport = self._run( + credential, + _make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED), + _make_response(200), + ) + + assert result.http_response.status_code == 200 + assert len(transport.requests) == 2 + + # Both requests must carry the Cosmos-specific AAD header format + for req in transport.requests: + assert req.headers["Authorization"].startswith(AAD_AUTH_PREFIX), ( + f"Expected Cosmos AAD header format, got: {req.headers.get('Authorization')}" + ) + + def test_403_aad_expired_sends_fresh_token_on_retry(self): + """The retry request must use a freshly-acquired token, not the expired one. + + We give the credential two different tokens: the first simulates an expired + cached token; the second is the fresh one returned after the cache is cleared. + """ + fresh_token = "brand-new-token" + expired_token = "old-expired-token" + + call_count = [0] + tokens = [expired_token, fresh_token] + + credential = Mock(spec_set=["get_token"]) + + def rotating_get_token(*scopes, **kwargs): + token = tokens[min(call_count[0], len(tokens) - 1)] + call_count[0] += 1 + return AccessToken(token, int(time.time()) + 3600) + + credential.get_token.side_effect = rotating_get_token + + transport = MockTransport( + _make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED), + _make_response(200), + ) + policy = CosmosBearerTokenCredentialPolicy(credential, ACCOUNT_SCOPE) + pipeline = Pipeline(transport=transport, policies=[policy]) + pipeline.run(HttpRequest("GET", f"{COSMOS_ACCOUNT_URL}/dbs")) + + assert len(transport.requests) == 2 + retry_auth = transport.requests[1].headers["Authorization"] + assert fresh_token in retry_auth, ( + f"Expected fresh token '{fresh_token}' in retry Authorization header, got: {retry_auth}" + ) + + def test_403_aad_expired_auth_header_cleared_before_retry(self): + """After 403/5300 the policy clears its cached token so the retry gets a new one. + + We force the token cache to contain an expired-looking token and verify + that the Authorization header on the retry differs from the initial request. + """ + credential = _make_credential("fresh-token-after-expiry") + transport = MockTransport( + _make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED), + _make_response(200), + ) + policy = CosmosBearerTokenCredentialPolicy(credential, ACCOUNT_SCOPE) + # Inject a "stale" token into the policy cache to simulate an expired token + policy._token = AccessToken("stale-token", int(time.time()) - 60) + + pipeline = Pipeline(transport=transport, policies=[policy]) + pipeline.run(HttpRequest("GET", f"{COSMOS_ACCOUNT_URL}/dbs")) + + assert len(transport.requests) == 2 + initial_auth = transport.requests[0].headers["Authorization"] + retry_auth = transport.requests[1].headers["Authorization"] + # The stale token must not appear in the retry request + assert "stale-token" not in retry_auth, ( + "Stale token should have been replaced before retry" + ) + # Both headers must still use the Cosmos-specific format + assert initial_auth.startswith(AAD_AUTH_PREFIX) + assert retry_auth.startswith(AAD_AUTH_PREFIX) + + def test_403_aad_retry_still_fails_returns_second_response(self): + """If the retry also returns a non-retryable 403, that response is returned unchanged.""" + credential = _make_credential() + result, transport = self._run( + credential, + _make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED), + _make_response(403, sub_status=SubStatusCodes.WRITE_FORBIDDEN), + ) + + assert result.http_response.status_code == 403 + assert len(transport.requests) == 2 + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_auth_policy_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_auth_policy_unit_async.py new file mode 100644 index 000000000000..ba6e594c63c4 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_auth_policy_unit_async.py @@ -0,0 +1,219 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Async unit tests for AsyncCosmosBearerTokenCredentialPolicy 403/AAD token refresh behavior. + +Uses a realistic azure-core AsyncPipeline with an async mock transport that returns proper +requests.Response objects (including the x-ms-substatus header), and verifies that the +Authorization header is correctly set in the requests that reach the transport. +""" + +import time +import unittest +from unittest.mock import Mock, AsyncMock + +from requests import Response + +from azure.core.credentials import AccessToken +from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.transport import AsyncHttpTransport, HttpRequest + +from azure.cosmos.aio._auth_policy_async import AsyncCosmosBearerTokenCredentialPolicy +from azure.cosmos.http_constants import HttpHeaders, SubStatusCodes + +COSMOS_ACCOUNT_URL = "https://example.cosmos.azure.com" +ACCOUNT_SCOPE = "https://cosmos.azure.com/.default" +AAD_AUTH_PREFIX = "type=aad&ver=1.0&sig=" + + +def _make_response(status_code, sub_status=None): + """Create a requests.Response with optional x-ms-substatus header.""" + response = Response() + response.status_code = status_code + if sub_status is not None: + response.headers[HttpHeaders.SubStatus] = str(sub_status) + return response + + +def _make_async_credential(token_str="fake-token"): + """Create an async credential mock that returns an AccessToken via get_token.""" + credential = Mock(spec_set=["get_token"]) + credential.get_token = AsyncMock(return_value=AccessToken(token_str, int(time.time()) + 3600)) + return credential + + +class MockAsyncTransport(AsyncHttpTransport): + """Minimal async HTTP transport that replays a sequence of canned responses and + records each outgoing request so tests can inspect its headers.""" + + def __init__(self, *responses): + self._responses = list(responses) + self.requests = [] + + async def open(self): + pass + + async def close(self): + pass + + async def __aexit__(self, *args): + pass + + async def __aenter__(self): + return self + + async def send(self, request, **kwargs): + self.requests.append(request) + return self._responses.pop(0) + + +class TestAsyncCosmosBearerTokenPolicySend(unittest.IsolatedAsyncioTestCase): + + async def _run(self, credential, *responses): + """Build an AsyncPipeline with the Cosmos bearer policy and run a GET against it. + + Returns (pipeline_response, transport) so callers can inspect both the + final response and the recorded outgoing requests. + """ + transport = MockAsyncTransport(*responses) + policy = AsyncCosmosBearerTokenCredentialPolicy(credential, ACCOUNT_SCOPE) + pipeline = AsyncPipeline(transport=transport, policies=[policy]) + http_response = await pipeline.run(HttpRequest("GET", f"{COSMOS_ACCOUNT_URL}/dbs")) + return http_response, transport + + # ------------------------------------------------------------------ + # Pass-through cases — no retry expected + # ------------------------------------------------------------------ + + async def test_200_response_passes_through(self): + """A 200 response is forwarded to the caller with no retry.""" + credential = _make_async_credential() + _, transport = await self._run(credential, _make_response(200)) + + assert transport.requests[0].headers["Authorization"].startswith(AAD_AUTH_PREFIX) + assert len(transport.requests) == 1 + + async def test_403_without_substatus_no_retry(self): + """A 403 with no sub-status is not an AAD expiry — no retry should occur.""" + credential = _make_async_credential() + result, transport = await self._run(credential, _make_response(403)) + + assert result.http_response.status_code == 403 + assert len(transport.requests) == 1 + + async def test_403_write_forbidden_no_retry(self): + """403/WRITE_FORBIDDEN is a different error — no AAD-triggered retry.""" + credential = _make_async_credential() + result, transport = await self._run( + credential, _make_response(403, sub_status=SubStatusCodes.WRITE_FORBIDDEN) + ) + + assert result.http_response.status_code == 403 + assert len(transport.requests) == 1 + + # ------------------------------------------------------------------ + # 403 / AAD_REQUEST_NOT_AUTHORIZED — retry expected + # ------------------------------------------------------------------ + + async def test_403_aad_expired_retries_and_succeeds(self): + """403/AAD_REQUEST_NOT_AUTHORIZED triggers a token refresh and one retry. + + The retry must succeed with the fresh token, and both the initial request + and the retry must carry a properly-formatted Cosmos AAD Authorization header. + """ + credential = _make_async_credential("fresh-token") + result, transport = await self._run( + credential, + _make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED), + _make_response(200), + ) + + assert result.http_response.status_code == 200 + assert len(transport.requests) == 2 + + # Both requests must carry the Cosmos-specific AAD header format + for req in transport.requests: + assert req.headers["Authorization"].startswith(AAD_AUTH_PREFIX), ( + f"Expected Cosmos AAD header format, got: {req.headers.get('Authorization')}" + ) + + async def test_403_aad_expired_sends_fresh_token_on_retry(self): + """The retry request must use a freshly-acquired token, not the expired one. + + We give the credential two different tokens: the first simulates an expired + cached token; the second is the fresh one returned after the cache is cleared. + """ + fresh_token = "brand-new-token" + expired_token = "old-expired-token" + + call_count = [0] + tokens = [expired_token, fresh_token] + + credential = Mock(spec_set=["get_token"]) + + async def rotating_get_token(*scopes, **kwargs): + token = tokens[min(call_count[0], len(tokens) - 1)] + call_count[0] += 1 + return AccessToken(token, int(time.time()) + 3600) + + credential.get_token = rotating_get_token + + transport = MockAsyncTransport( + _make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED), + _make_response(200), + ) + policy = AsyncCosmosBearerTokenCredentialPolicy(credential, ACCOUNT_SCOPE) + pipeline = AsyncPipeline(transport=transport, policies=[policy]) + await pipeline.run(HttpRequest("GET", f"{COSMOS_ACCOUNT_URL}/dbs")) + + assert len(transport.requests) == 2 + retry_auth = transport.requests[1].headers["Authorization"] + assert fresh_token in retry_auth, ( + f"Expected fresh token '{fresh_token}' in retry Authorization header, got: {retry_auth}" + ) + + async def test_403_aad_expired_auth_header_cleared_before_retry(self): + """After 403/5300 the policy clears its cached token so the retry gets a new one. + + We force the token cache to contain an expired-looking token and verify + that the Authorization header on the retry differs from the initial request. + """ + credential = _make_async_credential("fresh-token-after-expiry") + transport = MockAsyncTransport( + _make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED), + _make_response(200), + ) + policy = AsyncCosmosBearerTokenCredentialPolicy(credential, ACCOUNT_SCOPE) + # Inject a "stale" token into the policy cache to simulate an expired token + policy._token = AccessToken("stale-token", int(time.time()) - 60) + + pipeline = AsyncPipeline(transport=transport, policies=[policy]) + await pipeline.run(HttpRequest("GET", f"{COSMOS_ACCOUNT_URL}/dbs")) + + assert len(transport.requests) == 2 + initial_auth = transport.requests[0].headers["Authorization"] + retry_auth = transport.requests[1].headers["Authorization"] + # The stale token must not appear in the retry request + assert "stale-token" not in retry_auth, ( + "Stale token should have been replaced before retry" + ) + # Both headers must still use the Cosmos-specific format + assert initial_auth.startswith(AAD_AUTH_PREFIX) + assert retry_auth.startswith(AAD_AUTH_PREFIX) + + async def test_403_aad_retry_still_fails_returns_second_response(self): + """If the retry also returns a non-retryable 403, that response is returned unchanged.""" + credential = _make_async_credential() + result, transport = await self._run( + credential, + _make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED), + _make_response(403, sub_status=SubStatusCodes.WRITE_FORBIDDEN), + ) + + assert result.http_response.status_code == 403 + assert len(transport.requests) == 2 + + +if __name__ == "__main__": + unittest.main() +