From 229115c5b1072513c5741fe61e9998a32cde6f84 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 30 Sep 2019 10:46:56 -0700 Subject: [PATCH 1/6] aio EnvironmentCredential doesn't support user auth --- .../azure/identity/aio/_credentials/environment.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py index b045d07ddbb4..d037452f9e5c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py @@ -24,11 +24,6 @@ class EnvironmentCredential: - **AZURE_CLIENT_ID**: the service principal's client ID - **AZURE_CLIENT_CERTIFICATE_PATH**: path to a PEM-encoded certificate file including the private key - **AZURE_TENANT_ID**: ID of the service principal's tenant. Also called its 'directory' ID. - - User with username and password: - - **AZURE_CLIENT_ID**: the application's client ID - - **AZURE_USERNAME**: a username (usually an email address) - - **AZURE_PASSWORD**: that user's password """ def __init__(self, **kwargs: "Any") -> None: From 4099c62621c90ad2ee8db63cc83c71525f56baba Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 30 Sep 2019 12:02:06 -0700 Subject: [PATCH 2/6] known authorities --- sdk/identity/azure-identity/azure/identity/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index 60d2ba4af91c..54cda2ee9f10 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -2,10 +2,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from ._constants import EnvironmentVariables, KnownAuthorities +from ._constants import KnownAuthorities from ._credentials import ( AuthorizationCodeCredential, - CertificateCredential, ChainedTokenCredential, ClientSecretCredential, @@ -27,7 +26,6 @@ "DefaultAzureCredential", "DeviceCodeCredential", "EnvironmentCredential", - "EnvironmentVariables", "InteractiveBrowserCredential", "KnownAuthorities", "ManagedIdentityCredential", From 0a1986cf56952a93bf3b641eeed3ee1ae9423a61 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 30 Sep 2019 12:06:18 -0700 Subject: [PATCH 3/6] AuthnClient accepts endpoint or authority + tenant --- .../azure/identity/_authn_client.py | 25 +++++++--- .../_credentials/client_credential.py | 15 ++++-- .../identity/_credentials/managed_identity.py | 8 +-- .../azure/identity/_credentials/user.py | 8 ++- .../azure/identity/aio/_authn_client.py | 3 +- .../aio/_credentials/client_credential.py | 15 ++++-- .../azure/identity/aio/_credentials/user.py | 8 ++- .../azure-identity/tests/test_authn_client.py | 49 ++++++++++++------- .../tests/test_authn_client_async.py | 28 +++++++++++ 9 files changed, 118 insertions(+), 41 deletions(-) create mode 100644 sdk/identity/azure-identity/tests/test_authn_client_async.py diff --git a/sdk/identity/azure-identity/azure/identity/_authn_client.py b/sdk/identity/azure-identity/azure/identity/_authn_client.py index dbb0d0028fa9..9ff536ebc4b3 100644 --- a/sdk/identity/azure-identity/azure/identity/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/_authn_client.py @@ -15,7 +15,7 @@ from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, ProxyPolicy, RetryPolicy from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy from azure.core.pipeline.transport import RequestsTransport -from azure.identity._constants import AZURE_CLI_CLIENT_ID +from azure.identity._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities try: ABC = abc.ABC @@ -39,12 +39,22 @@ class AuthnClientBase(ABC): """Sans I/O authentication client methods""" - def __init__(self, auth_url, **kwargs): # pylint:disable=unused-argument - # type: (str, **Any) -> None - if not auth_url: - raise ValueError("auth_url should be the URL of an OAuth endpoint") + def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pylint:disable=unused-argument + # type: (Optional[str], Optional[str], Optional[str], **Any) -> None super(AuthnClientBase, self).__init__() - self._auth_url = auth_url + if authority and endpoint: + raise ValueError( + "'authority' and 'endpoint' are mutually exclusive. 'authority' should be the authority of an AAD" + + " endpoint, whereas 'endpoint' should be the endpoint's full URL." + ) + + if endpoint: + self._auth_url = endpoint + else: + if not tenant: + raise ValueError("'tenant' is required") + authority = authority or KnownAuthorities.AZURE_PUBLIC_CLOUD + self._auth_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/"), "oauth2/v2.0/token")) self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache def get_cached_token(self, scopes): @@ -165,7 +175,6 @@ class AuthnClient(AuthnClientBase): # pylint:disable=missing-client-constructor-parameter-credential def __init__( self, - auth_url, # type: str config=None, # type: Optional[Configuration] policies=None, # type: Optional[Iterable[HTTPPolicy]] transport=None, # type: Optional[HttpTransport] @@ -182,7 +191,7 @@ def __init__( if not transport: transport = RequestsTransport(**kwargs) self._pipeline = Pipeline(transport=transport, policies=policies) - super(AuthnClient, self).__init__(auth_url, **kwargs) + super(AuthnClient, self).__init__(**kwargs) def request_token( self, diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_credential.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_credential.py index 514adfc03249..d3afd2dff697 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_credential.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_credential.py @@ -4,7 +4,6 @@ # ------------------------------------ from .._authn_client import AuthnClient from .._base import ClientSecretCredentialBase, CertificateCredentialBase -from .._constants import Endpoints try: from typing import TYPE_CHECKING @@ -24,12 +23,17 @@ class ClientSecretCredential(ClientSecretCredentialBase): :param str client_id: the service principal's client ID :param str secret: one of the service principal's client secrets :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. """ def __init__(self, client_id, secret, tenant_id, **kwargs): # type: (str, str, str, Mapping[str, Any]) -> None super(ClientSecretCredential, self).__init__(client_id, secret, tenant_id, **kwargs) - self._client = AuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), **kwargs) + self._client = AuthnClient(tenant=tenant_id, **kwargs) def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (*str, **Any) -> AccessToken @@ -54,11 +58,16 @@ class CertificateCredential(CertificateCredentialBase): :param str client_id: the service principal's client ID :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. :param str certificate_path: path to a PEM-encoded certificate file including the private key + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. """ def __init__(self, client_id, tenant_id, certificate_path, **kwargs): # type: (str, str, str, Mapping[str, Any]) -> None - self._client = AuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), **kwargs) + self._client = AuthnClient(tenant=tenant_id, **kwargs) super(CertificateCredential, self).__init__(client_id, tenant_id, certificate_path, **kwargs) def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index feb14df041a4..e6258ba52413 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -61,7 +61,7 @@ def __init__(self, endpoint, client_cls, config=None, client_id=None, **kwargs): self._client_id = client_id config = config or self._create_config(**kwargs) policies = [ContentDecodePolicy(), config.headers_policy, config.retry_policy, config.logging_policy] - self._client = client_cls(endpoint, config, policies, **kwargs) + self._client = client_cls(endpoint=endpoint, config=config, policies=policies, **kwargs) @staticmethod def _create_config(**kwargs): @@ -105,9 +105,9 @@ class ImdsCredential(_ManagedIdentityBase): :type config: :class:`azure.core.configuration` """ - def __init__(self, config=None, **kwargs): - # type: (Optional[Configuration], Any) -> None - super(ImdsCredential, self).__init__(endpoint=Endpoints.IMDS, client_cls=AuthnClient, config=config, **kwargs) + def __init__(self, **kwargs): + # type: (**Any) -> None + super(ImdsCredential, self).__init__(endpoint=Endpoints.IMDS, client_cls=AuthnClient, **kwargs) self._endpoint_available = None # type: Optional[bool] def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/user.py b/sdk/identity/azure-identity/azure/identity/_credentials/user.py index 1b6db4a81e1e..be1287906859 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/user.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/user.py @@ -10,7 +10,6 @@ from azure.core.exceptions import ClientAuthenticationError from .._authn_client import AuthnClient -from .._constants import Endpoints from .._internal import PublicClientCredential, wrap_exceptions try: @@ -112,6 +111,11 @@ class SharedTokenCacheCredential(object): :param str username: Username (typically an email address) of the user to authenticate as. This is required because the local cache may contain tokens for multiple identities. + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. """ def __init__(self, username, **kwargs): # pylint:disable=unused-argument @@ -166,7 +170,7 @@ def supported(): @staticmethod def _get_auth_client(cache): # type: (msal_extensions.FileTokenCache) -> AuthnClientBase - return AuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format("common"), cache=cache) + return AuthnClient(tenant="common", cache=cache) class UsernamePasswordCredential(PublicClientCredential): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py index 7012ba264bce..5cb00f5e6e11 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py @@ -30,7 +30,6 @@ class AsyncAuthnClient(AuthnClientBase): # pylint:disable=async-client-bad-name # pylint:disable=missing-client-constructor-parameter-credential def __init__( self, - auth_url: str, config: "Optional[Configuration]" = None, policies: Optional[Iterable[HTTPPolicy]] = None, transport: Optional[AsyncHttpTransport] = None, @@ -46,7 +45,7 @@ def __init__( if not transport: transport = AsyncioRequestsTransport(**kwargs) self._pipeline = AsyncPipeline(transport=transport, policies=policies) - super(AsyncAuthnClient, self).__init__(auth_url, **kwargs) + super().__init__(**kwargs) async def request_token( self, diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py index 3fa1db426657..5ff081ad211d 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_credential.py @@ -6,7 +6,6 @@ from .._authn_client import AsyncAuthnClient from ..._base import ClientSecretCredentialBase, CertificateCredentialBase -from ..._constants import Endpoints if TYPE_CHECKING: from typing import Any, Mapping @@ -20,11 +19,16 @@ class ClientSecretCredential(ClientSecretCredentialBase): :param str client_id: the service principal's client ID :param str secret: one of the service principal's client secrets :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. """ def __init__(self, client_id: str, secret: str, tenant_id: str, **kwargs: "Mapping[str, Any]") -> None: super(ClientSecretCredential, self).__init__(client_id, secret, tenant_id, **kwargs) - self._client = AsyncAuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), **kwargs) + self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs) async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument """ @@ -48,11 +52,16 @@ class CertificateCredential(CertificateCredentialBase): :param str client_id: the service principal's client ID :param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID. :param str certificate_path: path to a PEM-encoded certificate file including the private key + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. """ def __init__(self, client_id: str, tenant_id: str, certificate_path: str, **kwargs: "Mapping[str, Any]") -> None: super(CertificateCredential, self).__init__(client_id, tenant_id, certificate_path, **kwargs) - self._client = AsyncAuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), **kwargs) + self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs) async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument """ diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/user.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/user.py index 40f9d38a32d4..f1edc37a2632 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/user.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/user.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING from azure.core.exceptions import ClientAuthenticationError -from azure.identity._constants import Endpoints from ... import SharedTokenCacheCredential as SyncSharedTokenCacheCredential from .._authn_client import AsyncAuthnClient from .._internal.exception_wrapper import wrap_exceptions @@ -37,6 +36,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py :raises: :class:`azure.core.exceptions.ClientAuthenticationError` when the cache is unavailable or no access token can be acquired from it + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', + the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` + defines authorities for other clouds. """ if not self._client: @@ -50,4 +54,4 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py @staticmethod def _get_auth_client(cache: "msal_extensions.FileTokenCache") -> "AuthnClientBase": - return AsyncAuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format("common"), cache=cache) + return AsyncAuthnClient(tenant="common", cache=cache) diff --git a/sdk/identity/azure-identity/tests/test_authn_client.py b/sdk/identity/azure-identity/tests/test_authn_client.py index 1aca947b3d89..dac27dd12785 100644 --- a/sdk/identity/azure-identity/tests/test_authn_client.py +++ b/sdk/identity/azure-identity/tests/test_authn_client.py @@ -4,7 +4,6 @@ # ------------------------------------ """These tests use the synchronous AuthnClient as a driver to test functionality of the sans I/O AuthnClientBase shared with AsyncAuthnClient.""" - import json import time @@ -15,7 +14,7 @@ from azure.core.credentials import AccessToken from azure.identity._authn_client import AuthnClient - +from six.moves.urllib_parse import urlparse from helpers import mock_response @@ -27,16 +26,14 @@ def test_authn_client_deserialization(): expected_access_token = AccessToken(access_token, expires_on) scope = "scope" - mock_response = Mock( - headers={"content-type": "application/json"}, status_code=200, content_type="application/json" - ) + mock_response = Mock(headers={"content-type": "application/json"}, status_code=200, content_type="application/json") mock_send = Mock(return_value=mock_response) # response with expires_on only mock_response.text = lambda: json.dumps( {"access_token": access_token, "expires_on": expires_on, "token_type": "Bearer", "resource": scope} ) - token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token(scope) + token = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)).request_token(scope) assert token == expected_access_token # response with expires_on only and it's a datetime string (App Service MSI) @@ -48,7 +45,7 @@ def test_authn_client_deserialization(): "resource": scope, } ) - token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token(scope) + token = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)).request_token(scope) assert token == expected_access_token # response with string expires_in and expires_on (IMDS, Cloud Shell) @@ -61,7 +58,7 @@ def test_authn_client_deserialization(): "resource": scope, } ) - token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token(scope) + token = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)).request_token(scope) assert token == expected_access_token # response with int expires_in (AAD) @@ -70,7 +67,7 @@ def test_authn_client_deserialization(): ) with patch("azure.identity._authn_client.time.time") as mock_time: mock_time.return_value = now - token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token(scope) + token = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)).request_token(scope) assert token == expected_access_token @@ -91,7 +88,7 @@ def test_caching_when_only_expires_in_set(): ) mock_send = Mock(return_value=mock_response) - client = AuthnClient("http://foo", transport=Mock(send=mock_send)) + client = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)) with patch("azure.identity._authn_client.time.time") as mock_time: mock_time.return_value = 42 token = client.request_token(["scope"]) @@ -105,9 +102,7 @@ def test_caching_when_only_expires_in_set(): def test_expires_in_strings(): expected_token = "token" - mock_response = Mock( - headers={"content-type": "application/json"}, status_code=200, content_type="application/json" - ) + mock_response = Mock(headers={"content-type": "application/json"}, status_code=200, content_type="application/json") mock_send = Mock(return_value=mock_response) mock_response.text = lambda: json.dumps( @@ -117,7 +112,7 @@ def test_expires_in_strings(): now = int(time.time()) with patch("azure.identity._authn_client.time.time") as mock_time: mock_time.return_value = now - token = AuthnClient("http://foo", transport=Mock(send=mock_send)).request_token("scope") + token = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)).request_token("scope") assert token.token == expected_token assert token.expires_on == now + 42 @@ -137,7 +132,7 @@ def test_cache_expiry(): ) mock_send = Mock(return_value=mock_response) - client = AuthnClient("http://foo", transport=Mock(send=mock_send)) + client = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)) with patch("azure.identity._authn_client.time.time") as mock_time: # populate the cache with a valid token mock_time.return_value = now @@ -178,7 +173,7 @@ def mock_send(request, **kwargs): token = expected_tokens[request.data["resource"]] return mock_response(json_payload=token) - client = AuthnClient("http://foo", transport=Mock(send=mock_send)) + client = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)) # if the cache has a token for a & b, it should hit for a, b, a & b token = client.request_token([scope_a, scope_b], form_data={"resource": scope_ab}) @@ -188,9 +183,29 @@ def mock_send(request, **kwargs): assert client.get_cached_token([scope_a, scope_b]).token == scope_ab # if the cache has only tokens for a and b alone, a & b should miss - client = AuthnClient("http://foo", transport=Mock(send=mock_send)) + client = AuthnClient(endpoint="http://foo", transport=Mock(send=mock_send)) for scope in (scope_a, scope_b): token = client.request_token([scope], form_data={"resource": scope}) assert token.token == scope assert client.get_cached_token([scope]).token == scope assert not client.get_cached_token([scope_a, scope_b]) + + +def test_request_url(): + authority = "authority.com" + tenant = "expected_tenant" + + def validate_url(url): + scheme, netloc, path, _, _, _ = urlparse(url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant) + + def mock_send(request, **kwargs): + validate_url(request.url) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) + + client = AuthnClient(tenant=tenant, transport=Mock(send=mock_send), authority=authority) + client.request_token(("scope",)) + request = client.get_refresh_token_grant_request({"secret": "***"}, "scope") + validate_url(request.url) diff --git a/sdk/identity/azure-identity/tests/test_authn_client_async.py b/sdk/identity/azure-identity/tests/test_authn_client_async.py new file mode 100644 index 000000000000..06e63dc9dba9 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_authn_client_async.py @@ -0,0 +1,28 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import asyncio +from unittest.mock import Mock +from urllib.parse import urlparse + +import pytest +from azure.identity.aio._authn_client import AsyncAuthnClient + +from helpers import mock_response + + +@pytest.mark.asyncio +async def test_request_url(): + authority = "authority.com" + tenant = "expected_tenant" + + def mock_send(request, **kwargs): + scheme, netloc, path, _, _, _ = urlparse(request.url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) + + client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=asyncio.coroutine(mock_send)), authority=authority) + await client.request_token(("scope",)) From 58511a9f8714ff7b2de462c7ef3e2e12e8b7d3f0 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 30 Sep 2019 12:21:06 -0700 Subject: [PATCH 4/6] conditional typing imports --- .../azure/identity/aio/_authn_client.py | 37 +++++++++---------- .../identity/aio/_credentials/environment.py | 7 ++++ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py index 5cb00f5e6e11..b18917d2b72b 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py @@ -3,26 +3,25 @@ # Licensed under the MIT License. # ------------------------------------ import time -from typing import Any, Dict, Iterable, Mapping, Optional +from typing import TYPE_CHECKING + from msal import TokenCache from azure.core import Configuration from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline import AsyncPipeline -from azure.core.pipeline.policies import ( - AsyncRetryPolicy, - ContentDecodePolicy, - HTTPPolicy, - NetworkTraceLoggingPolicy, - ProxyPolicy, -) +from azure.core.pipeline.policies import AsyncRetryPolicy, ContentDecodePolicy, NetworkTraceLoggingPolicy, ProxyPolicy from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy -from azure.core.pipeline.transport import AsyncHttpTransport from azure.core.pipeline.transport.requests_asyncio import AsyncioRequestsTransport from .._authn_client import AuthnClientBase +if TYPE_CHECKING: + from typing import Any, Dict, Iterable, Mapping, Optional + from azure.core.pipeline.policies import HTTPPolicy + from azure.core.pipeline.transport import AsyncHttpTransport + class AsyncAuthnClient(AuthnClientBase): # pylint:disable=async-client-bad-name """Async authentication client""" @@ -31,9 +30,9 @@ class AsyncAuthnClient(AuthnClientBase): # pylint:disable=async-client-bad-name def __init__( self, config: "Optional[Configuration]" = None, - policies: Optional[Iterable[HTTPPolicy]] = None, - transport: Optional[AsyncHttpTransport] = None, - **kwargs: Mapping[str, Any] + policies: "Optional[Iterable[HTTPPolicy]]" = None, + transport: "Optional[AsyncHttpTransport]" = None, + **kwargs: "Any" ) -> None: config = config or self._create_config(**kwargs) policies = policies or [ @@ -49,11 +48,11 @@ def __init__( async def request_token( self, - scopes: Iterable[str], - method: Optional[str] = "POST", - headers: Optional[Mapping[str, str]] = None, - form_data: Optional[Mapping[str, str]] = None, - params: Optional[Dict[str, str]] = None, + scopes: "Iterable[str]", + method: "Optional[str]" = "POST", + headers: "Optional[Mapping[str, str]]" = None, + form_data: "Optional[Mapping[str, str]]" = None, + params: "Optional[Dict[str, str]]" = None, **kwargs: "Any" ) -> AccessToken: request = self._prepare_request(method, headers=headers, form_data=form_data, params=params) @@ -62,7 +61,7 @@ async def request_token( token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time) return token - async def obtain_token_by_refresh_token(self, scopes: Iterable[str], username: str) -> Optional[AccessToken]: + async def obtain_token_by_refresh_token(self, scopes: "Iterable[str]", username: str) -> "Optional[AccessToken]": """Acquire an access token using a cached refresh token. Returns ``None`` when that fails, or the cache has no refresh token. This is only used by SharedTokenCacheCredential and isn't robust enough for anything else.""" @@ -90,7 +89,7 @@ async def obtain_token_by_refresh_token(self, scopes: Iterable[str], username: s return None @staticmethod - def _create_config(**kwargs: Mapping[str, Any]) -> Configuration: + def _create_config(**kwargs: "Any") -> Configuration: config = Configuration(**kwargs) config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) config.retry_policy = AsyncRetryPolicy(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py index d037452f9e5c..61aa2c358a6f 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py @@ -3,11 +3,18 @@ # Licensed under the MIT License. # ------------------------------------ import os +from typing import TYPE_CHECKING from azure.core.exceptions import ClientAuthenticationError from ..._constants import EnvironmentVariables from .client_credential import CertificateCredential, ClientSecretCredential +if TYPE_CHECKING: + from typing import Any, Optional, Union + from azure.core.credentials import AccessToken + from azure.core.pipeline.policies import HTTPPolicy + from azure.core.pipeline.transport import AsyncHttpTransport + class EnvironmentCredential: """ From 59070977546846f3e6579147f8612cd35e554a0a Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 30 Sep 2019 12:21:19 -0700 Subject: [PATCH 5/6] update HISTORY --- sdk/identity/azure-identity/HISTORY.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sdk/identity/azure-identity/HISTORY.md b/sdk/identity/azure-identity/HISTORY.md index 029c4583f2a9..9e5226d3431c 100644 --- a/sdk/identity/azure-identity/HISTORY.md +++ b/sdk/identity/azure-identity/HISTORY.md @@ -6,6 +6,16 @@ authorization code. See Azure Active Directory's [authorization code documentation](https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow) for more information about this authentication flow. +- Multi-cloud support: client credentials accept the authority of an Azure Active +Directory authentication endpoint as an `authority` keyword argument. Known +authorities are defined in `azure.identity.KnownAuthorities`. The default +authority is for Azure Public Cloud, `login.microsoftonline.com` +(`KnownAuthorities.AZURE_PUBLIC_CLOUD`). An application running in Azure +Government would use `KnownAuthorities.AZURE_GOVERNMENT` instead: +>``` +>from azure.identity import DefaultAzureCredential, KnownAuthorities +>credential = DefaultAzureCredential(authority=KnownAuthorities.AZURE_GOVERNMENT) +>``` ### Breaking changes: - Removed `client_secret` parameter from `InteractiveBrowserCredential` From 82c6e867529ffd2341a5c3f74817941dc8502c6c Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Wed, 2 Oct 2019 14:30:59 -0700 Subject: [PATCH 6/6] MSAL credentials accept authority kwarg --- .../azure/identity/_credentials/browser.py | 5 ++++- .../azure/identity/_credentials/user.py | 6 ++++++ .../identity/_internal/msal_credentials.py | 19 ++++++++----------- .../azure-identity/tests/test_live.py | 5 +++-- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py index d419d4a4d1ca..cad67642c9b7 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py @@ -33,6 +33,9 @@ class InteractiveBrowserCredential(PublicClientCredential): :param str client_id: the application's client ID Keyword arguments + - *authority*: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. - *tenant (str)*: a tenant ID or a domain associated with a tenant. Defaults to the 'organizations' tenant, which can authenticate work or school accounts. - *timeout (int)*: seconds to wait for the user to complete authentication. Defaults to 300 (5 minutes). @@ -40,7 +43,7 @@ class InteractiveBrowserCredential(PublicClientCredential): """ def __init__(self, client_id, **kwargs): - # type: (str, Any) -> None + # type: (str, **Any) -> None self._timeout = kwargs.pop("timeout", 300) self._server_class = kwargs.pop("server_class", AuthCodeRedirectServer) # facilitate mocking super(InteractiveBrowserCredential, self).__init__(client_id=client_id, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/user.py b/sdk/identity/azure-identity/azure/identity/_credentials/user.py index be1287906859..ef966fa40f61 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/user.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/user.py @@ -46,6 +46,9 @@ class DeviceCodeCredential(PublicClientCredential): If not provided, the credential will print instructions to stdout. Keyword arguments + - *authority*: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. - *tenant (str)* - tenant ID or a domain associated with a tenant. If not provided, defaults to the 'organizations' tenant, which supports only Azure Active Directory work or school accounts. - *timeout (int)* - seconds to wait for the user to authenticate. Defaults to the validity period of the device @@ -190,6 +193,9 @@ class UsernamePasswordCredential(PublicClientCredential): :param str password: the user's password Keyword arguments + - *authority*: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. - *tenant (str)* - tenant ID or a domain associated with a tenant. If not provided, defaults to the 'organizations' tenant, which supports only Azure Active Directory work or school accounts. diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index 29e3a8f1cd62..afe9fc11ce55 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -14,6 +14,7 @@ from .exception_wrapper import wrap_exceptions from .msal_transport_adapter import MsalTransportAdapter +from .._constants import KnownAuthorities try: ABC = abc.ABC @@ -33,9 +34,11 @@ class MsalCredential(ABC): """Base class for credentials wrapping MSAL applications""" - def __init__(self, client_id, authority, client_credential=None, **kwargs): - # type: (str, str, Optional[Union[str, Mapping[str, str]]], Any) -> None - self._authority = authority + def __init__(self, client_id, client_credential=None, **kwargs): + # type: (str, Optional[Union[str, Mapping[str, str]]], **Any) -> None + tenant = kwargs.pop("tenant", None) or "organizations" + authority = kwargs.pop("authority", KnownAuthorities.AZURE_PUBLIC_CLOUD) + self._base_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/"))) self._client_credential = client_credential self._client_id = client_id @@ -60,7 +63,8 @@ def _create_app(self, cls): # MSAL application initializers use msal.authority to send AAD tenant discovery requests with self._adapter: - app = cls(client_id=self._client_id, client_credential=self._client_credential, authority=self._authority) + # MSAL's "authority" is a URL e.g. https://login.microsoftonline.com/common + app = cls(client_id=self._client_id, client_credential=self._client_credential, authority=self._base_url) # monkeypatch the app to replace requests.Session with MsalTransportAdapter app.client.session.close() @@ -100,13 +104,6 @@ def _get_app(self): class PublicClientCredential(MsalCredential): """Wraps an MSAL PublicClientApplication with the TokenCredential API""" - def __init__(self, **kwargs): - # type: (Any) -> None - tenant = kwargs.pop("tenant", None) or "organizations" - super(PublicClientCredential, self).__init__( - authority="https://login.microsoftonline.com/" + tenant, **kwargs - ) - @abc.abstractmethod def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (*str, **Any) -> AccessToken diff --git a/sdk/identity/azure-identity/tests/test_live.py b/sdk/identity/azure-identity/tests/test_live.py index ddff3d83fa3a..51394dfd2bbf 100644 --- a/sdk/identity/azure-identity/tests/test_live.py +++ b/sdk/identity/azure-identity/tests/test_live.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from azure.identity import DefaultAzureCredential, CertificateCredential, ClientSecretCredential +from azure.identity import DefaultAzureCredential, CertificateCredential, ClientSecretCredential, KnownAuthorities from azure.identity._internal import ConfidentialClientCredential ARM_SCOPE = "https://management.azure.com/.default" @@ -44,7 +44,8 @@ def test_confidential_client_credential(live_identity_settings): credential = ConfidentialClientCredential( client_id=live_identity_settings["client_id"], client_credential=live_identity_settings["client_secret"], - authority="https://login.microsoftonline.com/" + live_identity_settings["tenant_id"], + authority=KnownAuthorities.AZURE_PUBLIC_CLOUD, + tenant=live_identity_settings["tenant_id"], ) token = credential.get_token(ARM_SCOPE) assert token