diff --git a/msal/application.py b/msal/application.py index 391cd1ae..038badb6 100644 --- a/msal/application.py +++ b/msal/application.py @@ -280,6 +280,49 @@ def _get_authority_aliases(self, instance): return [alias for alias in group if alias != instance] return [] + def remove_account(self, account): + """Sign me out and forget me from token cache""" + self._forget_me(account) + + def _sign_out(self, home_account): + # Remove all relevant RTs and ATs from token cache + owned_by_home_account = { + "environment": home_account["environment"], + "home_account_id": home_account["home_account_id"],} # realm-independent + app_metadata = self._get_app_metadata(home_account["environment"]) + # Remove RTs/FRTs, and they are realm-independent + for rt in [rt for rt in self.token_cache.find( + TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_home_account) + # Do RT's app ownership check as a precaution, in case family apps + # and 3rd-party apps share same token cache, although they should not. + if rt["client_id"] == self.client_id or ( + app_metadata.get("family_id") # Now let's settle family business + and rt.get("family_id") == app_metadata["family_id"]) + ]: + self.token_cache.remove_rt(rt) + for at in self.token_cache.find( # Remove ATs + # Regardless of realm, b/c we've removed realm-independent RTs anyway + TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account): + # To avoid the complexity of locating sibling family app's AT, + # we skip AT's app ownership check. + # It means ATs for other apps will also be removed, it is OK because: + # * non-family apps are not supposed to share token cache to begin with; + # * Even if it happens, we keep other app's RT already, so SSO still works + self.token_cache.remove_at(at) + + def _forget_me(self, home_account): + # It implies signout, and then also remove all relevant accounts and IDTs + self._sign_out(home_account) + owned_by_home_account = { + "environment": home_account["environment"], + "home_account_id": home_account["home_account_id"],} # realm-independent + for idt in self.token_cache.find( # Remove IDTs, regardless of realm + TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account): + self.token_cache.remove_idt(idt) + for a in self.token_cache.find( # Remove Accounts, regardless of realm + TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account): + self.token_cache.remove_account(a) + def acquire_token_silent( self, scopes, # type: List[str] @@ -364,10 +407,7 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( "home_account_id": (account or {}).get("home_account_id"), # "realm": authority.tenant, # AAD RTs are tenant-independent } - apps = self.token_cache.find( # Use find(), rather than token_cache.get(...) - TokenCache.CredentialType.APP_METADATA, query={ - "environment": authority.instance, "client_id": self.client_id}) - app_metadata = apps[0] if apps else {} + app_metadata = self._get_app_metadata(authority.instance) if not app_metadata: # Meaning this app is now used for the first time. # When/if we have a way to directly detect current app's family, # we'll rewrite this block, to support multiple families. @@ -396,6 +436,12 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( return self._acquire_token_silent_by_finding_specific_refresh_token( authority, scopes, dict(query, client_id=self.client_id), **kwargs) + def _get_app_metadata(self, environment): + apps = self.token_cache.find( # Use find(), rather than token_cache.get(...) + TokenCache.CredentialType.APP_METADATA, query={ + "environment": environment, "client_id": self.client_id}) + return apps[0] if apps else {} + def _acquire_token_silent_by_finding_specific_refresh_token( self, authority, scopes, query, rt_remover=None, break_condition=lambda response: False, **kwargs): diff --git a/msal/token_cache.py b/msal/token_cache.py index 8fd79e59..e802eddd 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -39,6 +39,12 @@ class AuthorityType: def __init__(self): self._lock = threading.RLock() self._cache = {} + self.key_makers = { + self.CredentialType.REFRESH_TOKEN: self._build_rt_key, + self.CredentialType.ACCESS_TOKEN: self._build_at_key, + self.CredentialType.ID_TOKEN: self._build_idt_key, + self.CredentialType.ACCOUNT: self._build_account_key, + } def find(self, credential_type, target=None, query=None): target = target or [] @@ -83,14 +89,9 @@ def add(self, event, now=None): with self._lock: if access_token: - key = "-".join([ - home_account_id or "", - environment or "", - self.CredentialType.ACCESS_TOKEN, - event.get("client_id", ""), - realm or "", - target, - ]).lower() + key = self._build_at_key( + home_account_id, environment, event.get("client_id", ""), + realm, target) now = time.time() if now is None else now expires_in = response.get("expires_in", 3599) self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = { @@ -110,11 +111,7 @@ def add(self, event, now=None): if client_info: decoded_id_token = json.loads( base64decode(id_token.split('.')[1])) if id_token else {} - key = "-".join([ - home_account_id or "", - environment or "", - realm or "", - ]).lower() + key = self._build_account_key(home_account_id, environment, realm) self._cache.setdefault(self.CredentialType.ACCOUNT, {})[key] = { "home_account_id": home_account_id, "environment": environment, @@ -129,14 +126,8 @@ def add(self, event, now=None): } if id_token: - key = "-".join([ - home_account_id or "", - environment or "", - self.CredentialType.ID_TOKEN, - event.get("client_id", ""), - realm or "", - "" # Albeit irrelevant, schema requires an empty scope here - ]).lower() + key = self._build_idt_key( + home_account_id, environment, event.get("client_id", ""), realm) self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = { "credential_type": self.CredentialType.ID_TOKEN, "secret": id_token, @@ -170,6 +161,24 @@ def add(self, event, now=None): "family_id": response.get("foci"), # None is also valid } + def modify(self, credential_type, old_entry, new_key_value_pairs=None): + # Modify the specified old_entry with new_key_value_pairs, + # or remove the old_entry if the new_key_value_pairs is None. + + # This helper exists to consolidate all token modify/remove behaviors, + # so that the sub-classes will have only one method to work on, + # instead of patching a pair of update_xx() and remove_xx() per type. + # You can monkeypatch self.key_makers to support more types on-the-fly. + key = self.key_makers[credential_type](**old_entry) + with self._lock: + if new_key_value_pairs: # Update with them + entries = self._cache.setdefault(credential_type, {}) + entry = entries.get(key, {}) # key usually exists, but we'll survive its absence + entry.update(new_key_value_pairs) + else: # Remove old_entry + self._cache.setdefault(credential_type, {}).pop(key, None) + + @staticmethod def _build_appmetadata_key(environment, client_id): return "appmetadata-{}-{}".format(environment or "", client_id or "") @@ -178,7 +187,7 @@ def _build_appmetadata_key(environment, client_id): def _build_rt_key( cls, home_account_id=None, environment=None, client_id=None, target=None, - **ignored): + **ignored_payload_from_a_real_token): return "-".join([ home_account_id or "", environment or "", @@ -189,16 +198,61 @@ def _build_rt_key( ]).lower() def remove_rt(self, rt_item): - key = self._build_rt_key(**rt_item) - with self._lock: - self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}).pop(key, None) + assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN + return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item) def update_rt(self, rt_item, new_rt): - key = self._build_rt_key(**rt_item) - with self._lock: - RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}) - rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence - rt["secret"] = new_rt + assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN + return self.modify( + self.CredentialType.REFRESH_TOKEN, rt_item, {"secret": new_rt}) + + @classmethod + def _build_at_key(cls, + home_account_id=None, environment=None, client_id=None, + realm=None, target=None, **ignored_payload_from_a_real_token): + return "-".join([ + home_account_id or "", + environment or "", + cls.CredentialType.ACCESS_TOKEN, + client_id, + realm or "", + target or "", + ]).lower() + + def remove_at(self, at_item): + assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN + return self.modify(self.CredentialType.ACCESS_TOKEN, at_item) + + @classmethod + def _build_idt_key(cls, + home_account_id=None, environment=None, client_id=None, realm=None, + **ignored_payload_from_a_real_token): + return "-".join([ + home_account_id or "", + environment or "", + cls.CredentialType.ID_TOKEN, + client_id or "", + realm or "", + "" # Albeit irrelevant, schema requires an empty scope here + ]).lower() + + def remove_idt(self, idt_item): + assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN + return self.modify(self.CredentialType.ID_TOKEN, idt_item) + + @classmethod + def _build_account_key(cls, + home_account_id=None, environment=None, realm=None, + **ignored_payload_from_a_real_entry): + return "-".join([ + home_account_id or "", + environment or "", + realm or "", + ]).lower() + + def remove_account(self, account_item): + assert "authority_type" in account_item + return self.modify(self.CredentialType.ACCOUNT, account_item) class SerializableTokenCache(TokenCache): @@ -221,7 +275,7 @@ class SerializableTokenCache(TokenCache): ... :var bool has_state_changed: - Indicates whether the cache state has changed since last + Indicates whether the cache state in the memory has changed since last :func:`~serialize` or :func:`~deserialize` call. """ has_state_changed = False @@ -230,12 +284,9 @@ def add(self, event, **kwargs): super(SerializableTokenCache, self).add(event, **kwargs) self.has_state_changed = True - def remove_rt(self, rt_item): - super(SerializableTokenCache, self).remove_rt(rt_item) - self.has_state_changed = True - - def update_rt(self, rt_item, new_rt): - super(SerializableTokenCache, self).update_rt(rt_item, new_rt) + def modify(self, credential_type, old_entry, new_key_value_pairs=None): + super(SerializableTokenCache, self).modify( + credential_type, old_entry, new_key_value_pairs) self.has_state_changed = True def deserialize(self, state): diff --git a/tests/test_application.py b/tests/test_application.py index 6346774a..75d5d27b 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -174,11 +174,14 @@ def setUp(self): self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)} self.frt = "what the frt" self.cache = msal.SerializableTokenCache() + self.preexisting_family_app_id = "preexisting_family_app" self.cache.add({ # Pre-populate a FRT - "client_id": "preexisting_family_app", + "client_id": self.preexisting_family_app_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), "response": TokenCacheTestCase.build_response( + access_token="Siblings won't share AT. test_remove_account() will.", + id_token=TokenCacheTestCase.build_id_token(), uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"), }) # The add(...) helper populates correct home_account_id for future searching @@ -239,6 +242,35 @@ def tester(url, data=None, **kwargs): # Will not test scenario of app leaving family. Per specs, it won't happen. + def test_family_app_remove_account(self): + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + app = ClientApplication( + self.preexisting_family_app_id, + authority=self.authority_url, token_cache=self.cache) + account = app.get_accounts()[0] + mine = {"home_account_id": account["home_account_id"]} + + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.ACCESS_TOKEN, query=mine)) + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.REFRESH_TOKEN, query=mine)) + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.ID_TOKEN, query=mine)) + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.ACCOUNT, query=mine)) + + app.remove_account(account) + + self.assertEqual([], self.cache.find( + self.cache.CredentialType.ACCESS_TOKEN, query=mine)) + self.assertEqual([], self.cache.find( + self.cache.CredentialType.REFRESH_TOKEN, query=mine)) + self.assertEqual([], self.cache.find( + self.cache.CredentialType.ID_TOKEN, query=mine)) + self.assertEqual([], self.cache.find( + self.cache.CredentialType.ACCOUNT, query=mine)) + + class TestClientApplicationForAuthorityMigration(unittest.TestCase): @classmethod