From d51abb4f6c4a9761c3efb885f4bb9408323d4a3d Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Mon, 5 Aug 2024 12:34:56 -0700 Subject: [PATCH] CAE for MIv1 CAE team and MSI team are working on turning on CAE by default for MSI v1. So what that means is, App developers will start seeing CAE even without setting the capability - "CP1". Update msal/application.py Co-authored-by: Den Delimarsky <53200638+localden@users.noreply.github.com> Update msal/application.py Co-authored-by: Den Delimarsky <53200638+localden@users.noreply.github.com> Update msal/application.py Co-authored-by: Den Delimarsky <53200638+localden@users.noreply.github.com> Update msal/managed_identity.py Co-authored-by: Den Delimarsky <53200638+localden@users.noreply.github.com> Update msal/managed_identity.py Co-authored-by: Den Delimarsky <53200638+localden@users.noreply.github.com> Update msal/managed_identity.py Co-authored-by: Den Delimarsky <53200638+localden@users.noreply.github.com> Update msal/managed_identity.py Co-authored-by: Den Delimarsky <53200638+localden@users.noreply.github.com> --- msal/application.py | 8 +++++--- msal/managed_identity.py | 33 ++++++++++++++++++++++++++++----- tests/test_mi.py | 15 ++++++++------- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/msal/application.py b/msal/application.py index b3c07a47..75ca6c83 100644 --- a/msal/application.py +++ b/msal/application.py @@ -411,9 +411,11 @@ def __init__( (STS) what this client is capable for, so STS can decide to turn on certain features. For example, if client is capable to handle *claims challenge*, - STS can then issue CAE access tokens to resources - knowing when the resource emits *claims challenge* - the client will be capable to handle. + STS may issue + `Continuous Access Evaluation (CAE) `_ + access tokens to resources, + knowing that when the resource emits a *claims challenge* + the client will be able to handle those challenges. Implementation details: Client capability is implemented using "claims" parameter on the wire, diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 5636f564..181d34c3 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -10,7 +10,7 @@ import time from urllib.parse import urlparse # Python 3+ from collections import UserDict # Python 3+ -from typing import Union # Needed in Python 3.7 & 3.8 +from typing import Optional, Union # Needed in Python 3.7 & 3.8 from .token_cache import TokenCache from .individual_cache import _IndividualCache as IndividualCache from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser @@ -145,6 +145,9 @@ class ManagedIdentityClient(object): not a token with application permissions for an app. """ __instance, _tenant = None, "managed_identity" # Placeholders + _TOKEN_SOURCE = "token_source" + _TOKEN_SOURCE_IDP = "identity_provider" + _TOKEN_SOURCE_CACHE = "cache" def __init__( self, @@ -237,12 +240,31 @@ def _get_instance(self): self.__instance = socket.getfqdn() # Moved from class definition to here return self.__instance - def acquire_token_for_client(self, *, resource): # We may support scope in the future + def acquire_token_for_client( + self, + *, + resource: str, # If/when we support scope, resource will become optional + claims_challenge: Optional[str] = None, + ): """Acquire token for the managed identity. The result will be automatically cached. Subsequent calls will automatically search from cache first. + :param resource: The resource for which the token is acquired. + + :param claims_challenge: + Optional. + It is a string representation of a JSON object + (which contains lists of claims being requested). + + The tenant admin may choose to revoke all Managed Identity tokens, + and then a *claims challenge* will be returned by the target resource, + as a `claims_challenge` directive in the `www-authenticate` header, + even if the app developer did not opt in for the "CP1" client capability. + Upon receiving a `claims_challenge`, MSAL will skip a token cache read, + and will attempt to acquire a new token. + .. note:: Known issue: When an Azure VM has only one user-assigned managed identity, @@ -255,8 +277,8 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the access_token_from_cache = None client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") - if True: # Does not offer an "if not force_refresh" option, because - # there would be built-in token cache in the service side anyway + now = time.time() + if not claims_challenge: # Then attempt token cache search matches = self._token_cache.find( self._token_cache.CredentialType.ACCESS_TOKEN, target=[resource], @@ -267,7 +289,6 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the home_account_id=None, ), ) - now = time.time() for entry in matches: expires_in = int(entry["expires_on"]) - now if expires_in < 5*60: # Then consider it expired @@ -277,6 +298,7 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the "access_token": entry["secret"], "token_type": entry.get("token_type", "Bearer"), "expires_in": int(expires_in), # OAuth2 specs defines it as int + self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE, } if "refresh_on" in entry: access_token_from_cache["refresh_on"] = int(entry["refresh_on"]) @@ -300,6 +322,7 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the )) if "refresh_in" in result: result["refresh_on"] = int(now + result["refresh_in"]) + result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP if (result and "error" not in result) or (not access_token_from_cache): return result except: # The exact HTTP exception is transportation-layer dependent diff --git a/tests/test_mi.py b/tests/test_mi.py index f3182c7b..2041419d 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -82,20 +82,17 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): self.assertTrue( is_subdict_of(expected_result, result), # We will test refresh_on later "Should obtain a token response") + self.assertTrue(result["token_source"], "identity_provider") self.assertEqual(expires_in, result["expires_in"], "Should have expected expires_in") if expires_in >= 7200: expected_refresh_on = int(time.time() + expires_in / 2) self.assertTrue( expected_refresh_on - 1 <= result["refresh_on"] <= expected_refresh_on + 1, "Should have a refresh_on time around the middle of the token's life") - self.assertEqual( - result["access_token"], - app.acquire_token_for_client(resource=resource).get("access_token"), - "Should hit the same token from cache") - - self.assertCacheStatus(app) result = app.acquire_token_for_client(resource=resource) + self.assertCacheStatus(app) + self.assertEqual("cache", result["token_source"], "Should hit cache") self.assertEqual( call_count, mocked_http.call_count, "No new call to the mocked http should be made for a cache hit") @@ -110,6 +107,9 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on, "Should have a refresh_on time around the middle of the token's life") + result = app.acquire_token_for_client(resource=resource, claims_challenge="foo") + self.assertEqual("identity_provider", result["token_source"], "Should miss cache") + class VmTestCase(ClientTestCase): @@ -249,7 +249,8 @@ def test_happy_path(self, mocked_stat): status_code=200, text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in, ), - ]) as mocked_method: + ] * 2, # Duplicate a pair of mocks for _test_happy_path()'s CAE check + ) as mocked_method: try: self._test_happy_path(self.app, mocked_method, expires_in) mocked_stat.assert_called_with(os.path.join(