From 9b9e0b75a92bea6487506e3dfc2047f28d504b37 Mon Sep 17 00:00:00 2001 From: 824750130 Date: Wed, 26 Aug 2020 21:18:38 +0800 Subject: [PATCH] Improve the logic of automatic refresh token --- alibabacloud_credentials/client.py | 6 +- alibabacloud_credentials/credentials.py | 80 ++++++++++++++++--------- alibabacloud_credentials/providers.py | 7 ++- tests/test_credentials.py | 12 ++-- tests/test_providers.py | 11 ++-- 5 files changed, 70 insertions(+), 46 deletions(-) diff --git a/alibabacloud_credentials/client.py b/alibabacloud_credentials/client.py index 5038840..74d4f05 100644 --- a/alibabacloud_credentials/client.py +++ b/alibabacloud_credentials/client.py @@ -47,15 +47,15 @@ def get_provider(config): @attribute_error_return_none def get_access_key_id(self): - return self.cloud_credential.access_key_id + return self.cloud_credential.get_access_key_id() @attribute_error_return_none def get_access_key_secret(self): - return self.cloud_credential.access_key_secret + return self.cloud_credential.get_access_key_secret() @attribute_error_return_none def get_security_token(self): - return self.cloud_credential.security_token + return self.cloud_credential.get_security_token() @attribute_error_return_none def get_type(self): diff --git a/alibabacloud_credentials/credentials.py b/alibabacloud_credentials/credentials.py index 9ca1a6a..256b315 100644 --- a/alibabacloud_credentials/credentials.py +++ b/alibabacloud_credentials/credentials.py @@ -4,13 +4,12 @@ class _AutomaticallyRefreshCredentials: - def __init__(self, expiration, provider, refresh_fields): + def __init__(self, expiration, provider): self.expiration = expiration self.provider = provider - self._REFRESH_FIELDS = refresh_fields def _with_should_refresh(self): - return int(time.mktime(time.localtime())) >= (object.__getattribute__(self, 'expiration') - 180) + return int(time.mktime(time.localtime())) >= (self.expiration - 180) def _get_new_credential(self): return self.provider.get_credentials() @@ -19,11 +18,6 @@ def _refresh_credential(self): if self._with_should_refresh(): return self._get_new_credential() - def __getattribute__(self, item): - if item in object.__getattribute__(self, '__dict__')['_REFRESH_FIELDS']: - self._refresh_credential() - return object.__getattribute__(self, item) - class AccessKeyCredential: """AccessKeyCredential""" @@ -33,6 +27,12 @@ def __init__(self, access_key_id, access_key_secret): self.access_key_secret = access_key_secret self.credential_type = ac.ACCESS_KEY + def get_access_key_id(self): + return self.access_key_id + + def get_access_key_secret(self): + return self.access_key_secret + class BearerTokenCredential: """BearerTokenCredential""" @@ -46,13 +46,7 @@ class EcsRamRoleCredential(_AutomaticallyRefreshCredentials): """EcsRamRoleCredential""" def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider): - refresh_fields = ( - 'access_key_id', - 'access_key_secret', - 'security_token', - 'expiration' - ) - super().__init__(expiration, provider, refresh_fields) + super().__init__(expiration, provider) self.access_key_id = access_key_id self.access_key_secret = access_key_secret self.security_token = security_token @@ -66,18 +60,24 @@ def _refresh_credential(self): self.expiration = credential.expiration self.security_token = credential.security_token + def get_access_key_id(self): + self._refresh_credential() + return self.access_key_id + + def get_access_key_secret(self): + self._refresh_credential() + return self.access_key_secret + + def get_security_token(self): + self._refresh_credential() + return self.security_token + class RamRoleArnCredential(_AutomaticallyRefreshCredentials): """RamRoleArnCredential""" def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider): - refresh_fields = ( - 'access_key_id', - 'access_key_secret', - 'security_token', - 'expiration' - ) - super().__init__(expiration, provider, refresh_fields) + super().__init__(expiration, provider) self.access_key_id = access_key_id self.access_key_secret = access_key_secret self.security_token = security_token @@ -91,15 +91,22 @@ def _refresh_credential(self): self.expiration = credential.expiration self.security_token = credential.security_token + def get_access_key_id(self): + self._refresh_credential() + return self.access_key_id + + def get_access_key_secret(self): + self._refresh_credential() + return self.access_key_secret + + def get_security_token(self): + self._refresh_credential() + return self.security_token + class RsaKeyPairCredential(_AutomaticallyRefreshCredentials): def __init__(self, access_key_id, access_key_secret, expiration, provider): - refresh_fields = ( - 'access_key_id', - 'access_key_secret', - 'expiration' - ) - super().__init__(expiration, provider, refresh_fields) + super().__init__(expiration, provider) self.access_key_id = access_key_id self.access_key_secret = access_key_secret self.credential_type = ac.RSA_KEY_PAIR @@ -111,6 +118,14 @@ def _refresh_credential(self): self.access_key_secret = credential.access_key_secret self.expiration = credential.expiration + def get_access_key_id(self): + self._refresh_credential() + return self.access_key_id + + def get_access_key_secret(self): + self._refresh_credential() + return self.access_key_secret + class StsCredential: def __init__(self, access_key_id, access_key_secret, security_token): @@ -118,3 +133,12 @@ def __init__(self, access_key_id, access_key_secret, security_token): self.access_key_secret = access_key_secret self.security_token = security_token self.credential_type = ac.STS + + def get_access_key_id(self): + return self.access_key_id + + def get_access_key_secret(self): + return self.access_key_secret + + def get_security_token(self): + return self.security_token diff --git a/alibabacloud_credentials/providers.py b/alibabacloud_credentials/providers.py index ad384ab..6bc7f4c 100644 --- a/alibabacloud_credentials/providers.py +++ b/alibabacloud_credentials/providers.py @@ -2,6 +2,7 @@ import json import time import configparser +import calendar from alibabacloud_credentials.utils import auth_util as au, \ auth_constant as ac, \ @@ -123,7 +124,7 @@ def _create_credential(self, url=None): # 先转换为时间数组 time_array = time.strptime(expiration_str, "%Y-%m-%d %H:%M:%S") # 转换为时间戳 - time_stamp = int(time.mktime(time_array)) + time_stamp = calendar.timegm(time_array) return credentials.EcsRamRoleCredential(content_access_key_id, content_access_key_secret, content_security_token, time_stamp, self) @@ -184,7 +185,7 @@ def _create_credentials(self, turl=None): # 先转换为时间数组 time_array = time.strptime(expiration_str, "%Y-%m-%d %H:%M:%S") # 转换为时间戳 - expiration = int(time.mktime(time_array)) + expiration = calendar.timegm(time_array) return credentials.RamRoleArnCredential(cre.get("AccessKeyId"), cre.get("AccessKeySecret"), cre.get("SecurityToken"), expiration, self) raise CredentialException(response.text) @@ -225,7 +226,7 @@ def _create_credential(self, turl=None): cre = dic.get("SessionAccessKey") expiration_str = cre.get("Expiration").replace("T", " ").replace("Z", "") time_array = time.strptime(expiration_str, "%Y-%m-%d %H:%M:%S") - expiration = int(time.mktime(time_array)) + expiration = calendar.timegm(time_array) return credentials.RsaKeyPairCredential(cre.get("SessionAccessKeyId"), cre.get("SessionAccessKeySecret"), expiration, self) raise CredentialException(resp.text) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 7b78e44..b07529f 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -33,9 +33,9 @@ def test_EcsRamRoleCredential(self): provider ) - self.assertEqual('access_key_id', cred.access_key_id) - self.assertEqual('access_key_secret', cred.access_key_secret) - self.assertEqual('security_token', cred.security_token) + self.assertEqual('access_key_id', cred.get_access_key_id()) + self.assertEqual('access_key_secret', cred.get_access_key_secret()) + self.assertEqual('security_token', cred.get_security_token()) self.assertEqual(900000000000, cred.expiration) self.assertIsInstance(cred.provider, providers.EcsRamRoleCredentialProvider) self.assertEqual('ecs_ram_role', cred.credential_type) @@ -48,7 +48,7 @@ def test_EcsRamRoleCredential(self): self.TestEcsRamRoleProvider() ) # refresh token - self.assertEqual('accessKeyId', cred.access_key_id) + self.assertEqual('accessKeyId', cred.get_access_key_id()) self.assertEqual('accessKeySecret', cred.access_key_secret) self.assertEqual('securityToken', cred.security_token) self.assertEqual(100000000000, cred.expiration) @@ -93,7 +93,7 @@ def test_RamRoleArnCredential(self): # refresh token self.assertTrue(cred._with_should_refresh()) - self.assertEqual('accessKeyId', cred.access_key_id) + self.assertEqual('accessKeyId', cred.get_access_key_id()) self.assertEqual('accessKeySecret', cred.access_key_secret) self.assertEqual('securityToken', cred.security_token) self.assertEqual(100000000000, cred.expiration) @@ -128,7 +128,7 @@ def test_RsaKeyPairCredential(self): ) # refresh token - self.assertEqual('accessKeyId', cred.access_key_id) + self.assertEqual('accessKeyId', cred.get_access_key_id()) self.assertEqual('accessKeySecret', cred.access_key_secret) self.assertEqual(100000000000, cred.expiration) diff --git a/tests/test_providers.py b/tests/test_providers.py index e5e5dc9..01202df 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -19,8 +19,8 @@ def do_GET(self): self.end_headers() self.wfile.write(b'{"Code": "Success", "AccessKeyId": "ak",' b' "Expiration": "3999-08-07T20:20:20Z", "Credentials":' - b' {"Expiration": "3999-08-07T20:20:20Z"}, "SessionAccessKey":' - b' {"Expiration": "3999-08-07T20:20:20Z"}}') + b' {"Expiration": "3999-08-07T20:20:20Z", "AccessKeyId": "AccessKeyId"}, "SessionAccessKey":' + b' {"Expiration": "3999-08-07T20:20:20Z", "SessionAccessKeyId": "SessionAccessKeyId"}}') def run_server(): @@ -55,9 +55,8 @@ def test_EcsRamRoleCredentialProvider(self): # prov._create_credential(url='http://www.aliyun.com') cred = prov._create_credential(url='http://127.0.0.1:8888') self.assertEqual('ak', cred.access_key_id) - self.assertEqual('3999-08-07T20:20:20Z', self.strftime(cred.expiration)) - prov._get_role_name(url='http://www.aliyun.com') + prov._get_role_name(url='http://127.0.0.1:8888') self.assertIsNotNone(prov.role_name) prov.role_name = 'role_name' prov._set_credential_url() @@ -123,7 +122,7 @@ def test_RamRoleArnCredentialProvider(self): self.assertIsNone(prov.policy) cred = prov._create_credentials(turl='http://127.0.0.1:8888') - self.assertEqual('3999-08-07T20:20:20Z', self.strftime(cred.expiration)) + self.assertEqual('AccessKeyId', cred.access_key_id) def test_RsaKeyPairCredentialProvider(self): access_key_id, access_key_secret, region_id = \ @@ -145,7 +144,7 @@ def test_RsaKeyPairCredentialProvider(self): self.assertEqual('cn-hangzhou', prov.region_id) cred = prov._create_credential(turl='http://127.0.0.1:8888') - self.assertEqual('3999-08-07T20:20:20Z', self.strftime(cred.expiration)) + self.assertEqual('SessionAccessKeyId', cred.access_key_id) def test_ProfileCredentialsProvider(self): prov = providers.ProfileCredentialsProvider(ini_file)