diff --git a/dev_requirements.txt b/dev_requirements.txt new file mode 100644 index 0000000..edf602b --- /dev/null +++ b/dev_requirements.txt @@ -0,0 +1,2 @@ +-e .[ci_tools] +mock;python_version<="2.7" diff --git a/setup.py b/setup.py index e00249b..f67516a 100644 --- a/setup.py +++ b/setup.py @@ -8,9 +8,7 @@ import io from setuptools import setup - -VERSION = "1.2.0" - +VERSION = "1.2.1" CLASSIFIERS = [ 'Development Status :: 3 - Alpha', @@ -29,7 +27,7 @@ DEPENDENCIES = [ 'ConfigArgParse>=0.12.0', 'six>=1.10.0', - 'vcrpy>=1.11.0', + 'vcrpy==3.0.0' ] with io.open('README.rst', 'r', encoding='utf-8') as f: @@ -49,13 +47,27 @@ packages=[ 'azure_devtools', 'azure_devtools.scenario_tests', + 'azure_devtools.perfstress_tests', 'azure_devtools.ci_tools', ], + entry_points={ + 'console_scripts': [ + 'perfstress = azure_devtools.perfstress_tests:run_perfstress_cmd', + 'systemperf = azure_devtools.perfstress_tests:run_system_perfstress_tests_cmd', + ], + }, extras_require={ 'ci_tools':[ "PyGithub>=1.40", # Can Merge PR after 1.36, "requests" and tests after 1.40 "GitPython", "requests>=2.0" + ], + 'systemperf':[ + "aiohttp>=3.0", + "requests>=2.0", + "tornado==6.0.3" + "pycurl==7.43.0.5" + "httpx==0.11.1" ] }, package_dir={'': 'src'}, diff --git a/src/azure_devtools/ci_tools/git_tools.py b/src/azure_devtools/ci_tools/git_tools.py index 8c29dd8..810e464 100644 --- a/src/azure_devtools/ci_tools/git_tools.py +++ b/src/azure_devtools/ci_tools/git_tools.py @@ -88,3 +88,18 @@ def get_files_in_commit(git_folder, commit_id="HEAD"): repo = Repo(str(git_folder)) output = repo.git.diff("--name-only", commit_id+"^", commit_id) return output.splitlines() + +def get_diff_file_list(git_folder): + """List of unstaged files. + """ + repo = Repo(str(git_folder)) + output = repo.git.diff("--name-only") + return output.splitlines() + +def get_add_diff_file_list(git_folder): + """List of new files. + """ + repo = Repo(str(git_folder)) + repo.git.add("sdk") + output = repo.git.diff("HEAD", "--name-only") + return output.splitlines() diff --git a/src/azure_devtools/perfstress_tests/__init__.py b/src/azure_devtools/perfstress_tests/__init__.py new file mode 100644 index 0000000..0716333 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/__init__.py @@ -0,0 +1,35 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import os +import asyncio + +from .perf_stress_runner import PerfStressRunner +from .perf_stress_test import PerfStressTest +from .random_stream import RandomStream, WriteStream, get_random_bytes +from .async_random_stream import AsyncRandomStream + +__all__ = [ + "PerfStressRunner", + "PerfStressTest", + "RandomStream", + "WriteStream", + "AsyncRandomStream", + "get_random_bytes" +] + + +def run_perfstress_cmd(): + main_loop = PerfStressRunner() + loop = asyncio.get_event_loop() + loop.run_until_complete(main_loop.start()) + + +def run_system_perfstress_tests_cmd(): + root_dir = os.path.dirname(os.path.abspath(__file__)) + sys_test_dir = os.path.join(root_dir, 'system_perfstress') + main_loop = PerfStressRunner(test_folder_path=sys_test_dir) + loop = asyncio.get_event_loop() + loop.run_until_complete(main_loop.start()) diff --git a/src/azure_devtools/perfstress_tests/async_random_stream.py b/src/azure_devtools/perfstress_tests/async_random_stream.py new file mode 100644 index 0000000..e483bc0 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/async_random_stream.py @@ -0,0 +1,57 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from io import BytesIO + +from .random_stream import get_random_bytes, _DEFAULT_LENGTH + + +class AsyncRandomStream(BytesIO): + def __init__(self, length, initial_buffer_length=_DEFAULT_LENGTH): + super().__init__() + self._base_data = get_random_bytes(initial_buffer_length) + self._data_length = length + self._base_buffer_length = initial_buffer_length + self._position = 0 + self._remaining = length + self._closed = False + + def reset(self): + self._position = 0 + self._remaining = self._data_length + self._closed = False + + def read(self, size=None): + if self._remaining == 0: + return b"" + + if size is None: + e = self._base_buffer_length + else: + e = size + e = min(e, self._remaining) + if e > self._base_buffer_length: + self._base_data = get_random_bytes(e) + self._base_buffer_length = e + self._remaining = self._remaining - e + self._position += e + return self._base_data[:e] + + def seek(self, index, whence=0): + if whence == 0: + self._position = index + elif whence == 1: + self._position = self._position + index + elif whence == 2: + self._position = self._data_length - 1 + index + + def tell(self): + return self._position + + def remaining(self): + return self._remaining + + def close(self): + self._closed = True diff --git a/src/azure_devtools/perfstress_tests/perf_stress_runner.py b/src/azure_devtools/perfstress_tests/perf_stress_runner.py new file mode 100644 index 0000000..861fb84 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/perf_stress_runner.py @@ -0,0 +1,201 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import argparse +import asyncio +import time +import inspect +import logging +import os +import pkgutil +import sys +import threading + +from .perf_stress_test import PerfStressTest +from .repeated_timer import RepeatedTimer + + +class PerfStressRunner: + def __init__(self, test_folder_path=None): + if test_folder_path is None: + # Use current working directory + test_folder_path = os.getcwd() + + self.logger = logging.getLogger(__name__) + self.logger.setLevel(level=logging.INFO) + handler = logging.StreamHandler() + handler.setLevel(level=logging.INFO) + self.logger.addHandler(handler) + + #NOTE: If you need to support registering multiple test locations, move this into Initialize, call lazily on Run, expose RegisterTestLocation function. + self._discover_tests(test_folder_path) + self._parse_args() + + def _get_completed_operations(self): + return sum(self._completed_operations) + + def _get_operations_per_second(self): + return sum(map( + lambda x: x[0] / x[1] if x[1] else 0, + zip(self._completed_operations, self._last_completion_times))) + + def _parse_args(self): + # First, detect which test we're running. + arg_parser = argparse.ArgumentParser( + description='Python Perf Test Runner', + usage='{} []'.format(__file__)) + + # NOTE: remove this and add another help string to query for available tests + # if/when # of classes become enough that this isn't practical. + arg_parser.add_argument('test', help='Which test to run. Supported tests: {}'.format(" ".join(sorted(self._test_classes.keys())))) + + args = arg_parser.parse_args(sys.argv[1:2]) + try: + self._test_class_to_run = self._test_classes[args.test] + except KeyError as e: + self.logger.error("Invalid test: {}\n Test must be one of: {}\n".format(args.test, " ".join(sorted(self._test_classes.keys())))) + raise + + # Next, parse args for that test. We also do global args here too so as not to confuse the initial test parse. + per_test_arg_parser = argparse.ArgumentParser( + description=self._test_class_to_run.__doc__ or args.test, + usage='{} {} []'.format(__file__, args.test)) + + # Global args + per_test_arg_parser.add_argument('-p', '--parallel', nargs='?', type=int, help='Degree of parallelism to run with. Default is 1.', default=1) + per_test_arg_parser.add_argument('-d', '--duration', nargs='?', type=int, help='Duration of the test in seconds. Default is 10.', default=10) + per_test_arg_parser.add_argument('-i', '--iterations', nargs='?', type=int, help='Number of iterations in the main test loop. Default is 1.', default=1) + per_test_arg_parser.add_argument('-w', '--warmup', nargs='?', type=int, help='Duration of warmup in seconds. Default is 5.', default=5) + per_test_arg_parser.add_argument('--no-cleanup', action='store_true', help='Do not run cleanup logic. Default is false.', default=False) + per_test_arg_parser.add_argument('--sync', action='store_true', help='Run tests in sync mode. Default is False.', default=False) + + # Per-test args + self._test_class_to_run.add_arguments(per_test_arg_parser) + self.per_test_args = per_test_arg_parser.parse_args(sys.argv[2:]) + + self.logger.info("") + self.logger.info("=== Options ===") + self.logger.info(args) + self.logger.info(self.per_test_args) + self.logger.info("") + + def _discover_tests(self, test_folder_path): + self._test_classes = {} + + # Dynamically enumerate all python modules under the tests path for classes that implement PerfStressTest + for loader, name, _ in pkgutil.walk_packages([test_folder_path]): + try: + module = loader.find_module(name).load_module(name) + except Exception as e: + self.logger.warn("Unable to load module {}: {}".format(name, e)) + continue + for name, value in inspect.getmembers(module): + + if name.startswith('_'): + continue + if inspect.isclass(value) and issubclass(value, PerfStressTest) and value != PerfStressTest: + self.logger.info("Loaded test class: {}".format(name)) + self._test_classes[name] = value + + async def start(self): + self.logger.info("=== Setup ===") + + tests = [] + for _ in range(0, self.per_test_args.parallel): + tests.append(self._test_class_to_run(self.per_test_args)) + + try: + try: + await tests[0].global_setup() + try: + await asyncio.gather(*[test.setup() for test in tests]) + + self.logger.info("") + + if self.per_test_args.warmup > 0: + await self._run_tests(tests, self.per_test_args.warmup, "Warmup") + + for i in range(0, self.per_test_args.iterations): + title = "Test" + if self.per_test_args.iterations > 1: + title += " " + (i + 1) + await self._run_tests(tests, self.per_test_args.duration, title) + except Exception as e: + print("Exception: " + str(e)) + finally: + if not self.per_test_args.no_cleanup: + self.logger.info("=== Cleanup ===") + await asyncio.gather(*[test.cleanup() for test in tests]) + except Exception as e: + print("Exception: " + str(e)) + finally: + if not self.per_test_args.no_cleanup: + await tests[0].global_cleanup() + except Exception as e: + print("Exception: " + str(e)) + finally: + await asyncio.gather(*[test.close() for test in tests]) + + async def _run_tests(self, tests, duration, title): + self._completed_operations = [0] * len(tests) + self._last_completion_times = [0] * len(tests) + self._last_total_operations = -1 + + status_thread = RepeatedTimer(1, self._print_status, title) + + if self.per_test_args.sync: + threads = [] + for id, test in enumerate(tests): + thread = threading.Thread(target=lambda: self._run_sync_loop(test, duration, id)) + threads.append(thread) + thread.start() + for thread in threads: + thread.join() + else: + await asyncio.gather(*[self._run_async_loop(test, duration, id) for id, test in enumerate(tests)]) + + status_thread.stop() + + self.logger.info("") + self.logger.info("=== Results ===") + + total_operations = self._get_completed_operations() + operations_per_second = self._get_operations_per_second() + seconds_per_operation = 1 / operations_per_second + weighted_average_seconds = total_operations / operations_per_second + + self.logger.info("Completed {:,} operations in a weighted-average of {:,.2f}s ({:,.2f} ops/s, {:,.3f} s/op)".format( + total_operations, weighted_average_seconds, operations_per_second, seconds_per_operation)) + self.logger.info("") + + def _run_sync_loop(self, test, duration, id): + start = time.time() + runtime = 0 + while runtime < duration: + test.run_sync() + runtime = time.time() - start + self._completed_operations[id] += 1 + self._last_completion_times[id] = runtime + + async def _run_async_loop(self, test, duration, id): + start = time.time() + runtime = 0 + while runtime < duration: + await test.run_async() + runtime = time.time() - start + self._completed_operations[id] += 1 + self._last_completion_times[id] = runtime + + def _print_status(self, title): + if self._last_total_operations == -1: + self._last_total_operations = 0 + self.logger.info("=== {} ===\nCurrent\t\tTotal\t\tAverage".format(title)) + + total_operations = self._get_completed_operations() + current_operations = total_operations - self._last_total_operations + average_operations = self._get_operations_per_second() + + self._last_total_operations = total_operations + self.logger.info("{}\t\t{}\t\t{:.2f}".format(current_operations, total_operations, average_operations)) diff --git a/src/azure_devtools/perfstress_tests/perf_stress_test.py b/src/azure_devtools/perfstress_tests/perf_stress_test.py new file mode 100644 index 0000000..c59a797 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/perf_stress_test.py @@ -0,0 +1,63 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import os + + +class PerfStressTest: + '''Base class for implementing a python perf test. + + - run_sync and run_async must be implemented. + - global_setup and global_cleanup are optional and run once, ever, regardless of parallelism. + - setup and cleanup are run once per test instance (where each instance runs in its own thread/process), regardless of #iterations. + - close is run once per test instance, after cleanup and global_cleanup. + - run_sync/run_async are run once per iteration. + ''' + args = {} + + def __init__(self, arguments): + self.args = arguments + + async def global_setup(self): + return + + async def global_cleanup(self): + return + + async def setup(self): + return + + async def cleanup(self): + return + + async def close(self): + return + + def __enter__(self): + return + + def __exit__(self, exc_type, exc_value, traceback): + return + + def run_sync(self): + raise Exception('run_sync must be implemented for {}'.format(self.__class__.__name__)) + + async def run_async(self): + raise Exception('run_async must be implemented for {}'.format(self.__class__.__name__)) + + @staticmethod + def add_arguments(parser): + """ + Override this method to add test-specific argparser args to the class. + These are accessible in __init__() and the self.args property. + """ + return + + @staticmethod + def get_from_env(variable): + value = os.environ.get(variable) + if not value: + raise Exception("Undefined environment variable {}".format(variable)) + return value diff --git a/src/azure_devtools/perfstress_tests/random_stream.py b/src/azure_devtools/perfstress_tests/random_stream.py new file mode 100644 index 0000000..b68be9b --- /dev/null +++ b/src/azure_devtools/perfstress_tests/random_stream.py @@ -0,0 +1,82 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import os + +_DEFAULT_LENGTH = 1024*1024 +_BYTE_BUFFER = [_DEFAULT_LENGTH, os.urandom(_DEFAULT_LENGTH)] + + +def get_random_bytes(buffer_length): + if buffer_length > _BYTE_BUFFER[0]: + _BYTE_BUFFER[0] = buffer_length + _BYTE_BUFFER[1] = os.urandom(buffer_length) + return _BYTE_BUFFER[1][:buffer_length] + + +class RandomStream: + def __init__(self, length, initial_buffer_length=_DEFAULT_LENGTH): + self._base_data = get_random_bytes(initial_buffer_length) + self._data_length = length + self._base_buffer_length = initial_buffer_length + self._position = 0 + self._remaining = length + + def reset(self): + self._position = 0 + self._remaining = self._data_length + + def read(self, size=None): + if self._remaining == 0: + return b"" + + if size is None: + e = self._base_buffer_length + else: + e = size + e = min(e, self._remaining) + if e > self._base_buffer_length: + self._base_data = get_random_bytes(e) + self._base_buffer_length = e + self._remaining = self._remaining - e + self._position += e + return self._base_data[:e] + + def tell(self): + return self._position + + def seek(self, index, whence=0): + if whence == 0: + self._position = index + elif whence == 1: + self._position = self._position + index + elif whence == 2: + self._position = self._data_length - 1 + index + + def remaining(self): + return self._remaining + + +class WriteStream: + + def __init__(self): + self._position = 0 + + def reset(self): + self._position = 0 + + def write(self, content): + length = len(content) + self._position += length + return length + + def seek(self, index): + self._position = index + + def seekable(self): + return True + + def tell(self): + return self._position diff --git a/src/azure_devtools/perfstress_tests/repeated_timer.py b/src/azure_devtools/perfstress_tests/repeated_timer.py new file mode 100644 index 0000000..9022035 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/repeated_timer.py @@ -0,0 +1,37 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from threading import Timer + +# Credit to https://stackoverflow.com/questions/3393612/run-certain-code-every-n-seconds +class RepeatedTimer(object): + def __init__(self, interval, function, *args, **kwargs): + self._timer = None + self.interval = interval + self.function = function + self.args = args + self.kwargs = kwargs + self.is_running = False + self.start() + + + def _run(self): + self.is_running = False + self.start() + self.function(*self.args, **self.kwargs) + + + def start(self): + if not self.is_running: + #NOTE: If there is a concern about perf impact of this Timer, we'd need to convert to multiprocess and use IPC. + + self._timer = Timer(self.interval, self._run) + self._timer.start() + self.is_running = True + + + def stop(self): + self._timer.cancel() + self.is_running = False \ No newline at end of file diff --git a/src/azure_devtools/perfstress_tests/system_perfstress/__init__.py b/src/azure_devtools/perfstress_tests/system_perfstress/__init__.py new file mode 100644 index 0000000..34913fb --- /dev/null +++ b/src/azure_devtools/perfstress_tests/system_perfstress/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/src/azure_devtools/perfstress_tests/system_perfstress/aiohttp_get_test.py b/src/azure_devtools/perfstress_tests/system_perfstress/aiohttp_get_test.py new file mode 100644 index 0000000..4f0d033 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/system_perfstress/aiohttp_get_test.py @@ -0,0 +1,25 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import aiohttp + +from azure_devtools.perfstress_tests import PerfStressTest + + +class AioHttpGetTest(PerfStressTest): + + async def global_setup(self): + type(self).session = aiohttp.ClientSession() + + async def global_cleanup(self): + await type(self).session.close() + + async def run_async(self): + async with type(self).session.get(self.Arguments.url) as response: + await response.text() + + @staticmethod + def add_arguments(parser): + parser.add_argument('-u', '--url', required=True) diff --git a/src/azure_devtools/perfstress_tests/system_perfstress/httpx_get_test.py b/src/azure_devtools/perfstress_tests/system_perfstress/httpx_get_test.py new file mode 100644 index 0000000..08c6214 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/system_perfstress/httpx_get_test.py @@ -0,0 +1,25 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import httpx + +from azure_devtools.perfstress_tests import PerfStressTest + + +class HttpxGetTest(PerfStressTest): + + async def global_setup(self): + type(self).client = httpx.AsyncClient() + + async def global_cleanup(self): + await type(self).client.aclose() + + async def run_async(self): + response = await type(self).client.get(self.Arguments.url) + _ = response.text + + @staticmethod + def add_arguments(parser): + parser.add_argument('-u', '--url', required=True) diff --git a/src/azure_devtools/perfstress_tests/system_perfstress/no_op_test.py b/src/azure_devtools/perfstress_tests/system_perfstress/no_op_test.py new file mode 100644 index 0000000..ff3f696 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/system_perfstress/no_op_test.py @@ -0,0 +1,14 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from azure_devtools.perfstress_tests import PerfStressTest + + +class NoOpTest(PerfStressTest): + def run_sync(self): + pass + + async def run_async(self): + pass diff --git a/src/azure_devtools/perfstress_tests/system_perfstress/requests_get_test.py b/src/azure_devtools/perfstress_tests/system_perfstress/requests_get_test.py new file mode 100644 index 0000000..ee04fa1 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/system_perfstress/requests_get_test.py @@ -0,0 +1,21 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import requests + +from azure_devtools.perfstress_tests import PerfStressTest + + +class RequestsGetTest(PerfStressTest): + + async def global_setup(self): + type(self).session = requests.Session() + + def run_sync(self): + type(self).session.get(self.Arguments.url).text + + @staticmethod + def add_arguments(parser): + parser.add_argument('-u', '--url', required=True) diff --git a/src/azure_devtools/perfstress_tests/system_perfstress/sleep_test.py b/src/azure_devtools/perfstress_tests/system_perfstress/sleep_test.py new file mode 100644 index 0000000..c4f3786 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/system_perfstress/sleep_test.py @@ -0,0 +1,25 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import math +import time +import asyncio + +from azure_devtools.perfstress_tests import PerfStressTest + + +# Used for verifying the perf framework correctly computes average throughput across parallel tests of different speed +class SleepTest(PerfStressTest): + instance_count = 0 + + def __init__(self, arguments): + type(self).instance_count += 1 + self.seconds_per_operation = math.pow(2, type(self).instance_count) + + def run_sync(self): + time.sleep(self.seconds_per_operation) + + async def run_async(self): + await asyncio.sleep(self.seconds_per_operation) diff --git a/src/azure_devtools/perfstress_tests/system_perfstress/socket_http_get_test.py b/src/azure_devtools/perfstress_tests/system_perfstress/socket_http_get_test.py new file mode 100644 index 0000000..1563c93 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/system_perfstress/socket_http_get_test.py @@ -0,0 +1,33 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import asyncio +from urllib.parse import urlparse + +from azure_devtools.perfstress_tests import PerfStressTest + + +class SocketHttpGetTest(PerfStressTest): + + async def setup(self): + parsed_url = urlparse(self.Arguments.url) + hostname = parsed_url.hostname + port = parsed_url.port + path = parsed_url.path + + message = f'GET {path} HTTP/1.1\r\nHost: {hostname}:{port}\r\n\r\n' + self.message_bytes = message.encode() + self.reader, self.writer = await asyncio.open_connection(parsed_url.hostname, parsed_url.port) + + async def cleanup(self): + self.writer.close() + + async def run_async(self): + self.writer.write(self.message_bytes) + await self.reader.read(200) + + @staticmethod + def add_arguments(parser): + parser.add_argument('-u', '--url', required=True) diff --git a/src/azure_devtools/perfstress_tests/system_perfstress/tornado_get_test.py b/src/azure_devtools/perfstress_tests/system_perfstress/tornado_get_test.py new file mode 100644 index 0000000..adda137 --- /dev/null +++ b/src/azure_devtools/perfstress_tests/system_perfstress/tornado_get_test.py @@ -0,0 +1,22 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from tornado import httpclient + +from azure_devtools.perfstress_tests import PerfStressTest + + +class TornadoGetTest(PerfStressTest): + + async def global_setup(self): + httpclient.AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient") + type(self).client = httpclient.AsyncHTTPClient() + + async def run_async(self): + await type(self).client.fetch(self.Arguments.url) + + @staticmethod + def add_arguments(parser): + parser.add_argument('-u', '--url', required=True) diff --git a/src/azure_devtools/scenario_tests/__init__.py b/src/azure_devtools/scenario_tests/__init__.py index b90b5aa..5fcd349 100644 --- a/src/azure_devtools/scenario_tests/__init__.py +++ b/src/azure_devtools/scenario_tests/__init__.py @@ -4,24 +4,25 @@ # -------------------------------------------------------------------------------------------- from .base import IntegrationTestBase, ReplayableTest, LiveTest -from .exceptions import AzureTestError +from .exceptions import AzureTestError, AzureNameError, NameInUseError, ReservedResourceNameError from .decorators import live_only, record_only, AllowLargeResponse from .patches import mock_in_unit_test, patch_time_sleep_api, patch_long_run_operation_delay from .preparers import AbstractPreparer, SingleValueReplacer from .recording_processors import ( RecordingProcessor, SubscriptionRecordingProcessor, - LargeRequestBodyProcessor, LargeResponseBodyProcessor, LargeResponseBodyReplacer, + LargeRequestBodyProcessor, LargeResponseBodyProcessor, LargeResponseBodyReplacer, AuthenticationMetadataFilter, OAuthRequestResponsesFilter, DeploymentNameReplacer, GeneralNameReplacer, AccessTokenReplacer, RequestUrlNormalizer, ) from .utilities import create_random_name, get_sha1_hash __all__ = ['IntegrationTestBase', 'ReplayableTest', 'LiveTest', - 'AzureTestError', + 'AzureTestError', 'AzureNameError', 'NameInUseError', 'ReservedResourceNameError', 'mock_in_unit_test', 'patch_time_sleep_api', 'patch_long_run_operation_delay', 'AbstractPreparer', 'SingleValueReplacer', 'AllowLargeResponse', 'RecordingProcessor', 'SubscriptionRecordingProcessor', 'LargeRequestBodyProcessor', 'LargeResponseBodyProcessor', 'LargeResponseBodyReplacer', - 'OAuthRequestResponsesFilter', 'DeploymentNameReplacer', 'GeneralNameReplacer', + 'AuthenticationMetadataFilter', 'OAuthRequestResponsesFilter', + 'DeploymentNameReplacer', 'GeneralNameReplacer', 'AccessTokenReplacer', 'RequestUrlNormalizer', 'live_only', 'record_only', 'create_random_name', 'get_sha1_hash'] diff --git a/src/azure_devtools/scenario_tests/base.py b/src/azure_devtools/scenario_tests/base.py index 60fe576..e836d47 100644 --- a/src/azure_devtools/scenario_tests/base.py +++ b/src/azure_devtools/scenario_tests/base.py @@ -91,7 +91,8 @@ class ReplayableTest(IntegrationTestBase): # pylint: disable=too-many-instance- def __init__(self, # pylint: disable=too-many-arguments method_name, config_file=None, recording_dir=None, recording_name=None, recording_processors=None, - replay_processors=None, recording_patches=None, replay_patches=None): + replay_processors=None, recording_patches=None, replay_patches=None, match_body=False, + custom_request_matchers=None): super(ReplayableTest, self).__init__(method_name) self.recording_processors = recording_processors or [] @@ -117,6 +118,11 @@ def __init__(self, # pylint: disable=too-many-arguments filter_headers=self.FILTER_HEADERS ) self.vcr.register_matcher('query', self._custom_request_query_matcher) + if match_body: + self.vcr.match_on += ('body',) + for matcher in custom_request_matchers or []: + self.vcr.register_matcher(matcher.__name__, matcher) + self.vcr.match_on += (matcher.__name__,) self.recording_file = os.path.join( recording_dir, @@ -132,6 +138,9 @@ def __init__(self, # pylint: disable=too-many-arguments def setUp(self): super(ReplayableTest, self).setUp() + if self.is_live and os.environ.get('AZURE_SKIP_LIVE_RECORDING', '').lower() == 'true': + return + # set up cassette cm = self.vcr.use_cassette(self.recording_file) self.cassette = cm.__enter__() diff --git a/src/azure_devtools/scenario_tests/exceptions.py b/src/azure_devtools/scenario_tests/exceptions.py index bdebae0..41d7b06 100644 --- a/src/azure_devtools/scenario_tests/exceptions.py +++ b/src/azure_devtools/scenario_tests/exceptions.py @@ -8,3 +8,16 @@ class AzureTestError(Exception): def __init__(self, error_message): message = 'An error caused by the Azure test harness failed the test: {}' super(AzureTestError, self).__init__(message.format(error_message)) + +class AzureNameError(Exception): + pass + +class NameInUseError(AzureNameError): + def __init__(self, vault_name): + error_message = "A vault with the name {} already exists".format(vault_name) + super(NameInUseError, self).__init__(error_message) + +class ReservedResourceNameError(AzureNameError): + def __init__(self, rg_name): + error_message = "The resource name {} or a part of the name is trademarked / reserved".format(rg_name) + super(ReservedResourceNameError, self).__init__(error_message) \ No newline at end of file diff --git a/src/azure_devtools/scenario_tests/preparers.py b/src/azure_devtools/scenario_tests/preparers.py index 593e1b4..7e503a3 100644 --- a/src/azure_devtools/scenario_tests/preparers.py +++ b/src/azure_devtools/scenario_tests/preparers.py @@ -5,15 +5,25 @@ import contextlib import functools +import logging +import sys +from collections import namedtuple +from threading import Lock from .base import ReplayableTest from .utilities import create_random_name, is_text_payload, trim_kwargs_from_test_function from .recording_processors import RecordingProcessor +from .exceptions import AzureNameError - +_logger = logging.getLogger(__name__) # Core Utility + class AbstractPreparer(object): + _cache_lock = Lock() + _resource_cache = {} + ResourceCacheEntry = namedtuple('ResourceCacheEntry', 'resource_name kwargs preparer') + def __init__(self, name_prefix, name_len, disable_recording=False): self.name_prefix = name_prefix self.name_len = name_len @@ -22,37 +32,140 @@ def __init__(self, name_prefix, name_len, disable_recording=False): self.test_class_instance = None self.live_test = False self.disable_recording = disable_recording + self._cache_key = (self.__class__.__name__,) + self._use_cache = False + self._aggregate_cache_key = None + + def _prepare_create_resource(self, test_class_instance, **kwargs): + self.live_test = not isinstance(test_class_instance, ReplayableTest) + self.test_class_instance = test_class_instance + + # This latter conditional is to triage a specific failure mode: + # If the first cached test run does not have any http traffic, a recording will not have been + # generated, so in_recording will be True even if live_test is false, so a random name would be given. + # In cached mode we need to avoid this because then for tests with recordings, they would not have a moniker. + if (self.live_test or test_class_instance.in_recording) \ + and not (not test_class_instance.is_live and test_class_instance.in_recording and self._use_cache): + resource_name = self.random_name + if not self.live_test and isinstance(self, RecordingProcessor): + test_class_instance.recording_processors.append(self) + else: + resource_name = self.moniker + + _logger.debug("Creating resource %s for %s", resource_name, self.__class__.__name__) + with self.override_disable_recording(): + retries = 4 + for i in range(retries): + try: + parameter_update = self.create_resource( + resource_name, + **kwargs + ) + _logger.debug("Successfully created resource %s", resource_name) + break + except AzureNameError: + if i == retries - 1: + raise + self.resource_random_name = None + resource_name = self.random_name + except Exception as e: + msg = "Preparer failure when creating resource {} for test {}: {}".format( + self.__class__.__name__, + test_class_instance, + e) + while e: + try: + e = e.inner_exception + except AttributeError: + break + try: + msg += "\nDetailed error message: " + str(e.additional_properties['error']['message']) + except (AttributeError, KeyError): + pass + + _logger.error(msg) + raise Exception(msg) + + if parameter_update: + kwargs.update(parameter_update) + + return resource_name, kwargs + def __call__(self, fn): def _preparer_wrapper(test_class_instance, **kwargs): - self.live_test = not isinstance(test_class_instance, ReplayableTest) - self.test_class_instance = test_class_instance - - if self.live_test or test_class_instance.in_recording: - resource_name = self.random_name - if not self.live_test and isinstance(self, RecordingProcessor): - test_class_instance.recording_processors.append(self) + _logger.debug("Entering preparer wrapper for %s and test %s", + self.__class__.__name__, str(test_class_instance)) + + # If a child is cached we must use the same cached resource their equivalent parent did so all the deps line up + child_is_cached = getattr(fn, '__use_cache', False) + # Note: If it is ever desired to make caching inferred, remove this if/throw. + # This ensures that a user must _very specifically say they want caching_ on an item and all parents. + if not self._use_cache and child_is_cached: + raise Exception("""Preparer exception for test {}:\n Child preparers are cached, but parent {} is not. +You must specify use_cache=True in the preparer decorator""".format(test_class_instance, self.__class__.__name__)) + self._use_cache |= child_is_cached + _logger.debug("Child cache status for %s: %s", self.__class__.__name__, child_is_cached) + + # We must use a cache_key that includes our parents, so that we get a cached stack + # matching the desired resource stack. (e.g. if parent resource has specific settings) + try: + aggregate_cache_key = (self._cache_key, kwargs['__aggregate_cache_key']) + except KeyError: # If we're at the root of the cache stack, start with our own key. + aggregate_cache_key = self._cache_key + kwargs['__aggregate_cache_key'] = aggregate_cache_key + self._aggregate_cache_key = aggregate_cache_key + _logger.debug("Aggregate cache key: %s", aggregate_cache_key) + + # If cache is enabled, and the cached resource exists, use it, otherwise create and store. + if self._use_cache and aggregate_cache_key in AbstractPreparer._resource_cache: + _logger.debug("Using cached resource for %s", self.__class__.__name__) + with self._cache_lock: + resource_name, kwargs, _ = AbstractPreparer._resource_cache[aggregate_cache_key] else: - resource_name = self.moniker + resource_name, kwargs = self._prepare_create_resource(test_class_instance, **kwargs) - with self.override_disable_recording(): - parameter_update = self.create_resource( + if self._use_cache: + with self._cache_lock: + if aggregate_cache_key not in AbstractPreparer._resource_cache: + _logger.debug("Storing cached resource for %s", self.__class__.__name__) + AbstractPreparer._resource_cache[aggregate_cache_key] = AbstractPreparer.ResourceCacheEntry(resource_name, kwargs, self) + + if test_class_instance.is_live: + test_class_instance.scrubber.register_name_pair( resource_name, - **kwargs + self.moniker ) - if parameter_update: - kwargs.update(parameter_update) - - trim_kwargs_from_test_function(fn, kwargs) + # We shouldn't trim the same kwargs that we use for deletion, + # we may remove some of the variables we needed to do the delete. + trimmed_kwargs = {k:v for k,v in kwargs.items()} + trim_kwargs_from_test_function(fn, trimmed_kwargs) try: - fn(test_class_instance, **kwargs) - finally: - # Russian Doll - the last declared resource to be deleted first. - self.remove_resource_with_record_override(resource_name, **kwargs) - + try: + import asyncio + except ImportError: + fn(test_class_instance, **trimmed_kwargs) + else: + if asyncio.iscoroutinefunction(fn): + loop = asyncio.get_event_loop() + loop.run_until_complete(fn(test_class_instance, **trimmed_kwargs)) + else: + fn(test_class_instance, **trimmed_kwargs) + finally: + # If we use cache we delay deletion for the end. + # This won't guarantee deletion order, but it will guarantee everything delayed + # does get deleted, in the worst case by getting rid of the RG at the top. + if not (self._use_cache or child_is_cached): + # Russian Doll - the last declared resource to be deleted first. + self.remove_resource_with_record_override(resource_name, **kwargs) + + # _logger.debug("Setting up preparer stack for {}".format(self.__class__.__name__)) setattr(_preparer_wrapper, '__is_preparer', True) + # Inform the next step in the chain (our parent) that we're cached. + if self._use_cache or getattr(fn, '__use_cache', False): + setattr(_preparer_wrapper, '__use_cache', True) functools.update_wrapper(_preparer_wrapper, fn) return _preparer_wrapper @@ -83,6 +196,16 @@ def random_name(self): self.resource_random_name = self.create_random_name() return self.resource_random_name + # The only other design idea I see that doesn't require each preparer to be instrumented + # would be to have a decorator at the top that wraps the rest, but the user would have to define + # the "cache key" themselves which seems riskier (As opposed to as below, where it's defined + # locally that sku and location are the parameters that make a resource unique) + # This also would prevent fine-grained caching where leaf resources are still created. + def set_cache(self, enabled, *args): + # can't use *args expansion directly into a tuple, py27 compat. + self._cache_key = tuple([self.__class__.__name__] + list(args)) + self._use_cache = enabled + def create_resource(self, name, **kwargs): # pylint: disable=unused-argument,no-self-use return {} @@ -93,6 +216,16 @@ def remove_resource_with_record_override(self, name, **kwargs): with self.override_disable_recording(): self.remove_resource(name, **kwargs) + @classmethod + def _perform_pending_deletes(cls): + _logger.debug("Perform all delayed resource removal.") + for resource_name, kwargs, preparer in reversed([e for e in cls._resource_cache.values()]): + try: + _logger.debug("Performing delayed delete for: %s %s", preparer, resource_name) + preparer.remove_resource_with_record_override(resource_name, **kwargs) + except Exception as e: #pylint: disable=broad-except + # Intentionally broad exception to attempt to leave as few orphan resources as possible even on error. + _logger.warning("Exception while performing delayed deletes (this can happen): %s", e) class SingleValueReplacer(RecordingProcessor): # pylint: disable=no-member diff --git a/src/azure_devtools/scenario_tests/recording_processors.py b/src/azure_devtools/scenario_tests/recording_processors.py index 376531d..553e195 100644 --- a/src/azure_devtools/scenario_tests/recording_processors.py +++ b/src/azure_devtools/scenario_tests/recording_processors.py @@ -2,8 +2,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +import six -from .utilities import is_text_payload, is_json_payload +from .utilities import is_text_payload, is_json_payload, is_batch_payload class RecordingProcessor(object): @@ -23,7 +24,10 @@ def replace_header_fn(cls, entity, header, replace_fn): # but we don't want to modify the case of original header key. for key, values in entity['headers'].items(): if key.lower() == header.lower(): - entity['headers'][key] = [replace_fn(v) for v in values] + if isinstance(values, list): + entity['headers'][key] = [replace_fn(v) for v in values] + else: + entity['headers'][key] = replace_fn(values) class SubscriptionRecordingProcessor(RecordingProcessor): @@ -127,6 +131,15 @@ def process_response(self, response): response['body']['string'] = bytes([0] * length) return response +class AuthenticationMetadataFilter(RecordingProcessor): + """Remove authority and tenant discovery requests and responses from recordings. + MSAL sends these requests to obtain non-secret metadata about the token authority. Recording them is unnecessary + because tests use fake credentials during playback that don't invoke MSAL. + """ + def process_request(self, request): + if "/.well-known/openid-configuration" in request.uri or "/common/discovery/instance" in request.uri: + return None + return request class OAuthRequestResponsesFilter(RecordingProcessor): @@ -135,8 +148,9 @@ class OAuthRequestResponsesFilter(RecordingProcessor): def process_request(self, request): # filter request like: # GET https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/token + # POST https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/v2.0/token import re - if not re.match('https://login.microsoftonline.com/([^/]+)/oauth2/token', request.uri): + if not re.match('https://login.microsoftonline.com/([^/]+)/oauth2(?:/v2.0)?/token', request.uri): return request return None @@ -179,17 +193,27 @@ def process_request(self, request): request.uri = request.uri.replace(old, new) if is_text_payload(request) and request.body: - body = str(request.body) + body = six.ensure_str(request.body) if old in body: request.body = body.replace(old, new) + if request.body and request.uri and is_batch_payload(request): + import re + body = six.ensure_str(request.body) + matched_objects = set(re.findall(old, body)) + for matched_object in matched_objects: + request.body = body.replace(matched_object, new) + body = body.replace(matched_object, new) return request def process_response(self, response): for old, new in self.names_name: if is_text_payload(response) and response['body']['string']: - response['body']['string'] = response['body']['string'].replace(old, new) - + try: + response['body']['string'] = response['body']['string'].replace(old, new) + except UnicodeDecodeError: + body = response['body']['string'] + response['body']['string'].decode('utf8', 'backslashreplace').replace(old, new).encode('utf8', 'backslashreplace') self.replace_header(response, 'location', old, new) self.replace_header(response, 'azure-asyncoperation', old, new) diff --git a/src/azure_devtools/scenario_tests/tests/async_tests/test_preparer_async.py b/src/azure_devtools/scenario_tests/tests/async_tests/test_preparer_async.py new file mode 100644 index 0000000..d564bca --- /dev/null +++ b/src/azure_devtools/scenario_tests/tests/async_tests/test_preparer_async.py @@ -0,0 +1,49 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +import asyncio +from azure_devtools.scenario_tests.preparers import AbstractPreparer + +traces = [] + +# Separated into its own file to not disrupt 2.7 code with syntax errors from the use of async. +class _TestPreparer(AbstractPreparer): + def __init__(self, name, use_cache=False): + super(_TestPreparer, self).__init__('test', 20) + self._name = name + self.set_cache(use_cache, name) + + def create_resource(self, name, **kwargs): + traces.append('create ' + self._name) + return {} + + def remove_resource(self, name, **kwargs): + traces.append('remove ' + self._name) + + +class _AsyncTestClassSample(unittest.TestCase): + @_TestPreparer('A') + @_TestPreparer('B') + async def example_async_test(self): + traces.append('ran async') + + @_TestPreparer('A') + @_TestPreparer('B') + def example_test(self): + traces.append('ran sync') + + +def test_preparer_async_handling(): + # Mimic a real test runner, for better compat 2.7 / 3.x + # This test won't work for 2.7, however, because it relies on asyncio. + + suite = unittest.TestSuite() + suite.addTest(_AsyncTestClassSample('example_test')) + suite.addTest(_AsyncTestClassSample('example_async_test')) + unittest.TextTestRunner().run(suite) + + assert len(traces) == 10 + assert traces == ['create A', 'create B', 'ran sync', 'remove B', 'remove A', 'create A', 'create B', 'ran async', 'remove B', 'remove A'] \ No newline at end of file diff --git a/src/azure_devtools/scenario_tests/tests/test_preparer_order.py b/src/azure_devtools/scenario_tests/tests/test_preparer_order.py index b9eb5dd..c7d5bf8 100644 --- a/src/azure_devtools/scenario_tests/tests/test_preparer_order.py +++ b/src/azure_devtools/scenario_tests/tests/test_preparer_order.py @@ -10,9 +10,10 @@ class _TestPreparer(AbstractPreparer): - def __init__(self, name): + def __init__(self, name, use_cache=False): super(_TestPreparer, self).__init__('test', 20) self._name = name + self.set_cache(use_cache, name) def create_resource(self, name, **kwargs): traces.append('create ' + self._name) @@ -28,6 +29,33 @@ class _TestClassSample(unittest.TestCase): def example_test(self): pass +class _CachedTestClassSample(unittest.TestCase): + @_TestPreparer('A', True) + @_TestPreparer('B', True) + def example_test(self): + pass + + @_TestPreparer('A', True) + @_TestPreparer('C', True) + def example_test_2(self): + pass + + @_TestPreparer('A', True) + @_TestPreparer('C', False) + def example_test_3(self): + pass + + @_TestPreparer('A', True) + @_TestPreparer('C', False) + def fail_test(self): + raise Exception("Intentional failure to test cache.") + + @_TestPreparer('PARENT', True) + @_TestPreparer('A', True) + @_TestPreparer('C', True) + def parent_cache_test(self): + pass + def test_preparer_order(): # Mimic a real test runner, for better compat 2.7 / 3.x @@ -40,3 +68,57 @@ def test_preparer_order(): assert traces[1] == 'create B' assert traces[2] == 'remove B' assert traces[3] == 'remove A' + + + +def test_cached_preparer_order(): + # Mimic a real test runner, for better compat 2.7 / 3.x + suite = unittest.TestSuite() + suite.addTest(_CachedTestClassSample('example_test')) + suite.addTest(_CachedTestClassSample('example_test_2')) + suite.addTest(_CachedTestClassSample('example_test_3')) + unittest.TextTestRunner().run(suite) + + assert len(traces) == 5 + assert traces[0] == 'create A' + assert traces[1] == 'create B' + assert traces[2] == 'create C' + assert traces[3] == 'create C' + assert traces[4] == 'remove C' # One of the C's is cached, one is not. + + # Note: unit test runner doesn't trigger the pytest session fixture that deletes resources when all tests are done. + # let's run that manually now to test it. + AbstractPreparer._perform_pending_deletes() + + assert len(traces) == 8 + # we're technically relying on an implementation detail (for earlier versions of python + # dicts did not guarantee ordering by insertion order, later versions do) + # to order removal by relying on dict ordering. + assert traces[5] == 'remove C' + assert traces[6] == 'remove B' + assert traces[7] == 'remove A' + + +def test_cached_preparer_failure(): + # Mimic a real test runner, for better compat 2.7 / 3.x + suite = unittest.TestSuite() + suite.addTest(_CachedTestClassSample('fail_test')) + suite.addTest(_CachedTestClassSample('example_test')) + suite.addTest(_CachedTestClassSample('example_test_2')) + suite.addTest(_CachedTestClassSample('example_test_3')) + unittest.TextTestRunner().run(suite) + AbstractPreparer._perform_pending_deletes() + # the key here is that the cached A and noncached C is used even though the test failed, and successfully removed later. + assert traces == ['create A', 'create C', 'remove C', 'create B', 'create C', 'create C', 'remove C', 'remove C', 'remove B', 'remove A'] + + +def test_cached_preparer_parent_cache_keying(): + # Mimic a real test runner, for better compat 2.7 / 3.x + suite = unittest.TestSuite() + suite.addTest(_CachedTestClassSample('example_test_2')) + suite.addTest(_CachedTestClassSample('example_test_3')) + suite.addTest(_CachedTestClassSample('parent_cache_test')) + unittest.TextTestRunner().run(suite) + AbstractPreparer._perform_pending_deletes() + # The key here is to observe that changing a parent preparer means the child preparers can't utilize a cache from a cache-stack not including that parent. + assert traces == ['create A', 'create C', 'create C', 'remove C', 'create PARENT', 'create A', 'create C', 'remove C', 'remove A', 'remove PARENT', 'remove C', 'remove A'] diff --git a/src/azure_devtools/scenario_tests/tests/test_recording_processor.py b/src/azure_devtools/scenario_tests/tests/test_recording_processor.py index 4d08f5a..7787369 100644 --- a/src/azure_devtools/scenario_tests/tests/test_recording_processor.py +++ b/src/azure_devtools/scenario_tests/tests/test_recording_processor.py @@ -37,6 +37,25 @@ def test_recording_processor_base_class(self): rp.replace_header_fn(request_sample, 'beta', lambda v: 'customized') self.assertSequenceEqual(request_sample['headers']['beta'], ['customized', 'customized']) + def test_recording_processor_headers_as_string(self): + rp = RecordingProcessor() + response_sample = {'body': 'something', 'headers': {'charlie': 'value_1'}} + + rp.replace_header(response_sample, 'charlie', 'value_1', 'replaced_1') + assert response_sample['headers']['charlie'] == 'replaced_1' + + rp.replace_header(response_sample, 'Charlie', 'value_1', 'replaced_1') # case insensitive + assert response_sample['headers']['charlie'] == 'replaced_1' + + rp.replace_header(response_sample, 'Charlie', 'replaced_1', 'replaced_2') # case insensitive + assert response_sample['headers']['charlie'] == 'replaced_2' + + rp.replace_header(response_sample, 'sigma', 'replaced_2', 'replaced_3') # ignore KeyError + assert response_sample['headers']['charlie'] == 'replaced_2' + + rp.replace_header_fn(response_sample, 'charlie', lambda v: 'customized') + assert response_sample['headers']['charlie'] == 'customized' + def test_access_token_processor(self): replaced_subscription_id = 'test_fake_token' rp = AccessTokenReplacer(replaced_subscription_id) diff --git a/src/azure_devtools/scenario_tests/tests/test_replayable_test.py b/src/azure_devtools/scenario_tests/tests/test_replayable_test.py new file mode 100644 index 0000000..f6b7cf5 --- /dev/null +++ b/src/azure_devtools/scenario_tests/tests/test_replayable_test.py @@ -0,0 +1,47 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure_devtools.scenario_tests.base import ReplayableTest + +import pytest + +try: + from unittest import mock +except ImportError: # python < 3.3 + import mock # type: ignore + +VCR = ReplayableTest.__module__ + ".vcr.VCR" + + +def test_default_match_configuration(): + """ReplayableTest should default to VCR's default matching configuration""" + + with mock.patch(VCR) as mock_vcr: + ReplayableTest("__init__") + + assert not any("match_on" in call.kwargs for call in mock_vcr.call_args_list) + + +@pytest.mark.parametrize("opt_in", (True, False, None)) +def test_match_body(opt_in): + """match_body should control opting in to vcr.py's in-box body matching, and default to False""" + + mock_vcr = mock.Mock(match_on=()) + with mock.patch(VCR, lambda *_, **__: mock_vcr): + ReplayableTest("__init__", match_body=opt_in) + + assert ("body" in mock_vcr.match_on) == (opt_in == True) + + +def test_custom_request_matchers(): + """custom request matchers should be registered with vcr.py and added to the default matchers""" + + matcher = mock.Mock(__name__="mock matcher") + + mock_vcr = mock.Mock(match_on=()) + with mock.patch(VCR, lambda *_, **__: mock_vcr): + ReplayableTest("__init__", custom_request_matchers=[matcher]) + + assert mock.call(matcher.__name__, matcher) in mock_vcr.register_matcher.call_args_list + assert matcher.__name__ in mock_vcr.match_on diff --git a/src/azure_devtools/scenario_tests/tests/test_utilities.py b/src/azure_devtools/scenario_tests/tests/test_utilities.py index 8fa6771..898c3d0 100644 --- a/src/azure_devtools/scenario_tests/tests/test_utilities.py +++ b/src/azure_devtools/scenario_tests/tests/test_utilities.py @@ -4,7 +4,10 @@ # -------------------------------------------------------------------------------------------- import unittest -import mock +try: + from unittest import mock +except ImportError: + import mock from azure_devtools.scenario_tests.utilities import (create_random_name, get_sha1_hash, is_text_payload, is_json_payload) @@ -17,7 +20,7 @@ def test_create_random_name_default_value(self): self.assertTrue(isinstance(default_generated_name, str)) def test_create_random_name_randomness(self): - self.assertEqual(100, len({create_random_name() for _ in range(100)})) + self.assertEqual(100, len(set([create_random_name() for _ in range(100)]))) def test_create_random_name_customization(self): customized_name = create_random_name(prefix='pauline', length=61) @@ -81,7 +84,7 @@ def test_get_sha1_hash(self): f.write(content) f.seek(0) hash_value = get_sha1_hash(f.name) - self.assertEqual('1a9ea462ce80aac3f1cacbdf59d3a630df01b933593a2c53bccc25ecc2569e31', hash_value) + self.assertEqual('6487bbdbd848686338d729e6076da1a795d1ae747642bf906469c6ccd9e642f9', hash_value) def test_text_payload(self): http_entity = mock.MagicMock() diff --git a/src/azure_devtools/scenario_tests/utilities.py b/src/azure_devtools/scenario_tests/utilities.py index 98be0bc..f517722 100644 --- a/src/azure_devtools/scenario_tests/utilities.py +++ b/src/azure_devtools/scenario_tests/utilities.py @@ -62,6 +62,12 @@ def is_text_payload(entity): return True +def is_batch_payload(entity): + if _get_content_type(entity) == "multipart/mixed" and "&comp=batch" in entity.uri: + return True + return False + + def is_json_payload(entity): return _get_content_type(entity) == 'application/json' @@ -70,7 +76,10 @@ def trim_kwargs_from_test_function(fn, kwargs): # the next function is the actual test function. the kwargs need to be trimmed so # that parameters which are not required will not be passed to it. if not is_preparer_func(fn): - args, _, kw, _ = inspect.getargspec(fn) # pylint: disable=deprecated-method + try: + args, _, kw, _, _, _, _ = inspect.getfullargspec(fn) + except AttributeError: + args, _, kw, _ = inspect.getargspec(fn) # pylint: disable=deprecated-method if kw is None: args = set(args) for key in [k for k in kwargs if k not in args]: