Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Release History

## 1.17.1 (Unreleased)
## 1.18.0 (Unreleased)

### Features Added

- Added `AccessToken.refresh_on`, a timestamp after which the `AccessToken`
should be refreshed

### Breaking Changes

### Bugs Fixed
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# regenerated.
# --------------------------------------------------------------------------

VERSION = "1.17.1"
VERSION = "1.18.0"
34 changes: 29 additions & 5 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
import six

if TYPE_CHECKING:
from typing import Any, NamedTuple
from typing import Any, Optional
from typing_extensions import Protocol

AccessToken = NamedTuple("AccessToken", [("token", str), ("expires_on", int)])

class TokenCredential(Protocol):
"""Protocol for classes able to provide OAuth tokens.

Expand All @@ -25,8 +23,34 @@ def get_token(self, *scopes, **kwargs):
pass


else:
AccessToken = namedtuple("AccessToken", ["token", "expires_on"])
class AccessToken(namedtuple("_AccessToken", ["token", "expires_on"])):
"""An access token.

:param str token: the access token itself
:param int expires_on: the Unix timestamp at which the token expires
:param int refresh_on: (optional) a Unix timestamp after which the token should be refreshed. Defaults to 5 minutes
before **expires_on**.
"""

def __new__(cls, token, expires_on, refresh_on=None): # pylint:disable=unused-argument
# type: (str, int, Optional[int]) -> AccessToken
# AccessToken began as a namedtuple with "token" and "expires_on" fields. This class inherits that namedtuple
# to maintain API compatibility. This override enables adding "refresh_on" to the class but not the namedtuple.
return super(AccessToken, cls).__new__(cls, token, expires_on)

def __init__(self, token, expires_on, refresh_on=None): # pylint:disable=super-init-not-called,unused-argument
# type: (str, int, Optional[int]) -> None
self._refresh_on = refresh_on or max(0, expires_on - 300)

@property
def refresh_on(self):
# type: () -> int
"""A Unix timestamp after which the token should be refreshed.

:rtype: int
"""
return self._refresh_on


AzureNamedKey = namedtuple("AzureNamedKey", ["name", "key"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _update_headers(headers, token):
@property
def _need_new_token(self):
# type: () -> bool
return not self._token or self._token.expires_on - time.time() < 300
return not self._token or self._token.refresh_on <= int(time.time())


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,4 @@ def on_exception(self, request: "PipelineRequest") -> "Union[bool, Awaitable[boo
return False

def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300
return not self._token or self._token.refresh_on <= int(time.time())
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,23 @@ async def fake_send(*args, **kwargs):
policy.on_exception.assert_called_once_with(policy.request)


async def test_bearer_policy_token_refresh():
"""AsyncBearerTokenCredentialPolicy should observe a token's refresh_on value"""
now = int(time.time())

async def get_token(*_, **__):
return AccessToken("***", expires_on=now + 3600, refresh_on=now)

credential = Mock(get_token=Mock(wraps=get_token))
policy = AsyncBearerTokenCredentialPolicy(credential, "scope")
pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=[policy])

# the policy should call get_token for every request because each token's refresh_on is past
for n in range(4):
assert credential.get_token.call_count == n
await pipeline.run(HttpRequest("GET", "https://localhost"))


def get_completed_future(result=None):
fut = asyncio.Future()
fut.set_result(result)
Expand Down
17 changes: 17 additions & 0 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,23 @@ def raise_the_second_time(*args, **kwargs):
policy.on_exception.assert_called_once_with(policy.request)


def test_bearer_policy_token_refresh():
"""BearerTokenCredentialPolicy should observe a token's refresh_on value"""
now = int(time.time())

def get_token(*_, **__):
return AccessToken("***", expires_on=now + 3600, refresh_on=now)

credential = Mock(get_token=Mock(wraps=get_token))
policy = BearerTokenCredentialPolicy(credential, "scope")
pipeline = Pipeline(transport=Mock(), policies=[policy])

# the policy should call get_token for every request because each token's refresh_on is past
for n in range(4):
assert credential.get_token.call_count == n
pipeline.run(HttpRequest("GET", "https://localhost"))


@pytest.mark.skipif(azure.core.__version__ >= "2", reason="this test applies only to azure-core 1.x")
def test_key_vault_regression():
"""Test for regression affecting azure-keyvault-* 4.0.0. This test must pass, unmodified, for all 1.x versions."""
Expand Down
36 changes: 36 additions & 0 deletions sdk/core/azure-core/tests/test_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from azure.core.credentials import AccessToken


def test_accesstoken_compatibility():
"""AccessToken should remain compatible with its original namedtuple implementation"""

access_token = "***"
expires_on = 42000

def assert_namedtuple_api_compatibility(token):
# AccessToken should have the same API as a namedtuple with "token" and "expires_on" fields
assert token == (access_token, expires_on)
assert token.token == access_token
assert token.expires_on == expires_on
assert token._asdict() == {"token": access_token, "expires_on": expires_on}

# should be able to construct AccessToken with only "token" and "expires_on"
for token in (AccessToken(access_token, expires_on), AccessToken(token=access_token, expires_on=expires_on)):
assert_namedtuple_api_compatibility(token)

# refresh_on defaults to expires_on - 300
assert token.refresh_on == expires_on - 300

# refresh_on is an optional positional parameter AccessToken doesn't return during unpacking or iteration
refresh_on = 42
for token in (
AccessToken(access_token, expires_on, refresh_on),
AccessToken(access_token, expires_on, refresh_on=refresh_on),
):
assert_namedtuple_api_compatibility(token)
assert token.refresh_on == refresh_on