diff --git a/docs/index.rst b/docs/index.rst index 15ee4a0a..1c5d79bc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -168,3 +168,23 @@ You may want to catch them to provide a better error message to your end users. .. autoclass:: msal.IdTokenError + +Managed Identity +================ +MSAL supports +`Managed Identity `_. + +You can create one of these two kinds of managed identity configuration objects: + +.. autoclass:: msal.SystemAssignedManagedIdentity + :members: + +.. autoclass:: msal.UserAssignedManagedIdentity + :members: + +And then feed the configuration object into a :class:`ManagedIdentityClient` object. + +.. autoclass:: msal.ManagedIdentityClient + :members: + + .. automethod:: __init__ diff --git a/msal/__init__.py b/msal/__init__.py index 87d32019..380d584e 100644 --- a/msal/__init__.py +++ b/msal/__init__.py @@ -34,8 +34,15 @@ from .oauth2cli.oidc import Prompt, IdTokenError from .token_cache import TokenCache, SerializableTokenCache from .auth_scheme import PopAuthScheme +from .managed_identity import ( + SystemAssignedManagedIdentity, UserAssignedManagedIdentity, + ManagedIdentityClient, + ManagedIdentityError, + ArcPlatformNotSupportedError, + ) # Putting module-level exceptions into the package namespace, to make them # 1. officially part of the MSAL public API, and # 2. can still be caught by the user code even if we change the module structure. from .oauth2cli.oauth2 import BrowserInteractionTimeoutError + diff --git a/msal/application.py b/msal/application.py index 65022124..b6a1d9d8 100644 --- a/msal/application.py +++ b/msal/application.py @@ -537,7 +537,11 @@ def __init__( self.http_client.mount("https://", a) self.http_client = ThrottledHttpClient( self.http_client, - {} if http_cache is None else http_cache, # Default to an in-memory dict + http_cache=http_cache, + default_throttle_time=60 + # The default value 60 was recommended mainly for PCA at the end of + # https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview + if isinstance(self, PublicClientApplication) else 5, ) self.app_name = app_name diff --git a/msal/managed_identity.py b/msal/managed_identity.py new file mode 100644 index 00000000..354fee52 --- /dev/null +++ b/msal/managed_identity.py @@ -0,0 +1,599 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +import json +import logging +import os +import socket +import sys +import time +from urllib.parse import urlparse # Python 3+ +from collections import UserDict # Python 3+ +from typing import Union # Needed in Python 3.7 & 3.8 +from .token_cache import TokenCache +from .individual_cache import _IndividualCache as IndividualCache +from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser +from .cloudshell import _is_running_in_cloud_shell + + +logger = logging.getLogger(__name__) + + +class ManagedIdentityError(ValueError): + pass + + +class ManagedIdentity(UserDict): + """Feed an instance of this class to :class:`msal.ManagedIdentityClient` + to acquire token for the specified managed identity. + """ + # The key names used in config dict + ID_TYPE = "ManagedIdentityIdType" # Contains keyword ManagedIdentity so its json equivalent will be more readable + ID = "Id" + + # Valid values for key ID_TYPE + CLIENT_ID = "ClientId" + RESOURCE_ID = "ResourceId" + OBJECT_ID = "ObjectId" + SYSTEM_ASSIGNED = "SystemAssigned" + + _types_mapping = { # Maps type name in configuration to type name on wire + CLIENT_ID: "client_id", + RESOURCE_ID: "mi_res_id", + OBJECT_ID: "object_id", + } + + @classmethod + def is_managed_identity(cls, unknown): + return isinstance(unknown, ManagedIdentity) or ( + isinstance(unknown, dict) and cls.ID_TYPE in unknown) + + @classmethod + def is_system_assigned(cls, unknown): + return isinstance(unknown, SystemAssignedManagedIdentity) or ( + isinstance(unknown, dict) + and unknown.get(cls.ID_TYPE) == cls.SYSTEM_ASSIGNED) + + @classmethod + def is_user_assigned(cls, unknown): + return isinstance(unknown, UserAssignedManagedIdentity) or ( + isinstance(unknown, dict) + and unknown.get(cls.ID_TYPE) in cls._types_mapping + and unknown.get(cls.ID)) + + def __init__(self, identifier=None, id_type=None): + # Undocumented. Use subclasses instead. + super(ManagedIdentity, self).__init__({ + self.ID_TYPE: id_type, + self.ID: identifier, + }) + + +class SystemAssignedManagedIdentity(ManagedIdentity): + """Represent a system-assigned managed identity. + + It is equivalent to a Python dict of:: + + {"ManagedIdentityIdType": "SystemAssigned", "Id": None} + + or a JSON blob of:: + + {"ManagedIdentityIdType": "SystemAssigned", "Id": null} + """ + def __init__(self): + super(SystemAssignedManagedIdentity, self).__init__(id_type=self.SYSTEM_ASSIGNED) + + +class UserAssignedManagedIdentity(ManagedIdentity): + """Represent a user-assigned managed identity. + + Depends on the id you provided, the outcome is equivalent to one of the below:: + + {"ManagedIdentityIdType": "ClientId", "Id": "foo"} + {"ManagedIdentityIdType": "ResourceId", "Id": "foo"} + {"ManagedIdentityIdType": "ObjectId", "Id": "foo"} + """ + def __init__(self, *, client_id=None, resource_id=None, object_id=None): + if client_id and not resource_id and not object_id: + super(UserAssignedManagedIdentity, self).__init__( + id_type=self.CLIENT_ID, identifier=client_id) + elif not client_id and resource_id and not object_id: + super(UserAssignedManagedIdentity, self).__init__( + id_type=self.RESOURCE_ID, identifier=resource_id) + elif not client_id and not resource_id and object_id: + super(UserAssignedManagedIdentity, self).__init__( + id_type=self.OBJECT_ID, identifier=object_id) + else: + raise ManagedIdentityError( + "You shall specify one of the three parameters: " + "client_id, resource_id, object_id") + + +class _ThrottledHttpClient(ThrottledHttpClientBase): + def __init__(self, http_client, **kwargs): + super(_ThrottledHttpClient, self).__init__(http_client, **kwargs) + self.get = IndividualCache( # All MIs (except Cloud Shell) use GETs + mapping=self._expiring_mapping, + key_maker=lambda func, args, kwargs: "REQ {} hash={} 429/5xx/Retry-After".format( + args[0], # It is the endpoint, typically a constant per MI type + self._hash( + # Managed Identity flavors have inconsistent parameters. + # We simply choose to hash them all. + str(kwargs.get("params")) + str(kwargs.get("data"))), + ), + expires_in=RetryAfterParser(5).parse, # 5 seconds default for non-PCA + )(http_client.get) + + +class ManagedIdentityClient(object): + """This API encapsulates multiple managed identity back-ends: + VM, App Service, Azure Automation (Runbooks), Azure Function, Service Fabric, + and Azure Arc. + + It also provides token cache support. + + .. note:: + + Cloud Shell support is NOT implemented in this class. + Since MSAL Python 1.18 in May 2022, it has been implemented in + :func:`PublicClientApplication.acquire_token_interactive` via calling pattern + ``PublicClientApplication(...).acquire_token_interactive(scopes=[...], prompt="none")``. + That is appropriate, because Cloud Shell yields a token with + delegated permissions for the end user who has signed in to the Azure Portal + (like what a ``PublicClientApplication`` does), + not a token with application permissions for an app. + """ + _instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders + + def __init__( + self, + managed_identity: Union[ + dict, + ManagedIdentity, # Could use Type[ManagedIdentity] but it is deprecatred in Python 3.9+ + SystemAssignedManagedIdentity, + UserAssignedManagedIdentity, + ], + *, + http_client, + token_cache=None, + http_cache=None, + ): + """Create a managed identity client. + + :param managed_identity: + It accepts an instance of :class:`SystemAssignedManagedIdentity` + or :class:`UserAssignedManagedIdentity`. + They are equivalent to a dict with a certain shape, + which may be loaded from a JSON configuration file or an env var. + + :param http_client: + An http client object. For example, you can use ``requests.Session()``, + optionally with exponential backoff behavior demonstrated in this recipe:: + + import msal, requests + from requests.adapters import HTTPAdapter, Retry + s = requests.Session() + retries = Retry(total=3, backoff_factor=0.1, status_forcelist=[ + 429, 500, 501, 502, 503, 504]) + s.mount('https://', HTTPAdapter(max_retries=retries)) + managed_identity = ... + client = msal.ManagedIdentityClient(managed_identity, http_client=s) + + :param token_cache: + Optional. It accepts a :class:`msal.TokenCache` instance to store tokens. + It will use an in-memory token cache by default. + + :param http_cache: + Optional. It has the same characteristics as the + :paramref:`msal.ClientApplication.http_cache`. + + Recipe 1: Hard code a managed identity for your app:: + + import msal, requests + client = msal.ManagedIdentityClient( + msal.UserAssignedManagedIdentity(client_id="foo"), + http_client=requests.Session(), + ) + token = client.acquire_token_for_client("resource") + + Recipe 2: Write once, run everywhere. + If you use different managed identity on different deployment, + you may use an environment variable (such as MY_MANAGED_IDENTITY_CONFIG) + to store a json blob like + ``{"ManagedIdentityIdType": "ClientId", "Id": "foo"}`` or + ``{"ManagedIdentityIdType": "SystemAssignedManagedIdentity", "Id": null})``. + The following app can load managed identity configuration dynamically:: + + import json, os, msal, requests + config = os.getenv("MY_MANAGED_IDENTITY_CONFIG") + assert config, "An ENV VAR with value should exist" + client = msal.ManagedIdentityClient( + json.loads(config), + http_client=requests.Session(), + ) + token = client.acquire_token_for_client("resource") + """ + self._managed_identity = managed_identity + self._http_client = _ThrottledHttpClient( + # This class only throttles excess token acquisition requests. + # It does not provide retry. + # Retry is the http_client or caller's responsibility, not MSAL's. + # + # FWIW, here is the inconsistent retry recommendation. + # 1. Only MI on VM defines exotic 404 and 410 retry recommendations + # ( https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling ) + # (especially for 410 which was supposed to be a permanent failure). + # 2. MI on Service Fabric specifically suggests to not retry on 404. + # ( https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-cluster-managed-identity-service-fabric-app-code#error-handling ) + http_client.http_client # Patch the raw (unpatched) http client + if isinstance(http_client, ThrottledHttpClientBase) else http_client, + http_cache=http_cache, + ) + self._token_cache = token_cache or TokenCache() + + def acquire_token_for_client(self, *, resource): # We may support scope in the future + """Acquire token for the managed identity. + + The result will be automatically cached. + Subsequent calls will automatically search from cache first. + + .. note:: + + Known issue: When an Azure VM has only one user-assigned managed identity, + and your app specifies to use system-assigned managed identity, + Azure VM may still return a token for your user-assigned identity. + + This is a service-side behavior that cannot be changed by this library. + `Azure VM docs `_ + """ + access_token_from_cache = None + client_id_in_cache = self._managed_identity.get( + ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") + if True: # Does not offer an "if not force_refresh" option, because + # there would be built-in token cache in the service side anyway + matches = self._token_cache.find( + self._token_cache.CredentialType.ACCESS_TOKEN, + target=[resource], + query=dict( + client_id=client_id_in_cache, + environment=self._instance, + realm=self._tenant, + home_account_id=None, + ), + ) + now = time.time() + for entry in matches: + expires_in = int(entry["expires_on"]) - now + if expires_in < 5*60: # Then consider it expired + continue # Removal is not necessary, it will be overwritten + logger.debug("Cache hit an AT") + access_token_from_cache = { # Mimic a real response + "access_token": entry["secret"], + "token_type": entry.get("token_type", "Bearer"), + "expires_in": int(expires_in), # OAuth2 specs defines it as int + } + if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging + break # With a fallback in hand, we break here to go refresh + return access_token_from_cache # It is still good as new + try: + result = _obtain_token(self._http_client, self._managed_identity, resource) + if "access_token" in result: + expires_in = result.get("expires_in", 3600) + if "refresh_in" not in result and expires_in >= 7200: + result["refresh_in"] = int(expires_in / 2) + self._token_cache.add(dict( + client_id=client_id_in_cache, + scope=[resource], + token_endpoint="https://{}/{}".format(self._instance, self._tenant), + response=result, + params={}, + data={}, + )) + if (result and "error" not in result) or (not access_token_from_cache): + return result + except: # The exact HTTP exception is transportation-layer dependent + # Typically network error. Potential AAD outage? + if not access_token_from_cache: # It means there is no fall back option + raise # We choose to bubble up the exception + return access_token_from_cache + + +def _scope_to_resource(scope): # This is an experimental reasonable-effort approach + u = urlparse(scope) + if u.scheme: + return "{}://{}".format(u.scheme, u.netloc) + return scope # There is no much else we can do here + + +APP_SERVICE = object() +AZURE_ARC = object() +CLOUD_SHELL = object() # In MSAL Python, token acquisition was done by + # PublicClientApplication(...).acquire_token_interactive(..., prompt="none") +MACHINE_LEARNING = object() +SERVICE_FABRIC = object() +DEFAULT_TO_VM = object() # Unknown environment; default to VM; you may want to probe +def get_managed_identity_source(): + """Detect the current environment and return the likely identity source. + + When this function returns ``CLOUD_SHELL``, you should use + :func:`msal.PublicClientApplication.acquire_token_interactive` with ``prompt="none"`` + to obtain a token. + """ + if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ + and "IDENTITY_SERVER_THUMBPRINT" in os.environ + ): + return SERVICE_FABRIC + if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ: + return APP_SERVICE + if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ: + return MACHINE_LEARNING + if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ: + return AZURE_ARC + if _is_running_in_cloud_shell(): + return CLOUD_SHELL + return DEFAULT_TO_VM + + +def _obtain_token(http_client, managed_identity, resource): + # A unified low-level API that talks to different Managed Identity + if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ + and "IDENTITY_SERVER_THUMBPRINT" in os.environ + ): + if managed_identity: + logger.debug( + "Ignoring managed_identity parameter. " + "Managed Identity in Service Fabric is configured in the cluster, " + "not during runtime. See also " + "https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service") + return _obtain_token_on_service_fabric( + http_client, + os.environ["IDENTITY_ENDPOINT"], + os.environ["IDENTITY_HEADER"], + os.environ["IDENTITY_SERVER_THUMBPRINT"], + resource, + ) + if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ: + return _obtain_token_on_app_service( + http_client, + os.environ["IDENTITY_ENDPOINT"], + os.environ["IDENTITY_HEADER"], + managed_identity, + resource, + ) + if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ: + # Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py + return _obtain_token_on_machine_learning( + http_client, + os.environ["MSI_ENDPOINT"], + os.environ["MSI_SECRET"], + managed_identity, + resource, + ) + if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ: + if ManagedIdentity.is_user_assigned(managed_identity): + raise ManagedIdentityError( # Note: Azure Identity for Python raised exception too + "Invalid managed_identity parameter. " + "Azure Arc supports only system-assigned managed identity, " + "See also " + "https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service") + return _obtain_token_on_arc( + http_client, + os.environ["IDENTITY_ENDPOINT"], + resource, + ) + return _obtain_token_on_azure_vm(http_client, managed_identity, resource) + + +def _adjust_param(params, managed_identity): + # Modify the params dict in place + id_name = ManagedIdentity._types_mapping.get( + managed_identity.get(ManagedIdentity.ID_TYPE)) + if id_name: + params[id_name] = managed_identity[ManagedIdentity.ID] + +def _obtain_token_on_azure_vm(http_client, managed_identity, resource): + # Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + logger.debug("Obtaining token via managed identity on Azure VM") + params = { + "api-version": "2018-02-01", + "resource": resource, + } + _adjust_param(params, managed_identity) + resp = http_client.get( + "http://169.254.169.254/metadata/identity/oauth2/token", + params=params, + headers={"Metadata": "true"}, + ) + try: + payload = json.loads(resp.text) + if payload.get("access_token") and payload.get("expires_in"): + return { # Normalizing the payload into OAuth2 format + "access_token": payload["access_token"], + "expires_in": int(payload["expires_in"]), + "resource": payload.get("resource"), + "token_type": payload.get("token_type", "Bearer"), + } + return payload # Typically an error, but it is undefined in the doc above + except json.decoder.JSONDecodeError: + logger.debug("IMDS emits unexpected payload: %s", resp.text) + raise + +def _obtain_token_on_app_service( + http_client, endpoint, identity_header, managed_identity, resource, +): + """Obtains token for + `App Service `_, + Azure Functions, and Azure Automation. + """ + # Prerequisite: Create your app service https://docs.microsoft.com/en-us/azure/app-service/quickstart-python + # Assign it a managed identity https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp + # SSH into your container for testing https://docs.microsoft.com/en-us/azure/app-service/configure-linux-open-ssh-session + logger.debug("Obtaining token via managed identity on Azure App Service") + params = { + "api-version": "2019-08-01", + "resource": resource, + } + _adjust_param(params, managed_identity) + resp = http_client.get( + endpoint, + params=params, + headers={ + "X-IDENTITY-HEADER": identity_header, + "Metadata": "true", # Unnecessary yet harmless for App Service, + # It will be needed by Azure Automation + # https://docs.microsoft.com/en-us/azure/automation/enable-managed-identity-for-automation#get-access-token-for-system-assigned-managed-identity-using-http-get + }, + ) + try: + payload = json.loads(resp.text) + if payload.get("access_token") and payload.get("expires_on"): + return { # Normalizing the payload into OAuth2 format + "access_token": payload["access_token"], + "expires_in": int(payload["expires_on"]) - int(time.time()), + "resource": payload.get("resource"), + "token_type": payload.get("token_type", "Bearer"), + } + return { + "error": "invalid_scope", # Empirically, wrong resource ends up with a vague statusCode=500 + "error_description": "{}, {}".format( + payload.get("statusCode"), payload.get("message")), + } + except json.decoder.JSONDecodeError: + logger.debug("IMDS emits unexpected payload: %s", resp.text) + raise + +def _obtain_token_on_machine_learning( + http_client, endpoint, secret, managed_identity, resource, +): + # Could not find protocol docs from https://docs.microsoft.com/en-us/azure/machine-learning + # The following implementation is back ported from Azure Identity 1.15.0 + logger.debug("Obtaining token via managed identity on Azure Machine Learning") + params = {"api-version": "2017-09-01", "resource": resource} + _adjust_param(params, managed_identity) + if params["api-version"] == "2017-09-01" and "client_id" in params: + # Workaround for a known bug in Azure ML 2017 API + params["clientid"] = params.pop("client_id") + resp = http_client.get( + endpoint, + params=params, + headers={"secret": secret}, + ) + try: + payload = json.loads(resp.text) + if payload.get("access_token") and payload.get("expires_on"): + return { # Normalizing the payload into OAuth2 format + "access_token": payload["access_token"], + "expires_in": int(payload["expires_on"]) - int(time.time()), + "resource": payload.get("resource"), + "token_type": payload.get("token_type", "Bearer"), + } + return { + "error": "invalid_scope", # TODO: To be tested + "error_description": "{}".format(payload), + } + except json.decoder.JSONDecodeError: + logger.debug("IMDS emits unexpected payload: %s", resp.text) + raise + + +def _obtain_token_on_service_fabric( + http_client, endpoint, identity_header, server_thumbprint, resource, +): + """Obtains token for + `Service Fabric `_ + """ + # Deployment https://learn.microsoft.com/en-us/azure/service-fabric/service-fabric-get-started-containers-linux + # See also https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/tests/managed-identity-live/service-fabric/service_fabric.md + # Protocol https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#acquiring-an-access-token-using-rest-api + logger.debug("Obtaining token via managed identity on Azure Service Fabric") + resp = http_client.get( + endpoint, + params={"api-version": "2019-07-01-preview", "resource": resource}, + headers={"Secret": identity_header}, + ) + try: + payload = json.loads(resp.text) + if payload.get("access_token") and payload.get("expires_on"): + return { # Normalizing the payload into OAuth2 format + "access_token": payload["access_token"], + "expires_in": int( # Despite the example in docs shows an integer, + payload["expires_on"] # Azure SDK team's test obtained a string. + ) - int(time.time()), + "resource": payload.get("resource"), + "token_type": payload["token_type"], + } + error = payload.get("error", {}) # https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#error-handling + error_mapping = { # Map Service Fabric errors into OAuth2 errors https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + "SecretHeaderNotFound": "unauthorized_client", + "ManagedIdentityNotFound": "invalid_client", + "ArgumentNullOrEmpty": "invalid_scope", + } + return { + "error": error_mapping.get(payload["error"]["code"], "invalid_request"), + "error_description": resp.text, + } + except json.decoder.JSONDecodeError: + logger.debug("IMDS emits unexpected payload: %s", resp.text) + raise + + +_supported_arc_platforms_and_their_prefixes = { + "linux": "/var/opt/azcmagent/tokens", + "win32": os.path.expandvars(r"%ProgramData%\AzureConnectedMachineAgent\Tokens"), +} + +class ArcPlatformNotSupportedError(ManagedIdentityError): + pass + +def _obtain_token_on_arc(http_client, endpoint, resource): + # https://learn.microsoft.com/en-us/azure/azure-arc/servers/managed-identity-authentication + logger.debug("Obtaining token via managed identity on Azure Arc") + resp = http_client.get( + endpoint, + params={"api-version": "2020-06-01", "resource": resource}, + headers={"Metadata": "true"}, + ) + www_auth = "www-authenticate" # Header in lower case + challenge = { + # Normalized to lowercase, because header names are case-insensitive + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 + k.lower(): v for k, v in resp.headers.items() if k.lower() == www_auth + }.get(www_auth, "").split("=") # Output will be ["Basic realm", "content"] + if not ( # https://datatracker.ietf.org/doc/html/rfc7617#section-2 + len(challenge) == 2 and challenge[0].lower() == "basic realm"): + raise ManagedIdentityError( + "Unrecognizable WWW-Authenticate header: {}".format(resp.headers)) + if sys.platform not in _supported_arc_platforms_and_their_prefixes: + raise ArcPlatformNotSupportedError( + f"Platform {sys.platform} was undefined and unsupported") + filename = os.path.join( + # This algorithm is documented in an internal doc https://msazure.visualstudio.com/One/_wiki/wikis/One.wiki/233012/VM-Extension-Authoring-for-Arc?anchor=2.-obtaining-tokens + _supported_arc_platforms_and_their_prefixes[sys.platform], + os.path.splitext(os.path.basename(challenge[1]))[0] + ".key") + if os.stat(filename).st_size > 4096: # Check size BEFORE loading its content + raise ManagedIdentityError("Local key file shall not be larger than 4KB") + with open(filename) as f: + secret = f.read() + response = http_client.get( + endpoint, + params={"api-version": "2020-06-01", "resource": resource}, + headers={"Metadata": "true", "Authorization": "Basic {}".format(secret)}, + ) + try: + payload = json.loads(response.text) + if payload.get("access_token") and payload.get("expires_in"): + # Example: https://learn.microsoft.com/en-us/azure/azure-arc/servers/media/managed-identity-authentication/bash-token-output-example.png + return { + "access_token": payload["access_token"], + "expires_in": int(payload["expires_in"]), + "token_type": payload.get("token_type", "Bearer"), + "resource": payload.get("resource"), + } + except json.decoder.JSONDecodeError: + pass + return { + "error": "invalid_request", + "error_description": response.text, + } + diff --git a/msal/throttled_http_client.py b/msal/throttled_http_client.py index 1e285ff8..ebad76c7 100644 --- a/msal/throttled_http_client.py +++ b/msal/throttled_http_client.py @@ -9,35 +9,27 @@ DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" -def _hash(raw): - return sha256(repr(raw).encode("utf-8")).hexdigest() - - -def _parse_http_429_5xx_retry_after(result=None, **ignored): - """Return seconds to throttle""" - assert result is not None, """ - The signature defines it with a default value None, - only because the its shape is already decided by the - IndividualCache's.__call__(). - In actual code path, the result parameter here won't be None. - """ - response = result - lowercase_headers = {k.lower(): v for k, v in getattr( - # Historically, MSAL's HttpResponse does not always have headers - response, "headers", {}).items()} - if not (response.status_code == 429 or response.status_code >= 500 - or "retry-after" in lowercase_headers): - return 0 # Quick exit - default = 60 # Recommended at the end of - # https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview - retry_after = lowercase_headers.get("retry-after", default) - try: - # AAD's retry_after uses integer format only - # https://stackoverflow.microsoft.com/questions/264931/264932 - delay_seconds = int(retry_after) - except ValueError: - delay_seconds = default - return min(3600, delay_seconds) +class RetryAfterParser(object): + def __init__(self, default_value=None): + self._default_value = 5 if default_value is None else default_value + + def parse(self, *, result, **ignored): + """Return seconds to throttle""" + response = result + lowercase_headers = {k.lower(): v for k, v in getattr( + # Historically, MSAL's HttpResponse does not always have headers + response, "headers", {}).items()} + if not (response.status_code == 429 or response.status_code >= 500 + or "retry-after" in lowercase_headers): + return 0 # Quick exit + retry_after = lowercase_headers.get("retry-after", self._default_value) + try: + # AAD's retry_after uses integer format only + # https://stackoverflow.microsoft.com/questions/264931/264932 + delay_seconds = int(retry_after) + except ValueError: + delay_seconds = self._default_value + return min(3600, delay_seconds) def _extract_data(kwargs, key, default=None): @@ -45,31 +37,52 @@ def _extract_data(kwargs, key, default=None): return data.get(key) if isinstance(data, dict) else default -class ThrottledHttpClient(object): - def __init__(self, http_client, http_cache): - """Throttle the given http_client by storing and retrieving data from cache. +class ThrottledHttpClientBase(object): + """Throttle the given http_client by storing and retrieving data from cache. - This wrapper exists so that our patching post() and get() would prevent - re-patching side effect when/if same http_client being reused. - """ - expiring_mapping = ExpiringMapping( # It will automatically clean up + This wrapper exists so that our patching post() and get() would prevent + re-patching side effect when/if same http_client being reused. + + The subclass should implement post() and/or get() + """ + def __init__(self, http_client, *, http_cache=None): + self.http_client = http_client + self._expiring_mapping = ExpiringMapping( # It will automatically clean up mapping=http_cache if http_cache is not None else {}, capacity=1024, # To prevent cache blowing up especially for CCA lock=Lock(), # TODO: This should ideally also allow customization ) + def post(self, *args, **kwargs): + return self.http_client.post(*args, **kwargs) + + def get(self, *args, **kwargs): + return self.http_client.get(*args, **kwargs) + + def close(self): + return self.http_client.close() + + @staticmethod + def _hash(raw): + return sha256(repr(raw).encode("utf-8")).hexdigest() + + +class ThrottledHttpClient(ThrottledHttpClientBase): + def __init__(self, http_client, *, default_throttle_time=None, **kwargs): + super(ThrottledHttpClient, self).__init__(http_client, **kwargs) + _post = http_client.post # We'll patch _post, and keep original post() intact _post = IndividualCache( # Internal specs requires throttling on at least token endpoint, # here we have a generic patch for POST on all endpoints. - mapping=expiring_mapping, + mapping=self._expiring_mapping, key_maker=lambda func, args, kwargs: "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format( args[0], # It is the url, typically containing authority and tenant _extract_data(kwargs, "client_id"), # Per internal specs _extract_data(kwargs, "scope"), # Per internal specs - _hash( + self._hash( # The followings are all approximations of the "account" concept # to support per-account throttling. # TODO: We may want to disable it for confidential client, though @@ -77,14 +90,14 @@ def __init__(self, http_client, http_cache): _extract_data(kwargs, "code", # "account" of auth code grant _extract_data(kwargs, "username")))), # "account" of ROPC ), - expires_in=_parse_http_429_5xx_retry_after, + expires_in=RetryAfterParser(default_throttle_time or 5).parse, )(_post) _post = IndividualCache( # It covers the "UI required cache" - mapping=expiring_mapping, + mapping=self._expiring_mapping, key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( args[0], # It is the url, typically containing authority and tenant - _hash( + self._hash( # Here we use literally all parameters, even those short-lived # parameters containing timestamps (WS-Trust or POP assertion), # because they will automatically be cleaned up by ExpiringMapping. @@ -120,22 +133,16 @@ def __init__(self, http_client, http_cache): self.post = _post self.get = IndividualCache( # Typically those discovery GETs - mapping=expiring_mapping, + mapping=self._expiring_mapping, key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format( args[0], # It is the url, sometimes containing inline params - _hash(kwargs.get("params", "")), + self._hash(kwargs.get("params", "")), ), expires_in=lambda result=None, **ignored: 3600*24 if 200 <= result.status_code < 300 else 0, )(http_client.get) - self._http_client = http_client - # The following 2 methods have been defined dynamically by __init__() #def post(self, *args, **kwargs): pass #def get(self, *args, **kwargs): pass - def close(self): - """MSAL won't need this. But we allow throttled_http_client.close() anyway""" - return self._http_client.close() - diff --git a/sample/.env.sample.managed_identity b/sample/.env.sample.managed_identity new file mode 100644 index 00000000..8b62f4f1 --- /dev/null +++ b/sample/.env.sample.managed_identity @@ -0,0 +1,17 @@ +# This sample can be configured to work with Microsoft Entra ID's Managed Identity. +# +# A user-assigned managed identity can be represented as a JSON blob. +# Check MSAL Python's API Reference for its syntax. +# https://msal-python.readthedocs.io/en/latest/#managed-identity +# +# Example value when using a user-assigned managed identity: +# {"ManagedIdentityIdType": "ClientId", "Id": "your_managed_identity_id"} +# Leave it empty or absent if you are using a system-assigned managed identity. +MANAGED_IDENTITY= + +# Managed Identity works with resource, not scopes. +RESOURCE= + +# Required if the sample app wants to call an API. +#ENDPOINT=https://graph.microsoft.com/v1.0/me + diff --git a/sample/managed_identity_sample.py b/sample/managed_identity_sample.py new file mode 100644 index 00000000..95bf1b28 --- /dev/null +++ b/sample/managed_identity_sample.py @@ -0,0 +1,77 @@ +""" +This sample demonstrates a daemon application that acquires a token using a +managed identity and then calls a web API with the token. + +This sample loads its configuration from a .env file. + +To make this sample work, you need to choose this template: + + .env.sample.managed_identity + +Copy the chosen template to a new file named .env, and fill in the values. + +You can then run this sample: + + python name_of_this_script.py +""" +import json +import logging +import os +import time + +from dotenv import load_dotenv # Need "pip install python-dotenv" +import msal +import requests + + +# Optional logging +# logging.basicConfig(level=logging.DEBUG) # Enable DEBUG log for entire script +# logging.getLogger("msal").setLevel(logging.INFO) # Optionally disable MSAL DEBUG logs + +load_dotenv() # We use this to load configuration from a .env file + +# If for whatever reason you plan to recreate same ClientApplication periodically, +# you shall create one global token cache and reuse it by each ClientApplication +global_token_cache = msal.TokenCache() # The TokenCache() is in-memory. + # See more options in https://msal-python.readthedocs.io/en/latest/#tokencache + +# Create a managed identity instance based on the environment variable value +if os.getenv('MANAGED_IDENTITY'): + managed_identity = json.loads(os.getenv('MANAGED_IDENTITY')) +else: + managed_identity = msal.SystemAssignedManagedIdentity() + +# Create a preferably long-lived app instance, to avoid the overhead of app creation +global_app = msal.ManagedIdentityClient( + managed_identity, + http_client=requests.Session(), + token_cache=global_token_cache, # Let this app (re)use an existing token cache. + # If absent, ClientApplication will create its own empty token cache + ) +resource = os.getenv("RESOURCE") + + +def acquire_and_use_token(): + # ManagedIdentityClient.acquire_token_for_client(...) will automatically look up + # a token from cache, and fall back to acquire a fresh token when needed. + result = global_app.acquire_token_for_client(resource=resource) + + if "access_token" in result: + if os.getenv('ENDPOINT'): + # Calling a web API using the access token + api_result = requests.get( + os.getenv('ENDPOINT'), + headers={'Authorization': 'Bearer ' + result['access_token']}, + ).json() # Assuming the response is JSON + print("Web API call result", json.dumps(api_result, indent=2)) + else: + print("Token acquisition result", json.dumps(result, indent=2)) + else: + print("Token acquisition failed", result) # Examine result["error_description"] etc. to diagnose error + + +while True: # Here we mimic a long-lived daemon + acquire_and_use_token() + print("Press Ctrl-C to stop.") + time.sleep(5) # Let's say your app would run a workload every X minutes + diff --git a/tests/test_mi.py b/tests/test_mi.py new file mode 100644 index 00000000..d6dcc159 --- /dev/null +++ b/tests/test_mi.py @@ -0,0 +1,286 @@ +import json +import os +import sys +import time +import unittest +try: + from unittest.mock import patch, ANY, mock_open, Mock +except: + from mock import patch, ANY, mock_open, Mock +import requests + +from tests.http_client import MinimalResponse +from msal import ( + SystemAssignedManagedIdentity, UserAssignedManagedIdentity, + ManagedIdentityClient, + ManagedIdentityError, + ArcPlatformNotSupportedError, +) +from msal.managed_identity import ( + _supported_arc_platforms_and_their_prefixes, + get_managed_identity_source, + APP_SERVICE, + AZURE_ARC, + CLOUD_SHELL, + MACHINE_LEARNING, + SERVICE_FABRIC, + DEFAULT_TO_VM, +) + + +class ManagedIdentityTestCase(unittest.TestCase): + def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_from_file_or_env_var(self): + self.assertEqual( + UserAssignedManagedIdentity(client_id="foo"), + {"ManagedIdentityIdType": "ClientId", "Id": "foo"}) + self.assertEqual( + UserAssignedManagedIdentity(resource_id="foo"), + {"ManagedIdentityIdType": "ResourceId", "Id": "foo"}) + self.assertEqual( + UserAssignedManagedIdentity(object_id="foo"), + {"ManagedIdentityIdType": "ObjectId", "Id": "foo"}) + with self.assertRaises(ManagedIdentityError): + UserAssignedManagedIdentity() + with self.assertRaises(ManagedIdentityError): + UserAssignedManagedIdentity(client_id="foo", resource_id="bar") + self.assertEqual( + SystemAssignedManagedIdentity(), + {"ManagedIdentityIdType": "SystemAssigned", "Id": None}) + + +class ClientTestCase(unittest.TestCase): + maxDiff = None + + def setUp(self): + self.app = ManagedIdentityClient( + { # Here we test it with the raw dict form, to test that + # the client has no hard dependency on ManagedIdentity object + "ManagedIdentityIdType": "SystemAssigned", "Id": None, + }, + http_client=requests.Session(), + ) + + def _test_token_cache(self, app): + cache = app._token_cache._cache + self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT") + at = list(cache["AccessToken"].values())[0] + self.assertEqual( + app._managed_identity.get("Id", "SYSTEM_ASSIGNED_MANAGED_IDENTITY"), + at["client_id"], + "Should have expected client_id") + self.assertEqual("managed_identity", at["realm"], "Should have expected realm") + + def _test_happy_path(self, app, mocked_http): + result = app.acquire_token_for_client(resource="R") + mocked_http.assert_called() + self.assertEqual({ + "access_token": "AT", + "expires_in": 1234, + "resource": "R", + "token_type": "Bearer", + }, result, "Should obtain a token response") + self.assertEqual( + result["access_token"], + app.acquire_token_for_client(resource="R").get("access_token"), + "Should hit the same token from cache") + self._test_token_cache(app) + + +class VmTestCase(ClientTestCase): + + def test_happy_path(self): + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}', + )) as mocked_method: + self._test_happy_path(self.app, mocked_method) + + def test_vm_error_should_be_returned_as_is(self): + raw_error = '{"raw": "error format is undefined"}' + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=400, + text=raw_error, + )) as mocked_method: + self.assertEqual( + json.loads(raw_error), self.app.acquire_token_for_client(resource="R")) + self.assertEqual({}, self.app._token_cache._cache) + + +@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"}) +class AppServiceTestCase(ClientTestCase): + + def test_happy_path(self): + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % ( + int(time.time()) + 1234), + )) as mocked_method: + self._test_happy_path(self.app, mocked_method) + + def test_app_service_error_should_be_normalized(self): + raw_error = '{"statusCode": 500, "message": "error content is undefined"}' + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=500, + text=raw_error, + )) as mocked_method: + self.assertEqual({ + "error": "invalid_scope", + "error_description": "500, error content is undefined", + }, self.app.acquire_token_for_client(resource="R")) + self.assertEqual({}, self.app._token_cache._cache) + + +@patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"}) +class MachineLearningTestCase(ClientTestCase): + + def test_happy_path(self): + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % ( + int(time.time()) + 1234), + )) as mocked_method: + self._test_happy_path(self.app, mocked_method) + + def test_machine_learning_error_should_be_normalized(self): + raw_error = '{"error": "placeholder", "message": "placeholder"}' + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=500, + text=raw_error, + )) as mocked_method: + self.assertEqual({ + "error": "invalid_scope", + "error_description": "{'error': 'placeholder', 'message': 'placeholder'}", + }, self.app.acquire_token_for_client(resource="R")) + self.assertEqual({}, self.app._token_cache._cache) + + +@patch.dict(os.environ, { + "IDENTITY_ENDPOINT": "http://localhost", + "IDENTITY_HEADER": "foo", + "IDENTITY_SERVER_THUMBPRINT": "bar", +}) +class ServiceFabricTestCase(ClientTestCase): + + def _test_happy_path(self, app): + with patch.object(app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % ( + int(time.time()) + 1234), + )) as mocked_method: + super(ServiceFabricTestCase, self)._test_happy_path(app, mocked_method) + + def test_happy_path(self): + self._test_happy_path(self.app) + + def test_unified_api_service_should_ignore_unnecessary_client_id(self): + self._test_happy_path(ManagedIdentityClient( + {"ManagedIdentityIdType": "ClientId", "Id": "foo"}, + http_client=requests.Session(), + )) + + def test_sf_error_should_be_normalized(self): + raw_error = ''' +{"error": { + "correlationId": "foo", + "code": "SecretHeaderNotFound", + "message": "Secret is not found in the request headers." +}}''' # https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#error-handling + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=404, + text=raw_error, + )) as mocked_method: + self.assertEqual({ + "error": "unauthorized_client", + "error_description": raw_error, + }, self.app.acquire_token_for_client(resource="R")) + self.assertEqual({}, self.app._token_cache._cache) + + +@patch.dict(os.environ, { + "IDENTITY_ENDPOINT": "http://localhost/token", + "IMDS_ENDPOINT": "http://localhost", +}) +@patch( + "builtins.open" if sys.version_info.major >= 3 else "__builtin__.open", + new=mock_open(read_data="secret"), # `new` requires no extra argument on the decorated function. + # https://docs.python.org/3/library/unittest.mock.html#unittest.mock.patch +) +@patch("os.stat", return_value=Mock(st_size=4096)) +class ArcTestCase(ClientTestCase): + challenge = MinimalResponse(status_code=401, text="", headers={ + "WWW-Authenticate": "Basic realm=/tmp/foo", + }) + + def test_happy_path(self, mocked_stat): + with patch.object(self.app._http_client, "get", side_effect=[ + self.challenge, + MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}', + ), + ]) as mocked_method: + try: + super(ArcTestCase, self)._test_happy_path(self.app, mocked_method) + mocked_stat.assert_called_with(os.path.join( + _supported_arc_platforms_and_their_prefixes[sys.platform], + "foo.key")) + except ArcPlatformNotSupportedError: + if sys.platform in _supported_arc_platforms_and_their_prefixes: + self.fail("Should not raise ArcPlatformNotSupportedError") + + def test_arc_error_should_be_normalized(self, mocked_stat): + with patch.object(self.app._http_client, "get", side_effect=[ + self.challenge, + MinimalResponse(status_code=400, text="undefined"), + ]) as mocked_method: + try: + self.assertEqual({ + "error": "invalid_request", + "error_description": "undefined", + }, self.app.acquire_token_for_client(resource="R")) + self.assertEqual({}, self.app._token_cache._cache) + except ArcPlatformNotSupportedError: + if sys.platform in _supported_arc_platforms_and_their_prefixes: + self.fail("Should not raise ArcPlatformNotSupportedError") + + +class GetManagedIdentitySourceTestCase(unittest.TestCase): + + @patch.dict(os.environ, { + "IDENTITY_ENDPOINT": "http://localhost", + "IDENTITY_HEADER": "foo", + "IDENTITY_SERVER_THUMBPRINT": "bar", + }) + def test_service_fabric(self): + self.assertEqual(get_managed_identity_source(), SERVICE_FABRIC) + + @patch.dict(os.environ, { + "IDENTITY_ENDPOINT": "http://localhost", + "IDENTITY_HEADER": "foo", + }) + def test_app_service(self): + self.assertEqual(get_managed_identity_source(), APP_SERVICE) + + @patch.dict(os.environ, { + "MSI_ENDPOINT": "http://localhost", + "MSI_SECRET": "foo", + }) + def test_machine_learning(self): + self.assertEqual(get_managed_identity_source(), MACHINE_LEARNING) + + @patch.dict(os.environ, { + "IDENTITY_ENDPOINT": "http://localhost", + "IMDS_ENDPOINT": "http://localhost", + }) + def test_arc(self): + self.assertEqual(get_managed_identity_source(), AZURE_ARC) + + @patch.dict(os.environ, { + "AZUREPS_HOST_ENVIRONMENT": "cloud-shell-foo", + }) + def test_cloud_shell(self): + self.assertEqual(get_managed_identity_source(), CLOUD_SHELL) + + def test_default_to_vm(self): + self.assertEqual(get_managed_identity_source(), DEFAULT_TO_VM) + diff --git a/tests/test_throttled_http_client.py b/tests/test_throttled_http_client.py index aa20060d..3994719d 100644 --- a/tests/test_throttled_http_client.py +++ b/tests/test_throttled_http_client.py @@ -40,11 +40,10 @@ class CloseMethodCalled(Exception): class TestHttpDecoration(unittest.TestCase): def test_throttled_http_client_should_not_alter_original_http_client(self): - http_cache = {} original_http_client = DummyHttpClient() original_get = original_http_client.get original_post = original_http_client.post - throttled_http_client = ThrottledHttpClient(original_http_client, http_cache) + throttled_http_client = ThrottledHttpClient(original_http_client) goal = """The implementation should wrap original http_client and keep it intact, instead of monkey-patching it""" self.assertNotEqual(throttled_http_client, original_http_client, goal) @@ -54,7 +53,7 @@ def test_throttled_http_client_should_not_alter_original_http_client(self): def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( self, http_client, retry_after): http_cache = {} - http_client = ThrottledHttpClient(http_client, http_cache) + http_client = ThrottledHttpClient(http_client, http_cache=http_cache) resp1 = http_client.post("https://example.com") # We implemented POST only resp2 = http_client.post("https://example.com") # We implemented POST only logger.debug(http_cache) @@ -90,7 +89,7 @@ def test_one_RetryAfter_request_should_block_a_similar_request(self): http_cache = {} http_client = DummyHttpClient( status_code=429, response_headers={"Retry-After": 2}) - http_client = ThrottledHttpClient(http_client, http_cache) + http_client = ThrottledHttpClient(http_client, http_cache=http_cache) resp1 = http_client.post("https://example.com", data={ "scope": "one", "claims": "bar", "grant_type": "authorization_code"}) resp2 = http_client.post("https://example.com", data={ @@ -102,7 +101,7 @@ def test_one_RetryAfter_request_should_not_block_a_different_request(self): http_cache = {} http_client = DummyHttpClient( status_code=429, response_headers={"Retry-After": 2}) - http_client = ThrottledHttpClient(http_client, http_cache) + http_client = ThrottledHttpClient(http_client, http_cache=http_cache) resp1 = http_client.post("https://example.com", data={"scope": "one"}) resp2 = http_client.post("https://example.com", data={"scope": "two"}) logger.debug(http_cache) @@ -112,7 +111,7 @@ def test_one_invalid_grant_should_block_a_similar_request(self): http_cache = {} http_client = DummyHttpClient( status_code=400) # It covers invalid_grant and interaction_required - http_client = ThrottledHttpClient(http_client, http_cache) + http_client = ThrottledHttpClient(http_client, http_cache=http_cache) resp1 = http_client.post("https://example.com", data={"claims": "foo"}) logger.debug(http_cache) resp1_again = http_client.post("https://example.com", data={"claims": "foo"}) @@ -146,7 +145,7 @@ def test_http_get_200_should_be_cached(self): http_cache = {} http_client = DummyHttpClient( status_code=200) # It covers UserRealm discovery and OIDC discovery - http_client = ThrottledHttpClient(http_client, http_cache) + http_client = ThrottledHttpClient(http_client, http_cache=http_cache) resp1 = http_client.get("https://example.com?foo=bar") resp2 = http_client.get("https://example.com?foo=bar") logger.debug(http_cache) @@ -156,7 +155,7 @@ def test_device_flow_retry_should_not_be_cached(self): DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" http_cache = {} http_client = DummyHttpClient(status_code=400) - http_client = ThrottledHttpClient(http_client, http_cache) + http_client = ThrottledHttpClient(http_client, http_cache=http_cache) resp1 = http_client.post( "https://example.com", data={"grant_type": DEVICE_AUTH_GRANT}) resp2 = http_client.post( @@ -165,9 +164,8 @@ def test_device_flow_retry_should_not_be_cached(self): self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") def test_throttled_http_client_should_provide_close(self): - http_cache = {} http_client = DummyHttpClient(status_code=200) - http_client = ThrottledHttpClient(http_client, http_cache) + http_client = ThrottledHttpClient(http_client) with self.assertRaises(CloseMethodCalled): http_client.close()