Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 61 additions & 27 deletions alibabacloud_credentials/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,28 @@
from alibabacloud_credentials.utils import auth_constant as ac


class _AutomaticallyRefreshCredentials:
def __init__(self, expiration, provider, refresh_fields):
self.expiration = expiration
self.provider = provider
self._REFRESH_FIELDS = refresh_fields

def _with_should_refresh(self):
return int(time.mktime(time.localtime())) >= (object.__getattribute__(self, 'expiration') - 180)

def _get_new_credential(self):
return self.provider.get_credentials()

def _refresh_credential(self):
if self._with_should_refresh():
return self._get_new_credential()

def __getattribute__(self, item):
if item in object.__getattribute__(self, '__dict__')['_REFRESH_FIELDS']:
self._refresh_credential()
return object.__getattribute__(self, item)


class AccessKeyCredential:
"""AccessKeyCredential"""

Expand All @@ -20,63 +42,75 @@ def __init__(self, bearer_token):
self.credential_type = ac.BEARER


class EcsRamRoleCredential:
class EcsRamRoleCredential(_AutomaticallyRefreshCredentials):
"""EcsRamRoleCredential"""

def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider):
refresh_fields = (
'access_key_id',
'access_key_secret',
'security_token',
'expiration'
)
super().__init__(expiration, provider, refresh_fields)
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.security_token = security_token
self.expiration = expiration
self.provider = provider
self.credential_type = ac.ECS_RAM_ROLE

def _refresh_credential(self):
credential = super()._refresh_credential()
if credential:
self.access_key_id = credential.access_key_id
self.access_key_secret = credential.access_key_secret
self.expiration = credential.expiration
self.security_token = credential.security_token


class RamRoleArnCredential:
class RamRoleArnCredential(_AutomaticallyRefreshCredentials):
"""RamRoleArnCredential"""

