From fca3c6bd8f7f3ffb0910520f933a1f599f219bd4 Mon Sep 17 00:00:00 2001 From: yndu13 Date: Mon, 8 Jul 2024 21:07:09 +0800 Subject: [PATCH] feat: support region or endpoint for sts requests --- alibabacloud_credentials/models.py | 9 ++++- alibabacloud_credentials/providers.py | 42 ++++++++++++--------- alibabacloud_credentials/utils/auth_util.py | 2 + tests/test_model.py | 1 + tests/test_providers.py | 40 ++++++++++++++------ 5 files changed, 65 insertions(+), 29 deletions(-) diff --git a/alibabacloud_credentials/models.py b/alibabacloud_credentials/models.py index 5536df5..5a91722 100644 --- a/alibabacloud_credentials/models.py +++ b/alibabacloud_credentials/models.py @@ -31,7 +31,8 @@ def __init__( proxy: str = '', credentials_uri: str = '', enable_imds_v2: bool = False, - metadata_token_duration: int = 21600 + metadata_token_duration: int = 21600, + sts_endpoint: str = None ): # accesskey id self.access_key_id = access_key_id @@ -71,6 +72,8 @@ def __init__( self.proxy = proxy # credentials uri self.credentials_uri = credentials_uri + # STS Endpoint + self.sts_endpoint = sts_endpoint def validate(self): pass @@ -121,6 +124,8 @@ def to_map(self): result['proxy'] = self.proxy if self.credentials_uri is not None: result['credentialsUri'] = self.credentials_uri + if self.sts_endpoint is not None: + result['stsEndpoint'] = self.sts_endpoint return result def from_map(self, m: dict = None): @@ -169,6 +174,8 @@ def from_map(self, m: dict = None): self.proxy = m.get('proxy') if m.get('credentialsUri') is not None: self.credentials_uri = m.get('credentials_uri') + if m.get('stsEndpoint') is not None: + self.sts_endpoint = m.get('stsEndpoint') return self diff --git a/alibabacloud_credentials/providers.py b/alibabacloud_credentials/providers.py index 5f5ee9f..5f8ee36 100644 --- a/alibabacloud_credentials/providers.py +++ b/alibabacloud_credentials/providers.py @@ -19,7 +19,6 @@ class AlibabaCloudCredentialsProvider: """BaseProvider class""" duration_seconds = 3600 timeout = 2000 - region_id = 'cn-hangzhou' def __init__(self, config=None): if isinstance(config, Config): @@ -41,6 +40,7 @@ def __init__(self, config=None): self.timeout = config.timeout + config.connect_timeout self.connect_timeout = config.connect_timeout self.proxy = config.proxy + self.sts_endpoint = config.sts_endpoint def _set_arg(self, key, value): if value is not None: @@ -267,11 +267,18 @@ def __init__(self, access_key_id=None, access_key_secret=None, role_session_name self._set_arg('region_id', region_id) self._set_arg('role_session_name', role_session_name) self._set_arg('policy', policy) + if region_id is None and au.environment_sts_region is not None: + self._set_arg('region_id', au.environment_sts_region) + if self.region_id is not None: + self._set_arg('sts_endpoint', f'sts.{self.region_id}.aliyuncs.com') + else: + self._set_arg('sts_endpoint', + 'sts.aliyuncs.com' if config is None or config.sts_endpoint is None else config.sts_endpoint) def get_credentials(self): return self._create_credentials() - def _create_credentials(self, turl=None): + def _create_credentials(self): # 获取credential 先实现签名用工具类 tea_request = ph.get_new_request() tea_request.query = { @@ -281,7 +288,6 @@ def _create_credentials(self, turl=None): 'DurationSeconds': str(self.duration_seconds), 'RoleArn': self.role_arn, 'AccessKeyId': self.access_key_id, - 'RegionId': self.region_id, 'RoleSessionName': self.role_session_name, 'SignatureMethod': 'HMAC-SHA1', 'SignatureVersion': '1.0' @@ -294,7 +300,7 @@ def _create_credentials(self, turl=None): signature = ph.sign_string(string_to_sign, self.access_key_secret + "&") tea_request.query["Signature"] = signature tea_request.protocol = 'https' - tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com' + tea_request.headers['host'] = self.sts_endpoint # request response = TeaCore.do_action(tea_request) if response.status_code == 200: @@ -312,7 +318,7 @@ def _create_credentials(self, turl=None): async def get_credentials_async(self): return await self._create_credentials_async() - async def _create_credentials_async(self, turl=None): + async def _create_credentials_async(self): # 获取credential 先实现签名用工具类 tea_request = ph.get_new_request() tea_request.query = { @@ -322,7 +328,6 @@ async def _create_credentials_async(self, turl=None): 'DurationSeconds': str(self.duration_seconds), 'RoleArn': self.role_arn, 'AccessKeyId': self.access_key_id, - 'RegionId': self.region_id, 'RoleSessionName': self.role_session_name, 'SignatureMethod': 'HMAC-SHA1', 'SignatureVersion': '1.0' @@ -335,7 +340,7 @@ async def _create_credentials_async(self, turl=None): signature = ph.sign_string(string_to_sign, self.access_key_secret + "&") tea_request.query["Signature"] = signature tea_request.protocol = 'https' - tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com' + tea_request.headers['host'] = self.sts_endpoint # request response = await TeaCore.async_do_action(tea_request) if response.status_code == 200: @@ -375,11 +380,18 @@ def __init__(self, role_session_name=None, role_arn=None, self._set_arg('region_id', region_id) self._set_arg('role_session_name', role_session_name) self._set_arg('policy', policy) + if region_id is None and au.environment_sts_region is not None: + self._set_arg('region_id', au.environment_sts_region) + if self.region_id is not None: + self._set_arg('sts_endpoint', f'sts.{self.region_id}.aliyuncs.com') + else: + self._set_arg('sts_endpoint', + 'sts.aliyuncs.com' if config is None or config.sts_endpoint is None else config.sts_endpoint) def get_credentials(self): return self._create_credentials() - def _create_credentials(self, turl=None): + def _create_credentials(self): # 获取credential 先实现签名用工具类 oidc_token = au.get_private_key(self.oidc_token_file_path) tea_request = ph.get_new_request() @@ -387,19 +399,18 @@ def _create_credentials(self, turl=None): 'Action': 'AssumeRoleWithOIDC', 'Format': 'JSON', 'Version': '2015-04-01', - 'RegionId': self.region_id, 'DurationSeconds': str(self.duration_seconds), 'RoleArn': self.role_arn, 'OIDCProviderArn': self.oidc_provider_arn, 'OIDCToken': oidc_token, - 'RoleSessionName': self.role_session_name if self.role_session_name else 'defaultSessionName' + 'RoleSessionName': self.role_session_name or 'defaultSessionName' } tea_request.query["Timestamp"] = ph.get_iso_8061_date() tea_request.query["SignatureNonce"] = ph.get_uuid() if self.policy is not None: tea_request.query["Policy"] = self.policy tea_request.protocol = 'https' - tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com' + tea_request.headers['host'] = self.sts_endpoint # request response = TeaCore.do_action(tea_request) if response.status_code == 200: @@ -417,7 +428,7 @@ def _create_credentials(self, turl=None): async def get_credentials_async(self): return await self._create_credentials_async() - async def _create_credentials_async(self, turl=None): + async def _create_credentials_async(self): # 获取credential 先实现签名用工具类 oidc_token = au.get_private_key(self.oidc_token_file_path) tea_request = ph.get_new_request() @@ -425,19 +436,18 @@ async def _create_credentials_async(self, turl=None): 'Action': 'AssumeRoleWithOIDC', 'Format': 'JSON', 'Version': '2015-04-01', - 'RegionId': self.region_id, 'DurationSeconds': str(self.duration_seconds), 'RoleArn': self.role_arn, 'OIDCProviderArn': self.oidc_provider_arn, 'OIDCToken': oidc_token, - 'RoleSessionName': self.role_session_name if self.role_session_name else 'defaultSessionName' + 'RoleSessionName': self.role_session_name or 'defaultSessionName' } tea_request.query["Timestamp"] = ph.get_iso_8061_date() tea_request.query["SignatureNonce"] = ph.get_uuid() if self.policy is not None: tea_request.query["Policy"] = self.policy tea_request.protocol = 'https' - tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com' + tea_request.headers['host'] = self.sts_endpoint # request response = await TeaCore.async_do_action(tea_request) if response.status_code == 200: @@ -473,7 +483,6 @@ async def _create_credential_async(self, turl=None): 'Version': '2015-04-01', 'DurationSeconds': str(self.duration_seconds), 'AccessKeyId': self.access_key_id, - 'RegionId': self.region_id, 'SignatureMethod': 'HMAC-SHA1', 'SignatureVersion': '1.0' } @@ -509,7 +518,6 @@ def _create_credential(self, turl=None): 'Version': '2015-04-01', 'DurationSeconds': str(self.duration_seconds), 'AccessKeyId': self.access_key_id, - 'RegionId': self.region_id, 'SignatureMethod': 'HMAC-SHA1', 'SignatureVersion': '1.0' } diff --git a/alibabacloud_credentials/utils/auth_util.py b/alibabacloud_credentials/utils/auth_util.py index 078ece7..189b15f 100644 --- a/alibabacloud_credentials/utils/auth_util.py +++ b/alibabacloud_credentials/utils/auth_util.py @@ -12,6 +12,8 @@ environment_oidc_provider_arn = os.environ.get('ALIBABA_CLOUD_OIDC_PROVIDER_ARN') environment_role_session_name = os.environ.get('ALIBABA_CLOUD_ROLE_SESSION_NAME') +environment_sts_region = os.environ.get('ALIBABA_CLOUD_STS_REGION') + enable_oidc_credential = environment_oidc_token_file is not None \ and environment_role_arn is not None \ and environment_oidc_provider_arn is not None diff --git a/tests/test_model.py b/tests/test_model.py index 61c9bda..dd5c888 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,6 +9,7 @@ def test_model_config(self): self.assertEqual('', conf1.access_key_secret) self.assertEqual('', conf1.role_name) self.assertEqual(1000, conf1.timeout) + self.assertIsNone(conf1.sts_endpoint) conf1.timeout = 0 conf1.access_key_id = 'access_key_id' diff --git a/tests/test_providers.py b/tests/test_providers.py index fabc788..6bca466 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -192,19 +192,27 @@ def test_RamRoleArnCredentialProvider(self): access_key_id=access_key_id, access_key_secret=access_key_secret, role_session_name=role_session_name, - role_arn=role_arn + role_arn=role_arn, + sts_endpoint='http://127.0.0.1:8888' ) prov = providers.RamRoleArnCredentialProvider(config=conf) self.assertEqual('access_key_id', prov.access_key_id) self.assertEqual('access_key_secret', prov.access_key_secret) self.assertEqual('role_session_name', prov.role_session_name) self.assertEqual('role_arn', prov.role_arn) - self.assertEqual('cn-hangzhou', prov.region_id) + self.assertIsNone(prov.region_id) self.assertIsNone(prov.policy) + self.assertEqual('http://127.0.0.1:8888', prov.sts_endpoint) - cred = prov._create_credentials(turl='http://127.0.0.1:8888') + cred = prov._create_credentials() self.assertEqual('AccessKeyId', cred.access_key_id) + auth_util.environment_sts_region = 'cn-hangzhou' + prov = providers.RamRoleArnCredentialProvider(config=conf) + self.assertEqual('cn-hangzhou', prov.region_id) + self.assertEqual('sts.cn-hangzhou.aliyuncs.com', prov.sts_endpoint) + auth_util.environment_sts_region = None + def test_OIDCRoleArnCredentialProvider(self): access_key_id, access_key_secret, role_session_name, role_arn, oidc_provider_arn, oidc_token_file_path, region_id, policy = \ 'access_key_id', 'access_key_secret', 'role_session_name', 'role_arn', 'oidc_provider_arn', 'tests/private_key.txt', 'region_id', 'policy' @@ -225,7 +233,8 @@ def test_OIDCRoleArnCredentialProvider(self): role_session_name=role_session_name, role_arn=role_arn, oidc_provider_arn=oidc_provider_arn, - oidc_token_file_path=oidc_token_file_path + oidc_token_file_path=oidc_token_file_path, + sts_endpoint='http://127.0.0.1:8888' ) prov = providers.OIDCRoleArnCredentialProvider(config=conf) self.assertEqual('access_key_id', prov.access_key_id) @@ -234,12 +243,19 @@ def test_OIDCRoleArnCredentialProvider(self): self.assertEqual('role_arn', prov.role_arn) self.assertEqual('oidc_provider_arn', prov.oidc_provider_arn) self.assertEqual('tests/private_key.txt', prov.oidc_token_file_path) - self.assertEqual('cn-hangzhou', prov.region_id) + self.assertIsNone(prov.region_id) self.assertIsNone(prov.policy) + self.assertEqual('http://127.0.0.1:8888', prov.sts_endpoint) - cred = prov._create_credentials(turl='http://127.0.0.1:8888') + cred = prov._create_credentials() self.assertEqual('AccessKeyId', cred.access_key_id) + auth_util.environment_sts_region = 'cn-hangzhou' + prov = providers.OIDCRoleArnCredentialProvider(config=conf) + self.assertEqual('cn-hangzhou', prov.region_id) + self.assertEqual('sts.cn-hangzhou.aliyuncs.com', prov.sts_endpoint) + auth_util.environment_sts_region = None + def test_RamRoleArnCredentialProvider_async(self): async def main(): access_key_id, access_key_secret, role_session_name, role_arn, region_id, policy = \ @@ -258,17 +274,19 @@ async def main(): access_key_id=access_key_id, access_key_secret=access_key_secret, role_session_name=role_session_name, - role_arn=role_arn + role_arn=role_arn, + sts_endpoint='http://127.0.0.1:8888' ) prov = providers.RamRoleArnCredentialProvider(config=conf) self.assertEqual('access_key_id', prov.access_key_id) self.assertEqual('access_key_secret', prov.access_key_secret) self.assertEqual('role_session_name', prov.role_session_name) self.assertEqual('role_arn', prov.role_arn) - self.assertEqual('cn-hangzhou', prov.region_id) + self.assertIsNone(prov.region_id) self.assertIsNone(prov.policy) + self.assertEqual('http://127.0.0.1:8888', prov.sts_endpoint) - cred = await prov._create_credentials_async(turl='http://127.0.0.1:8888') + cred = await prov._create_credentials_async() self.assertEqual('AccessKeyId', cred.access_key_id) loop.run_until_complete(main()) @@ -290,7 +308,7 @@ def test_RsaKeyPairCredentialProvider(self): prov = providers.RsaKeyPairCredentialProvider(config=conf) self.assertEqual('access_key_id', prov.access_key_id) self.assertEqual('access_key_secret', prov.access_key_secret) - self.assertEqual('cn-hangzhou', prov.region_id) + self.assertIsNone(prov.region_id) cred = prov._create_credential(turl='http://127.0.0.1:8888') self.assertEqual('SessionAccessKeyId', cred.access_key_id) @@ -313,7 +331,7 @@ async def main(): prov = providers.RsaKeyPairCredentialProvider(config=conf) self.assertEqual('access_key_id', prov.access_key_id) self.assertEqual('access_key_secret', prov.access_key_secret) - self.assertEqual('cn-hangzhou', prov.region_id) + self.assertIsNone(prov.region_id) cred = await prov._create_credential_async(turl='http://127.0.0.1:8888') self.assertEqual('SessionAccessKeyId', cred.access_key_id)