From f44d2233614b080ae373f610255702bf840469bf Mon Sep 17 00:00:00 2001 From: yndu13 Date: Thu, 4 Jul 2024 15:25:25 +0800 Subject: [PATCH] refactor: solve the inconsistency of credentials refresh --- alibabacloud_credentials/client.py | 30 +++- alibabacloud_credentials/credentials.py | 181 +++++++++++++++++++++--- alibabacloud_credentials/models.py | 56 ++++++++ setup.py | 2 +- tests/test_client.py | 10 +- tests/test_credentials.py | 49 +++++++ tests/test_model.py | 38 ++++- tests/test_providers.py | 2 +- 8 files changed, 334 insertions(+), 34 deletions(-) diff --git a/alibabacloud_credentials/client.py b/alibabacloud_credentials/client.py index 587d74c..80bd909 100644 --- a/alibabacloud_credentials/client.py +++ b/alibabacloud_credentials/client.py @@ -1,8 +1,8 @@ from functools import wraps -from alibabacloud_credentials import credentials, providers -from alibabacloud_credentials.models import Config +from alibabacloud_credentials import credentials, providers, models from alibabacloud_credentials.utils import auth_constant as ac +from Tea.decorators import deprecated def attribute_error_return_none(f): @@ -24,10 +24,24 @@ def __init__(self, config=None): provider = providers.DefaultCredentialsProvider() self.cloud_credential = provider.get_credentials() return - self.cloud_credential = self.get_credential(config) + self.cloud_credential = Client.get_credentials(config) + + def get_credential(self) -> models.CredentialModel: + """ + Get credential + @return: the whole credential + """ + return self.cloud_credential.get_credential() + + async def get_credential_async(self) -> models.CredentialModel: + """ + Get credential + @return: the whole credential + """ + return await self.cloud_credential.get_credential_async() @staticmethod - def get_credential(config): + def get_credentials(config): if config.type == ac.ACCESS_KEY: return credentials.AccessKeyCredential(config.access_key_id, config.access_key_secret) elif config.type == ac.STS: @@ -68,28 +82,36 @@ def get_credential(config): providers.OIDCRoleArnCredentialProvider(config=config)) return providers.DefaultCredentialsProvider().get_credentials() + @deprecated("Use 'get_credential().access_key_id' instead") def get_access_key_id(self): return self.cloud_credential.get_access_key_id() + @deprecated("Use 'get_credential().access_key_secret' instead") def get_access_key_secret(self): return self.cloud_credential.get_access_key_secret() + @deprecated("Use 'get_credential().security_token' instead") def get_security_token(self): return self.cloud_credential.get_security_token() + @deprecated("Use 'get_credential_async().access_key_id' instead") async def get_access_key_id_async(self): return await self.cloud_credential.get_access_key_id_async() + @deprecated("Use 'get_credential_async().access_key_secret' instead") async def get_access_key_secret_async(self): return await self.cloud_credential.get_access_key_secret_async() + @deprecated("Use 'get_credential_async().security_token' instead") async def get_security_token_async(self): return await self.cloud_credential.get_security_token_async() + @deprecated("Use 'get_credential().type' instead") @attribute_error_return_none def get_type(self): return self.cloud_credential.credential_type + @deprecated("Use 'get_credential().bearer_token' instead") @attribute_error_return_none def get_bearer_token(self): return self.cloud_credential.bearer_token diff --git a/alibabacloud_credentials/credentials.py b/alibabacloud_credentials/credentials.py index bac424a..956b36a 100644 --- a/alibabacloud_credentials/credentials.py +++ b/alibabacloud_credentials/credentials.py @@ -8,6 +8,8 @@ from alibabacloud_credentials.utils import auth_constant as ac from alibabacloud_credentials.exceptions import CredentialException +from alibabacloud_credentials.models import CredentialModel + class Credential: def get_access_key_id(self): @@ -28,6 +30,12 @@ async def get_access_key_secret_async(self): async def get_security_token_async(self): return + def get_credential(self): + return + + async def get_credential_async(self): + return + class _AutomaticallyRefreshCredentials: def __init__(self, expiration, provider): @@ -35,6 +43,8 @@ def __init__(self, expiration, provider): self.provider = provider def _with_should_refresh(self): + if self.expiration is None: + return True return int(time.mktime(time.localtime())) >= (self.expiration - 180) def _get_new_credential(self): @@ -47,6 +57,10 @@ def _refresh_credential(self): async def _get_new_credential_async(self): return await self.provider.get_credentials_async() + async def _refresh_credential_async(self): + if self._with_should_refresh(): + return await self._get_new_credential_async() + class AccessKeyCredential(Credential): """AccessKeyCredential""" @@ -68,6 +82,20 @@ async def get_access_key_id_async(self): async def get_access_key_secret_async(self): return self.access_key_secret + def get_credential(self): + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + type=ac.ACCESS_KEY + ) + + async def get_credential_async(self): + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + type=ac.ACCESS_KEY + ) + class BearerTokenCredential(Credential): """BearerTokenCredential""" @@ -76,6 +104,18 @@ def __init__(self, bearer_token): self.bearer_token = bearer_token self.credential_type = ac.BEARER + def get_credential(self): + return CredentialModel( + bearer_token=self.bearer_token, + type=ac.BEARER + ) + + async def get_credential_async(self): + return CredentialModel( + bearer_token=self.bearer_token, + type=ac.BEARER + ) + class EcsRamRoleCredential(Credential, _AutomaticallyRefreshCredentials): """EcsRamRoleCredential""" @@ -89,7 +129,6 @@ def __init__(self, access_key_id, access_key_secret, security_token, expiration, def _refresh_credential(self): credential = super()._refresh_credential() - if credential: self.access_key_id = credential.access_key_id self.access_key_secret = credential.access_key_secret @@ -97,10 +136,7 @@ def _refresh_credential(self): self.security_token = credential.security_token async def _refresh_credential_async(self): - credential = None - if self._with_should_refresh(): - credential = await self._get_new_credential_async() - + credential = await super()._refresh_credential_async() if credential: self.access_key_id = credential.access_key_id self.access_key_secret = credential.access_key_secret @@ -131,6 +167,24 @@ async def get_security_token_async(self): await self._refresh_credential_async() return self.security_token + def get_credential(self): + self._refresh_credential() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.ECS_RAM_ROLE + ) + + async def get_credential_async(self): + await self._refresh_credential_async() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.ECS_RAM_ROLE + ) + class RamRoleArnCredential(Credential, _AutomaticallyRefreshCredentials): """RamRoleArnCredential""" @@ -151,10 +205,7 @@ def _refresh_credential(self): self.security_token = credential.security_token async def _refresh_credential_async(self): - credential = None - if self._with_should_refresh(): - credential = await self._get_new_credential_async() - + credential = await super()._refresh_credential_async() if credential: self.access_key_id = credential.access_key_id self.access_key_secret = credential.access_key_secret @@ -185,6 +236,25 @@ async def get_security_token_async(self): await self._refresh_credential_async() return self.security_token + def get_credential(self): + self._refresh_credential() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.RAM_ROLE_ARN + ) + + async def get_credential_async(self): + await self._refresh_credential_async() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.RAM_ROLE_ARN + ) + + class OIDCRoleArnCredential(Credential, _AutomaticallyRefreshCredentials): """OIDCRoleArnCredential""" @@ -204,10 +274,7 @@ def _refresh_credential(self): self.security_token = credential.security_token async def _refresh_credential_async(self): - credential = None - if self._with_should_refresh(): - credential = await self._get_new_credential_async() - + credential = await super()._refresh_credential_async() if credential: self.access_key_id = credential.access_key_id self.access_key_secret = credential.access_key_secret @@ -238,8 +305,26 @@ async def get_security_token_async(self): await self._refresh_credential_async() return self.security_token + def get_credential(self): + self._refresh_credential() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.OIDC_ROLE_ARN + ) + + async def get_credential_async(self): + await self._refresh_credential_async() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.OIDC_ROLE_ARN + ) -class CredentialsURICredential(): + +class CredentialsURICredential(Credential): """CredentialsURICredential""" def __init__(self, credentials_uri): @@ -276,7 +361,8 @@ def _get_new_credential(self): tea_request.query[key] = value response = TeaCore.do_action(tea_request) if response.status_code != 200: - raise CredentialException("Get credentials from " + self.credentials_uri + " failed, HttpCode=" + str(response.status_code)) + raise CredentialException( + "Get credentials from " + self.credentials_uri + " failed, HttpCode=" + str(response.status_code)) body = response.body.decode('utf-8') dic = json.loads(body) @@ -287,7 +373,8 @@ def _get_new_credential(self): content_expiration = dic.get('Expiration') if content_code != "Success": - raise CredentialException("Get credentials from " + self.credentials_uri + " failed, Code is " + content_code) + raise CredentialException( + "Get credentials from " + self.credentials_uri + " failed, Code is " + content_code) # 先转换为时间数组 time_array = time.strptime(content_expiration, "%Y-%m-%dT%H:%M:%SZ") @@ -307,7 +394,8 @@ async def _get_new_credential_async(self): tea_request.query = parse_qs(r.query) response = await TeaCore.async_do_action(tea_request) if response.status_code != 200: - raise CredentialException("Get credentials from " + self.credentials_uri + " failed, HttpCode=" + str(response.status_code)) + raise CredentialException( + "Get credentials from " + self.credentials_uri + " failed, HttpCode=" + str(response.status_code)) body = response.body.decode('utf-8') dic = json.loads(body) @@ -318,7 +406,8 @@ async def _get_new_credential_async(self): content_expiration = dic.get('Expiration') if content_code != "Success": - raise CredentialException("Get credentials from " + self.credentials_uri + " failed, Code is " + content_code) + raise CredentialException( + "Get credentials from " + self.credentials_uri + " failed, Code is " + content_code) # 先转换为时间数组 time_array = time.strptime(content_expiration, "%Y-%m-%dT%H:%M:%SZ") @@ -353,6 +442,25 @@ async def get_security_token_async(self): await self._ensure_credential_async() return self.security_token + def get_credential(self): + self._ensure_credential() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.CREDENTIALS_URI + ) + + async def get_credential_async(self): + await self._ensure_credential_async() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.CREDENTIALS_URI + ) + + class RsaKeyPairCredential(Credential, _AutomaticallyRefreshCredentials): def __init__(self, access_key_id, access_key_secret, expiration, provider): super().__init__(expiration, provider) @@ -368,10 +476,7 @@ def _refresh_credential(self): self.expiration = credential.expiration async def _refresh_credential_async(self): - credential = None - if self._with_should_refresh(): - credential = await self._get_new_credential_async() - + credential = await super()._refresh_credential_async() if credential: self.access_key_id = credential.access_key_id self.access_key_secret = credential.access_key_secret @@ -394,6 +499,22 @@ async def get_access_key_secret_async(self): await self._refresh_credential_async() return self.access_key_secret + def get_credential(self): + self._refresh_credential() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + type=ac.RSA_KEY_PAIR + ) + + async def get_credential_async(self): + await self._refresh_credential_async() + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + type=ac.RSA_KEY_PAIR + ) + class StsCredential(Credential): def __init__(self, access_key_id, access_key_secret, security_token): @@ -419,3 +540,19 @@ async def get_access_key_secret_async(self): async def get_security_token_async(self): return self.security_token + + def get_credential(self): + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.STS + ) + + async def get_credential_async(self): + return CredentialModel( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + type=ac.STS + ) diff --git a/alibabacloud_credentials/models.py b/alibabacloud_credentials/models.py index 28e75ca..424f304 100644 --- a/alibabacloud_credentials/models.py +++ b/alibabacloud_credentials/models.py @@ -170,3 +170,59 @@ def from_map(self, m: dict = None): if m.get('credentialsUri') is not None: self.credentials_uri = m.get('credentials_uri') return self + + +class CredentialModel(TeaModel): + def __init__( + self, + access_key_id: str = None, + access_key_secret: str = None, + security_token: str = None, + bearer_token: str = None, + type: str = None, + ): + # accesskey id + self.access_key_id = access_key_id + # accesskey secret + self.access_key_secret = access_key_secret + # security token + self.security_token = security_token + # bearer token + self.bearer_token = bearer_token + # type + self.type = type + + def validate(self): + pass + + def to_map(self): + _map = super().to_map() + if _map is not None: + return _map + + result = dict() + if self.access_key_id is not None: + result['accessKeyId'] = self.access_key_id + if self.access_key_secret is not None: + result['accessKeySecret'] = self.access_key_secret + if self.security_token is not None: + result['securityToken'] = self.security_token + if self.bearer_token is not None: + result['bearerToken'] = self.bearer_token + if self.type is not None: + result['type'] = self.type + return result + + def from_map(self, m: dict = None): + m = m or dict() + if m.get('accessKeyId') is not None: + self.access_key_id = m.get('accessKeyId') + if m.get('accessKeySecret') is not None: + self.access_key_secret = m.get('accessKeySecret') + if m.get('securityToken') is not None: + self.security_token = m.get('securityToken') + if m.get('bearerToken') is not None: + self.bearer_token = m.get('bearerToken') + if m.get('type') is not None: + self.type = m.get('type') + return self diff --git a/setup.py b/setup.py index e986b5b..1980646 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ 'packages': find_packages(exclude=["tests*"]), 'platforms': 'any', 'python_requires': '>=3.6', - 'install_requires': ['alibabacloud-tea'], + 'install_requires': ['alibabacloud-tea>=0.3.9'], 'classifiers': ( 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', diff --git a/tests/test_client.py b/tests/test_client.py index 34630da..8aed346 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -40,23 +40,23 @@ def test_client_bearer(self): def test_client_ecs_ram_role(self): conf = Config(type='ecs_ram_role') - self.assertIsInstance(Client.get_credential(conf), credentials.EcsRamRoleCredential) + self.assertIsInstance(Client.get_credentials(conf), credentials.EcsRamRoleCredential) def test_client_credentials_uri(self): conf = Config(type='credentials_uri') - self.assertIsInstance(Client.get_credential(conf), credentials.CredentialsURICredential) + self.assertIsInstance(Client.get_credentials(conf), credentials.CredentialsURICredential) def test_client_ram_role_arn(self): conf = Config(type='ram_role_arn') - self.assertIsInstance(Client.get_credential(conf), credentials.RamRoleArnCredential) + self.assertIsInstance(Client.get_credentials(conf), credentials.RamRoleArnCredential) def test_client_oidc_role_arn(self): conf = Config(type='oidc_role_arn', oidc_token_file_path='oidc_token_file_path') - self.assertIsInstance(Client.get_credential(conf), credentials.OIDCRoleArnCredential) + self.assertIsInstance(Client.get_credentials(conf), credentials.OIDCRoleArnCredential) def test_client_rsa_key_pair(self): conf = Config(type='rsa_key_pair') - self.assertIsInstance(Client.get_credential(conf), credentials.RsaKeyPairCredential) + self.assertIsInstance(Client.get_credentials(conf), credentials.RsaKeyPairCredential) def test_async_call(self): conf = Config( diff --git a/tests/test_credentials.py b/tests/test_credentials.py index bd461d6..9c3c3d2 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -38,6 +38,12 @@ def test_EcsRamRoleCredential(self): provider ) + model = cred.get_credential() + self.assertEqual('access_key_id', model.access_key_id) + self.assertEqual('access_key_secret', model.access_key_secret) + self.assertEqual('security_token', model.security_token) + self.assertEqual(900000000000, cred.expiration) + self.assertEqual('access_key_id', cred.get_access_key_id()) self.assertEqual('access_key_secret', cred.get_access_key_secret()) self.assertEqual('security_token', cred.get_security_token()) @@ -52,7 +58,14 @@ def test_EcsRamRoleCredential(self): 100, self.TestEcsRamRoleProvider() ) + # refresh token + model = cred.get_credential() + self.assertEqual('accessKeyId', model.access_key_id) + self.assertEqual('accessKeySecret', model.access_key_secret) + self.assertEqual('securityToken', model.security_token) + self.assertEqual(100000000000, cred.expiration) + self.assertEqual('accessKeyId', cred.get_access_key_id()) self.assertEqual('accessKeySecret', cred.get_access_key_secret()) self.assertEqual('securityToken', cred.get_security_token()) @@ -65,6 +78,11 @@ def test_AccessKeyCredential(self): access_key_id=access_key_id, access_key_secret=access_key_secret ) + model = cred.get_credential() + self.assertEqual('access_key_id', model.access_key_id) + self.assertEqual('access_key_secret', model.access_key_secret) + self.assertEqual('access_key', model.type) + self.assertEqual('access_key_id', cred.access_key_id) self.assertEqual('access_key_secret', cred.access_key_secret) self.assertEqual('access_key', cred.credential_type) @@ -72,6 +90,10 @@ def test_AccessKeyCredential(self): def test_BearerTokenCredential(self): bearer_token = 'bearer_token' cred = credentials.BearerTokenCredential(bearer_token=bearer_token) + model = cred.get_credential() + self.assertEqual('bearer_token', model.bearer_token) + self.assertEqual('bearer', model.type) + self.assertEqual('bearer_token', cred.bearer_token) self.assertEqual('bearer', cred.credential_type) @@ -98,6 +120,13 @@ def test_RamRoleArnCredential(self): # refresh token self.assertTrue(cred._with_should_refresh()) + model = cred.get_credential() + self.assertEqual('accessKeyId', model.access_key_id) + self.assertEqual('accessKeySecret', model.access_key_secret) + self.assertEqual('securityToken', model.security_token) + self.assertEqual('ram_role_arn', model.type) + self.assertEqual(100000000000, cred.expiration) + self.assertEqual('accessKeyId', cred.get_access_key_id()) self.assertEqual('accessKeySecret', cred.access_key_secret) self.assertEqual('securityToken', cred.security_token) @@ -136,6 +165,13 @@ def test_OIDCRoleArnCredential(self): # refresh token self.assertTrue(cred._with_should_refresh()) + model = cred.get_credential() + self.assertEqual('accessKeyId', model.access_key_id) + self.assertEqual('accessKeySecret', model.access_key_secret) + self.assertEqual('securityToken', model.security_token) + self.assertEqual('oidc_role_arn', model.type) + self.assertEqual(100000000000, cred.expiration) + self.assertEqual('accessKeyId', cred.get_access_key_id()) self.assertEqual('accessKeySecret', cred.access_key_secret) self.assertEqual('securityToken', cred.security_token) @@ -171,6 +207,13 @@ def test_RsaKeyPairCredential(self): ) # refresh token + + model = cred.get_credential() + self.assertEqual('accessKeyId', model.access_key_id) + self.assertEqual('accessKeySecret', model.access_key_secret) + self.assertEqual('rsa_key_pair', model.type) + self.assertEqual(100000000000, cred.expiration) + self.assertEqual('accessKeyId', cred.get_access_key_id()) self.assertEqual('accessKeySecret', cred.access_key_secret) self.assertEqual(100000000000, cred.expiration) @@ -196,3 +239,9 @@ def test_StsCredential(self): self.assertEqual('access_key_secret', cred.access_key_secret) self.assertEqual('security_token', cred.security_token) self.assertEqual('sts', cred.credential_type) + + model = cred.get_credential() + self.assertEqual('access_key_id', model.access_key_id) + self.assertEqual('access_key_secret', model.access_key_secret) + self.assertEqual('security_token', model.security_token) + self.assertEqual('sts', model.type) diff --git a/tests/test_model.py b/tests/test_model.py index f52ebfc..e01743a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,5 @@ import unittest -from alibabacloud_credentials.models import Config +from alibabacloud_credentials.models import Config, CredentialModel class TestModel(unittest.TestCase): @@ -21,3 +21,39 @@ def test_model_config(self): ) self.assertEqual('access_key_id', conf2.access_key_id) self.assertEqual('access_key_secret', conf2.access_key_secret) + + def test_model_credential(self): + cred = CredentialModel() + self.assertIsNone(cred.access_key_id) + self.assertIsNone(cred.access_key_secret) + self.assertIsNone(cred.security_token) + self.assertIsNone(cred.bearer_token) + self.assertIsNone(cred.type) + + cred = CredentialModel( + access_key_id='access_key_id', + access_key_secret='access_key_secret', + security_token='security_token', + bearer_token='bearer_token', + type='type', + ) + self.assertEqual('access_key_id', cred.access_key_id) + self.assertEqual('access_key_secret', cred.access_key_secret) + self.assertEqual('security_token', cred.security_token) + self.assertEqual('bearer_token', cred.bearer_token) + self.assertEqual('type', cred.type) + + cred_map = cred.to_map() + self.assertEqual('access_key_id', cred_map['accessKeyId']) + self.assertEqual('access_key_secret', cred_map['accessKeySecret']) + self.assertEqual('security_token', cred_map['securityToken']) + self.assertEqual('bearer_token', cred_map['bearerToken']) + self.assertEqual('type', cred_map['type']) + + cred = CredentialModel() + cred.from_map(cred_map) + self.assertEqual('access_key_id', cred.access_key_id) + self.assertEqual('access_key_secret', cred.access_key_secret) + self.assertEqual('security_token', cred.security_token) + self.assertEqual('bearer_token', cred.bearer_token) + self.assertEqual('type', cred.type) diff --git a/tests/test_providers.py b/tests/test_providers.py index 3e648fc..fabc788 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -375,4 +375,4 @@ def test_EnvironmentVariableCredentialsProvider(self): auth_util.environment_access_key_id = None auth_util.environment_access_key_secret = None - auth_util.environment_security_token = None \ No newline at end of file + auth_util.environment_security_token = None