diff --git a/msal/application.py b/msal/application.py index 9d22eace..410086d1 100644 --- a/msal/application.py +++ b/msal/application.py @@ -752,6 +752,11 @@ def _acquire_token_silent_by_finding_specific_refresh_token( response = client.obtain_token_by_refresh_token( entry, rt_getter=lambda token_item: token_item["secret"], on_removing_rt=rt_remover or self.token_cache.remove_rt, + on_obtaining_tokens=lambda event: self.token_cache.add(dict( + event, + environment=authority.instance, + skip_account_creation=True, # To honor a concurrent remove_account() + )), scope=scopes, headers={ CLIENT_REQUEST_ID: correlation_id or _get_new_correlation_id(), diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 1d9c21d5..90c1d31b 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -462,6 +462,7 @@ def __init__(self, def _obtain_token( self, grant_type, params=None, data=None, also_save_rt=False, + on_obtaining_tokens=None, *args, **kwargs): _data = data.copy() # to prevent side effect resp = super(Client, self)._obtain_token( @@ -481,7 +482,7 @@ def _obtain_token( # but our obtain_token_by_authorization_code(...) encourages # app developer to still explicitly provide a scope here. scope = _data.get("scope") - self.on_obtaining_tokens({ + (on_obtaining_tokens or self.on_obtaining_tokens)({ "client_id": self.client_id, "scope": scope, "token_endpoint": self.configuration["token_endpoint"], @@ -495,6 +496,7 @@ def obtain_token_by_refresh_token(self, token_item, scope=None, rt_getter=lambda token_item: token_item["refresh_token"], on_removing_rt=None, on_updating_rt=None, + on_obtaining_tokens=None, **kwargs): # type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict """This is an overload which will trigger token storage callbacks. diff --git a/msal/token_cache.py b/msal/token_cache.py index 83fc1891..b7ebbb99 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -172,7 +172,7 @@ def __add(self, event, now=None): at["key_id"] = data.get("key_id") self.modify(self.CredentialType.ACCESS_TOKEN, at, at) - if client_info: + if client_info and not event.get("skip_account_creation"): account = { "home_account_id": home_account_id, "environment": environment,