Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def _str2bytes(raw):
return raw


def _clean_up(result):
if isinstance(result, dict):
result.pop("refresh_in", None) # MSAL handled refresh_in, customers need not
return result


class ClientApplication(object):

ACQUIRE_TOKEN_SILENT_ID = "84"
Expand Down Expand Up @@ -507,7 +513,7 @@ def authorize(): # A controller in a web app
return redirect(url_for("index"))
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_by_auth_code_flow(
return _clean_up(self.client.obtain_token_by_auth_code_flow(
auth_code_flow,
auth_response,
scope=decorate_scope(scopes, self.client_id) if scopes else None,
Expand All @@ -521,7 +527,7 @@ def authorize(): # A controller in a web app
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities,
auth_code_flow.pop("claims_challenge", None))),
**kwargs)
**kwargs))

def acquire_token_by_authorization_code(
self,
Expand Down Expand Up @@ -580,7 +586,7 @@ def acquire_token_by_authorization_code(
"Change your acquire_token_by_authorization_code() "
"to acquire_token_by_auth_code_flow()", DeprecationWarning)
with warnings.catch_warnings(record=True):
return self.client.obtain_token_by_authorization_code(
return _clean_up(self.client.obtain_token_by_authorization_code(
code, redirect_uri=redirect_uri,
scope=decorate_scope(scopes, self.client_id),
headers={
Expand All @@ -593,7 +599,7 @@ def acquire_token_by_authorization_code(
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
nonce=nonce,
**kwargs)
**kwargs))

def get_accounts(self, username=None):
"""Get a list of accounts which previously signed in, i.e. exists in cache.
Expand Down Expand Up @@ -855,13 +861,13 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
result = self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, decorate_scope(scopes, self.client_id), account,
force_refresh=force_refresh, claims_challenge=claims_challenge, **kwargs)
result = _clean_up(result)
if (result and "error" not in result) or (not access_token_from_cache):
return result
except: # The exact HTTP exception is transportation-layer dependent
logger.exception("Refresh token failed") # Potential AAD outage?
return access_token_from_cache


def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
self, authority, scopes, account, **kwargs):
query = {
Expand Down Expand Up @@ -987,7 +993,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
* A dict contains no "error" key means migration was successful.
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_by_refresh_token(
return _clean_up(self.client.obtain_token_by_refresh_token(
refresh_token,
scope=decorate_scope(scopes, self.client_id),
headers={
Expand All @@ -998,7 +1004,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
rt_getter=lambda rt: rt,
on_updating_rt=False,
on_removing_rt=lambda rt_item: None, # No OP
**kwargs)
**kwargs))


class PublicClientApplication(ClientApplication): # browser app or mobile app
Expand Down Expand Up @@ -1072,7 +1078,7 @@ def acquire_token_interactive(
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
claims = _merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)
return self.client.obtain_token_by_browser(
return _clean_up(self.client.obtain_token_by_browser(
scope=decorate_scope(scopes, self.client_id) if scopes else None,
extra_scope_to_consent=extra_scopes_to_consent,
redirect_uri="http://localhost:{port}".format(
Expand All @@ -1091,7 +1097,7 @@ def acquire_token_interactive(
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_INTERACTIVE),
},
**kwargs)
**kwargs))

def initiate_device_flow(self, scopes=None, **kwargs):
"""Initiate a Device Flow instance,
Expand Down Expand Up @@ -1134,7 +1140,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
- A successful response would contain "access_token" key,
- an error response would contain "error" and usually "error_description".
"""
return self.client.obtain_token_by_device_flow(
return _clean_up(self.client.obtain_token_by_device_flow(
flow,
data=dict(
kwargs.pop("data", {}),
Expand All @@ -1150,7 +1156,7 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID),
},
**kwargs)
**kwargs))

def acquire_token_by_username_password(
self, username, password, scopes, claims_challenge=None, **kwargs):
Expand Down Expand Up @@ -1188,15 +1194,15 @@ def acquire_token_by_username_password(
user_realm_result = self.authority.user_realm_discovery(
username, correlation_id=headers[CLIENT_REQUEST_ID])
if user_realm_result.get("account_type") == "Federated":
return self._acquire_token_by_username_password_federated(
return _clean_up(self._acquire_token_by_username_password_federated(
user_realm_result, username, password, scopes=scopes,
data=data,
headers=headers, **kwargs)
return self.client.obtain_token_by_username_password(
headers=headers, **kwargs))
return _clean_up(self.client.obtain_token_by_username_password(
username, password, scope=scopes,
headers=headers,
data=data,
**kwargs)
**kwargs))

def _acquire_token_by_username_password_federated(
self, user_realm_result, username, password, scopes=None, **kwargs):
Expand Down Expand Up @@ -1256,7 +1262,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
"""
# TBD: force_refresh behavior
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_for_client(
return _clean_up(self.client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
Expand All @@ -1267,7 +1273,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
**kwargs)
**kwargs))

def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs):
"""Acquires token using on-behalf-of (OBO) flow.
Expand Down Expand Up @@ -1297,7 +1303,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
"""
# The implementation is NOT based on Token Exchange
# https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16
return self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
return _clean_up(self.client.obtain_token_by_assertion( # bases on assertion RFC 7521
user_assertion,
self.client.GRANT_TYPE_JWT, # IDTs and AAD ATs are all JWTs
scope=decorate_scope(scopes, self.client_id), # Decoration is used for:
Expand All @@ -1316,4 +1322,4 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID),
},
**kwargs)
**kwargs))
34 changes: 21 additions & 13 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,19 +354,23 @@ def test_fresh_token_should_be_returned_from_cache(self):
# a.k.a. Return unexpired token that is not above token refresh expiration threshold
access_token = "An access token prepopulated into cache"
self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450)
self.assertEqual(
access_token,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
result = self.app.acquire_token_silent(['s1'], self.account)
self.assertEqual(access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

def test_aging_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt to refresh unexpired token when AAD available
self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1)
new_access_token = "new AT"
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
lambda *args, **kwargs: {"access_token": new_access_token})
self.assertEqual(
new_access_token,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
def mock_post(*args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": new_access_token,
"refresh_in": 123,
}))
self.app.http_client.post = mock_post
result = self.app.acquire_token_silent(['s1'], self.account)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

def test_aging_token_and_unavailable_aad_should_return_old_token(self):
# a.k.a. Attempt refresh unexpired token when AAD unavailable
Expand All @@ -393,9 +397,13 @@ def test_expired_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt refresh expired token when AAD available
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
new_access_token = "new AT"
self.app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family = (
lambda *args, **kwargs: {"access_token": new_access_token})
self.assertEqual(
new_access_token,
self.app.acquire_token_silent(['s1'], self.account).get("access_token"))
def mock_post(*args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": new_access_token,
"refresh_in": 123,
}))
self.app.http_client.post = mock_post
result = self.app.acquire_token_silent(['s1'], self.account)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")