Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Change Log

### Unreleased
Modelerfour version: 4.13.351

**New Features**

- We have added a `--credential-default-policy-type` flag. Its default value is `BearerTokenCredentialPolicy`, but it can also accept
`AzureKeyCredentialPolicy`. The value passed in will be the default authentication policy in the client's config, so users using the
generated library will use that auth policy unless they pass in a separate one through kwargs #686

### 2020-06-08 - 5.1.0-preview.2
Modelerfour version: 4.13.351

Expand Down
40 changes: 36 additions & 4 deletions autorest/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@
from .serializers import JinjaSerializer


_LOGGER = logging.getLogger(__name__)

def _get_credential_default_policy_type_has_async_version(credential_default_policy_type: str) -> bool:
mapping = {
"BearerTokenCredentialPolicy": True,
"AzureKeyCredentialPolicy": False
}
return mapping[credential_default_policy_type]

_LOGGER = logging.getLogger(__name__)
class CodeGenerator(Plugin):
@staticmethod
def remove_cloud_errors(yaml_data: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -138,8 +143,31 @@ def _build_code_model_options(self) -> Dict[str, Any]:
"For example: --credential-scopes=https://cognitiveservices.azure.com/.default"
)

passed_in_credential_default_policy_type = (
self._autorestapi.get_value("credential-default-policy-type") or "BearerTokenCredentialPolicy"
)

# right now, we only allow BearerTokenCredentialPolicy and AzureKeyCredentialPolicy
allowed_policies = ["BearerTokenCredentialPolicy", "AzureKeyCredentialPolicy"]
try:
credential_default_policy_type = [
cp for cp in allowed_policies if cp.lower() == passed_in_credential_default_policy_type.lower()
][0]
except IndexError:
raise ValueError(
"The credential you pass in with --credential-default-policy-type must be either "
"BearerTokenCredentialPolicy or AzureKeyCredentialPolicy"
)

if credential_scopes and credential_default_policy_type != "BearerTokenCredentialPolicy":
_LOGGER.warning(
"You have --credential-default-policy-type not set as BearerTokenCredentialPolicy and a value for "
"--credential-scopes. Since credential scopes are tied to the BearerTokenCredentialPolicy, "
"we will ignore your credential scopes."
)
credential_scopes = []

if not credential_scopes:
elif not credential_scopes and credential_default_policy_type == "BearerTokenCredentialPolicy":
if azure_arm:
credential_scopes = ["https://management.azure.com/.default"]
elif credential:
Expand Down Expand Up @@ -176,7 +204,11 @@ def _build_code_model_options(self) -> Dict[str, Any]:
"package_version": self._autorestapi.get_value("package-version"),
"client_side_validation": self._autorestapi.get_boolean_value("client-side-validation", False),
"tracing": self._autorestapi.get_boolean_value("trace", False),
"multiapi": self._autorestapi.get_boolean_value("multiapi", False)
"multiapi": self._autorestapi.get_boolean_value("multiapi", False),
"credential_default_policy_type": credential_default_policy_type,
"credential_default_policy_type_has_async_version": (
_get_credential_default_policy_type_has_async_version(credential_default_policy_type)
)
}

