diff --git a/alibabacloud_credentials/models.py b/alibabacloud_credentials/models.py index 5a91722..98dcb37 100644 --- a/alibabacloud_credentials/models.py +++ b/alibabacloud_credentials/models.py @@ -30,6 +30,7 @@ def __init__( connect_timeout: int = 1000, proxy: str = '', credentials_uri: str = '', + disable_imds_v1: bool = False, enable_imds_v2: bool = False, metadata_token_duration: int = 21600, sts_endpoint: str = None @@ -62,6 +63,7 @@ def __init__( self.private_key_file = private_key_file # role name self.role_name = role_name + self.disable_imds_v1 = disable_imds_v1 self.enable_imds_v2 = enable_imds_v2 self.metadata_token_duration = metadata_token_duration # credential type @@ -108,6 +110,8 @@ def to_map(self): result['privateKeyFile'] = self.private_key_file if self.role_name is not None: result['roleName'] = self.role_name + if self.disable_imds_v1 is not None: + result['disableIMDSv1'] = self.disable_imds_v1 if self.enable_imds_v2 is not None: result['enableIMDSv2'] = self.enable_imds_v2 if self.metadata_token_duration is not None: @@ -158,6 +162,8 @@ def from_map(self, m: dict = None): self.private_key_file = m.get('privateKeyFile') if m.get('roleName') is not None: self.role_name = m.get('roleName') + if m.get('disableIMDSv1') is not None: + self.disable_imds_v1 = m.get('disableIMDSv1') if m.get('enableIMDSv2') is not None: self.enable_imds_v2 = m.get('enableIMDSv2') if m.get('metadataTokenDuration') is not None: diff --git a/alibabacloud_credentials/providers.py b/alibabacloud_credentials/providers.py index 5f8ee36..ca2073a 100644 --- a/alibabacloud_credentials/providers.py +++ b/alibabacloud_credentials/providers.py @@ -29,8 +29,7 @@ def __init__(self, config=None): self.role_session_name = config.role_session_name self.public_key_id = config.public_key_id self.role_name = config.role_name - self.enable_imds_v2 = config.enable_imds_v2 - self.metadata_token_duration = config.metadata_token_duration + self.disable_imds_v1 = config.disable_imds_v1 self.oidc_provider_arn = config.oidc_provider_arn self.oidc_token_file_path = config.oidc_token_file_path self.private_key_file = config.private_key_file @@ -114,25 +113,30 @@ def __init__(self, role_name=None, config=None): self.__ecs_metadata_token_fetch_error_msg = "Failed to get token from ECS Metadata Service." self.__metadata_service_host = "100.100.100.200" self._set_arg('role_name', role_name) - self.__metadata_token = None - self.__stale_time = 0 - self.enable_imds_v2 = au.environment_ecs_meta_data_imds_v2_enable and au.environment_ecs_meta_data_imds_v2_enable.lower() == 'true' - self.metadata_token_duration = self.default_metadata_token_duration + self.disable_imds_v1 = au.environment_imds_v1_disabled and au.environment_imds_v1_disabled.lower() == 'true' + if isinstance(config, Config): - self.enable_imds_v2 = config.enable_imds_v2 - self.metadata_token_duration = config.metadata_token_duration + self.disable_imds_v1 = config.disable_imds_v1 is not None and config.disable_imds_v1 == True def _get_role_name(self, url=None): - url = url if url else f'http://{self.__metadata_service_host}{self.__url_in_ecs_metadata}' - response = requests.get(url, timeout=self.timeout / 1000) + tea_request = ph.get_new_request() + tea_request.headers['host'] = url if url else self.__metadata_service_host + metadata_token = self._get_metadata_token(url) + if metadata_token is not None: + tea_request.headers['X-aliyun-ecs-metadata-token'] = metadata_token + if not url: + tea_request.pathname = self.__url_in_ecs_metadata + response = TeaCore.do_action(tea_request) if response.status_code != 200: raise CredentialException(self.__ecs_metadata_fetch_error_msg + " HttpCode=" + str(response.status_code)) - response.encoding = 'utf-8' - self.role_name = response.text + self.role_name = response.body.decode('utf-8') async def _get_role_name_async(self, url=None): tea_request = ph.get_new_request() tea_request.headers['host'] = url if url else self.__metadata_service_host + metadata_token = await self._get_metadata_token_async(url) + if metadata_token is not None: + tea_request.headers['X-aliyun-ecs-metadata-token'] = metadata_token if not url: tea_request.pathname = self.__url_in_ecs_metadata response = await TeaCore.async_do_action(tea_request) @@ -140,45 +144,47 @@ async def _get_role_name_async(self, url=None): raise CredentialException(self.__ecs_metadata_fetch_error_msg + " HttpCode=" + str(response.status_code)) self.role_name = response.body.decode('utf-8') - def _need_to_refresh_token(self): - return int(time.mktime(time.localtime())) >= self.__stale_time - def _get_metadata_token(self, url=None): - if self._need_to_refresh_token(): - tmp_time = int(time.mktime(time.localtime())) + self.metadata_token_duration - tea_request = ph.get_new_request() - tea_request.method = 'PUT' - tea_request.headers['host'] = url if url else self.__metadata_service_host - tea_request.headers['X-aliyun-ecs-metadata-token-ttl-seconds'] = str(self.metadata_token_duration) - if not url: - tea_request.pathname = self.__url_in_ecs_metadata_token + tea_request = ph.get_new_request() + tea_request.method = 'PUT' + tea_request.headers['host'] = url if url else self.__metadata_service_host + tea_request.headers['X-aliyun-ecs-metadata-token-ttl-seconds'] = str(self.default_metadata_token_duration) + if not url: + tea_request.pathname = self.__url_in_ecs_metadata_token + try: response = TeaCore.do_action(tea_request) if response.status_code != 200: raise CredentialException( self.__ecs_metadata_token_fetch_error_msg + " HttpCode=" + str(response.status_code)) - self.__stale_time = tmp_time - self.__metadata_token = response.body.decode('utf-8') + return response.body.decode('utf-8') + except Exception as e: + if self.disable_imds_v1: + raise e + return None async def _get_metadata_token_async(self, url=None): - if self._need_to_refresh_token(): - tmp_time = int(time.mktime(time.localtime())) + self.metadata_token_duration - tea_request = ph.get_new_request() - tea_request.method = 'PUT' - tea_request.headers['host'] = url if url else self.__metadata_service_host - tea_request.headers['X-aliyun-ecs-metadata-token-ttl-seconds'] = str(self.metadata_token_duration) - if not url: - tea_request.pathname = self.__url_in_ecs_metadata_token + tea_request = ph.get_new_request() + tea_request.method = 'PUT' + tea_request.headers['host'] = url if url else self.__metadata_service_host + tea_request.headers['X-aliyun-ecs-metadata-token-ttl-seconds'] = str(self.default_metadata_token_duration) + if not url: + tea_request.pathname = self.__url_in_ecs_metadata_token + try: response = await TeaCore.async_do_action(tea_request) if response.status_code != 200: raise CredentialException( self.__ecs_metadata_token_fetch_error_msg + " HttpCode=" + str(response.status_code)) - self.__stale_time = tmp_time - self.__metadata_token = response.body.decode('utf-8') + return response.body.decode('utf-8') + except Exception as e: + if self.disable_imds_v1: + raise e + return None - def _create_credential(self, url=None, metadata_token=None): + def _create_credential(self, url=None): tea_request = ph.get_new_request() tea_request.headers['host'] = url if url else self.__metadata_service_host - if metadata_token: + metadata_token = self._get_metadata_token(url) + if metadata_token is not None: tea_request.headers['X-aliyun-ecs-metadata-token'] = metadata_token if not url: tea_request.pathname = self.__url_in_ecs_metadata + self.role_name @@ -208,15 +214,13 @@ def _create_credential(self, url=None, metadata_token=None): def get_credentials(self): if self.role_name == "": self._get_role_name() - if self.enable_imds_v2: - self._get_metadata_token() - return self._create_credential(metadata_token=self.__metadata_token) return self._create_credential() - async def _create_credential_async(self, url=None, metadata_token=None): + async def _create_credential_async(self, url=None): tea_request = ph.get_new_request() tea_request.headers['host'] = url if url else self.__metadata_service_host - if metadata_token: + metadata_token = await self._get_metadata_token_async(url) + if metadata_token is not None: tea_request.headers['X-aliyun-ecs-metadata-token'] = metadata_token if not url: tea_request.pathname = self.__url_in_ecs_metadata + self.role_name @@ -247,9 +251,6 @@ async def _create_credential_async(self, url=None, metadata_token=None): async def get_credentials_async(self): if self.role_name == "": await self._get_role_name_async() - if self.enable_imds_v2: - await self._get_metadata_token_async() - return await self._create_credential_async(metadata_token=self.__metadata_token) return await self._create_credential_async() diff --git a/alibabacloud_credentials/utils/auth_util.py b/alibabacloud_credentials/utils/auth_util.py index 189b15f..713f058 100644 --- a/alibabacloud_credentials/utils/auth_util.py +++ b/alibabacloud_credentials/utils/auth_util.py @@ -5,7 +5,7 @@ environment_access_key_secret = os.environ.get('ALIBABA_CLOUD_ACCESS_KEY_SECRET') environment_security_token = os.environ.get('ALIBABA_CLOUD_SECURITY_TOKEN') environment_ECSMeta_data = os.environ.get('ALIBABA_CLOUD_ECS_METADATA') -environment_ecs_meta_data_imds_v2_enable = os.environ.get('ALIBABA_CLOUD_ECS_IMDSV2_ENABLE') +environment_imds_v1_disabled = os.environ.get('ALIBABA_CLOUD_IMDSV1_DISABLED') environment_credentials_file = os.environ.get('ALIBABA_CLOUD_CREDENTIALS_FILE') environment_oidc_token_file = os.environ.get('ALIBABA_CLOUD_OIDC_TOKEN_FILE') environment_role_arn = os.environ.get('ALIBABA_CLOUD_ROLE_ARN') diff --git a/tests/test_model.py b/tests/test_model.py index dd5c888..bc20c16 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,6 +9,8 @@ def test_model_config(self): self.assertEqual('', conf1.access_key_secret) self.assertEqual('', conf1.role_name) self.assertEqual(1000, conf1.timeout) + self.assertEqual(1000, conf1.connect_timeout) + self.assertFalse(conf1.disable_imds_v1) self.assertIsNone(conf1.sts_endpoint) conf1.timeout = 0 diff --git a/tests/test_providers.py b/tests/test_providers.py index 6bca466..4fbc6c2 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -4,6 +4,7 @@ from Tea.exceptions import RetryError from alibabacloud_credentials import providers, models, credentials, exceptions +from alibabacloud_credentials.exceptions import CredentialException from alibabacloud_credentials.utils import auth_util from . import ini_file @@ -30,17 +31,39 @@ def do_PUT(self): self.wfile.write(b'token') +class RequestError(BaseHTTPRequestHandler): + def do_GET(self): + self.send_response(500) + self.send_header('Content-type', 'text/plain') + self.end_headers() + self.wfile.write(b'error') + + def do_PUT(self): + self.send_response(500) + self.send_header('Content-type', 'text/plain') + self.end_headers() + self.wfile.write(b'error') + + def run_server(): server = HTTPServer(('localhost', 8888), Request) server.serve_forever() +def run_server_error(): + server_error = HTTPServer(('localhost', 9999), RequestError) + server_error.serve_forever() + + class TestProviders(unittest.TestCase): @classmethod def setUpClass(cls): server = threading.Thread(target=run_server) server.setDaemon(True) server.start() + server_error = threading.Thread(target=run_server_error) + server_error.setDaemon(True) + server_error.start() @staticmethod def strftime(t): @@ -51,25 +74,25 @@ def test_EcsRamRoleCredentialProvider(self): self.assertIsNotNone(prov) self.assertEqual("roleName", prov.role_name) - auth_util.environment_ecs_meta_data_imds_v2_enable = 'False' + auth_util.environment_imds_v1_disabled = 'False' prov = providers.EcsRamRoleCredentialProvider("roleName") self.assertIsNotNone(prov) self.assertEqual("roleName", prov.role_name) - self.assertFalse(prov.enable_imds_v2) + self.assertFalse(prov.disable_imds_v1) - auth_util.environment_ecs_meta_data_imds_v2_enable = '1' + auth_util.environment_imds_v1_disabled = '1' prov = providers.EcsRamRoleCredentialProvider("roleName") self.assertIsNotNone(prov) self.assertEqual("roleName", prov.role_name) - self.assertFalse(prov.enable_imds_v2) + self.assertFalse(prov.disable_imds_v1) - auth_util.environment_ecs_meta_data_imds_v2_enable = 'True' + auth_util.environment_imds_v1_disabled = 'True' prov = providers.EcsRamRoleCredentialProvider("roleName") self.assertIsNotNone(prov) self.assertEqual("roleName", prov.role_name) - self.assertTrue(prov.enable_imds_v2) + self.assertTrue(prov.disable_imds_v1) - auth_util.environment_ecs_meta_data_imds_v2_enable = None + auth_util.environment_imds_v1_disabled = None cfg = models.Config() cfg.role_name = "roleNameConfig" @@ -79,29 +102,61 @@ def test_EcsRamRoleCredentialProvider(self): self.assertIsNotNone(prov) self.assertEqual("roleNameConfig", prov.role_name) self.assertEqual(2300, prov.timeout) + token = prov._get_metadata_token(url='127.0.0.1:8888') + self.assertEqual('token', token) + cred = prov._create_credential(url='127.0.0.1:8888') self.assertEqual('ak', cred.access_key_id) prov._get_role_name(url='http://127.0.0.1:8888') self.assertIsNotNone(prov.role_name) - cfg.enable_imds_v2 = True - cfg.metadata_token_duration = 180 + # request error + token = prov._get_metadata_token(url='127.0.0.1:9999') + self.assertIsNone(token) + try: + prov._create_credential(url='127.0.0.1:9999') + self.fail() + except CredentialException as e: + self.assertEqual('Failed to get RAM session credentials from ECS metadata service. HttpCode=500', e.message) + try: + prov._get_role_name(url='http://127.0.0.1:9999') + self.fail() + except CredentialException as e: + self.assertEqual('Failed to get RAM session credentials from ECS metadata service. HttpCode=500', e.message) + + + cfg.disable_imds_v1 = True prov = providers.EcsRamRoleCredentialProvider(config=cfg) self.assertIsNotNone(prov) - self.assertTrue(prov.enable_imds_v2) - self.assertEqual(180, prov.metadata_token_duration) + self.assertTrue(prov.disable_imds_v1) self.assertEqual("roleNameConfig", prov.role_name) self.assertEqual(2300, prov.timeout) prov._get_metadata_token(url='127.0.0.1:8888') cred = prov._create_credential(url='127.0.0.1:8888') - self.assertEqual("token", getattr(prov, '_EcsRamRoleCredentialProvider__metadata_token')) - self.assertNotEqual(0, getattr(prov, '_EcsRamRoleCredentialProvider__stale_time')) self.assertEqual('ak', cred.access_key_id) prov._get_role_name(url='http://127.0.0.1:8888') self.assertIsNotNone(prov.role_name) + # request error + try: + prov._get_metadata_token(url='127.0.0.1:9999') + self.fail() + except CredentialException as e: + self.assertEqual('Failed to get token from ECS Metadata Service. HttpCode=500', e.message) + try: + prov._create_credential(url='127.0.0.1:9999') + self.fail() + except CredentialException as e: + self.assertEqual('Failed to get token from ECS Metadata Service. HttpCode=500', e.message) + try: + prov._get_role_name(url='http://127.0.0.1:9999') + self.fail() + except CredentialException as e: + self.assertEqual('Failed to get token from ECS Metadata Service. HttpCode=500', e.message) + + def test_EcsRamRoleCredentialProvider_async(self): async def main(): prov = providers.EcsRamRoleCredentialProvider("roleName")