diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index d9983c29d6b4..bb49d3811472 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2927,11 +2927,19 @@ dependencies = [ "codex-api", "codex-aws-auth", "codex-client", + "codex-feedback", "codex-login", "codex-model-provider-info", + "codex-models-manager", + "codex-otel", "codex-protocol", + "codex-response-debug-context", "http 1.4.0", "pretty_assertions", + "serde_json", + "tokio", + "tracing", + "wiremock", ] [[package]] @@ -2955,32 +2963,21 @@ dependencies = [ name = "codex-models-manager" version = "0.0.0" dependencies = [ - "base64 0.22.1", + "async-trait", "chrono", - "codex-api", "codex-app-server-protocol", "codex-collaboration-mode-templates", - "codex-config", - "codex-feedback", "codex-login", - "codex-model-provider", - "codex-model-provider-info", "codex-otel", "codex-protocol", - "codex-response-debug-context", - "codex-utils-absolute-path", "codex-utils-output-truncation", "codex-utils-template", - "core_test_support", - "http 1.4.0", "pretty_assertions", "serde", "serde_json", "tempfile", "tokio", "tracing", - "tracing-subscriber", - "wiremock", ] [[package]] diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 347767e781da..4d7f3c9a5ac9 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -305,6 +305,7 @@ use codex_mcp::resolve_oauth_scopes; use codex_model_provider::ProviderAccountError; use codex_model_provider::create_model_provider; use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig; +use codex_models_manager::collaboration_mode_presets::builtin_collaboration_mode_presets; use codex_protocol::ThreadId; use codex_protocol::config_types::CollaborationMode; use codex_protocol::config_types::ForcedLoginMethod; @@ -792,14 +793,12 @@ impl CodexMessageProcessor { collaboration_modes_config: CollaborationModesConfig, ) -> CollaborationMode { if collaboration_mode.settings.developer_instructions.is_none() - && let Some(instructions) = self - .thread_manager - .get_models_manager() - .list_collaboration_modes_for_config(collaboration_modes_config) - .into_iter() - .find(|preset| preset.mode == Some(collaboration_mode.mode)) - .and_then(|preset| preset.developer_instructions.flatten()) - .filter(|instructions| !instructions.is_empty()) + && let Some(instructions) = + builtin_collaboration_mode_presets(collaboration_modes_config) + .into_iter() + .find(|preset| preset.mode == Some(collaboration_mode.mode)) + .and_then(|preset| preset.developer_instructions.flatten()) + .filter(|instructions| !instructions.is_empty()) { collaboration_mode.settings.developer_instructions = Some(instructions); } diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index f378afad2c0c..2481ecd6fe9e 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -61,7 +61,7 @@ use codex_core::config::find_codex_home; use codex_features::FEATURES; use codex_features::Stage; use codex_features::is_known_feature_key; -use codex_models_manager::AuthManager; +use codex_login::AuthManager; use codex_models_manager::bundled_models_response; use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig; use codex_models_manager::manager::RefreshStrategy; diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index eb3876f60e7a..1fb2f42f2e3c 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -49,7 +49,7 @@ use crate::session::session::Session; use crate::session::turn_context::TurnContext; use crate::session::turn_context::TurnEnvironment; use codex_login::AuthManager; -use codex_models_manager::manager::ModelsManager; +use codex_models_manager::manager::SharedModelsManager; use codex_protocol::error::CodexErr; use codex_protocol::protocol::InitialHistory; @@ -65,7 +65,7 @@ use crate::session::completed_session_loop_termination; pub(crate) async fn run_codex_thread_interactive( config: Config, auth_manager: Arc, - models_manager: Arc, + models_manager: SharedModelsManager, parent_session: Arc, parent_ctx: Arc, cancel_token: CancellationToken, @@ -165,7 +165,7 @@ pub(crate) async fn run_codex_thread_interactive( pub(crate) async fn run_codex_thread_one_shot( config: Config, auth_manager: Arc, - models_manager: Arc, + models_manager: SharedModelsManager, input: Vec, parent_session: Arc, parent_ctx: Arc, diff --git a/codex-rs/core/src/guardian/tests.rs b/codex-rs/core/src/guardian/tests.rs index c679f605a8f4..c78884bcea72 100644 --- a/codex-rs/core/src/guardian/tests.rs +++ b/codex-rs/core/src/guardian/tests.rs @@ -151,11 +151,11 @@ async fn guardian_test_session_and_turn_with_base_url( config.model_provider.base_url = Some(format!("{base_url}/v1")); config.user_instructions = None; let config = Arc::new(config); - let models_manager = Arc::new(test_support::models_manager_with_provider( + let models_manager = test_support::models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); session.services.models_manager = models_manager; turn.config = Arc::clone(&config); turn.provider = create_model_provider(config.model_provider.clone(), turn.auth_manager.clone()); @@ -1134,11 +1134,11 @@ async fn guardian_review_request_layout_matches_model_visible_request_snapshot() config.cwd = temp_cwd.abs(); config.model_provider.base_url = Some(format!("{}/v1", server.uri())); let config = Arc::new(config); - let models_manager = Arc::new(test_support::models_manager_with_provider( + let models_manager = test_support::models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); session.services.models_manager = models_manager; turn.config = Arc::clone(&config); turn.provider = create_model_provider(config.model_provider.clone(), turn.auth_manager.clone()); @@ -1606,11 +1606,11 @@ async fn guardian_review_surfaces_responses_api_errors_in_rejection_reason() -> config.model_provider.base_url = Some(format!("{}/v1", server.uri())); config.user_instructions = None; let config = Arc::new(config); - let models_manager = Arc::new(test_support::models_manager_with_provider( + let models_manager = test_support::models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); Arc::get_mut(&mut session) .expect("session should be uniquely owned") .services diff --git a/codex-rs/core/src/mcp_tool_call_tests.rs b/codex-rs/core/src/mcp_tool_call_tests.rs index 522686446b01..7dcc1eabe68d 100644 --- a/codex-rs/core/src/mcp_tool_call_tests.rs +++ b/codex-rs/core/src/mcp_tool_call_tests.rs @@ -3,6 +3,7 @@ use crate::config::ConfigBuilder; use crate::session::tests::make_session_and_context; use crate::session::tests::make_session_and_context_with_rx; use crate::state::ActiveTurn; +use crate::test_support::models_manager_with_provider; use codex_config::CONFIG_TOML_FILE; use codex_config::config_toml::ConfigToml; use codex_config::types::AppConfig; @@ -1491,11 +1492,11 @@ async fn guardian_mode_skips_auto_when_annotations_do_not_require_approval() { config.model_provider.base_url = Some(format!("{}/v1", server.uri())); config.approvals_reviewer = ApprovalsReviewer::AutoReview; let config = Arc::new(config); - let models_manager = Arc::new(crate::test_support::models_manager_with_provider( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); session.services.models_manager = models_manager; turn_context.config = Arc::clone(&config); turn_context.provider = create_model_provider( @@ -1768,11 +1769,11 @@ async fn guardian_mode_mcp_denial_returns_rationale_message() { config.model_provider.base_url = Some(format!("{}/v1", server.uri())); config.approvals_reviewer = ApprovalsReviewer::AutoReview; let config = Arc::new(config); - let models_manager = Arc::new(crate::test_support::models_manager_with_provider( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); session.services.models_manager = models_manager; turn_context.config = Arc::clone(&config); turn_context.provider = create_model_provider( @@ -2231,11 +2232,11 @@ async fn approve_mode_routes_arc_ask_user_to_guardian_when_guardian_reviewer_is_ config.model_provider.base_url = Some(format!("{}/v1", server.uri())); config.approvals_reviewer = ApprovalsReviewer::AutoReview; let config = Arc::new(config); - let models_manager = Arc::new(crate::test_support::models_manager_with_provider( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); session.services.models_manager = models_manager; turn_context.config = Arc::clone(&config); turn_context.provider = create_model_provider( diff --git a/codex-rs/core/src/mcp_tool_exposure_test.rs b/codex-rs/core/src/mcp_tool_exposure_test.rs index 3372291df9ac..18bb97642ab4 100644 --- a/codex-rs/core/src/mcp_tool_exposure_test.rs +++ b/codex-rs/core/src/mcp_tool_exposure_test.rs @@ -6,7 +6,7 @@ use codex_features::Feature; use codex_features::Features; use codex_mcp::CODEX_APPS_MCP_SERVER_NAME; use codex_mcp::ToolInfo; -use codex_models_manager::manager::ModelsManager; +use codex_models_manager::test_support::construct_model_info_offline_for_tests; use codex_protocol::config_types::WebSearchMode; use codex_protocol::config_types::WindowsSandboxLevel; use codex_protocol::protocol::SandboxPolicy; @@ -93,10 +93,8 @@ fn numbered_mcp_tools(count: usize) -> HashMap { async fn tools_config_for_mcp_tool_exposure(search_tool: bool) -> ToolsConfig { let config = test_config().await; - let model_info = ModelsManager::construct_model_info_offline_for_tests( - "gpt-5.4", - &config.to_models_manager_config(), - ); + let model_info = + construct_model_info_offline_for_tests("gpt-5.4", &config.to_models_manager_config()); let features = Features::with_defaults(); let available_models = Vec::new(); let mut tools_config = ToolsConfig::new(&ToolsConfigParams { diff --git a/codex-rs/core/src/session/mod.rs b/codex-rs/core/src/session/mod.rs index e5c675f16a95..0fffe1a29d94 100644 --- a/codex-rs/core/src/session/mod.rs +++ b/codex-rs/core/src/session/mod.rs @@ -66,10 +66,8 @@ use codex_mcp::McpConnectionManager; use codex_mcp::McpRuntimeEnvironment; use codex_mcp::ToolInfo; use codex_mcp::codex_apps_tools_cache_key; -#[cfg(test)] -use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig; -use codex_models_manager::manager::ModelsManager; use codex_models_manager::manager::RefreshStrategy; +use codex_models_manager::manager::SharedModelsManager; use codex_network_proxy::NetworkProxy; use codex_network_proxy::NetworkProxyAuditMetadata; use codex_network_proxy::normalize_host; @@ -391,7 +389,7 @@ pub struct CodexSpawnOk { pub(crate) struct CodexSpawnArgs { pub(crate) config: Config, pub(crate) auth_manager: Arc, - pub(crate) models_manager: Arc, + pub(crate) models_manager: SharedModelsManager, pub(crate) environment_manager: Arc, pub(crate) skills_manager: Arc, pub(crate) plugins_manager: Arc, diff --git a/codex-rs/core/src/session/session.rs b/codex-rs/core/src/session/session.rs index 2918e21a03e4..f7a36c12825b 100644 --- a/codex-rs/core/src/session/session.rs +++ b/codex-rs/core/src/session/session.rs @@ -270,7 +270,7 @@ impl Session { mut session_configuration: SessionConfiguration, config: Arc, auth_manager: Arc, - models_manager: Arc, + models_manager: SharedModelsManager, exec_policy: Arc, tx_event: Sender, agent_status: watch::Sender, diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index 1b18c8378d13..a36b3a421ab9 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -17,6 +17,7 @@ use crate::function_tool::FunctionCallError; use crate::shell::default_user_shell; use crate::skills::SkillRenderSideEffects; use crate::skills::render::SkillMetadataBudget; +use crate::test_support::models_manager_with_provider; use crate::tools::format_exec_output_str; use codex_features::Feature; @@ -25,6 +26,8 @@ use codex_login::CodexAuth; use codex_model_provider_info::ModelProviderInfo; use codex_models_manager::bundled_models_response; use codex_models_manager::model_info; +use codex_models_manager::test_support::construct_model_info_offline_for_tests; +use codex_models_manager::test_support::get_model_offline_for_tests; use codex_protocol::AgentPath; use codex_protocol::ThreadId; use codex_protocol::account::PlanType as AccountPlanType; @@ -2203,11 +2206,9 @@ async fn set_rate_limits_retains_previous_credits() { let codex_home = tempfile::tempdir().expect("create temp dir"); let config = build_test_config(codex_home.path()).await; let config = Arc::new(config); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - model.as_str(), - &config.to_models_manager_config(), - ); + let model = get_model_offline_for_tests(config.model.as_deref()); + let model_info = + construct_model_info_offline_for_tests(model.as_str(), &config.to_models_manager_config()); let reasoning_effort = config.model_reasoning_effort; let collaboration_mode = CollaborationMode { mode: ModeKind::Default, @@ -2309,11 +2310,9 @@ async fn set_rate_limits_updates_plan_type_when_present() { let codex_home = tempfile::tempdir().expect("create temp dir"); let config = build_test_config(codex_home.path()).await; let config = Arc::new(config); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - model.as_str(), - &config.to_models_manager_config(), - ); + let model = get_model_offline_for_tests(config.model.as_deref()); + let model_info = + construct_model_info_offline_for_tests(model.as_str(), &config.to_models_manager_config()); let reasoning_effort = config.model_reasoning_effort; let collaboration_mode = CollaborationMode { mode: ModeKind::Default, @@ -2645,7 +2644,7 @@ fn session_telemetry( ) -> SessionTelemetry { SessionTelemetry::new( conversation_id, - ModelsManager::get_model_offline_for_tests(config.model.as_deref()).as_str(), + get_model_offline_for_tests(config.model.as_deref()).as_str(), model_info.slug.as_str(), /*account_id*/ None, Some("test@test.com".to_string()), @@ -2759,11 +2758,9 @@ pub(crate) async fn make_session_configuration_for_tests() -> SessionConfigurati let codex_home = tempfile::tempdir().expect("create temp dir"); let config = build_test_config(codex_home.path()).await; let config = Arc::new(config); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - model.as_str(), - &config.to_models_manager_config(), - ); + let model = get_model_offline_for_tests(config.model.as_deref()); + let model_info = + construct_model_info_offline_for_tests(model.as_str(), &config.to_models_manager_config()); let reasoning_effort = config.model_reasoning_effort; let collaboration_mode = CollaborationMode { mode: ModeKind::Default, @@ -3084,17 +3081,14 @@ async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() { let config = Arc::new(config); let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), auth_manager.clone(), - /*model_catalog*/ None, - CollaborationModesConfig::default(), - )); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - model.as_str(), - &config.to_models_manager_config(), + config.model_provider.clone(), ); + let model = get_model_offline_for_tests(config.model.as_deref()); + let model_info = + construct_model_info_offline_for_tests(model.as_str(), &config.to_models_manager_config()); let collaboration_mode = CollaborationMode { mode: ModeKind::Default, settings: Settings { @@ -3185,20 +3179,17 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { let config = Arc::new(config); let conversation_id = ThreadId::default(); let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), auth_manager.clone(), - /*model_catalog*/ None, - CollaborationModesConfig::default(), - )); + config.model_provider.clone(), + ); let agent_control = AgentControl::default(); let exec_policy = Arc::new(ExecPolicyManager::default()); let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - model.as_str(), - &config.to_models_manager_config(), - ); + let model = get_model_offline_for_tests(config.model.as_deref()); + let model_info = + construct_model_info_offline_for_tests(model.as_str(), &config.to_models_manager_config()); let reasoning_effort = config.model_reasoning_effort; let collaboration_mode = CollaborationMode { mode: ModeKind::Default, @@ -3247,7 +3238,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) { }; let per_turn_config = Session::build_per_turn_config(&session_configuration, session_configuration.cwd.clone()); - let model_info = ModelsManager::construct_model_info_offline_for_tests( + let model_info = construct_model_info_offline_for_tests( session_configuration.collaboration_mode.model(), &per_turn_config.to_models_manager_config(), ); @@ -3412,17 +3403,14 @@ async fn make_session_with_config_and_rx( mutator(&mut config); let config = Arc::new(config); let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), auth_manager.clone(), - /*model_catalog*/ None, - CollaborationModesConfig::default(), - )); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - model.as_str(), - &config.to_models_manager_config(), + config.model_provider.clone(), ); + let model = get_model_offline_for_tests(config.model.as_deref()); + let model_info = + construct_model_info_offline_for_tests(model.as_str(), &config.to_models_manager_config()); let collaboration_mode = CollaborationMode { mode: ModeKind::Default, settings: Settings { @@ -4554,20 +4542,17 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( let config = Arc::new(config); let conversation_id = ThreadId::default(); let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), auth_manager.clone(), - /*model_catalog*/ None, - CollaborationModesConfig::default(), - )); + config.model_provider.clone(), + ); let agent_control = AgentControl::default(); let exec_policy = Arc::new(ExecPolicyManager::default()); let (agent_status_tx, _agent_status_rx) = watch::channel(AgentStatus::PendingInit); - let model = ModelsManager::get_model_offline_for_tests(config.model.as_deref()); - let model_info = ModelsManager::construct_model_info_offline_for_tests( - model.as_str(), - &config.to_models_manager_config(), - ); + let model = get_model_offline_for_tests(config.model.as_deref()); + let model_info = + construct_model_info_offline_for_tests(model.as_str(), &config.to_models_manager_config()); let reasoning_effort = config.model_reasoning_effort; let collaboration_mode = CollaborationMode { mode: ModeKind::Default, @@ -4616,7 +4601,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx( }; let per_turn_config = Session::build_per_turn_config(&session_configuration, session_configuration.cwd.clone()); - let model_info = ModelsManager::construct_model_info_offline_for_tests( + let model_info = construct_model_info_offline_for_tests( session_configuration.collaboration_mode.model(), &per_turn_config.to_models_manager_config(), ); diff --git a/codex-rs/core/src/session/tests/guardian_tests.rs b/codex-rs/core/src/session/tests/guardian_tests.rs index 7844070acf5c..6423bee28d53 100644 --- a/codex-rs/core/src/session/tests/guardian_tests.rs +++ b/codex-rs/core/src/session/tests/guardian_tests.rs @@ -8,6 +8,7 @@ use crate::exec::ExecParams; use crate::exec_policy::ExecPolicyManager; use crate::guardian::GUARDIAN_REVIEWER_NAME; use crate::sandboxing::SandboxPermissions; +use crate::test_support::models_manager_with_provider; use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolCallSource; use crate::turn_diff_tracker::TurnDiffTracker; @@ -92,11 +93,11 @@ async fn request_permissions_routes_to_guardian_when_reviewer_is_enabled() { config.approvals_reviewer = ApprovalsReviewer::AutoReview; config.model_provider.base_url = Some(format!("{}/v1", server.uri())); let config = Arc::new(config); - let models_manager = Arc::new(crate::test_support::models_manager_with_provider( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); session.services.models_manager = models_manager; turn_context_raw.config = Arc::clone(&config); turn_context_raw.provider = create_model_provider( @@ -171,11 +172,11 @@ async fn request_permissions_guardian_review_stops_when_cancelled() { config.approvals_reviewer = ApprovalsReviewer::AutoReview; config.model_provider.base_url = Some(format!("{}/v1", server.uri())); let config = Arc::new(config); - let models_manager = Arc::new(crate::test_support::models_manager_with_provider( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); Arc::get_mut(&mut session) .expect("single session ref") .services @@ -287,11 +288,11 @@ async fn guardian_allows_shell_additional_permissions_requests_past_policy_valid let mut config = (*turn_context_raw.config).clone(); config.model_provider.base_url = Some(format!("{}/v1", server.uri())); let config = Arc::new(config); - let models_manager = Arc::new(crate::test_support::models_manager_with_provider( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); session.services.models_manager = models_manager; turn_context_raw.config = Arc::clone(&config); turn_context_raw.provider = create_model_provider( @@ -440,11 +441,11 @@ async fn strict_auto_review_turn_grant_forces_guardian_for_shell_policy_skip() { config.approvals_reviewer = ApprovalsReviewer::User; config.model_provider.base_url = Some(format!("{}/v1", server.uri())); let config = Arc::new(config); - let models_manager = Arc::new(crate::test_support::models_manager_with_provider( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), Arc::clone(&session.services.auth_manager), config.model_provider.clone(), - )); + ); session.services.models_manager = models_manager; turn_context_raw.config = Arc::clone(&config); turn_context_raw.provider = create_model_provider( @@ -736,12 +737,11 @@ async fn guardian_subagent_does_not_inherit_parent_exec_policy_rules() { ); let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let models_manager = Arc::new(ModelsManager::new( + let models_manager = models_manager_with_provider( config.codex_home.to_path_buf(), auth_manager.clone(), - /*model_catalog*/ None, - CollaborationModesConfig::default(), - )); + config.model_provider.clone(), + ); let plugins_manager = Arc::new(PluginsManager::new(config.codex_home.to_path_buf())); let skills_manager = Arc::new(SkillsManager::new( config.codex_home.clone(), diff --git a/codex-rs/core/src/session/turn_context.rs b/codex-rs/core/src/session/turn_context.rs index 14e3ce4b7c7b..0898bd89af51 100644 --- a/codex-rs/core/src/session/turn_context.rs +++ b/codex-rs/core/src/session/turn_context.rs @@ -117,7 +117,11 @@ impl TurnContext { self.features.apps_enabled_for_auth(uses_codex_backend) } - pub(crate) async fn with_model(&self, model: String, models_manager: &ModelsManager) -> Self { + pub(crate) async fn with_model( + &self, + model: String, + models_manager: &SharedModelsManager, + ) -> Self { let mut config = (*self.config).clone(); config.model = Some(model.clone()); let model_info = models_manager @@ -381,7 +385,7 @@ impl Session { main_execve_wrapper_exe: Option<&PathBuf>, per_turn_config: Config, model_info: ModelInfo, - models_manager: &ModelsManager, + models_manager: &SharedModelsManager, network: Option, environment: Option>, environments: Vec, diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index e3086f14a729..fe27d89ae084 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -20,7 +20,7 @@ use codex_exec_server::EnvironmentManager; use codex_hooks::Hooks; use codex_login::AuthManager; use codex_mcp::McpConnectionManager; -use codex_models_manager::manager::ModelsManager; +use codex_models_manager::manager::SharedModelsManager; use codex_otel::SessionTelemetry; use codex_rollout::state_db::StateDbHandle; use codex_rollout_trace::ThreadTraceContext; @@ -49,7 +49,7 @@ pub(crate) struct SessionServices { pub(crate) show_raw_agent_reasoning: bool, pub(crate) exec_policy: Arc, pub(crate) auth_manager: Arc, - pub(crate) models_manager: Arc, + pub(crate) models_manager: SharedModelsManager, pub(crate) session_telemetry: SessionTelemetry, pub(crate) tool_approvals: Mutex, pub(crate) guardian_rejections: Mutex>, diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index b9621d8fabc3..d3142bf779f9 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -31,7 +31,7 @@ use crate::state::RunningTask; use crate::state::TaskKind; use codex_analytics::TurnTokenUsageFact; use codex_login::AuthManager; -use codex_models_manager::manager::ModelsManager; +use codex_models_manager::manager::SharedModelsManager; use codex_otel::SessionTelemetry; use codex_otel::TURN_E2E_DURATION_METRIC; use codex_otel::TURN_MEMORY_METRIC; @@ -128,7 +128,7 @@ impl SessionTaskContext { Arc::clone(&self.session.services.auth_manager) } - pub(crate) fn models_manager(&self) -> Arc { + pub(crate) fn models_manager(&self) -> SharedModelsManager { Arc::clone(&self.session.services.models_manager) } } diff --git a/codex-rs/core/src/test_support.rs b/codex-rs/core/src/test_support.rs index 804f84208b6a..0cb0e9d0cde1 100644 --- a/codex-rs/core/src/test_support.rs +++ b/codex-rs/core/src/test_support.rs @@ -10,10 +10,13 @@ use std::sync::Arc; use codex_exec_server::EnvironmentManager; use codex_login::AuthManager; use codex_login::CodexAuth; +use codex_model_provider::create_model_provider; use codex_model_provider_info::ModelProviderInfo; use codex_models_manager::bundled_models_response; use codex_models_manager::collaboration_mode_presets; -use codex_models_manager::manager::ModelsManager; +use codex_models_manager::manager::SharedModelsManager; +use codex_models_manager::test_support::construct_model_info_offline_for_tests; +use codex_models_manager::test_support::get_model_offline_for_tests; use codex_protocol::config_types::CollaborationModeMask; use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ModelPreset; @@ -101,16 +104,21 @@ pub fn models_manager_with_provider( codex_home: PathBuf, auth_manager: Arc, provider: ModelProviderInfo, -) -> ModelsManager { - ModelsManager::with_provider_for_tests(codex_home, auth_manager, provider) +) -> SharedModelsManager { + let provider = create_model_provider(provider, Some(auth_manager)); + provider.models_manager( + codex_home, + /*config_model_catalog*/ None, + Default::default(), + ) } pub fn get_model_offline(model: Option<&str>) -> String { - ModelsManager::get_model_offline_for_tests(model) + get_model_offline_for_tests(model) } pub fn construct_model_info_offline(model: &str, config: &Config) -> ModelInfo { - ModelsManager::construct_model_info_offline_for_tests(model, &config.to_models_manager_config()) + construct_model_info_offline_for_tests(model, &config.to_models_manager_config()) } pub fn all_model_presets() -> &'static Vec { diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index 3988360b9671..2e6ea5f9eb99 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -24,11 +24,11 @@ use codex_app_server_protocol::TurnStatus; use codex_exec_server::EnvironmentManager; use codex_login::AuthManager; use codex_login::CodexAuth; +use codex_model_provider::create_model_provider; use codex_model_provider_info::ModelProviderInfo; -use codex_model_provider_info::OPENAI_PROVIDER_ID; use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig; -use codex_models_manager::manager::ModelsManager; use codex_models_manager::manager::RefreshStrategy; +use codex_models_manager::manager::SharedModelsManager; use codex_protocol::ThreadId; use codex_protocol::config_types::CollaborationModeMask; use codex_protocol::error::CodexErr; @@ -224,7 +224,7 @@ pub(crate) struct ThreadManagerState { threads: Arc>>>, thread_created_tx: broadcast::Sender, auth_manager: Arc, - models_manager: Arc, + models_manager: SharedModelsManager, environment_manager: Arc, skills_manager: Arc, plugins_manager: Arc, @@ -240,20 +240,13 @@ pub fn build_models_manager( config: &Config, auth_manager: Arc, collaboration_modes_config: CollaborationModesConfig, -) -> Arc { - let openai_models_provider = config - .model_providers - .get(OPENAI_PROVIDER_ID) - .cloned() - .unwrap_or_else(|| ModelProviderInfo::create_openai_provider(/*base_url*/ None)); - - Arc::new(ModelsManager::new_with_provider( +) -> SharedModelsManager { + let provider = create_model_provider(config.model_provider.clone(), Some(auth_manager)); + provider.models_manager( config.codex_home.to_path_buf(), - auth_manager, config.model_catalog.clone(), collaboration_modes_config, - openai_models_provider, - )) + ) } fn configured_thread_store(config: &Config) -> Arc { @@ -364,11 +357,12 @@ impl ThreadManager { state: Arc::new(ThreadManagerState { threads: Arc::new(RwLock::new(HashMap::new())), thread_created_tx, - models_manager: Arc::new(ModelsManager::with_provider_for_tests( - codex_home, - auth_manager.clone(), - provider, - )), + models_manager: create_model_provider(provider, Some(auth_manager.clone())) + .models_manager( + codex_home, + /*config_model_catalog*/ None, + CollaborationModesConfig::default(), + ), environment_manager, skills_manager, plugins_manager, @@ -422,7 +416,7 @@ impl ThreadManager { validate_environment_selections(self.state.environment_manager.as_ref(), environments) } - pub fn get_models_manager(&self) -> Arc { + pub fn get_models_manager(&self) -> SharedModelsManager { self.state.models_manager.clone() } diff --git a/codex-rs/core/src/thread_manager_tests.rs b/codex-rs/core/src/thread_manager_tests.rs index dc2cb004f502..0ef7afaff1e2 100644 --- a/codex-rs/core/src/thread_manager_tests.rs +++ b/codex-rs/core/src/thread_manager_tests.rs @@ -412,7 +412,7 @@ async fn resume_and_fork_do_not_restore_thread_environments_from_rollout() { } #[tokio::test] -async fn new_uses_configured_openai_provider_for_model_refresh() { +async fn new_uses_active_provider_for_model_refresh() { let server = MockServer::start().await; let models_mock = mount_models_once(&server, ModelsResponse { models: vec![] }).await; @@ -422,11 +422,7 @@ async fn new_uses_configured_openai_provider_for_model_refresh() { config.cwd = config.codex_home.abs(); std::fs::create_dir_all(&config.codex_home).expect("create codex home"); config.model_catalog = None; - config - .model_providers - .get_mut("openai") - .expect("openai provider should exist") - .base_url = Some(server.uri()); + config.model_provider.base_url = Some(server.uri()); let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); diff --git a/codex-rs/core/tests/suite/personality.rs b/codex-rs/core/tests/suite/personality.rs index 0a1c76295a17..a53242c0dd3d 100644 --- a/codex-rs/core/tests/suite/personality.rs +++ b/codex-rs/core/tests/suite/personality.rs @@ -1,7 +1,7 @@ use codex_config::types::Personality; use codex_features::Feature; -use codex_models_manager::manager::ModelsManager; use codex_models_manager::manager::RefreshStrategy; +use codex_models_manager::manager::SharedModelsManager; use codex_protocol::config_types::ReasoningSummary; use codex_protocol::openai_models::ConfigShellToolType; use codex_protocol::openai_models::ModelInfo; @@ -28,7 +28,6 @@ use core_test_support::skip_if_no_network; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use pretty_assertions::assert_eq; -use std::sync::Arc; use tempfile::TempDir; use tokio::time::Duration; use tokio::time::Instant; @@ -933,7 +932,7 @@ async fn user_turn_personality_remote_model_template_includes_update_message() - Ok(()) } -async fn wait_for_model_available(manager: &Arc, slug: &str) { +async fn wait_for_model_available(manager: &SharedModelsManager, slug: &str) { let deadline = Instant::now() + Duration::from_secs(2); loop { let models = manager.list_models(RefreshStrategy::OnlineIfUncached).await; diff --git a/codex-rs/core/tests/suite/remote_models.rs b/codex-rs/core/tests/suite/remote_models.rs index 790d9ca4f7f9..07a1bc404d1d 100644 --- a/codex-rs/core/tests/suite/remote_models.rs +++ b/codex-rs/core/tests/suite/remote_models.rs @@ -1,15 +1,12 @@ #![cfg(not(target_os = "windows"))] #![allow(clippy::expect_used)] -// unified exec is not supported on Windows OS -use std::sync::Arc; - use anyhow::Result; use codex_login::CodexAuth; use codex_model_provider_info::ModelProviderInfo; use codex_model_provider_info::built_in_model_providers; use codex_models_manager::bundled_models_response; -use codex_models_manager::manager::ModelsManager; use codex_models_manager::manager::RefreshStrategy; +use codex_models_manager::manager::SharedModelsManager; use codex_protocol::config_types::ReasoningSummary; use codex_protocol::openai_models::ConfigShellToolType; use codex_protocol::openai_models::ModelInfo; @@ -1207,7 +1204,7 @@ async fn remote_models_hide_picker_only_models() -> Result<()> { Ok(()) } -async fn wait_for_model_available(manager: &Arc, slug: &str) -> ModelPreset { +async fn wait_for_model_available(manager: &SharedModelsManager, slug: &str) -> ModelPreset { let deadline = Instant::now() + Duration::from_secs(2); loop { if let Some(model) = { diff --git a/codex-rs/core/tests/suite/spawn_agent_description.rs b/codex-rs/core/tests/suite/spawn_agent_description.rs index a8b3bab95253..031c3135e8a3 100644 --- a/codex-rs/core/tests/suite/spawn_agent_description.rs +++ b/codex-rs/core/tests/suite/spawn_agent_description.rs @@ -4,8 +4,8 @@ use anyhow::Result; use codex_features::Feature; use codex_login::CodexAuth; -use codex_models_manager::manager::ModelsManager; use codex_models_manager::manager::RefreshStrategy; +use codex_models_manager::manager::SharedModelsManager; use codex_protocol::config_types::ReasoningSummary; use codex_protocol::openai_models::ConfigShellToolType; use codex_protocol::openai_models::ModelInfo; @@ -23,7 +23,6 @@ use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::test_codex::test_codex; use serde_json::Value; -use std::sync::Arc; use std::time::Duration; use std::time::Instant; use tokio::time::sleep; @@ -89,7 +88,7 @@ fn test_model_info( } } -async fn wait_for_model_available(manager: &Arc, slug: &str) { +async fn wait_for_model_available(manager: &SharedModelsManager, slug: &str) { let deadline = Instant::now() + Duration::from_secs(2); loop { let available_models = manager.list_models(RefreshStrategy::Online).await; diff --git a/codex-rs/model-provider/Cargo.toml b/codex-rs/model-provider/Cargo.toml index f5ff5b10cc8f..58235ab24d5e 100644 --- a/codex-rs/model-provider/Cargo.toml +++ b/codex-rs/model-provider/Cargo.toml @@ -18,10 +18,19 @@ codex-api = { workspace = true } codex-agent-identity = { workspace = true } codex-aws-auth = { workspace = true } codex-client = { workspace = true } +codex-feedback = { workspace = true } codex-login = { workspace = true } codex-model-provider-info = { workspace = true } +codex-models-manager = { workspace = true } +codex-otel = { workspace = true } codex-protocol = { workspace = true } +codex-response-debug-context = { workspace = true } http = { workspace = true } +tokio = { workspace = true, features = ["sync", "time"] } +tracing = { workspace = true, features = ["log"] } [dev-dependencies] pretty_assertions = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt"] } +wiremock = { workspace = true } diff --git a/codex-rs/model-provider/src/amazon_bedrock/catalog.rs b/codex-rs/model-provider/src/amazon_bedrock/catalog.rs new file mode 100644 index 000000000000..30536bd271e9 --- /dev/null +++ b/codex-rs/model-provider/src/amazon_bedrock/catalog.rs @@ -0,0 +1,143 @@ +use codex_models_manager::bundled_models_response; +use codex_models_manager::model_info::model_info_from_slug; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::openai_models::ConfigShellToolType; +use codex_protocol::openai_models::InputModality; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ModelVisibility; +use codex_protocol::openai_models::ModelsResponse; +use codex_protocol::openai_models::ReasoningEffort; +use codex_protocol::openai_models::ReasoningEffortPreset; +use codex_protocol::openai_models::TruncationPolicyConfig; +use codex_protocol::openai_models::WebSearchToolType; + +const GPT_OSS_CONTEXT_WINDOW: i64 = 128_000; +const GPT_5_4_CMB_MODEL_ID: &str = "openai.gpt-5.4-cmb"; +const GPT_5_4_MODEL_ID: &str = "gpt-5.4"; + +pub(crate) fn static_model_catalog() -> ModelsResponse { + ModelsResponse { + models: vec![ + gpt_5_4_cmb_bedrock_model(/*priority*/ 0), + bedrock_model( + "openai.gpt-oss-120b", + "GPT OSS 120B on Bedrock", + /*priority*/ 1, + ), + bedrock_model( + "openai.gpt-oss-20b", + "GPT OSS 20B on Bedrock", + /*priority*/ 2, + ), + ], + } +} + +fn gpt_5_4_cmb_bedrock_model(priority: i32) -> ModelInfo { + let mut model = bundled_gpt_5_4_model(); + + model.slug = GPT_5_4_CMB_MODEL_ID.to_string(); + model.priority = priority; + model +} + +fn bundled_gpt_5_4_model() -> ModelInfo { + if let Ok(response) = bundled_models_response() + && let Some(model) = response + .models + .into_iter() + .find(|model| model.slug == GPT_5_4_MODEL_ID) + { + return model; + } + + model_info_from_slug(GPT_5_4_MODEL_ID) +} + +fn bedrock_model(slug: &str, display_name: &str, priority: i32) -> ModelInfo { + ModelInfo { + slug: slug.to_string(), + display_name: display_name.to_string(), + description: Some(display_name.to_string()), + default_reasoning_level: Some(ReasoningEffort::Medium), + supported_reasoning_levels: vec![ + reasoning_effort_preset(ReasoningEffort::Low), + reasoning_effort_preset(ReasoningEffort::Medium), + reasoning_effort_preset(ReasoningEffort::High), + ], + shell_type: ConfigShellToolType::ShellCommand, + visibility: ModelVisibility::List, + supported_in_api: true, + priority, + additional_speed_tiers: Vec::new(), + availability_nux: None, + upgrade: None, + base_instructions: codex_models_manager::model_info::BASE_INSTRUCTIONS.to_string(), + model_messages: None, + supports_reasoning_summaries: true, + default_reasoning_summary: ReasoningSummary::None, + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: None, + web_search_tool_type: WebSearchToolType::Text, + truncation_policy: TruncationPolicyConfig::tokens(/*limit*/ 10_000), + supports_parallel_tool_calls: true, + supports_image_detail_original: false, + context_window: Some(GPT_OSS_CONTEXT_WINDOW), + max_context_window: Some(GPT_OSS_CONTEXT_WINDOW), + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + input_modalities: vec![InputModality::Text], + used_fallback_model_metadata: false, + supports_search_tool: false, + } +} + +fn reasoning_effort_preset(effort: ReasoningEffort) -> ReasoningEffortPreset { + ReasoningEffortPreset { + effort, + description: match effort { + ReasoningEffort::None => "No reasoning", + ReasoningEffort::Minimal => "Minimal reasoning", + ReasoningEffort::Low => "Fast responses with lighter reasoning", + ReasoningEffort::Medium => "Balances speed and reasoning depth for everyday tasks", + ReasoningEffort::High => "Greater reasoning depth for complex problems", + ReasoningEffort::XHigh => "Extra high reasoning depth for complex problems", + } + .to_string(), + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn catalog_uses_mantle_model_ids_as_slugs() { + let catalog = static_model_catalog(); + + assert_eq!(catalog.models.len(), 3); + assert_eq!(catalog.models[0].slug, GPT_5_4_CMB_MODEL_ID); + assert_eq!(catalog.models[1].slug, "openai.gpt-oss-120b"); + assert_eq!(catalog.models[2].slug, "openai.gpt-oss-20b"); + } + + #[test] + fn gpt_5_4_cmb_uses_gpt_5_4_spec() { + let catalog = static_model_catalog(); + let cmb_model = catalog + .models + .iter() + .find(|model| model.slug == GPT_5_4_CMB_MODEL_ID) + .expect("Bedrock catalog should include GPT-5.4 CMB"); + let mut gpt_5_4_model = bundled_gpt_5_4_model(); + + gpt_5_4_model.slug = GPT_5_4_CMB_MODEL_ID.to_string(); + gpt_5_4_model.priority = cmb_model.priority; + + assert_eq!(*cmb_model, gpt_5_4_model); + } +} diff --git a/codex-rs/model-provider/src/amazon_bedrock/mod.rs b/codex-rs/model-provider/src/amazon_bedrock/mod.rs index af7ac8714ce1..2c47b2f25d90 100644 --- a/codex-rs/model-provider/src/amazon_bedrock/mod.rs +++ b/codex-rs/model-provider/src/amazon_bedrock/mod.rs @@ -1,6 +1,8 @@ mod auth; +mod catalog; mod mantle; +use std::path::PathBuf; use std::sync::Arc; use codex_api::Provider; @@ -9,14 +11,19 @@ use codex_login::AuthManager; use codex_login::CodexAuth; use codex_model_provider_info::ModelProviderAwsAuthInfo; use codex_model_provider_info::ModelProviderInfo; +use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig; +use codex_models_manager::manager::SharedModelsManager; +use codex_models_manager::manager::StaticModelsManager; use codex_protocol::account::ProviderAccount; use codex_protocol::error::Result; +use codex_protocol::openai_models::ModelsResponse; use crate::provider::ModelProvider; use crate::provider::ProviderAccountResult; use crate::provider::ProviderAccountState; use auth::resolve_provider_auth; use auth::resolve_region; +pub(crate) use catalog::static_model_catalog; use mantle::base_url; /// Runtime provider for Amazon Bedrock's OpenAI-compatible Mantle endpoint. @@ -26,6 +33,22 @@ pub(crate) struct AmazonBedrockModelProvider { pub(crate) aws: ModelProviderAwsAuthInfo, } +impl AmazonBedrockModelProvider { + pub(crate) fn new(provider_info: ModelProviderInfo) -> Self { + let aws = provider_info + .aws + .clone() + .unwrap_or(ModelProviderAwsAuthInfo { + profile: None, + region: None, + }); + Self { + info: provider_info, + aws, + } + } +} + #[async_trait::async_trait] impl ModelProvider for AmazonBedrockModelProvider { fn info(&self) -> &ModelProviderInfo { @@ -57,6 +80,19 @@ impl ModelProvider for AmazonBedrockModelProvider { async fn api_auth(&self) -> Result { resolve_provider_auth(&self.aws).await } + + fn models_manager( + &self, + _codex_home: PathBuf, + config_model_catalog: Option, + collaboration_modes_config: CollaborationModesConfig, + ) -> SharedModelsManager { + Arc::new(StaticModelsManager::new( + /*auth_manager*/ None, + config_model_catalog.unwrap_or_else(static_model_catalog), + collaboration_modes_config, + )) + } } #[cfg(test)] diff --git a/codex-rs/model-provider/src/lib.rs b/codex-rs/model-provider/src/lib.rs index 11c180db1141..ac51968ac961 100644 --- a/codex-rs/model-provider/src/lib.rs +++ b/codex-rs/model-provider/src/lib.rs @@ -1,6 +1,7 @@ mod amazon_bedrock; mod auth; mod bearer_auth_provider; +mod models_endpoint; mod provider; pub use auth::auth_provider_from_auth; diff --git a/codex-rs/model-provider/src/models_endpoint.rs b/codex-rs/model-provider/src/models_endpoint.rs new file mode 100644 index 000000000000..8a72beea7012 --- /dev/null +++ b/codex-rs/model-provider/src/models_endpoint.rs @@ -0,0 +1,247 @@ +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use codex_api::ModelsClient; +use codex_api::RequestTelemetry; +use codex_api::ReqwestTransport; +use codex_api::TransportError; +use codex_api::auth_header_telemetry; +use codex_api::map_api_error; +use codex_feedback::FeedbackRequestTags; +use codex_feedback::emit_feedback_request_tags_with_auth_env; +use codex_login::AuthEnvTelemetry; +use codex_login::AuthManager; +use codex_login::CodexAuth; +use codex_login::collect_auth_env_telemetry; +use codex_login::default_client::build_reqwest_client; +use codex_model_provider_info::ModelProviderInfo; +use codex_models_manager::manager::ModelsEndpointClient; +use codex_otel::TelemetryAuthMode; +use codex_protocol::error::CodexErr; +use codex_protocol::error::Result as CoreResult; +use codex_protocol::openai_models::ModelInfo; +use codex_response_debug_context::extract_response_debug_context; +use codex_response_debug_context::telemetry_transport_error_message; +use http::HeaderMap; +use tokio::time::timeout; + +use crate::auth::resolve_provider_auth; + +const MODELS_REFRESH_TIMEOUT: Duration = Duration::from_secs(5); +const MODELS_ENDPOINT: &str = "/models"; + +/// Provider-owned OpenAI-compatible `/models` endpoint. +#[derive(Debug)] +pub(crate) struct OpenAiModelsEndpoint { + provider_info: ModelProviderInfo, + auth_manager: Option>, +} + +impl OpenAiModelsEndpoint { + pub(crate) fn new( + provider_info: ModelProviderInfo, + auth_manager: Option>, + ) -> Self { + Self { + provider_info, + auth_manager, + } + } + + async fn auth(&self) -> Option { + match self.auth_manager.as_ref() { + Some(auth_manager) => auth_manager.auth().await, + None => None, + } + } + + fn auth_env(&self) -> AuthEnvTelemetry { + let codex_api_key_env_enabled = self + .auth_manager + .as_ref() + .is_some_and(|auth_manager| auth_manager.codex_api_key_env_enabled()); + collect_auth_env_telemetry(&self.provider_info, codex_api_key_env_enabled) + } +} + +#[async_trait] +impl ModelsEndpointClient for OpenAiModelsEndpoint { + fn has_command_auth(&self) -> bool { + self.provider_info.has_command_auth() + } + + async fn uses_codex_backend(&self) -> bool { + self.auth() + .await + .as_ref() + .is_some_and(CodexAuth::uses_codex_backend) + } + + async fn list_models( + &self, + client_version: &str, + ) -> CoreResult<(Vec, Option)> { + let _timer = + codex_otel::start_global_timer("codex.remote_models.fetch_update.duration_ms", &[]); + let auth = self.auth().await; + let auth_mode = auth.as_ref().map(CodexAuth::auth_mode); + let api_provider = self.provider_info.to_api_provider(auth_mode)?; + let api_auth = resolve_provider_auth(auth.as_ref(), &self.provider_info)?; + let transport = ReqwestTransport::new(build_reqwest_client()); + let auth_telemetry = auth_header_telemetry(api_auth.as_ref()); + let request_telemetry: Arc = Arc::new(ModelsRequestTelemetry { + auth_mode: auth_mode.map(|mode| TelemetryAuthMode::from(mode).to_string()), + auth_header_attached: auth_telemetry.attached, + auth_header_name: auth_telemetry.name, + auth_env: self.auth_env(), + }); + let client = ModelsClient::new(transport, api_provider, api_auth) + .with_telemetry(Some(request_telemetry)); + + timeout( + MODELS_REFRESH_TIMEOUT, + client.list_models(client_version, HeaderMap::new()), + ) + .await + .map_err(|_| CodexErr::Timeout)? + .map_err(map_api_error) + } +} + +#[derive(Clone)] +struct ModelsRequestTelemetry { + auth_mode: Option, + auth_header_attached: bool, + auth_header_name: Option<&'static str>, + auth_env: AuthEnvTelemetry, +} + +impl RequestTelemetry for ModelsRequestTelemetry { + fn on_request( + &self, + attempt: u64, + status: Option, + error: Option<&TransportError>, + duration: Duration, + ) { + let success = status.is_some_and(|code| code.is_success()) && error.is_none(); + let error_message = error.map(telemetry_transport_error_message); + let response_debug = error + .map(extract_response_debug_context) + .unwrap_or_default(); + let status = status.map(|status| status.as_u16()); + tracing::event!( + target: "codex_otel.log_only", + tracing::Level::INFO, + event.name = "codex.api_request", + duration_ms = %duration.as_millis(), + http.response.status_code = status, + success = success, + error.message = error_message.as_deref(), + attempt = attempt, + endpoint = MODELS_ENDPOINT, + auth.header_attached = self.auth_header_attached, + auth.header_name = self.auth_header_name, + auth.env_openai_api_key_present = self.auth_env.openai_api_key_env_present, + auth.env_codex_api_key_present = self.auth_env.codex_api_key_env_present, + auth.env_codex_api_key_enabled = self.auth_env.codex_api_key_env_enabled, + auth.env_provider_key_name = self.auth_env.provider_env_key_name.as_deref(), + auth.env_provider_key_present = self.auth_env.provider_env_key_present, + auth.env_refresh_token_url_override_present = self.auth_env.refresh_token_url_override_present, + auth.request_id = response_debug.request_id.as_deref(), + auth.cf_ray = response_debug.cf_ray.as_deref(), + auth.error = response_debug.auth_error.as_deref(), + auth.error_code = response_debug.auth_error_code.as_deref(), + auth.mode = self.auth_mode.as_deref(), + ); + tracing::event!( + target: "codex_otel.trace_safe", + tracing::Level::INFO, + event.name = "codex.api_request", + duration_ms = %duration.as_millis(), + http.response.status_code = status, + success = success, + error.message = error_message.as_deref(), + attempt = attempt, + endpoint = MODELS_ENDPOINT, + auth.header_attached = self.auth_header_attached, + auth.header_name = self.auth_header_name, + auth.env_openai_api_key_present = self.auth_env.openai_api_key_env_present, + auth.env_codex_api_key_present = self.auth_env.codex_api_key_env_present, + auth.env_codex_api_key_enabled = self.auth_env.codex_api_key_env_enabled, + auth.env_provider_key_name = self.auth_env.provider_env_key_name.as_deref(), + auth.env_provider_key_present = self.auth_env.provider_env_key_present, + auth.env_refresh_token_url_override_present = self.auth_env.refresh_token_url_override_present, + auth.request_id = response_debug.request_id.as_deref(), + auth.cf_ray = response_debug.cf_ray.as_deref(), + auth.error = response_debug.auth_error.as_deref(), + auth.error_code = response_debug.auth_error_code.as_deref(), + auth.mode = self.auth_mode.as_deref(), + ); + emit_feedback_request_tags_with_auth_env( + &FeedbackRequestTags { + endpoint: MODELS_ENDPOINT, + auth_header_attached: self.auth_header_attached, + auth_header_name: self.auth_header_name, + auth_mode: self.auth_mode.as_deref(), + auth_retry_after_unauthorized: None, + auth_recovery_mode: None, + auth_recovery_phase: None, + auth_connection_reused: None, + auth_request_id: response_debug.request_id.as_deref(), + auth_cf_ray: response_debug.cf_ray.as_deref(), + auth_error: response_debug.auth_error.as_deref(), + auth_error_code: response_debug.auth_error_code.as_deref(), + auth_recovery_followup_success: None, + auth_recovery_followup_status: None, + }, + &self.auth_env, + ); + } +} + +#[cfg(test)] +mod tests { + use std::num::NonZeroU64; + + use super::*; + use codex_protocol::config_types::ModelProviderAuthInfo; + + fn provider_info_with_command_auth() -> ModelProviderInfo { + ModelProviderInfo { + auth: Some(ModelProviderAuthInfo { + command: "print-token".to_string(), + args: Vec::new(), + timeout_ms: NonZeroU64::new(5_000).expect("timeout should be non-zero"), + refresh_interval_ms: 300_000, + cwd: std::env::current_dir() + .expect("current dir should be available") + .try_into() + .expect("current dir should be absolute"), + }), + requires_openai_auth: false, + ..ModelProviderInfo::create_openai_provider(/*base_url*/ None) + } + } + + #[test] + fn command_auth_provider_reports_command_auth_without_cached_auth() { + let endpoint = OpenAiModelsEndpoint::new( + provider_info_with_command_auth(), + /*auth_manager*/ None, + ); + + assert!(endpoint.has_command_auth()); + } + + #[test] + fn provider_without_command_auth_reports_no_command_auth() { + let endpoint = OpenAiModelsEndpoint::new( + ModelProviderInfo::create_openai_provider(/*base_url*/ None), + /*auth_manager*/ None, + ); + + assert!(!endpoint.has_command_auth()); + } +} diff --git a/codex-rs/model-provider/src/provider.rs b/codex-rs/model-provider/src/provider.rs index 7cd14bbc49b4..b845aae5b57e 100644 --- a/codex-rs/model-provider/src/provider.rs +++ b/codex-rs/model-provider/src/provider.rs @@ -1,17 +1,23 @@ use std::fmt; +use std::path::PathBuf; use std::sync::Arc; use codex_api::Provider; use codex_api::SharedAuthProvider; use codex_login::AuthManager; use codex_login::CodexAuth; -use codex_model_provider_info::ModelProviderAwsAuthInfo; use codex_model_provider_info::ModelProviderInfo; +use codex_models_manager::collaboration_mode_presets::CollaborationModesConfig; +use codex_models_manager::manager::OpenAiModelsManager; +use codex_models_manager::manager::SharedModelsManager; +use codex_models_manager::manager::StaticModelsManager; use codex_protocol::account::ProviderAccount; +use codex_protocol::openai_models::ModelsResponse; use crate::amazon_bedrock::AmazonBedrockModelProvider; use crate::auth::auth_manager_for_provider; use crate::auth::resolve_provider_auth; +use crate::models_endpoint::OpenAiModelsEndpoint; /// Current app-visible account state for a model provider. #[derive(Debug, Clone, PartialEq, Eq)] @@ -79,6 +85,14 @@ pub trait ModelProvider: fmt::Debug + Send + Sync { let auth = self.auth().await; resolve_provider_auth(auth.as_ref(), self.info()) } + + /// Creates the model manager implementation appropriate for this provider. + fn models_manager( + &self, + codex_home: PathBuf, + config_model_catalog: Option, + collaboration_modes_config: CollaborationModesConfig, + ) -> SharedModelsManager; } /// Shared runtime model provider handle. @@ -90,24 +104,10 @@ pub fn create_model_provider( auth_manager: Option>, ) -> SharedModelProvider { if provider_info.is_amazon_bedrock() { - let aws = provider_info - .aws - .clone() - .unwrap_or(ModelProviderAwsAuthInfo { - profile: None, - region: None, - }); - return Arc::new(AmazonBedrockModelProvider { - info: provider_info, - aws, - }); + Arc::new(AmazonBedrockModelProvider::new(provider_info)) + } else { + Arc::new(ConfiguredModelProvider::new(provider_info, auth_manager)) } - - let auth_manager = auth_manager_for_provider(auth_manager, &provider_info); - Arc::new(ConfiguredModelProvider { - info: provider_info, - auth_manager, - }) } /// Runtime model provider backed by configured `ModelProviderInfo`. @@ -117,6 +117,16 @@ struct ConfiguredModelProvider { auth_manager: Option>, } +impl ConfiguredModelProvider { + fn new(provider_info: ModelProviderInfo, auth_manager: Option>) -> Self { + let auth_manager = auth_manager_for_provider(auth_manager, &provider_info); + Self { + info: provider_info, + auth_manager, + } + } +} + #[async_trait::async_trait] impl ModelProvider for ConfiguredModelProvider { fn info(&self) -> &ModelProviderInfo { @@ -165,6 +175,33 @@ impl ModelProvider for ConfiguredModelProvider { requires_openai_auth: self.info.requires_openai_auth, }) } + + fn models_manager( + &self, + codex_home: PathBuf, + config_model_catalog: Option, + collaboration_modes_config: CollaborationModesConfig, + ) -> SharedModelsManager { + match config_model_catalog { + Some(model_catalog) => Arc::new(StaticModelsManager::new( + self.auth_manager.clone(), + model_catalog, + collaboration_modes_config, + )), + None => { + let endpoint = Arc::new(OpenAiModelsEndpoint::new( + self.info.clone(), + self.auth_manager.clone(), + )); + Arc::new(OpenAiModelsManager::new( + codex_home, + endpoint, + self.auth_manager.clone(), + collaboration_modes_config, + )) + } + } + } } #[cfg(test)] @@ -173,8 +210,18 @@ mod tests { use codex_model_provider_info::ModelProviderAwsAuthInfo; use codex_model_provider_info::WireApi; + use codex_models_manager::manager::RefreshStrategy; use codex_protocol::config_types::ModelProviderAuthInfo; + use codex_protocol::openai_models::ModelInfo; + use codex_protocol::openai_models::ModelsResponse; use pretty_assertions::assert_eq; + use serde_json::json; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::header_regex; + use wiremock::matchers::method; + use wiremock::matchers::path; use super::*; @@ -195,6 +242,59 @@ mod tests { } } + fn test_codex_home() -> std::path::PathBuf { + std::env::temp_dir().join(format!("codex-model-provider-test-{}", std::process::id())) + } + + fn provider_for(base_url: String) -> ModelProviderInfo { + ModelProviderInfo { + name: "mock".into(), + base_url: Some(base_url), + env_key: None, + env_key_instructions: None, + experimental_bearer_token: None, + auth: None, + aws: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + websocket_connect_timeout_ms: None, + requires_openai_auth: false, + supports_websockets: false, + } + } + + fn remote_model(slug: &str) -> ModelInfo { + serde_json::from_value(json!({ + "slug": slug, + "display_name": slug, + "description": null, + "default_reasoning_level": "medium", + "supported_reasoning_levels": [], + "shell_type": "shell_command", + "visibility": "list", + "supported_in_api": true, + "priority": 0, + "upgrade": null, + "base_instructions": "base instructions", + "supports_reasoning_summaries": false, + "support_verbosity": false, + "default_verbosity": null, + "apply_patch_tool_type": null, + "truncation_policy": {"mode": "bytes", "limit": 10_000}, + "supports_parallel_tool_calls": false, + "supports_image_detail_original": false, + "context_window": 272_000, + "max_context_window": 272_000, + "experimental_supported_tools": [], + })) + .expect("valid model") + } + #[test] fn create_model_provider_builds_command_auth_manager_without_base_manager() { let provider = create_model_provider( @@ -295,4 +395,108 @@ mod tests { }) ); } + + #[tokio::test] + async fn amazon_bedrock_provider_creates_static_models_manager() { + let provider = create_model_provider( + ModelProviderInfo::create_amazon_bedrock_provider(/*aws*/ None), + /*auth_manager*/ None, + ); + let manager = provider.models_manager( + test_codex_home(), + /*config_model_catalog*/ None, + Default::default(), + ); + + let catalog = manager.raw_model_catalog(RefreshStrategy::Online).await; + let model_ids = catalog + .models + .iter() + .map(|model| model.slug.as_str()) + .collect::>(); + + assert_eq!( + model_ids, + vec![ + "openai.gpt-5.4-cmb", + "openai.gpt-oss-120b", + "openai.gpt-oss-20b" + ] + ); + + let default_model = manager + .list_models(RefreshStrategy::Online) + .await + .into_iter() + .find(|preset| preset.is_default) + .expect("Bedrock catalog should have a default model"); + + assert_eq!(default_model.model, "openai.gpt-5.4-cmb"); + } + + #[tokio::test] + async fn amazon_bedrock_provider_uses_configured_static_catalog_when_present() { + let custom_model = + codex_models_manager::model_info::model_info_from_slug("custom-bedrock-model"); + + let provider = create_model_provider( + ModelProviderInfo::create_amazon_bedrock_provider(/*aws*/ None), + /*auth_manager*/ None, + ); + let manager = provider.models_manager( + test_codex_home(), + Some(ModelsResponse { + models: vec![custom_model], + }), + Default::default(), + ); + + let catalog = manager.raw_model_catalog(RefreshStrategy::Online).await; + + assert_eq!(catalog.models.len(), 1); + assert_eq!(catalog.models[0].slug, "custom-bedrock-model"); + } + + #[tokio::test] + async fn configured_provider_models_manager_uses_provider_bearer_token() { + let server = MockServer::start().await; + let remote_models = vec![remote_model("provider-model")]; + + Mock::given(method("GET")) + .and(path("/models")) + .and(header_regex("Authorization", "Bearer provider-token")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "application/json") + .set_body_json(ModelsResponse { + models: remote_models.clone(), + }), + ) + .expect(1) + .mount(&server) + .await; + + let mut provider_info = provider_for(server.uri()); + provider_info.experimental_bearer_token = Some("provider-token".to_string()); + let provider = create_model_provider( + provider_info, + Some(AuthManager::from_auth_for_testing( + CodexAuth::create_dummy_chatgpt_auth_for_testing(), + )), + ); + + let manager = provider.models_manager( + test_codex_home(), + /*config_model_catalog*/ None, + Default::default(), + ); + let catalog = manager.raw_model_catalog(RefreshStrategy::Online).await; + + assert!( + catalog + .models + .iter() + .any(|model| model.slug == "provider-model") + ); + } } diff --git a/codex-rs/models-manager/Cargo.toml b/codex-rs/models-manager/Cargo.toml index 59a2bff101c0..f46bf2b285a4 100644 --- a/codex-rs/models-manager/Cargo.toml +++ b/codex-rs/models-manager/Cargo.toml @@ -13,33 +13,21 @@ path = "src/lib.rs" workspace = true [dependencies] +async-trait = { workspace = true } chrono = { workspace = true, features = ["serde"] } -codex-api = { workspace = true } codex-app-server-protocol = { workspace = true } codex-collaboration-mode-templates = { workspace = true } -codex-config = { workspace = true } -codex-feedback = { workspace = true } codex-login = { workspace = true } -codex-model-provider-info = { workspace = true } codex-otel = { workspace = true } -codex-model-provider = { workspace = true } codex-protocol = { workspace = true } -codex-response-debug-context = { workspace = true } codex-utils-output-truncation = { workspace = true } codex-utils-template = { workspace = true } -http = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { workspace = true, features = ["fs", "sync", "time"] } tracing = { workspace = true, features = ["log"] } [dev-dependencies] -base64 = { workspace = true } -codex-utils-absolute-path = { workspace = true } -core_test_support = { workspace = true } pretty_assertions = { workspace = true } serde_json = { workspace = true } tempfile = { workspace = true } -tracing = { workspace = true, features = ["log"] } -tracing-subscriber = { workspace = true } -wiremock = { workspace = true } diff --git a/codex-rs/models-manager/src/lib.rs b/codex-rs/models-manager/src/lib.rs index e99c33edb900..8bf30d0b602a 100644 --- a/codex-rs/models-manager/src/lib.rs +++ b/codex-rs/models-manager/src/lib.rs @@ -4,12 +4,9 @@ pub(crate) mod config; pub mod manager; pub mod model_info; pub mod model_presets; +pub mod test_support; pub use codex_app_server_protocol::AuthMode; -pub use codex_login::AuthManager; -pub use codex_login::CodexAuth; -pub use codex_model_provider_info::ModelProviderInfo; -pub use codex_model_provider_info::WireApi; pub use config::ModelsManagerConfig; /// Load the bundled model catalog shipped with `codex-models-manager`. diff --git a/codex-rs/models-manager/src/manager.rs b/codex-rs/models-manager/src/manager.rs index 34f9f7a781fe..f13f2df60ddb 100644 --- a/codex-rs/models-manager/src/manager.rs +++ b/codex-rs/models-manager/src/manager.rs @@ -3,137 +3,44 @@ use crate::collaboration_mode_presets::CollaborationModesConfig; use crate::collaboration_mode_presets::builtin_collaboration_mode_presets; use crate::config::ModelsManagerConfig; use crate::model_info; -use codex_api::ModelsClient; -use codex_api::RequestTelemetry; -use codex_api::ReqwestTransport; -use codex_api::TransportError; -use codex_api::auth_header_telemetry; -use codex_api::map_api_error; -use codex_feedback::FeedbackRequestTags; -use codex_feedback::emit_feedback_request_tags_with_auth_env; -use codex_login::AuthEnvTelemetry; +use async_trait::async_trait; use codex_login::AuthManager; -use codex_login::CodexAuth; -use codex_login::collect_auth_env_telemetry; -use codex_login::default_client::build_reqwest_client; -use codex_model_provider::SharedModelProvider; -use codex_model_provider::create_model_provider; -use codex_model_provider_info::ModelProviderInfo; -use codex_otel::TelemetryAuthMode; use codex_protocol::config_types::CollaborationModeMask; -use codex_protocol::error::CodexErr; use codex_protocol::error::Result as CoreResult; use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ModelPreset; use codex_protocol::openai_models::ModelsResponse; -use codex_response_debug_context::extract_response_debug_context; -use codex_response_debug_context::telemetry_transport_error_message; -use http::HeaderMap; use std::fmt; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use tokio::sync::RwLock; use tokio::sync::TryLockError; -use tokio::time::timeout; +use tracing::Instrument as _; use tracing::error; use tracing::info; -use tracing::instrument; const MODEL_CACHE_FILE: &str = "models_cache.json"; const DEFAULT_MODEL_CACHE_TTL: Duration = Duration::from_secs(300); -const MODELS_REFRESH_TIMEOUT: Duration = Duration::from_secs(5); -const MODELS_ENDPOINT: &str = "/models"; -#[derive(Clone)] -struct ModelsRequestTelemetry { - auth_mode: Option, - auth_header_attached: bool, - auth_header_name: Option<&'static str>, - auth_env: AuthEnvTelemetry, -} -impl RequestTelemetry for ModelsRequestTelemetry { - fn on_request( +/// Remote endpoint used by the OpenAI-compatible model manager. +/// +/// Implementations own provider-specific auth and transport details. The model +/// manager owns refresh policy, cache behavior, and catalog merging; it calls +/// this endpoint only when it decides a remote refresh should happen. +#[async_trait] +pub trait ModelsEndpointClient: fmt::Debug + Send + Sync { + /// Returns whether this provider can authenticate command-scoped requests. + fn has_command_auth(&self) -> bool; + + /// Returns whether the currently resolved auth can use Codex backend-only models. + async fn uses_codex_backend(&self) -> bool; + + /// Fetches the latest remote model catalog and optional ETag. + async fn list_models( &self, - attempt: u64, - status: Option, - error: Option<&TransportError>, - duration: Duration, - ) { - let success = status.is_some_and(|code| code.is_success()) && error.is_none(); - let error_message = error.map(telemetry_transport_error_message); - let response_debug = error - .map(extract_response_debug_context) - .unwrap_or_default(); - let status = status.map(|status| status.as_u16()); - tracing::event!( - target: "codex_otel.log_only", - tracing::Level::INFO, - event.name = "codex.api_request", - duration_ms = %duration.as_millis(), - http.response.status_code = status, - success = success, - error.message = error_message.as_deref(), - attempt = attempt, - endpoint = MODELS_ENDPOINT, - auth.header_attached = self.auth_header_attached, - auth.header_name = self.auth_header_name, - auth.env_openai_api_key_present = self.auth_env.openai_api_key_env_present, - auth.env_codex_api_key_present = self.auth_env.codex_api_key_env_present, - auth.env_codex_api_key_enabled = self.auth_env.codex_api_key_env_enabled, - auth.env_provider_key_name = self.auth_env.provider_env_key_name.as_deref(), - auth.env_provider_key_present = self.auth_env.provider_env_key_present, - auth.env_refresh_token_url_override_present = self.auth_env.refresh_token_url_override_present, - auth.request_id = response_debug.request_id.as_deref(), - auth.cf_ray = response_debug.cf_ray.as_deref(), - auth.error = response_debug.auth_error.as_deref(), - auth.error_code = response_debug.auth_error_code.as_deref(), - auth.mode = self.auth_mode.as_deref(), - ); - tracing::event!( - target: "codex_otel.trace_safe", - tracing::Level::INFO, - event.name = "codex.api_request", - duration_ms = %duration.as_millis(), - http.response.status_code = status, - success = success, - error.message = error_message.as_deref(), - attempt = attempt, - endpoint = MODELS_ENDPOINT, - auth.header_attached = self.auth_header_attached, - auth.header_name = self.auth_header_name, - auth.env_openai_api_key_present = self.auth_env.openai_api_key_env_present, - auth.env_codex_api_key_present = self.auth_env.codex_api_key_env_present, - auth.env_codex_api_key_enabled = self.auth_env.codex_api_key_env_enabled, - auth.env_provider_key_name = self.auth_env.provider_env_key_name.as_deref(), - auth.env_provider_key_present = self.auth_env.provider_env_key_present, - auth.env_refresh_token_url_override_present = self.auth_env.refresh_token_url_override_present, - auth.request_id = response_debug.request_id.as_deref(), - auth.cf_ray = response_debug.cf_ray.as_deref(), - auth.error = response_debug.auth_error.as_deref(), - auth.error_code = response_debug.auth_error_code.as_deref(), - auth.mode = self.auth_mode.as_deref(), - ); - emit_feedback_request_tags_with_auth_env( - &FeedbackRequestTags { - endpoint: MODELS_ENDPOINT, - auth_header_attached: self.auth_header_attached, - auth_header_name: self.auth_header_name, - auth_mode: self.auth_mode.as_deref(), - auth_retry_after_unauthorized: None, - auth_recovery_mode: None, - auth_recovery_phase: None, - auth_connection_reused: None, - auth_request_id: response_debug.request_id.as_deref(), - auth_cf_ray: response_debug.cf_ray.as_deref(), - auth_error: response_debug.auth_error.as_deref(), - auth_error_code: response_debug.auth_error_code.as_deref(), - auth_recovery_followup_success: None, - auth_recovery_followup_status: None, - }, - &self.auth_env, - ); - } + client_version: &str, + ) -> CoreResult<(Vec, Option)>; } /// Strategy for refreshing available models. @@ -163,123 +70,64 @@ impl fmt::Display for RefreshStrategy { } } -/// How the manager's base catalog is sourced for the lifetime of the process. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum CatalogMode { - /// Start from bundled `models.json` and allow cache/network refresh updates. - Default, - /// Use a caller-provided catalog as authoritative and do not mutate it via refresh. - Custom, -} - -/// Coordinates remote model discovery plus cached metadata on disk. -#[derive(Debug)] -pub struct ModelsManager { - remote_models: RwLock>, - catalog_mode: CatalogMode, - collaboration_modes_config: CollaborationModesConfig, - etag: RwLock>, - cache_manager: ModelsCacheManager, - provider: SharedModelProvider, -} - -impl ModelsManager { - /// Construct a manager scoped to the provided `AuthManager`. - /// - /// Uses `codex_home` to store cached model metadata and initializes with bundled catalog - /// When `model_catalog` is provided, it becomes the authoritative remote model list and - /// background refreshes from `/models` are disabled. - pub fn new( - codex_home: PathBuf, - auth_manager: Arc, - model_catalog: Option, - collaboration_modes_config: CollaborationModesConfig, - ) -> Self { - Self::new_with_provider( - codex_home, - auth_manager, - model_catalog, - collaboration_modes_config, - ModelProviderInfo::create_openai_provider(/*base_url*/ None), - ) - } - - /// Construct a manager with an explicit provider used for remote model refreshes. - // TODO(celia-oai): Revisit this ownership direction: the model provider should likely - // own or return the models manager instead of requiring the manager to construct and use - // a provider from provider info. - pub fn new_with_provider( - codex_home: PathBuf, - auth_manager: Arc, - model_catalog: Option, - collaboration_modes_config: CollaborationModesConfig, - provider_info: ModelProviderInfo, - ) -> Self { - let model_provider = create_model_provider(provider_info, Some(auth_manager)); - let cache_path = codex_home.join(MODEL_CACHE_FILE); - let cache_manager = ModelsCacheManager::new(cache_path, DEFAULT_MODEL_CACHE_TTL); - let catalog_mode = if model_catalog.is_some() { - CatalogMode::Custom - } else { - CatalogMode::Default - }; - let remote_models = model_catalog - .map(|catalog| catalog.models) - .unwrap_or_else(|| Self::load_remote_models_from_file().unwrap_or_default()); - Self { - remote_models: RwLock::new(remote_models), - catalog_mode, - collaboration_modes_config, - etag: RwLock::new(None), - cache_manager, - provider: model_provider, - } - } +type SharedModelsEndpointClient = Arc; +/// Coordinates model discovery plus cached metadata on disk. +#[async_trait] +pub trait ModelsManager: fmt::Debug + Send + Sync { /// List all available models, refreshing according to the specified strategy. /// /// Returns model presets sorted by priority and filtered by auth mode and visibility. - #[instrument( - level = "info", - skip(self), - fields(refresh_strategy = %refresh_strategy) - )] - pub async fn list_models(&self, refresh_strategy: RefreshStrategy) -> Vec { - if let Err(err) = self.refresh_available_models(refresh_strategy).await { - error!("failed to refresh available models: {err}"); + async fn list_models(&self, refresh_strategy: RefreshStrategy) -> Vec { + async move { + let catalog = self.raw_model_catalog(refresh_strategy).await; + self.build_available_models(catalog.models) } - let remote_models = self.get_remote_models().await; - self.build_available_models(remote_models) + .instrument(tracing::info_span!( + "list_models", + refresh_strategy = %refresh_strategy + )) + .await } /// Return the active raw model catalog, refreshing according to the specified strategy. - pub async fn raw_model_catalog(&self, refresh_strategy: RefreshStrategy) -> ModelsResponse { - if let Err(err) = self.refresh_available_models(refresh_strategy).await { - error!("failed to refresh available models: {err}"); - } - ModelsResponse { - models: self.get_remote_models().await, - } + async fn raw_model_catalog(&self, refresh_strategy: RefreshStrategy) -> ModelsResponse; + + /// Return the current in-memory remote model catalog without refreshing or loading cache state. + async fn get_remote_models(&self) -> Vec; + + /// Attempt to return the current in-memory remote model catalog without blocking. + /// + /// Returns an error if the internal lock cannot be acquired. + fn try_get_remote_models(&self) -> Result, TryLockError>; + + /// Return the auth manager used for picker filtering. + fn auth_manager(&self) -> Option<&AuthManager>; + + /// Build picker-ready presets from the active catalog snapshot. + fn build_available_models(&self, mut remote_models: Vec) -> Vec { + remote_models.sort_by(|a, b| a.priority.cmp(&b.priority)); + + let mut presets: Vec = remote_models.into_iter().map(Into::into).collect(); + let uses_codex_backend = self + .auth_manager() + .is_some_and(AuthManager::current_auth_uses_codex_backend); + presets = ModelPreset::filter_by_auth(presets, uses_codex_backend); + + ModelPreset::mark_default_by_picker_visibility(&mut presets); + + presets } /// List collaboration mode presets. /// /// Returns a static set of presets seeded with the configured model. - pub fn list_collaboration_modes(&self) -> Vec { - self.list_collaboration_modes_for_config(self.collaboration_modes_config) - } - - pub fn list_collaboration_modes_for_config( - &self, - collaboration_modes_config: CollaborationModesConfig, - ) -> Vec { - builtin_collaboration_mode_presets(collaboration_modes_config) - } + fn list_collaboration_modes(&self) -> Vec; /// Attempt to list models without blocking, using the current cached state. /// /// Returns an error if the internal lock cannot be acquired. - pub fn try_list_models(&self) -> Result, TryLockError> { + fn try_list_models(&self) -> Result, TryLockError> { let remote_models = self.try_get_remote_models()?; Ok(self.build_available_models(remote_models)) } @@ -289,104 +137,129 @@ impl ModelsManager { /// /// If `model` is provided, returns it directly. Otherwise selects the default based on /// auth mode and available models. - #[instrument( - level = "info", - skip(self, model), - fields( - model.provided = model.is_some(), - refresh_strategy = %refresh_strategy - ) - )] - pub async fn get_default_model( + async fn get_default_model( &self, model: &Option, refresh_strategy: RefreshStrategy, ) -> String { - if let Some(model) = model.as_ref() { - return model.to_string(); - } - if let Err(err) = self.refresh_available_models(refresh_strategy).await { - error!("failed to refresh available models: {err}"); + async move { + if let Some(model) = model.as_ref() { + return model.to_string(); + } + default_model_from_available(self.list_models(refresh_strategy).await) } - let remote_models = self.get_remote_models().await; - let available = self.build_available_models(remote_models); - available - .iter() - .find(|model| model.is_default) - .or_else(|| available.first()) - .map(|model| model.model.clone()) - .unwrap_or_default() + .instrument(tracing::info_span!( + "get_default_model", + model.provided = model.is_some(), + refresh_strategy = %refresh_strategy + )) + .await } // todo(aibrahim): look if we can tighten it to pub(crate) /// Look up model metadata, applying remote overrides and config adjustments. - #[instrument(level = "info", skip(self, config), fields(model = model))] - pub async fn get_model_info(&self, model: &str, config: &ModelsManagerConfig) -> ModelInfo { - let remote_models = self.get_remote_models().await; - Self::construct_model_info_from_candidates(model, &remote_models, config) + async fn get_model_info(&self, model: &str, config: &ModelsManagerConfig) -> ModelInfo { + async move { + let remote_models = self.get_remote_models().await; + construct_model_info_from_candidates(model, &remote_models, config) + } + .instrument(tracing::info_span!("get_model_info", model = model)) + .await } - fn find_model_by_longest_prefix(model: &str, candidates: &[ModelInfo]) -> Option { - let mut best: Option = None; - for candidate in candidates { - if !model.starts_with(&candidate.slug) { - continue; - } - let is_better_match = if let Some(current) = best.as_ref() { - candidate.slug.len() > current.slug.len() - } else { - true - }; - if is_better_match { - best = Some(candidate.clone()); - } + /// Refresh models if the provided ETag differs from the cached ETag. + /// + /// Uses `Online` strategy to fetch latest models when ETags differ. + async fn refresh_if_new_etag(&self, etag: String); +} + +/// Shared model manager handle used across runtime services. +pub type SharedModelsManager = Arc; + +/// OpenAI-compatible model manager backed by bundled models, cache, and `/models`. +#[derive(Debug)] +pub struct OpenAiModelsManager { + remote_models: RwLock>, + collaboration_modes_config: CollaborationModesConfig, + etag: RwLock>, + cache_manager: ModelsCacheManager, + endpoint_client: SharedModelsEndpointClient, + auth_manager: Option>, +} + +/// Static model manager backed by an authoritative in-process catalog. +#[derive(Debug)] +pub struct StaticModelsManager { + remote_models: Vec, + collaboration_modes_config: CollaborationModesConfig, + auth_manager: Option>, +} + +impl OpenAiModelsManager { + /// Construct an OpenAI-compatible remote model manager. + pub fn new( + codex_home: PathBuf, + endpoint_client: Arc, + auth_manager: Option>, + collaboration_modes_config: CollaborationModesConfig, + ) -> Self { + let cache_path = codex_home.join(MODEL_CACHE_FILE); + let cache_manager = ModelsCacheManager::new(cache_path, DEFAULT_MODEL_CACHE_TTL); + let remote_models = load_remote_models_from_file().unwrap_or_default(); + Self { + remote_models: RwLock::new(remote_models), + collaboration_modes_config, + etag: RwLock::new(None), + cache_manager, + endpoint_client, + auth_manager, } - best } +} - /// Retry metadata lookup for a single namespaced slug like `namespace/model-name`. - /// - /// This only strips one leading namespace segment and only when the namespace is ASCII - /// alphanumeric/underscore (`\\w+`) to avoid broadly matching arbitrary aliases. - fn find_model_by_namespaced_suffix(model: &str, candidates: &[ModelInfo]) -> Option { - let (namespace, suffix) = model.split_once('/')?; - if suffix.contains('/') { - return None; +impl StaticModelsManager { + /// Construct a static model manager from an authoritative catalog. + pub fn new( + auth_manager: Option>, + model_catalog: ModelsResponse, + collaboration_modes_config: CollaborationModesConfig, + ) -> Self { + Self { + remote_models: model_catalog.models, + collaboration_modes_config, + auth_manager, + } + } +} + +#[async_trait] +impl ModelsManager for OpenAiModelsManager { + async fn raw_model_catalog(&self, refresh_strategy: RefreshStrategy) -> ModelsResponse { + if let Err(err) = self.refresh_available_models(refresh_strategy).await { + error!("failed to refresh available models: {err}"); } - if !namespace - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_') - { - return None; + ModelsResponse { + models: self.get_remote_models().await, } - Self::find_model_by_longest_prefix(suffix, candidates) } - fn construct_model_info_from_candidates( - model: &str, - candidates: &[ModelInfo], - config: &ModelsManagerConfig, - ) -> ModelInfo { - // First use the normal longest-prefix match. If that misses, allow a narrowly scoped - // retry for namespaced slugs like `custom/gpt-5.3-codex`. - let remote = Self::find_model_by_longest_prefix(model, candidates) - .or_else(|| Self::find_model_by_namespaced_suffix(model, candidates)); - let model_info = if let Some(remote) = remote { - ModelInfo { - slug: model.to_string(), - used_fallback_model_metadata: false, - ..remote - } - } else { - model_info::model_info_from_slug(model) - }; - model_info::with_config_overrides(model_info, config) + async fn get_remote_models(&self) -> Vec { + self.remote_models.read().await.clone() } - /// Refresh models if the provided ETag differs from the cached ETag. - /// - /// Uses `Online` strategy to fetch latest models when ETags differ. - pub async fn refresh_if_new_etag(&self, etag: String) { + fn try_get_remote_models(&self) -> Result, TryLockError> { + Ok(self.remote_models.try_read()?.clone()) + } + + fn auth_manager(&self) -> Option<&AuthManager> { + self.auth_manager.as_deref() + } + + fn list_collaboration_modes(&self) -> Vec { + builtin_collaboration_mode_presets(self.collaboration_modes_config) + } + + async fn refresh_if_new_etag(&self, etag: String) { let current_etag = self.get_etag().await; if current_etag.clone().is_some() && current_etag.as_deref() == Some(etag.as_str()) { if let Err(err) = self.cache_manager.renew_cache_ttl().await { @@ -398,21 +271,12 @@ impl ModelsManager { error!("failed to refresh available models: {err}"); } } +} +impl OpenAiModelsManager { /// Refresh available models according to the specified strategy. async fn refresh_available_models(&self, refresh_strategy: RefreshStrategy) -> CoreResult<()> { - // don't override the custom model catalog if one was provided by the user - if matches!(self.catalog_mode, CatalogMode::Custom) { - return Ok(()); - } - - let uses_codex_backend = self - .provider - .auth() - .await - .as_ref() - .is_some_and(CodexAuth::uses_codex_backend); - if !uses_codex_backend && !self.provider.info().has_command_auth() { + if !self.should_refresh_models().await { if matches!( refresh_strategy, RefreshStrategy::Offline | RefreshStrategy::OnlineIfUncached @@ -445,37 +309,8 @@ impl ModelsManager { } async fn fetch_and_update_models(&self) -> CoreResult<()> { - let _timer = - codex_otel::start_global_timer("codex.remote_models.fetch_update.duration_ms", &[]); - let auth_manager = self.provider.auth_manager(); - let codex_api_key_env_enabled = auth_manager - .as_ref() - .is_some_and(|auth_manager| auth_manager.codex_api_key_env_enabled()); - let auth = self.provider.auth().await; - let auth_mode = auth.as_ref().map(CodexAuth::auth_mode); - let api_provider = self.provider.api_provider().await?; - let api_auth = self.provider.api_auth().await?; - let auth_env = collect_auth_env_telemetry(self.provider.info(), codex_api_key_env_enabled); - let transport = ReqwestTransport::new(build_reqwest_client()); - let auth_telemetry = auth_header_telemetry(api_auth.as_ref()); - let request_telemetry: Arc = Arc::new(ModelsRequestTelemetry { - auth_mode: auth_mode.map(|mode| TelemetryAuthMode::from(mode).to_string()), - auth_header_attached: auth_telemetry.attached, - auth_header_name: auth_telemetry.name, - auth_env, - }); - let client = ModelsClient::new(transport, api_provider, api_auth) - .with_telemetry(Some(request_telemetry)); - let client_version = crate::client_version_to_whole(); - let (models, etag) = timeout( - MODELS_REFRESH_TIMEOUT, - client.list_models(&client_version, HeaderMap::new()), - ) - .await - .map_err(|_| CodexErr::Timeout)? - .map_err(map_api_error)?; - + let (models, etag) = self.endpoint_client.list_models(&client_version).await?; self.apply_remote_models(models.clone()).await; *self.etag.write().await = etag.clone(); self.cache_manager @@ -484,13 +319,17 @@ impl ModelsManager { Ok(()) } + async fn should_refresh_models(&self) -> bool { + self.endpoint_client.uses_codex_backend().await || self.endpoint_client.has_command_auth() + } + async fn get_etag(&self) -> Option { self.etag.read().await.clone() } /// Replace the cached remote models and rebuild the derived presets list. async fn apply_remote_models(&self, models: Vec) { - let mut existing_models = Self::load_remote_models_from_file().unwrap_or_default(); + let mut existing_models = load_remote_models_from_file().unwrap_or_default(); for model in models { if let Some(existing_index) = existing_models .iter() @@ -504,16 +343,14 @@ impl ModelsManager { *self.remote_models.write().await = existing_models; } - fn load_remote_models_from_file() -> Result, std::io::Error> { - Ok(crate::bundled_models_response()?.models) - } - /// Attempt to satisfy the refresh from the cache when it matches the provider and TTL. async fn try_load_cache(&self) -> bool { let _timer = codex_otel::start_global_timer("codex.remote_models.load_cache.duration_ms", &[]); let client_version = crate::client_version_to_whole(); info!(client_version, "models cache: evaluating cache eligibility"); + // TODO(celia-oai): Include provider identity in cache eligibility so switching + // providers does not reuse a fresh models_cache.json entry from another provider. let cache = match self.cache_manager.load_fresh(&client_version).await { Some(cache) => cache, None => { @@ -531,75 +368,103 @@ impl ModelsManager { ); true } +} - /// Build picker-ready presets from the active catalog snapshot. - fn build_available_models(&self, mut remote_models: Vec) -> Vec { - remote_models.sort_by(|a, b| a.priority.cmp(&b.priority)); - - let mut presets: Vec = remote_models.into_iter().map(Into::into).collect(); - let uses_codex_backend = self - .provider - .auth_manager() - .as_deref() - .is_some_and(AuthManager::current_auth_uses_codex_backend); - presets = ModelPreset::filter_by_auth(presets, uses_codex_backend); - - ModelPreset::mark_default_by_picker_visibility(&mut presets); - - presets +#[async_trait] +impl ModelsManager for StaticModelsManager { + async fn raw_model_catalog(&self, _refresh_strategy: RefreshStrategy) -> ModelsResponse { + ModelsResponse { + models: self.get_remote_models().await, + } } async fn get_remote_models(&self) -> Vec { - self.remote_models.read().await.clone() + self.remote_models.clone() } fn try_get_remote_models(&self) -> Result, TryLockError> { - Ok(self.remote_models.try_read()?.clone()) + Ok(self.remote_models.clone()) } - /// Construct a manager with a specific provider for testing. - pub fn with_provider_for_tests( - codex_home: PathBuf, - auth_manager: Arc, - provider: ModelProviderInfo, - ) -> Self { - Self::new_with_provider( - codex_home, - auth_manager, - /*model_catalog*/ None, - CollaborationModesConfig::default(), - provider, - ) + fn auth_manager(&self) -> Option<&AuthManager> { + self.auth_manager.as_deref() } - /// Get model identifier without consulting remote state or cache. - pub fn get_model_offline_for_tests(model: Option<&str>) -> String { - if let Some(model) = model { - return model.to_string(); - } - let mut models = Self::load_remote_models_from_file().unwrap_or_default(); - models.sort_by(|a, b| a.priority.cmp(&b.priority)); - let presets: Vec = models.into_iter().map(Into::into).collect(); - presets - .iter() - .find(|preset| preset.show_in_picker) - .or_else(|| presets.first()) - .map(|preset| preset.model.clone()) - .unwrap_or_default() + fn list_collaboration_modes(&self) -> Vec { + builtin_collaboration_mode_presets(self.collaboration_modes_config) } - /// Build `ModelInfo` without consulting remote state or cache. - pub fn construct_model_info_offline_for_tests( - model: &str, - config: &ModelsManagerConfig, - ) -> ModelInfo { - let candidates: &[ModelInfo] = if let Some(model_catalog) = config.model_catalog.as_ref() { - &model_catalog.models + async fn refresh_if_new_etag(&self, _etag: String) {} +} + +fn load_remote_models_from_file() -> Result, std::io::Error> { + Ok(crate::bundled_models_response()?.models) +} + +fn default_model_from_available(available: Vec) -> String { + available + .iter() + .find(|model| model.is_default) + .or_else(|| available.first()) + .map(|model| model.model.clone()) + .unwrap_or_default() +} + +fn find_model_by_longest_prefix(model: &str, candidates: &[ModelInfo]) -> Option { + let mut best: Option = None; + for candidate in candidates { + if !model.starts_with(&candidate.slug) { + continue; + } + let is_better_match = if let Some(current) = best.as_ref() { + candidate.slug.len() > current.slug.len() } else { - &[] + true }; - Self::construct_model_info_from_candidates(model, candidates, config) + if is_better_match { + best = Some(candidate.clone()); + } + } + best +} + +fn find_model_by_namespaced_suffix(model: &str, candidates: &[ModelInfo]) -> Option { + // Retry metadata lookup for a single namespaced slug like `namespace/model-name`. + // + // This only strips one leading namespace segment and only when the namespace is ASCII + // alphanumeric/underscore (`\w+`) to avoid broadly matching arbitrary aliases. + let (namespace, suffix) = model.split_once('/')?; + if suffix.contains('/') { + return None; + } + if !namespace + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_') + { + return None; } + find_model_by_longest_prefix(suffix, candidates) +} + +pub(crate) fn construct_model_info_from_candidates( + model: &str, + candidates: &[ModelInfo], + config: &ModelsManagerConfig, +) -> ModelInfo { + // First use the normal longest-prefix match. If that misses, allow a narrowly scoped + // retry for namespaced slugs like `custom/gpt-5.3-codex`. + let remote = find_model_by_longest_prefix(model, candidates) + .or_else(|| find_model_by_namespaced_suffix(model, candidates)); + let model_info = if let Some(remote) = remote { + ModelInfo { + slug: model.to_string(), + used_fallback_model_metadata: false, + ..remote + } + } else { + model_info::model_info_from_slug(model) + }; + model_info::with_config_overrides(model_info, config) } #[cfg(test)] diff --git a/codex-rs/models-manager/src/manager_tests.rs b/codex-rs/models-manager/src/manager_tests.rs index 5966df616d19..4046b7565f54 100644 --- a/codex-rs/models-manager/src/manager_tests.rs +++ b/codex-rs/models-manager/src/manager_tests.rs @@ -1,40 +1,27 @@ use super::*; use crate::ModelsManagerConfig; -use base64::Engine as _; use chrono::Utc; -use codex_api::TransportError; -use codex_config::types::AuthCredentialsStoreMode; +use codex_app_server_protocol::AuthMode; +use codex_login::AuthCredentialsStoreMode; use codex_login::AuthManager; use codex_login::CodexAuth; -use codex_model_provider_info::WireApi; -use codex_protocol::config_types::ModelProviderAuthInfo; +use codex_login::ExternalAuth; +use codex_login::ExternalAuthRefreshContext; +use codex_login::ExternalAuthTokens; +use codex_login::TokenData; +use codex_login::auth::AgentIdentityAuth; +use codex_login::auth::AgentIdentityAuthRecord; +use codex_protocol::account::PlanType; use codex_protocol::openai_models::ModelsResponse; -use codex_utils_absolute_path::AbsolutePathBuf; -use core_test_support::responses::mount_models_once; -use http::HeaderMap; -use http::StatusCode; use pretty_assertions::assert_eq; use serde_json::json; -use std::collections::BTreeMap; -use std::num::NonZeroU64; +use std::collections::VecDeque; +use std::path::Path; use std::sync::Arc; use std::sync::Mutex; -use tempfile::TempDir; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use tempfile::tempdir; -use tracing::Event; -use tracing::Subscriber; -use tracing::field::Visit; -use tracing_subscriber::Layer; -use tracing_subscriber::layer::Context; -use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::registry::LookupSpan; -use tracing_subscriber::util::SubscriberInitExt; -use wiremock::Mock; -use wiremock::MockServer; -use wiremock::ResponseTemplate; -use wiremock::matchers::header_regex; -use wiremock::matchers::method; -use wiremock::matchers::path; #[path = "model_info_overrides_tests.rs"] mod model_info_overrides_tests; @@ -86,174 +73,188 @@ fn assert_models_contain(actual: &[ModelInfo], expected: &[ModelInfo]) { } } -fn provider_for(base_url: String) -> ModelProviderInfo { - ModelProviderInfo { - name: "mock".into(), - base_url: Some(base_url), - env_key: None, - env_key_instructions: None, - experimental_bearer_token: None, - auth: None, - aws: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(5_000), - websocket_connect_timeout_ms: None, - requires_openai_auth: false, - supports_websockets: false, - } +#[derive(Debug)] +struct TestModelsEndpoint { + has_command_auth: bool, + uses_codex_backend: bool, + responses: Mutex>>, + fetch_count: AtomicUsize, } -struct ProviderAuthScript { - tempdir: TempDir, - command: String, - args: Vec, +impl TestModelsEndpoint { + fn new(responses: Vec>) -> Arc { + Arc::new(Self { + has_command_auth: false, + uses_codex_backend: true, + responses: Mutex::new(responses.into()), + fetch_count: AtomicUsize::new(0), + }) + } + + fn without_refresh(responses: Vec>) -> Arc { + Arc::new(Self { + has_command_auth: false, + uses_codex_backend: false, + responses: Mutex::new(responses.into()), + fetch_count: AtomicUsize::new(0), + }) + } + + fn fetch_count(&self) -> usize { + self.fetch_count.load(Ordering::SeqCst) + } } -impl ProviderAuthScript { - fn new(tokens: &[&str]) -> std::io::Result { - let tempdir = tempfile::tempdir()?; - let tokens_file = tempdir.path().join("tokens.txt"); - // `cmd.exe`'s `set /p` treats LF-only input as one line, so use CRLF on Windows. - let token_line_ending = if cfg!(windows) { "\r\n" } else { "\n" }; - let mut token_file_contents = String::new(); - for token in tokens { - token_file_contents.push_str(token); - token_file_contents.push_str(token_line_ending); - } - std::fs::write(&tokens_file, token_file_contents)?; - - #[cfg(unix)] - let (command, args) = { - let script_path = tempdir.path().join("print-token.sh"); - std::fs::write( - &script_path, - r#"#!/bin/sh -first_line=$(sed -n '1p' tokens.txt) -printf '%s\n' "$first_line" -tail -n +2 tokens.txt > tokens.next -mv tokens.next tokens.txt -"#, - )?; - let mut permissions = std::fs::metadata(&script_path)?.permissions(); - { - use std::os::unix::fs::PermissionsExt; - permissions.set_mode(0o755); - } - std::fs::set_permissions(&script_path, permissions)?; - ("./print-token.sh".to_string(), Vec::new()) - }; - - #[cfg(windows)] - let (command, args) = { - let script_path = tempdir.path().join("print-token.cmd"); - std::fs::write( - &script_path, - r#"@echo off -setlocal EnableExtensions DisableDelayedExpansion -set "first_line=" - AuthMode { + AuthMode::ApiKey } - fn auth_config(&self) -> ModelProviderAuthInfo { - let timeout_ms = if cfg!(windows) { - // Process startup can be slow on loaded Windows CI workers. - 10_000 - } else { - 2_000 - }; - ModelProviderAuthInfo { - command: self.command.clone(), - args: self.args.clone(), - timeout_ms: NonZeroU64::new(timeout_ms).unwrap(), - refresh_interval_ms: 60_000, - cwd: match AbsolutePathBuf::try_from(self.tempdir.path()) { - Ok(cwd) => cwd, - Err(err) => panic!("tempdir should be absolute: {err}"), - }, - } + async fn resolve(&self) -> std::io::Result> { + Ok(Some(ExternalAuthTokens::access_token_only( + "test-external-api-key", + ))) + } + + async fn refresh( + &self, + _context: ExternalAuthRefreshContext, + ) -> std::io::Result { + Ok(ExternalAuthTokens::access_token_only( + "test-external-api-key", + )) } } -#[derive(Default)] -struct TagCollectorVisitor { - tags: BTreeMap, +#[derive(Debug)] +struct TestUnresolvedExternalApiKeyAuth; + +#[async_trait] +impl ExternalAuth for TestUnresolvedExternalApiKeyAuth { + fn auth_mode(&self) -> AuthMode { + AuthMode::ApiKey + } + + async fn refresh( + &self, + _context: ExternalAuthRefreshContext, + ) -> std::io::Result { + Err(std::io::Error::other("unresolved test auth")) + } } -impl Visit for TagCollectorVisitor { - fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { - self.tags - .insert(field.name().to_string(), value.to_string()); +#[async_trait] +impl ModelsEndpointClient for TestModelsEndpoint { + fn has_command_auth(&self) -> bool { + self.has_command_auth } - fn record_str(&mut self, field: &tracing::field::Field, value: &str) { - self.tags - .insert(field.name().to_string(), value.to_string()); + async fn uses_codex_backend(&self) -> bool { + self.uses_codex_backend } - fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { - self.tags - .insert(field.name().to_string(), format!("{value:?}")); + async fn list_models( + &self, + _client_version: &str, + ) -> CoreResult<(Vec, Option)> { + self.fetch_count.fetch_add(1, Ordering::SeqCst); + let models = self + .responses + .lock() + .expect("responses lock should not be poisoned") + .pop_front() + .unwrap_or_default(); + Ok((models, None)) } } -#[derive(Clone)] -struct TagCollectorLayer { - tags: Arc>>, +fn openai_manager_for_tests( + codex_home: std::path::PathBuf, + endpoint_client: Arc, +) -> OpenAiModelsManager { + openai_manager_for_tests_with_auth( + codex_home, + endpoint_client, + Some(AuthManager::from_auth_for_testing( + CodexAuth::create_dummy_chatgpt_auth_for_testing(), + )), + ) } -impl Layer for TagCollectorLayer -where - S: Subscriber + for<'a> LookupSpan<'a>, -{ - fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) { - if event.metadata().target() != "feedback_tags" { - return; - } - let mut visitor = TagCollectorVisitor::default(); - event.record(&mut visitor); - self.tags.lock().unwrap().extend(visitor.tags); - } +fn openai_manager_for_tests_with_auth( + codex_home: std::path::PathBuf, + endpoint_client: Arc, + auth_manager: Option>, +) -> OpenAiModelsManager { + OpenAiModelsManager::new( + codex_home, + endpoint_client, + auth_manager, + CollaborationModesConfig::default(), + ) +} + +fn static_manager_for_tests(model_catalog: ModelsResponse) -> StaticModelsManager { + StaticModelsManager::new( + /*auth_manager*/ None, + model_catalog, + CollaborationModesConfig::default(), + ) +} + +fn chatgpt_auth_tokens_for_tests(codex_home: &Path) -> CodexAuth { + let auth_dot_json = codex_login::AuthDotJson { + auth_mode: Some(AuthMode::ChatgptAuthTokens), + openai_api_key: None, + tokens: Some(TokenData { + id_token: codex_login::token_data::parse_chatgpt_jwt_claims( + "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.\ +eyJlbWFpbCI6InVzZXJAZXhhbXBsZS5jb20iLCJodHRwczovL2FwaS5vcGVuYWkuY29tL2F1dGgiOnsiY2hhdGdwdF9wbGFuX3R5cGUiOiJwcm8iLCJjaGF0Z3B0X3VzZXJfaWQiOiJ1c2VyLWlkIiwiY2hhdGdwdF9hY2NvdW50X2lkIjoiYWNjb3VudC1pZCJ9fQ.\ +c2ln", + ) + .expect("fake id token should parse"), + access_token: "Access Token".to_string(), + refresh_token: "test".to_string(), + account_id: Some("account_id".to_string()), + }), + last_refresh: Some(Utc::now()), + agent_identity: None, + }; + std::fs::create_dir_all(codex_home).expect("codex home should be created"); + std::fs::write( + codex_home.join("auth.json"), + serde_json::to_string(&auth_dot_json).expect("auth should serialize"), + ) + .expect("auth.json should be written"); + + CodexAuth::from_auth_storage(codex_home, AuthCredentialsStoreMode::File) + .expect("auth should load") + .expect("auth should be present") +} + +fn agent_identity_auth_for_tests() -> CodexAuth { + CodexAuth::AgentIdentity(AgentIdentityAuth::new(AgentIdentityAuthRecord { + agent_runtime_id: "agent-runtime-id".to_string(), + agent_private_key: "agent-private-key".to_string(), + account_id: "account-id".to_string(), + chatgpt_user_id: "chatgpt-user-id".to_string(), + email: "agent@example.com".to_string(), + plan_type: PlanType::Pro, + chatgpt_account_is_fedramp: false, + })) } #[tokio::test] async fn get_model_info_tracks_fallback_usage() { let codex_home = tempdir().expect("temp dir"); let config = ModelsManagerConfig::default(); - let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let manager = ModelsManager::new( + let manager = openai_manager_for_tests( codex_home.path().to_path_buf(), - auth_manager, - /*model_catalog*/ None, - CollaborationModesConfig::default(), + TestModelsEndpoint::new(Vec::new()), ); let known_slug = manager .get_remote_models() @@ -276,20 +277,13 @@ async fn get_model_info_tracks_fallback_usage() { #[tokio::test] async fn get_model_info_uses_custom_catalog() { - let codex_home = tempdir().expect("temp dir"); let config = ModelsManagerConfig::default(); let mut overlay = remote_model("gpt-overlay", "Overlay", /*priority*/ 0); overlay.supports_image_detail_original = true; - let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let manager = ModelsManager::new( - codex_home.path().to_path_buf(), - auth_manager, - Some(ModelsResponse { - models: vec![overlay], - }), - CollaborationModesConfig::default(), - ); + let manager = static_manager_for_tests(ModelsResponse { + models: vec![overlay], + }); let model_info = manager .get_model_info("gpt-overlay-experiment", &config) @@ -305,19 +299,12 @@ async fn get_model_info_uses_custom_catalog() { #[tokio::test] async fn get_model_info_matches_namespaced_suffix() { - let codex_home = tempdir().expect("temp dir"); let config = ModelsManagerConfig::default(); let mut remote = remote_model("gpt-image", "Image", /*priority*/ 0); remote.supports_image_detail_original = true; - let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let manager = ModelsManager::new( - codex_home.path().to_path_buf(), - auth_manager, - Some(ModelsResponse { - models: vec![remote], - }), - CollaborationModesConfig::default(), - ); + let manager = static_manager_for_tests(ModelsResponse { + models: vec![remote], + }); let namespaced_model = "custom/gpt-image".to_string(); let model_info = manager.get_model_info(&namespaced_model, &config).await; @@ -331,12 +318,9 @@ async fn get_model_info_matches_namespaced_suffix() { async fn get_model_info_rejects_multi_segment_namespace_suffix_matching() { let codex_home = tempdir().expect("temp dir"); let config = ModelsManagerConfig::default(); - let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let manager = ModelsManager::new( + let manager = openai_manager_for_tests( codex_home.path().to_path_buf(), - auth_manager, - /*model_catalog*/ None, - CollaborationModesConfig::default(), + TestModelsEndpoint::new(Vec::new()), ); let known_slug = manager .get_remote_models() @@ -355,28 +339,13 @@ async fn get_model_info_rejects_multi_segment_namespace_suffix_matching() { #[tokio::test] async fn refresh_available_models_sorts_by_priority() { - let server = MockServer::start().await; let remote_models = vec![ remote_model("priority-low", "Low", /*priority*/ 1), remote_model("priority-high", "High", /*priority*/ 0), ]; - let models_mock = mount_models_once( - &server, - ModelsResponse { - models: remote_models.clone(), - }, - ) - .await; - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); + let endpoint = TestModelsEndpoint::new(vec![remote_models.clone()]); + let manager = openai_manager_for_tests(codex_home.path().to_path_buf(), endpoint.clone()); manager .refresh_available_models(RefreshStrategy::OnlineIfUncached) @@ -398,78 +367,15 @@ async fn refresh_available_models_sorts_by_priority() { high_idx < low_idx, "higher priority should be listed before lower priority" ); - assert_eq!( - models_mock.requests().len(), - 1, - "expected a single /models request" - ); -} - -#[tokio::test] -async fn refresh_available_models_uses_provider_auth_token() { - let server = MockServer::start().await; - let auth_script = ProviderAuthScript::new(&["provider-token"]).unwrap(); - let remote_models = vec![remote_model( - "provider-model", - "Provider", - /*priority*/ 0, - )]; - - Mock::given(method("GET")) - .and(path("/models")) - .and(header_regex("Authorization", "Bearer provider-token")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("content-type", "application/json") - .set_body_json(ModelsResponse { - models: remote_models.clone(), - }), - ) - .expect(1) - .mount(&server) - .await; - - let codex_home = tempdir().expect("temp dir"); - let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("unused")); - let provider = ModelProviderInfo { - auth: Some(auth_script.auth_config()), - ..provider_for(server.uri()) - }; - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); - - manager - .refresh_available_models(RefreshStrategy::Online) - .await - .expect("refresh succeeds"); - - assert_models_contain(&manager.get_remote_models().await, &remote_models); + assert_eq!(endpoint.fetch_count(), 1, "expected a single model fetch"); } #[tokio::test] async fn refresh_available_models_uses_cache_when_fresh() { - let server = MockServer::start().await; let remote_models = vec![remote_model("cached", "Cached", /*priority*/ 5)]; - let models_mock = mount_models_once( - &server, - ModelsResponse { - models: remote_models.clone(), - }, - ) - .await; - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); + let endpoint = TestModelsEndpoint::new(vec![remote_models.clone()]); + let manager = openai_manager_for_tests(codex_home.path().to_path_buf(), endpoint.clone()); manager .refresh_available_models(RefreshStrategy::OnlineIfUncached) @@ -484,33 +390,19 @@ async fn refresh_available_models_uses_cache_when_fresh() { .expect("cached refresh succeeds"); assert_models_contain(&manager.get_remote_models().await, &remote_models); assert_eq!( - models_mock.requests().len(), + endpoint.fetch_count(), 1, - "cache hit should avoid a second /models request" + "cache hit should avoid a second model fetch" ); } #[tokio::test] async fn refresh_available_models_refetches_when_cache_stale() { - let server = MockServer::start().await; let initial_models = vec![remote_model("stale", "Stale", /*priority*/ 1)]; - let initial_mock = mount_models_once( - &server, - ModelsResponse { - models: initial_models.clone(), - }, - ) - .await; - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); + let updated_models = vec![remote_model("fresh", "Fresh", /*priority*/ 9)]; + let endpoint = TestModelsEndpoint::new(vec![initial_models.clone(), updated_models.clone()]); + let manager = openai_manager_for_tests(codex_home.path().to_path_buf(), endpoint.clone()); manager .refresh_available_models(RefreshStrategy::OnlineIfUncached) @@ -526,54 +418,25 @@ async fn refresh_available_models_refetches_when_cache_stale() { .await .expect("cache manipulation succeeds"); - let updated_models = vec![remote_model("fresh", "Fresh", /*priority*/ 9)]; - server.reset().await; - let refreshed_mock = mount_models_once( - &server, - ModelsResponse { - models: updated_models.clone(), - }, - ) - .await; - manager .refresh_available_models(RefreshStrategy::OnlineIfUncached) .await .expect("second refresh succeeds"); assert_models_contain(&manager.get_remote_models().await, &updated_models); assert_eq!( - initial_mock.requests().len(), - 1, - "initial refresh should only hit /models once" - ); - assert_eq!( - refreshed_mock.requests().len(), - 1, - "stale cache refresh should fetch /models once" + endpoint.fetch_count(), + 2, + "stale cache refresh should fetch models again" ); } #[tokio::test] async fn refresh_available_models_refetches_when_version_mismatch() { - let server = MockServer::start().await; let initial_models = vec![remote_model("old", "Old", /*priority*/ 1)]; - let initial_mock = mount_models_once( - &server, - ModelsResponse { - models: initial_models.clone(), - }, - ) - .await; - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); + let updated_models = vec![remote_model("new", "New", /*priority*/ 2)]; + let endpoint = TestModelsEndpoint::new(vec![initial_models.clone(), updated_models.clone()]); + let manager = openai_manager_for_tests(codex_home.path().to_path_buf(), endpoint.clone()); manager .refresh_available_models(RefreshStrategy::OnlineIfUncached) @@ -589,58 +452,33 @@ async fn refresh_available_models_refetches_when_version_mismatch() { .await .expect("cache mutation succeeds"); - let updated_models = vec![remote_model("new", "New", /*priority*/ 2)]; - server.reset().await; - let refreshed_mock = mount_models_once( - &server, - ModelsResponse { - models: updated_models.clone(), - }, - ) - .await; - manager .refresh_available_models(RefreshStrategy::OnlineIfUncached) .await .expect("second refresh succeeds"); assert_models_contain(&manager.get_remote_models().await, &updated_models); assert_eq!( - initial_mock.requests().len(), - 1, - "initial refresh should only hit /models once" - ); - assert_eq!( - refreshed_mock.requests().len(), - 1, - "version mismatch should fetch /models once" + endpoint.fetch_count(), + 2, + "version mismatch should fetch models again" ); } #[tokio::test] async fn refresh_available_models_drops_removed_remote_models() { - let server = MockServer::start().await; let initial_models = vec![remote_model( "remote-old", "Remote Old", /*priority*/ 1, )]; - let initial_mock = mount_models_once( - &server, - ModelsResponse { - models: initial_models, - }, - ) - .await; - let codex_home = tempdir().expect("temp dir"); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let provider = provider_for(server.uri()); - let mut manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); + let refreshed_models = vec![remote_model( + "remote-new", + "Remote New", + /*priority*/ 1, + )]; + let endpoint = TestModelsEndpoint::new(vec![initial_models, refreshed_models]); + let mut manager = openai_manager_for_tests(codex_home.path().to_path_buf(), endpoint.clone()); manager.cache_manager.set_ttl(Duration::ZERO); manager @@ -648,20 +486,6 @@ async fn refresh_available_models_drops_removed_remote_models() { .await .expect("initial refresh succeeds"); - server.reset().await; - let refreshed_models = vec![remote_model( - "remote-new", - "Remote New", - /*priority*/ 1, - )]; - let refreshed_mock = mount_models_once( - &server, - ModelsResponse { - models: refreshed_models, - }, - ) - .await; - manager .refresh_available_models(RefreshStrategy::OnlineIfUncached) .await @@ -679,41 +503,25 @@ async fn refresh_available_models_drops_removed_remote_models() { "removed remote model should not be listed" ); assert_eq!( - initial_mock.requests().len(), - 1, - "initial refresh should only hit /models once" - ); - assert_eq!( - refreshed_mock.requests().len(), - 1, - "second refresh should only hit /models once" + endpoint.fetch_count(), + 2, + "second refresh should fetch models again" ); } #[tokio::test] async fn refresh_available_models_skips_network_without_chatgpt_auth() { - let server = MockServer::start().await; let dynamic_slug = "dynamic-model-only-for-test-noauth"; - let models_mock = mount_models_once( - &server, - ModelsResponse { - models: vec![remote_model(dynamic_slug, "No Auth", /*priority*/ 1)], - }, - ) - .await; - let codex_home = tempdir().expect("temp dir"); - let auth_manager = Arc::new(AuthManager::new( - codex_home.path().to_path_buf(), - /*enable_codex_api_key_env*/ false, - AuthCredentialsStoreMode::File, - /*chatgpt_base_url*/ None, - )); - let provider = provider_for(server.uri()); - let manager = ModelsManager::with_provider_for_tests( + let endpoint = TestModelsEndpoint::without_refresh(vec![vec![remote_model( + dynamic_slug, + "No Auth", + /*priority*/ 1, + )]]); + let manager = openai_manager_for_tests_with_auth( codex_home.path().to_path_buf(), - auth_manager, - provider, + endpoint.clone(), + /*auth_manager*/ None, ); manager @@ -728,120 +536,222 @@ async fn refresh_available_models_skips_network_without_chatgpt_auth() { "remote refresh should be skipped without chatgpt auth" ); assert_eq!( - models_mock.requests().len(), + endpoint.fetch_count(), 0, - "no auth should avoid /models requests" + "endpoint that cannot refresh should avoid model fetches" ); } -#[test] -fn models_request_telemetry_emits_auth_env_feedback_tags_on_failure() { - let tags = Arc::new(Mutex::new(BTreeMap::new())); - let _guard = tracing_subscriber::registry() - .with(TagCollectorLayer { tags: tags.clone() }) - .set_default(); - - let telemetry = ModelsRequestTelemetry { - auth_mode: Some(TelemetryAuthMode::Chatgpt.to_string()), - auth_header_attached: true, - auth_header_name: Some("authorization"), - auth_env: codex_login::AuthEnvTelemetry { - openai_api_key_env_present: false, - codex_api_key_env_present: false, - codex_api_key_env_enabled: false, - provider_env_key_name: Some("configured".to_string()), - provider_env_key_present: Some(false), - refresh_token_url_override_present: false, - }, - }; - let mut headers = HeaderMap::new(); - headers.insert("x-request-id", "req-models-401".parse().unwrap()); - headers.insert("cf-ray", "ray-models-401".parse().unwrap()); - headers.insert( - "x-openai-authorization-error", - "missing_authorization_header".parse().unwrap(), - ); - headers.insert( - "x-error-json", - base64::engine::general_purpose::STANDARD - .encode(r#"{"error":{"code":"token_expired"}}"#) - .parse() - .unwrap(), - ); - telemetry.on_request( - /*attempt*/ 1, - Some(StatusCode::UNAUTHORIZED), - Some(&TransportError::Http { - status: StatusCode::UNAUTHORIZED, - url: Some("https://example.test/models".to_string()), - headers: Some(headers), - body: Some("plain text error".to_string()), - }), - Duration::from_millis(17), +#[derive(Debug)] +struct TestAuthAwareModelsEndpoint { + auth_manager: Option>, + responses: Mutex>>, + fetch_count: AtomicUsize, +} + +impl TestAuthAwareModelsEndpoint { + fn new(auth_manager: Option>, responses: Vec>) -> Arc { + Arc::new(Self { + auth_manager, + responses: Mutex::new(responses.into()), + fetch_count: AtomicUsize::new(0), + }) + } + + fn fetch_count(&self) -> usize { + self.fetch_count.load(Ordering::SeqCst) + } +} + +#[async_trait] +impl ModelsEndpointClient for TestAuthAwareModelsEndpoint { + fn has_command_auth(&self) -> bool { + false + } + + async fn uses_codex_backend(&self) -> bool { + match self.auth_manager.as_ref() { + Some(auth_manager) => auth_manager + .auth() + .await + .as_ref() + .is_some_and(CodexAuth::uses_codex_backend), + None => false, + } + } + + async fn list_models( + &self, + _client_version: &str, + ) -> CoreResult<(Vec, Option)> { + self.fetch_count.fetch_add(1, Ordering::SeqCst); + let models = self + .responses + .lock() + .expect("responses lock should not be poisoned") + .pop_front() + .unwrap_or_default(); + Ok((models, None)) + } +} + +#[tokio::test] +async fn refresh_available_models_skips_network_when_external_api_key_overrides_chatgpt_auth() { + let dynamic_slug = "dynamic-model-only-for-test-external-api-key"; + let codex_home = tempdir().expect("temp dir"); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + auth_manager.set_external_auth(Arc::new(TestExternalApiKeyAuth)); + let endpoint = TestAuthAwareModelsEndpoint::new( + Some(Arc::clone(&auth_manager)), + vec![vec![remote_model( + dynamic_slug, + "External API Key", + /*priority*/ 1, + )]], + ); + let manager = openai_manager_for_tests_with_auth( + codex_home.path().to_path_buf(), + endpoint.clone(), + Some(auth_manager), ); - let tags = tags.lock().unwrap().clone(); - assert_eq!( - tags.get("endpoint").map(String::as_str), - Some("\"/models\"") + manager + .refresh_available_models(RefreshStrategy::Online) + .await + .expect("refresh should no-op with API key auth"); + let cached_remote = manager.get_remote_models().await; + + assert!( + !cached_remote + .iter() + .any(|candidate| candidate.slug == dynamic_slug), + "remote refresh should be skipped when external API key auth is active" ); assert_eq!( - tags.get("auth_mode").map(String::as_str), - Some("\"Chatgpt\"") + endpoint.fetch_count(), + 0, + "endpoint should avoid model fetches when external API key auth is active" ); - assert_eq!( - tags.get("auth_request_id").map(String::as_str), - Some("\"req-models-401\"") +} + +#[tokio::test] +async fn refresh_available_models_uses_cached_chatgpt_when_external_api_key_is_unresolved() { + let dynamic_slug = "dynamic-model-only-for-test-unresolved-external-api-key"; + let codex_home = tempdir().expect("temp dir"); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + auth_manager.set_external_auth(Arc::new(TestUnresolvedExternalApiKeyAuth)); + let endpoint = TestAuthAwareModelsEndpoint::new( + Some(Arc::clone(&auth_manager)), + vec![vec![remote_model( + dynamic_slug, + "Unresolved External API Key", + /*priority*/ 1, + )]], + ); + let manager = openai_manager_for_tests_with_auth( + codex_home.path().to_path_buf(), + endpoint.clone(), + Some(auth_manager), ); - assert_eq!( - tags.get("auth_error").map(String::as_str), - Some("\"missing_authorization_header\"") + + manager + .refresh_available_models(RefreshStrategy::Online) + .await + .expect("refresh should fall back to cached ChatGPT auth"); + + assert!( + manager + .get_remote_models() + .await + .iter() + .any(|candidate| candidate.slug == dynamic_slug), + "remote refresh should include models fetched with cached ChatGPT auth" ); assert_eq!( - tags.get("auth_error_code").map(String::as_str), - Some("\"token_expired\"") + endpoint.fetch_count(), + 1, + "endpoint should fetch models when unresolved external API key falls back to ChatGPT auth" ); - assert_eq!( - tags.get("auth_env_openai_api_key_present") - .map(String::as_str), - Some("false") +} + +#[tokio::test] +async fn refresh_available_models_fetches_with_chatgpt_auth_tokens() { + let dynamic_slug = "dynamic-model-only-for-test-chatgpt-auth-tokens"; + let codex_home = tempdir().expect("temp dir"); + let endpoint = TestModelsEndpoint::new(vec![vec![remote_model( + dynamic_slug, + "ChatGPT Auth Tokens", + /*priority*/ 1, + )]]); + let auth = chatgpt_auth_tokens_for_tests(codex_home.path()); + let manager = openai_manager_for_tests_with_auth( + codex_home.path().to_path_buf(), + endpoint.clone(), + Some(AuthManager::from_auth_for_testing(auth)), ); - assert_eq!( - tags.get("auth_env_codex_api_key_present") - .map(String::as_str), - Some("false") + + manager + .refresh_available_models(RefreshStrategy::Online) + .await + .expect("refresh should fetch with ChatGPT auth tokens"); + + assert!( + manager + .get_remote_models() + .await + .iter() + .any(|candidate| candidate.slug == dynamic_slug), + "remote refresh should include models fetched with ChatGPT auth tokens" ); assert_eq!( - tags.get("auth_env_codex_api_key_enabled") - .map(String::as_str), - Some("false") + endpoint.fetch_count(), + 1, + "endpoint should fetch models with ChatGPT auth tokens" ); - assert_eq!( - tags.get("auth_env_provider_key_name").map(String::as_str), - Some("\"configured\"") +} + +#[tokio::test] +async fn refresh_available_models_fetches_with_agent_identity() { + let dynamic_slug = "dynamic-model-only-for-test-agent-identity"; + let codex_home = tempdir().expect("temp dir"); + let endpoint = TestModelsEndpoint::new(vec![vec![remote_model( + dynamic_slug, + "Agent Identity", + /*priority*/ 1, + )]]); + let manager = openai_manager_for_tests_with_auth( + codex_home.path().to_path_buf(), + endpoint.clone(), + Some(AuthManager::from_auth_for_testing( + agent_identity_auth_for_tests(), + )), ); - assert_eq!( - tags.get("auth_env_provider_key_present") - .map(String::as_str), - Some("\"false\"") + + manager + .refresh_available_models(RefreshStrategy::Online) + .await + .expect("refresh should fetch with agent identity"); + + assert!( + manager + .get_remote_models() + .await + .iter() + .any(|candidate| candidate.slug == dynamic_slug), + "remote refresh should include models fetched with agent identity" ); assert_eq!( - tags.get("auth_env_refresh_token_url_override_present") - .map(String::as_str), - Some("false") + endpoint.fetch_count(), + 1, + "endpoint should fetch models with agent identity" ); } #[test] fn build_available_models_picks_default_after_hiding_hidden_models() { - let codex_home = tempdir().expect("temp dir"); - let auth_manager = AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key")); - let provider = provider_for("http://example.test".to_string()); - let manager = ModelsManager::with_provider_for_tests( - codex_home.path().to_path_buf(), - auth_manager, - provider, - ); + let manager = static_manager_for_tests(ModelsResponse { models: Vec::new() }); let hidden_model = remote_model_with_visibility("hidden", "Hidden", /*priority*/ 0, "hide"); @@ -857,6 +767,74 @@ fn build_available_models_picks_default_after_hiding_hidden_models() { assert_eq!(available, vec![expected_hidden, expected_visible]); } +#[tokio::test] +async fn static_manager_treats_agent_identity_as_backend_auth_for_filtering() { + let chatgpt_only_model = { + let mut model = remote_model("chatgpt-only", "ChatGPT Only", /*priority*/ 0); + model.supported_in_api = false; + model + }; + let api_model = remote_model("api-model", "API Model", /*priority*/ 1); + let manager = StaticModelsManager::new( + Some(AuthManager::from_auth_for_testing( + agent_identity_auth_for_tests(), + )), + ModelsResponse { + models: vec![chatgpt_only_model, api_model], + }, + CollaborationModesConfig::default(), + ); + + let agent_identity_models = manager.list_models(RefreshStrategy::Online).await; + + assert_eq!( + agent_identity_models + .iter() + .map(|model| model.model.as_str()) + .collect::>(), + vec!["chatgpt-only", "api-model"] + ); +} + +#[tokio::test] +async fn static_manager_reads_latest_auth_mode() { + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let chatgpt_only_model = { + let mut model = remote_model("chatgpt-only", "ChatGPT Only", /*priority*/ 0); + model.supported_in_api = false; + model + }; + let api_model = remote_model("api-model", "API Model", /*priority*/ 1); + let manager = StaticModelsManager::new( + Some(Arc::clone(&auth_manager)), + ModelsResponse { + models: vec![chatgpt_only_model, api_model], + }, + CollaborationModesConfig::default(), + ); + + let chatgpt_models = manager.list_models(RefreshStrategy::Online).await; + assert_eq!( + chatgpt_models + .iter() + .map(|model| model.model.as_str()) + .collect::>(), + vec!["chatgpt-only", "api-model"] + ); + + auth_manager.set_external_auth(Arc::new(TestExternalApiKeyAuth)); + let api_models = manager.list_models(RefreshStrategy::Online).await; + + assert_eq!( + api_models + .iter() + .map(|model| model.model.as_str()) + .collect::>(), + vec!["api-model"] + ); +} + #[test] fn bundled_models_json_roundtrips() { let response = crate::bundled_models_response() diff --git a/codex-rs/models-manager/src/model_info_overrides_tests.rs b/codex-rs/models-manager/src/model_info_overrides_tests.rs index aaaf2dc44c8f..c499938ed47c 100644 --- a/codex-rs/models-manager/src/model_info_overrides_tests.rs +++ b/codex-rs/models-manager/src/model_info_overrides_tests.rs @@ -1,24 +1,19 @@ -use codex_login::AuthManager; -use codex_login::CodexAuth; - use crate::ModelsManagerConfig; -use crate::collaboration_mode_presets::CollaborationModesConfig; use crate::manager::ModelsManager; use codex_protocol::openai_models::TruncationPolicyConfig; use pretty_assertions::assert_eq; use tempfile::TempDir; +use super::TestModelsEndpoint; +use super::openai_manager_for_tests; + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn offline_model_info_without_tool_output_override() { let codex_home = TempDir::new().expect("create temp dir"); let config = ModelsManagerConfig::default(); - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let manager = ModelsManager::new( + let manager = openai_manager_for_tests( codex_home.path().to_path_buf(), - auth_manager, - /*model_catalog*/ None, - CollaborationModesConfig::default(), + TestModelsEndpoint::new(Vec::new()), ); let model_info = manager.get_model_info("gpt-5.2", &config).await; @@ -36,13 +31,9 @@ async fn offline_model_info_with_tool_output_override() { tool_output_token_limit: Some(123), ..Default::default() }; - let auth_manager = - AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); - let manager = ModelsManager::new( + let manager = openai_manager_for_tests( codex_home.path().to_path_buf(), - auth_manager, - /*model_catalog*/ None, - CollaborationModesConfig::default(), + TestModelsEndpoint::new(Vec::new()), ); let model_info = manager.get_model_info("gpt-5.4", &config).await; diff --git a/codex-rs/models-manager/src/test_support.rs b/codex-rs/models-manager/src/test_support.rs new file mode 100644 index 000000000000..aff28389076b --- /dev/null +++ b/codex-rs/models-manager/src/test_support.rs @@ -0,0 +1,38 @@ +//! Test-only helpers exposed for dependent crate tests. +//! +//! Production code should not depend on this module. + +use crate::ModelsManagerConfig; +use crate::bundled_models_response; +use crate::manager::construct_model_info_from_candidates; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ModelPreset; + +/// Get model identifier without consulting remote state or cache. +pub fn get_model_offline_for_tests(model: Option<&str>) -> String { + if let Some(model) = model { + return model.to_string(); + } + let mut response = bundled_models_response().unwrap_or_default(); + response.models.sort_by(|a, b| a.priority.cmp(&b.priority)); + let presets: Vec = response.models.into_iter().map(Into::into).collect(); + presets + .iter() + .find(|preset| preset.show_in_picker) + .or_else(|| presets.first()) + .map(|preset| preset.model.clone()) + .unwrap_or_default() +} + +/// Build `ModelInfo` without consulting remote state or cache. +pub fn construct_model_info_offline_for_tests( + model: &str, + config: &ModelsManagerConfig, +) -> ModelInfo { + let candidates: &[ModelInfo] = if let Some(model_catalog) = config.model_catalog.as_ref() { + &model_catalog.models + } else { + &[] + }; + construct_model_info_from_candidates(model, candidates, config) +}