From 2645ed73c1f727b7f4694cfa9359a245f41bf66b Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 19 Sep 2019 09:36:50 -0700 Subject: [PATCH 01/14] synchronous AAD client wrapping MSAL's oauth2 client --- .../azure/identity/_internal/__init__.py | 4 + .../azure/identity/_internal/aad_client.py | 30 +++++ .../identity/_internal/aad_client_base.py | 112 ++++++++++++++++++ .../identity/aio/_internal/aad_client.py | 31 +++++ 4 files changed, 177 insertions(+) create mode 100644 sdk/identity/azure-identity/azure/identity/_internal/aad_client.py create mode 100644 sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py create mode 100644 sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index 6108ffdaa895..b6234d1eee67 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -2,12 +2,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .aad_client import AadClient +from .aad_client_base import AadClientBase from .auth_code_redirect_handler import AuthCodeRedirectServer from .exception_wrapper import wrap_exceptions from .msal_credentials import ConfidentialClientCredential, PublicClientCredential from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse __all__ = [ + "AadClient", + "AadClientBase", "AuthCodeRedirectServer", "ConfidentialClientCredential", "MsalTransportAdapter", diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py new file mode 100644 index 000000000000..aa25b4a4b1a2 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -0,0 +1,30 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""A thin wrapper around MSAL's token cache and OAuth 2 client""" + +import time +from typing import TYPE_CHECKING + +from azure.core.credentials import AccessToken + +from .aad_client_base import AadClientBase +from .msal_transport_adapter import MsalTransportAdapter +from .exception_wrapper import wrap_exceptions + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Callable, Iterable + + +class AadClient(AadClientBase): + def _get_client_session(self, **kwargs): + return MsalTransportAdapter(**kwargs) + + @wrap_exceptions + def _obtain_token(self, scopes, fn, **kwargs): # pylint:disable=unused-argument + # type: (Iterable[str], Callable, **Any) -> AccessToken + now = int(time.time()) + response = fn() + return self._process_response(response=response, scopes=scopes, now=now) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py new file mode 100644 index 000000000000..a894a9cfcd5c --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -0,0 +1,112 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import abc +import functools +import time + +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +from msal import TokenCache +from msal.oauth2cli.oauth2 import Client + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError +from .._constants import Endpoints + +try: + ABC = abc.ABC +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Callable, Iterable, Optional + + +class AadClientBase(ABC): + """Sans I/O methods for AAD clients wrapping MSAL's OAuth client""" + + def __init__(self, client_id, tenant_id, **kwargs): + # type: (str, str, **Any) -> None + config = {"token_endpoint": Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id)} + self._client = Client(server_configuration=config, client_id=client_id) + self._client.session.close() + self._client.session = self._get_client_session(**kwargs) + self._cache = TokenCache() + + def get_cached_access_token(self, scopes): + # type: (Iterable[str]) -> Optional[AccessToken] + tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes)) + for token in tokens: + expires_on = int(token["expires_on"]) + if expires_on - 300 > int(time.time()): + return AccessToken(token["secret"], expires_on) + return None + + def get_cached_refresh_tokens(self, scopes): + """Assumes all cached refresh tokens belong to the same user""" + return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes)) + + def obtain_token_by_authorization_code(self, code, redirect_uri, scopes, **kwargs): + # type: (str, str, Iterable[str], **Any) -> AccessToken + fn = functools.partial( + self._client.obtain_token_by_authorization_code, code=code, redirect_uri=redirect_uri, **kwargs + ) + return self._obtain_token(scopes, fn, **kwargs) + + def obtain_token_by_refresh_token(self, refresh_token, scopes, **kwargs): + # type: (str, Iterable[str], **Any) -> AccessToken + fn = functools.partial( + self._client.obtain_token_by_refresh_token, + token_item=refresh_token, + scope=scopes, + rt_getter=lambda token: token["secret"], + **kwargs + ) + return self._obtain_token(scopes, fn) + + def _process_response(self, response, scopes, now): + # type: (dict, Iterable[str], int) -> AccessToken + _raise_for_error(response) + self._cache.add(event={"response": response, "scope": scopes}, now=now) + if "expires_on" in response: + expires_on = int(response["expires_on"]) + elif "expires_in" in response: + expires_on = now + int(response["expires_in"]) + else: + for secret in ("access_token", "refresh_token"): + if secret in response: + response[secret] = "***" + raise ClientAuthenticationError( + message="Unexpected response from Azure Active Directory: {}".format(response) + ) + return AccessToken(response["access_token"], expires_on) + + @abc.abstractmethod + def _get_client_session(self, **kwargs): + pass + + @abc.abstractmethod + def _obtain_token(self, scopes, fn, **kwargs): + # type: (Iterable[str], Callable, **Any) -> AccessToken + pass + + +def _raise_for_error(response): + # type: (dict) -> None + if "error" not in response: + return + + if "error_description" in response: + message = "Azure Active Directory error '({}) {}'".format(response["error"], response["error_description"]) + else: + for secret in ("access_token", "refresh_token"): + if secret in response: + response[secret] = "***" + message = "Azure Active Directory error '{}'".format(response) + raise ClientAuthenticationError(message=message) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py new file mode 100644 index 000000000000..444dade7d13f --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -0,0 +1,31 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""A thin wrapper around MSAL's token cache and OAuth 2 client""" + +import asyncio +import time +from typing import TYPE_CHECKING + +from azure.identity._internal import AadClientBase +from .msal_transport_adapter import MsalTransportAdapter +from .exception_wrapper import wrap_exceptions + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Callable, Iterable + from azure.core.credentials import AccessToken + + +class AadClient(AadClientBase): + def _get_client_session(self, **kwargs): + return MsalTransportAdapter(**kwargs) + + @wrap_exceptions + async def _obtain_token(self, scopes: "Iterable[str]", fn: "Callable", **kwargs: "Any") -> "AccessToken": + now = int(time.time()) + executor = kwargs.get("executor", None) + loop = kwargs.get("loop", None) or asyncio.get_event_loop() + response = await loop.run_in_executor(executor, fn) + return self._process_response(response=response, scopes=scopes, now=now) From d2994012b342760373916455ad952b8441bd6da2 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 19 Sep 2019 09:42:01 -0700 Subject: [PATCH 02/14] synchronous AuthorizationCodeCredential --- .../azure-identity/azure/identity/__init__.py | 1 + .../azure/identity/_credentials/__init__.py | 2 + .../_credentials/authorization_code.py | 70 +++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index 3e586121bfb9..4cbec9a5867f 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -17,6 +17,7 @@ __all__ = [ + "AuthorizationCodeCredential", "CertificateCredential", "ChainedTokenCredential", "ClientSecretCredential", diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py b/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py index 45f6450fe18d..28e0969c8cd4 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .authorization_code import AuthorizationCodeCredential from .browser import InteractiveBrowserCredential from .chained import ChainedTokenCredential from .client_credential import CertificateCredential, ClientSecretCredential @@ -12,6 +13,7 @@ __all__ = [ + "AuthorizationCodeCredential", "CertificateCredential", "ChainedTokenCredential", "ClientSecretCredential", diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py new file mode 100644 index 000000000000..2827b573cc6c --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -0,0 +1,70 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from typing import TYPE_CHECKING + +from azure.core.exceptions import ClientAuthenticationError +from .._internal.aad_client import AadClient + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Iterable, Optional + from azure.core.credentials import AccessToken + + +class AuthorizationCodeCredential(object): + """ + Authenticates by redeeming an authorization code previously obtained from Azure Active Directory. + See https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow for more information + about the authentication flow. + + :param str client_id: the application's client ID + :param str tenant_id: ID of the application's Azure Active Directory tenant. Also called its 'directory' ID. + :param str authorization_code: the authorization code from the user's log-in + :param str redirect_uri: The application's redirect URI. Must match the URI used to request the authorization code. + :param str client_secret: One of the application's client secrets. Required only for web apps and web APIs. + """ + + def __init__(self, client_id, tenant_id, authorization_code, redirect_uri, client_secret=None, **kwargs): + # type: (str, str, str, str, Optional[str], **Any) -> None + self._authorization_code = authorization_code # type: Optional[str] + self._client_id = client_id + self._client_secret = client_secret + self._client = kwargs.pop("client", None) or AadClient(client_id, tenant_id, **kwargs) + self._redirect_uri = redirect_uri + + def get_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + """ + Request an access token for ``scopes``. The first time this method is called, the credential will redeem its + authorization code. On subsequent calls the credential will return a cached access token or redeem a refresh + token, if it acquired a refresh token upon redeeming the authorization code. + + :param str scopes: desired scopes for the access token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: :class:`azure.core.exceptions.ClientAuthenticationError` + """ + + if self._authorization_code: + token = self._client.obtain_token_by_authorization_code( + code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, **kwargs + ) + self._authorization_code = None # auth codes are single-use + return token + + token = self._client.get_cached_access_token(scopes) or self._redeem_refresh_token(scopes, **kwargs) + if not token: + raise ClientAuthenticationError( + message="No authorization code, cached access token, or refresh token available." + ) + + return token + + def _redeem_refresh_token(self, scopes, **kwargs): + # type: (Iterable[str], **Any) -> Optional[AccessToken] + for refresh_token in self._client.get_cached_refresh_tokens(scopes): + token = self._client.obtain_token_by_refresh_token(refresh_token, scopes, **kwargs) + if token: + return token + return None From 55e2da526db84faa6565ebfe5835ef34c41eb6c7 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 19 Sep 2019 12:32:12 -0700 Subject: [PATCH 03/14] async wrapper for requests Session --- .../azure/identity/aio/_internal/__init__.py | 3 +- .../aio/_internal/msal_transport_adapter.py | 108 ++++++++++++++++++ 2 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py index 7ca58029041f..bb2e886d8475 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py @@ -3,5 +3,6 @@ # Licensed under the MIT License. # ------------------------------------ from .exception_wrapper import wrap_exceptions +from .msal_transport_adapter import MsalTransportAdapter -__all__ = ["wrap_exceptions"] +__all__ = ["MsalTransportAdapter", "wrap_exceptions"] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py new file mode 100644 index 000000000000..e85baca7d3c9 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py @@ -0,0 +1,108 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Adapter to substitute an async azure-core pipeline for Requests in MSAL application token acquisition methods.""" + +import asyncio +import atexit +from typing import TYPE_CHECKING + +from azure.core.configuration import Configuration +from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.policies import AsyncRetryPolicy, DistributedTracingPolicy, NetworkTraceLoggingPolicy +from azure.core.pipeline.transport import AioHttpTransport, HttpRequest + +from azure.identity._internal import MsalTransportResponse + +if TYPE_CHECKING: + # pylint:disable=unused-import + from typing import Any, Dict, Iterable, Optional + from azure.core.pipeline.policies import AsyncHTTPPolicy + from azure.core.pipeline.transport import AsyncHttpTransport + + +class MsalTransportAdapter: + """Wraps an async azure-core pipeline with the shape of a (synchronous) Requests Session""" + + def __init__( + self, + config: "Optional[Configuration]" = None, + policies: "Optional[Iterable[AsyncHTTPPolicy]]" = None, + transport: "Optional[AsyncHttpTransport]" = None, + **kwargs: "Any" + ) -> None: + + config = config or self._create_config(**kwargs) + policies = policies or [config.retry_policy, config.logging_policy, DistributedTracingPolicy()] + self._transport = transport or AioHttpTransport(configuration=config) + atexit.register(self._close_transport_session) # prevent aiohttp warnings + self._pipeline = AsyncPipeline(transport=self._transport, policies=policies) + + def _close_transport_session(self) -> None: + """If transport has a 'close' method, invoke it.""" + + close = getattr(self._transport, "close", None) + if not callable(close): + return + + if asyncio.iscoroutinefunction(close): + # we expect no loop is running because this method should be called only when the interpreter is exiting + asyncio.new_event_loop().run_until_complete(close()) + else: + close() + + def get( + self, + url: str, + headers: "Optional[Dict[str, str]]" = None, + params: "Optional[Dict[str, str]]" = None, + timeout: "Optional[float]" = None, + verify: "Optional[bool]" = None, + **kwargs: "Any" + ) -> MsalTransportResponse: + + request = HttpRequest("GET", url, headers=headers) + if params: + request.format_parameters(params) + + loop = kwargs.pop("loop", None) + future = asyncio.run_coroutine_threadsafe( # type: ignore + self._pipeline.run(request, connection_timeout=timeout, connection_verify=verify, **kwargs), loop + ) + response = future.result(timeout=timeout) + + return MsalTransportResponse(response) + + def post( + self, + url: str, + data: "Any" = None, + headers: "Optional[Dict[str, str]]" = None, + params: "Optional[Dict[str, str]]" = None, + timeout: "Optional[float]" = None, + verify: "Optional[bool]" = None, + **kwargs: "Any" + ) -> MsalTransportResponse: + + request = HttpRequest("POST", url, headers=headers) + if params: + request.format_parameters(params) + if data: + request.headers["Content-Type"] = "application/x-www-form-urlencoded" + request.set_formdata_body(data) + + loop = kwargs.pop("loop", None) + future = asyncio.run_coroutine_threadsafe( # type: ignore + self._pipeline.run(request, connection_timeout=timeout, connection_verify=verify, **kwargs), loop + ) + response = future.result(timeout=timeout) + + return MsalTransportResponse(response) + + @staticmethod + def _create_config(**kwargs: "Any") -> Configuration: + config = Configuration(**kwargs) + config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) + config.retry_policy = AsyncRetryPolicy(**kwargs) + return config From 5343e2fec1aecb0416bc8b6021e84f4632d20ff9 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 19 Sep 2019 12:32:27 -0700 Subject: [PATCH 04/14] async AuthorizationCodeCredential --- .../azure/identity/aio/__init__.py | 2 + .../identity/aio/_credentials/__init__.py | 2 + .../aio/_credentials/authorization_code.py | 79 +++++++++++++++++++ .../azure/identity/aio/_internal/__init__.py | 3 +- .../identity/aio/_internal/aad_client.py | 4 +- 5 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py diff --git a/sdk/identity/azure-identity/azure/identity/aio/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/__init__.py index 57a52755df27..8a63109e49ff 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/__init__.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ from ._credentials import ( + AuthorizationCodeCredential, CertificateCredential, ChainedTokenCredential, ClientSecretCredential, @@ -14,6 +15,7 @@ __all__ = [ + "AuthorizationCodeCredential", "CertificateCredential", "ClientSecretCredential", "DefaultAzureCredential", diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py index 55ed64d00b1e..e2bb64f0fa4e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .authorization_code import AuthorizationCodeCredential from .chained import ChainedTokenCredential from .default import DefaultAzureCredential from .environment import EnvironmentCredential @@ -11,6 +12,7 @@ __all__ = [ + "AuthorizationCodeCredential", "CertificateCredential", "ChainedTokenCredential", "ClientSecretCredential", diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py new file mode 100644 index 000000000000..21a95c51e67d --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -0,0 +1,79 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import asyncio +from typing import TYPE_CHECKING + +from azure.core.exceptions import ClientAuthenticationError +from .._internal import AadClient + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Collection, Optional + from azure.core.credentials import AccessToken + + +class AuthorizationCodeCredential(object): + """ + Authenticates by redeeming an authorization code previously obtained from Azure Active Directory. + See https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow for more information + about the authentication flow. + + :param str client_id: the application's client ID + :param str tenant_id: ID of the application's Azure Active Directory tenant. Also called its 'directory' ID. + :param str authorization_code: the authorization code from the user's log-in + :param str redirect_uri: The application's redirect URI. Must match the URI used to request the authorization code. + :param str client_secret: One of the application's client secrets. Required only for web apps and web APIs. + """ + + def __init__(self, client_id, tenant_id, authorization_code, redirect_uri, client_secret=None, **kwargs): + # type: (str, str, str, str, Optional[str], **Any) -> None + self._authorization_code = authorization_code # type: Optional[str] + self._client_id = client_id + self._client_secret = client_secret + self._client = kwargs.pop("client", None) or AadClient(client_id, tenant_id, **kwargs) + self._redirect_uri = redirect_uri + + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + """ + Request an access token for ``scopes``. The first time this method is called, the credential will redeem its + authorization code. On subsequent calls the credential will return a cached access token or redeem a refresh + token, if it acquired a refresh token upon redeeming the authorization code. + + :param str scopes: desired scopes for the access token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: :class:`azure.core.exceptions.ClientAuthenticationError` + + Keyword arguments: + - **executor**: (optional) a :class:`concurrent.futures.Executor` used to execute asynchronous calls + - **loop**: (optional) an event loop on which to schedule network I/O. If not provided, the currently + running loop will be used. + """ + + if self._authorization_code: + loop = kwargs.pop("loop", None) or asyncio.get_event_loop() + token = await self._client.obtain_token_by_authorization_code( + code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, loop=loop, **kwargs + ) + self._authorization_code = None # auth codes are single-use + return token + + token = self._client.get_cached_access_token(scopes) + if not token: + token = await self._redeem_refresh_token(scopes, **kwargs) + + if not token: + raise ClientAuthenticationError( + message="No authorization code, cached access token, or refresh token available." + ) + + return token + + async def _redeem_refresh_token(self, scopes: "Collection[str]", **kwargs: "Any") -> "Optional[AccessToken]": + loop = kwargs.pop("loop", None) or asyncio.get_event_loop() + for refresh_token in self._client.get_cached_refresh_tokens(scopes): + token = await self._client.obtain_token_by_refresh_token(refresh_token, scopes, loop=loop, **kwargs) + if token: + return token + return None diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py index bb2e886d8475..097b56a96c85 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py @@ -2,7 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .aad_client import AadClient from .exception_wrapper import wrap_exceptions from .msal_transport_adapter import MsalTransportAdapter -__all__ = ["MsalTransportAdapter", "wrap_exceptions"] +__all__ = ["AadClient", "MsalTransportAdapter", "wrap_exceptions"] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 444dade7d13f..17e997db834e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -23,7 +23,9 @@ def _get_client_session(self, **kwargs): return MsalTransportAdapter(**kwargs) @wrap_exceptions - async def _obtain_token(self, scopes: "Iterable[str]", fn: "Callable", **kwargs: "Any") -> "AccessToken": + async def _obtain_token( + self, scopes: "Iterable[str]", fn: "Callable", **kwargs: "Any" + ) -> "AccessToken": now = int(time.time()) executor = kwargs.get("executor", None) loop = kwargs.get("loop", None) or asyncio.get_event_loop() From a591a778d73f4dfb01d27cfc2cf78f80204fdc20 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 19 Sep 2019 16:30:05 -0700 Subject: [PATCH 05/14] aiohttp is a dev requirement --- sdk/identity/azure-identity/dev_requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/identity/azure-identity/dev_requirements.txt b/sdk/identity/azure-identity/dev_requirements.txt index 811e306ce457..bdb60dd6d780 100644 --- a/sdk/identity/azure-identity/dev_requirements.txt +++ b/sdk/identity/azure-identity/dev_requirements.txt @@ -1,4 +1,5 @@ -e ../../core/azure-core +aiohttp;python_full_version>="3.5.2" typing_extensions>=3.7.2 pytest pytest-asyncio;python_full_version>="3.5.2" From 15735f2749392bb8382b473b028488739386f04b Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Thu, 19 Sep 2019 16:30:11 -0700 Subject: [PATCH 06/14] tests --- .../azure-identity/tests/test_aad_client.py | 95 ++++++++++++++++++ .../tests/test_aad_client_async.py | 39 ++++++++ .../azure-identity/tests/test_auth_code.py | 50 ++++++++++ .../tests/test_auth_code_async.py | 97 +++++++++++++++++++ 4 files changed, 281 insertions(+) create mode 100644 sdk/identity/azure-identity/tests/test_aad_client.py create mode 100644 sdk/identity/azure-identity/tests/test_aad_client_async.py create mode 100644 sdk/identity/azure-identity/tests/test_auth_code.py create mode 100644 sdk/identity/azure-identity/tests/test_auth_code_async.py diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py new file mode 100644 index 000000000000..50004d132484 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -0,0 +1,95 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import functools + +from azure.core.exceptions import ClientAuthenticationError +from azure.identity._internal.aad_client import AadClient +import pytest + +try: + from unittest.mock import Mock +except ImportError: # python < 3.3 + from mock import Mock # type: ignore + + +class MockClient(AadClient): + def __init__(self, *args, **kwargs): + self.session = kwargs.pop("session") + super(MockClient, self).__init__(*args, **kwargs) + + def _get_client_session(self, **kwargs): + return self.session + + +def test_uses_msal_correctly(): + session = Mock() + transport = Mock() + session.get = session.post = transport + + client = MockClient("client id", "tenant id", session=session) + + # MSAL will raise on each call because the mock transport returns nothing useful. + # That's okay because we only want to verify the transport was called, i.e. that + # the client used the MSAL API correctly, such that MSAL tried to send a request. + with pytest.raises(ClientAuthenticationError): + client.obtain_token_by_authorization_code("code", "redirect uri", "scope") + assert transport.call_count == 1 + + transport.reset_mock() + + with pytest.raises(ClientAuthenticationError): + client.obtain_token_by_refresh_token("refresh token", "scope") + assert transport.call_count == 1 + + +def test_error_reporting(): + error_name = "everything's sideways" + error_description = "something went wrong" + error_response = {"error": error_name, "error_description": error_description} + + response = Mock(status_code=403, json=lambda: error_response) + transport = Mock(return_value=response) + session = Mock(get=transport, post=transport) + client = MockClient("client id", "tenant id", session=session) + + fns = [ + functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), + functools.partial(client.obtain_token_by_refresh_token, {"secret": "refresh token"}, "scope"), + ] + + # exceptions raised for AAD errors should contain AAD's error description + for fn in fns: + with pytest.raises(ClientAuthenticationError) as ex: + fn() + message = str(ex.value) + assert error_name in message and error_description in message + + +def test_exceptions_do_not_expose_secrets(): + secret = "secret" + body = {"error": "bad thing", "access_token": secret, "refresh_token": secret} + response = Mock(status_code=403, json=lambda: body) + transport = Mock(return_value=response) + session = Mock(get=transport, post=transport) + client = MockClient("client id", "tenant id", session=session) + + fns = [ + functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), + functools.partial(client.obtain_token_by_refresh_token, {"secret": "refresh token"}, "scope"), + ] + + def assert_secrets_not_exposed(): + for fn in fns: + with pytest.raises(ClientAuthenticationError) as ex: + fn() + assert secret not in str(ex.value) + assert secret not in repr(ex.value) + + # AAD errors shouldn't provoke exceptions exposing secrets + assert_secrets_not_exposed() + + # neither should unexpected AAD responses + del body["error"] + assert_secrets_not_exposed() diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py new file mode 100644 index 000000000000..515d24957162 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -0,0 +1,39 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.identity.aio._internal.aad_client import AadClient +import pytest + +from unittest.mock import Mock + + +class MockClient(AadClient): + def __init__(self, *args, **kwargs): + self.session = kwargs.pop("session") + super(MockClient, self).__init__(*args, **kwargs) + + def _get_client_session(self, **kwargs): + return self.session + + +@pytest.mark.asyncio +async def test_uses_msal_correctly(): + session = Mock() + transport = Mock() + session.get = session.post = transport + + client = MockClient("client id", "tenant id", session=session) + + # MSAL will raise on each call because the mock transport returns nothing useful. + # That's okay because we only want to verify the transport was called, i.e. that + # the client used the MSAL API correctly, such that MSAL tried to send a request. + with pytest.raises(Exception): + await client.obtain_token_by_authorization_code("code", "redirect uri", "scope") + assert transport.call_count == 1 + + transport.reset_mock() + + with pytest.raises(Exception): + await client.obtain_token_by_refresh_token("refresh token", "scope") + assert transport.call_count == 1 diff --git a/sdk/identity/azure-identity/tests/test_auth_code.py b/sdk/identity/azure-identity/tests/test_auth_code.py new file mode 100644 index 000000000000..5300368ef334 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_auth_code.py @@ -0,0 +1,50 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.core.credentials import AccessToken +from azure.identity import AuthorizationCodeCredential + +try: + from unittest.mock import Mock +except ImportError: # python < 3.3 + from mock import Mock # type: ignore + + +def test_auth_code_credential(): + client_id = "client id" + tenant_id = "tenant" + expected_code = "auth code" + redirect_uri = "https://foo.bar" + expected_token = AccessToken("token", 42) + + mock_client = Mock(spec=object) + mock_client.obtain_token_by_authorization_code = Mock(return_value=expected_token) + + credential = AuthorizationCodeCredential( + client_id=client_id, + tenant_id=tenant_id, + authorization_code=expected_code, + redirect_uri=redirect_uri, + client=mock_client, + ) + + # first call should redeem the auth code + token = credential.get_token("scope") + assert token is expected_token + assert mock_client.obtain_token_by_authorization_code.call_count == 1 + _, kwargs = mock_client.obtain_token_by_authorization_code.call_args + assert kwargs["code"] == expected_code + + # no auth code -> credential should return cached token + mock_client.obtain_token_by_authorization_code = None # raise if credential calls this again + mock_client.get_cached_access_token = lambda *_: expected_token + token = credential.get_token("scope") + assert token is expected_token + + # no auth code, no cached token -> credential should use refresh token + mock_client.get_cached_access_token = lambda *_: None + mock_client.get_cached_refresh_tokens = lambda *_: ["this is a refresh token"] + mock_client.obtain_token_by_refresh_token = lambda *_, **__: expected_token + token = credential.get_token("scope") + assert token is expected_token diff --git a/sdk/identity/azure-identity/tests/test_auth_code_async.py b/sdk/identity/azure-identity/tests/test_auth_code_async.py new file mode 100644 index 000000000000..53bd5266485c --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_auth_code_async.py @@ -0,0 +1,97 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import asyncio +from unittest.mock import Mock + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError +from azure.identity.aio import AuthorizationCodeCredential +import pytest + + +@pytest.mark.asyncio +async def test_auth_code_credential(): + client_id = "client id" + tenant_id = "tenant" + expected_code = "auth code" + redirect_uri = "https://foo.bar" + expected_token = AccessToken("token", 42) + + mock_client = Mock(spec=object) + obtain_by_auth_code = Mock(return_value=expected_token) + mock_client.obtain_token_by_authorization_code = asyncio.coroutine(obtain_by_auth_code) + + credential = AuthorizationCodeCredential( + client_id=client_id, + tenant_id=tenant_id, + authorization_code=expected_code, + redirect_uri=redirect_uri, + client=mock_client, + ) + + # first call should redeem the auth code + token = await credential.get_token("scope") + assert token is expected_token + assert obtain_by_auth_code.call_count == 1 + _, kwargs = obtain_by_auth_code.call_args + assert kwargs["code"] == expected_code + + # no auth code -> credential should return cached token + mock_client.obtain_token_by_authorization_code = None # raise if credential calls this again + mock_client.get_cached_access_token = lambda *_: expected_token + token = await credential.get_token("scope") + assert token is expected_token + + # no auth code, no cached token -> credential should use refresh token + mock_client.get_cached_access_token = lambda *_: None + mock_client.get_cached_refresh_tokens = lambda *_: ["this is a refresh token"] + mock_client.obtain_token_by_refresh_token = asyncio.coroutine(lambda *_, **__: expected_token) + token = await credential.get_token("scope") + assert token is expected_token + + +@pytest.mark.asyncio +async def test_custom_executor_used(): + credential = AuthorizationCodeCredential( + client_id="client id", tenant_id="tenant id", authorization_code="auth code", redirect_uri="https://foo.bar" + ) + + executor = Mock() + + with pytest.raises(ClientAuthenticationError): + await credential.get_token("scope", executor=executor) + + assert executor.submit.call_count == 1 + + +@pytest.mark.asyncio +async def test_custom_loop_used(): + credential = AuthorizationCodeCredential( + client_id="client id", tenant_id="tenant id", authorization_code="auth code", redirect_uri="https://foo.bar" + ) + + loop = Mock() + + with pytest.raises(ClientAuthenticationError): + await credential.get_token("scope", loop=loop) + + assert loop.run_in_executor.call_count == 1 + + +@pytest.mark.asyncio +async def test_custom_loop_and_executor_used(): + credential = AuthorizationCodeCredential( + client_id="client id", tenant_id="tenant id", authorization_code="auth code", redirect_uri="https://foo.bar" + ) + + executor = Mock() + loop = Mock() + + with pytest.raises(ClientAuthenticationError): + await credential.get_token("scope", executor=executor, loop=loop) + + assert loop.run_in_executor.call_count == 1 + executor_arg, _ = loop.run_in_executor.call_args[0] + assert executor_arg is executor From 6370135e1358e9650ba3d8c1315f1a1dd4a53129 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 20 Sep 2019 14:16:02 -0700 Subject: [PATCH 07/14] update HISTORY --- sdk/identity/azure-identity/HISTORY.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sdk/identity/azure-identity/HISTORY.md b/sdk/identity/azure-identity/HISTORY.md index 80e0ff93b793..dd5207477ab5 100644 --- a/sdk/identity/azure-identity/HISTORY.md +++ b/sdk/identity/azure-identity/HISTORY.md @@ -1,6 +1,12 @@ # Release History ## 1.0.0b4 +### New features: +- `AuthorizationCodeCredential` authenticates with a previously obtained +authorization code. See Azure Active Directory's +[authorization code documentation](https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow) +for more information about this authentication flow. + ### Fixes and improvements: - `UsernamePasswordCredential` correctly handles environment configuration with no tenant information (#7260) From afb7d716c0a4dc27e0122257a1208ee0d6f24b9f Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Fri, 20 Sep 2019 15:25:31 -0700 Subject: [PATCH 08/14] factor out secret scrubbing --- .../azure/identity/_internal/aad_client_base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index a894a9cfcd5c..6111b03965db 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -79,9 +79,7 @@ def _process_response(self, response, scopes, now): elif "expires_in" in response: expires_on = now + int(response["expires_in"]) else: - for secret in ("access_token", "refresh_token"): - if secret in response: - response[secret] = "***" + _scrub_secrets(response) raise ClientAuthenticationError( message="Unexpected response from Azure Active Directory: {}".format(response) ) @@ -97,16 +95,20 @@ def _obtain_token(self, scopes, fn, **kwargs): pass +def _scrub_secrets(response): + for secret in ("access_token", "refresh_token"): + if secret in response: + response[secret] = "***" + + def _raise_for_error(response): # type: (dict) -> None if "error" not in response: return + _scrub_secrets(response) if "error_description" in response: message = "Azure Active Directory error '({}) {}'".format(response["error"], response["error_description"]) else: - for secret in ("access_token", "refresh_token"): - if secret in response: - response[secret] = "***" message = "Azure Active Directory error '{}'".format(response) raise ClientAuthenticationError(message=message) From 0d301ad8f81403d2eac5a4ded831c80d8d74a32a Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 23 Sep 2019 10:14:53 -0700 Subject: [PATCH 09/14] configurable authority --- .../azure-identity/azure/identity/__init__.py | 7 ++++- .../azure/identity/_constants.py | 10 +++++-- .../_credentials/authorization_code.py | 5 ++++ .../identity/_internal/aad_client_base.py | 11 +++++-- .../aio/_credentials/authorization_code.py | 5 ++++ .../azure-identity/tests/test_aad_client.py | 26 ++++++++++++++++ .../tests/test_aad_client_async.py | 30 ++++++++++++++++++- 7 files changed, 87 insertions(+), 7 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index 4cbec9a5867f..60d2ba4af91c 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -2,14 +2,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from ._constants import EnvironmentVariables, KnownAuthorities from ._credentials import ( - InteractiveBrowserCredential, + AuthorizationCodeCredential, + CertificateCredential, ChainedTokenCredential, ClientSecretCredential, DefaultAzureCredential, DeviceCodeCredential, EnvironmentCredential, + InteractiveBrowserCredential, ManagedIdentityCredential, SharedTokenCacheCredential, UsernamePasswordCredential, @@ -24,7 +27,9 @@ "DefaultAzureCredential", "DeviceCodeCredential", "EnvironmentCredential", + "EnvironmentVariables", "InteractiveBrowserCredential", + "KnownAuthorities", "ManagedIdentityCredential", "SharedTokenCacheCredential", "UsernamePasswordCredential", diff --git a/sdk/identity/azure-identity/azure/identity/_constants.py b/sdk/identity/azure-identity/azure/identity/_constants.py index f497282c826e..ffb0ed644b58 100644 --- a/sdk/identity/azure-identity/azure/identity/_constants.py +++ b/sdk/identity/azure-identity/azure/identity/_constants.py @@ -7,6 +7,13 @@ AZURE_CLI_CLIENT_ID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" +class KnownAuthorities: + AZURE_CHINA = "login.chinacloudapi.cn" + AZURE_GERMANY = "login.microsoftonline.de" + AZURE_GOVERNMENT = "login.microsoftonline.us" + AZURE_PUBLIC_CLOUD = "login.microsoftonline.com" + + class EnvironmentVariables: AZURE_CLIENT_ID = "AZURE_CLIENT_ID" AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET" @@ -28,5 +35,4 @@ class Endpoints: # https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http IMDS = "http://169.254.169.254/metadata/identity/oauth2/token" - # TODO: other clouds have other endpoints - AAD_OAUTH2_V2_FORMAT = "https://login.microsoftonline.com/{}/oauth2/v2.0/token" + AAD_OAUTH2_V2_FORMAT = "https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD + "/{}/oauth2/v2.0/token" diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py index 2827b573cc6c..82a23634b2fb 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -24,6 +24,11 @@ class AuthorizationCodeCredential(object): :param str authorization_code: the authorization code from the user's log-in :param str redirect_uri: The application's redirect URI. Must match the URI used to request the authorization code. :param str client_secret: One of the application's client secrets. Required only for web apps and web APIs. + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. """ def __init__(self, client_id, tenant_id, authorization_code, redirect_uri, client_secret=None, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 6111b03965db..67deda4bbe0b 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -16,7 +16,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError -from .._constants import Endpoints +from .._constants import KnownAuthorities try: ABC = abc.ABC @@ -31,9 +31,14 @@ class AadClientBase(ABC): """Sans I/O methods for AAD clients wrapping MSAL's OAuth client""" - def __init__(self, client_id, tenant_id, **kwargs): + def __init__(self, client_id, tenant, **kwargs): # type: (str, str, **Any) -> None - config = {"token_endpoint": Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id)} + authority = kwargs.pop("authority", KnownAuthorities.AZURE_PUBLIC_CLOUD) + if authority[-1] == "/": + authority = authority[:-1] + token_endpoint = "https://" + "/".join((authority, tenant, "oauth2/v2.0/token")) + config = {"token_endpoint": token_endpoint} + self._client = Client(server_configuration=config, client_id=client_id) self._client.session.close() self._client.session = self._get_client_session(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py index 21a95c51e67d..0ef0f0d0d969 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -25,6 +25,11 @@ class AuthorizationCodeCredential(object): :param str authorization_code: the authorization code from the user's log-in :param str redirect_uri: The application's redirect URI. Must match the URI used to request the authorization code. :param str client_secret: One of the application's client secrets. Required only for web apps and web APIs. + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. """ def __init__(self, client_id, tenant_id, authorization_code, redirect_uri, client_secret=None, **kwargs): diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 50004d132484..90162af00265 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -93,3 +93,29 @@ def assert_secrets_not_exposed(): # neither should unexpected AAD responses del body["error"] assert_secrets_not_exposed() + + +def test_respects_authority(): + my_authority = "my.authority.com" + + class Authority: + respected = False + + def check_url(url, **kwargs): + Authority.respected = url.startswith("https://" + my_authority) + + transport = Mock(side_effect=check_url) + session = Mock(get=transport, post=transport) + client = MockClient("client id", "tenant id", session=session, authority=my_authority) + + fns = [ + functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), + functools.partial(client.obtain_token_by_refresh_token, {"secret": "refresh token"}, "scope"), + ] + + for fn in fns: + Authority.respected = False + with pytest.raises(ClientAuthenticationError): + # raises because the mock transport returns nothing + fn() + assert Authority.respected diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index 515d24957162..414b678c6b56 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -2,10 +2,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import functools +from unittest.mock import Mock + +from azure.core.exceptions import ClientAuthenticationError from azure.identity.aio._internal.aad_client import AadClient import pytest -from unittest.mock import Mock class MockClient(AadClient): @@ -37,3 +40,28 @@ async def test_uses_msal_correctly(): with pytest.raises(Exception): await client.obtain_token_by_refresh_token("refresh token", "scope") assert transport.call_count == 1 + + +@pytest.mark.asyncio +async def test_respects_authority(): + my_authority = "my.authority.com" + + def check_url(url, **kwargs): + nonlocal authority_respected + authority_respected = url.startswith("https://" + my_authority) + + transport = Mock(side_effect=check_url) + session = Mock(get=transport, post=transport) + client = MockClient("client id", "tenant id", session=session, authority=my_authority) + + coros = [ + client.obtain_token_by_authorization_code("code", "uri", "scope"), + client.obtain_token_by_refresh_token({"secret": "refresh token"}, "scope"), + ] + + for coro in coros: + authority_respected = False + with pytest.raises(ClientAuthenticationError): + # raises because the mock transport returns nothing + await coro + assert authority_respected From 40961ce961e1f5d1dc7a08b83c99b4fb1c28dcdd Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 23 Sep 2019 10:45:56 -0700 Subject: [PATCH 10/14] paint it black --- .../azure-identity/azure/identity/aio/_internal/aad_client.py | 4 +--- sdk/identity/azure-identity/tests/test_aad_client_async.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 17e997db834e..444dade7d13f 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -23,9 +23,7 @@ def _get_client_session(self, **kwargs): return MsalTransportAdapter(**kwargs) @wrap_exceptions - async def _obtain_token( - self, scopes: "Iterable[str]", fn: "Callable", **kwargs: "Any" - ) -> "AccessToken": + async def _obtain_token(self, scopes: "Iterable[str]", fn: "Callable", **kwargs: "Any") -> "AccessToken": now = int(time.time()) executor = kwargs.get("executor", None) loop = kwargs.get("loop", None) or asyncio.get_event_loop() diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index 414b678c6b56..8648553169ee 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -10,7 +10,6 @@ import pytest - class MockClient(AadClient): def __init__(self, *args, **kwargs): self.session = kwargs.pop("session") From 7a6249852dd93899ad56058a4e1a492fe500a787 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 30 Sep 2019 12:48:09 -0700 Subject: [PATCH 11/14] test URLs more thoroughly --- .../azure-identity/tests/test_aad_client.py | 27 +++++++++++-------- .../tests/test_aad_client_async.py | 19 ++++++------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 90162af00265..e0cd5d9b93e1 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -7,6 +7,9 @@ from azure.core.exceptions import ClientAuthenticationError from azure.identity._internal.aad_client import AadClient import pytest +from six.moves.urllib_parse import urlparse + +from helpers import mock_response try: from unittest.mock import Mock @@ -95,18 +98,19 @@ def assert_secrets_not_exposed(): assert_secrets_not_exposed() -def test_respects_authority(): - my_authority = "my.authority.com" - - class Authority: - respected = False +def test_request_url(): + authority = "authority.com" + tenant = "expected_tenant" - def check_url(url, **kwargs): - Authority.respected = url.startswith("https://" + my_authority) + def validate_url(url, **kwargs): + scheme, netloc, path, _, _, _ = urlparse(url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant) - transport = Mock(side_effect=check_url) + transport = Mock(side_effect=validate_url) session = Mock(get=transport, post=transport) - client = MockClient("client id", "tenant id", session=session, authority=my_authority) + client = MockClient("client id", tenant, session=session, authority=authority) fns = [ functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), @@ -114,8 +118,9 @@ def check_url(url, **kwargs): ] for fn in fns: - Authority.respected = False with pytest.raises(ClientAuthenticationError): # raises because the mock transport returns nothing fn() - assert Authority.respected + +if __name__ == "__main__": + test_request_url() \ No newline at end of file diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index 8648553169ee..c3bfed4c4a66 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -42,16 +42,19 @@ async def test_uses_msal_correctly(): @pytest.mark.asyncio -async def test_respects_authority(): - my_authority = "my.authority.com" +async def test_request_url(): + authority = "authority.com" + tenant = "expected_tenant" - def check_url(url, **kwargs): - nonlocal authority_respected - authority_respected = url.startswith("https://" + my_authority) + def validate_url(url, **kwargs): + scheme, netloc, path, _, _, _ = urlparse(url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant) - transport = Mock(side_effect=check_url) + transport = Mock(side_effect=validate_url) session = Mock(get=transport, post=transport) - client = MockClient("client id", "tenant id", session=session, authority=my_authority) + client = MockClient("client id", "tenant id", session=session, authority=authority) coros = [ client.obtain_token_by_authorization_code("code", "uri", "scope"), @@ -59,8 +62,6 @@ def check_url(url, **kwargs): ] for coro in coros: - authority_respected = False with pytest.raises(ClientAuthenticationError): # raises because the mock transport returns nothing await coro - assert authority_respected From 4cea49f88dbedc82ca4e869b1af4b1d06a872985 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 30 Sep 2019 15:54:24 -0700 Subject: [PATCH 12/14] obtain_token_by_refresh_token must pass kwargs --- .../azure-identity/azure/identity/_internal/aad_client_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 67deda4bbe0b..d289cb3e58b3 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -73,7 +73,7 @@ def obtain_token_by_refresh_token(self, refresh_token, scopes, **kwargs): rt_getter=lambda token: token["secret"], **kwargs ) - return self._obtain_token(scopes, fn) + return self._obtain_token(scopes, fn, **kwargs) def _process_response(self, response, scopes, now): # type: (dict, Iterable[str], int) -> AccessToken From 600c61f765a037a99f0351b81d03e34d7aaffceb Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 30 Sep 2019 16:24:27 -0700 Subject: [PATCH 13/14] ensure loop is passed in kwargs --- .../identity/aio/_internal/aad_client.py | 20 ++++++++++++++++--- .../aio/_internal/msal_transport_adapter.py | 4 ++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 444dade7d13f..853af81ec319 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -19,13 +19,27 @@ class AadClient(AadClientBase): + # pylint:disable=arguments-differ + + def obtain_token_by_authorization_code( + self, *args: "Any", loop: "asyncio.AbstractEventLoop" = None, **kwargs: "Any" + ) -> "AccessToken": + # 'loop' will reach the transport adapter as a kwarg, so here we ensure it's passed + loop = loop or asyncio.get_event_loop() + return super().obtain_token_by_authorization_code(*args, loop=loop, **kwargs) + + def obtain_token_by_refresh_token(self, *args, loop: "asyncio.AbstractEventLoop" = None, **kwargs) -> "AccessToken": + # 'loop' will reach the transport adapter as a kwarg, so here we ensure it's passed + loop = loop or asyncio.get_event_loop() + return super().obtain_token_by_refresh_token(*args, loop=loop, **kwargs) + def _get_client_session(self, **kwargs): return MsalTransportAdapter(**kwargs) @wrap_exceptions - async def _obtain_token(self, scopes: "Iterable[str]", fn: "Callable", **kwargs: "Any") -> "AccessToken": + async def _obtain_token( + self, scopes: "Iterable[str]", fn: "Callable", loop: "asyncio.AbstractEventLoop", executor=None, **kwargs: "Any" + ) -> "AccessToken": now = int(time.time()) - executor = kwargs.get("executor", None) - loop = kwargs.get("loop", None) or asyncio.get_event_loop() response = await loop.run_in_executor(executor, fn) return self._process_response(response=response, scopes=scopes, now=now) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py index e85baca7d3c9..a1be9df87c77 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py @@ -55,6 +55,7 @@ def _close_transport_session(self) -> None: def get( self, url: str, + loop: "asyncio.AbstractEventLoop", headers: "Optional[Dict[str, str]]" = None, params: "Optional[Dict[str, str]]" = None, timeout: "Optional[float]" = None, @@ -66,7 +67,6 @@ def get( if params: request.format_parameters(params) - loop = kwargs.pop("loop", None) future = asyncio.run_coroutine_threadsafe( # type: ignore self._pipeline.run(request, connection_timeout=timeout, connection_verify=verify, **kwargs), loop ) @@ -77,6 +77,7 @@ def get( def post( self, url: str, + loop: "asyncio.AbstractEventLoop", data: "Any" = None, headers: "Optional[Dict[str, str]]" = None, params: "Optional[Dict[str, str]]" = None, @@ -92,7 +93,6 @@ def post( request.headers["Content-Type"] = "application/x-www-form-urlencoded" request.set_formdata_body(data) - loop = kwargs.pop("loop", None) future = asyncio.run_coroutine_threadsafe( # type: ignore self._pipeline.run(request, connection_timeout=timeout, connection_verify=verify, **kwargs), loop ) From f7c50bc9df50feb028fe2b2b385051e69e09d6fd Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 30 Sep 2019 16:25:22 -0700 Subject: [PATCH 14/14] return a passable mock response --- .../azure-identity/tests/test_aad_client.py | 23 ++++----------- .../tests/test_aad_client_async.py | 28 +++++++------------ 2 files changed, 16 insertions(+), 35 deletions(-) diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index e0cd5d9b93e1..1c067ce9d1eb 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -102,25 +102,14 @@ def test_request_url(): authority = "authority.com" tenant = "expected_tenant" - def validate_url(url, **kwargs): - scheme, netloc, path, _, _, _ = urlparse(url) + def send(request, **_): + scheme, netloc, path, _, _, _ = urlparse(request.url) assert scheme == "https" assert netloc == authority assert path.startswith("/" + tenant) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - transport = Mock(side_effect=validate_url) - session = Mock(get=transport, post=transport) - client = MockClient("client id", tenant, session=session, authority=authority) - - fns = [ - functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), - functools.partial(client.obtain_token_by_refresh_token, {"secret": "refresh token"}, "scope"), - ] - - for fn in fns: - with pytest.raises(ClientAuthenticationError): - # raises because the mock transport returns nothing - fn() + client = AadClient("client id", tenant, transport=Mock(send=send), authority=authority) -if __name__ == "__main__": - test_request_url() \ No newline at end of file + client.obtain_token_by_authorization_code("code", "uri", "scope") + client.obtain_token_by_refresh_token("refresh token", "scope") diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index c3bfed4c4a66..d4101e841b9a 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -2,13 +2,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -import functools from unittest.mock import Mock +from urllib.parse import urlparse -from azure.core.exceptions import ClientAuthenticationError from azure.identity.aio._internal.aad_client import AadClient import pytest +from helpers import mock_response + class MockClient(AadClient): def __init__(self, *args, **kwargs): @@ -21,9 +22,8 @@ def _get_client_session(self, **kwargs): @pytest.mark.asyncio async def test_uses_msal_correctly(): - session = Mock() transport = Mock() - session.get = session.post = transport + session = Mock(get=transport, post=transport) client = MockClient("client id", "tenant id", session=session) @@ -46,22 +46,14 @@ async def test_request_url(): authority = "authority.com" tenant = "expected_tenant" - def validate_url(url, **kwargs): - scheme, netloc, path, _, _, _ = urlparse(url) + async def send(request, **_): + scheme, netloc, path, _, _, _ = urlparse(request.url) assert scheme == "https" assert netloc == authority assert path.startswith("/" + tenant) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - transport = Mock(side_effect=validate_url) - session = Mock(get=transport, post=transport) - client = MockClient("client id", "tenant id", session=session, authority=authority) - - coros = [ - client.obtain_token_by_authorization_code("code", "uri", "scope"), - client.obtain_token_by_refresh_token({"secret": "refresh token"}, "scope"), - ] + client = AadClient("client id", tenant, transport=Mock(send=send), authority=authority) - for coro in coros: - with pytest.raises(ClientAuthenticationError): - # raises because the mock transport returns nothing - await coro + await client.obtain_token_by_authorization_code("code", "uri", "scope") + await client.obtain_token_by_refresh_token("refresh token", "scope")