From a6dbbfffac83992f97b5a394185900d5a40b2002 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Wed, 1 May 2019 16:19:23 -0700 Subject: [PATCH 1/4] WIP: Accurately remove ATs owned by sibling apps --- msal/application.py | 42 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/msal/application.py b/msal/application.py index 391cd1ae..17749801 100644 --- a/msal/application.py +++ b/msal/application.py @@ -280,6 +280,37 @@ def _get_authority_aliases(self, instance): return [alias for alias in group if alias != instance] return [] + def sign_out(self, account): + """Remove all relevant RTs and ATs from token cache""" + owned_by_account = { + "environment": account["environment"], + "home_account_id": (account or {}).get("home_account_id"),} + + owned_by_account_and_app = dict(owned_by_account, client=self.client_id) + for rt in self.token_cache.find( # Remove RTs + TokenCache.CredentialType.REFRESH_TOKEN, + query=owned_by_account_and_app): + self.token_cache.remove_rt(rt) + for at in self.token_cache.find( # Remove ATs + TokenCache.CredentialType.ACCESS_TOKEN, + query=owned_by_account_and_app): # regardless of realm + self.token_cache.remove_at(at) # TODO + + app_metadata = self._get_app_metadata(account["environment"]) + if app_metadata.get("family_id"): # Now let's settle family business + for rt in self.token_cache.find( # Remove FRTs + TokenCache.CredentialType.REFRESH_TOKEN, query=dict( + owned_by_account, + family_id=app_metadata["family_id"])): + self.token_cache.remove_rt(rt) + for sibling_app in self.token_cache.find( # Remove siblings' ATs + TokenCache.CredentialType.APP_METADATA, + query={"family_id": app_metadata.get["family_id"]}): + for at in self.token_cache.find( # Remove ATs, regardless of realm + TokenCache.CredentialType.ACCESS_TOKEN, query=dict( + owned_by_account, client_id=sibling_app["client_id"])): + self.token_cache.remove_at(at) # TODO + def acquire_token_silent( self, scopes, # type: List[str] @@ -364,10 +395,7 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( "home_account_id": (account or {}).get("home_account_id"), # "realm": authority.tenant, # AAD RTs are tenant-independent } - apps = self.token_cache.find( # Use find(), rather than token_cache.get(...) - TokenCache.CredentialType.APP_METADATA, query={ - "environment": authority.instance, "client_id": self.client_id}) - app_metadata = apps[0] if apps else {} + app_metadata = self._get_app_metadata(authority.instance) if not app_metadata: # Meaning this app is now used for the first time. # When/if we have a way to directly detect current app's family, # we'll rewrite this block, to support multiple families. @@ -396,6 +424,12 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( return self._acquire_token_silent_by_finding_specific_refresh_token( authority, scopes, dict(query, client_id=self.client_id), **kwargs) + def _get_app_metadata(self, environment): + apps = self.token_cache.find( # Use find(), rather than token_cache.get(...) + TokenCache.CredentialType.APP_METADATA, query={ + "environment": environment, "client_id": self.client_id}) + return apps[0] if apps else {} + def _acquire_token_silent_by_finding_specific_refresh_token( self, authority, scopes, query, rt_remover=None, break_condition=lambda response: False, **kwargs): From 167e954c50c11e84b00d633e42ddd94443081fcc Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Wed, 8 May 2019 11:03:26 -0700 Subject: [PATCH 2/4] Remove all AT, RT, FRT belongs to current account --- msal/application.py | 41 +++++++------------ msal/token_cache.py | 85 +++++++++++++++++++++++++++++---------- tests/test_application.py | 30 ++++++++++++++ 3 files changed, 107 insertions(+), 49 deletions(-) diff --git a/msal/application.py b/msal/application.py index 17749801..4290cba6 100644 --- a/msal/application.py +++ b/msal/application.py @@ -280,36 +280,23 @@ def _get_authority_aliases(self, instance): return [alias for alias in group if alias != instance] return [] - def sign_out(self, account): + def remove_account(self, home_account): """Remove all relevant RTs and ATs from token cache""" owned_by_account = { - "environment": account["environment"], - "home_account_id": (account or {}).get("home_account_id"),} - - owned_by_account_and_app = dict(owned_by_account, client=self.client_id) - for rt in self.token_cache.find( # Remove RTs - TokenCache.CredentialType.REFRESH_TOKEN, - query=owned_by_account_and_app): + "environment": home_account["environment"], + "home_account_id": home_account["home_account_id"],} # realm-independent + for rt in self.token_cache.find( # Remove RTs, and RTs are realm-independent + TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_account): self.token_cache.remove_rt(rt) - for at in self.token_cache.find( # Remove ATs - TokenCache.CredentialType.ACCESS_TOKEN, - query=owned_by_account_and_app): # regardless of realm - self.token_cache.remove_at(at) # TODO - - app_metadata = self._get_app_metadata(account["environment"]) - if app_metadata.get("family_id"): # Now let's settle family business - for rt in self.token_cache.find( # Remove FRTs - TokenCache.CredentialType.REFRESH_TOKEN, query=dict( - owned_by_account, - family_id=app_metadata["family_id"])): - self.token_cache.remove_rt(rt) - for sibling_app in self.token_cache.find( # Remove siblings' ATs - TokenCache.CredentialType.APP_METADATA, - query={"family_id": app_metadata.get["family_id"]}): - for at in self.token_cache.find( # Remove ATs, regardless of realm - TokenCache.CredentialType.ACCESS_TOKEN, query=dict( - owned_by_account, client_id=sibling_app["client_id"])): - self.token_cache.remove_at(at) # TODO + for at in self.token_cache.find( # Remove ATs, regardless of realm + TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_account): + self.token_cache.remove_at(at) + for idt in self.token_cache.find( # Remove IDTs, regardless of realm + TokenCache.CredentialType.ID_TOKEN, query=owned_by_account): + self.token_cache.remove_idt(idt) + for a in self.token_cache.find( # Remove Accounts, regardless of realm + TokenCache.CredentialType.ACCOUNT, query=owned_by_account): + self.token_cache.remove_account(a) def acquire_token_silent( self, diff --git a/msal/token_cache.py b/msal/token_cache.py index 8fd79e59..cad8e722 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -83,14 +83,9 @@ def add(self, event, now=None): with self._lock: if access_token: - key = "-".join([ - home_account_id or "", - environment or "", - self.CredentialType.ACCESS_TOKEN, - event.get("client_id", ""), - realm or "", - target, - ]).lower() + key = self._build_at_key( + home_account_id, environment, event.get("client_id", ""), + realm, target) now = time.time() if now is None else now expires_in = response.get("expires_in", 3599) self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {})[key] = { @@ -110,11 +105,7 @@ def add(self, event, now=None): if client_info: decoded_id_token = json.loads( base64decode(id_token.split('.')[1])) if id_token else {} - key = "-".join([ - home_account_id or "", - environment or "", - realm or "", - ]).lower() + key = self._build_account_key(home_account_id, environment, realm) self._cache.setdefault(self.CredentialType.ACCOUNT, {})[key] = { "home_account_id": home_account_id, "environment": environment, @@ -129,14 +120,8 @@ def add(self, event, now=None): } if id_token: - key = "-".join([ - home_account_id or "", - environment or "", - self.CredentialType.ID_TOKEN, - event.get("client_id", ""), - realm or "", - "" # Albeit irrelevant, schema requires an empty scope here - ]).lower() + key = self._build_idt_key( + home_account_id, environment, event.get("client_id", ""), realm) self._cache.setdefault(self.CredentialType.ID_TOKEN, {})[key] = { "credential_type": self.CredentialType.ID_TOKEN, "secret": id_token, @@ -178,7 +163,7 @@ def _build_appmetadata_key(environment, client_id): def _build_rt_key( cls, home_account_id=None, environment=None, client_id=None, target=None, - **ignored): + **ignored_payload_from_a_real_token): return "-".join([ home_account_id or "", environment or "", @@ -189,17 +174,73 @@ def _build_rt_key( ]).lower() def remove_rt(self, rt_item): + assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN key = self._build_rt_key(**rt_item) with self._lock: self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}).pop(key, None) def update_rt(self, rt_item, new_rt): + assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN key = self._build_rt_key(**rt_item) with self._lock: RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}) rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence rt["secret"] = new_rt + @classmethod + def _build_at_key(cls, + home_account_id=None, environment=None, client_id=None, + realm=None, target=None, **ignored_payload_from_a_real_token): + return "-".join([ + home_account_id or "", + environment or "", + cls.CredentialType.ACCESS_TOKEN, + client_id, + realm or "", + target or "", + ]).lower() + + def remove_at(self, at_item): + assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN + key = self._build_at_key(**at_item) + with self._lock: + self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {}).pop(key, None) + + @classmethod + def _build_idt_key(cls, + home_account_id=None, environment=None, client_id=None, realm=None, + **ignored_payload_from_a_real_token): + return "-".join([ + home_account_id or "", + environment or "", + cls.CredentialType.ID_TOKEN, + client_id or "", + realm or "", + "" # Albeit irrelevant, schema requires an empty scope here + ]).lower() + + def remove_idt(self, idt_item): + assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN + key = self._build_idt_key(**idt_item) + with self._lock: + self._cache.setdefault(self.CredentialType.ID_TOKEN, {}).pop(key, None) + + @classmethod + def _build_account_key(cls, + home_account_id=None, environment=None, realm=None, + **ignored_payload_from_a_real_entry): + return "-".join([ + home_account_id or "", + environment or "", + realm or "", + ]).lower() + + def remove_account(self, account_item): + assert "authority_type" in account_item + key = self._build_account_key(**account_item) + with self._lock: + self._cache.setdefault(self.CredentialType.ACCOUNT, {}).pop(key, None) + class SerializableTokenCache(TokenCache): """This serialization can be a starting point to implement your own persistence. diff --git a/tests/test_application.py b/tests/test_application.py index 6346774a..a542245f 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -179,6 +179,8 @@ def setUp(self): "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), "response": TokenCacheTestCase.build_response( + access_token="Siblings won't share AT. test_remove_account() will.", + id_token=TokenCacheTestCase.build_id_token(), uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"), }) # The add(...) helper populates correct home_account_id for future searching @@ -239,6 +241,34 @@ def tester(url, data=None, **kwargs): # Will not test scenario of app leaving family. Per specs, it won't happen. + def test_get_remove_account(self): + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + app = ClientApplication( + "family_app_2", authority=self.authority_url, token_cache=self.cache) + account = app.get_accounts()[0] + mine = {"home_account_id": account["home_account_id"]} + + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.ACCESS_TOKEN, query=mine)) + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.REFRESH_TOKEN, query=mine)) + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.ID_TOKEN, query=mine)) + self.assertNotEqual([], self.cache.find( + self.cache.CredentialType.ACCOUNT, query=mine)) + + app.remove_account(account) + + self.assertEqual([], self.cache.find( + self.cache.CredentialType.ACCESS_TOKEN, query=mine)) + self.assertEqual([], self.cache.find( + self.cache.CredentialType.REFRESH_TOKEN, query=mine)) + self.assertEqual([], self.cache.find( + self.cache.CredentialType.ID_TOKEN, query=mine)) + self.assertEqual([], self.cache.find( + self.cache.CredentialType.ACCOUNT, query=mine)) + + class TestClientApplicationForAuthorityMigration(unittest.TestCase): @classmethod From 6362813af63979055c6b593f79d6db662421a2b4 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Fri, 10 May 2019 14:40:50 -0700 Subject: [PATCH 3/4] Pivot to remove only RTs of this app or its family --- msal/application.py | 43 +++++++++++++++++++++++++++++++-------- tests/test_application.py | 8 +++++--- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/msal/application.py b/msal/application.py index 4290cba6..038badb6 100644 --- a/msal/application.py +++ b/msal/application.py @@ -280,22 +280,47 @@ def _get_authority_aliases(self, instance): return [alias for alias in group if alias != instance] return [] - def remove_account(self, home_account): - """Remove all relevant RTs and ATs from token cache""" - owned_by_account = { + def remove_account(self, account): + """Sign me out and forget me from token cache""" + self._forget_me(account) + + def _sign_out(self, home_account): + # Remove all relevant RTs and ATs from token cache + owned_by_home_account = { "environment": home_account["environment"], "home_account_id": home_account["home_account_id"],} # realm-independent - for rt in self.token_cache.find( # Remove RTs, and RTs are realm-independent - TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_account): + app_metadata = self._get_app_metadata(home_account["environment"]) + # Remove RTs/FRTs, and they are realm-independent + for rt in [rt for rt in self.token_cache.find( + TokenCache.CredentialType.REFRESH_TOKEN, query=owned_by_home_account) + # Do RT's app ownership check as a precaution, in case family apps + # and 3rd-party apps share same token cache, although they should not. + if rt["client_id"] == self.client_id or ( + app_metadata.get("family_id") # Now let's settle family business + and rt.get("family_id") == app_metadata["family_id"]) + ]: self.token_cache.remove_rt(rt) - for at in self.token_cache.find( # Remove ATs, regardless of realm - TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_account): + for at in self.token_cache.find( # Remove ATs + # Regardless of realm, b/c we've removed realm-independent RTs anyway + TokenCache.CredentialType.ACCESS_TOKEN, query=owned_by_home_account): + # To avoid the complexity of locating sibling family app's AT, + # we skip AT's app ownership check. + # It means ATs for other apps will also be removed, it is OK because: + # * non-family apps are not supposed to share token cache to begin with; + # * Even if it happens, we keep other app's RT already, so SSO still works self.token_cache.remove_at(at) + + def _forget_me(self, home_account): + # It implies signout, and then also remove all relevant accounts and IDTs + self._sign_out(home_account) + owned_by_home_account = { + "environment": home_account["environment"], + "home_account_id": home_account["home_account_id"],} # realm-independent for idt in self.token_cache.find( # Remove IDTs, regardless of realm - TokenCache.CredentialType.ID_TOKEN, query=owned_by_account): + TokenCache.CredentialType.ID_TOKEN, query=owned_by_home_account): self.token_cache.remove_idt(idt) for a in self.token_cache.find( # Remove Accounts, regardless of realm - TokenCache.CredentialType.ACCOUNT, query=owned_by_account): + TokenCache.CredentialType.ACCOUNT, query=owned_by_home_account): self.token_cache.remove_account(a) def acquire_token_silent( diff --git a/tests/test_application.py b/tests/test_application.py index a542245f..75d5d27b 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -174,8 +174,9 @@ def setUp(self): self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)} self.frt = "what the frt" self.cache = msal.SerializableTokenCache() + self.preexisting_family_app_id = "preexisting_family_app" self.cache.add({ # Pre-populate a FRT - "client_id": "preexisting_family_app", + "client_id": self.preexisting_family_app_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), "response": TokenCacheTestCase.build_response( @@ -241,10 +242,11 @@ def tester(url, data=None, **kwargs): # Will not test scenario of app leaving family. Per specs, it won't happen. - def test_get_remove_account(self): + def test_family_app_remove_account(self): logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) app = ClientApplication( - "family_app_2", authority=self.authority_url, token_cache=self.cache) + self.preexisting_family_app_id, + authority=self.authority_url, token_cache=self.cache) account = app.get_accounts()[0] mine = {"home_account_id": account["home_account_id"]} From 57b1195705bb62667359262428cad9410c316484 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Mon, 13 May 2019 14:03:30 -0700 Subject: [PATCH 4/4] TokenCache now have one modify() to rule them all. --- msal/token_cache.py | 58 ++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index cad8e722..e802eddd 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -39,6 +39,12 @@ class AuthorityType: def __init__(self): self._lock = threading.RLock() self._cache = {} + self.key_makers = { + self.CredentialType.REFRESH_TOKEN: self._build_rt_key, + self.CredentialType.ACCESS_TOKEN: self._build_at_key, + self.CredentialType.ID_TOKEN: self._build_idt_key, + self.CredentialType.ACCOUNT: self._build_account_key, + } def find(self, credential_type, target=None, query=None): target = target or [] @@ -155,6 +161,24 @@ def add(self, event, now=None): "family_id": response.get("foci"), # None is also valid } + def modify(self, credential_type, old_entry, new_key_value_pairs=None): + # Modify the specified old_entry with new_key_value_pairs, + # or remove the old_entry if the new_key_value_pairs is None. + + # This helper exists to consolidate all token modify/remove behaviors, + # so that the sub-classes will have only one method to work on, + # instead of patching a pair of update_xx() and remove_xx() per type. + # You can monkeypatch self.key_makers to support more types on-the-fly. + key = self.key_makers[credential_type](**old_entry) + with self._lock: + if new_key_value_pairs: # Update with them + entries = self._cache.setdefault(credential_type, {}) + entry = entries.get(key, {}) # key usually exists, but we'll survive its absence + entry.update(new_key_value_pairs) + else: # Remove old_entry + self._cache.setdefault(credential_type, {}).pop(key, None) + + @staticmethod def _build_appmetadata_key(environment, client_id): return "appmetadata-{}-{}".format(environment or "", client_id or "") @@ -175,17 +199,12 @@ def _build_rt_key( def remove_rt(self, rt_item): assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN - key = self._build_rt_key(**rt_item) - with self._lock: - self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}).pop(key, None) + return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item) def update_rt(self, rt_item, new_rt): assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN - key = self._build_rt_key(**rt_item) - with self._lock: - RTs = self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {}) - rt = RTs.get(key, {}) # key usually exists, but we'll survive its absence - rt["secret"] = new_rt + return self.modify( + self.CredentialType.REFRESH_TOKEN, rt_item, {"secret": new_rt}) @classmethod def _build_at_key(cls, @@ -202,9 +221,7 @@ def _build_at_key(cls, def remove_at(self, at_item): assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN - key = self._build_at_key(**at_item) - with self._lock: - self._cache.setdefault(self.CredentialType.ACCESS_TOKEN, {}).pop(key, None) + return self.modify(self.CredentialType.ACCESS_TOKEN, at_item) @classmethod def _build_idt_key(cls, @@ -221,9 +238,7 @@ def _build_idt_key(cls, def remove_idt(self, idt_item): assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN - key = self._build_idt_key(**idt_item) - with self._lock: - self._cache.setdefault(self.CredentialType.ID_TOKEN, {}).pop(key, None) + return self.modify(self.CredentialType.ID_TOKEN, idt_item) @classmethod def _build_account_key(cls, @@ -237,9 +252,7 @@ def _build_account_key(cls, def remove_account(self, account_item): assert "authority_type" in account_item - key = self._build_account_key(**account_item) - with self._lock: - self._cache.setdefault(self.CredentialType.ACCOUNT, {}).pop(key, None) + return self.modify(self.CredentialType.ACCOUNT, account_item) class SerializableTokenCache(TokenCache): @@ -262,7 +275,7 @@ class SerializableTokenCache(TokenCache): ... :var bool has_state_changed: - Indicates whether the cache state has changed since last + Indicates whether the cache state in the memory has changed since last :func:`~serialize` or :func:`~deserialize` call. """ has_state_changed = False @@ -271,12 +284,9 @@ def add(self, event, **kwargs): super(SerializableTokenCache, self).add(event, **kwargs) self.has_state_changed = True - def remove_rt(self, rt_item): - super(SerializableTokenCache, self).remove_rt(rt_item) - self.has_state_changed = True - - def update_rt(self, rt_item, new_rt): - super(SerializableTokenCache, self).update_rt(rt_item, new_rt) + def modify(self, credential_type, old_entry, new_key_value_pairs=None): + super(SerializableTokenCache, self).modify( + credential_type, old_entry, new_key_value_pairs) self.has_state_changed = True def deserialize(self, state):