From cf0d2fee72e68940d0fde008b7496396a71920b5 Mon Sep 17 00:00:00 2001 From: yndu13 Date: Fri, 23 Jan 2026 18:22:38 +0800 Subject: [PATCH] fix: update OAuth token when has source profile --- .../provider/cli_profile.py | 64 +++-- tests/provider/test_cli_profile.py | 258 ++++++++++++++++++ 2 files changed, 293 insertions(+), 29 deletions(-) diff --git a/alibabacloud_credentials/provider/cli_profile.py b/alibabacloud_credentials/provider/cli_profile.py index e7a32fd..b7c52af 100644 --- a/alibabacloud_credentials/provider/cli_profile.py +++ b/alibabacloud_credentials/provider/cli_profile.py @@ -555,39 +555,45 @@ async def _update_oauth_tokens_async(self, refresh_token: str, access_token: str security_token: str, access_token_expire: int, sts_expire: int) -> None: """异步更新 OAuth 令牌并写回配置文件""" - try: - with self._file_lock: - cfg_path = self._profile_file - conf = await _load_config_async(cfg_path) + def _find_source_oauth_profile(config: dict, profile_name: str) -> dict: + profiles = config.get('profiles', []) + profile = next((p for p in profiles if p.get('name') == profile_name), None) + if not profile: + raise CredentialException(f"unable to get profile with name '{profile_name}' from cli credentials file.") + + if profile.get('mode') == 'OAuth': + return profile + elif source_profile := profile.get('source_profile'): + return _find_source_oauth_profile(config, source_profile) + + raise CredentialException(f"unable to get OAuth profile with name '{profile_name}' from cli credentials file.") + + with self._file_lock: + try: + # 读取现有配置 + config = await _load_config_async(self._profile_file) # 找到当前 profile 并更新 OAuth 令牌 - profile_name = self._profile_name + profile_name = self._profile_name or config.get('current') if not profile_name: - profile_name = conf.get('current') - profiles = conf.get('profiles', []) - profile_tag = False - for profile in profiles: - if profile.get('name') == profile_name: - profile_tag = True - # 更新 OAuth 相关字段 - profile['oauth_refresh_token'] = refresh_token - profile['oauth_access_token'] = access_token - profile['oauth_access_token_expire'] = access_token_expire - # 更新 STS 凭据 - profile['access_key_id'] = access_key - profile['access_key_secret'] = secret - profile['sts_token'] = security_token - profile['sts_expiration'] = sts_expire - break - - if not profile_tag: - raise CredentialException(f"Profile '{profile_name}' not found in config file") - - # 异步写回配置文件 - await self._write_configuration_to_file_with_lock_async(cfg_path, conf) + raise CredentialException(f"unable to get profile to updated.") - except Exception as e: - raise CredentialException(f"failed to update OAuth tokens in config file: {e}") + source_profile = _find_source_oauth_profile(config, profile_name) + + # 更新 OAuth 令牌 + source_profile['oauth_refresh_token'] = refresh_token + source_profile['oauth_access_token'] = access_token + source_profile['oauth_access_token_expire'] = access_token_expire + # 更新 STS 凭据 + source_profile['access_key_id'] = access_key + source_profile['access_key_secret'] = secret + source_profile['sts_token'] = security_token + source_profile['sts_expiration'] = sts_expire + + await self._write_configuration_to_file_with_lock_async(self._profile_file, config) + + except Exception as e: + raise CredentialException(f"failed to update OAuth tokens in config file: {e}") def _get_oauth_token_update_callback_async(self) -> OAuthTokenUpdateCallbackAsync: """获取异步 OAuth 令牌更新回调函数""" diff --git a/tests/provider/test_cli_profile.py b/tests/provider/test_cli_profile.py index ffbd802..f765b8d 100644 --- a/tests/provider/test_cli_profile.py +++ b/tests/provider/test_cli_profile.py @@ -1446,3 +1446,261 @@ async def run_test(): finally: import shutil shutil.rmtree(temp_dir, ignore_errors=True) + + def test_chainable_ram_role_with_oauth_source_profile_sync(self): + """测试 ChainableRamRoleArn 使用 OAuth source profile 时,token 更新到正确的 profile (同步)""" + import tempfile + import json + import time + + # 创建临时配置文件 + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + # 配置:OAuth profile <- ChainableRamRoleArn profile + test_config = { + "current": "chainable_oauth", + "profiles": [ + { + "name": "oauth_source", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_refresh_token", + "oauth_access_token": "initial_access_token", + "oauth_access_token_expire": int(time.time()) + 3600 + }, + { + "name": "chainable_oauth", + "mode": "ChainableRamRoleArn", + "source_profile": "oauth_source", + "ram_role_arn": "acs:ram::123456789012:role/test-role", + "ram_session_name": "test-session" + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="chainable_oauth", + profile_file=config_path, + allow_config_force_rewrite=True + ) + + # 更新令牌(模拟 OAuth 刷新) + new_refresh_token = "new_refresh_token" + new_access_token = "new_access_token" + new_access_key = "new_access_key" + new_secret = "new_secret" + new_security_token = "new_security_token" + new_expire_time = int(time.time()) + 7200 + new_sts_expire = int(time.time()) + 10800 + + provider._update_oauth_tokens(new_refresh_token, new_access_token, new_access_key, new_secret, + new_security_token, new_expire_time, new_sts_expire) + + # 验证配置文件已更新到 OAuth source profile + with open(config_path, 'r') as f: + updated_config = json.load(f) + + # 找到 oauth_source profile + oauth_profile = next((p for p in updated_config['profiles'] if p.get('name') == 'oauth_source'), None) + self.assertIsNotNone(oauth_profile, "OAuth source profile should exist") + + # 验证 OAuth tokens 更新到了 source profile + self.assertEqual(oauth_profile['oauth_refresh_token'], new_refresh_token) + self.assertEqual(oauth_profile['oauth_access_token'], new_access_token) + self.assertEqual(oauth_profile['access_key_id'], new_access_key) + self.assertEqual(oauth_profile['access_key_secret'], new_secret) + self.assertEqual(oauth_profile['sts_token'], new_security_token) + self.assertEqual(oauth_profile['oauth_access_token_expire'], new_expire_time) + self.assertEqual(oauth_profile['sts_expiration'], new_sts_expire) + + # 验证 chainable profile 没有被错误更新 + chainable_profile = next((p for p in updated_config['profiles'] if p.get('name') == 'chainable_oauth'), None) + self.assertIsNotNone(chainable_profile, "Chainable profile should exist") + self.assertNotIn('oauth_refresh_token', chainable_profile, "Chainable profile should not have oauth_refresh_token") + self.assertNotIn('oauth_access_token', chainable_profile, "Chainable profile should not have oauth_access_token") + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_chainable_ram_role_with_oauth_source_profile_async(self): + """测试 ChainableRamRoleArn 使用 OAuth source profile 时,token 更新到正确的 profile (异步)""" + import tempfile + import json + import time + + # 创建临时配置文件 + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + # 配置:OAuth profile <- ChainableRamRoleArn profile + test_config = { + "current": "chainable_oauth", + "profiles": [ + { + "name": "oauth_source", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_refresh_token", + "oauth_access_token": "initial_access_token", + "oauth_access_token_expire": int(time.time()) + 3600 + }, + { + "name": "chainable_oauth", + "mode": "ChainableRamRoleArn", + "source_profile": "oauth_source", + "ram_role_arn": "acs:ram::123456789012:role/test-role", + "ram_session_name": "test-session" + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="chainable_oauth", + profile_file=config_path, + allow_config_force_rewrite=True + ) + + # 更新令牌(模拟 OAuth 刷新) + new_refresh_token = "async_new_refresh_token" + new_access_token = "async_new_access_token" + new_access_key = "async_new_access_key" + new_secret = "async_new_secret" + new_security_token = "async_new_security_token" + new_expire_time = int(time.time()) + 7200 + new_sts_expire = int(time.time()) + 10800 + + async def run_test(): + await provider._update_oauth_tokens_async(new_refresh_token, new_access_token, new_access_key, + new_secret, new_security_token, new_expire_time, + new_sts_expire) + + # 使用 asyncio.run() 替代 get_event_loop() + asyncio.run(run_test()) + + # 验证配置文件已更新到 OAuth source profile + with open(config_path, 'r') as f: + updated_config = json.load(f) + + # 找到 oauth_source profile + oauth_profile = next((p for p in updated_config['profiles'] if p.get('name') == 'oauth_source'), None) + self.assertIsNotNone(oauth_profile, "OAuth source profile should exist") + + # 验证 OAuth tokens 更新到了 source profile + self.assertEqual(oauth_profile['oauth_refresh_token'], new_refresh_token) + self.assertEqual(oauth_profile['oauth_access_token'], new_access_token) + self.assertEqual(oauth_profile['access_key_id'], new_access_key) + self.assertEqual(oauth_profile['access_key_secret'], new_secret) + self.assertEqual(oauth_profile['sts_token'], new_security_token) + self.assertEqual(oauth_profile['oauth_access_token_expire'], new_expire_time) + self.assertEqual(oauth_profile['sts_expiration'], new_sts_expire) + + # 验证 chainable profile 没有被错误更新 + chainable_profile = next((p for p in updated_config['profiles'] if p.get('name') == 'chainable_oauth'), None) + self.assertIsNotNone(chainable_profile, "Chainable profile should exist") + self.assertNotIn('oauth_refresh_token', chainable_profile, "Chainable profile should not have oauth_refresh_token") + self.assertNotIn('oauth_access_token', chainable_profile, "Chainable profile should not have oauth_access_token") + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_nested_chainable_ram_role_with_oauth_source_profile_sync(self): + """测试多层嵌套 ChainableRamRoleArn 使用 OAuth source profile 时,token 更新到正确的 profile (同步)""" + import tempfile + import json + import time + + # 创建临时配置文件 + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + # 配置:OAuth profile <- ChainableRamRoleArn level1 <- ChainableRamRoleArn level2 + test_config = { + "current": "chainable_level2", + "profiles": [ + { + "name": "oauth_source", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_refresh_token", + "oauth_access_token": "initial_access_token", + "oauth_access_token_expire": int(time.time()) + 3600 + }, + { + "name": "chainable_level1", + "mode": "ChainableRamRoleArn", + "source_profile": "oauth_source", + "ram_role_arn": "acs:ram::123456789012:role/level1-role", + "ram_session_name": "level1-session" + }, + { + "name": "chainable_level2", + "mode": "ChainableRamRoleArn", + "source_profile": "chainable_level1", + "ram_role_arn": "acs:ram::123456789012:role/level2-role", + "ram_session_name": "level2-session" + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="chainable_level2", + profile_file=config_path, + allow_config_force_rewrite=True + ) + + # 更新令牌(模拟 OAuth 刷新) + new_refresh_token = "nested_refresh_token" + new_access_token = "nested_access_token" + new_access_key = "nested_access_key" + new_secret = "nested_secret" + new_security_token = "nested_security_token" + new_expire_time = int(time.time()) + 7200 + new_sts_expire = int(time.time()) + 10800 + + provider._update_oauth_tokens(new_refresh_token, new_access_token, new_access_key, new_secret, + new_security_token, new_expire_time, new_sts_expire) + + # 验证配置文件已更新到最深层的 OAuth source profile + with open(config_path, 'r') as f: + updated_config = json.load(f) + + # 找到 oauth_source profile + oauth_profile = next((p for p in updated_config['profiles'] if p.get('name') == 'oauth_source'), None) + self.assertIsNotNone(oauth_profile, "OAuth source profile should exist") + + # 验证 OAuth tokens 更新到了最深层的 source profile + self.assertEqual(oauth_profile['oauth_refresh_token'], new_refresh_token) + self.assertEqual(oauth_profile['oauth_access_token'], new_access_token) + self.assertEqual(oauth_profile['access_key_id'], new_access_key) + self.assertEqual(oauth_profile['access_key_secret'], new_secret) + self.assertEqual(oauth_profile['sts_token'], new_security_token) + self.assertEqual(oauth_profile['oauth_access_token_expire'], new_expire_time) + self.assertEqual(oauth_profile['sts_expiration'], new_sts_expire) + + # 验证中间的 chainable profiles 没有被错误更新 + level1_profile = next((p for p in updated_config['profiles'] if p.get('name') == 'chainable_level1'), None) + self.assertIsNotNone(level1_profile, "Level1 chainable profile should exist") + self.assertNotIn('oauth_refresh_token', level1_profile, "Level1 profile should not have oauth_refresh_token") + + level2_profile = next((p for p in updated_config['profiles'] if p.get('name') == 'chainable_level2'), None) + self.assertIsNotNone(level2_profile, "Level2 chainable profile should exist") + self.assertNotIn('oauth_refresh_token', level2_profile, "Level2 profile should not have oauth_refresh_token") + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True)