Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions sdk/identity/azure-identity/azure/identity/_auth_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
SUPPORTED_VERSIONS = {"1.0"}


class AuthenticationRecord(object):
class AuthenticationRecord:
"""Non-secret account information for an authenticated user

This class enables :class:`DeviceCodeCredential` and :class:`InteractiveBrowserCredential` to access
Expand All @@ -32,8 +32,7 @@ def __init__(
self._username = username

@property
def authority(self):
# type: () -> str
def authority(self) -> str:
return self._authority

@property
Expand All @@ -54,8 +53,7 @@ def username(self) -> str:
return self._username

@classmethod
def deserialize(cls, data):
# type: (str) -> AuthenticationRecord
def deserialize(cls, data: str) -> "AuthenticationRecord":
"""Deserialize a record.

:param str data: a serialized record
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# ------------------------------------
import functools
import os
from typing import Optional, Dict
from typing import Optional, Dict, Any

from azure.core.pipeline.transport import HttpRequest

Expand All @@ -14,7 +14,7 @@


class AppServiceCredential(ManagedIdentityBase):
def get_client(self, **kwargs) -> Optional[ManagedIdentityClient]:
def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]:
client_args = _get_client_args(**kwargs)
if client_args:
return ManagedIdentityClient(**client_args)
Expand All @@ -24,7 +24,7 @@ def get_unavailable_message(self) -> str:
return "App Service managed identity configuration not found in environment"


def _get_client_args(**kwargs) -> Optional[Dict]:
def _get_client_args(**kwargs: Any) -> Optional[Dict]:
identity_config = kwargs.pop("identity_config", None) or {}

url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# ------------------------------------
import logging
import os
from typing import Any

from azure.core.credentials import AccessToken
from .chained import ChainedTokenCredential
Expand Down Expand Up @@ -51,7 +52,7 @@ class AzureApplicationCredential(ChainedTokenCredential):
of the environment variable AZURE_CLIENT_ID, if any. If not specified, a system-assigned identity will be used.
"""

def __init__(self, **kwargs) -> None:
def __init__(self, **kwargs: Any) -> None:
authority = kwargs.pop("authority", None)
authority = normalize_authority(authority) if authority else get_default_authority()
managed_identity_client_id = kwargs.pop(
Expand All @@ -62,7 +63,7 @@ def __init__(self, **kwargs) -> None:
ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs),
)

def get_token(self, *scopes: str, **kwargs) -> AccessToken:
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import Optional
from typing import Optional, Any

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
Expand Down Expand Up @@ -37,9 +37,9 @@ def __init__(
client_id: str,
authorization_code: str,
redirect_uri: str,
**kwargs
**kwargs: Any
) -> None:
self._authorization_code = authorization_code # type: Optional[str]
self._authorization_code: Optional[str] = authorization_code
self._client_id = client_id
self._client_secret = kwargs.pop("client_secret", None)
self._client = kwargs.pop("client", None) or AadClient(tenant_id, client_id, **kwargs)
Expand All @@ -57,7 +57,7 @@ def close(self) -> None:
"""Close the credential's transport session."""
self.__exit__()

def get_token(self, *scopes: str, **kwargs) -> AccessToken:
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __enter__(self):
def __exit__(self, *args):
self._client.__exit__(*args)

def close(self):
def close(self) -> None:
self.__exit__()

