Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### Breaking Changes

#### Bugs Fixed
* Fixed bug where HTTP 403 responses with sub-status 5300 (AAD_REQUEST_NOT_AUTHORIZED) did not trigger a token refresh and retry, causing AAD-authenticated requests to fail permanently after token expiry instead of recovering transparently. See [PR 46167](https://github.com/Azure/azure-sdk-for-python/pull/46167)

#### Other Changes

Expand Down
29 changes: 28 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import HttpResponseError

from .http_constants import HttpHeaders
from .http_constants import HttpHeaders, SubStatusCodes
from ._constants import _Constants as Constants

HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
Expand Down Expand Up @@ -67,6 +67,33 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
continue
raise

def send(self, request: PipelineRequest[HTTPRequestType]): # type: ignore[override]
"""Authorize request with a bearer token and send it to the next policy.

If Cosmos DB returns HTTP 403 with sub-status AAD_REQUEST_NOT_AUTHORIZED (5300), the cached
token is cleared and a single retry is performed with a fresh token. This handles the case
where an AAD token has expired and Cosmos DB returns 403 instead of 401.

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
response = super().send(request)
if (
response.http_response.status_code == 403
and int(response.http_response.headers.get(HttpHeaders.SubStatus, 0))
== SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED
):
self._token = None # cached token is invalid
self.on_request(request)
try:
response = self.next.send(request)
except Exception:
self.on_exception(request)
raise
return response

def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.

Expand Down
29 changes: 28 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import HttpResponseError

from ..http_constants import HttpHeaders
from ..http_constants import HttpHeaders, SubStatusCodes
from .._constants import _Constants as Constants

HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
Expand Down Expand Up @@ -68,6 +68,33 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
continue
raise

async def send(self, request: PipelineRequest[HTTPRequestType]): # type: ignore[override]
"""Authorize request with a bearer token and send it to the next policy.

If Cosmos DB returns HTTP 403 with sub-status AAD_REQUEST_NOT_AUTHORIZED (5300), the cached
token is cleared and a single retry is performed with a fresh token. This handles the case
where an AAD token has expired and Cosmos DB returns 403 instead of 401.

:param request: The pipeline request object
:type request: ~azure.core.pipeline.PipelineRequest
:return: The pipeline response object
:rtype: ~azure.core.pipeline.PipelineResponse
"""
response = await super().send(request)
if (
response.http_response.status_code == 403
and int(response.http_response.headers.get(HttpHeaders.SubStatus, 0))
== SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED
):
self._token = None # cached token is invalid
await self.on_request(request)
try:
response = await self.next.send(request)
except Exception:
self.on_exception(request)
raise
return response

async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.

Expand Down
219 changes: 219 additions & 0 deletions sdk/cosmos/azure-cosmos/tests/test_auth_policy_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.

"""Unit tests for CosmosBearerTokenCredentialPolicy 403/AAD token refresh behavior.

Uses a realistic azure-core Pipeline with a mock transport that returns proper
requests.Response objects (including the x-ms-substatus header), and verifies
that the Authorization header is correctly set in the requests that reach the transport.
"""

import time
import unittest
from unittest.mock import Mock

from requests import Response

from azure.core.credentials import AccessToken
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import HttpTransport, HttpRequest

from azure.cosmos._auth_policy import CosmosBearerTokenCredentialPolicy
from azure.cosmos.http_constants import HttpHeaders, SubStatusCodes

COSMOS_ACCOUNT_URL = "https://example.cosmos.azure.com"
ACCOUNT_SCOPE = "https://cosmos.azure.com/.default"
AAD_AUTH_PREFIX = "type=aad&ver=1.0&sig="


def _make_response(status_code, sub_status=None):
"""Create a requests.Response with optional x-ms-substatus header."""
response = Response()
response.status_code = status_code
if sub_status is not None:
response.headers[HttpHeaders.SubStatus] = str(sub_status)
return response


def _make_credential(token_str="fake-token"):
"""Create a sync credential mock that returns an AccessToken via get_token."""
credential = Mock(spec_set=["get_token"])
credential.get_token.return_value = AccessToken(token_str, int(time.time()) + 3600)
return credential


class MockTransport(HttpTransport):
"""Minimal sync HTTP transport that replays a sequence of canned responses and
records each outgoing request so tests can inspect its headers."""

def __init__(self, *responses):
self._responses = list(responses)
self.requests = []

def open(self):
pass

def close(self):
pass

def __exit__(self, *args):
pass

def __enter__(self):
return self

def send(self, request, **kwargs):
self.requests.append(request)
return self._responses.pop(0)


class TestCosmosBearerTokenPolicySend(unittest.TestCase):

def _run(self, credential, *responses):
"""Build a Pipeline with the Cosmos bearer policy and run a GET against it.

Returns (pipeline_response, transport) so callers can inspect both the
final response and the recorded outgoing requests.
"""
transport = MockTransport(*responses)
policy = CosmosBearerTokenCredentialPolicy(credential, ACCOUNT_SCOPE)
pipeline = Pipeline(transport=transport, policies=[policy])
http_response = pipeline.run(HttpRequest("GET", f"{COSMOS_ACCOUNT_URL}/dbs"))
return http_response, transport

# ------------------------------------------------------------------
# Pass-through cases — no retry expected
# ------------------------------------------------------------------

def test_200_response_passes_through(self):
"""A 200 response is forwarded to the caller with no retry."""
credential = _make_credential()
_, transport = self._run(credential, _make_response(200))

assert transport.requests[0].headers["Authorization"].startswith(AAD_AUTH_PREFIX)
assert len(transport.requests) == 1

def test_403_without_substatus_no_retry(self):
"""A 403 with no sub-status is not an AAD expiry — no retry should occur."""
credential = _make_credential()
result, transport = self._run(credential, _make_response(403))

assert result.http_response.status_code == 403
assert len(transport.requests) == 1

def test_403_write_forbidden_no_retry(self):
"""403/WRITE_FORBIDDEN is a different error — no AAD-triggered retry."""
credential = _make_credential()
result, transport = self._run(
credential, _make_response(403, sub_status=SubStatusCodes.WRITE_FORBIDDEN)
)

assert result.http_response.status_code == 403
assert len(transport.requests) == 1

# ------------------------------------------------------------------
# 403 / AAD_REQUEST_NOT_AUTHORIZED — retry expected
# ------------------------------------------------------------------

def test_403_aad_expired_retries_and_succeeds(self):
"""403/AAD_REQUEST_NOT_AUTHORIZED triggers a token refresh and one retry.

The retry must succeed with the fresh token, and both the initial request
and the retry must carry a properly-formatted Cosmos AAD Authorization header.
"""
credential = _make_credential("fresh-token")
result, transport = self._run(
credential,
_make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED),
_make_response(200),
)

assert result.http_response.status_code == 200
assert len(transport.requests) == 2

# Both requests must carry the Cosmos-specific AAD header format
for req in transport.requests:
assert req.headers["Authorization"].startswith(AAD_AUTH_PREFIX), (
f"Expected Cosmos AAD header format, got: {req.headers.get('Authorization')}"
)

def test_403_aad_expired_sends_fresh_token_on_retry(self):
"""The retry request must use a freshly-acquired token, not the expired one.

We give the credential two different tokens: the first simulates an expired
cached token; the second is the fresh one returned after the cache is cleared.
"""
fresh_token = "brand-new-token"
expired_token = "old-expired-token"

call_count = [0]
tokens = [expired_token, fresh_token]

credential = Mock(spec_set=["get_token"])

def rotating_get_token(*scopes, **kwargs):
token = tokens[min(call_count[0], len(tokens) - 1)]
call_count[0] += 1
return AccessToken(token, int(time.time()) + 3600)

credential.get_token.side_effect = rotating_get_token

transport = MockTransport(
_make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED),
_make_response(200),
)
policy = CosmosBearerTokenCredentialPolicy(credential, ACCOUNT_SCOPE)
pipeline = Pipeline(transport=transport, policies=[policy])
pipeline.run(HttpRequest("GET", f"{COSMOS_ACCOUNT_URL}/dbs"))

assert len(transport.requests) == 2
retry_auth = transport.requests[1].headers["Authorization"]
assert fresh_token in retry_auth, (
f"Expected fresh token '{fresh_token}' in retry Authorization header, got: {retry_auth}"
)

def test_403_aad_expired_auth_header_cleared_before_retry(self):
"""After 403/5300 the policy clears its cached token so the retry gets a new one.

We force the token cache to contain an expired-looking token and verify
that the Authorization header on the retry differs from the initial request.
"""
credential = _make_credential("fresh-token-after-expiry")
transport = MockTransport(
_make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED),
_make_response(200),
)
policy = CosmosBearerTokenCredentialPolicy(credential, ACCOUNT_SCOPE)
# Inject a "stale" token into the policy cache to simulate an expired token
policy._token = AccessToken("stale-token", int(time.time()) - 60)

pipeline = Pipeline(transport=transport, policies=[policy])
pipeline.run(HttpRequest("GET", f"{COSMOS_ACCOUNT_URL}/dbs"))

assert len(transport.requests) == 2
initial_auth = transport.requests[0].headers["Authorization"]
retry_auth = transport.requests[1].headers["Authorization"]
# The stale token must not appear in the retry request
assert "stale-token" not in retry_auth, (
"Stale token should have been replaced before retry"
)
# Both headers must still use the Cosmos-specific format
assert initial_auth.startswith(AAD_AUTH_PREFIX)
assert retry_auth.startswith(AAD_AUTH_PREFIX)

def test_403_aad_retry_still_fails_returns_second_response(self):
"""If the retry also returns a non-retryable 403, that response is returned unchanged."""
credential = _make_credential()
result, transport = self._run(
credential,
_make_response(403, sub_status=SubStatusCodes.AAD_REQUEST_NOT_AUTHORIZED),
_make_response(403, sub_status=SubStatusCodes.WRITE_FORBIDDEN),
)

assert result.http_response.status_code == 403
assert len(transport.requests) == 2


if __name__ == "__main__":
unittest.main()

Loading