if options["basic_setup_py"] and not options["package_version"]:
Expand Down
4 changes: 3 additions & 1 deletion autorest/codegen/templates/config.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,7 @@ class {{ code_model.class_name }}Configuration(Configuration):
raise ValueError("You must provide either credential_scopes or authentication_policy as kwargs")
{% endif %}
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.{{ async_prefix }}BearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs)
{% set credential_default_policy_type = ("Async" if (async_mode and code_model.options['credential_default_policy_type_has_async_version']) else "") + code_model.options['credential_default_policy_type'] %}
{% set bearer_token_specific_params = "*self.credential_scopes, " %}
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ bearer_token_specific_params if "BearerTokenCredentialPolicy" in credential_default_policy_type }}**kwargs)
{% endif %}
4 changes: 3 additions & 1 deletion autorest/codegen/templates/metadata.json.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
},
"config": {
"credential": {{ code_model.options['credential'] | tojson }},
"credential_scopes": {{ code_model.options['credential_scopes'] | tojson }}
"credential_scopes": {{ code_model.options['credential_scopes'] | tojson }},
"credential_default_policy_type": {{ code_model.options['credential_default_policy_type'] | tojson }},
"credential_default_policy_type_has_async_version": {{ code_model.options['credential_default_policy_type_has_async_version'] | tojson }}
},
"operation_groups": {
{% for operation_group in code_model.operation_groups %}
Expand Down
4 changes: 3 additions & 1 deletion autorest/multiapi/templates/multiapi_config.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,7 @@ class {{ client_name }}Configuration(Configuration):
raise ValueError("You must provide either credential_scopes or authentication_policy as kwargs")
{% endif %}
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.{{ async_prefix }}BearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs)
{% set credential_default_policy_type = ("Async" if (async_mode and config['credential_default_policy_type_has_async_version']) else "") + config['credential_default_policy_type'] %}
{% set bearer_token_specific_params = "*self.credential_scopes, " %}
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ bearer_token_specific_params if "BearerTokenCredentialPolicy" in credential_default_policy_type }}**kwargs)
{% endif %}
53 changes: 40 additions & 13 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def regen_expected(c, opts, debug):
args.append(f"--override-info.title={opts['override-info.title']}")
if opts.get('override-info.description'):
args.append(f"--override-info.description={opts['override-info.description']}")
if opts.get('credential-default-policy-type'):
args.append(f"--credential-default-policy-type={opts['credential-default-policy-type']}")

cmd_line = '{} {}'.format(_AUTOREST_CMD_LINE, " ".join(args))
print(Fore.YELLOW + f'Queuing up: {cmd_line}')
Expand Down Expand Up @@ -244,6 +246,21 @@ def regenerate_namespace_folders_test(c, debug=False):
}
regen_expected(c, opts, debug)

@task
def regenerate_credential_default_policy(c, debug=False):
default_mapping = {'AcceptanceTests/HeadWithAzureKeyCredentialPolicy': 'head.json'}
opts = {
'output_base_dir': 'test/azure',
'input_base_dir': swagger_dir,
'mappings': default_mapping,
'output_dir': 'Expected',
'azure_arm': True,
'flattening_threshold': '1',
'ns_prefix': True,
'credential-default-policy-type': 'AzureKeyCredentialPolicy'
}
regen_expected(c, opts, debug)

@task
def regenerate(c, swagger_name=None, debug=False):
# regenerate expected code for tests
Expand All @@ -253,6 +270,7 @@ def regenerate(c, swagger_name=None, debug=False):
if not swagger_name:
regenerate_namespace_folders_test(c, debug)
regenerate_multiapi(c, debug)
regenerate_credential_default_policy(c, debug)


@task
Expand Down Expand Up @@ -302,16 +320,25 @@ def _multiapi_command_line(location):
)

@task
def regenerate_multiapi(c, debug=False):
cmds = []
# create basic multiapi client (package-name=multapi)
cmds.append(_multiapi_command_line("test/multiapi/specification/multiapi/README.md"))
# create multiapi client with submodule (package-name=multiapi#submodule)
cmds.append(_multiapi_command_line("test/multiapi/specification/multiapiwithsubmodule/README.md"))
# create multiapi client with no aio folder (package-name=multiapinoasync)
cmds.append(_multiapi_command_line("test/multiapi/specification/multiapinoasync/README.md"))
with Pool() as pool:
result = pool.map(run_autorest, cmds)
success = all(result)
if not success:
raise SystemExit("Autorest generation fails")
def regenerate_multiapi(c, debug=False, swagger_name="test"):
# being hacky: making default swagger_name 'test', since it appears in each spec name
available_specifications = [
# create basic multiapi client (package-name=multapi)
"test/multiapi/specification/multiapi/README.md",
# create multiapi client with submodule (package-name=multiapi#submodule)
"test/multiapi/specification/multiapiwithsubmodule/README.md",
# create multiapi client with no aio folder (package-name=multiapinoasync)
"test/multiapi/specification/multiapinoasync/README.md",
# create multiapi client with AzureKeyCredentialPolicy
"test/multiapi/specification/multiapicredentialdefaultpolicy/README.md"
]

