Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
authorization code. See Azure Active Directory's
[authorization code documentation](https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow)
for more information about this authentication flow.
- Multi-cloud support: client credentials accept the authority of an Azure Active
Directory authentication endpoint as an `authority` keyword argument. Known
authorities are defined in `azure.identity.KnownAuthorities`. The default
authority is for Azure Public Cloud, `login.microsoftonline.com`
(`KnownAuthorities.AZURE_PUBLIC_CLOUD`). An application running in Azure
Government would use `KnownAuthorities.AZURE_GOVERNMENT` instead:
>```
>from azure.identity import DefaultAzureCredential, KnownAuthorities
>credential = DefaultAzureCredential(authority=KnownAuthorities.AZURE_GOVERNMENT)
>```

### Breaking changes:
- Removed `client_secret` parameter from `InteractiveBrowserCredential`
Expand Down
4 changes: 1 addition & 3 deletions sdk/identity/azure-identity/azure/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from ._constants import EnvironmentVariables, KnownAuthorities
from ._constants import KnownAuthorities
from ._credentials import (
AuthorizationCodeCredential,

CertificateCredential,
ChainedTokenCredential,
ClientSecretCredential,
Expand All @@ -27,7 +26,6 @@
"DefaultAzureCredential",
"DeviceCodeCredential",
"EnvironmentCredential",
"EnvironmentVariables",
"InteractiveBrowserCredential",
"KnownAuthorities",
"ManagedIdentityCredential",
Expand Down
25 changes: 17 additions & 8 deletions sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, ProxyPolicy, RetryPolicy
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
from azure.core.pipeline.transport import RequestsTransport
from azure.identity._constants import AZURE_CLI_CLIENT_ID
from azure.identity._constants import AZURE_CLI_CLIENT_ID, KnownAuthorities

try:
ABC = abc.ABC
Expand All @@ -39,12 +39,22 @@
class AuthnClientBase(ABC):
"""Sans I/O authentication client methods"""

def __init__(self, auth_url, **kwargs): # pylint:disable=unused-argument
# type: (str, **Any) -> None
if not auth_url:
raise ValueError("auth_url should be the URL of an OAuth endpoint")
def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pylint:disable=unused-argument
# type: (Optional[str], Optional[str], Optional[str], **Any) -> None
super(AuthnClientBase, self).__init__()
self._auth_url = auth_url
if authority and endpoint:
raise ValueError(
"'authority' and 'endpoint' are mutually exclusive. 'authority' should be the authority of an AAD"
+ " endpoint, whereas 'endpoint' should be the endpoint's full URL."
)

if endpoint:
self._auth_url = endpoint
else:
if not tenant:
raise ValueError("'tenant' is required")
authority = authority or KnownAuthorities.AZURE_PUBLIC_CLOUD
self._auth_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/"), "oauth2/v2.0/token"))
self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache

def get_cached_token(self, scopes):
Expand Down Expand Up @@ -165,7 +175,6 @@ class AuthnClient(AuthnClientBase):
# pylint:disable=missing-client-constructor-parameter-credential
def __init__(
self,
auth_url, # type: str
config=None, # type: Optional[Configuration]
policies=None, # type: Optional[Iterable[HTTPPolicy]]
transport=None, # type: Optional[HttpTransport]
Expand All @@ -182,7 +191,7 @@ def __init__(
if not transport:
transport = RequestsTransport(**kwargs)
self._pipeline = Pipeline(transport=transport, policies=policies)
super(AuthnClient, self).__init__(auth_url, **kwargs)
super(AuthnClient, self).__init__(**kwargs)

def request_token(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ class InteractiveBrowserCredential(PublicClientCredential):
:param str client_id: the application's client ID

Keyword arguments
- *authority*: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the
authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines
authorities for other clouds.
- *tenant (str)*: a tenant ID or a domain associated with a tenant. Defaults to the 'organizations' tenant,
which can authenticate work or school accounts.
- *timeout (int)*: seconds to wait for the user to complete authentication. Defaults to 300 (5 minutes).

"""

def __init__(self, client_id, **kwargs):
# type: (str, Any) -> None
# type: (str, **Any) -> None
self._timeout = kwargs.pop("timeout", 300)
self._server_class = kwargs.pop("server_class", AuthCodeRedirectServer) # facilitate mocking
super(InteractiveBrowserCredential, self).__init__(client_id=client_id, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# ------------------------------------
from .._authn_client import AuthnClient
from .._base import ClientSecretCredentialBase, CertificateCredentialBase
from .._constants import Endpoints

try:
from typing import TYPE_CHECKING
Expand All @@ -24,12 +23,17 @@ class ClientSecretCredential(ClientSecretCredentialBase):
:param str client_id: the service principal's client ID
:param str secret: one of the service principal's client secrets
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.

Keyword arguments
- **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the
authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines
authorities for other clouds.
"""

def __init__(self, client_id, secret, tenant_id, **kwargs):
# type: (str, str, str, Mapping[str, Any]) -> None
super(ClientSecretCredential, self).__init__(client_id, secret, tenant_id, **kwargs)
self._client = AuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), **kwargs)
self._client = AuthnClient(tenant=tenant_id, **kwargs)

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
Expand All @@ -54,11 +58,16 @@ class CertificateCredential(CertificateCredentialBase):
:param str client_id: the service principal's client ID
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
:param str certificate_path: path to a PEM-encoded certificate file including the private key

Keyword arguments
- **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the
authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines
authorities for other clouds.
"""

def __init__(self, client_id, tenant_id, certificate_path, **kwargs):
# type: (str, str, str, Mapping[str, Any]) -> None
self._client = AuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format(tenant_id), **kwargs)
self._client = AuthnClient(tenant=tenant_id, **kwargs)
super(CertificateCredential, self).__init__(client_id, tenant_id, certificate_path, **kwargs)

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, endpoint, client_cls, config=None, client_id=None, **kwargs):
self._client_id = client_id
config = config or self._create_config(**kwargs)
policies = [ContentDecodePolicy(), config.headers_policy, config.retry_policy, config.logging_policy]
self._client = client_cls(endpoint, config, policies, **kwargs)
self._client = client_cls(endpoint=endpoint, config=config, policies=policies, **kwargs)

@staticmethod
def _create_config(**kwargs):
Expand Down Expand Up @@ -105,9 +105,9 @@ class ImdsCredential(_ManagedIdentityBase):
:type config: :class:`azure.core.configuration`
"""

def __init__(self, config=None, **kwargs):
# type: (Optional[Configuration], Any) -> None
super(ImdsCredential, self).__init__(endpoint=Endpoints.IMDS, client_cls=AuthnClient, config=config, **kwargs)
def __init__(self, **kwargs):
# type: (**Any) -> None
super(ImdsCredential, self).__init__(endpoint=Endpoints.IMDS, client_cls=AuthnClient, **kwargs)
self._endpoint_available = None # type: Optional[bool]

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
Expand Down
14 changes: 12 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_credentials/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from azure.core.exceptions import ClientAuthenticationError

from .._authn_client import AuthnClient
from .._constants import Endpoints
from .._internal import PublicClientCredential, wrap_exceptions

try:
Expand Down Expand Up @@ -47,6 +46,9 @@ class DeviceCodeCredential(PublicClientCredential):
If not provided, the credential will print instructions to stdout.

Keyword arguments
- *authority*: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the
authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines
authorities for other clouds.
- *tenant (str)* - tenant ID or a domain associated with a tenant. If not provided, defaults to the
'organizations' tenant, which supports only Azure Active Directory work or school accounts.
- *timeout (int)* - seconds to wait for the user to authenticate. Defaults to the validity period of the device
Expand Down Expand Up @@ -112,6 +114,11 @@ class SharedTokenCacheCredential(object):
:param str username:
Username (typically an email address) of the user to authenticate as. This is required because the local cache
may contain tokens for multiple identities.

Keyword arguments
- **authority**: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the
authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines
authorities for other clouds.
"""

def __init__(self, username, **kwargs): # pylint:disable=unused-argument
Expand Down Expand Up @@ -166,7 +173,7 @@ def supported():
@staticmethod
def _get_auth_client(cache):
# type: (msal_extensions.FileTokenCache) -> AuthnClientBase
return AuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format("common"), cache=cache)
return AuthnClient(tenant="common", cache=cache)


class UsernamePasswordCredential(PublicClientCredential):
Expand All @@ -186,6 +193,9 @@ class UsernamePasswordCredential(PublicClientCredential):
:param str password: the user's password

Keyword arguments
- *authority*: Authority of an Azure Active Directory endpoint, for example 'login.microsoftonline.com', the
authority for Azure Public Cloud (which is the default). :class:`~azure.identity.KnownAuthorities` defines
authorities for other clouds.
- *tenant (str)* - tenant ID or a domain associated with a tenant. If not provided, defaults to the
'organizations' tenant, which supports only Azure Active Directory work or school accounts.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .exception_wrapper import wrap_exceptions
from .msal_transport_adapter import MsalTransportAdapter
from .._constants import KnownAuthorities

try:
ABC = abc.ABC
Expand All @@ -33,9 +34,11 @@
class MsalCredential(ABC):
"""Base class for credentials wrapping MSAL applications"""

def __init__(self, client_id, authority, client_credential=None, **kwargs):
# type: (str, str, Optional[Union[str, Mapping[str, str]]], Any) -> None
self._authority = authority
def __init__(self, client_id, client_credential=None, **kwargs):
# type: (str, Optional[Union[str, Mapping[str, str]]], **Any) -> None
tenant = kwargs.pop("tenant", None) or "organizations"
authority = kwargs.pop("authority", KnownAuthorities.AZURE_PUBLIC_CLOUD)
self._base_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/")))
self._client_credential = client_credential
self._client_id = client_id

Expand All @@ -60,7 +63,8 @@ def _create_app(self, cls):

# MSAL application initializers use msal.authority to send AAD tenant discovery requests
with self._adapter:
app = cls(client_id=self._client_id, client_credential=self._client_credential, authority=self._authority)
# MSAL's "authority" is a URL e.g. https://login.microsoftonline.com/common
app = cls(client_id=self._client_id, client_credential=self._client_credential, authority=self._base_url)

# monkeypatch the app to replace requests.Session with MsalTransportAdapter
app.client.session.close()
Expand Down Expand Up @@ -100,13 +104,6 @@ def _get_app(self):
class PublicClientCredential(MsalCredential):
"""Wraps an MSAL PublicClientApplication with the TokenCredential API"""

def __init__(self, **kwargs):
# type: (Any) -> None
tenant = kwargs.pop("tenant", None) or "organizations"
super(PublicClientCredential, self).__init__(
authority="https://login.microsoftonline.com/" + tenant, **kwargs
)

@abc.abstractmethod
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
Expand Down
40 changes: 19 additions & 21 deletions sdk/identity/azure-identity/azure/identity/aio/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,36 @@
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import Any, Dict, Iterable, Mapping, Optional
from typing import TYPE_CHECKING


from msal import TokenCache
from azure.core import Configuration
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies import (
AsyncRetryPolicy,
ContentDecodePolicy,
HTTPPolicy,
NetworkTraceLoggingPolicy,
ProxyPolicy,
)
from azure.core.pipeline.policies import AsyncRetryPolicy, ContentDecodePolicy, NetworkTraceLoggingPolicy, ProxyPolicy
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
from azure.core.pipeline.transport import AsyncHttpTransport
from azure.core.pipeline.transport.requests_asyncio import AsyncioRequestsTransport

from .._authn_client import AuthnClientBase

if TYPE_CHECKING:
from typing import Any, Dict, Iterable, Mapping, Optional
from azure.core.pipeline.policies import HTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport


class AsyncAuthnClient(AuthnClientBase): # pylint:disable=async-client-bad-name
"""Async authentication client"""

# pylint:disable=missing-client-constructor-parameter-credential
def __init__(
self,
auth_url: str,
config: "Optional[Configuration]" = None,
policies: Optional[Iterable[HTTPPolicy]] = None,
transport: Optional[AsyncHttpTransport] = None,
**kwargs: Mapping[str, Any]
policies: "Optional[Iterable[HTTPPolicy]]" = None,
transport: "Optional[AsyncHttpTransport]" = None,
**kwargs: "Any"
) -> None:
config = config or self._create_config(**kwargs)
policies = policies or [
Expand All @@ -46,15 +44,15 @@ def __init__(
if not transport:
transport = AsyncioRequestsTransport(**kwargs)
self._pipeline = AsyncPipeline(transport=transport, policies=policies)
super(AsyncAuthnClient, self).__init__(auth_url, **kwargs)
super().__init__(**kwargs)

async def request_token(
self,
scopes: Iterable[str],
method: Optional[str] = "POST",
headers: Optional[Mapping[str, str]] = None,
form_data: Optional[Mapping[str, str]] = None,
params: Optional[Dict[str, str]] = None,
scopes: "Iterable[str]",
method: "Optional[str]" = "POST",
headers: "Optional[Mapping[str, str]]" = None,
form_data: "Optional[Mapping[str, str]]" = None,
params: "Optional[Dict[str, str]]" = None,
**kwargs: "Any"
) -> AccessToken:
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
Expand All @@ -63,7 +61,7 @@ async def request_token(
token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
return token

async def obtain_token_by_refresh_token(self, scopes: Iterable[str], username: str) -> Optional[AccessToken]:
async def obtain_token_by_refresh_token(self, scopes: "Iterable[str]", username: str) -> "Optional[AccessToken]":
"""Acquire an access token using a cached refresh token. Returns ``None`` when that fails, or the cache has no
refresh token. This is only used by SharedTokenCacheCredential and isn't robust enough for anything else."""

Expand Down Expand Up @@ -91,7 +89,7 @@ async def obtain_token_by_refresh_token(self, scopes: Iterable[str], username: s
return None

@staticmethod
def _create_config(**kwargs: Mapping[str, Any]) -> Configuration:
def _create_config(**kwargs: "Any") -> Configuration:
config = Configuration(**kwargs)
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
config.retry_policy = AsyncRetryPolicy(**kwargs)
Expand Down
Loading