diff --git a/sdk/identity/azure-identity/HISTORY.md b/sdk/identity/azure-identity/HISTORY.md index 80e0ff93b793..dd5207477ab5 100644 --- a/sdk/identity/azure-identity/HISTORY.md +++ b/sdk/identity/azure-identity/HISTORY.md @@ -1,6 +1,12 @@ # Release History ## 1.0.0b4 +### New features: +- `AuthorizationCodeCredential` authenticates with a previously obtained +authorization code. See Azure Active Directory's +[authorization code documentation](https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow) +for more information about this authentication flow. + ### Fixes and improvements: - `UsernamePasswordCredential` correctly handles environment configuration with no tenant information (#7260) diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index 3e586121bfb9..60d2ba4af91c 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -2,14 +2,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from ._constants import EnvironmentVariables, KnownAuthorities from ._credentials import ( - InteractiveBrowserCredential, + AuthorizationCodeCredential, + CertificateCredential, ChainedTokenCredential, ClientSecretCredential, DefaultAzureCredential, DeviceCodeCredential, EnvironmentCredential, + InteractiveBrowserCredential, ManagedIdentityCredential, SharedTokenCacheCredential, UsernamePasswordCredential, @@ -17,13 +20,16 @@ __all__ = [ + "AuthorizationCodeCredential", "CertificateCredential", "ChainedTokenCredential", "ClientSecretCredential", "DefaultAzureCredential", "DeviceCodeCredential", "EnvironmentCredential", + "EnvironmentVariables", "InteractiveBrowserCredential", + "KnownAuthorities", "ManagedIdentityCredential", "SharedTokenCacheCredential", "UsernamePasswordCredential", diff --git a/sdk/identity/azure-identity/azure/identity/_constants.py b/sdk/identity/azure-identity/azure/identity/_constants.py index f497282c826e..ffb0ed644b58 100644 --- a/sdk/identity/azure-identity/azure/identity/_constants.py +++ b/sdk/identity/azure-identity/azure/identity/_constants.py @@ -7,6 +7,13 @@ AZURE_CLI_CLIENT_ID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" +class KnownAuthorities: + AZURE_CHINA = "login.chinacloudapi.cn" + AZURE_GERMANY = "login.microsoftonline.de" + AZURE_GOVERNMENT = "login.microsoftonline.us" + AZURE_PUBLIC_CLOUD = "login.microsoftonline.com" + + class EnvironmentVariables: AZURE_CLIENT_ID = "AZURE_CLIENT_ID" AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET" @@ -28,5 +35,4 @@ class Endpoints: # https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http IMDS = "http://169.254.169.254/metadata/identity/oauth2/token" - # TODO: other clouds have other endpoints - AAD_OAUTH2_V2_FORMAT = "https://login.microsoftonline.com/{}/oauth2/v2.0/token" + AAD_OAUTH2_V2_FORMAT = "https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD + "/{}/oauth2/v2.0/token" diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py b/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py index 45f6450fe18d..28e0969c8cd4 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/__init__.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .authorization_code import AuthorizationCodeCredential from .browser import InteractiveBrowserCredential from .chained import ChainedTokenCredential from .client_credential import CertificateCredential, ClientSecretCredential @@ -12,6 +13,7 @@ __all__ = [ + "AuthorizationCodeCredential", "CertificateCredential", "ChainedTokenCredential", "ClientSecretCredential", diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py new file mode 100644 index 000000000000..82a23634b2fb --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -0,0 +1,75 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from typing import TYPE_CHECKING + +from azure.core.exceptions import ClientAuthenticationError +from .._internal.aad_client import AadClient + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Iterable, Optional + from azure.core.credentials import AccessToken + + +class AuthorizationCodeCredential(object): + """ + Authenticates by redeeming an authorization code previously obtained from Azure Active Directory. + See https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow for more information + about the authentication flow. + + :param str client_id: the application's client ID + :param str tenant_id: ID of the application's Azure Active Directory tenant. Also called its 'directory' ID. + :param str authorization_code: the authorization code from the user's log-in + :param str redirect_uri: The application's redirect URI. Must match the URI used to request the authorization code. + :param str client_secret: One of the application's client secrets. Required only for web apps and web APIs. + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. + """ + + def __init__(self, client_id, tenant_id, authorization_code, redirect_uri, client_secret=None, **kwargs): + # type: (str, str, str, str, Optional[str], **Any) -> None + self._authorization_code = authorization_code # type: Optional[str] + self._client_id = client_id + self._client_secret = client_secret + self._client = kwargs.pop("client", None) or AadClient(client_id, tenant_id, **kwargs) + self._redirect_uri = redirect_uri + + def get_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + """ + Request an access token for ``scopes``. The first time this method is called, the credential will redeem its + authorization code. On subsequent calls the credential will return a cached access token or redeem a refresh + token, if it acquired a refresh token upon redeeming the authorization code. + + :param str scopes: desired scopes for the access token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: :class:`azure.core.exceptions.ClientAuthenticationError` + """ + + if self._authorization_code: + token = self._client.obtain_token_by_authorization_code( + code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, **kwargs + ) + self._authorization_code = None # auth codes are single-use + return token + + token = self._client.get_cached_access_token(scopes) or self._redeem_refresh_token(scopes, **kwargs) + if not token: + raise ClientAuthenticationError( + message="No authorization code, cached access token, or refresh token available." + ) + + return token + + def _redeem_refresh_token(self, scopes, **kwargs): + # type: (Iterable[str], **Any) -> Optional[AccessToken] + for refresh_token in self._client.get_cached_refresh_tokens(scopes): + token = self._client.obtain_token_by_refresh_token(refresh_token, scopes, **kwargs) + if token: + return token + return None diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index 6108ffdaa895..b6234d1eee67 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -2,12 +2,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .aad_client import AadClient +from .aad_client_base import AadClientBase from .auth_code_redirect_handler import AuthCodeRedirectServer from .exception_wrapper import wrap_exceptions from .msal_credentials import ConfidentialClientCredential, PublicClientCredential from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse __all__ = [ + "AadClient", + "AadClientBase", "AuthCodeRedirectServer", "ConfidentialClientCredential", "MsalTransportAdapter", diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py new file mode 100644 index 000000000000..aa25b4a4b1a2 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -0,0 +1,30 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""A thin wrapper around MSAL's token cache and OAuth 2 client""" + +import time +from typing import TYPE_CHECKING + +from azure.core.credentials import AccessToken + +from .aad_client_base import AadClientBase +from .msal_transport_adapter import MsalTransportAdapter +from .exception_wrapper import wrap_exceptions + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Callable, Iterable + + +class AadClient(AadClientBase): + def _get_client_session(self, **kwargs): + return MsalTransportAdapter(**kwargs) + + @wrap_exceptions + def _obtain_token(self, scopes, fn, **kwargs): # pylint:disable=unused-argument + # type: (Iterable[str], Callable, **Any) -> AccessToken + now = int(time.time()) + response = fn() + return self._process_response(response=response, scopes=scopes, now=now) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py new file mode 100644 index 000000000000..d289cb3e58b3 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -0,0 +1,119 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import abc +import functools +import time + +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +from msal import TokenCache +from msal.oauth2cli.oauth2 import Client + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError +from .._constants import KnownAuthorities + +try: + ABC = abc.ABC +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Callable, Iterable, Optional + + +class AadClientBase(ABC): + """Sans I/O methods for AAD clients wrapping MSAL's OAuth client""" + + def __init__(self, client_id, tenant, **kwargs): + # type: (str, str, **Any) -> None + authority = kwargs.pop("authority", KnownAuthorities.AZURE_PUBLIC_CLOUD) + if authority[-1] == "/": + authority = authority[:-1] + token_endpoint = "https://" + "/".join((authority, tenant, "oauth2/v2.0/token")) + config = {"token_endpoint": token_endpoint} + + self._client = Client(server_configuration=config, client_id=client_id) + self._client.session.close() + self._client.session = self._get_client_session(**kwargs) + self._cache = TokenCache() + + def get_cached_access_token(self, scopes): + # type: (Iterable[str]) -> Optional[AccessToken] + tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes)) + for token in tokens: + expires_on = int(token["expires_on"]) + if expires_on - 300 > int(time.time()): + return AccessToken(token["secret"], expires_on) + return None + + def get_cached_refresh_tokens(self, scopes): + """Assumes all cached refresh tokens belong to the same user""" + return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes)) + + def obtain_token_by_authorization_code(self, code, redirect_uri, scopes, **kwargs): + # type: (str, str, Iterable[str], **Any) -> AccessToken + fn = functools.partial( + self._client.obtain_token_by_authorization_code, code=code, redirect_uri=redirect_uri, **kwargs + ) + return self._obtain_token(scopes, fn, **kwargs) + + def obtain_token_by_refresh_token(self, refresh_token, scopes, **kwargs): + # type: (str, Iterable[str], **Any) -> AccessToken + fn = functools.partial( + self._client.obtain_token_by_refresh_token, + token_item=refresh_token, + scope=scopes, + rt_getter=lambda token: token["secret"], + **kwargs + ) + return self._obtain_token(scopes, fn, **kwargs) + + def _process_response(self, response, scopes, now): + # type: (dict, Iterable[str], int) -> AccessToken + _raise_for_error(response) + self._cache.add(event={"response": response, "scope": scopes}, now=now) + if "expires_on" in response: + expires_on = int(response["expires_on"]) + elif "expires_in" in response: + expires_on = now + int(response["expires_in"]) + else: + _scrub_secrets(response) + raise ClientAuthenticationError( + message="Unexpected response from Azure Active Directory: {}".format(response) + ) + return AccessToken(response["access_token"], expires_on) + + @abc.abstractmethod + def _get_client_session(self, **kwargs): + pass + + @abc.abstractmethod + def _obtain_token(self, scopes, fn, **kwargs): + # type: (Iterable[str], Callable, **Any) -> AccessToken + pass + + +def _scrub_secrets(response): + for secret in ("access_token", "refresh_token"): + if secret in response: + response[secret] = "***" + + +def _raise_for_error(response): + # type: (dict) -> None + if "error" not in response: + return + + _scrub_secrets(response) + if "error_description" in response: + message = "Azure Active Directory error '({}) {}'".format(response["error"], response["error_description"]) + else: + message = "Azure Active Directory error '{}'".format(response) + raise ClientAuthenticationError(message=message) diff --git a/sdk/identity/azure-identity/azure/identity/aio/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/__init__.py index 57a52755df27..8a63109e49ff 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/__init__.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ from ._credentials import ( + AuthorizationCodeCredential, CertificateCredential, ChainedTokenCredential, ClientSecretCredential, @@ -14,6 +15,7 @@ __all__ = [ + "AuthorizationCodeCredential", "CertificateCredential", "ClientSecretCredential", "DefaultAzureCredential", diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py index 55ed64d00b1e..e2bb64f0fa4e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/__init__.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .authorization_code import AuthorizationCodeCredential from .chained import ChainedTokenCredential from .default import DefaultAzureCredential from .environment import EnvironmentCredential @@ -11,6 +12,7 @@ __all__ = [ + "AuthorizationCodeCredential", "CertificateCredential", "ChainedTokenCredential", "ClientSecretCredential", diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py new file mode 100644 index 000000000000..0ef0f0d0d969 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -0,0 +1,84 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import asyncio +from typing import TYPE_CHECKING + +from azure.core.exceptions import ClientAuthenticationError +from .._internal import AadClient + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Collection, Optional + from azure.core.credentials import AccessToken + + +class AuthorizationCodeCredential(object): + """ + Authenticates by redeeming an authorization code previously obtained from Azure Active Directory. + See https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow for more information + about the authentication flow. + + :param str client_id: the application's client ID + :param str tenant_id: ID of the application's Azure Active Directory tenant. Also called its 'directory' ID. + :param str authorization_code: the authorization code from the user's log-in + :param str redirect_uri: The application's redirect URI. Must match the URI used to request the authorization code. + :param str client_secret: One of the application's client secrets. Required only for web apps and web APIs. + + Keyword arguments + - **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the + authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines + authorities for other clouds. + """ + + def __init__(self, client_id, tenant_id, authorization_code, redirect_uri, client_secret=None, **kwargs): + # type: (str, str, str, str, Optional[str], **Any) -> None + self._authorization_code = authorization_code # type: Optional[str] + self._client_id = client_id + self._client_secret = client_secret + self._client = kwargs.pop("client", None) or AadClient(client_id, tenant_id, **kwargs) + self._redirect_uri = redirect_uri + + async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + """ + Request an access token for ``scopes``. The first time this method is called, the credential will redeem its + authorization code. On subsequent calls the credential will return a cached access token or redeem a refresh + token, if it acquired a refresh token upon redeeming the authorization code. + + :param str scopes: desired scopes for the access token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: :class:`azure.core.exceptions.ClientAuthenticationError` + + Keyword arguments: + - **executor**: (optional) a :class:`concurrent.futures.Executor` used to execute asynchronous calls + - **loop**: (optional) an event loop on which to schedule network I/O. If not provided, the currently + running loop will be used. + """ + + if self._authorization_code: + loop = kwargs.pop("loop", None) or asyncio.get_event_loop() + token = await self._client.obtain_token_by_authorization_code( + code=self._authorization_code, redirect_uri=self._redirect_uri, scopes=scopes, loop=loop, **kwargs + ) + self._authorization_code = None # auth codes are single-use + return token + + token = self._client.get_cached_access_token(scopes) + if not token: + token = await self._redeem_refresh_token(scopes, **kwargs) + + if not token: + raise ClientAuthenticationError( + message="No authorization code, cached access token, or refresh token available." + ) + + return token + + async def _redeem_refresh_token(self, scopes: "Collection[str]", **kwargs: "Any") -> "Optional[AccessToken]": + loop = kwargs.pop("loop", None) or asyncio.get_event_loop() + for refresh_token in self._client.get_cached_refresh_tokens(scopes): + token = await self._client.obtain_token_by_refresh_token(refresh_token, scopes, loop=loop, **kwargs) + if token: + return token + return None diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py index 7ca58029041f..097b56a96c85 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py @@ -2,6 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from .aad_client import AadClient from .exception_wrapper import wrap_exceptions +from .msal_transport_adapter import MsalTransportAdapter -__all__ = ["wrap_exceptions"] +__all__ = ["AadClient", "MsalTransportAdapter", "wrap_exceptions"] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py new file mode 100644 index 000000000000..853af81ec319 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -0,0 +1,45 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""A thin wrapper around MSAL's token cache and OAuth 2 client""" + +import asyncio +import time +from typing import TYPE_CHECKING + +from azure.identity._internal import AadClientBase +from .msal_transport_adapter import MsalTransportAdapter +from .exception_wrapper import wrap_exceptions + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any, Callable, Iterable + from azure.core.credentials import AccessToken + + +class AadClient(AadClientBase): + # pylint:disable=arguments-differ + + def obtain_token_by_authorization_code( + self, *args: "Any", loop: "asyncio.AbstractEventLoop" = None, **kwargs: "Any" + ) -> "AccessToken": + # 'loop' will reach the transport adapter as a kwarg, so here we ensure it's passed + loop = loop or asyncio.get_event_loop() + return super().obtain_token_by_authorization_code(*args, loop=loop, **kwargs) + + def obtain_token_by_refresh_token(self, *args, loop: "asyncio.AbstractEventLoop" = None, **kwargs) -> "AccessToken": + # 'loop' will reach the transport adapter as a kwarg, so here we ensure it's passed + loop = loop or asyncio.get_event_loop() + return super().obtain_token_by_refresh_token(*args, loop=loop, **kwargs) + + def _get_client_session(self, **kwargs): + return MsalTransportAdapter(**kwargs) + + @wrap_exceptions + async def _obtain_token( + self, scopes: "Iterable[str]", fn: "Callable", loop: "asyncio.AbstractEventLoop", executor=None, **kwargs: "Any" + ) -> "AccessToken": + now = int(time.time()) + response = await loop.run_in_executor(executor, fn) + return self._process_response(response=response, scopes=scopes, now=now) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py new file mode 100644 index 000000000000..a1be9df87c77 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/msal_transport_adapter.py @@ -0,0 +1,108 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Adapter to substitute an async azure-core pipeline for Requests in MSAL application token acquisition methods.""" + +import asyncio +import atexit +from typing import TYPE_CHECKING + +from azure.core.configuration import Configuration +from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.policies import AsyncRetryPolicy, DistributedTracingPolicy, NetworkTraceLoggingPolicy +from azure.core.pipeline.transport import AioHttpTransport, HttpRequest + +from azure.identity._internal import MsalTransportResponse + +if TYPE_CHECKING: + # pylint:disable=unused-import + from typing import Any, Dict, Iterable, Optional + from azure.core.pipeline.policies import AsyncHTTPPolicy + from azure.core.pipeline.transport import AsyncHttpTransport + + +class MsalTransportAdapter: + """Wraps an async azure-core pipeline with the shape of a (synchronous) Requests Session""" + + def __init__( + self, + config: "Optional[Configuration]" = None, + policies: "Optional[Iterable[AsyncHTTPPolicy]]" = None, + transport: "Optional[AsyncHttpTransport]" = None, + **kwargs: "Any" + ) -> None: + + config = config or self._create_config(**kwargs) + policies = policies or [config.retry_policy, config.logging_policy, DistributedTracingPolicy()] + self._transport = transport or AioHttpTransport(configuration=config) + atexit.register(self._close_transport_session) # prevent aiohttp warnings + self._pipeline = AsyncPipeline(transport=self._transport, policies=policies) + + def _close_transport_session(self) -> None: + """If transport has a 'close' method, invoke it.""" + + close = getattr(self._transport, "close", None) + if not callable(close): + return + + if asyncio.iscoroutinefunction(close): + # we expect no loop is running because this method should be called only when the interpreter is exiting + asyncio.new_event_loop().run_until_complete(close()) + else: + close() + + def get( + self, + url: str, + loop: "asyncio.AbstractEventLoop", + headers: "Optional[Dict[str, str]]" = None, + params: "Optional[Dict[str, str]]" = None, + timeout: "Optional[float]" = None, + verify: "Optional[bool]" = None, + **kwargs: "Any" + ) -> MsalTransportResponse: + + request = HttpRequest("GET", url, headers=headers) + if params: + request.format_parameters(params) + + future = asyncio.run_coroutine_threadsafe( # type: ignore + self._pipeline.run(request, connection_timeout=timeout, connection_verify=verify, **kwargs), loop + ) + response = future.result(timeout=timeout) + + return MsalTransportResponse(response) + + def post( + self, + url: str, + loop: "asyncio.AbstractEventLoop", + data: "Any" = None, + headers: "Optional[Dict[str, str]]" = None, + params: "Optional[Dict[str, str]]" = None, + timeout: "Optional[float]" = None, + verify: "Optional[bool]" = None, + **kwargs: "Any" + ) -> MsalTransportResponse: + + request = HttpRequest("POST", url, headers=headers) + if params: + request.format_parameters(params) + if data: + request.headers["Content-Type"] = "application/x-www-form-urlencoded" + request.set_formdata_body(data) + + future = asyncio.run_coroutine_threadsafe( # type: ignore + self._pipeline.run(request, connection_timeout=timeout, connection_verify=verify, **kwargs), loop + ) + response = future.result(timeout=timeout) + + return MsalTransportResponse(response) + + @staticmethod + def _create_config(**kwargs: "Any") -> Configuration: + config = Configuration(**kwargs) + config.logging_policy = NetworkTraceLoggingPolicy(**kwargs) + config.retry_policy = AsyncRetryPolicy(**kwargs) + return config diff --git a/sdk/identity/azure-identity/dev_requirements.txt b/sdk/identity/azure-identity/dev_requirements.txt index 811e306ce457..bdb60dd6d780 100644 --- a/sdk/identity/azure-identity/dev_requirements.txt +++ b/sdk/identity/azure-identity/dev_requirements.txt @@ -1,4 +1,5 @@ -e ../../core/azure-core +aiohttp;python_full_version>="3.5.2" typing_extensions>=3.7.2 pytest pytest-asyncio;python_full_version>="3.5.2" diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py new file mode 100644 index 000000000000..1c067ce9d1eb --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -0,0 +1,115 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import functools + +from azure.core.exceptions import ClientAuthenticationError +from azure.identity._internal.aad_client import AadClient +import pytest +from six.moves.urllib_parse import urlparse + +from helpers import mock_response + +try: + from unittest.mock import Mock +except ImportError: # python < 3.3 + from mock import Mock # type: ignore + + +class MockClient(AadClient): + def __init__(self, *args, **kwargs): + self.session = kwargs.pop("session") + super(MockClient, self).__init__(*args, **kwargs) + + def _get_client_session(self, **kwargs): + return self.session + + +def test_uses_msal_correctly(): + session = Mock() + transport = Mock() + session.get = session.post = transport + + client = MockClient("client id", "tenant id", session=session) + + # MSAL will raise on each call because the mock transport returns nothing useful. + # That's okay because we only want to verify the transport was called, i.e. that + # the client used the MSAL API correctly, such that MSAL tried to send a request. + with pytest.raises(ClientAuthenticationError): + client.obtain_token_by_authorization_code("code", "redirect uri", "scope") + assert transport.call_count == 1 + + transport.reset_mock() + + with pytest.raises(ClientAuthenticationError): + client.obtain_token_by_refresh_token("refresh token", "scope") + assert transport.call_count == 1 + + +def test_error_reporting(): + error_name = "everything's sideways" + error_description = "something went wrong" + error_response = {"error": error_name, "error_description": error_description} + + response = Mock(status_code=403, json=lambda: error_response) + transport = Mock(return_value=response) + session = Mock(get=transport, post=transport) + client = MockClient("client id", "tenant id", session=session) + + fns = [ + functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), + functools.partial(client.obtain_token_by_refresh_token, {"secret": "refresh token"}, "scope"), + ] + + # exceptions raised for AAD errors should contain AAD's error description + for fn in fns: + with pytest.raises(ClientAuthenticationError) as ex: + fn() + message = str(ex.value) + assert error_name in message and error_description in message + + +def test_exceptions_do_not_expose_secrets(): + secret = "secret" + body = {"error": "bad thing", "access_token": secret, "refresh_token": secret} + response = Mock(status_code=403, json=lambda: body) + transport = Mock(return_value=response) + session = Mock(get=transport, post=transport) + client = MockClient("client id", "tenant id", session=session) + + fns = [ + functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), + functools.partial(client.obtain_token_by_refresh_token, {"secret": "refresh token"}, "scope"), + ] + + def assert_secrets_not_exposed(): + for fn in fns: + with pytest.raises(ClientAuthenticationError) as ex: + fn() + assert secret not in str(ex.value) + assert secret not in repr(ex.value) + + # AAD errors shouldn't provoke exceptions exposing secrets + assert_secrets_not_exposed() + + # neither should unexpected AAD responses + del body["error"] + assert_secrets_not_exposed() + + +def test_request_url(): + authority = "authority.com" + tenant = "expected_tenant" + + def send(request, **_): + scheme, netloc, path, _, _, _ = urlparse(request.url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) + + client = AadClient("client id", tenant, transport=Mock(send=send), authority=authority) + + client.obtain_token_by_authorization_code("code", "uri", "scope") + client.obtain_token_by_refresh_token("refresh token", "scope") diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py new file mode 100644 index 000000000000..d4101e841b9a --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -0,0 +1,59 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from unittest.mock import Mock +from urllib.parse import urlparse + +from azure.identity.aio._internal.aad_client import AadClient +import pytest + +from helpers import mock_response + + +class MockClient(AadClient): + def __init__(self, *args, **kwargs): + self.session = kwargs.pop("session") + super(MockClient, self).__init__(*args, **kwargs) + + def _get_client_session(self, **kwargs): + return self.session + + +@pytest.mark.asyncio +async def test_uses_msal_correctly(): + transport = Mock() + session = Mock(get=transport, post=transport) + + client = MockClient("client id", "tenant id", session=session) + + # MSAL will raise on each call because the mock transport returns nothing useful. + # That's okay because we only want to verify the transport was called, i.e. that + # the client used the MSAL API correctly, such that MSAL tried to send a request. + with pytest.raises(Exception): + await client.obtain_token_by_authorization_code("code", "redirect uri", "scope") + assert transport.call_count == 1 + + transport.reset_mock() + + with pytest.raises(Exception): + await client.obtain_token_by_refresh_token("refresh token", "scope") + assert transport.call_count == 1 + + +@pytest.mark.asyncio +async def test_request_url(): + authority = "authority.com" + tenant = "expected_tenant" + + async def send(request, **_): + scheme, netloc, path, _, _, _ = urlparse(request.url) + assert scheme == "https" + assert netloc == authority + assert path.startswith("/" + tenant) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) + + client = AadClient("client id", tenant, transport=Mock(send=send), authority=authority) + + await client.obtain_token_by_authorization_code("code", "uri", "scope") + await client.obtain_token_by_refresh_token("refresh token", "scope") diff --git a/sdk/identity/azure-identity/tests/test_auth_code.py b/sdk/identity/azure-identity/tests/test_auth_code.py new file mode 100644 index 000000000000..5300368ef334 --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_auth_code.py @@ -0,0 +1,50 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.core.credentials import AccessToken +from azure.identity import AuthorizationCodeCredential + +try: + from unittest.mock import Mock +except ImportError: # python < 3.3 + from mock import Mock # type: ignore + + +def test_auth_code_credential(): + client_id = "client id" + tenant_id = "tenant" + expected_code = "auth code" + redirect_uri = "https://foo.bar" + expected_token = AccessToken("token", 42) + + mock_client = Mock(spec=object) + mock_client.obtain_token_by_authorization_code = Mock(return_value=expected_token) + + credential = AuthorizationCodeCredential( + client_id=client_id, + tenant_id=tenant_id, + authorization_code=expected_code, + redirect_uri=redirect_uri, + client=mock_client, + ) + + # first call should redeem the auth code + token = credential.get_token("scope") + assert token is expected_token + assert mock_client.obtain_token_by_authorization_code.call_count == 1 + _, kwargs = mock_client.obtain_token_by_authorization_code.call_args + assert kwargs["code"] == expected_code + + # no auth code -> credential should return cached token + mock_client.obtain_token_by_authorization_code = None # raise if credential calls this again + mock_client.get_cached_access_token = lambda *_: expected_token + token = credential.get_token("scope") + assert token is expected_token + + # no auth code, no cached token -> credential should use refresh token + mock_client.get_cached_access_token = lambda *_: None + mock_client.get_cached_refresh_tokens = lambda *_: ["this is a refresh token"] + mock_client.obtain_token_by_refresh_token = lambda *_, **__: expected_token + token = credential.get_token("scope") + assert token is expected_token diff --git a/sdk/identity/azure-identity/tests/test_auth_code_async.py b/sdk/identity/azure-identity/tests/test_auth_code_async.py new file mode 100644 index 000000000000..53bd5266485c --- /dev/null +++ b/sdk/identity/azure-identity/tests/test_auth_code_async.py @@ -0,0 +1,97 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import asyncio +from unittest.mock import Mock + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError +from azure.identity.aio import AuthorizationCodeCredential +import pytest + + +@pytest.mark.asyncio +async def test_auth_code_credential(): + client_id = "client id" + tenant_id = "tenant" + expected_code = "auth code" + redirect_uri = "https://foo.bar" + expected_token = AccessToken("token", 42) + + mock_client = Mock(spec=object) + obtain_by_auth_code = Mock(return_value=expected_token) + mock_client.obtain_token_by_authorization_code = asyncio.coroutine(obtain_by_auth_code) + + credential = AuthorizationCodeCredential( + client_id=client_id, + tenant_id=tenant_id, + authorization_code=expected_code, + redirect_uri=redirect_uri, + client=mock_client, + ) + + # first call should redeem the auth code + token = await credential.get_token("scope") + assert token is expected_token + assert obtain_by_auth_code.call_count == 1 + _, kwargs = obtain_by_auth_code.call_args + assert kwargs["code"] == expected_code + + # no auth code -> credential should return cached token + mock_client.obtain_token_by_authorization_code = None # raise if credential calls this again + mock_client.get_cached_access_token = lambda *_: expected_token + token = await credential.get_token("scope") + assert token is expected_token + + # no auth code, no cached token -> credential should use refresh token + mock_client.get_cached_access_token = lambda *_: None + mock_client.get_cached_refresh_tokens = lambda *_: ["this is a refresh token"] + mock_client.obtain_token_by_refresh_token = asyncio.coroutine(lambda *_, **__: expected_token) + token = await credential.get_token("scope") + assert token is expected_token + + +@pytest.mark.asyncio +async def test_custom_executor_used(): + credential = AuthorizationCodeCredential( + client_id="client id", tenant_id="tenant id", authorization_code="auth code", redirect_uri="https://foo.bar" + ) + + executor = Mock() + + with pytest.raises(ClientAuthenticationError): + await credential.get_token("scope", executor=executor) + + assert executor.submit.call_count == 1 + + +@pytest.mark.asyncio +async def test_custom_loop_used(): + credential = AuthorizationCodeCredential( + client_id="client id", tenant_id="tenant id", authorization_code="auth code", redirect_uri="https://foo.bar" + ) + + loop = Mock() + + with pytest.raises(ClientAuthenticationError): + await credential.get_token("scope", loop=loop) + + assert loop.run_in_executor.call_count == 1 + + +@pytest.mark.asyncio +async def test_custom_loop_and_executor_used(): + credential = AuthorizationCodeCredential( + client_id="client id", tenant_id="tenant id", authorization_code="auth code", redirect_uri="https://foo.bar" + ) + + executor = Mock() + loop = Mock() + + with pytest.raises(ClientAuthenticationError): + await credential.get_token("scope", executor=executor, loop=loop) + + assert loop.run_in_executor.call_count == 1 + executor_arg, _ = loop.run_in_executor.call_args[0] + assert executor_arg is executor