diff --git a/msal/application.py b/msal/application.py index b3c07a47..75ca6c83 100644 --- a/msal/application.py +++ b/msal/application.py @@ -411,9 +411,11 @@ def __init__( (STS) what this client is capable for, so STS can decide to turn on certain features. For example, if client is capable to handle *claims challenge*, - STS can then issue CAE access tokens to resources - knowing when the resource emits *claims challenge* - the client will be capable to handle. + STS may issue + `Continuous Access Evaluation (CAE) `_ + access tokens to resources, + knowing that when the resource emits a *claims challenge* + the client will be able to handle those challenges. Implementation details: Client capability is implemented using "claims" parameter on the wire, diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 5636f564..181d34c3 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -10,7 +10,7 @@ import time from urllib.parse import urlparse # Python 3+ from collections import UserDict # Python 3+ -from typing import Union # Needed in Python 3.7 & 3.8 +from typing import Optional, Union # Needed in Python 3.7 & 3.8 from .token_cache import TokenCache from .individual_cache import _IndividualCache as IndividualCache from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser @@ -145,6 +145,9 @@ class ManagedIdentityClient(object): not a token with application permissions for an app. """ __instance, _tenant = None, "managed_identity" # Placeholders + _TOKEN_SOURCE = "token_source" + _TOKEN_SOURCE_IDP = "identity_provider" + _TOKEN_SOURCE_CACHE = "cache" def __init__( self, @@ -237,12 +240,31 @@ def _get_instance(self): self.__instance = socket.getfqdn() # Moved from class definition to here return self.__instance - def acquire_token_for_client(self, *, resource): # We may support scope in the future + def acquire_token_for_client( + self, + *, + resource: str, # If/when we support scope, resource will become optional + claims_challenge: Optional[str] = None, + ): """Acquire token for the managed identity. The result will be automatically cached. Subsequent calls will automatically search from cache first. + :param resource: The resource for which the token is acquired. + + :param claims_challenge: + Optional. + It is a string representation of a JSON object + (which contains lists of claims being requested). + + The tenant admin may choose to revoke all Managed Identity tokens, + and then a *claims challenge* will be returned by the target resource, + as a `claims_challenge` directive in the `www-authenticate` header, + even if the app developer did not opt in for the "CP1" client capability. + Upon receiving a `claims_challenge`, MSAL will skip a token cache read, + and will attempt to acquire a new token. + .. note:: Known issue: When an Azure VM has only one user-assigned managed identity, @@ -255,8 +277,8 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the access_token_from_cache = None client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") - if True: # Does not offer an "if not force_refresh" option, because - # there would be built-in token cache in the service side anyway + now = time.time() + if not claims_challenge: # Then attempt token cache search matches = self._token_cache.find( self._token_cache.CredentialType.ACCESS_TOKEN, target=[resource], @@ -267,7 +289,6 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the home_account_id=None, ), ) - now = time.time() for entry in matches: expires_in = int(entry["expires_on"]) - now if expires_in < 5*60: # Then consider it expired @@ -277,6 +298,7 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the "access_token": entry["secret"], "token_type": entry.get("token_type", "Bearer"), "expires_in": int(expires_in), # OAuth2 specs defines it as int + self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE, } if "refresh_on" in entry: access_token_from_cache["refresh_on"] = int(entry["refresh_on"]) @@ -300,6 +322,7 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the )) if "refresh_in" in result: result["refresh_on"] = int(now + result["refresh_in"]) + result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP if (result and "error" not in result) or (not access_token_from_cache): return result except: # The exact HTTP exception is transportation-layer dependent diff --git a/tests/test_mi.py b/tests/test_mi.py index f3182c7b..2041419d 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -82,20 +82,17 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): self.assertTrue( is_subdict_of(expected_result, result), # We will test refresh_on later "Should obtain a token response") + self.assertTrue(result["token_source"], "identity_provider") self.assertEqual(expires_in, result["expires_in"], "Should have expected expires_in") if expires_in >= 7200: expected_refresh_on = int(time.time() + expires_in / 2) self.assertTrue( expected_refresh_on - 1 <= result["refresh_on"] <= expected_refresh_on + 1, "Should have a refresh_on time around the middle of the token's life") - self.assertEqual( - result["access_token"], - app.acquire_token_for_client(resource=resource).get("access_token"), - "Should hit the same token from cache") - - self.assertCacheStatus(app) result = app.acquire_token_for_client(resource=resource) + self.assertCacheStatus(app) + self.assertEqual("cache", result["token_source"], "Should hit cache") self.assertEqual( call_count, mocked_http.call_count, "No new call to the mocked http should be made for a cache hit") @@ -110,6 +107,9 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on, "Should have a refresh_on time around the middle of the token's life") + result = app.acquire_token_for_client(resource=resource, claims_challenge="foo") + self.assertEqual("identity_provider", result["token_source"], "Should miss cache") + class VmTestCase(ClientTestCase): @@ -249,7 +249,8 @@ def test_happy_path(self, mocked_stat): status_code=200, text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in, ), - ]) as mocked_method: + ] * 2, # Duplicate a pair of mocks for _test_happy_path()'s CAE check + ) as mocked_method: try: self._test_happy_path(self.app, mocked_method, expires_in) mocked_stat.assert_called_with(os.path.join(