diff --git a/sdk/identity/azure-identity/azure/identity/_auth_record.py b/sdk/identity/azure-identity/azure/identity/_auth_record.py index e6ba08788f7f..8161ee5f20ca 100644 --- a/sdk/identity/azure-identity/azure/identity/_auth_record.py +++ b/sdk/identity/azure-identity/azure/identity/_auth_record.py @@ -8,7 +8,7 @@ SUPPORTED_VERSIONS = {"1.0"} -class AuthenticationRecord(object): +class AuthenticationRecord: """Non-secret account information for an authenticated user This class enables :class:`DeviceCodeCredential` and :class:`InteractiveBrowserCredential` to access @@ -32,8 +32,7 @@ def __init__( self._username = username @property - def authority(self): - # type: () -> str + def authority(self) -> str: return self._authority @property @@ -54,8 +53,7 @@ def username(self) -> str: return self._username @classmethod - def deserialize(cls, data): - # type: (str) -> AuthenticationRecord + def deserialize(cls, data: str) -> "AuthenticationRecord": """Deserialize a record. :param str data: a serialized record diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py index 64b526f7f374..6c1a4d80ec97 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/app_service.py @@ -4,7 +4,7 @@ # ------------------------------------ import functools import os -from typing import Optional, Dict +from typing import Optional, Dict, Any from azure.core.pipeline.transport import HttpRequest @@ -14,7 +14,7 @@ class AppServiceCredential(ManagedIdentityBase): - def get_client(self, **kwargs) -> Optional[ManagedIdentityClient]: + def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: client_args = _get_client_args(**kwargs) if client_args: return ManagedIdentityClient(**client_args) @@ -24,7 +24,7 @@ def get_unavailable_message(self) -> str: return "App Service managed identity configuration not found in environment" -def _get_client_args(**kwargs) -> Optional[Dict]: +def _get_client_args(**kwargs: Any) -> Optional[Dict]: identity_config = kwargs.pop("identity_config", None) or {} url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/_credentials/application.py index 11be519a3c02..520d323b4589 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/application.py @@ -4,6 +4,7 @@ # ------------------------------------ import logging import os +from typing import Any from azure.core.credentials import AccessToken from .chained import ChainedTokenCredential @@ -51,7 +52,7 @@ class AzureApplicationCredential(ChainedTokenCredential): of the environment variable AZURE_CLIENT_ID, if any. If not specified, a system-assigned identity will be used. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: authority = kwargs.pop("authority", None) authority = normalize_authority(authority) if authority else get_default_authority() managed_identity_client_id = kwargs.pop( @@ -62,7 +63,7 @@ def __init__(self, **kwargs) -> None: ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs), ) - def get_token(self, *scopes: str, **kwargs) -> AccessToken: + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py index 7185b1784e23..75a1f3680f68 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/authorization_code.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Optional +from typing import Optional, Any from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError @@ -37,9 +37,9 @@ def __init__( client_id: str, authorization_code: str, redirect_uri: str, - **kwargs + **kwargs: Any ) -> None: - self._authorization_code = authorization_code # type: Optional[str] + self._authorization_code: Optional[str] = authorization_code self._client_id = client_id self._client_secret = kwargs.pop("client_secret", None) self._client = kwargs.pop("client", None) or AadClient(tenant_id, client_id, **kwargs) @@ -57,7 +57,7 @@ def close(self) -> None: """Close the credential's transport session.""" self.__exit__() - def get_token(self, *scopes: str, **kwargs) -> AccessToken: + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py index 6264244e5acd..93068c9b1a43 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py @@ -35,7 +35,7 @@ def __enter__(self): def __exit__(self, *args): self._client.__exit__(*args) - def close(self): + def close(self) -> None: self.__exit__() def get_unavailable_message(self) -> str: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py index b7d6cf82c896..873426775bee 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_cli.py @@ -10,7 +10,7 @@ import subprocess import sys import time -from typing import List +from typing import List, Optional, Any import six from azure.core.credentials import AccessToken @@ -27,7 +27,7 @@ NOT_LOGGED_IN = "Please run 'az login' to set up an account" -class AzureCliCredential(object): +class AzureCliCredential: """Authenticates by requesting a token from the Azure CLI. This requires previously logging in to Azure via "az login", and will use the CLI's currently logged in identity. @@ -37,7 +37,12 @@ class AzureCliCredential(object): for which the credential may acquire tokens. Add the wildcard value "*" to allow the credential to acquire tokens for any tenant the application can access. """ - def __init__(self, *, tenant_id: str = "", additionally_allowed_tenants: List[str] = None): + def __init__( + self, + *, + tenant_id: str = "", + additionally_allowed_tenants: Optional[List[str]] = None + ) -> None: self.tenant_id = tenant_id self._additionally_allowed_tenants = additionally_allowed_tenants or [] @@ -52,7 +57,7 @@ def close(self) -> None: """Calling this method is unnecessary.""" @log_get_token("AzureCliCredential") - def get_token(self, *scopes: str, **kwargs) -> AccessToken: + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. Applications calling this method directly must @@ -92,7 +97,7 @@ def get_token(self, *scopes: str, **kwargs) -> AccessToken: return token -def parse_token(output): +def parse_token(output) -> Optional[AccessToken]: """Parse output of 'az account get-access-token' to an AccessToken. In particular, convert the "expiresOn" value to epoch seconds. This value is a naive local datetime as returned by @@ -113,7 +118,7 @@ def parse_token(output): return None -def get_safe_working_dir(): +def get_safe_working_dir() -> str: """Invoke 'az' from a directory controlled by the OS, not the executing program's directory""" if sys.platform.startswith("win"): @@ -125,7 +130,7 @@ def get_safe_working_dir(): return "/bin" -def sanitize_output(output): +def sanitize_output(output: str) -> str: """Redact access tokens from CLI output to prevent error messages revealing them""" return re.sub(r"\"accessToken\": \"(.*?)(\"|$)", "****", output) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py index 4715c8ee3123..0ceef0f01b6e 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py @@ -7,7 +7,7 @@ import platform import subprocess import sys -from typing import List, Tuple +from typing import List, Tuple, Optional, Any import six from azure.core.credentials import AccessToken @@ -42,7 +42,7 @@ """ -class AzurePowerShellCredential(object): +class AzurePowerShellCredential: """Authenticates by requesting a token from Azure PowerShell. This requires previously logging in to Azure via "Connect-AzAccount", and will use the currently logged in identity. @@ -52,7 +52,12 @@ class AzurePowerShellCredential(object): for which the credential may acquire tokens. Add the wildcard value "*" to allow the credential to acquire tokens for any tenant the application can access. """ - def __init__(self, *, tenant_id: str = "", additionally_allowed_tenants: List[str] = None): + def __init__( + self, + *, + tenant_id: str = "", + additionally_allowed_tenants: Optional[List[str]] = None + ) -> None: self.tenant_id = tenant_id self._additionally_allowed_tenants = additionally_allowed_tenants or [] @@ -67,7 +72,7 @@ def close(self) -> None: """Calling this method is unnecessary.""" @log_get_token("AzurePowerShellCredential") - def get_token(self, *scopes: str, **kwargs) -> AccessToken: + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. Applications calling this method directly must @@ -127,8 +132,7 @@ def run_command_line(command_line: List[str]) -> str: return stdout -def start_process(args): - # type: (List[str]) -> subprocess.Popen +def start_process(args: List[str]) -> "subprocess.Popen": working_directory = get_safe_working_dir() proc = subprocess.Popen( args, @@ -140,8 +144,7 @@ def start_process(args): return proc -def parse_token(output): - # type: (str) -> AccessToken +def parse_token(output: str) -> AccessToken: for line in output.split(): if line.startswith("azsdk%"): _, token, expires_on = line.split("%") @@ -150,8 +153,7 @@ def parse_token(output): raise ClientAuthenticationError(message='Unexpected output from Get-AzAccessToken: "{}"'.format(output)) -def get_command_line(scopes, tenant_id): - # type: (Tuple, str) -> List[str] +def get_command_line(scopes: Tuple[str, ...], tenant_id: str) -> List[str]: if tenant_id: tenant_argument = " -TenantId " + tenant_id else: @@ -166,8 +168,7 @@ def get_command_line(scopes, tenant_id): return ["/bin/sh", "-c", command] -def raise_for_error(return_code, stdout, stderr): - # type: (int, str, str) -> None +def raise_for_error(return_code: int, stdout: str, stderr: str) -> None: if return_code == 0: if NO_AZ_ACCOUNT_MODULE in stdout: raise CredentialUnavailableError(AZ_ACCOUNT_NOT_INSTALLED) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py index 670ee0299fcb..0a4b0c8d7aa7 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/browser.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/browser.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import socket -from typing import Dict +from typing import Dict, Any from urllib.parse import urlparse from azure.core.exceptions import ClientAuthenticationError @@ -46,7 +46,7 @@ class InteractiveBrowserCredential(InteractiveCredential): :raises ValueError: invalid **redirect_uri** """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: redirect_uri = kwargs.pop("redirect_uri", None) if redirect_uri: self._parsed_url = urlparse(redirect_uri) @@ -61,7 +61,7 @@ def __init__(self, **kwargs) -> None: super(InteractiveBrowserCredential, self).__init__(client_id=client_id, **kwargs) @wrap_exceptions - def _request_token(self, *scopes: str, **kwargs) -> Dict: + def _request_token(self, *scopes: str, **kwargs: Any) -> Dict: scopes = list(scopes) # type: ignore claims = kwargs.get("claims") app = self._get_app(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py index 3fd8025fb2b1..d84754a6d2bc 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py @@ -50,8 +50,8 @@ def __init__( self, tenant_id: str, client_id: str, - certificate_path: str = None, - **kwargs + certificate_path: Optional[str] = None, + **kwargs: Any ) -> None: validate_tenant_id(tenant_id) @@ -82,7 +82,7 @@ def extract_cert_chain(pem_bytes: bytes) -> bytes: def load_pem_certificate( certificate_data: bytes, - password: bytes = None + password: Optional[bytes] = None ) -> _Cert: private_key = serialization.load_pem_private_key(certificate_data, password, backend=default_backend()) cert = x509.load_pem_x509_certificate(certificate_data, default_backend()) @@ -92,7 +92,7 @@ def load_pem_certificate( def load_pkcs12_certificate( certificate_data: bytes, - password: bytes = None + password: Optional[bytes] = None ) -> _Cert: from cryptography.hazmat.primitives.serialization import Encoding, NoEncryption, pkcs12, PrivateFormat @@ -121,11 +121,11 @@ def load_pkcs12_certificate( def get_client_credential( - certificate_path: str = None, - password: Union[bytes, str] = None, - certificate_data: bytes = None, + certificate_path: Optional[str] = None, + password: Optional[Union[bytes, str]] = None, + certificate_data: Optional[bytes] = None, send_certificate_chain: bool = False, - **_ + **_: Any ) -> Dict: """Load a certificate from a filesystem path or bytes, return it as a dict suitable for msal.ClientApplication""" diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py index 17f9fb0ce6ca..3ca79d5a15ca 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/chained.py @@ -6,11 +6,12 @@ from typing import Any, Optional, TYPE_CHECKING from azure.core.exceptions import ClientAuthenticationError +from azure.core.credentials import AccessToken from .. import CredentialUnavailableError from .._internal import within_credential_chain if TYPE_CHECKING: - from azure.core.credentials import AccessToken, TokenCredential + from azure.core.credentials import TokenCredential _LOGGER = logging.getLogger(__name__) @@ -28,7 +29,7 @@ def _get_error_message(history): ) -class ChainedTokenCredential(object): +class ChainedTokenCredential: """A sequence of credentials that is itself a credential. Its :func:`get_token` method calls ``get_token`` on each credential in the sequence, in order, returning the first @@ -48,20 +49,18 @@ def __init__(self, *credentials): def __enter__(self): for credential in self.credentials: - credential.__enter__() + credential.__enter__() # type: ignore return self - def __exit__(self, *args): + def __exit__(self, *args: Any): for credential in self.credentials: - credential.__exit__(*args) + credential.__exit__(*args) # type: ignore - def close(self): - # type: () -> None + def close(self) -> None: """Close the transport session of each credential in the chain.""" self.__exit__() - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # type: (*str, **Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument """Request a token from each chained credential, in order, returning the first token received. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py index 0c951d2dcd8e..2a327b0ac556 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Callable, Optional +from typing import Callable, Optional, Any from azure.core.credentials import AccessToken from .._internal import AadClient @@ -34,7 +34,7 @@ def __init__( tenant_id: str, client_id: str, func: Callable[[], str], - **kwargs + **kwargs: Any ) -> None: self._func = func self._client = AadClient(tenant_id, client_id, **kwargs) @@ -47,14 +47,13 @@ def __enter__(self): def __exit__(self, *args): self._client.__exit__(*args) - def close(self): - # type: () -> None + def close(self) -> None: self.__exit__() - def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: return self._client.get_cached_access_token(scopes, **kwargs) - def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: assertion = self._func() token = self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py index fc86cbf1886f..76ac8d6db5df 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_secret.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from typing import Any from .._internal.client_credential_base import ClientCredentialBase @@ -28,7 +29,7 @@ def __init__( tenant_id: str, client_id: str, client_secret: str, - **kwargs + **kwargs: Any ) -> None: if not client_id: raise ValueError("client_id should be the id of an Azure Active Directory application") diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py index ec776fbce3d4..ab3a337bdc62 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/cloud_shell.py @@ -14,8 +14,7 @@ class CloudShellCredential(ManagedIdentityBase): - def get_client(self, **kwargs): - # type: (**Any) -> Optional[ManagedIdentityClient] + def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT) if url: return ManagedIdentityClient( diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index 2521c7bd620d..2a59837ab360 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -4,7 +4,7 @@ # ------------------------------------ import logging import os -from typing import List, TYPE_CHECKING +from typing import List, TYPE_CHECKING, Any from azure.core.credentials import AccessToken from .._constants import EnvironmentVariables @@ -18,7 +18,6 @@ from .azure_cli import AzureCliCredential from .vscode import VisualStudioCodeCredential - if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -74,7 +73,7 @@ class DefaultAzureCredential(ChainedTokenCredential): Directory work or school accounts. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: if "tenant_id" in kwargs: raise TypeError("'tenant_id' is not supported in DefaultAzureCredential.") diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py index 7834919cb1af..6acdf77dc447 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py @@ -4,7 +4,7 @@ # ------------------------------------ from datetime import datetime import time -from typing import Dict, Optional +from typing import Dict, Optional, Callable, Any from azure.core.exceptions import ClientAuthenticationError @@ -52,14 +52,17 @@ class DeviceCodeCredential(InteractiveCredential): def __init__( self, client_id: str = DEVELOPER_SIGN_ON_CLIENT_ID, - **kwargs + *, + timeout: Optional[int] = None, + prompt_callback: Optional[Callable[[str, str, datetime], None]] = None, + **kwargs: Any ) -> None: - self._timeout = kwargs.pop("timeout", None) # type: Optional[int] - self._prompt_callback = kwargs.pop("prompt_callback", None) + self._timeout = timeout + self._prompt_callback = prompt_callback super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs) @wrap_exceptions - def _request_token(self, *scopes: str, **kwargs) -> Dict: + def _request_token(self, *scopes: str, **kwargs: Any) -> Dict: # MSAL requires scopes be a list scopes = list(scopes) # type: ignore diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py index 4caa3a9a8e4b..687345a802ff 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/environment.py @@ -4,7 +4,7 @@ # ------------------------------------ import logging import os -from typing import Optional, Union, TYPE_CHECKING +from typing import Optional, Union, Any from azure.core.credentials import AccessToken from .. import CredentialUnavailableError @@ -14,14 +14,12 @@ from .client_secret import ClientSecretCredential from .user_password import UsernamePasswordCredential - -if TYPE_CHECKING: - EnvironmentCredentialTypes = Union["CertificateCredential", "ClientSecretCredential", "UsernamePasswordCredential"] +EnvironmentCredentialTypes = Union[CertificateCredential, ClientSecretCredential, UsernamePasswordCredential] _LOGGER = logging.getLogger(__name__) -class EnvironmentCredential(object): +class EnvironmentCredential: """A credential configured by environment variables. This credential is capable of authenticating as a service principal using a client secret or a certificate, or as @@ -56,8 +54,8 @@ class EnvironmentCredential(object): when no value is given. """ - def __init__(self, **kwargs) -> None: - self._credential = None # type: Optional[EnvironmentCredentialTypes] + def __init__(self, **kwargs: Any) -> None: + self._credential: Optional[EnvironmentCredentialTypes] = None if all(os.environ.get(v) is not None for v in EnvironmentVariables.CLIENT_SECRET_VARS): self._credential = ClientSecretCredential( @@ -113,7 +111,7 @@ def close(self) -> None: self.__exit__() @log_get_token("EnvironmentCredential") - def get_token(self, *scopes: str, **kwargs) -> AccessToken: + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index 9dcf812fc7f8..45a869223f35 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import os -from typing import Any, Optional +from typing import Any, Optional, Dict import six @@ -29,7 +29,7 @@ } -def get_request(scope, identity_config): +def _get_request(scope: str, identity_config: Dict) -> HttpRequest: url = ( os.environ.get(EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST, IMDS_AUTHORITY).strip("/") + IMDS_TOKEN_PATH @@ -40,15 +40,15 @@ def get_request(scope, identity_config): class ImdsCredential(GetTokenMixin): - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super(ImdsCredential, self).__init__() - self._client = ManagedIdentityClient(get_request, **dict(PIPELINE_SETTINGS, **kwargs)) + self._client = ManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **kwargs)) if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ: - self._endpoint_available = True # type: Optional[bool] + self._endpoint_available: Optional[bool] = True else: self._endpoint_available = None - self._error_message = None # type: Optional[str] + self._error_message: Optional[str] = None self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs def __enter__(self): @@ -58,14 +58,13 @@ def __enter__(self): def __exit__(self, *args): self._client.__exit__(*args) - def close(self): + def close(self) -> None: self.__exit__() - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: return self._client.get_cached_token(*scopes) - def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: if self._endpoint_available is None: # Lacking another way to determine whether the IMDS endpoint is listening, # we send a request it would immediately reject (because it lacks the Metadata header), diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index 173ba245f347..39ec113d9b60 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -4,7 +4,7 @@ # ------------------------------------ import logging import os -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, Any from azure.core.credentials import AccessToken from .. import CredentialUnavailableError @@ -17,7 +17,7 @@ _LOGGER = logging.getLogger(__name__) -class ManagedIdentityCredential(object): +class ManagedIdentityCredential: """Authenticates with an Azure managed identity in any hosting environment which supports managed identities. This credential defaults to using a system-assigned identity. To configure a user-assigned identity, use one of @@ -33,7 +33,7 @@ class ManagedIdentityCredential(object): :paramtype identity_config: Mapping[str, str] """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: self._credential = None # type: Optional[TokenCredential] if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT): if os.environ.get(EnvironmentVariables.IDENTITY_HEADER): @@ -95,7 +95,7 @@ def close(self) -> None: self.__exit__() @log_get_token("ManagedIdentityCredential") - def get_token(self, *scopes: str, **kwargs) -> AccessToken: + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py index 3c39e0af0e3a..71ab072558ac 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/on_behalf_of.py @@ -3,9 +3,10 @@ # Licensed under the MIT License. # ------------------------------------ import time -from typing import TYPE_CHECKING, Any, Optional +from typing import Any, Optional import six +import msal from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError @@ -15,11 +16,7 @@ from .._internal.get_token_mixin import GetTokenMixin from .._internal.interactive import _build_auth_record from .._internal.msal_credentials import MsalCredential - -if TYPE_CHECKING: - import msal - from .. import AuthenticationRecord - +from .. import AuthenticationRecord class OnBehalfOfCredential(MsalCredential, GetTokenMixin): """Authenticates a service principal via the on-behalf-of flow. @@ -56,7 +53,7 @@ def __init__( self, tenant_id: str, client_id: str, - **kwargs + **kwargs: Any ) -> None: self._assertion = kwargs.pop("user_assertion", None) if not self._assertion: @@ -87,11 +84,10 @@ def __init__( client_credential=credential, tenant_id=tenant_id, **kwargs) - self._auth_record = None # type: Optional[AuthenticationRecord] + self._auth_record: Optional[AuthenticationRecord] = None @wrap_exceptions - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: if self._auth_record: claims = kwargs.get("claims") app = self._get_app(**kwargs) @@ -107,9 +103,8 @@ def _acquire_token_silently(self, *scopes, **kwargs): return None @wrap_exceptions - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken - app = self._get_app(**kwargs) # type: msal.ConfidentialClientApplication + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + app: msal.ConfidentialClientApplication = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_on_behalf_of(self._assertion, list(scopes), claims_challenge=kwargs.get("claims")) if "access_token" not in result or "expires_in" not in result: diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py index 69b7483a27aa..b4081fbf013b 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/service_fabric.py @@ -4,7 +4,7 @@ # ------------------------------------ import functools import os -from typing import Dict, Optional +from typing import Dict, Optional, Any from azure.core.pipeline.transport import HttpRequest @@ -24,7 +24,7 @@ def get_unavailable_message(self) -> str: return "Service Fabric managed identity configuration not found in environment" -def _get_client_args(**kwargs) -> Optional[Dict]: +def _get_client_args(**kwargs: Any) -> Optional[Dict]: url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT) secret = os.environ.get(EnvironmentVariables.IDENTITY_HEADER) thumbprint = os.environ.get(EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index e64621cea94f..27a34788f9d5 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -2,21 +2,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional +from azure.core.credentials import AccessToken from .silent import SilentAuthenticationCredential from .. import CredentialUnavailableError from .._constants import DEVELOPER_SIGN_ON_CLIENT_ID -from .._internal import AadClient +from .._internal import AadClient, AadClientBase from .._internal.decorators import log_get_token from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase if TYPE_CHECKING: from azure.core.credentials import TokenCredential - from .._internal import AadClientBase - -class SharedTokenCacheCredential(object): +class SharedTokenCacheCredential: """Authenticates using tokens in the local cache shared between Microsoft applications. :param str username: Username (typically an email address) of the user to authenticate as. This is used when the @@ -34,7 +33,7 @@ class SharedTokenCacheCredential(object): :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions """ - def __init__(self, username: str = None, **kwargs) -> None: + def __init__(self, username: Optional[str] = None, **kwargs: Any) -> None: if "authentication_record" in kwargs: self._credential = SilentAuthenticationCredential(**kwargs) # type: TokenCredential else: @@ -52,8 +51,7 @@ def close(self) -> None: self.__exit__() @log_get_token("SharedTokenCacheCredential") - def get_token(self, *scopes, **kwargs): - # type (*str, **Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Get an access token for `scopes` from the shared cache. If no access token is cached, attempt to acquire one using a cached refresh token. @@ -94,8 +92,7 @@ def __exit__(self, *args): if self._client: self._client.__exit__(*args) - def get_token(self, *scopes, **kwargs): - # type (*str, **Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: if not scopes: raise ValueError("'get_token' requires at least one scope") @@ -118,6 +115,5 @@ def get_token(self, *scopes, **kwargs): raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username"))) - def _get_auth_client(self, **kwargs): - # type: (**Any) -> AadClientBase + def _get_auth_client(self, **kwargs: Any) -> AadClientBase: return AadClient(client_id=DEVELOPER_SIGN_ON_CLIENT_ID, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py index a4c2d4104683..a61040b40c57 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/silent.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/silent.py @@ -5,7 +5,7 @@ import os import platform import time -from typing import Dict +from typing import Dict, Optional, Any from msal import PublicClientApplication @@ -21,22 +21,24 @@ from .. import AuthenticationRecord -class SilentAuthenticationCredential(object): +class SilentAuthenticationCredential: """Internal class for authenticating from the default shared cache given an AuthenticationRecord""" def __init__( self, authentication_record: AuthenticationRecord, + *, + tenant_id: Optional[str] = None, **kwargs ) -> None: self._auth_record = authentication_record # authenticate in the tenant that produced the record unless "tenant_id" specifies another - self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id + self._tenant_id = tenant_id or self._auth_record.tenant_id validate_tenant_id(self._tenant_id) self._cache = kwargs.pop("_cache", None) self._cache_persistence_options = kwargs.pop("cache_persistence_options", None) - self._client_applications = {} # type: Dict[str, PublicClientApplication] + self._client_applications: Dict[str, PublicClientApplication] = {} self._additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", []) self._client = MsalClient(**kwargs) self._initialized = False @@ -48,7 +50,7 @@ def __enter__(self): def __exit__(self, *args): self._client.__exit__(*args) - def get_token(self, *scopes: str, **kwargs) -> AccessToken: + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: if not scopes: raise ValueError('"get_token" requires at least one scope') @@ -75,7 +77,7 @@ def _initialize(self): self._initialized = True - def _get_client_application(self, **kwargs): + def _get_client_application(self, **kwargs: Any): tenant_id = resolve_tenant( self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, @@ -94,7 +96,7 @@ def _get_client_application(self, **kwargs): return self._client_applications[tenant_id] @wrap_exceptions - def _acquire_token_silent(self, *scopes: str, **kwargs) -> AccessToken: + def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken: """Silently acquire a token from MSAL.""" result = None diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/token_exchange.py b/sdk/identity/azure-identity/azure/identity/_credentials/token_exchange.py index 0ebd5cd9a1b9..6397a5842c2e 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/token_exchange.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/token_exchange.py @@ -8,11 +8,11 @@ from .client_assertion import ClientAssertionCredential -class TokenFileMixin(object): +class TokenFileMixin: def __init__( self, token_file_path: str, - **_ + **_: Any ) -> None: super(TokenFileMixin, self).__init__() self._jwt = "" @@ -34,7 +34,7 @@ def __init__( tenant_id: str, client_id: str, token_file_path: str, - **kwargs + **kwargs: Any ) -> None: super(TokenExchangeCredential, self).__init__( tenant_id=tenant_id, diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/user_password.py b/sdk/identity/azure-identity/azure/identity/_credentials/user_password.py index 37e2e115950c..c0d2cdc67c7a 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/user_password.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/user_password.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Any +from typing import Any, Dict from .._internal import InteractiveCredential, wrap_exceptions @@ -48,7 +48,7 @@ def __init__( client_id: str, username: str, password: str, - **kwargs + **kwargs: Any ) -> None: # The base class will accept an AuthenticationRecord, allowing this credential to authenticate silently the # first time it's asked for a token. However, we want to ensure this first authentication is not silent, to @@ -60,8 +60,7 @@ def __init__( self._password = password @wrap_exceptions - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> dict + def _request_token(self, *scopes: str, **kwargs: Any) -> Dict: app = self._get_app(**kwargs) return app.acquire_token_by_username_password( username=self._username, diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py index 7d72d3f45ba4..824443f9203f 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py @@ -5,12 +5,13 @@ import abc import os import sys -from typing import cast, TYPE_CHECKING, Any, Dict, Optional +from typing import cast, Any, Dict, Optional +from azure.core.credentials import AccessToken from .._exceptions import CredentialUnavailableError from .._constants import AzureAuthorityHosts, AZURE_VSCODE_CLIENT_ID, EnvironmentVariables from .._internal import normalize_authority, validate_tenant_id -from .._internal.aad_client import AadClient +from .._internal.aad_client import AadClient, AadClientBase from .._internal.get_token_mixin import GetTokenMixin from .._internal.decorators import log_get_token @@ -21,15 +22,9 @@ else: from .._internal.linux_vscode_adapter import get_refresh_token, get_user_settings -if TYPE_CHECKING: - from azure.core.credentials import AccessToken - from .._internal.aad_client import AadClientBase -ABC = abc.ABC - -class _VSCodeCredentialBase(ABC): - def __init__(self, **kwargs): - # type: (**Any) -> None +class _VSCodeCredentialBase(abc.ABC): + def __init__(self, **kwargs: Any) -> None: super(_VSCodeCredentialBase, self).__init__() user_settings = get_user_settings() @@ -44,20 +39,17 @@ def __init__(self, **kwargs): self._unavailable_reason = "Initialization failed" @abc.abstractmethod - def _get_client(self, **kwargs): - # type: (**Any) -> AadClientBase + def _get_client(self, **kwargs: Any) -> AadClientBase: pass - def _get_refresh_token(self): - # type: () -> str + def _get_refresh_token(self) -> str: if not self._refresh_token: self._refresh_token = get_refresh_token(self._cloud) if not self._refresh_token: raise CredentialUnavailableError(message="Failed to get Azure user details from Visual Studio Code.") return self._refresh_token - def _initialize(self, vscode_user_settings, **kwargs): - # type: (Dict, **Any) -> None + def _initialize(self, vscode_user_settings: Dict, **kwargs: Any) -> None: """Build a client from kwargs merged with VS Code user settings. The first stable version of this credential defaulted to Public Cloud and the "organizations" @@ -134,14 +126,12 @@ def __exit__(self, *args): if self._client: self._client.__exit__(*args) - def close(self): - # type: () -> None + def close(self) -> None: """Close the credential's transport session.""" self.__exit__() @log_get_token("VSCodeCredential") - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes` as the user currently signed in to Visual Studio Code. This method is called automatically by Azure SDK clients. @@ -161,17 +151,14 @@ def get_token(self, *scopes, **kwargs): raise CredentialUnavailableError(message=error_message) return super(VisualStudioCodeCredential, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: self._client = cast(AadClient, self._client) return self._client.get_cached_access_token(scopes, **kwargs) - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: refresh_token = self._get_refresh_token() self._client = cast(AadClient, self._client) return self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) - def _get_client(self, **kwargs): - # type: (**Any) -> AadClient + def _get_client(self, **kwargs: Any) -> AadClient: return AadClient(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_exceptions.py b/sdk/identity/azure-identity/azure/identity/_exceptions.py index dc6aa4c018dc..97bde06779d4 100644 --- a/sdk/identity/azure-identity/azure/identity/_exceptions.py +++ b/sdk/identity/azure-identity/azure/identity/_exceptions.py @@ -20,8 +20,13 @@ class AuthenticationRequiredError(CredentialUnavailableError): method. """ - def __init__(self, scopes, message=None, claims=None, **kwargs): - # type: (Iterable[str], Optional[str], Optional[str], **Any) -> None + def __init__( + self, + scopes: Iterable[str], + message: Optional[str] = None, + claims: Optional[str] = None, + **kwargs: Any + ) -> None: self._claims = claims self._scopes = scopes if not message: @@ -29,13 +34,11 @@ def __init__(self, scopes, message=None, claims=None, **kwargs): super(AuthenticationRequiredError, self).__init__(message=message, **kwargs) @property - def scopes(self): - # type: () -> Iterable[str] + def scopes(self) -> Iterable[str]: """Scopes requested during the failed authentication""" return self._scopes @property - def claims(self): - # type: () -> Optional[str] + def claims(self) -> Optional[str]: """Additional claims required in the next authentication""" return self._claims diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index 3207bf4bf623..e5adeff9ddfa 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import time -from typing import Iterable, Union +from typing import Iterable, Union, Optional, Any from azure.core.credentials import AccessToken from azure.core.pipeline import Pipeline @@ -29,8 +29,8 @@ def obtain_token_by_authorization_code( scopes: Iterable[str], code: str, redirect_uri: str, - client_secret: str = None, - **kwargs + client_secret: Optional[str] = None, + **kwargs: Any ) -> AccessToken: request = self._get_auth_code_request( scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs @@ -41,7 +41,7 @@ def obtain_token_by_client_certificate( self, scopes: Iterable[str], certificate: AadClientCertificate, - **kwargs + **kwargs: Any ) -> AccessToken: request = self._get_client_certificate_request(scopes, certificate, **kwargs) return self._run_pipeline(request, **kwargs) @@ -50,7 +50,7 @@ def obtain_token_by_client_secret( self, scopes: Iterable[str], secret: str, - **kwargs + **kwargs: Any ) -> AccessToken: request = self._get_client_secret_request(scopes, secret, **kwargs) return self._run_pipeline(request, **kwargs) @@ -59,7 +59,7 @@ def obtain_token_by_jwt_assertion( self, scopes: Iterable[str], assertion: str, - **kwargs + **kwargs: Any ) -> AccessToken: request = self._get_jwt_assertion_request(scopes, assertion) return self._run_pipeline(request, **kwargs) @@ -68,7 +68,7 @@ def obtain_token_by_refresh_token( self, scopes: Iterable[str], refresh_token: str, - **kwargs + **kwargs: Any ) -> AccessToken: request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) return self._run_pipeline(request, **kwargs) @@ -78,16 +78,16 @@ def obtain_token_on_behalf_of( scopes: Iterable[str], client_credential: Union[str, AadClientCertificate], user_assertion: str, - **kwargs + **kwargs: Any ) -> AccessToken: # no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL raise NotImplementedError() # pylint:disable=no-self-use - def _build_pipeline(self, **kwargs) -> Pipeline: + def _build_pipeline(self, **kwargs: Any) -> Pipeline: return build_pipeline(**kwargs) - def _run_pipeline(self, request: HttpRequest, **kwargs) -> AccessToken: + def _run_pipeline(self, request: HttpRequest, **kwargs: Any) -> AccessToken: # remove tenant_id and claims kwarg that could have been passed from credential's get_token method # tenant_id is already part of `request` at this point kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 728fe68a8254..bdc99aa49772 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -12,6 +12,7 @@ import six from msal import TokenCache +from azure.core.pipeline import PipelineResponse from azure.core.pipeline.policies import ContentDecodePolicy from azure.core.pipeline.transport import HttpRequest from azure.core.credentials import AccessToken @@ -20,7 +21,7 @@ from .aadclient_certificate import AadClientCertificate if TYPE_CHECKING: - from azure.core.pipeline import AsyncPipeline, Pipeline, PipelineResponse + from azure.core.pipeline import AsyncPipeline, Pipeline from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport @@ -42,7 +43,7 @@ def __init__( cache: TokenCache = None, *, additionally_allowed_tenants: List[str] = None, - **kwargs + **kwargs: Any ) -> None: self._authority = normalize_authority(authority) if authority else get_default_authority() @@ -53,7 +54,7 @@ def __init__( self._additionally_allowed_tenants = additionally_allowed_tenants or [] self._pipeline = self._build_pipeline(**kwargs) - def get_cached_access_token(self, scopes: Iterable[str], **kwargs) -> Optional[AccessToken]: + def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optional[AccessToken]: tenant = resolve_tenant( self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, @@ -102,8 +103,7 @@ def obtain_token_on_behalf_of(self, scopes, client_credential, user_assertion, * def _build_pipeline(self, **kwargs): pass - def _process_response(self, response, request_time): - # type: (PipelineResponse, int) -> AccessToken + def _process_response(self, response: PipelineResponse, request_time: int) -> AccessToken: content = response.context.get( ContentDecodePolicy.CONTEXT_NAME ) or ContentDecodePolicy.deserialize_from_http_generics(response.http_response) @@ -156,8 +156,14 @@ def _process_response(self, response, request_time): return token - def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None, **kwargs): - # type: (Iterable[str], str, str, Optional[str], **Any) -> HttpRequest + def _get_auth_code_request( + self, + scopes: Iterable[str], + code: str, + redirect_uri: str, + client_secret: Optional[str] = None, + **kwargs: Any + ) -> HttpRequest: data = { "client_id": self._client_id, "code": code, @@ -171,8 +177,12 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None, request = self._post(data, **kwargs) return request - def _get_jwt_assertion_request(self, scopes, assertion, **kwargs): - # type: (Iterable[str], str, **Any) -> HttpRequest + def _get_jwt_assertion_request( + self, + scopes: Iterable[str], + assertion: str, + **kwargs: Any + ) -> HttpRequest: data = { "client_assertion": assertion, "client_assertion_type": JWT_BEARER_ASSERTION, @@ -184,8 +194,7 @@ def _get_jwt_assertion_request(self, scopes, assertion, **kwargs): request = self._post(data, **kwargs) return request - def _get_client_certificate_assertion(self, certificate, **kwargs): - # type: (AadClientCertificate, **Any) -> str + def _get_client_certificate_assertion(self, certificate: AadClientCertificate, **kwargs: Any) -> str: now = int(time.time()) header = six.ensure_binary( json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8" @@ -208,13 +217,16 @@ def _get_client_certificate_assertion(self, certificate, **kwargs): jwt_bytes = jws + b"." + base64.urlsafe_b64encode(signature) return jwt_bytes.decode("utf-8") - def _get_client_certificate_request(self, scopes, certificate, **kwargs): - # type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest + def _get_client_certificate_request( + self, + scopes: Iterable[str], + certificate: AadClientCertificate, + **kwargs: Any + ) -> HttpRequest: assertion = self._get_client_certificate_assertion(certificate, **kwargs) return self._get_jwt_assertion_request(scopes, assertion, **kwargs) - def _get_client_secret_request(self, scopes, secret, **kwargs): - # type: (Iterable[str], str, **Any) -> HttpRequest + def _get_client_secret_request(self, scopes: Iterable[str], secret: str, **kwargs: Any) -> HttpRequest: data = { "client_id": self._client_id, "client_secret": secret, @@ -224,8 +236,13 @@ def _get_client_secret_request(self, scopes, secret, **kwargs): request = self._post(data, **kwargs) return request - def _get_on_behalf_of_request(self, scopes, client_credential, user_assertion, **kwargs): - # type: (Iterable[str], Union[str, AadClientCertificate], str, **Any) -> HttpRequest + def _get_on_behalf_of_request( + self, + scopes: Iterable[str], + client_credential: Union[str, AadClientCertificate], + user_assertion: str, + **kwargs: Any + ) -> HttpRequest: data = { "assertion": user_assertion, "client_id": self._client_id, @@ -242,8 +259,12 @@ def _get_on_behalf_of_request(self, scopes, client_credential, user_assertion, * request = self._post(data, **kwargs) return request - def _get_refresh_token_request(self, scopes, refresh_token, **kwargs): - # type: (Iterable[str], str, **Any) -> HttpRequest + def _get_refresh_token_request( + self, + scopes: Iterable[str], + refresh_token: str, + **kwargs: Any + ) -> HttpRequest: data = { "grant_type": "refresh_token", "refresh_token": refresh_token, @@ -254,8 +275,13 @@ def _get_refresh_token_request(self, scopes, refresh_token, **kwargs): request = self._post(data, **kwargs) return request - def _get_refresh_token_on_behalf_of_request(self, scopes, client_credential, refresh_token, **kwargs): - # type: (Iterable[str], Union[str, AadClientCertificate], str, **Any) -> HttpRequest + def _get_refresh_token_on_behalf_of_request( + self, + scopes: Iterable[str], + client_credential: Union[str, AadClientCertificate], + refresh_token: str, + **kwargs: Any + ) -> HttpRequest: data = { "grant_type": "refresh_token", "refresh_token": refresh_token, @@ -280,21 +306,18 @@ def _get_token_url(self, **kwargs): ) return "/".join((self._authority, tenant, "oauth2/v2.0/token")) - def _post(self, data, **kwargs): - # type: (dict, **Any) -> HttpRequest + def _post(self, data: Dict, **kwargs: Any) -> HttpRequest: url = self._get_token_url(**kwargs) return HttpRequest("POST", url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}) -def _scrub_secrets(response): - # type: (dict) -> None +def _scrub_secrets(response: Dict) -> None: for secret in ("access_token", "refresh_token"): if secret in response: response[secret] = "***" -def _raise_for_error(response, content): - # type: (PipelineResponse, dict) -> None +def _raise_for_error(response: PipelineResponse, content: Dict) -> None: if "error" not in content: return diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aadclient_certificate.py b/sdk/identity/azure-identity/azure/identity/_internal/aadclient_certificate.py index 8a6421559c67..73cefaeb4884 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aadclient_certificate.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aadclient_certificate.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import base64 - +from typing import Optional from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding @@ -12,7 +12,7 @@ import six -class AadClientCertificate(object): +class AadClientCertificate: """Wraps 'cryptography' to provide the crypto operations AadClient requires for certificate authentication. :param bytes pem_bytes: bytes of a a PEM-encoded certificate including the (RSA) private key @@ -21,7 +21,7 @@ class AadClientCertificate(object): def __init__( self, pem_bytes: bytes, - password: bytes = None + password: Optional[bytes] = None ) -> None: private_key = serialization.load_pem_private_key(pem_bytes, password=password, backend=default_backend()) if not isinstance(private_key, RSAPrivateKey): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py index a4f71b0b8508..4987a8ba1d81 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/client_credential_base.py @@ -17,8 +17,7 @@ class ClientCredentialBase(MsalCredential, GetTokenMixin): """Base class for credentials authenticating a service principal with a certificate or secret""" @wrap_exceptions - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: app = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_silent_with_error(list(scopes), account=None, **kwargs) @@ -27,8 +26,7 @@ def _acquire_token_silently(self, *scopes, **kwargs): return None @wrap_exceptions - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] + def _request_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: app = self._get_app(**kwargs) request_time = int(time.time()) result = app.acquire_token_for_client(list(scopes)) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py index e5e581389fa9..3e1b705db44d 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/get_token_mixin.py @@ -5,38 +5,31 @@ import abc import logging import time -from typing import TYPE_CHECKING, Any, Optional +from typing import Any, Optional +from azure.core.credentials import AccessToken from .utils import within_credential_chain from .._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY -if TYPE_CHECKING: - from azure.core.credentials import AccessToken - -ABC = abc.ABC _LOGGER = logging.getLogger(__name__) -class GetTokenMixin(ABC): - def __init__(self, *args, **kwargs): - # type: (*Any, **Any) -> None +class GetTokenMixin(abc.ABC): + def __init__(self, *args: Any, **kwargs: Any) -> None: self._last_request_time = 0 # https://github.com/python/mypy/issues/5887 super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @abc.abstractmethod - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: """Attempt to acquire an access token from a cache or by redeeming a refresh token""" @abc.abstractmethod - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token from the STS""" - def _should_refresh(self, token): - # type: (AccessToken) -> bool + def _should_refresh(self, token: AccessToken) -> bool: now = int(time.time()) if token.expires_on - now > DEFAULT_REFRESH_OFFSET: return False @@ -44,8 +37,7 @@ def _should_refresh(self, token): return False return True - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index 782714562e13..79e69c5e65a1 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -78,9 +78,14 @@ def _build_auth_record(response): class InteractiveCredential(MsalCredential, ABC): - def __init__(self, **kwargs): - self._disable_automatic_authentication = kwargs.pop("disable_automatic_authentication", False) - self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord] + def __init__( + self, + *, + authentication_record: Optional[AuthenticationRecord] = None, + disable_automatic_authentication: bool = False, + **kwargs: Any) -> None: + self._disable_automatic_authentication = disable_automatic_authentication + self._auth_record = authentication_record if self._auth_record: kwargs.pop("client_id", None) # authentication_record overrides client_id argument tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id @@ -93,8 +98,7 @@ def __init__(self, **kwargs): else: super(InteractiveCredential, self).__init__(**kwargs) - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. @@ -157,8 +161,7 @@ def get_token(self, *scopes, **kwargs): _LOGGER.info("%s.get_token succeeded", self.__class__.__name__) return AccessToken(result["access_token"], now + int(result["expires_in"])) - def authenticate(self, **kwargs): - # type: (**Any) -> AuthenticationRecord + def authenticate(self, **kwargs: Any) -> AuthenticationRecord: """Interactively authenticate a user. :keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by @@ -185,8 +188,7 @@ def authenticate(self, **kwargs): return self._auth_record # type: ignore @wrap_exceptions - def _acquire_token_silent(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken + def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessToken: result = None claims = kwargs.get("claims") if self._auth_record: diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py index ff285b38a7de..33b8b2fd36f6 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py @@ -3,26 +3,23 @@ # Licensed under the MIT License. # ------------------------------------ import abc -from typing import cast, TYPE_CHECKING, Any, Optional +from typing import cast, Any, Optional +from azure.core.credentials import AccessToken from .. import CredentialUnavailableError from .._internal.managed_identity_client import ManagedIdentityClient from .._internal.get_token_mixin import GetTokenMixin -if TYPE_CHECKING: - from azure.core.credentials import AccessToken - class ManagedIdentityBase(GetTokenMixin): """Base class for internal credentials using ManagedIdentityClient""" - def __init__(self, **kwargs): - # type: (**Any) -> None + def __init__(self, **kwargs: Any) -> None: super(ManagedIdentityBase, self).__init__() self._client = self.get_client(**kwargs) @abc.abstractmethod - def get_client(self, **kwargs) -> Optional[ManagedIdentityClient]: + def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]: pass @abc.abstractmethod @@ -39,22 +36,18 @@ def __exit__(self, *args): if self._client: self._client.__exit__(*args) - def close(self): - # type: () -> None + def close(self) -> None: self.__exit__() - def get_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: if not self._client: raise CredentialUnavailableError(message=self.get_unavailable_message()) return super(ManagedIdentityBase, self).get_token(*scopes, **kwargs) - def _acquire_token_silently(self, *scopes, **kwargs): - # type: (*str, **Any) -> Optional[AccessToken] + def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: # casting because mypy can't determine that these methods are called # only by get_token, which raises when self._client is None return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes) - def _request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py index 173da1fed373..99c4fdaac7e4 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py @@ -4,7 +4,7 @@ # ------------------------------------ import abc import time -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional from msal import TokenCache import six @@ -12,22 +12,21 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError, DecodeError from azure.core.pipeline.policies import ContentDecodePolicy +from azure.core.pipeline import PipelineResponse +from azure.core.pipeline.transport import HttpRequest from .._internal import _scopes_to_resource from .._internal.pipeline import build_pipeline -if TYPE_CHECKING: - from azure.core.pipeline import PipelineResponse - from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy - from azure.core.pipeline.transport import HttpRequest - PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy] - -ABC = abc.ABC - -class ManagedIdentityClientBase(ABC): +class ManagedIdentityClientBase(abc.ABC): # pylint:disable=missing-client-constructor-parameter-credential - def __init__(self, request_factory, client_id=None, identity_config=None, **kwargs): - # type: (Callable[[str, dict], HttpRequest], Optional[str], Optional[Dict], **Any) -> None + def __init__( + self, + request_factory: Callable[[str, dict], HttpRequest], + client_id: Optional[str] = None, + identity_config: Optional[Dict] = None, + **kwargs: Any + ) -> None: self._cache = kwargs.pop("_cache", None) or TokenCache() self._content_callback = kwargs.pop("_content_callback", None) self._identity_config = identity_config or {} @@ -36,9 +35,7 @@ def __init__(self, request_factory, client_id=None, identity_config=None, **kwar self._pipeline = self._build_pipeline(**kwargs) self._request_factory = request_factory - def _process_response(self, response, request_time): - # type: (PipelineResponse, int) -> AccessToken - + def _process_response(self, response: PipelineResponse, request_time: int) -> AccessToken: content = response.context.get(ContentDecodePolicy.CONTEXT_NAME) if not content: try: @@ -78,8 +75,7 @@ def _process_response(self, response, request_time): return token - def get_cached_token(self, *scopes): - # type: (*str) -> Optional[AccessToken] + def get_cached_token(self, *scopes: str) -> Optional[AccessToken]: resource = _scopes_to_resource(*scopes) tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource]) for token in tokens: @@ -105,12 +101,10 @@ def __enter__(self): def __exit__(self, *args): self._pipeline.__exit__(*args) - def close(self): - # type: () -> None + def close(self) -> None: self.__exit__() - def request_token(self, *scopes, **kwargs): - # type: (*str, **Any) -> AccessToken + def request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: resource = _scopes_to_resource(*scopes) request = self._request_factory(resource, self._identity_config) kwargs.pop("tenant_id", None) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py index 53c25c172f24..7e90c2ac4def 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_client.py @@ -3,25 +3,20 @@ # Licensed under the MIT License. # ------------------------------------ import threading -from typing import Any, Dict, Optional, Union, TYPE_CHECKING +from typing import Any, Dict, Optional, Union import six from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy -from azure.core.pipeline.transport import HttpRequest +from azure.core.pipeline.transport import HttpRequest, HttpResponse from azure.core.pipeline import PipelineResponse from .pipeline import build_pipeline -if TYPE_CHECKING: - from azure.core.pipeline.transport import HttpResponse - - RequestData = Union[Dict[str, str], str] - - +RequestData = Union[Dict[str, str], str] _POST = ["POST"] -class MsalResponse(object): +class MsalResponse: """Wraps HttpResponse according to msal.oauth2cli.http""" def __init__(self, response: PipelineResponse) -> None: @@ -56,11 +51,10 @@ def raise_for_status(self): raise ClientAuthenticationError(message=message, response=self._response.http_response) -class MsalClient(object): # pylint:disable=client-accepts-api-version-keyword +class MsalClient: # pylint:disable=client-accepts-api-version-keyword """Wraps Pipeline according to msal.oauth2cli.http""" - def __init__(self, **kwargs): # pylint:disable=missing-client-constructor-parameter-credential - # type: (**Any) -> None + def __init__(self, **kwargs: Any) -> None: # pylint:disable=missing-client-constructor-parameter-credential self._local = threading.local() self._pipeline = build_pipeline(**kwargs) @@ -71,12 +65,18 @@ def __enter__(self): def __exit__(self, *args): self._pipeline.__exit__(*args) - def close(self): - # type: () -> None + def close(self) -> None: self.__exit__() - def post(self, url, params=None, data=None, headers=None, **kwargs): # pylint:disable=unused-argument - # type: (str, Optional[Dict[str, str]], RequestData, Optional[Dict[str, str]], **Any) -> MsalResponse + def post( + self, + url: str, + params: Optional[Dict[str, str]] = None, + data: Optional[RequestData] = None, + headers: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> MsalResponse: + # pylint:disable=unused-argument request = HttpRequest("POST", url, headers=headers) if params: request.format_parameters(params) @@ -94,8 +94,14 @@ def post(self, url, params=None, data=None, headers=None, **kwargs): # pylint:d self._store_auth_error(response) return MsalResponse(response) - def get(self, url, params=None, headers=None, **kwargs): # pylint:disable=unused-argument - # type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], **Any) -> MsalResponse + def get( + self, + url: str, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> MsalResponse: + # pylint:disable=unused-argument request = HttpRequest("GET", url, headers=headers) if params: request.format_parameters(params) @@ -103,16 +109,14 @@ def get(self, url, params=None, headers=None, **kwargs): # pylint:disable=unuse self._store_auth_error(response) return MsalResponse(response) - def get_error_response(self, msal_result): - # type: (dict) -> Optional[HttpResponse] + def get_error_response(self, msal_result: Dict) -> Optional[HttpResponse]: """Get the HTTP response associated with an MSAL error""" error_code, response = getattr(self._local, "error", (None, None)) if response and error_code == msal_result.get("error"): return response return None - def _store_auth_error(self, response): - # type: (PipelineResponse) -> None + def _store_auth_error(self, response: PipelineResponse) -> None: if response.http_response.status_code >= 400: # if the body doesn't contain "error", this isn't an OAuth 2 error, i.e. this isn't a # response to an auth request, so no credential will want to include it with an exception diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index 5dd497b0371e..31442381740b 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import os -from typing import Any, List, Union, Dict +from typing import Any, List, Union, Dict, Optional import msal @@ -13,26 +13,28 @@ from .._persistent_cache import _load_persistent_cache -class MsalCredential(object): # pylint: disable=too-many-instance-attributes +class MsalCredential: # pylint: disable=too-many-instance-attributes """Base class for credentials wrapping MSAL applications""" def __init__( self, client_id: str, - client_credential: Union[str, Dict] = None, + client_credential: Optional[Union[str, Dict]] = None, *, - additionally_allowed_tenants: List[str] = None, - allow_broker: bool = None, + additionally_allowed_tenants: Optional[List[str]] = None, + allow_broker: Optional[bool] = None, + authority: Optional[str] = None, + instance_discovery: Optional[bool] = None, + tenant_id: Optional[str] = None, **kwargs ) -> None: - authority = kwargs.pop("authority", None) - self._instance_discovery = kwargs.pop("instance_discovery", None) + self._instance_discovery = instance_discovery self._authority = normalize_authority(authority) if authority else get_default_authority() self._regional_authority = os.environ.get(EnvironmentVariables.AZURE_REGIONAL_AUTHORITY_NAME) - self._tenant_id = kwargs.pop("tenant_id", None) or "organizations" + self._tenant_id = tenant_id or "organizations" validate_tenant_id(self._tenant_id) self._client = MsalClient(**kwargs) - self._client_applications = {} # type: Dict[str, msal.ClientApplication] + self._client_applications: Dict[str, msal.ClientApplication] = {} self._client_credential = client_credential self._client_id = client_id self._allow_broker = allow_broker @@ -55,8 +57,7 @@ def __enter__(self): def __exit__(self, *args): self._client.__exit__(*args) - def close(self): - # type: () -> None + def close(self) -> None: self.__exit__() def _get_app(self, **kwargs): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 5b86ea63c783..c49eef8df5c4 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -5,7 +5,7 @@ import abc import platform import time -from typing import Any, Iterable, List, Mapping, Optional +from typing import Any, Iterable, List, Mapping, Optional, cast from urllib.parse import urlparse import six import msal @@ -72,17 +72,22 @@ def _filtered_accounts(accounts, username=None, tenant_id=None): class SharedTokenCacheBase(ABC): - def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument - # type: (Optional[str], **Any) -> None - authority = kwargs.pop("authority", None) + def __init__( + self, + username: Optional[str] = None, + *, + authority: Optional[str] = None, + tenant_id: Optional[str] = None, + **kwargs: Any + ) -> None: # pylint:disable=unused-argument self._authority = normalize_authority(authority) if authority else get_default_authority() environment = urlparse(self._authority).netloc self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,)) self._username = username - self._tenant_id = kwargs.pop("tenant_id", None) + self._tenant_id = tenant_id self._cache = kwargs.pop("_cache", None) self._cache_persistence_options = kwargs.pop("cache_persistence_options", None) - self._client = None # type: Optional[AadClientBase] + self._client: AadClientBase = cast(AadClientBase, None) self._client_kwargs = kwargs self._client_kwargs["tenant_id"] = "organizations" self._initialized = False @@ -175,8 +180,7 @@ def _get_account(self, username: str = None, tenant_id: str = None) -> CacheIte raise CredentialUnavailableError(message=message) - def _get_cached_access_token(self, scopes, account): - # type: (Iterable[str], CacheItem) -> Optional[AccessToken] + def _get_cached_access_token(self, scopes: Iterable[str], account: CacheItem) -> Optional[AccessToken]: if "home_account_id" not in account: return None @@ -210,8 +214,7 @@ def _get_refresh_tokens(self, account): six.raise_from(CredentialUnavailableError(message=message), ex) @staticmethod - def supported(): - # type: () -> bool + def supported() -> bool: """Whether the shared token cache is supported on the current platform. :rtype: bool diff --git a/sdk/identity/azure-identity/azure/identity/_persistent_cache.py b/sdk/identity/azure-identity/azure/identity/_persistent_cache.py index c0eace113628..79e035334ed7 100644 --- a/sdk/identity/azure-identity/azure/identity/_persistent_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_persistent_cache.py @@ -15,7 +15,7 @@ _LOGGER = logging.getLogger(__name__) -class TokenCachePersistenceOptions(object): +class TokenCachePersistenceOptions: """Options for persistent token caching. Most credentials accept an instance of this class to configure persistent token caching. The default values @@ -46,10 +46,16 @@ class TokenCachePersistenceOptions(object): always try to encrypt its data. """ - def __init__(self, **kwargs): - # type: (**Any) -> None - self.allow_unencrypted_storage = kwargs.get("allow_unencrypted_storage", False) - self.name = kwargs.get("name", "msal.cache") + def __init__( + self, + *, + allow_unencrypted_storage: bool = False, + name: str = "msal.cache", + **kwargs: Any + ) -> None: + # pylint:disable=unused-argument + self.allow_unencrypted_storage = allow_unencrypted_storage + self.name = name def _load_persistent_cache(options): diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py index d9b44f0ee70f..1732802c4033 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/app_service.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Optional +from typing import Optional, Any from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient @@ -10,7 +10,7 @@ class AppServiceCredential(AsyncManagedIdentityBase): - def get_client(self, **kwargs) -> Optional[AsyncManagedIdentityClient]: + def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]: client_args = _get_client_args(**kwargs) if client_args: return AsyncManagedIdentityClient(**client_args) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py index 6075c8ff335b..64a38608c836 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/application.py @@ -4,6 +4,7 @@ # ------------------------------------ import logging import os +from typing import Optional, Any from azure.core.credentials import AccessToken from .chained import ChainedTokenCredential @@ -51,18 +52,21 @@ class AzureApplicationCredential(ChainedTokenCredential): of the environment variable AZURE_CLIENT_ID, if any. If not specified, a system-assigned identity will be used. """ - def __init__(self, **kwargs) -> None: - authority = kwargs.pop("authority", None) + def __init__( + self, + *, + authority: Optional[str] = None, + managed_identity_client_id: Optional[str] = None, + **kwargs: Any + ) -> None: authority = normalize_authority(authority) if authority else get_default_authority() - managed_identity_client_id = kwargs.pop( - "managed_identity_client_id", os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID) - ) + managed_identity_client_id = managed_identity_client_id or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID) super().__init__( EnvironmentCredential(authority=authority, **kwargs), ManagedIdentityCredential(client_id=managed_identity_client_id, **kwargs), ) - async def get_token(self, *scopes: str, **kwargs) -> AccessToken: + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Asynchronously request an access token for `scopes`. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py index 67b0c48b2292..c29d4223a82c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Optional +from typing import Optional, Any, cast from azure.core.exceptions import ClientAuthenticationError from azure.core.credentials import AccessToken @@ -48,16 +48,19 @@ def __init__( client_id: str, authorization_code: str, redirect_uri: str, + *, + client_secret: Optional[str] = None, + client: Optional[AadClient] = None, **kwargs ) -> None: - self._authorization_code = authorization_code # type: Optional[str] + self._authorization_code: Optional[str] = authorization_code self._client_id = client_id - self._client_secret = kwargs.pop("client_secret", None) - self._client = kwargs.pop("client", None) or AadClient(tenant_id, client_id, **kwargs) + self._client_secret = client_secret + self._client = client or AadClient(tenant_id, client_id, **kwargs) self._redirect_uri = redirect_uri super().__init__() - async def get_token(self, *scopes: str, **kwargs) -> AccessToken: + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. @@ -77,10 +80,10 @@ async def get_token(self, *scopes: str, **kwargs) -> AccessToken: """ return await super().get_token(*scopes, **kwargs) - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: if self._authorization_code: token = await self._client.obtain_token_by_authorization_code( scopes=scopes, code=self._authorization_code, redirect_uri=self._redirect_uri, **kwargs @@ -88,7 +91,7 @@ async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: self._authorization_code = None # auth codes are single-use return token - token = None + token = cast(AccessToken, None) for refresh_token in self._client.get_cached_refresh_tokens(scopes): if "secret" in refresh_token: token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token["secret"], **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py index 14859a076dc0..867df88b0230 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py @@ -4,7 +4,7 @@ # ------------------------------------ import functools import os -from typing import Optional +from typing import Optional, Any from azure.core.pipeline.policies import AsyncHTTPPolicy from azure.core.pipeline import PipelineRequest, PipelineResponse @@ -15,7 +15,7 @@ class AzureArcCredential(AsyncManagedIdentityBase): - def get_client(self, **kwargs) -> Optional[AsyncManagedIdentityClient]: + def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]: url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT) imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT) if url and imds: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py index 28819a5ae699..31b8fa23ee09 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_cli.py @@ -6,7 +6,7 @@ import os import shutil import sys -from typing import List +from typing import List, Any, Optional from azure.core.exceptions import ClientAuthenticationError from azure.core.credentials import AccessToken @@ -36,13 +36,18 @@ class AzureCliCredential(AsyncContextManager): for which the credential may acquire tokens. Add the wildcard value "*" to allow the credential to acquire tokens for any tenant the application can access. """ - def __init__(self, *, tenant_id: str = "", additionally_allowed_tenants: List[str] = None): + def __init__( + self, + *, + tenant_id: str = "", + additionally_allowed_tenants: Optional[List[str]] = None + ) -> None: self.tenant_id = tenant_id self._additionally_allowed_tenants = additionally_allowed_tenants or [] @log_get_token_async - async def get_token(self, *scopes: str, **kwargs) -> AccessToken: + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes`. This method is called automatically by Azure SDK clients. Applications calling this method directly must @@ -83,7 +88,7 @@ async def get_token(self, *scopes: str, **kwargs) -> AccessToken: return token - async def close(self): + async def close(self) -> None: """Calling this method is unnecessary""" diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_ml.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_ml.py index be06b6e453df..ce54d9e80219 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_ml.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_ml.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Optional +from typing import Optional, Any from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient @@ -10,7 +10,7 @@ class AzureMLCredential(AsyncManagedIdentityBase): - def get_client(self, **kwargs) -> Optional[AsyncManagedIdentityClient]: + def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]: client_args = _get_client_args(**kwargs) if client_args: return AsyncManagedIdentityClient(**client_args) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py index b20cd43c5032..cac1a3ab0fff 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py @@ -4,7 +4,7 @@ # ------------------------------------ import asyncio import sys -from typing import cast, List +from typing import cast, List, Any from azure.core.credentials import AccessToken from .._internal import AsyncContextManager @@ -38,7 +38,7 @@ def __init__(self, *, tenant_id: str = "", additionally_allowed_tenants: List[st @log_get_token_async async def get_token( - self, *scopes: str, **kwargs + self, *scopes: str, **kwargs: Any ) -> AccessToken: # pylint:disable=no-self-use,unused-argument """Request an access token for `scopes`. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py index 69b85ad6a6e9..e86ff433536e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import TypeVar, Optional +from typing import TypeVar, Optional, Any import msal @@ -49,7 +49,7 @@ def __init__( tenant_id: str, client_id: str, certificate_path: str = None, - **kwargs + **kwargs: Any ) -> None: validate_tenant_id(tenant_id) @@ -69,17 +69,17 @@ def __init__( self._client_id = client_id super().__init__() - async def __aenter__(self:T) -> T: + async def __aenter__(self: T) -> T: await self._client.__aenter__() return self - async def close(self): + async def close(self) -> None: """Close the credential's transport session.""" await self._client.__aexit__() - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: return await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py index 95b4632a8f08..d642b980bb7e 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/chained.py @@ -4,7 +4,7 @@ # ------------------------------------ import asyncio import logging -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, Any from azure.core.exceptions import ClientAuthenticationError from azure.core.credentials import AccessToken @@ -36,12 +36,12 @@ def __init__(self, *credentials: "AsyncTokenCredential") -> None: self._successful_credential = None # type: Optional[AsyncTokenCredential] self.credentials = credentials - async def close(self): + async def close(self) -> None: """Close the transport sessions of all credentials in the chain.""" await asyncio.gather(*(credential.close() for credential in self.credentials)) - async def get_token(self, *scopes: str, **kwargs) -> AccessToken: + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Asynchronously request a token from each credential, in order, returning the first token received. If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError` diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py index 7fb327cd701f..94f604944726 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py @@ -29,7 +29,7 @@ class ClientAssertionCredential(AsyncContextManager, GetTokenMixin): acquire tokens for any tenant the application can access. """ - def __init__(self, tenant_id: str, client_id: str, func: Callable[[], str], **kwargs) -> None: + def __init__(self, tenant_id: str, client_id: str, func: Callable[[], str], **kwargs: Any) -> None: self._func = func self._client = AadClient(tenant_id, client_id, **kwargs) super().__init__(**kwargs) @@ -42,10 +42,10 @@ async def close(self) -> None: """Close the credential's transport session.""" await self._client.close() - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: assertion = self._func() token = await self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs) return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py index 290641a44423..c23f0cb38923 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Optional, TypeVar +from typing import Optional, TypeVar, Any import msal @@ -33,7 +33,7 @@ class ClientSecretCredential(AsyncContextManager, GetTokenMixin): acquire tokens for any tenant the application can access. """ - def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs) -> None: + def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs: Any) -> None: if not client_id: raise ValueError("client_id should be the id of an Azure Active Directory application") if not client_secret: @@ -59,13 +59,13 @@ async def __aenter__(self: T) -> T: await self._client.__aenter__() return self - async def close(self): + async def close(self) -> None: """Close the credential's transport session.""" await self._client.__aexit__() - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: return await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py index 1b1fbe90b771..c4aefe44e542 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/cloud_shell.py @@ -4,7 +4,7 @@ # ------------------------------------ import functools import os -from typing import Optional +from typing import Optional, Any from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient @@ -13,7 +13,7 @@ class CloudShellCredential(AsyncManagedIdentityBase): - def get_client(self, **kwargs) -> Optional[AsyncManagedIdentityClient]: + def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]: url = os.environ.get(EnvironmentVariables.MSI_ENDPOINT) if url: return AsyncManagedIdentityClient( diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index 2eef508c4cf8..8000e05aa476 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -4,7 +4,7 @@ # ------------------------------------ import logging import os -from typing import List, TYPE_CHECKING +from typing import List, TYPE_CHECKING, Any from azure.core.credentials import AccessToken from ..._constants import EnvironmentVariables @@ -65,7 +65,7 @@ class DefaultAzureCredential(ChainedTokenCredential): Directory work or school accounts. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: if "tenant_id" in kwargs: raise TypeError("'tenant_id' is not supported in DefaultAzureCredential.") @@ -125,7 +125,7 @@ def __init__(self, **kwargs) -> None: super().__init__(*credentials) - async def get_token(self, *scopes: str, **kwargs) -> AccessToken: + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Asynchronously request an access token for `scopes`. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py index 504d650ba7e0..7bf5cbff414c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/environment.py @@ -4,7 +4,7 @@ # ------------------------------------ import logging import os -from typing import Optional, Union +from typing import Optional, Union, Any from azure.core.credentials import AccessToken from .._internal.decorators import log_get_token_async @@ -41,8 +41,8 @@ class EnvironmentCredential(AsyncContextManager): when no value is given. """ - def __init__(self, **kwargs) -> None: - self._credential = None # type: Optional[Union[CertificateCredential, ClientSecretCredential]] + def __init__(self, **kwargs: Any) -> None: + self._credential: Optional[Union[CertificateCredential, ClientSecretCredential]] = None if all(os.environ.get(v) is not None for v in EnvironmentVariables.CLIENT_SECRET_VARS): self._credential = ClientSecretCredential( @@ -77,14 +77,14 @@ async def __aenter__(self): await self._credential.__aenter__() return self - async def close(self): + async def close(self) -> None: """Close the credential's transport session.""" if self._credential: await self._credential.__aexit__() @log_get_token_async - async def get_token(self, *scopes: str, **kwargs) -> AccessToken: + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Asynchronously request an access token for `scopes`. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py index 95260a414e4c..ddb8dd4173e1 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import os -from typing import Optional, TypeVar +from typing import Optional, TypeVar, Any from azure.core.exceptions import ClientAuthenticationError, HttpResponseError from azure.core.credentials import AccessToken @@ -12,21 +12,21 @@ from .._internal import AsyncContextManager from .._internal.get_token_mixin import GetTokenMixin from .._internal.managed_identity_client import AsyncManagedIdentityClient -from ..._credentials.imds import get_request, PIPELINE_SETTINGS +from ..._credentials.imds import _get_request, PIPELINE_SETTINGS T = TypeVar("T", bound="ImdsCredential") class ImdsCredential(AsyncContextManager, GetTokenMixin): - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__() - self._client = AsyncManagedIdentityClient(get_request, **dict(PIPELINE_SETTINGS, **kwargs)) + self._client = AsyncManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **kwargs)) if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ: - self._endpoint_available = True # type: Optional[bool] + self._endpoint_available: Optional[bool] = True else: self._endpoint_available = None - self._error_message = None # type: Optional[str] + self._error_message: Optional[str] = None self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs async def __aenter__(self: T) -> T: @@ -36,10 +36,10 @@ async def __aenter__(self: T) -> T: async def close(self) -> None: await self._client.close() - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: return self._client.get_cached_token(*scopes) - async def _request_token(self, *scopes, **kwargs) -> AccessToken: # pylint:disable=unused-argument + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument if self._endpoint_available is None: # Lacking another way to determine whether the IMDS endpoint is listening, # we send a request it would immediately reject (because it lacks the Metadata header), diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index 43e5a8ac6cea..c06abfff440c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -4,7 +4,7 @@ # ------------------------------------ import logging import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Any from azure.core.credentials import AccessToken from .._internal import AsyncContextManager @@ -34,7 +34,7 @@ class ManagedIdentityCredential(AsyncContextManager): :paramtype identity_config: Mapping[str, str] """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: self._credential = None # type: Optional[AsyncTokenCredential] if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT): @@ -95,13 +95,13 @@ async def __aenter__(self): await self._credential.__aenter__() return self - async def close(self): + async def close(self) -> None: """Close the credential's transport session.""" if self._credential: - await self._credential.__aexit__() + await self._credential.close() @log_get_token_async - async def get_token(self, *scopes: str, **kwargs) -> AccessToken: + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Asynchronously request an access token for `scopes`. This method is called automatically by Azure SDK clients. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py index 0d68e6d367eb..d412a534438b 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import logging -from typing import Optional, Union +from typing import Optional, Union, Any from azure.core.exceptions import ClientAuthenticationError from azure.core.credentials import AccessToken @@ -54,7 +54,7 @@ def __init__( client_certificate: bytes = None, client_secret: str = None, user_assertion: str, - **kwargs + **kwargs: Any ) -> None: super().__init__() validate_tenant_id(tenant_id) @@ -86,13 +86,13 @@ async def __aenter__(self): await self._client.__aenter__() return self - async def close(self): + async def close(self) -> None: await self._client.close() - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # Note we assume the cache has tokens for one user only. That's okay because each instance of this class is # locked to a single user (assertion). This assumption will become unsafe if this class allows applications # to change an instance's assertion. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py index 3ec87f4045e5..7542ed127afc 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/service_fabric.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Optional +from typing import Optional, Any from .._internal.managed_identity_base import AsyncManagedIdentityBase from .._internal.managed_identity_client import AsyncManagedIdentityClient @@ -10,7 +10,7 @@ class ServiceFabricCredential(AsyncManagedIdentityBase): - def get_client(self, **kwargs) -> Optional[AsyncManagedIdentityClient]: + def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]: client_args = _get_client_args(**kwargs) if client_args: return AsyncManagedIdentityClient(**client_args) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index f7a3c60c4dca..1e034dd20967 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from typing import Any from azure.core.credentials import AccessToken from ..._internal.aad_client import AadClientBase from ... import CredentialUnavailableError @@ -31,17 +32,17 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncContextManager): async def __aenter__(self): if self._client: - await self._client.__aenter__() + await self._client.__aenter__() # type: ignore return self - async def close(self): + async def close(self) -> None: """Close the credential's transport session.""" if self._client: - await self._client.__aexit__() + await self._client.__aexit__() # type: ignore @log_get_token_async - async def get_token(self, *scopes: str, **kwargs) -> AccessToken: # pylint:disable=unused-argument + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument """Get an access token for `scopes` from the shared cache. If no access token is cached, attempt to acquire one using a cached refresh token. @@ -81,5 +82,5 @@ async def get_token(self, *scopes: str, **kwargs) -> AccessToken: # pylint:disa raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username"))) - def _get_auth_client(self, **kwargs) -> AadClientBase: + def _get_auth_client(self, **kwargs: Any) -> AadClientBase: return AadClient(client_id=DEVELOPER_SIGN_ON_CLIENT_ID, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/token_exchange.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/token_exchange.py index a9b15f201815..54547aa546e6 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/token_exchange.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/token_exchange.py @@ -2,12 +2,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from typing import Any from .client_assertion import ClientAssertionCredential from ..._credentials.token_exchange import TokenFileMixin class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin): - def __init__(self, tenant_id: str, client_id: str, token_file_path: str, **kwargs) -> None: + def __init__(self, tenant_id: str, client_id: str, token_file_path: str, **kwargs: Any) -> None: super().__init__( tenant_id=tenant_id, client_id=client_id, diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py index ac91a6e47b62..0d01fc0e335f 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import cast, Optional +from typing import cast, Optional, Any from azure.core.credentials import AccessToken from ..._exceptions import CredentialUnavailableError @@ -45,7 +45,7 @@ async def close(self) -> None: await self._client.__aexit__() @log_get_token_async - async def get_token(self, *scopes: str, **kwargs) -> AccessToken: + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: """Request an access token for `scopes` as the user currently signed in to Visual Studio Code. This method is called automatically by Azure SDK clients. @@ -69,14 +69,14 @@ async def get_token(self, *scopes: str, **kwargs) -> AccessToken: return await super().get_token(*scopes, **kwargs) - async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessToken]: + async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessToken]: self._client = cast(AadClient, self._client) return self._client.get_cached_access_token(scopes, **kwargs) - async def _request_token(self, *scopes: str, **kwargs) -> AccessToken: + async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: refresh_token = self._get_refresh_token() self._client = cast(AadClient, self._client) return await self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs) - def _get_client(self, **kwargs) -> AadClient: + def _get_client(self, **kwargs: Any) -> AadClient: return AadClient(**kwargs) diff --git a/tools/azure-sdk-tools/ci_tools/environment_exclusions.py b/tools/azure-sdk-tools/ci_tools/environment_exclusions.py index ace017d65bf0..5d2340f0178a 100644 --- a/tools/azure-sdk-tools/ci_tools/environment_exclusions.py +++ b/tools/azure-sdk-tools/ci_tools/environment_exclusions.py @@ -229,7 +229,6 @@ "azure-eventhub-checkpointstoreblob-aio", "azure-eventhub-checkpointstoretable", "azure-ai-formrecognizer", - "azure-identity", "azure-keyvault-administration", "azure-keyvault-certificates", "azure-keyvault-keys",