diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index c4c2b5a18f44..7adaffb06ba5 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -1,6 +1,10 @@ # Release History -## 1.12.1 (Unreleased) +## 1.13.0 (Unreleased) + +### Features + +- Supported adding custom policies #16519 ## 1.12.0 (2021-03-08) diff --git a/sdk/core/azure-core/azure/core/_pipeline_client.py b/sdk/core/azure-core/azure/core/_pipeline_client.py index a75271d14b2a..8e6301347457 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client.py @@ -25,11 +25,18 @@ # -------------------------------------------------------------------------- import logging +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable from .configuration import Configuration from .pipeline import Pipeline from .pipeline.transport._base import PipelineClientBase from .pipeline.policies import ( - ContentDecodePolicy, DistributedTracingPolicy, HttpLoggingPolicy, RequestIdPolicy + ContentDecodePolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, + RequestIdPolicy, ) from .pipeline.transport import RequestsTransport @@ -64,6 +71,10 @@ class PipelineClient(PipelineClientBase): :keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used. :keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned. :keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used. + :keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy + :paramtype per_call_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]] + :keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy + :paramtype per_retry_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]] :keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport. :return: A pipeline object. :rtype: ~azure.core.pipeline.Pipeline @@ -102,20 +113,34 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use policies = kwargs.get('policies') if policies is None: # [] is a valid policy list + per_call_policies = kwargs.get('per_call_policies', []) + per_retry_policies = kwargs.get('per_retry_policies', []) policies = [ RequestIdPolicy(**kwargs), config.headers_policy, config.user_agent_policy, config.proxy_policy, - ContentDecodePolicy(**kwargs), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.custom_hook_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - config.http_logging_policy or HttpLoggingPolicy(**kwargs) + ContentDecodePolicy(**kwargs) ] + if isinstance(per_call_policies, Iterable): + for policy in per_call_policies: + policies.append(policy) + else: + policies.append(per_call_policies) + + policies = policies + [config.redirect_policy, + config.retry_policy, + config.authentication_policy, + config.custom_hook_policy] + if isinstance(per_retry_policies, Iterable): + for policy in per_retry_policies: + policies.append(policy) + else: + policies.append(per_retry_policies) + + policies = policies + [config.logging_policy, + DistributedTracingPolicy(**kwargs), + config.http_logging_policy or HttpLoggingPolicy(**kwargs)] if not transport: transport = RequestsTransport(**kwargs) diff --git a/sdk/core/azure-core/azure/core/_pipeline_client_async.py b/sdk/core/azure-core/azure/core/_pipeline_client_async.py index 3c6917a5e401..f91ae4229243 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client_async.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client_async.py @@ -25,11 +25,15 @@ # -------------------------------------------------------------------------- import logging +from collections.abc import Iterable from .configuration import Configuration from .pipeline import AsyncPipeline from .pipeline.transport._base import PipelineClientBase from .pipeline.policies import ( - ContentDecodePolicy, DistributedTracingPolicy, HttpLoggingPolicy, RequestIdPolicy + ContentDecodePolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, + RequestIdPolicy, ) try: @@ -62,8 +66,14 @@ class AsyncPipelineClient(PipelineClientBase): :param str base_url: URL for the request. :keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used. :keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned. - :keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used. - :keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport. + :keyword list[AsyncHTTPPolicy] policies: If omitted, the standard policies of the configuration object is used. + :keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy + :paramtype per_call_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy, + list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]] + :keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy + :paramtype per_retry_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy, + list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]] + :keyword AsyncHttpTransport transport: If omitted, AioHttpTransport is used for synchronous transport. :return: An async pipeline object. :rtype: ~azure.core.pipeline.AsyncPipeline @@ -101,16 +111,34 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use policies = kwargs.get('policies') if policies is None: # [] is a valid policy list + per_call_policies = kwargs.get('per_call_policies', []) + per_retry_policies = kwargs.get('per_retry_policies', []) policies = [ RequestIdPolicy(**kwargs), config.headers_policy, config.user_agent_policy, config.proxy_policy, - ContentDecodePolicy(**kwargs), + ContentDecodePolicy(**kwargs) + ] + if isinstance(per_call_policies, Iterable): + for policy in per_call_policies: + policies.append(policy) + else: + policies.append(per_call_policies) + + policies = policies + [ config.redirect_policy, config.retry_policy, config.authentication_policy, - config.custom_hook_policy, + config.custom_hook_policy + ] + if isinstance(per_retry_policies, Iterable): + for policy in per_retry_policies: + policies.append(policy) + else: + policies.append(per_retry_policies) + + policies = policies + [ config.logging_policy, DistributedTracingPolicy(**kwargs), config.http_logging_policy or HttpLoggingPolicy(**kwargs) diff --git a/sdk/core/azure-core/azure/core/_version.py b/sdk/core/azure-core/azure/core/_version.py index e89e943b0813..7a5c38e4c326 100644 --- a/sdk/core/azure-core/azure/core/_version.py +++ b/sdk/core/azure-core/azure/core/_version.py @@ -9,4 +9,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "1.12.1" +VERSION = "1.13.0" diff --git a/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py b/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py index 66ec79abe307..7fdc48057b92 100644 --- a/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py @@ -29,6 +29,7 @@ from azure.core.pipeline.policies import ( SansIOHTTPPolicy, UserAgentPolicy, + AsyncRetryPolicy, AsyncRedirectPolicy, AsyncHTTPPolicy, AsyncRetryPolicy, @@ -219,4 +220,69 @@ def send(*args): policies = [AsyncRetryPolicy(), NaughtyPolicy()] pipeline = AsyncPipeline(policies=policies, transport=None) with pytest.raises(AzureError): - await pipeline.run(HttpRequest('GET', url='https://foo.bar')) \ No newline at end of file + await pipeline.run(HttpRequest('GET', url='https://foo.bar')) + +@pytest.mark.asyncio +async def test_add_custom_policy(): + class BooPolicy(AsyncHTTPPolicy): + def send(*args): + raise AzureError('boo') + + class FooPolicy(AsyncHTTPPolicy): + def send(*args): + raise AzureError('boo') + + config = Configuration() + retry_policy = AsyncRetryPolicy() + config.retry_policy = retry_policy + boo_policy = BooPolicy() + foo_policy = FooPolicy() + client = AsyncPipelineClient(base_url="test", config=config, per_call_policies=boo_policy) + policies = client._pipeline._impl_policies + assert boo_policy in policies + pos_boo = policies.index(boo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo < pos_retry + + client = AsyncPipelineClient(base_url="test", config=config, per_call_policies=[boo_policy]) + policies = client._pipeline._impl_policies + assert boo_policy in policies + pos_boo = policies.index(boo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo < pos_retry + + client = AsyncPipelineClient(base_url="test", config=config, per_retry_policies=boo_policy) + policies = client._pipeline._impl_policies + assert boo_policy in policies + pos_boo = policies.index(boo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo > pos_retry + + client = AsyncPipelineClient(base_url="test", config=config, per_retry_policies=[boo_policy]) + policies = client._pipeline._impl_policies + assert boo_policy in policies + pos_boo = policies.index(boo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo > pos_retry + + client = AsyncPipelineClient(base_url="test", config=config, per_call_policies=boo_policy, + per_retry_policies=foo_policy) + policies = client._pipeline._impl_policies + assert boo_policy in policies + assert foo_policy in policies + pos_boo = policies.index(boo_policy) + pos_foo = policies.index(foo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo < pos_retry + assert pos_foo > pos_retry + + client = AsyncPipelineClient(base_url="test", config=config, per_call_policies=[boo_policy], + per_retry_policies=[foo_policy]) + policies = client._pipeline._impl_policies + assert boo_policy in policies + assert foo_policy in policies + pos_boo = policies.index(boo_policy) + pos_foo = policies.index(foo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo < pos_retry + assert pos_foo > pos_retry diff --git a/sdk/core/azure-core/tests/test_pipeline.py b/sdk/core/azure-core/tests/test_pipeline.py index 43cf9def610b..153775c8d992 100644 --- a/sdk/core/azure-core/tests/test_pipeline.py +++ b/sdk/core/azure-core/tests/test_pipeline.py @@ -51,7 +51,10 @@ SansIOHTTPPolicy, UserAgentPolicy, RedirectPolicy, - HttpLoggingPolicy + RetryPolicy, + HttpLoggingPolicy, + HTTPPolicy, + SansIOHTTPPolicy ) from azure.core.pipeline.transport._base import PipelineClientBase from azure.core.pipeline.transport import ( @@ -332,6 +335,69 @@ def test_repr(self): request = HttpRequest("GET", "hello.com") assert repr(request) == "" + def test_add_custom_policy(self): + class BooPolicy(HTTPPolicy): + def send(*args): + raise AzureError('boo') + + class FooPolicy(HTTPPolicy): + def send(*args): + raise AzureError('boo') + + config = Configuration() + retry_policy = RetryPolicy() + config.retry_policy = retry_policy + boo_policy = BooPolicy() + foo_policy = FooPolicy() + client = PipelineClient(base_url="test", config=config, per_call_policies=boo_policy) + policies = client._pipeline._impl_policies + assert boo_policy in policies + pos_boo = policies.index(boo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo < pos_retry + + client = PipelineClient(base_url="test", config=config, per_call_policies=[boo_policy]) + policies = client._pipeline._impl_policies + assert boo_policy in policies + pos_boo = policies.index(boo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo < pos_retry + + client = PipelineClient(base_url="test", config=config, per_retry_policies=boo_policy) + policies = client._pipeline._impl_policies + assert boo_policy in policies + pos_boo = policies.index(boo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo > pos_retry + + client = PipelineClient(base_url="test", config=config, per_retry_policies=[boo_policy]) + policies = client._pipeline._impl_policies + assert boo_policy in policies + pos_boo = policies.index(boo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo > pos_retry + + client = PipelineClient(base_url="test", config=config, per_call_policies=boo_policy, per_retry_policies=foo_policy) + policies = client._pipeline._impl_policies + assert boo_policy in policies + assert foo_policy in policies + pos_boo = policies.index(boo_policy) + pos_foo = policies.index(foo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo < pos_retry + assert pos_foo > pos_retry + + client = PipelineClient(base_url="test", config=config, per_call_policies=[boo_policy], + per_retry_policies=[foo_policy]) + policies = client._pipeline._impl_policies + assert boo_policy in policies + assert foo_policy in policies + pos_boo = policies.index(boo_policy) + pos_foo = policies.index(foo_policy) + pos_retry = policies.index(retry_policy) + assert pos_boo < pos_retry + assert pos_foo > pos_retry + if __name__ == "__main__": unittest.main()