From cbed501b177f3060008bbc9281a8f615a32a27be Mon Sep 17 00:00:00 2001 From: 824750130 Date: Thu, 28 Jan 2021 11:28:43 +0800 Subject: [PATCH] Support async refresh credentials --- alibabacloud_credentials/client.py | 65 +++++----- alibabacloud_credentials/credentials.py | 116 +++++++++++++++++- alibabacloud_credentials/providers.py | 154 ++++++++++++++++++++---- tests/test_client.py | 13 +- tests/test_credentials.py | 4 +- tests/test_exceptions.py | 2 +- tests/test_model.py | 2 +- tests/test_providers.py | 86 ++++++++++++- tests/test_util.py | 2 +- 9 files changed, 368 insertions(+), 76 deletions(-) diff --git a/alibabacloud_credentials/client.py b/alibabacloud_credentials/client.py index 1901c0c..96d035d 100644 --- a/alibabacloud_credentials/client.py +++ b/alibabacloud_credentials/client.py @@ -25,38 +25,57 @@ def __init__(self, config=None): self.cloud_credential = provider.get_credentials() self.cloud_credential = self.get_credential(config) - def get_credential(self, config): + @staticmethod + def get_credential(config): if config.type == ac.ACCESS_KEY: return credentials.AccessKeyCredential(config.access_key_id, config.access_key_secret) elif config.type == ac.STS: return credentials.StsCredential(config.access_key_id, config.access_key_secret, config.security_token) elif config.type == ac.BEARER: return credentials.BearerTokenCredential(config.bearer_token) - else: - return self.get_provider(config).get_credentials() - - @staticmethod - def get_provider(config): - if config.type == ac.ECS_RAM_ROLE: - return providers.EcsRamRoleCredentialProvider(config=config) + elif config.type == ac.ECS_RAM_ROLE: + return credentials.EcsRamRoleCredential( + config.access_key_id, + config.access_key_secret, + config.security_token, + 0, + providers.EcsRamRoleCredentialProvider(config=config) + ) elif config.type == ac.RAM_ROLE_ARN: - return providers.RamRoleArnCredentialProvider(config=config) + return credentials.RamRoleArnCredential( + config.access_key_id, + config.access_key_secret, + config.security_token, + 0, + providers.RamRoleArnCredentialProvider(config=config) + ) elif config.type == ac.RSA_KEY_PAIR: - return providers.RsaKeyPairCredentialProvider(config=config) - return providers.DefaultCredentialsProvider() + return credentials.RsaKeyPairCredential( + config.access_key_id, + config.access_key_secret, + 0, + providers.RsaKeyPairCredentialProvider(config=config) + ) + return providers.DefaultCredentialsProvider().get_credentials() - @attribute_error_return_none def get_access_key_id(self): return self.cloud_credential.get_access_key_id() - @attribute_error_return_none def get_access_key_secret(self): return self.cloud_credential.get_access_key_secret() - @attribute_error_return_none def get_security_token(self): return self.cloud_credential.get_security_token() + async def get_access_key_id_async(self): + return await self.cloud_credential.get_access_key_id_async() + + async def get_access_key_secret_async(self): + return await self.cloud_credential.get_access_key_secret_async() + + async def get_security_token_async(self): + return await self.cloud_credential.get_security_token_async() + @attribute_error_return_none def get_type(self): return self.cloud_credential.credential_type @@ -64,21 +83,3 @@ def get_type(self): @attribute_error_return_none def get_bearer_token(self): return self.cloud_credential.bearer_token - - async def get_access_key_id_async(self): - if hasattr(self.cloud_credential, 'get_access_key_id'): - return self.cloud_credential.get_access_key_id() - else: - return - - async def get_access_key_secret_async(self): - if hasattr(self.cloud_credential, 'get_access_key_secret'): - return self.cloud_credential.get_access_key_secret() - else: - return - - async def get_security_token_async(self): - if hasattr(self.cloud_credential, 'get_security_token'): - return self.cloud_credential.get_security_token() - else: - return diff --git a/alibabacloud_credentials/credentials.py b/alibabacloud_credentials/credentials.py index 256b315..4523c04 100644 --- a/alibabacloud_credentials/credentials.py +++ b/alibabacloud_credentials/credentials.py @@ -3,6 +3,26 @@ from alibabacloud_credentials.utils import auth_constant as ac +class Credential: + def get_access_key_id(self): + return + + def get_access_key_secret(self): + return + + def get_security_token(self): + return + + async def get_access_key_id_async(self): + return + + async def get_access_key_secret_async(self): + return + + async def get_security_token_async(self): + return + + class _AutomaticallyRefreshCredentials: def __init__(self, expiration, provider): self.expiration = expiration @@ -18,8 +38,11 @@ def _refresh_credential(self): if self._with_should_refresh(): return self._get_new_credential() + async def _get_new_credential_async(self): + return await self.provider.get_credentials_async() -class AccessKeyCredential: + +class AccessKeyCredential(Credential): """AccessKeyCredential""" def __init__(self, access_key_id, access_key_secret): @@ -33,8 +56,14 @@ def get_access_key_id(self): def get_access_key_secret(self): return self.access_key_secret + async def get_access_key_id_async(self): + return self.access_key_id + + async def get_access_key_secret_async(self): + return self.access_key_secret + -class BearerTokenCredential: +class BearerTokenCredential(Credential): """BearerTokenCredential""" def __init__(self, bearer_token): @@ -42,7 +71,7 @@ def __init__(self, bearer_token): self.credential_type = ac.BEARER -class EcsRamRoleCredential(_AutomaticallyRefreshCredentials): +class EcsRamRoleCredential(Credential, _AutomaticallyRefreshCredentials): """EcsRamRoleCredential""" def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider): @@ -54,6 +83,18 @@ def __init__(self, access_key_id, access_key_secret, security_token, expiration, def _refresh_credential(self): credential = super()._refresh_credential() + + if credential: + self.access_key_id = credential.access_key_id + self.access_key_secret = credential.access_key_secret + self.expiration = credential.expiration + self.security_token = credential.security_token + + async def _refresh_credential_async(self): + credential = None + if self._with_should_refresh(): + credential = await self._get_new_credential_async() + if credential: self.access_key_id = credential.access_key_id self.access_key_secret = credential.access_key_secret @@ -72,8 +113,20 @@ def get_security_token(self): self._refresh_credential() return self.security_token + async def get_access_key_id_async(self): + await self._refresh_credential_async() + return self.access_key_id + + async def get_access_key_secret_async(self): + await self._refresh_credential_async() + return self.access_key_secret -class RamRoleArnCredential(_AutomaticallyRefreshCredentials): + async def get_security_token_async(self): + await self._refresh_credential_async() + return self.security_token + + +class RamRoleArnCredential(Credential, _AutomaticallyRefreshCredentials): """RamRoleArnCredential""" def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider): @@ -91,6 +144,17 @@ def _refresh_credential(self): self.expiration = credential.expiration self.security_token = credential.security_token + async def _refresh_credential_async(self): + credential = None + if self._with_should_refresh(): + credential = await self._get_new_credential_async() + + if credential: + self.access_key_id = credential.access_key_id + self.access_key_secret = credential.access_key_secret + self.expiration = credential.expiration + self.security_token = credential.security_token + def get_access_key_id(self): self._refresh_credential() return self.access_key_id @@ -103,8 +167,20 @@ def get_security_token(self): self._refresh_credential() return self.security_token + async def get_access_key_id_async(self): + await self._refresh_credential_async() + return self.access_key_id + + async def get_access_key_secret_async(self): + await self._refresh_credential_async() + return self.access_key_secret + + async def get_security_token_async(self): + await self._refresh_credential_async() + return self.security_token + -class RsaKeyPairCredential(_AutomaticallyRefreshCredentials): +class RsaKeyPairCredential(Credential, _AutomaticallyRefreshCredentials): def __init__(self, access_key_id, access_key_secret, expiration, provider): super().__init__(expiration, provider) self.access_key_id = access_key_id @@ -118,6 +194,17 @@ def _refresh_credential(self): self.access_key_secret = credential.access_key_secret self.expiration = credential.expiration + async def _refresh_credential_async(self): + credential = None + if self._with_should_refresh(): + credential = await self._get_new_credential_async() + + if credential: + self.access_key_id = credential.access_key_id + self.access_key_secret = credential.access_key_secret + self.expiration = credential.expiration + self.security_token = credential.security_token + def get_access_key_id(self): self._refresh_credential() return self.access_key_id @@ -126,8 +213,16 @@ def get_access_key_secret(self): self._refresh_credential() return self.access_key_secret + async def get_access_key_id_async(self): + await self._refresh_credential_async() + return self.access_key_id + + async def get_access_key_secret_async(self): + await self._refresh_credential_async() + return self.access_key_secret -class StsCredential: + +class StsCredential(Credential): def __init__(self, access_key_id, access_key_secret, security_token): self.access_key_id = access_key_id self.access_key_secret = access_key_secret @@ -142,3 +237,12 @@ def get_access_key_secret(self): def get_security_token(self): return self.security_token + + async def get_access_key_id_async(self): + return self.access_key_id + + async def get_access_key_secret_async(self): + return self.access_key_secret + + async def get_security_token_async(self): + return self.security_token diff --git a/alibabacloud_credentials/providers.py b/alibabacloud_credentials/providers.py index 70a124c..47ccf38 100644 --- a/alibabacloud_credentials/providers.py +++ b/alibabacloud_credentials/providers.py @@ -97,19 +97,30 @@ def __init__(self, role_name=None, config=None): self.__ecs_metadata_fetch_error_msg = "Failed to get RAM session credentials from ECS metadata service." self.__metadata_service_host = "100.100.100.200" self._set_arg('role_name', role_name) - self._set_credential_url() def _get_role_name(self, url=None): - url = url if url else self.credential_url + url = url if url else f'http://{self.__metadata_service_host}{self.__url_in_ecs_metadata}' response = requests.get(url, timeout=self.timeout / 1000) if response.status_code != 200: raise CredentialException(self.__ecs_metadata_fetch_error_msg + " HttpCode=" + str(response.status_code)) response.encoding = 'utf-8' self.role_name = response.text + async def _get_role_name_async(self, url=None): + tea_request = TeaRequest() + tea_request.headers['host'] = url if url else self.__metadata_service_host + if not url: + tea_request.pathname = self.__url_in_ecs_metadata + response = await TeaCore.async_do_action(tea_request) + if response.status_code != 200: + raise CredentialException(self.__ecs_metadata_fetch_error_msg + " HttpCode=" + str(response.status_code)) + self.role_name = response.body.decode('utf-8') + def _create_credential(self, url=None): tea_request = TeaRequest() - tea_request.headers['host'] = url if url else self.credential_url + tea_request.headers['host'] = url if url else self.__metadata_service_host + if not url: + tea_request.pathname = self.__url_in_ecs_metadata + self.role_name # request response = TeaCore.do_action(tea_request) @@ -136,11 +147,41 @@ def _create_credential(self, url=None): def get_credentials(self): if self.role_name == "": self._get_role_name() - self._set_credential_url() return self._create_credential() - def _set_credential_url(self): - self.credential_url = "http://" + self.__metadata_service_host + self.__url_in_ecs_metadata + self.role_name + async def _create_credential_async(self, url=None): + tea_request = TeaRequest() + tea_request.headers['host'] = url if url else self.__metadata_service_host + if not url: + tea_request.pathname = self.__url_in_ecs_metadata + self.role_name + + # request + response = await TeaCore.async_do_action(tea_request) + + if response.status_code != 200: + raise CredentialException(self.__ecs_metadata_fetch_error_msg + " HttpCode=" + str(response.status_code)) + + dic = json.loads(response.body.decode('utf-8')) + content_code = dic.get('Code') + content_access_key_id = dic.get('AccessKeyId') + content_access_key_secret = dic.get('AccessKeySecret') + content_security_token = dic.get('SecurityToken') + content_expiration = dic.get('Expiration') + + if content_code != "Success": + raise CredentialException(self.__ecs_metadata_fetch_error_msg) + + # 先转换为时间数组 + time_array = time.strptime(content_expiration, "%Y-%m-%dT%H:%M:%SZ") + # 转换为时间戳 + time_stamp = calendar.timegm(time_array) + return credentials.EcsRamRoleCredential(content_access_key_id, content_access_key_secret, + content_security_token, time_stamp, self) + + async def get_credentials_async(self): + if self.role_name == "": + await self._get_role_name_async() + return await self._create_credential_async() class RamRoleArnCredentialProvider(AlibabaCloudCredentialsProvider): @@ -179,7 +220,8 @@ def _create_credentials(self, turl=None): string_to_sign = ph.compose_string_to_sign("GET", tea_request.query) signature = ph.sign_string(string_to_sign, self.access_key_secret + "&") tea_request.query["Signature"] = signature - tea_request.headers['host'] = turl if turl else 'https://sts.aliyuncs.com' + tea_request.protocol = 'https' + tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com' # request response = TeaCore.do_action(tea_request) if response.status_code == 200: @@ -194,6 +236,43 @@ def _create_credentials(self, turl=None): cre.get("SecurityToken"), expiration, self) raise CredentialException(response.body.decode('utf-8')) + async def get_credentials_async(self): + return await self._create_credentials_async() + + async def _create_credentials_async(self, turl=None): + # 获取credential 先实现签名用工具类 + tea_request = TeaRequest() + tea_request.query = { + 'Action': 'AssumeRole', + 'Format': 'JSON', + 'Version': '2015-04-01', + 'DurationSeconds': str(self.duration_seconds), + 'RoleArn': self.role_arn, + 'AccessKeyId': self.access_key_id, + 'RegionId': self.region_id, + 'RoleSessionName': self.role_session_name + } + if self.policy is not None: + tea_request.query["Policy"] = self.policy + string_to_sign = ph.compose_string_to_sign("GET", tea_request.query) + signature = ph.sign_string(string_to_sign, self.access_key_secret + "&") + tea_request.query["Signature"] = signature + tea_request.protocol = 'https' + tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com' + # request + response = await TeaCore.async_do_action(tea_request) + if response.status_code == 200: + dic = json.loads(response.body.decode('utf-8')) + if "Credentials" in dic: + cre = dic.get("Credentials") + # 先转换为时间数组 + time_array = time.strptime(cre.get("Expiration"), "%Y-%m-%dT%H:%M:%SZ") + # 转换为时间戳 + expiration = calendar.timegm(time_array) + return credentials.RamRoleArnCredential(cre.get("AccessKeyId"), cre.get("AccessKeySecret"), + cre.get("SecurityToken"), expiration, self) + raise CredentialException(response.body.decode('utf-8')) + class RsaKeyPairCredentialProvider(AlibabaCloudCredentialsProvider): @@ -204,6 +283,37 @@ def __init__(self, access_key_id=None, access_key_secret=None, region_id=None, c self._set_arg('access_key_secret', access_key_secret) self._set_arg('region_id', region_id) + async def get_credentials_async(self): + return await self._create_credential_async() + + async def _create_credential_async(self, turl=None): + tea_request = TeaRequest() + tea_request.query = { + 'Action': 'GenerateSessionAccessKey', + 'Format': 'JSON', + 'Version': '2015-04-01', + 'DurationSeconds': str(self.duration_seconds), + 'AccessKeyId': self.access_key_id, + 'RegionId': self.region_id, + } + + str_to_sign = ph.compose_string_to_sign('GET', tea_request.query) + signature = ph.sign_string(str_to_sign, self.access_key_id + '&') + tea_request.query['Signature'] = signature + tea_request.protocol = 'https' + tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com' + # request + response = await TeaCore.async_do_action(tea_request) + if response.status_code == 200: + dic = json.loads(response.body.decode('utf-8')) + if "SessionAccessKey" in dic: + cre = dic.get("SessionAccessKey") + time_array = time.strptime(cre.get("Expiration"), "%Y-%m-%dT%H:%M:%SZ") + expiration = calendar.timegm(time_array) + return credentials.RsaKeyPairCredential(cre.get("SessionAccessKeyId"), cre.get("SessionAccessKeySecret"), + expiration, self) + raise CredentialException(response.body.decode('utf-8')) + def get_credentials(self): return self._create_credential() @@ -221,7 +331,8 @@ def _create_credential(self, turl=None): str_to_sign = ph.compose_string_to_sign('GET', tea_request.query) signature = ph.sign_string(str_to_sign, self.access_key_id + '&') tea_request.query['Signature'] = signature - tea_request.headers['host'] = turl if turl else 'https://sts.aliyuncs.com' + tea_request.protocol = 'https' + tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com' # request response = TeaCore.do_action(tea_request) if response.status_code == 200: @@ -240,7 +351,7 @@ def __init__(self, path=None): super().__init__() self._set_arg('file_path', path) - def get_credentials(self): + def parse_ini(self): file_path = self.file_path if self.file_path else au.environment_credentials_file if file_path is None: file_path = ac.DEFAULT_CREDENTIALS_FILE_PATH @@ -260,6 +371,10 @@ def get_credentials(self): option[key] = value.strip() ini_map[k] = option client_config = ini_map.get(au.client_type) + return client_config + + def get_credentials(self): + client_config = self.parse_ini() if client_config is None: return return self._create_credential(client_config) @@ -269,11 +384,11 @@ def _create_credential(self, config): if not config_type: raise CredentialException("The configured client type is empty") elif ac.INI_TYPE_ARN == config_type: - return self._get_sts_assume_role_session_credentials(config) + return self._get_sts_assume_role_session_provider(config).get_credentials() elif ac.INI_TYPE_KEY_PAIR == config_type: - return self._get_sts_get_session_access_key_credentials(config) + return self._get_sts_get_session_access_key_provider(config).get_credentials() elif ac.INI_TYPE_RAM == config_type: - return self._get_instance_profile_credentials(config) + return self._get_instance_profile_provider(config).get_credentials() access_key_id = config.get(ac.INI_ACCESS_KEY_ID) access_key_secret = config.get(ac.INI_ACCESS_KEY_IDSECRET) @@ -282,7 +397,7 @@ def _create_credential(self, config): return credentials.AccessKeyCredential(access_key_id, access_key_secret) @staticmethod - def _get_sts_assume_role_session_credentials(config): + def _get_sts_assume_role_session_provider(config): access_key_id = config.get(ac.INI_ACCESS_KEY_ID) access_key_secret = config.get(ac.INI_ACCESS_KEY_IDSECRET) role_session_name = config.get(ac.INI_ROLE_SESSION_NAME) @@ -294,13 +409,12 @@ def _get_sts_assume_role_session_credentials(config): raise CredentialException("The configured access_key_id or access_key_secret is empty") if not role_session_name or not role_arn: raise CredentialException("The configured role_session_name or role_arn is empty") - provider = RamRoleArnCredentialProvider( + return RamRoleArnCredentialProvider( access_key_id, access_key_secret, role_session_name, role_arn, region_id, policy ) - return provider.get_credentials() @staticmethod - def _get_sts_get_session_access_key_credentials(config): + def _get_sts_get_session_access_key_provider(config): public_key_id = config.get(ac.INI_PUBLIC_KEY_ID) private_key_file = config.get(ac.INI_PRIVATE_KEY_FILE) if not private_key_file: @@ -309,16 +423,14 @@ def _get_sts_get_session_access_key_credentials(config): if not public_key_id or not private_key: raise CredentialException("The configured public_key_id or private_key_file content is empty") - provider = RsaKeyPairCredentialProvider(public_key_id, private_key) - return provider.get_credentials() + return RsaKeyPairCredentialProvider(public_key_id, private_key) @staticmethod - def _get_instance_profile_credentials(config): + def _get_instance_profile_provider(config): role_name = config.get(ac.INI_ROLE_NAME) if not role_name: raise CredentialException("The configured role_name is empty") - provider = EcsRamRoleCredentialProvider(role_name) - return provider.get_credentials() + return EcsRamRoleCredentialProvider(role_name) class EnvironmentVariableCredentialsProvider(AlibabaCloudCredentialsProvider): diff --git a/tests/test_client.py b/tests/test_client.py index 7a72ccd..88762a2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,7 @@ from alibabacloud_credentials import providers -class TestCredentials(unittest.TestCase): +class TestClient(unittest.TestCase): def test_ak_client(self): conf = Config() conf.type = auth_constant.ACCESS_KEY @@ -23,7 +23,7 @@ def test_ak_client(self): cred = Client() cred.get_access_key_id() except Exception as e: - self.assertEqual(str(e), 'not found credentials') + self.assertEqual('not found credentials', str(e)) conf = Config(type='sts') cred = Client(conf) @@ -34,16 +34,13 @@ def test_ak_client(self): self.assertIsInstance(cred.cloud_credential, credentials.BearerTokenCredential) conf = Config(type='ecs_ram_role') - self.assertIsInstance(Client.get_provider(conf), providers.EcsRamRoleCredentialProvider) + self.assertIsInstance(Client.get_credential(conf), credentials.EcsRamRoleCredential) conf = Config(type='ram_role_arn') - self.assertIsInstance(Client.get_provider(conf), providers.RamRoleArnCredentialProvider) + self.assertIsInstance(Client.get_credential(conf), credentials.RamRoleArnCredential) conf = Config(type='rsa_key_pair') - self.assertIsInstance(Client.get_provider(conf), providers.RsaKeyPairCredentialProvider) - - conf = Config(type='test') - self.assertIsInstance(Client.get_provider(conf), providers.DefaultCredentialsProvider) + self.assertIsInstance(Client.get_credential(conf), credentials.RsaKeyPairCredential) conf = Config( access_key_id='ak1', diff --git a/tests/test_credentials.py b/tests/test_credentials.py index b07529f..d4a8124 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -49,8 +49,8 @@ def test_EcsRamRoleCredential(self): ) # refresh token self.assertEqual('accessKeyId', cred.get_access_key_id()) - self.assertEqual('accessKeySecret', cred.access_key_secret) - self.assertEqual('securityToken', cred.security_token) + self.assertEqual('accessKeySecret', cred.get_access_key_secret()) + self.assertEqual('securityToken', cred.get_security_token()) self.assertEqual(100000000000, cred.expiration) def test_AccessKeyCredential(self): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b676b0e..8bf6d32 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -2,7 +2,7 @@ from alibabacloud_credentials.exceptions import CredentialException -class TestCredentials(unittest.TestCase): +class TestException(unittest.TestCase): def test_CredentialException(self): try: raise CredentialException('error', 1000, 123456789) diff --git a/tests/test_model.py b/tests/test_model.py index 9c9a403..f3fdd15 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,7 +2,7 @@ from alibabacloud_credentials.models import Config -class TestCredentials(unittest.TestCase): +class TestModel(unittest.TestCase): def test_model_config(self): conf1 = Config() self.assertEqual('', conf1.access_key_id) diff --git a/tests/test_providers.py b/tests/test_providers.py index 3d5c6b2..296ebfc 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -2,6 +2,7 @@ import json import time import requests +import asyncio from alibabacloud_credentials.credentials import AccessKeyCredential from alibabacloud_credentials import providers, models, credentials, exceptions @@ -11,6 +12,8 @@ import threading from http.server import HTTPServer, BaseHTTPRequestHandler +loop = asyncio.get_event_loop() + class Request(BaseHTTPRequestHandler): def do_GET(self): @@ -52,14 +55,33 @@ def test_EcsRamRoleCredentialProvider(self): self.assertIsNotNone(prov) self.assertEqual("roleNameConfig", prov.role_name) self.assertEqual(2300, prov.timeout) - cred = prov._create_credential(url='http://127.0.0.1:8888') + cred = prov._create_credential(url='127.0.0.1:8888') self.assertEqual('ak', cred.access_key_id) prov._get_role_name(url='http://127.0.0.1:8888') self.assertIsNotNone(prov.role_name) - prov.role_name = 'role_name' - prov._set_credential_url() - self.assertEqual('http://100.100.100.200/latest/meta-data/ram/security-credentials/role_name', prov.credential_url) + + def test_EcsRamRoleCredentialProvider_async(self): + async def main(): + prov = providers.EcsRamRoleCredentialProvider("roleName") + self.assertIsNotNone(prov) + self.assertEqual("roleName", prov.role_name) + + cfg = models.Config() + cfg.role_name = "roleNameConfig" + cfg.timeout = 1100 + cfg.connect_timeout = 1200 + prov = providers.EcsRamRoleCredentialProvider(config=cfg) + self.assertIsNotNone(prov) + self.assertEqual("roleNameConfig", prov.role_name) + self.assertEqual(2300, prov.timeout) + cred = await prov._create_credential_async(url='127.0.0.1:8888') + self.assertEqual('ak', cred.access_key_id) + + await prov._get_role_name_async(url='127.0.0.1:8888') + self.assertIsNotNone(prov.role_name) + + loop.run_until_complete(main()) def test_DefaultCredentialsProvider(self): prov = providers.DefaultCredentialsProvider() @@ -123,6 +145,38 @@ def test_RamRoleArnCredentialProvider(self): cred = prov._create_credentials(turl='http://127.0.0.1:8888') self.assertEqual('AccessKeyId', cred.access_key_id) + def test_RamRoleArnCredentialProvider_async(self): + async def main(): + access_key_id, access_key_secret, role_session_name, role_arn, region_id, policy = \ + 'access_key_id', 'access_key_secret', 'role_session_name', 'role_arn', 'region_id', 'policy' + prov = providers.RamRoleArnCredentialProvider( + access_key_id, access_key_secret, role_session_name, role_arn, region_id, policy + ) + self.assertEqual('access_key_id', prov.access_key_id) + self.assertEqual('access_key_secret', prov.access_key_secret) + self.assertEqual('role_session_name', prov.role_session_name) + self.assertEqual('role_arn', prov.role_arn) + self.assertEqual('region_id', prov.region_id) + self.assertEqual('policy', prov.policy) + + conf = models.Config( + access_key_id=access_key_id, + access_key_secret=access_key_secret, + role_session_name=role_session_name, + role_arn=role_arn + ) + prov = providers.RamRoleArnCredentialProvider(config=conf) + self.assertEqual('access_key_id', prov.access_key_id) + self.assertEqual('access_key_secret', prov.access_key_secret) + self.assertEqual('role_session_name', prov.role_session_name) + self.assertEqual('role_arn', prov.role_arn) + self.assertEqual('cn-hangzhou', prov.region_id) + self.assertIsNone(prov.policy) + + cred = await prov._create_credentials_async(turl='http://127.0.0.1:8888') + self.assertEqual('AccessKeyId', cred.access_key_id) + loop.run_until_complete(main()) + def test_RsaKeyPairCredentialProvider(self): access_key_id, access_key_secret, region_id = \ 'access_key_id', 'access_key_secret', 'region_id' @@ -145,6 +199,30 @@ def test_RsaKeyPairCredentialProvider(self): cred = prov._create_credential(turl='http://127.0.0.1:8888') self.assertEqual('SessionAccessKeyId', cred.access_key_id) + def test_RsaKeyPairCredentialProvider_async(self): + async def main(): + access_key_id, access_key_secret, region_id = \ + 'access_key_id', 'access_key_secret', 'region_id' + prov = providers.RsaKeyPairCredentialProvider( + access_key_id, access_key_secret, region_id + ) + self.assertEqual('access_key_id', prov.access_key_id) + self.assertEqual('access_key_secret', prov.access_key_secret) + self.assertEqual('region_id', prov.region_id) + + conf = models.Config( + access_key_id=access_key_id, + access_key_secret=access_key_secret + ) + prov = providers.RsaKeyPairCredentialProvider(config=conf) + self.assertEqual('access_key_id', prov.access_key_id) + self.assertEqual('access_key_secret', prov.access_key_secret) + self.assertEqual('cn-hangzhou', prov.region_id) + + cred = await prov._create_credential_async(turl='http://127.0.0.1:8888') + self.assertEqual('SessionAccessKeyId', cred.access_key_id) + loop.run_until_complete(main()) + def test_ProfileCredentialsProvider(self): prov = providers.ProfileCredentialsProvider(ini_file) auth_util.client_type = 'default' diff --git a/tests/test_util.py b/tests/test_util.py index 40d959d..e965fce 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,7 +4,7 @@ from . import txt_file -class TestCredentials(unittest.TestCase): +class TestUtil(unittest.TestCase): def test_get_private_key(self): key = auth_util.get_private_key(txt_file)