From c6773bdf0e3768db62955159ed28f209a0e6e991 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 16 Aug 2021 15:34:30 -0700 Subject: [PATCH 1/4] Add AccessToken.refresh_on --- sdk/core/azure-core/azure/core/credentials.py | 34 +++++++++++++++--- sdk/core/azure-core/tests/test_credentials.py | 36 +++++++++++++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 sdk/core/azure-core/tests/test_credentials.py diff --git a/sdk/core/azure-core/azure/core/credentials.py b/sdk/core/azure-core/azure/core/credentials.py index e5146d7f947a..6407902bfcad 100644 --- a/sdk/core/azure-core/azure/core/credentials.py +++ b/sdk/core/azure-core/azure/core/credentials.py @@ -8,11 +8,9 @@ import six if TYPE_CHECKING: - from typing import Any, NamedTuple + from typing import Any, Optional from typing_extensions import Protocol - AccessToken = NamedTuple("AccessToken", [("token", str), ("expires_on", int)]) - class TokenCredential(Protocol): """Protocol for classes able to provide OAuth tokens. @@ -25,8 +23,34 @@ def get_token(self, *scopes, **kwargs): pass -else: - AccessToken = namedtuple("AccessToken", ["token", "expires_on"]) +class AccessToken(namedtuple("_AccessToken", ["token", "expires_on"])): + """An access token. + + :param str token: the access token itself + :param int expires_on: the Unix timestamp at which the token expires + :param int refresh_on: (optional) a Unix timestamp after which the token should be refreshed. Defaults to 5 minutes + before **expires_on**. + """ + + def __new__(cls, token, expires_on, refresh_on=None): # pylint:disable=unused-argument + # type: (str, int, Optional[int]) -> AccessToken + # AccessToken began as a namedtuple with "token" and "expires_on" fields. This class inherits that namedtuple + # to maintain API compatibility. This override enables adding "refresh_on" to the class but not the namedtuple. + return super(AccessToken, cls).__new__(cls, token, expires_on) + + def __init__(self, token, expires_on, refresh_on=None): # pylint:disable=super-init-not-called,unused-argument + # type: (str, int, Optional[int]) -> None + self._refresh_on = refresh_on or max(0, expires_on - 300) + + @property + def refresh_on(self): + # type: () -> int + """A Unix timestamp after which the token should be refreshed. + + :rtype: int + """ + return self._refresh_on + AzureNamedKey = namedtuple("AzureNamedKey", ["name", "key"]) diff --git a/sdk/core/azure-core/tests/test_credentials.py b/sdk/core/azure-core/tests/test_credentials.py new file mode 100644 index 000000000000..031c7e426e64 --- /dev/null +++ b/sdk/core/azure-core/tests/test_credentials.py @@ -0,0 +1,36 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.core.credentials import AccessToken + + +def test_accesstoken_compatibility(): + """AccessToken should remain compatible with its original namedtuple implementation""" + + access_token = "***" + expires_on = 42000 + + def assert_namedtuple_api_compatibility(token): + # AccessToken should have the same API as a namedtuple with "token" and "expires_on" fields + assert token == (access_token, expires_on) + assert token.token == access_token + assert token.expires_on == expires_on + assert token._asdict() == {"token": access_token, "expires_on": expires_on} + + # should be able to construct AccessToken with only "token" and "expires_on" + for token in (AccessToken(access_token, expires_on), AccessToken(token=access_token, expires_on=expires_on)): + assert_namedtuple_api_compatibility(token) + + # refresh_on defaults to expires_on - 300 + assert token.refresh_on == expires_on - 300 + + # refresh_on is an optional positional parameter AccessToken doesn't return during unpacking or iteration + refresh_on = 42 + for token in ( + AccessToken(access_token, expires_on, refresh_on), + AccessToken(access_token, expires_on, refresh_on=refresh_on), + ): + assert_namedtuple_api_compatibility(token) + assert token.refresh_on == refresh_on From bebd4c82930633dd46d1e1f3eed5ae3d478155f5 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 17 Aug 2021 09:11:29 -0700 Subject: [PATCH 2/4] BearerTokenCredentialPolicy observes refresh_on --- .../core/pipeline/policies/_authentication.py | 2 +- .../pipeline/policies/_authentication_async.py | 2 +- .../async_tests/test_authentication_async.py | 17 +++++++++++++++++ .../azure-core/tests/test_authentication.py | 17 +++++++++++++++++ 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 228e3fd20f58..ccf220958858 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -68,7 +68,7 @@ def _update_headers(headers, token): @property def _need_new_token(self): # type: () -> bool - return not self._token or self._token.expires_on - time.time() < 300 + return not self._token or self._token.refresh_on <= int(time.time()) class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy): diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 76564320b742..1551a682dc67 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -129,4 +129,4 @@ def on_exception(self, request: "PipelineRequest") -> "Union[bool, Awaitable[boo return False def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + return not self._token or self._token.refresh_on <= int(time.time()) diff --git a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py index 7230018aa37f..576f180b402d 100644 --- a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py @@ -215,6 +215,23 @@ async def fake_send(*args, **kwargs): policy.on_exception.assert_called_once_with(policy.request) +async def test_bearer_policy_token_refresh(): + """AsyncBearerTokenCredentialPolicy should observe a token's refresh_on value""" + now = int(time.time()) + + async def get_token(*_, **__): + return AccessToken("***", expires_on=now + 3600, refresh_on=now) + + credential = Mock(get_token=Mock(wraps=get_token)) + policy = AsyncBearerTokenCredentialPolicy(credential, "scope") + pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=[policy]) + + # the policy should call get_token for every request because each token's refresh_on is past + for n in range(4): + assert credential.get_token.call_count == n + await pipeline.run(HttpRequest("GET", "https://localhost")) + + def get_completed_future(result=None): fut = asyncio.Future() fut.set_result(result) diff --git a/sdk/core/azure-core/tests/test_authentication.py b/sdk/core/azure-core/tests/test_authentication.py index de029e8ea352..cc640d300286 100644 --- a/sdk/core/azure-core/tests/test_authentication.py +++ b/sdk/core/azure-core/tests/test_authentication.py @@ -252,6 +252,23 @@ def raise_the_second_time(*args, **kwargs): policy.on_exception.assert_called_once_with(policy.request) +def test_bearer_policy_token_refresh(): + """BearerTokenCredentialPolicy should observe a token's refresh_on value""" + now = int(time.time()) + + def get_token(*_, **__): + return AccessToken("***", expires_on=now + 3600, refresh_on=now) + + credential = Mock(get_token=Mock(wraps=get_token)) + policy = BearerTokenCredentialPolicy(credential, "scope") + pipeline = Pipeline(transport=Mock(), policies=[policy]) + + # the policy should call get_token for every request because each token's refresh_on is past + for n in range(4): + assert credential.get_token.call_count == n + pipeline.run(HttpRequest("GET", "https://localhost")) + + @pytest.mark.skipif(azure.core.__version__ >= "2", reason="this test applies only to azure-core 1.x") def test_key_vault_regression(): """Test for regression affecting azure-keyvault-* 4.0.0. This test must pass, unmodified, for all 1.x versions.""" From 72595b646f4131cc92592c1fe9a5e87e9c699cad Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 16 Aug 2021 16:48:54 -0700 Subject: [PATCH 3/4] raise minor version --- sdk/core/azure-core/azure/core/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/core/azure-core/azure/core/_version.py b/sdk/core/azure-core/azure/core/_version.py index 691adad6d1be..8b847f69ad1b 100644 --- a/sdk/core/azure-core/azure/core/_version.py +++ b/sdk/core/azure-core/azure/core/_version.py @@ -9,4 +9,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "1.17.1" +VERSION = "1.18.0" From 1db75faf262510d8976264e0b6b9ce26b10dc496 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 16 Aug 2021 16:50:16 -0700 Subject: [PATCH 4/4] changelog --- sdk/core/azure-core/CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index d9d8db2e1ada..7886d918d9d0 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -1,9 +1,12 @@ # Release History -## 1.17.1 (Unreleased) +## 1.18.0 (Unreleased) ### Features Added +- Added `AccessToken.refresh_on`, a timestamp after which the `AccessToken` + should be refreshed + ### Breaking Changes ### Bugs Fixed