From 04868f58eaa3dec1d866380453c590577bc6db9b Mon Sep 17 00:00:00 2001 From: 824750130 Date: Fri, 7 Aug 2020 15:41:41 +0800 Subject: [PATCH] EcsRamRole&RsaKeyPair support refresh token. --- alibabacloud_credentials/credentials.py | 88 +++++++++++++++++-------- alibabacloud_credentials/models.py | 35 +++++----- alibabacloud_credentials/providers.py | 10 +-- tests/__init__.py | 6 ++ tests/test_client.py | 23 ++++++- tests/test_credentials.py | 72 +++++++++++++++++--- tests/test_providers.py | 40 +++++++++-- tests/test_util.py | 3 +- 8 files changed, 210 insertions(+), 67 deletions(-) diff --git a/alibabacloud_credentials/credentials.py b/alibabacloud_credentials/credentials.py index 4033c3b..9ca1a6a 100644 --- a/alibabacloud_credentials/credentials.py +++ b/alibabacloud_credentials/credentials.py @@ -3,6 +3,28 @@ from alibabacloud_credentials.utils import auth_constant as ac +class _AutomaticallyRefreshCredentials: + def __init__(self, expiration, provider, refresh_fields): + self.expiration = expiration + self.provider = provider + self._REFRESH_FIELDS = refresh_fields + + def _with_should_refresh(self): + return int(time.mktime(time.localtime())) >= (object.__getattribute__(self, 'expiration') - 180) + + def _get_new_credential(self): + return self.provider.get_credentials() + + def _refresh_credential(self): + if self._with_should_refresh(): + return self._get_new_credential() + + def __getattribute__(self, item): + if item in object.__getattribute__(self, '__dict__')['_REFRESH_FIELDS']: + self._refresh_credential() + return object.__getattribute__(self, item) + + class AccessKeyCredential: """AccessKeyCredential""" @@ -20,63 +42,75 @@ def __init__(self, bearer_token): self.credential_type = ac.BEARER -class EcsRamRoleCredential: +class EcsRamRoleCredential(_AutomaticallyRefreshCredentials): """EcsRamRoleCredential""" def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider): + refresh_fields = ( + 'access_key_id', + 'access_key_secret', + 'security_token', + 'expiration' + ) + super().__init__(expiration, provider, refresh_fields) self.access_key_id = access_key_id self.access_key_secret = access_key_secret self.security_token = security_token - self.expiration = expiration - self.provider = provider self.credential_type = ac.ECS_RAM_ROLE + def _refresh_credential(self): + credential = super()._refresh_credential() + if credential: + self.access_key_id = credential.access_key_id + self.access_key_secret = credential.access_key_secret + self.expiration = credential.expiration + self.security_token = credential.security_token + -class RamRoleArnCredential: +class RamRoleArnCredential(_AutomaticallyRefreshCredentials): """RamRoleArnCredential""" def __init__(self, access_key_id, access_key_secret, security_token, expiration, provider): - self.access_key_id = access_key_id - self.access_key_secret = access_key_secret - self.security_token = security_token - self.expiration = expiration - self.provider = provider - self.credential_type = ac.RAM_ROLE_ARN - self._REFRESH_FIELDS = ( + refresh_fields = ( 'access_key_id', 'access_key_secret', 'security_token', 'expiration' ) - - def _with_should_refresh(self): - return int(time.mktime(time.localtime())) >= (object.__getattribute__(self, 'expiration') - 180) - - def _get_new_credential(self): - return self.provider.get_credentials() + super().__init__(expiration, provider, refresh_fields) + self.access_key_id = access_key_id + self.access_key_secret = access_key_secret + self.security_token = security_token + self.credential_type = ac.RAM_ROLE_ARN def _refresh_credential(self): - if self._with_should_refresh(): - credential = self._get_new_credential() + credential = super()._refresh_credential() + if credential: self.access_key_id = credential.access_key_id self.access_key_secret = credential.access_key_secret self.expiration = credential.expiration self.security_token = credential.security_token - def __getattribute__(self, item): - if item in object.__getattribute__(self, '__dict__')['_REFRESH_FIELDS']: - self._refresh_credential() - return object.__getattribute__(self, item) - -class RsaKeyPairCredential: +class RsaKeyPairCredential(_AutomaticallyRefreshCredentials): def __init__(self, access_key_id, access_key_secret, expiration, provider): + refresh_fields = ( + 'access_key_id', + 'access_key_secret', + 'expiration' + ) + super().__init__(expiration, provider, refresh_fields) self.access_key_id = access_key_id self.access_key_secret = access_key_secret - self.expiration = expiration - self.provider = provider self.credential_type = ac.RSA_KEY_PAIR + def _refresh_credential(self): + credential = super()._refresh_credential() + if credential: + self.access_key_id = credential.access_key_id + self.access_key_secret = credential.access_key_secret + self.expiration = credential.expiration + class StsCredential: def __init__(self, access_key_id, access_key_secret, security_token): diff --git a/alibabacloud_credentials/models.py b/alibabacloud_credentials/models.py index 374589a..002ef0d 100644 --- a/alibabacloud_credentials/models.py +++ b/alibabacloud_credentials/models.py @@ -2,21 +2,20 @@ class Config(TeaModel): - def __init__(self, **kwargs): - super().__init__() - self.type = "" - self.access_key_id = "" - self.access_key_secret = "" - self.role_arn = "" - self.role_session_name = "" - self.public_key_id = "" - self.role_name = "" - self.private_key_file = "" - self.bearer_token = "" - self.security_token = "" - self.host = "" - self.timeout = 1000 - self.connect_timeout = 1000 - self.proxy = "" - for k, v in kwargs.items(): - setattr(self, k, v) + def __init__(self, type='', access_key_id='', access_key_secret='', role_arn='', role_session_name='', + public_key_id='', role_name='', private_key_file='', bearer_token='', security_token='', host='', + timeout=1000, connect_timeout=1000, proxy=''): + self.type = type + self.access_key_id = access_key_id + self.access_key_secret = access_key_secret + self.role_arn = role_arn + self.role_session_name = role_session_name + self.public_key_id = public_key_id + self.role_name = role_name + self.private_key_file = private_key_file + self.bearer_token = bearer_token + self.security_token = security_token + self.host = host + self.timeout = timeout + self.connect_timeout = connect_timeout + self.proxy = proxy diff --git a/alibabacloud_credentials/providers.py b/alibabacloud_credentials/providers.py index 4767814..ad384ab 100644 --- a/alibabacloud_credentials/providers.py +++ b/alibabacloud_credentials/providers.py @@ -110,11 +110,11 @@ def _create_credential(self, url=None): raise CredentialException(self.__ecs_metadata_fetch_error_msg + " HttpCode=" + str(response.status_code)) response.encoding = 'utf-8' dic = json.loads(response.text) - content_code = dic.Code - content_access_key_id = dic.AccessKeyId - content_access_key_secret = dic.AccessKeySecret - content_security_token = dic.SecurityToken - content_expiration = dic.Expiration + content_code = dic.get('Code') + content_access_key_id = dic.get('AccessKeyId') + content_access_key_secret = dic.get('AccessKeySecret') + content_security_token = dic.get('SecurityToken') + content_expiration = dic.get('Expiration') if content_code != "Success": raise CredentialException(self.__ecs_metadata_fetch_error_msg) diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..5a3aa3c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,6 @@ +import os + +test_dir = os.path.dirname(__file__) + +txt_file = os.path.join(test_dir, 'private_key.txt') +ini_file = os.path.join(test_dir, 'tests.ini') diff --git a/tests/test_client.py b/tests/test_client.py index 2d557b4..6ece01c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,10 +4,11 @@ from alibabacloud_credentials.utils import auth_constant from alibabacloud_credentials.client import Client from alibabacloud_credentials import credentials +from alibabacloud_credentials import providers class TestCredentials(unittest.TestCase): - def test_client(self): + def test_ak_client(self): conf = Config() conf.type = auth_constant.ACCESS_KEY conf.access_key_id = '123456' @@ -22,3 +23,23 @@ def test_client(self): cred.get_access_key_id() except Exception as e: self.assertEqual(str(e), 'not found credentials') + + conf = Config(type='sts') + cred = Client(conf) + self.assertIsInstance(cred.cloud_credential, credentials.StsCredential) + + conf = Config(type='bearer') + cred = Client(conf) + self.assertIsInstance(cred.cloud_credential, credentials.BearerTokenCredential) + + conf = Config(type='ecs_ram_role') + self.assertIsInstance(Client.get_provider(conf), providers.EcsRamRoleCredentialProvider) + + conf = Config(type='ram_role_arn') + self.assertIsInstance(Client.get_provider(conf), providers.RamRoleArnCredentialProvider) + + conf = Config(type='rsa_key_pair') + self.assertIsInstance(Client.get_provider(conf), providers.RsaKeyPairCredentialProvider) + + conf = Config(type='test') + self.assertIsInstance(Client.get_provider(conf), providers.DefaultCredentialsProvider) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 4c78a8c..7b78e44 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -4,16 +4,27 @@ class TestCredentials(unittest.TestCase): - class AlibabaCredentialProvider: + class TestEcsRamRoleProvider: def get_credentials(self): - return credentials.RamRoleArnCredential("accessKeyId", "accessKeySecret", "securityToken", 10000, None) + return credentials.EcsRamRoleCredential("accessKeyId", "accessKeySecret", "securityToken", 100000000000, + None) + + class TestRamRoleArnProvider: + def get_credentials(self): + return credentials.RamRoleArnCredential("accessKeyId", "accessKeySecret", "securityToken", 100000000000, + None) + + class TestRsaKeyPairProvider: + def get_credentials(self): + return credentials.RsaKeyPairCredential("accessKeyId", "accessKeySecret", 100000000000, + None) def test_EcsRamRoleCredential(self): provider = providers.EcsRamRoleCredentialProvider("roleName") access_key_id = 'access_key_id' access_key_secret = 'access_key_secret' security_token = 'security_token' - expiration = 100 + expiration = 900000000000 cred = credentials.EcsRamRoleCredential( access_key_id, access_key_secret, @@ -25,10 +36,23 @@ def test_EcsRamRoleCredential(self): self.assertEqual('access_key_id', cred.access_key_id) self.assertEqual('access_key_secret', cred.access_key_secret) self.assertEqual('security_token', cred.security_token) - self.assertEqual(100, cred.expiration) + self.assertEqual(900000000000, cred.expiration) self.assertIsInstance(cred.provider, providers.EcsRamRoleCredentialProvider) self.assertEqual('ecs_ram_role', cred.credential_type) + cred = credentials.EcsRamRoleCredential( + access_key_id, + access_key_secret, + security_token, + 100, + self.TestEcsRamRoleProvider() + ) + # refresh token + self.assertEqual('accessKeyId', cred.access_key_id) + self.assertEqual('accessKeySecret', cred.access_key_secret) + self.assertEqual('securityToken', cred.security_token) + self.assertEqual(100000000000, cred.expiration) + def test_AccessKeyCredential(self): access_key_id = 'access_key_id' access_key_secret = 'access_key_secret' @@ -48,17 +72,33 @@ def test_BearerTokenCredential(self): def test_RamRoleArnCredential(self): access_key_id, access_key_secret, security_token, expiration = \ - 'access_key_id', 'access_key_secret', 'security_token', 64090527132000 - provider = self.AlibabaCredentialProvider() + 'access_key_id', 'access_key_secret', 'security_token', 640900000000 + provider = self.TestRamRoleArnProvider() cred = credentials.RamRoleArnCredential( access_key_id, access_key_secret, security_token, expiration, provider ) + self.assertEqual('access_key_id', cred.access_key_id) self.assertEqual('access_key_secret', cred.access_key_secret) self.assertEqual('security_token', cred.security_token) - self.assertEqual(64090527132000, cred.expiration) + self.assertEqual(640900000000, cred.expiration) + + access_key_id, access_key_secret, security_token, expiration = \ + 'access_key_id', 'access_key_secret', 'security_token', 6409 + provider = self.TestRamRoleArnProvider() + cred = credentials.RamRoleArnCredential( + access_key_id, access_key_secret, security_token, expiration, provider + ) + + # refresh token + self.assertTrue(cred._with_should_refresh()) + + self.assertEqual('accessKeyId', cred.access_key_id) + self.assertEqual('accessKeySecret', cred.access_key_secret) + self.assertEqual('securityToken', cred.security_token) + self.assertEqual(100000000000, cred.expiration) self.assertEqual('ram_role_arn', cred.credential_type) - self.assertIsInstance(cred.provider, self.AlibabaCredentialProvider) + self.assertIsInstance(cred.provider, self.TestRamRoleArnProvider) self.assertFalse(cred._with_should_refresh()) @@ -69,17 +109,29 @@ def test_RamRoleArnCredential(self): self.assertIsNotNone(cred) def test_RsaKeyPairCredential(self): - access_key_id, access_key_secret, expiration = 'access_key_id', 'access_key_secret', 100 + access_key_id, access_key_secret, expiration = 'access_key_id', 'access_key_secret', 90000000000 provider = providers.RsaKeyPairCredentialProvider(access_key_id, access_key_secret) cred = credentials.RsaKeyPairCredential( access_key_id, access_key_secret, expiration, provider ) self.assertEqual('access_key_id', cred.access_key_id) self.assertEqual('access_key_secret', cred.access_key_secret) - self.assertEqual(100, cred.expiration) + self.assertEqual(90000000000, cred.expiration) self.assertIsInstance(cred.provider, providers.RsaKeyPairCredentialProvider) self.assertEqual('rsa_key_pair', cred.credential_type) + cred = credentials.RsaKeyPairCredential( + access_key_id, + access_key_secret, + 900, + self.TestRsaKeyPairProvider() + ) + + # refresh token + self.assertEqual('accessKeyId', cred.access_key_id) + self.assertEqual('accessKeySecret', cred.access_key_secret) + self.assertEqual(100000000000, cred.expiration) + def test_StsCredential(self): access_key_id, access_key_secret, security_token =\ 'access_key_id', 'access_key_secret', 'security_token' diff --git a/tests/test_providers.py b/tests/test_providers.py index 473b6e5..e5e5dc9 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,15 +1,43 @@ import unittest import json +import time import requests from alibabacloud_credentials.credentials import AccessKeyCredential from alibabacloud_credentials import providers, models, credentials, exceptions from alibabacloud_credentials.utils import auth_util +from . import ini_file -ini_file = 'tests/tests.ini' +import threading +from http.server import HTTPServer, BaseHTTPRequestHandler + + +class Request(BaseHTTPRequestHandler): + def do_GET(self): + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(b'{"Code": "Success", "AccessKeyId": "ak",' + b' "Expiration": "3999-08-07T20:20:20Z", "Credentials":' + b' {"Expiration": "3999-08-07T20:20:20Z"}, "SessionAccessKey":' + b' {"Expiration": "3999-08-07T20:20:20Z"}}') + + +def run_server(): + server = HTTPServer(('localhost', 8888), Request) + server.serve_forever() class TestProviders(unittest.TestCase): + @classmethod + def setUpClass(cls): + server = threading.Thread(target=run_server) + server.setDaemon(True) + server.start() + + @staticmethod + def strftime(t): + return time.strftime('%Y-%m-%dT%H:%M:%SZ', time.localtime(t)) def test_EcsRamRoleCredentialProvider(self): prov = providers.EcsRamRoleCredentialProvider("roleName") @@ -25,7 +53,9 @@ def test_EcsRamRoleCredentialProvider(self): self.assertEqual("roleNameConfig", prov.role_name) self.assertEqual(2300, prov.timeout) # prov._create_credential(url='http://www.aliyun.com') - self.assertRaises(json.decoder.JSONDecodeError, prov._create_credential, url='http://www.aliyun.com') + cred = prov._create_credential(url='http://127.0.0.1:8888') + self.assertEqual('ak', cred.access_key_id) + self.assertEqual('3999-08-07T20:20:20Z', self.strftime(cred.expiration)) prov._get_role_name(url='http://www.aliyun.com') self.assertIsNotNone(prov.role_name) @@ -92,7 +122,8 @@ def test_RamRoleArnCredentialProvider(self): self.assertEqual('cn-hangzhou', prov.region_id) self.assertIsNone(prov.policy) - self.assertRaises(json.decoder.JSONDecodeError, prov._create_credentials, turl='http://www.aliyun.com') + cred = prov._create_credentials(turl='http://127.0.0.1:8888') + self.assertEqual('3999-08-07T20:20:20Z', self.strftime(cred.expiration)) def test_RsaKeyPairCredentialProvider(self): access_key_id, access_key_secret, region_id = \ @@ -113,7 +144,8 @@ def test_RsaKeyPairCredentialProvider(self): self.assertEqual('access_key_secret', prov.access_key_secret) self.assertEqual('cn-hangzhou', prov.region_id) - self.assertRaises(json.decoder.JSONDecodeError, prov._create_credential,turl='http://www.aliyun.com') + cred = prov._create_credential(turl='http://127.0.0.1:8888') + self.assertEqual('3999-08-07T20:20:20Z', self.strftime(cred.expiration)) def test_ProfileCredentialsProvider(self): prov = providers.ProfileCredentialsProvider(ini_file) diff --git a/tests/test_util.py b/tests/test_util.py index 813dfe4..40d959d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,8 +1,7 @@ from alibabacloud_credentials.utils import auth_util, parameter_helper import unittest - -txt_file = 'tests/private_key.txt' +from . import txt_file class TestCredentials(unittest.TestCase):