Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions alibabacloud_credentials/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def get_provider(config):

@attribute_error_return_none
def get_access_key_id(self):
return self.cloud_credential.access_key_id
return self.cloud_credential.get_access_key_id()

@attribute_error_return_none
def get_access_key_secret(self):
return self.cloud_credential.access_key_secret
return self.cloud_credential.get_access_key_secret()

@attribute_error_return_none
def get_security_token(self):
return self.cloud_credential.security_token
return self.cloud_credential.get_security_token()

@attribute_error_return_none
def get_type(self):
Expand Down
80 changes: 52 additions & 28 deletions alibabacloud_credentials/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@


class _AutomaticallyRefreshCredentials:
def __init__(self, expiration, provider, refresh_fields):
def __init__(self, expiration, provider):
self.expiration = expiration
self.provider = provider
self._REFRESH_FIELDS = refresh_fields

def _with_should_refresh(self):
return int(time.mktime(time.localtime())) >= (object.__getattribute__(self, 'expiration') - 180)
return int(time.mktime(time.localtime())) >= (self.expiration - 180)

def _get_new_credential(self):
return self.provider.get_credentials()
Expand All @@ -19,11 +18,6 @@ def _refresh_credential(self):
if self._with_should_refresh():
return self._get_new_credential()

def __getattribute__(self, item):
if item in object.__getattribute__(self, '__dict__')['_REFRESH_FIELDS']:
self._refresh_credential()
return object.__getattribute__(self, item)


class AccessKeyCredential:
"""AccessKeyCredential"""
Expand All @@ -33,6 +27,12 @@ def __init__(self, access_key_id, access_key_secret):
self.access_key_secret = access_key_secret
self.credential_type = ac.ACCESS_KEY

def get_access_key_id(self):
return self.access_key_id

def get_access_key_secret(self):
return self.access_key_secret


class BearerTokenCredential:
"""BearerTokenCredential"""
Expand All @@ -46,13 +46,7 @@ class EcsRamRoleCredential(_AutomaticallyRefreshCredentials):
"""EcsRamRoleCredential"""

def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider):
refresh_fields = (
'access_key_id',
'access_key_secret',
'security_token',
'expiration'
)
super().__init__(expiration, provider, refresh_fields)
super().__init__(expiration, provider)
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.security_token = security_token
Expand All @@ -66,18 +60,24 @@ def _refresh_credential(self):
self.expiration = credential.expiration
self.security_token = credential.security_token

def get_access_key_id(self):
self._refresh_credential()
return self.access_key_id

def get_access_key_secret(self):
self._refresh_credential()
return self.access_key_secret

def get_security_token(self):
self._refresh_credential()
return self.security_token


class RamRoleArnCredential(_AutomaticallyRefreshCredentials):
"""RamRoleArnCredential"""

def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider):
refresh_fields = (
'access_key_id',
'access_key_secret',
'security_token',
'expiration'
)
super().__init__(expiration, provider, refresh_fields)
super().__init__(expiration, provider)
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.security_token = security_token
Expand All @@ -91,15 +91,22 @@ def _refresh_credential(self):
self.expiration = credential.expiration
self.security_token = credential.security_token

def get_access_key_id(self):
self._refresh_credential()
return self.access_key_id

def get_access_key_secret(self):
self._refresh_credential()
return self.access_key_secret

def get_security_token(self):
self._refresh_credential()
return self.security_token


class RsaKeyPairCredential(_AutomaticallyRefreshCredentials):
def __init__(self, access_key_id, access_key_secret, expiration, provider):
refresh_fields = (
'access_key_id',
'access_key_secret',
'expiration'
)
super().__init__(expiration, provider, refresh_fields)
super().__init__(expiration, provider)
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.credential_type = ac.RSA_KEY_PAIR
Expand All @@ -111,10 +118,27 @@ def _refresh_credential(self):
self.access_key_secret = credential.access_key_secret
self.expiration = credential.expiration

def get_access_key_id(self):
self._refresh_credential()
return self.access_key_id

def get_access_key_secret(self):
self._refresh_credential()
return self.access_key_secret


class StsCredential:
def __init__(self, access_key_id, access_key_secret, security_token):
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.security_token = security_token
self.credential_type = ac.STS

def get_access_key_id(self):
return self.access_key_id

def get_access_key_secret(self):
return self.access_key_secret

def get_security_token(self):
return self.security_token
7 changes: 4 additions & 3 deletions alibabacloud_credentials/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import time
import configparser
import calendar

from alibabacloud_credentials.utils import auth_util as au, \
auth_constant as ac, \
Expand Down Expand Up @@ -123,7 +124,7 @@ def _create_credential(self, url=None):
# 先转换为时间数组
time_array = time.strptime(expiration_str, "%Y-%m-%d %H:%M:%S")
# 转换为时间戳
time_stamp = int(time.mktime(time_array))
time_stamp = calendar.timegm(time_array)
return credentials.EcsRamRoleCredential(content_access_key_id, content_access_key_secret,
content_security_token, time_stamp, self)

