From 4a876cab4e15944a6f63ededbbad5f049159a9f0 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Wed, 29 Jun 2022 14:22:38 -0700 Subject: [PATCH] Implement an app_token_provider callback --- msal/application.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/msal/application.py b/msal/application.py index 838a28d8..bef3b62f 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1687,6 +1687,14 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs): class ConfidentialClientApplication(ClientApplication): # server-side web app + def __init__(self, client_id, **kwargs): + self._app_token_provider = kwargs.pop("app_token_provider", None) + if self._app_token_provider: + warnings.warn( + "The undocumented app_token_provider param is subject to change", + PendingDeprecationWarning) + super(ConfidentialClientApplication, self).__init__(client_id, **kwargs) + def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): """Acquires token for the current confidential client, not for an end user. @@ -1704,6 +1712,21 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): - an error response would contain "error" and usually "error_description". """ # TBD: force_refresh behavior + if self._app_token_provider: + response = self._app_token_provider( + scopes=scopes, + client_id=self.client_id, + ) # Return value should be like + # {"access_token": "...", "expires_in": 123, ...} + # or {"error": "...", "error_description": "..."} + if "error" not in response: + self.token_cache.add(dict( + client_id=self.client_id, + scope=response["scope"].split() if "scope" in response else scopes, + token_endpoint=self.authority.token_endpoint, + response=response.copy(), + )) + return response if self.authority.tenant.lower() in ["common", "organizations"]: warnings.warn( "Using /common or /organizations authority "