diff --git a/setup.py b/setup.py index d8d1a32..6a2a2b3 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ DEPENDENCIES = [ 'ConfigArgParse>=0.12.0', 'six>=1.10.0', - 'vcrpy>=1.11.1', + 'vcrpy>=1.11.0', ] with io.open('README.rst', 'r', encoding='utf-8') as f: diff --git a/src/azure_devtools/scenario_tests/recording_processors.py b/src/azure_devtools/scenario_tests/recording_processors.py index d6fc222..e8ab240 100644 --- a/src/azure_devtools/scenario_tests/recording_processors.py +++ b/src/azure_devtools/scenario_tests/recording_processors.py @@ -17,12 +17,11 @@ def replace_header(cls, entity, header, old, new): @classmethod def replace_header_fn(cls, entity, header, replace_fn): - try: - header = header.lower() - values = entity['headers'][header] - entity['headers'][header] = [replace_fn(v) for v in values] - except KeyError: - pass + # Loop over the headers to find the one we want case insensitively, + # but we don't want to modify the case of original header key. + for key, values in entity['headers'].items(): + if key.lower() == header.lower(): + entity['headers'][key] = [replace_fn(v) for v in values] class SubscriptionRecordingProcessor(RecordingProcessor): @@ -31,6 +30,10 @@ def __init__(self, replacement): def process_request(self, request): request.uri = self._replace_subscription_id(request.uri) + + if request.body: + request.body = self._replace_subscription_id(request.body.decode()).encode() + return request def process_response(self, response): @@ -45,14 +48,16 @@ def process_response(self, response): def _replace_subscription_id(self, val): import re # subscription presents in all api call - retval = re.sub('/subscriptions/([^/]+)/', - '/subscriptions/{}/'.format(self._replacement), - val) + retval = re.sub('/(subscriptions)/([^/]+)/', + r'/\1/{}/'.format(self._replacement), + val, + flags=re.IGNORECASE) # subscription is also used in graph call - retval = re.sub('https://graph.windows.net/([^/]+)/', - 'https://graph.windows.net/{}/'.format(self._replacement), - retval) + retval = re.sub('https://(graph.windows.net)/([^/]+)/', + r'https://\1/{}/'.format(self._replacement), + retval, + flags=re.IGNORECASE) return retval diff --git a/src/azure_devtools/scenario_tests/tests/test_recording_processor.py b/src/azure_devtools/scenario_tests/tests/test_recording_processor.py index 54a0498..69b352b 100644 --- a/src/azure_devtools/scenario_tests/tests/test_recording_processor.py +++ b/src/azure_devtools/scenario_tests/tests/test_recording_processor.py @@ -36,6 +36,35 @@ def test_recording_processor_base_class(self): rp.replace_header_fn(request_sample, 'beta', lambda v: 'customized') self.assertSequenceEqual(request_sample['headers']['beta'], ['customized', 'customized']) + def test_access_token_processor(self): + replaced_subscription_id = 'test_fake_token' + rp = AccessTokenReplacer(replaced_subscription_id) + + TOKEN_STR = '{"token_type": "Bearer", "resource": "url", "access_token": "real_token"}' + token_response_sample = {'body': {'string': TOKEN_STR}} + + self.assertEqual(json.loads(rp.process_response(token_response_sample)['body']['string'])['access_token'], + replaced_subscription_id) + + no_token_response_sample = {'body': {'string': '{"location": "westus"}'}} + self.assertDictEqual(rp.process_response(no_token_response_sample), no_token_response_sample) + + @staticmethod + def _mock_subscription_request_body(mock_sub_id): + return json.dumps({ + "location": "westus", + "properties": { + "ipConfigurations": [ + { + "properties": { + "subnet": {"id": "/Subscriptions/{}/resourceGroups/etc".format(mock_sub_id)}, + "name": "azure-sample-ip-config" + } + }, + ] + } + }).encode() + def test_subscription_recording_processor_for_request(self): replaced_subscription_id = str(uuid.uuid4()) rp = SubscriptionRecordingProcessor(replaced_subscription_id) @@ -45,24 +74,15 @@ def test_subscription_recording_processor_for_request(self): 'https://graph.windows.net/{}/applications?api-version=1.6'] for template in uri_templates: + mock_sub_id = str(uuid.uuid4()) mock_request = mock.Mock() - mock_request.uri = template.format(str(uuid.uuid4())) + mock_request.uri = template.format(mock_sub_id) + mock_request.body = self._mock_subscription_request_body(mock_sub_id) rp.process_request(mock_request) self.assertEqual(mock_request.uri, template.format(replaced_subscription_id)) - - def test_access_token_processor(self): - replaced_subscription_id = 'test_fake_token' - rp = AccessTokenReplacer(replaced_subscription_id) - - TOKEN_STR = '{"token_type": "Bearer", "resource": "url", "access_token": "real_token"}' - token_response_sample = {'body': {'string': TOKEN_STR}} - - self.assertEqual(json.loads(rp.process_response(token_response_sample)['body']['string'])['access_token'], - replaced_subscription_id) - - no_token_response_sample = {'body': {'string': '{"location": "westus"}'}} - self.assertDictEqual(rp.process_response(no_token_response_sample), no_token_response_sample) + self.assertEqual(mock_request.body, + self._mock_subscription_request_body(replaced_subscription_id)) def test_subscription_recording_processor_for_response(self): replaced_subscription_id = str(uuid.uuid4()) @@ -70,7 +90,7 @@ def test_subscription_recording_processor_for_response(self): uri_templates = ['https://management.azure.com/subscriptions/{}/providers/Microsoft.ContainerRegistry/' 'checkNameAvailability?api-version=2017-03-01', - 'https://graph.windows.net/{}/applications?api-version=1.6'] + 'https://graph.Windows.net/{}/applications?api-version=1.6'] location_header_template = 'https://graph.windows.net/{}/directoryObjects/' \ 'f604c53a-aa21-44d5-a41f-c1ef0b5304bd/Microsoft.DirectoryServices.Application' @@ -89,8 +109,7 @@ def test_subscription_recording_processor_for_response(self): rp.process_response(mock_response) self.assertEqual(mock_response['body']['string'], template.format(replaced_subscription_id)) - # TODO: Restore after issue https://github.com/Azure/azure-python-devtools/issues/16 is fixed - # self.assertSequenceEqual(mock_response['headers']['Location'], - # [location_header_template.format(replaced_subscription_id)]) + self.assertSequenceEqual(mock_response['headers']['Location'], + [location_header_template.format(replaced_subscription_id)]) self.assertSequenceEqual(mock_response['headers']['azure-asyncoperation'], [asyncoperation_header_template.format(replaced_subscription_id)])