Expand Down Expand Up @@ -184,7 +185,7 @@ def _create_credentials(self, turl=None):
# 先转换为时间数组
time_array = time.strptime(expiration_str, "%Y-%m-%d %H:%M:%S")
# 转换为时间戳
expiration = int(time.mktime(time_array))
expiration = calendar.timegm(time_array)
return credentials.RamRoleArnCredential(cre.get("AccessKeyId"), cre.get("AccessKeySecret"),
cre.get("SecurityToken"), expiration, self)
raise CredentialException(response.text)
Expand Down Expand Up @@ -225,7 +226,7 @@ def _create_credential(self, turl=None):
cre = dic.get("SessionAccessKey")
expiration_str = cre.get("Expiration").replace("T", " ").replace("Z", "")
time_array = time.strptime(expiration_str, "%Y-%m-%d %H:%M:%S")
expiration = int(time.mktime(time_array))
expiration = calendar.timegm(time_array)
return credentials.RsaKeyPairCredential(cre.get("SessionAccessKeyId"), cre.get("SessionAccessKeySecret"),
expiration, self)
raise CredentialException(resp.text)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def test_EcsRamRoleCredential(self):
provider
)

self.assertEqual('access_key_id', cred.access_key_id)
self.assertEqual('access_key_secret', cred.access_key_secret)
self.assertEqual('security_token', cred.security_token)
self.assertEqual('access_key_id', cred.get_access_key_id())
self.assertEqual('access_key_secret', cred.get_access_key_secret())
self.assertEqual('security_token', cred.get_security_token())
self.assertEqual(900000000000, cred.expiration)
self.assertIsInstance(cred.provider, providers.EcsRamRoleCredentialProvider)
self.assertEqual('ecs_ram_role', cred.credential_type)
Expand All @@ -48,7 +48,7 @@ def test_EcsRamRoleCredential(self):
self.TestEcsRamRoleProvider()
)
# refresh token
self.assertEqual('accessKeyId', cred.access_key_id)
self.assertEqual('accessKeyId', cred.get_access_key_id())
self.assertEqual('accessKeySecret', cred.access_key_secret)
self.assertEqual('securityToken', cred.security_token)
self.assertEqual(100000000000, cred.expiration)
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_RamRoleArnCredential(self):
# refresh token
self.assertTrue(cred._with_should_refresh())

self.assertEqual('accessKeyId', cred.access_key_id)
self.assertEqual('accessKeyId', cred.get_access_key_id())
self.assertEqual('accessKeySecret', cred.access_key_secret)
self.assertEqual('securityToken', cred.security_token)
self.assertEqual(100000000000, cred.expiration)
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_RsaKeyPairCredential(self):
)

# refresh token
self.assertEqual('accessKeyId', cred.access_key_id)
self.assertEqual('accessKeyId', cred.get_access_key_id())
self.assertEqual('accessKeySecret', cred.access_key_secret)
self.assertEqual(100000000000, cred.expiration)

Expand Down
11 changes: 5 additions & 6 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def do_GET(self):
self.end_headers()
self.wfile.write(b'{"Code": "Success", "AccessKeyId": "ak",'
b' "Expiration": "3999-08-07T20:20:20Z", "Credentials":'
b' {"Expiration": "3999-08-07T20:20:20Z"}, "SessionAccessKey":'
b' {"Expiration": "3999-08-07T20:20:20Z"}}')
b' {"Expiration": "3999-08-07T20:20:20Z", "AccessKeyId": "AccessKeyId"}, "SessionAccessKey":'
b' {"Expiration": "3999-08-07T20:20:20Z", "SessionAccessKeyId": "SessionAccessKeyId"}}')


def run_server():
Expand Down Expand Up @@ -55,9 +55,8 @@ def test_EcsRamRoleCredentialProvider(self):
# prov._create_credential(url='http://www.aliyun.com')
cred = prov._create_credential(url='http://127.0.0.1:8888')
self.assertEqual('ak', cred.access_key_id)
self.assertEqual('3999-08-07T20:20:20Z', self.strftime(cred.expiration))

prov._get_role_name(url='http://www.aliyun.com')
prov._get_role_name(url='http://127.0.0.1:8888')
self.assertIsNotNone(prov.role_name)
prov.role_name = 'role_name'
prov._set_credential_url()
Expand Down Expand Up @@ -123,7 +122,7 @@ def test_RamRoleArnCredentialProvider(self):
self.assertIsNone(prov.policy)

cred = prov._create_credentials(turl='http://127.0.0.1:8888')
self.assertEqual('3999-08-07T20:20:20Z', self.strftime(cred.expiration))
self.assertEqual('AccessKeyId', cred.access_key_id)

def test_RsaKeyPairCredentialProvider(self):
access_key_id, access_key_secret, region_id = \
Expand All @@ -145,7 +144,7 @@ def test_RsaKeyPairCredentialProvider(self):
self.assertEqual('cn-hangzhou', prov.region_id)

cred = prov._create_credential(turl='http://127.0.0.1:8888')
self.assertEqual('3999-08-07T20:20:20Z', self.strftime(cred.expiration))
self.assertEqual('SessionAccessKeyId', cred.access_key_id)

def test_ProfileCredentialsProvider(self):
prov = providers.ProfileCredentialsProvider(ini_file)
Expand Down