def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider):
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.security_token = security_token
self.expiration = expiration
self.provider = provider
self.credential_type = ac.RAM_ROLE_ARN
self._REFRESH_FIELDS = (
refresh_fields = (
'access_key_id',
'access_key_secret',
'security_token',
'expiration'
)

def _with_should_refresh(self):
return int(time.mktime(time.localtime())) >= (object.__getattribute__(self, 'expiration') - 180)

def _get_new_credential(self):
return self.provider.get_credentials()
super().__init__(expiration, provider, refresh_fields)
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.security_token = security_token
self.credential_type = ac.RAM_ROLE_ARN

def _refresh_credential(self):
if self._with_should_refresh():
credential = self._get_new_credential()
credential = super()._refresh_credential()
if credential:
self.access_key_id = credential.access_key_id
self.access_key_secret = credential.access_key_secret
self.expiration = credential.expiration
self.security_token = credential.security_token

def __getattribute__(self, item):
if item in object.__getattribute__(self, '__dict__')['_REFRESH_FIELDS']:
self._refresh_credential()
return object.__getattribute__(self, item)


class RsaKeyPairCredential:
class RsaKeyPairCredential(_AutomaticallyRefreshCredentials):
def __init__(self, access_key_id, access_key_secret, expiration, provider):
refresh_fields = (
'access_key_id',
'access_key_secret',
'expiration'
)
super().__init__(expiration, provider, refresh_fields)
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.expiration = expiration
self.provider = provider
self.credential_type = ac.RSA_KEY_PAIR

def _refresh_credential(self):
credential = super()._refresh_credential()
if credential:
self.access_key_id = credential.access_key_id
self.access_key_secret = credential.access_key_secret
self.expiration = credential.expiration


class StsCredential:
def __init__(self, access_key_id, access_key_secret, security_token):
Expand Down
35 changes: 17 additions & 18 deletions alibabacloud_credentials/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@


class Config(TeaModel):
def __init__(self, **kwargs):
super().__init__()
self.type = ""
self.access_key_id = ""
self.access_key_secret = ""
self.role_arn = ""
self.role_session_name = ""
self.public_key_id = ""
self.role_name = ""
self.private_key_file = ""
self.bearer_token = ""
self.security_token = ""
self.host = ""
self.timeout = 1000
self.connect_timeout = 1000
self.proxy = ""
for k, v in kwargs.items():
setattr(self, k, v)
def __init__(self, type='', access_key_id='', access_key_secret='', role_arn='', role_session_name='',
public_key_id='', role_name='', private_key_file='', bearer_token='', security_token='', host='',
timeout=1000, connect_timeout=1000, proxy=''):
self.type = type
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
self.role_arn = role_arn
self.role_session_name = role_session_name
self.public_key_id = public_key_id
self.role_name = role_name
self.private_key_file = private_key_file
self.bearer_token = bearer_token
self.security_token = security_token
self.host = host
self.timeout = timeout
self.connect_timeout = connect_timeout
self.proxy = proxy
10 changes: 5 additions & 5 deletions alibabacloud_credentials/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def _create_credential(self, url=None):
raise CredentialException(self.__ecs_metadata_fetch_error_msg + " HttpCode=" + str(response.status_code))
response.encoding = 'utf-8'
dic = json.loads(response.text)
content_code = dic.Code
content_access_key_id = dic.AccessKeyId
content_access_key_secret = dic.AccessKeySecret
content_security_token = dic.SecurityToken
content_expiration = dic.Expiration
content_code = dic.get('Code')
content_access_key_id = dic.get('AccessKeyId')
content_access_key_secret = dic.get('AccessKeySecret')
content_security_token = dic.get('SecurityToken')
content_expiration = dic.get('Expiration')

if content_code != "Success":
raise CredentialException(self.__ecs_metadata_fetch_error_msg)
Expand Down
6 changes: 6 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os

test_dir = os.path.dirname(__file__)

txt_file = os.path.join(test_dir, 'private_key.txt')
ini_file = os.path.join(test_dir, 'tests.ini')
23 changes: 22 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from alibabacloud_credentials.utils import auth_constant
from alibabacloud_credentials.client import Client
from alibabacloud_credentials import credentials
from alibabacloud_credentials import providers


class TestCredentials(unittest.TestCase):
def test_client(self):
def test_ak_client(self):
conf = Config()
conf.type = auth_constant.ACCESS_KEY
conf.access_key_id = '123456'
Expand All @@ -22,3 +23,23 @@ def test_client(self):
cred.get_access_key_id()
except Exception as e:
self.assertEqual(str(e), 'not found credentials')

conf = Config(type='sts')
cred = Client(conf)
self.assertIsInstance(cred.cloud_credential, credentials.StsCredential)

conf = Config(type='bearer')
cred = Client(conf)
self.assertIsInstance(cred.cloud_credential, credentials.BearerTokenCredential)

conf = Config(type='ecs_ram_role')
self.assertIsInstance(Client.get_provider(conf), providers.EcsRamRoleCredentialProvider)

conf = Config(type='ram_role_arn')
self.assertIsInstance(Client.get_provider(conf), providers.RamRoleArnCredentialProvider)

conf = Config(type='rsa_key_pair')
self.assertIsInstance(Client.get_provider(conf), providers.RsaKeyPairCredentialProvider)

conf = Config(type='test')
self.assertIsInstance(Client.get_provider(conf), providers.DefaultCredentialsProvider)
72 changes: 62 additions & 10 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,27 @@


class TestCredentials(unittest.TestCase):
class AlibabaCredentialProvider:
class TestEcsRamRoleProvider:
def get_credentials(self):
return credentials.RamRoleArnCredential("accessKeyId", "accessKeySecret", "securityToken", 10000, None)
return credentials.EcsRamRoleCredential("accessKeyId", "accessKeySecret", "securityToken", 100000000000,
None)

class TestRamRoleArnProvider:
def get_credentials(self):
return credentials.RamRoleArnCredential("accessKeyId", "accessKeySecret", "securityToken", 100000000000,
None)

class TestRsaKeyPairProvider:
def get_credentials(self):
return credentials.RsaKeyPairCredential("accessKeyId", "accessKeySecret", 100000000000,
None)

def test_EcsRamRoleCredential(self):
provider = providers.EcsRamRoleCredentialProvider("roleName")
access_key_id = 'access_key_id'
access_key_secret = 'access_key_secret'
security_token = 'security_token'
expiration = 100
expiration = 900000000000
cred = credentials.EcsRamRoleCredential(
access_key_id,
access_key_secret,
Expand All @@ -25,10 +36,23 @@ def test_EcsRamRoleCredential(self):
self.assertEqual('access_key_id', cred.access_key_id)
self.assertEqual('access_key_secret', cred.access_key_secret)
self.assertEqual('security_token', cred.security_token)
self.assertEqual(100, cred.expiration)
self.assertEqual(900000000000, cred.expiration)
self.assertIsInstance(cred.provider, providers.EcsRamRoleCredentialProvider)
self.assertEqual('ecs_ram_role', cred.credential_type)

cred = credentials.EcsRamRoleCredential(
access_key_id,
access_key_secret,
security_token,
100,
self.TestEcsRamRoleProvider()
)
# refresh token
self.assertEqual('accessKeyId', cred.access_key_id)
self.assertEqual('accessKeySecret', cred.access_key_secret)
self.assertEqual('securityToken', cred.security_token)
self.assertEqual(100000000000, cred.expiration)

def test_AccessKeyCredential(self):
access_key_id = 'access_key_id'
access_key_secret = 'access_key_secret'
Expand All @@ -48,17 +72,33 @@ def test_BearerTokenCredential(self):

def test_RamRoleArnCredential(self):
access_key_id, access_key_secret, security_token, expiration = \
'access_key_id', 'access_key_secret', 'security_token', 64090527132000
provider = self.AlibabaCredentialProvider()
'access_key_id', 'access_key_secret', 'security_token', 640900000000
provider = self.TestRamRoleArnProvider()
cred = credentials.RamRoleArnCredential(
access_key_id, access_key_secret, security_token, expiration, provider
)

self.assertEqual('access_key_id', cred.access_key_id)
self.assertEqual('access_key_secret', cred.access_key_secret)
self.assertEqual('security_token', cred.security_token)
self.assertEqual(64090527132000, cred.expiration)
self.assertEqual(640900000000, cred.expiration)

access_key_id, access_key_secret, security_token, expiration = \
'access_key_id', 'access_key_secret', 'security_token', 6409
provider = self.TestRamRoleArnProvider()
cred = credentials.RamRoleArnCredential(
access_key_id, access_key_secret, security_token, expiration, provider
)

# refresh token
self.assertTrue(cred._with_should_refresh())

self.assertEqual('accessKeyId', cred.access_key_id)
self.assertEqual('accessKeySecret', cred.access_key_secret)
self.assertEqual('securityToken', cred.security_token)
self.assertEqual(100000000000, cred.expiration)
self.assertEqual('ram_role_arn', cred.credential_type)
self.assertIsInstance(cred.provider, self.AlibabaCredentialProvider)
self.assertIsInstance(cred.provider, self.TestRamRoleArnProvider)

self.assertFalse(cred._with_should_refresh())

Expand All @@ -69,17 +109,29 @@ def test_RamRoleArnCredential(self):
self.assertIsNotNone(cred)

def test_RsaKeyPairCredential(self):
access_key_id, access_key_secret, expiration = 'access_key_id', 'access_key_secret', 100
access_key_id, access_key_secret, expiration = 'access_key_id', 'access_key_secret', 90000000000
provider = providers.RsaKeyPairCredentialProvider(access_key_id, access_key_secret)
cred = credentials.RsaKeyPairCredential(
access_key_id, access_key_secret, expiration, provider
)
self.assertEqual('access_key_id', cred.access_key_id)
self.assertEqual('access_key_secret', cred.access_key_secret)
self.assertEqual(100, cred.expiration)
self.assertEqual(90000000000, cred.expiration)
self.assertIsInstance(cred.provider, providers.RsaKeyPairCredentialProvider)
self.assertEqual('rsa_key_pair', cred.credential_type)

cred = credentials.RsaKeyPairCredential(
access_key_id,
access_key_secret,
900,
self.TestRsaKeyPairProvider()
)

# refresh token
self.assertEqual('accessKeyId', cred.access_key_id)
self.assertEqual('accessKeySecret', cred.access_key_secret)
self.assertEqual(100000000000, cred.expiration)

def test_StsCredential(self):
access_key_id, access_key_secret, security_token =\
'access_key_id', 'access_key_secret', 'security_token'
Expand Down
Loading