Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion alibabacloud_credentials/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(
proxy: str = '',
credentials_uri: str = '',
enable_imds_v2: bool = False,
metadata_token_duration: int = 21600
metadata_token_duration: int = 21600,
sts_endpoint: str = None
):
# accesskey id
self.access_key_id = access_key_id
Expand Down Expand Up @@ -71,6 +72,8 @@ def __init__(
self.proxy = proxy
# credentials uri
self.credentials_uri = credentials_uri
# STS Endpoint
self.sts_endpoint = sts_endpoint

def validate(self):
pass
Expand Down Expand Up @@ -121,6 +124,8 @@ def to_map(self):
result['proxy'] = self.proxy
if self.credentials_uri is not None:
result['credentialsUri'] = self.credentials_uri
if self.sts_endpoint is not None:
result['stsEndpoint'] = self.sts_endpoint
return result

def from_map(self, m: dict = None):
Expand Down Expand Up @@ -169,6 +174,8 @@ def from_map(self, m: dict = None):
self.proxy = m.get('proxy')
if m.get('credentialsUri') is not None:
self.credentials_uri = m.get('credentials_uri')
if m.get('stsEndpoint') is not None:
self.sts_endpoint = m.get('stsEndpoint')
return self


Expand Down
42 changes: 25 additions & 17 deletions alibabacloud_credentials/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class AlibabaCloudCredentialsProvider:
"""BaseProvider class"""
duration_seconds = 3600
timeout = 2000
region_id = 'cn-hangzhou'

def __init__(self, config=None):
if isinstance(config, Config):
Expand All @@ -41,6 +40,7 @@ def __init__(self, config=None):
self.timeout = config.timeout + config.connect_timeout
self.connect_timeout = config.connect_timeout
self.proxy = config.proxy
self.sts_endpoint = config.sts_endpoint

def _set_arg(self, key, value):
if value is not None:
Expand Down Expand Up @@ -267,11 +267,18 @@ def __init__(self, access_key_id=None, access_key_secret=None, role_session_name
self._set_arg('region_id', region_id)
self._set_arg('role_session_name', role_session_name)
self._set_arg('policy', policy)
if region_id is None and au.environment_sts_region is not None:
self._set_arg('region_id', au.environment_sts_region)
if self.region_id is not None:
self._set_arg('sts_endpoint', f'sts.{self.region_id}.aliyuncs.com')
else:
self._set_arg('sts_endpoint',
'sts.aliyuncs.com' if config is None or config.sts_endpoint is None else config.sts_endpoint)

def get_credentials(self):
return self._create_credentials()

def _create_credentials(self, turl=None):
def _create_credentials(self):
# 获取credential 先实现签名用工具类
tea_request = ph.get_new_request()
tea_request.query = {
Expand All @@ -281,7 +288,6 @@ def _create_credentials(self, turl=None):
'DurationSeconds': str(self.duration_seconds),
'RoleArn': self.role_arn,
'AccessKeyId': self.access_key_id,
'RegionId': self.region_id,
'RoleSessionName': self.role_session_name,
'SignatureMethod': 'HMAC-SHA1',
'SignatureVersion': '1.0'
Expand All @@ -294,7 +300,7 @@ def _create_credentials(self, turl=None):
signature = ph.sign_string(string_to_sign, self.access_key_secret + "&")
tea_request.query["Signature"] = signature
tea_request.protocol = 'https'
tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com'
tea_request.headers['host'] = self.sts_endpoint
# request
response = TeaCore.do_action(tea_request)
if response.status_code == 200:
Expand All @@ -312,7 +318,7 @@ def _create_credentials(self, turl=None):
async def get_credentials_async(self):
return await self._create_credentials_async()

async def _create_credentials_async(self, turl=None):
async def _create_credentials_async(self):
# 获取credential 先实现签名用工具类
tea_request = ph.get_new_request()
tea_request.query = {
Expand All @@ -322,7 +328,6 @@ async def _create_credentials_async(self, turl=None):
'DurationSeconds': str(self.duration_seconds),
'RoleArn': self.role_arn,
'AccessKeyId': self.access_key_id,
'RegionId': self.region_id,
'RoleSessionName': self.role_session_name,
'SignatureMethod': 'HMAC-SHA1',
'SignatureVersion': '1.0'
Expand All @@ -335,7 +340,7 @@ async def _create_credentials_async(self, turl=None):
signature = ph.sign_string(string_to_sign, self.access_key_secret + "&")
tea_request.query["Signature"] = signature
tea_request.protocol = 'https'
tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com'
tea_request.headers['host'] = self.sts_endpoint
# request
response = await TeaCore.async_do_action(tea_request)
if response.status_code == 200:
Expand Down Expand Up @@ -375,31 +380,37 @@ def __init__(self, role_session_name=None, role_arn=None,
self._set_arg('region_id', region_id)
self._set_arg('role_session_name', role_session_name)
self._set_arg('policy', policy)
if region_id is None and au.environment_sts_region is not None:
self._set_arg('region_id', au.environment_sts_region)
if self.region_id is not None:
self._set_arg('sts_endpoint', f'sts.{self.region_id}.aliyuncs.com')
else:
self._set_arg('sts_endpoint',
'sts.aliyuncs.com' if config is None or config.sts_endpoint is None else config.sts_endpoint)

def get_credentials(self):
return self._create_credentials()

def _create_credentials(self, turl=None):
def _create_credentials(self):
# 获取credential 先实现签名用工具类
oidc_token = au.get_private_key(self.oidc_token_file_path)
tea_request = ph.get_new_request()
tea_request.query = {
'Action': 'AssumeRoleWithOIDC',
'Format': 'JSON',
'Version': '2015-04-01',
'RegionId': self.region_id,
'DurationSeconds': str(self.duration_seconds),
'RoleArn': self.role_arn,
'OIDCProviderArn': self.oidc_provider_arn,
'OIDCToken': oidc_token,
'RoleSessionName': self.role_session_name if self.role_session_name else 'defaultSessionName'
'RoleSessionName': self.role_session_name or 'defaultSessionName'
}
tea_request.query["Timestamp"] = ph.get_iso_8061_date()
tea_request.query["SignatureNonce"] = ph.get_uuid()
if self.policy is not None:
tea_request.query["Policy"] = self.policy
tea_request.protocol = 'https'
tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com'
tea_request.headers['host'] = self.sts_endpoint
# request
response = TeaCore.do_action(tea_request)
if response.status_code == 200:
Expand All @@ -417,27 +428,26 @@ def _create_credentials(self, turl=None):
async def get_credentials_async(self):
return await self._create_credentials_async()

async def _create_credentials_async(self, turl=None):
async def _create_credentials_async(self):
# 获取credential 先实现签名用工具类
oidc_token = au.get_private_key(self.oidc_token_file_path)
tea_request = ph.get_new_request()
tea_request.query = {
'Action': 'AssumeRoleWithOIDC',
'Format': 'JSON',
'Version': '2015-04-01',
'RegionId': self.region_id,
'DurationSeconds': str(self.duration_seconds),
'RoleArn': self.role_arn,
'OIDCProviderArn': self.oidc_provider_arn,
'OIDCToken': oidc_token,
'RoleSessionName': self.role_session_name if self.role_session_name else 'defaultSessionName'
'RoleSessionName': self.role_session_name or 'defaultSessionName'
}
tea_request.query["Timestamp"] = ph.get_iso_8061_date()
tea_request.query["SignatureNonce"] = ph.get_uuid()
if self.policy is not None:
tea_request.query["Policy"] = self.policy
tea_request.protocol = 'https'
tea_request.headers['host'] = turl if turl else 'sts.aliyuncs.com'
tea_request.headers['host'] = self.sts_endpoint
# request
response = await TeaCore.async_do_action(tea_request)
if response.status_code == 200:
Expand Down Expand Up @@ -473,7 +483,6 @@ async def _create_credential_async(self, turl=None):
'Version': '2015-04-01',
'DurationSeconds': str(self.duration_seconds),
'AccessKeyId': self.access_key_id,
'RegionId': self.region_id,
'SignatureMethod': 'HMAC-SHA1',
'SignatureVersion': '1.0'
}
Expand Down Expand Up @@ -509,7 +518,6 @@ def _create_credential(self, turl=None):
'Version': '2015-04-01',
'DurationSeconds': str(self.duration_seconds),
'AccessKeyId': self.access_key_id,
'RegionId': self.region_id,
'SignatureMethod': 'HMAC-SHA1',
'SignatureVersion': '1.0'
}
Expand Down
2 changes: 2 additions & 0 deletions alibabacloud_credentials/utils/auth_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
environment_oidc_provider_arn = os.environ.get('ALIBABA_CLOUD_OIDC_PROVIDER_ARN')
environment_role_session_name = os.environ.get('ALIBABA_CLOUD_ROLE_SESSION_NAME')

environment_sts_region = os.environ.get('ALIBABA_CLOUD_STS_REGION')

enable_oidc_credential = environment_oidc_token_file is not None \
and environment_role_arn is not None \
and environment_oidc_provider_arn is not None
Expand Down
1 change: 1 addition & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def test_model_config(self):
self.assertEqual('', conf1.access_key_secret)
self.assertEqual('', conf1.role_name)
self.assertEqual(1000, conf1.timeout)
self.assertIsNone(conf1.sts_endpoint)

conf1.timeout = 0
conf1.access_key_id = 'access_key_id'
Expand Down
40 changes: 29 additions & 11 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,27 @@ def test_RamRoleArnCredentialProvider(self):
access_key_id=access_key_id,
access_key_secret=access_key_secret,
role_session_name=role_session_name,
role_arn=role_arn
role_arn=role_arn,
sts_endpoint='http://127.0.0.1:8888'
)
prov = providers.RamRoleArnCredentialProvider(config=conf)
self.assertEqual('access_key_id', prov.access_key_id)
self.assertEqual('access_key_secret', prov.access_key_secret)
self.assertEqual('role_session_name', prov.role_session_name)
self.assertEqual('role_arn', prov.role_arn)
self.assertEqual('cn-hangzhou', prov.region_id)
self.assertIsNone(prov.region_id)
self.assertIsNone(prov.policy)
self.assertEqual('http://127.0.0.1:8888', prov.sts_endpoint)

cred = prov._create_credentials(turl='http://127.0.0.1:8888')
cred = prov._create_credentials()
self.assertEqual('AccessKeyId', cred.access_key_id)

auth_util.environment_sts_region = 'cn-hangzhou'
prov = providers.RamRoleArnCredentialProvider(config=conf)
self.assertEqual('cn-hangzhou', prov.region_id)
self.assertEqual('sts.cn-hangzhou.aliyuncs.com', prov.sts_endpoint)
auth_util.environment_sts_region = None

def test_OIDCRoleArnCredentialProvider(self):
access_key_id, access_key_secret, role_session_name, role_arn, oidc_provider_arn, oidc_token_file_path, region_id, policy = \
'access_key_id', 'access_key_secret', 'role_session_name', 'role_arn', 'oidc_provider_arn', 'tests/private_key.txt', 'region_id', 'policy'
Expand All @@ -225,7 +233,8 @@ def test_OIDCRoleArnCredentialProvider(self):
role_session_name=role_session_name,
role_arn=role_arn,
oidc_provider_arn=oidc_provider_arn,
oidc_token_file_path=oidc_token_file_path
oidc_token_file_path=oidc_token_file_path,
sts_endpoint='http://127.0.0.1:8888'
)
prov = providers.OIDCRoleArnCredentialProvider(config=conf)
self.assertEqual('access_key_id', prov.access_key_id)
Expand All @@ -234,12 +243,19 @@ def test_OIDCRoleArnCredentialProvider(self):
self.assertEqual('role_arn', prov.role_arn)
self.assertEqual('oidc_provider_arn', prov.oidc_provider_arn)
self.assertEqual('tests/private_key.txt', prov.oidc_token_file_path)
self.assertEqual('cn-hangzhou', prov.region_id)
self.assertIsNone(prov.region_id)
self.assertIsNone(prov.policy)
self.assertEqual('http://127.0.0.1:8888', prov.sts_endpoint)

cred = prov._create_credentials(turl='http://127.0.0.1:8888')
cred = prov._create_credentials()
self.assertEqual('AccessKeyId', cred.access_key_id)

auth_util.environment_sts_region = 'cn-hangzhou'
prov = providers.OIDCRoleArnCredentialProvider(config=conf)
self.assertEqual('cn-hangzhou', prov.region_id)
self.assertEqual('sts.cn-hangzhou.aliyuncs.com', prov.sts_endpoint)
auth_util.environment_sts_region = None

def test_RamRoleArnCredentialProvider_async(self):
async def main():
access_key_id, access_key_secret, role_session_name, role_arn, region_id, policy = \
Expand All @@ -258,17 +274,19 @@ async def main():
access_key_id=access_key_id,
access_key_secret=access_key_secret,
role_session_name=role_session_name,
role_arn=role_arn
role_arn=role_arn,
sts_endpoint='http://127.0.0.1:8888'
)
prov = providers.RamRoleArnCredentialProvider(config=conf)
self.assertEqual('access_key_id', prov.access_key_id)
self.assertEqual('access_key_secret', prov.access_key_secret)
self.assertEqual('role_session_name', prov.role_session_name)
self.assertEqual('role_arn', prov.role_arn)
self.assertEqual('cn-hangzhou', prov.region_id)
self.assertIsNone(prov.region_id)
self.assertIsNone(prov.policy)
self.assertEqual('http://127.0.0.1:8888', prov.sts_endpoint)

