diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 90a29db5c680..57b1a0df00bf 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -7,7 +7,8 @@ #### Breaking Changes #### Bugs Fixed - +* Fixed bug where client provided session token was not respected when client-side session management was disabled. See [PR 42965](https://github.com/Azure/azure-sdk-for-python/pull/42965) + #### Other Changes ### 4.14.0b3 (2025-09-09) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 040b98d476c4..8aabfae1ffe2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -335,8 +335,7 @@ def _is_session_token_request( # Verify that it is not a metadata request, and that it is either a read request, batch request, or an account # configured to use multiple write regions. Batch requests are special-cased because they can contain both read and # write operations, and we want to use session consistency for the read operations. - return (is_session_consistency is True and cosmos_client_connection.session is not None - and not IsMasterResource(request_object.resource_type) + return (is_session_consistency is True and not IsMasterResource(request_object.resource_type) and (documents._OperationType.IsReadOnlyOperation(request_object.operation_type) or request_object.operation_type == "Batch" or cosmos_client_connection._global_endpoint_manager.can_use_multiple_write_locations(request_object))) diff --git a/sdk/cosmos/azure-cosmos/tests/test_session.py b/sdk/cosmos/azure-cosmos/tests/test_session.py index 3a88c5ca5fbe..e4a675236125 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session.py @@ -46,6 +46,73 @@ def setUpClass(cls): cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) cls.created_collection = cls.created_db.get_container_client(cls.TEST_COLLECTION_ID) + def test_manual_session_token_takes_precedence(self): + # Establish an initial session state for the primary client. After this call, self.client has an internal session token. + self.created_collection.create_item( + body={'id': 'precedence_doc_1' + str(uuid.uuid4()), 'pk': 'mypk'} + ) + # Capture the session token from the primary client (Token A) + token_A = self.client.client_connection.last_response_headers.get(HttpHeaders.SessionToken) + self.assertIsNotNone(token_A) + + # Use a separate client to create a second item. This gives us a new, distinct session token from the response. + with cosmos_client.CosmosClient(self.host, self.masterKey) as other_client: + other_collection = other_client.get_database_client(self.TEST_DATABASE_ID) \ + .get_container_client(self.TEST_COLLECTION_ID) + item2 = other_collection.create_item( + body={'id': 'precedence_doc_2' + str(uuid.uuid4()), 'pk': 'mypk'} + ) + # Capture the session token from the second client (Token B) + manual_session_token = other_client.client_connection.last_response_headers.get(HttpHeaders.SessionToken) + self.assertIsNotNone(manual_session_token) + + # Assert that the two tokens are different to ensure we are testing a real override scenario. + self.assertNotEqual(token_A, manual_session_token) + + # At this point, self.client's session is at first token, but we are holding second token. We will now manually use second token in a request on self.client. + def manual_token_hook(request): + # Assert that the header contains the manually provided second token not the client's automatic first token. + self.assertIn(HttpHeaders.SessionToken, request.http_request.headers) + self.assertEqual(request.http_request.headers[HttpHeaders.SessionToken], manual_session_token) + + #Read an item using the primary client, but manually providing second token. The hook will verify that second token overrides the client's internal first token. + self.created_collection.read_item( + item=item2['id'], # Reading the item associated with second token + partition_key='mypk', + session_token=manual_session_token, # Manually provide second token + raw_request_hook=manual_token_hook + ) + + def test_manual_session_token_override(self): + # Create an item to get a valid session token from the response + created_document = self.created_collection.create_item( + body={'id': 'doc_for_manual_session' + str(uuid.uuid4()), 'pk': 'mypk'} + ) + session_token = self.client.client_connection.last_response_headers.get(HttpHeaders.SessionToken) + self.assertIsNotNone(session_token) + + # temporarily disable client-side session management to test manual override + original_session = self.client.client_connection.session + self.client.client_connection.session = None + + try: + # Define a hook to inspect the request headers + def manual_token_hook(request): + self.assertIn(HttpHeaders.SessionToken, request.http_request.headers) + self.assertEqual(request.http_request.headers[HttpHeaders.SessionToken], session_token) + + # Read the item, passing the session token manually. + # The hook will verify it's correctly added to the request headers. + self.created_collection.read_item( + item=created_document['id'], + partition_key='mypk', + session_token=session_token, # Manually provide the session token + raw_request_hook=manual_token_hook + ) + finally: + # Restore the original session object to avoid affecting other tests + self.client.client_connection.session = original_session + def test_session_token_sm_for_ops(self): # Session token should not be sent for control plane operations diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_async.py b/sdk/cosmos/azure-cosmos/tests/test_session_async.py index 9947f378daa5..6a991be540ab 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session_async.py @@ -48,6 +48,74 @@ async def asyncSetUp(self): async def asyncTearDown(self): await self.client.close() + async def test_manual_session_token_takes_precedence_async(self): + # Establish an initial session state for the primary async client. + await self.created_container.create_item( + body={'id': 'precedence_doc_1_async' + str(uuid.uuid4()), 'pk': 'mypk'} + ) + # Capture the session token from the primary client (Token A) + token_A = self.client.client_connection.last_response_headers.get(HttpHeaders.SessionToken) + self.assertIsNotNone(token_A) + + # Use a separate async client to create a second item. This gives us a new, distinct session token. + async with CosmosClient(self.host, self.masterKey) as other_client: + other_collection = other_client.get_database_client(self.TEST_DATABASE_ID) \ + .get_container_client(self.TEST_COLLECTION_ID) + item2 = await other_collection.create_item( + body={'id': 'precedence_doc_2_async' + str(uuid.uuid4()), 'pk': 'mypk'} + ) + # Capture the session token from the second client (Token B) + manual_session_token = other_client.client_connection.last_response_headers.get(HttpHeaders.SessionToken) + self.assertIsNotNone(manual_session_token) + + # Assert that the two tokens are different to ensure we are testing a real override scenario. + self.assertNotEqual(token_A, manual_session_token) + + # Define a hook to verify the correct token is sent. + def manual_token_hook(request): + # Assert that the header contains the manually provided Token B, not the client's automatic Token A. + self.assertIn(HttpHeaders.SessionToken, request.http_request.headers) + self.assertEqual(request.http_request.headers[HttpHeaders.SessionToken], manual_session_token) + + # Read an item using the primary client, but manually providing Token B. + # The hook will verify that Token B overrides the client's internal Token A. + await self.created_container.read_item( + item=item2['id'], + partition_key='mypk', + session_token=manual_session_token, # Manually provide Token B + raw_request_hook=manual_token_hook + ) + + async def test_manual_session_token_override_async(self): + # Create an item to get a valid session token from the response + created_document = await self.created_container.create_item( + body={'id': 'doc_for_manual_session' + str(uuid.uuid4()), 'pk': 'mypk'} + ) + session_token = self.client.client_connection.last_response_headers.get(HttpHeaders.SessionToken) + self.assertIsNotNone(session_token) + + # temporarily disable client-side session management to test manual override + original_session = self.client.client_connection.session + self.client.client_connection.session = None + + try: + # Define a hook to inspect the request headers + def manual_token_hook(request): + self.assertIn(HttpHeaders.SessionToken, request.http_request.headers) + self.assertEqual(request.http_request.headers[HttpHeaders.SessionToken], session_token) + + # Read the item, passing the session token manually. + # The hook will verify it's correctly added to the request headers. + await self.created_container.read_item( + item=created_document['id'], + partition_key='mypk', + session_token=session_token, # Manually provide the session token + raw_request_hook=manual_token_hook + ) + finally: + # Restore the original session object to avoid affecting other tests + self.client.client_connection.session = original_session + async def test_session_token_swr_for_ops_async(self): # Session token should not be sent for control plane operations test_container = await self.created_db.create_container(str(uuid.uuid4()), PartitionKey(path="/id"), raw_response_hook=test_config.no_token_response_hook)