def get_unavailable_message(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import subprocess
import sys
import time
from typing import List
from typing import List, Optional, Any
import six

from azure.core.credentials import AccessToken
Expand All @@ -27,7 +27,7 @@
NOT_LOGGED_IN = "Please run 'az login' to set up an account"


class AzureCliCredential(object):
class AzureCliCredential:
"""Authenticates by requesting a token from the Azure CLI.

This requires previously logging in to Azure via "az login", and will use the CLI's currently logged in identity.
Expand All @@ -37,7 +37,12 @@ class AzureCliCredential(object):
for which the credential may acquire tokens. Add the wildcard value "*" to allow the credential to
acquire tokens for any tenant the application can access.
"""
def __init__(self, *, tenant_id: str = "", additionally_allowed_tenants: List[str] = None):
def __init__(
self,
*,
tenant_id: str = "",
additionally_allowed_tenants: Optional[List[str]] = None
) -> None:

self.tenant_id = tenant_id
self._additionally_allowed_tenants = additionally_allowed_tenants or []
Expand All @@ -52,7 +57,7 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzureCliCredential")
def get_token(self, *scopes: str, **kwargs) -> AccessToken:
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand Down Expand Up @@ -92,7 +97,7 @@ def get_token(self, *scopes: str, **kwargs) -> AccessToken:
return token


def parse_token(output):
def parse_token(output) -> Optional[AccessToken]:
"""Parse output of 'az account get-access-token' to an AccessToken.

In particular, convert the "expiresOn" value to epoch seconds. This value is a naive local datetime as returned by
Expand All @@ -113,7 +118,7 @@ def parse_token(output):
return None


def get_safe_working_dir():
def get_safe_working_dir() -> str:
"""Invoke 'az' from a directory controlled by the OS, not the executing program's directory"""

if sys.platform.startswith("win"):
Expand All @@ -125,7 +130,7 @@ def get_safe_working_dir():
return "/bin"


def sanitize_output(output):
def sanitize_output(output: str) -> str:
"""Redact access tokens from CLI output to prevent error messages revealing them"""
return re.sub(r"\"accessToken\": \"(.*?)(\"|$)", "****", output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import platform
import subprocess
import sys
from typing import List, Tuple
from typing import List, Tuple, Optional, Any
import six

from azure.core.credentials import AccessToken
Expand Down Expand Up @@ -42,7 +42,7 @@
"""


class AzurePowerShellCredential(object):
class AzurePowerShellCredential:
"""Authenticates by requesting a token from Azure PowerShell.

This requires previously logging in to Azure via "Connect-AzAccount", and will use the currently logged in identity.
Expand All @@ -52,7 +52,12 @@ class AzurePowerShellCredential(object):
for which the credential may acquire tokens. Add the wildcard value "*" to allow the credential to
acquire tokens for any tenant the application can access.
"""
def __init__(self, *, tenant_id: str = "", additionally_allowed_tenants: List[str] = None):
def __init__(
self,
*,
tenant_id: str = "",
additionally_allowed_tenants: Optional[List[str]] = None
) -> None:

self.tenant_id = tenant_id
self._additionally_allowed_tenants = additionally_allowed_tenants or []
Expand All @@ -67,7 +72,7 @@ def close(self) -> None:
"""Calling this method is unnecessary."""

@log_get_token("AzurePowerShellCredential")
def get_token(self, *scopes: str, **kwargs) -> AccessToken:
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""Request an access token for `scopes`.

This method is called automatically by Azure SDK clients. Applications calling this method directly must
Expand Down Expand Up @@ -127,8 +132,7 @@ def run_command_line(command_line: List[str]) -> str:
return stdout


def start_process(args):
# type: (List[str]) -> subprocess.Popen
def start_process(args: List[str]) -> "subprocess.Popen":
working_directory = get_safe_working_dir()
proc = subprocess.Popen(
args,
Expand All @@ -140,8 +144,7 @@ def start_process(args):
return proc


def parse_token(output):
# type: (str) -> AccessToken
def parse_token(output: str) -> AccessToken:
for line in output.split():
if line.startswith("azsdk%"):
_, token, expires_on = line.split("%")
Expand All @@ -150,8 +153,7 @@ def parse_token(output):
raise ClientAuthenticationError(message='Unexpected output from Get-AzAccessToken: "{}"'.format(output))


def get_command_line(scopes, tenant_id):
# type: (Tuple, str) -> List[str]
def get_command_line(scopes: Tuple[str, ...], tenant_id: str) -> List[str]:
if tenant_id:
tenant_argument = " -TenantId " + tenant_id
else:
Expand All @@ -166,8 +168,7 @@ def get_command_line(scopes, tenant_id):
return ["/bin/sh", "-c", command]


def raise_for_error(return_code, stdout, stderr):
# type: (int, str, str) -> None
def raise_for_error(return_code: int, stdout: str, stderr: str) -> None:
if return_code == 0:
if NO_AZ_ACCOUNT_MODULE in stdout:
raise CredentialUnavailableError(AZ_ACCOUNT_NOT_INSTALLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import socket
from typing import Dict
from typing import Dict, Any
from urllib.parse import urlparse

from azure.core.exceptions import ClientAuthenticationError
Expand Down Expand Up @@ -46,7 +46,7 @@ class InteractiveBrowserCredential(InteractiveCredential):
:raises ValueError: invalid **redirect_uri**
"""

def __init__(self, **kwargs) -> None:
def __init__(self, **kwargs: Any) -> None:
redirect_uri = kwargs.pop("redirect_uri", None)
if redirect_uri:
self._parsed_url = urlparse(redirect_uri)
Expand All @@ -61,7 +61,7 @@ def __init__(self, **kwargs) -> None:
super(InteractiveBrowserCredential, self).__init__(client_id=client_id, **kwargs)

@wrap_exceptions
def _request_token(self, *scopes: str, **kwargs) -> Dict:
def _request_token(self, *scopes: str, **kwargs: Any) -> Dict:
scopes = list(scopes) # type: ignore
claims = kwargs.get("claims")
app = self._get_app(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(
self,
tenant_id: str,
client_id: str,
certificate_path: str = None,
**kwargs
certificate_path: Optional[str] = None,
**kwargs: Any
) -> None:
validate_tenant_id(tenant_id)

Expand Down Expand Up @@ -82,7 +82,7 @@ def extract_cert_chain(pem_bytes: bytes) -> bytes:

def load_pem_certificate(
certificate_data: bytes,
password: bytes = None
password: Optional[bytes] = None
) -> _Cert:
private_key = serialization.load_pem_private_key(certificate_data, password, backend=default_backend())
cert = x509.load_pem_x509_certificate(certificate_data, default_backend())
Expand All @@ -92,7 +92,7 @@ def load_pem_certificate(

def load_pkcs12_certificate(
certificate_data: bytes,
password: bytes = None
password: Optional[bytes] = None
) -> _Cert:
from cryptography.hazmat.primitives.serialization import Encoding, NoEncryption, pkcs12, PrivateFormat

Expand Down Expand Up @@ -121,11 +121,11 @@ def load_pkcs12_certificate(


def get_client_credential(
certificate_path: str = None,
password: Union[bytes, str] = None,
certificate_data: bytes = None,
certificate_path: Optional[str] = None,
password: Optional[Union[bytes, str]] = None,
certificate_data: Optional[bytes] = None,
send_certificate_chain: bool = False,
**_
**_: Any
) -> Dict:
"""Load a certificate from a filesystem path or bytes, return it as a dict suitable for msal.ClientApplication"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from typing import Any, Optional, TYPE_CHECKING
from azure.core.exceptions import ClientAuthenticationError

from azure.core.credentials import AccessToken
from .. import CredentialUnavailableError
from .._internal import within_credential_chain

if TYPE_CHECKING:
from azure.core.credentials import AccessToken, TokenCredential
from azure.core.credentials import TokenCredential

_LOGGER = logging.getLogger(__name__)

Expand All @@ -28,7 +29,7 @@ def _get_error_message(history):
)


class ChainedTokenCredential(object):
class ChainedTokenCredential:
"""A sequence of credentials that is itself a credential.

Its :func:`get_token` method calls ``get_token`` on each credential in the sequence, in order, returning the first
Expand All @@ -48,20 +49,18 @@ def __init__(self, *credentials):

def __enter__(self):
for credential in self.credentials:
credential.__enter__()
credential.__enter__() # type: ignore
return self

def __exit__(self, *args):
def __exit__(self, *args: Any):
for credential in self.credentials:
credential.__exit__(*args)
credential.__exit__(*args) # type: ignore

def close(self):
# type: () -> None
def close(self) -> None:
"""Close the transport session of each credential in the chain."""
self.__exit__()

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
"""Request a token from each chained credential, in order, returning the first token received.

This method is called automatically by Azure SDK clients.
Expand Down
Loading