cred = await prov._create_credentials_async(turl='http://127.0.0.1:8888')
cred = await prov._create_credentials_async()
self.assertEqual('AccessKeyId', cred.access_key_id)

loop.run_until_complete(main())
Expand All @@ -290,7 +308,7 @@ def test_RsaKeyPairCredentialProvider(self):
prov = providers.RsaKeyPairCredentialProvider(config=conf)
self.assertEqual('access_key_id', prov.access_key_id)
self.assertEqual('access_key_secret', prov.access_key_secret)
self.assertEqual('cn-hangzhou', prov.region_id)
self.assertIsNone(prov.region_id)

cred = prov._create_credential(turl='http://127.0.0.1:8888')
self.assertEqual('SessionAccessKeyId', cred.access_key_id)
Expand All @@ -313,7 +331,7 @@ async def main():
prov = providers.RsaKeyPairCredentialProvider(config=conf)
self.assertEqual('access_key_id', prov.access_key_id)
self.assertEqual('access_key_secret', prov.access_key_secret)
self.assertEqual('cn-hangzhou', prov.region_id)
self.assertIsNone(prov.region_id)

cred = await prov._create_credential_async(turl='http://127.0.0.1:8888')
self.assertEqual('SessionAccessKeyId', cred.access_key_id)
Expand Down