cmds = [_multiapi_command_line(spec) for spec in available_specifications if swagger_name.lower() in spec]

if len(cmds) == 1:
success = run_autorest(cmds[0], debug=debug)
else:
# Execute actual taks in parallel
with Pool() as pool:
result = pool.map(run_autorest, cmds)
success = all(result)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from ._auto_rest_head_test_service import AutoRestHeadTestService
from ._version import VERSION

__version__ = VERSION
__all__ = ['AutoRestHeadTestService']

try:
from ._patch import patch_sdk # type: ignore
patch_sdk()
except ImportError:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING

from azure.mgmt.core import ARMPipelineClient
from msrest import Deserializer, Serializer

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Dict, Optional

from azure.core.credentials import TokenCredential

from ._configuration import AutoRestHeadTestServiceConfiguration
from .operations import HttpSuccessOperations


class AutoRestHeadTestService(object):
"""Test Infrastructure for AutoRest.

:ivar http_success: HttpSuccessOperations operations
:vartype http_success: headwithazurekeycredentialpolicy.operations.HttpSuccessOperations
:param credential: Credential needed for the client to connect to Azure.
:type credential: ~azure.core.credentials.TokenCredential
:param str base_url: Service URL
:keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present.
"""

def __init__(
self,
credential, # type: "TokenCredential"
base_url=None, # type: Optional[str]
**kwargs # type: Any
):
# type: (...) -> None
if not base_url:
base_url = 'http://localhost:3000'
self._config = AutoRestHeadTestServiceConfiguration(credential, **kwargs)
self._client = ARMPipelineClient(base_url=base_url, config=self._config, **kwargs)

client_models = {} # type: Dict[str, Any]
self._serialize = Serializer(client_models)
self._deserialize = Deserializer(client_models)

self.http_success = HttpSuccessOperations(
self._client, self._config, self._serialize, self._deserialize)

def close(self):
# type: () -> None
self._client.close()

def __enter__(self):
# type: () -> AutoRestHeadTestService
self._client.__enter__()
return self

def __exit__(self, *exc_details):
# type: (Any) -> None
self._client.__exit__(*exc_details)
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING

from azure.core.configuration import Configuration
from azure.core.pipeline import policies

from ._version import VERSION

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any

from azure.core.credentials import TokenCredential


class AutoRestHeadTestServiceConfiguration(Configuration):
"""Configuration for AutoRestHeadTestService.

Note that all parameters used to create this instance are saved as instance
attributes.

:param credential: Credential needed for the client to connect to Azure.
:type credential: ~azure.core.credentials.TokenCredential
"""

def __init__(
self,
credential, # type: "TokenCredential"
**kwargs # type: Any
):
# type: (...) -> None
if credential is None:
raise ValueError("Parameter 'credential' must not be None.")
super(AutoRestHeadTestServiceConfiguration, self).__init__(**kwargs)

self.credential = credential
kwargs.setdefault('sdk_moniker', 'autorestheadtestservice/{}'.format(VERSION))
self._configure(**kwargs)

def _configure(
self,
**kwargs # type: Any
):
# type: (...) -> None
self.user_agent_policy = kwargs.get('user_agent_policy') or policies.UserAgentPolicy(**kwargs)
self.headers_policy = kwargs.get('headers_policy') or policies.HeadersPolicy(**kwargs)
self.proxy_policy = kwargs.get('proxy_policy') or policies.ProxyPolicy(**kwargs)
self.logging_policy = kwargs.get('logging_policy') or policies.NetworkTraceLoggingPolicy(**kwargs)
self.retry_policy = kwargs.get('retry_policy') or policies.RetryPolicy(**kwargs)
self.custom_hook_policy = kwargs.get('custom_hook_policy') or policies.CustomHookPolicy(**kwargs)
self.redirect_policy = kwargs.get('redirect_policy') or policies.RedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

VERSION = "0.1.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) AutoRest Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from ._auto_rest_head_test_service_async import AutoRestHeadTestService
__all__ = ['AutoRestHeadTestService']
Loading