From be145ab59e1ff8dfc7b9cfe06ed5cbb6c15d6cbe Mon Sep 17 00:00:00 2001 From: Owen Lin Date: Mon, 16 Mar 2026 11:32:40 -0700 Subject: [PATCH] fix(core): fallback to default task if websocket warm is not ready --- codex-rs/core/config.schema.json | 8 +- codex-rs/core/src/client.rs | 100 +++++--- codex-rs/core/src/codex.rs | 90 ++----- codex-rs/core/src/codex_tests.rs | 103 ++++++++ codex-rs/core/src/config/config_tests.rs | 2 + codex-rs/core/src/config/schema_tests.rs | 9 +- codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/model_provider_info.rs | 14 + .../core/src/model_provider_info_tests.rs | 16 ++ .../core/src/models_manager/manager_tests.rs | 1 + codex-rs/core/src/session_startup_prewarm.rs | 241 ++++++++++++++++++ codex-rs/core/src/state/session.rs | 23 +- codex-rs/core/src/tasks/mod.rs | 1 - codex-rs/core/src/tasks/regular.rs | 76 ++---- codex-rs/core/tests/responses_headers.rs | 3 + codex-rs/core/tests/suite/client.rs | 3 + .../core/tests/suite/client_websockets.rs | 26 +- .../suite/stream_error_allows_next_turn.rs | 1 + .../core/tests/suite/stream_no_completed.rs | 1 + codex-rs/otel/src/metrics/names.rs | 5 + 20 files changed, 548 insertions(+), 176 deletions(-) create mode 100644 codex-rs/core/src/session_startup_prewarm.rs diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 2235315d431..7d3ecdaa012 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -877,6 +877,12 @@ "description": "Whether this provider supports the Responses API WebSocket transport.", "type": "boolean" }, + "websocket_connect_timeout_ms": { + "description": "Maximum time (in milliseconds) to wait for a websocket connection attempt before treating it as failed.", + "format": "uint64", + "minimum": 0.0, + "type": "integer" + }, "wire_api": { "allOf": [ { @@ -2473,4 +2479,4 @@ }, "title": "ConfigToml", "type": "object" -} \ No newline at end of file +} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 72927b1e880..79fec5dfa1f 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -117,6 +117,9 @@ const RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=20 const RESPONSES_ENDPOINT: &str = "/responses"; const RESPONSES_COMPACT_ENDPOINT: &str = "/responses/compact"; const MEMORIES_SUMMARIZE_ENDPOINT: &str = "/memories/trace_summarize"; +#[cfg(test)] +pub(crate) const WEBSOCKET_CONNECT_TIMEOUT: Duration = + Duration::from_millis(crate::model_provider_info::DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS); pub fn ws_version_from_features(config: &Config) -> bool { config .features @@ -310,6 +313,27 @@ impl ModelClient { .unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session; } + pub(crate) fn force_http_fallback( + &self, + session_telemetry: &SessionTelemetry, + model_info: &ModelInfo, + ) -> bool { + let websocket_enabled = self.responses_websocket_enabled(model_info); + let activated = + websocket_enabled && !self.state.disable_websockets.swap(true, Ordering::Relaxed); + if activated { + warn!("falling back to HTTP"); + session_telemetry.counter( + "codex.transport.fallback_to_http", + /*inc*/ 1, + &[("from_wire_api", "responses_websocket")], + ); + } + + self.store_cached_websocket_session(WebsocketSession::default()); + activated + } + /// Compacts the current conversation history using the Compact endpoint. /// /// This is a unary call (no streaming) that returns a new list of @@ -538,15 +562,22 @@ impl ModelClient { auth_context, request_route_telemetry, ); + let websocket_connect_timeout = self.state.provider.websocket_connect_timeout(); let start = Instant::now(); - let result = ApiWebSocketResponsesClient::new(api_provider, api_auth) - .connect( + let result = match tokio::time::timeout( + websocket_connect_timeout, + ApiWebSocketResponsesClient::new(api_provider, api_auth).connect( headers, crate::default_client::default_headers(), turn_state, Some(websocket_telemetry), - ) - .await; + ), + ) + .await + { + Ok(result) => result, + Err(_) => Err(ApiError::Transport(TransportError::Timeout)), + }; let error_message = result.as_ref().err().map(telemetry_api_error_message); let response_debug = result .as_ref() @@ -637,13 +668,12 @@ impl Drop for ModelClientSession { } impl ModelClientSession { - fn activate_http_fallback(&self, websocket_enabled: bool) -> bool { - websocket_enabled - && !self - .client - .state - .disable_websockets - .swap(true, Ordering::Relaxed) + fn reset_websocket_session(&mut self) { + self.websocket_session.connection = None; + self.websocket_session.last_request = None; + self.websocket_session.last_response_rx = None; + self.websocket_session + .set_connection_reused(/*connection_reused*/ false); } fn build_responses_request( @@ -896,7 +926,7 @@ impl ModelClientSession { .turn_state .clone() .unwrap_or_else(|| Arc::clone(&self.turn_state)); - let new_conn = self + let new_conn = match self .client .connect_websocket( session_telemetry, @@ -907,7 +937,16 @@ impl ModelClientSession { auth_context, request_route_telemetry, ) - .await?; + .await + { + Ok(new_conn) => new_conn, + Err(err) => { + if matches!(err, ApiError::Transport(TransportError::Timeout)) { + self.reset_websocket_session(); + } + return Err(err); + } + }; self.websocket_session.connection = Some(new_conn); self.websocket_session .set_connection_reused(/*connection_reused*/ false); @@ -1130,15 +1169,12 @@ impl ModelClientSession { let ws_request = self.prepare_websocket_request(ws_payload, &request); self.websocket_session.last_request = Some(request); - let stream_result = self - .websocket_session - .connection - .as_ref() - .ok_or_else(|| { - map_api_error(ApiError::Stream( - "websocket connection is unavailable".to_string(), - )) - })? + let stream_result = self.websocket_session.connection.as_ref().ok_or_else(|| { + map_api_error(ApiError::Stream( + "websocket connection is unavailable".to_string(), + )) + })?; + let stream_result = stream_result .stream_request(ws_request, self.websocket_session.connection_reused()) .await .map_err(map_api_error)?; @@ -1296,22 +1332,10 @@ impl ModelClientSession { session_telemetry: &SessionTelemetry, model_info: &ModelInfo, ) -> bool { - let websocket_enabled = self.client.responses_websocket_enabled(model_info); - let activated = self.activate_http_fallback(websocket_enabled); - if activated { - warn!("falling back to HTTP"); - session_telemetry.counter( - "codex.transport.fallback_to_http", - /*inc*/ 1, - &[("from_wire_api", "responses_websocket")], - ); - - self.websocket_session.connection = None; - self.websocket_session.last_request = None; - self.websocket_session.last_response_rx = None; - self.websocket_session - .set_connection_reused(/*connection_reused*/ false); - } + let activated = self + .client + .force_http_fallback(session_telemetry, model_info); + self.websocket_session = WebsocketSession::default(); activated } } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 735fe7ec4da..8ffe1d3bd1e 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -105,7 +105,6 @@ use codex_protocol::protocol::SubAgentSource; use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnContextItem; use codex_protocol::protocol::TurnContextNetworkItem; -use codex_protocol::protocol::TurnStartedEvent; use codex_protocol::protocol::W3cTraceContext; use codex_protocol::request_permissions::PermissionGrantScope; use codex_protocol::request_permissions::RequestPermissionProfile; @@ -267,6 +266,7 @@ use crate::rollout::RolloutRecorderParams; use crate::rollout::map_session_init_error; use crate::rollout::metadata; use crate::rollout::policy::EventPersistenceMode; +use crate::session_startup_prewarm::SessionStartupPrewarmHandle; use crate::shell; use crate::shell_snapshot::ShellSnapshot; use crate::skills::SkillError; @@ -286,7 +286,6 @@ use crate::state::SessionServices; use crate::state::SessionState; use crate::state_db; use crate::tasks::GhostSnapshotTask; -use crate::tasks::RegularTask; use crate::tasks::ReviewTask; use crate::tasks::SessionTask; use crate::tasks::SessionTaskContext; @@ -2411,70 +2410,17 @@ impl Session { .await } - pub(crate) async fn take_startup_regular_task(&self) -> Option { - let startup_regular_task = { - let mut state = self.state.lock().await; - state.take_startup_regular_task() - }; - let startup_regular_task = startup_regular_task?; - match startup_regular_task.await { - Ok(Ok(regular_task)) => Some(regular_task), - Ok(Err(err)) => { - warn!("startup websocket prewarm setup failed: {err:#}"); - None - } - Err(err) => { - warn!("startup websocket prewarm setup join failed: {err}"); - None - } - } - } - - async fn schedule_startup_prewarm(self: &Arc, base_instructions: String) { - let sess = Arc::clone(self); - let startup_regular_task: JoinHandle> = - tokio::spawn( - async move { sess.schedule_startup_prewarm_inner(base_instructions).await }, - ); + pub(crate) async fn set_session_startup_prewarm( + &self, + startup_prewarm: SessionStartupPrewarmHandle, + ) { let mut state = self.state.lock().await; - state.set_startup_regular_task(startup_regular_task); + state.set_session_startup_prewarm(startup_prewarm); } - async fn schedule_startup_prewarm_inner( - self: &Arc, - base_instructions: String, - ) -> CodexResult { - let startup_turn_context = self - .new_default_turn_with_sub_id(INITIAL_SUBMIT_ID.to_owned()) - .await; - let startup_cancellation_token = CancellationToken::new(); - let startup_router = built_tools( - self, - startup_turn_context.as_ref(), - &[], - &HashSet::new(), - /*skills_outcome*/ None, - &startup_cancellation_token, - ) - .await?; - let startup_prompt = build_prompt( - Vec::new(), - startup_router.as_ref(), - startup_turn_context.as_ref(), - BaseInstructions { - text: base_instructions, - }, - ); - let startup_turn_metadata_header = startup_turn_context - .turn_metadata_state - .current_header_value(); - RegularTask::with_startup_prewarm( - self.services.model_client.clone(), - startup_prompt, - startup_turn_context, - startup_turn_metadata_header, - ) - .await + pub(crate) async fn take_session_startup_prewarm(&self) -> Option { + let mut state = self.state.lock().await; + state.take_session_startup_prewarm() } pub(crate) async fn get_config(&self) -> std::sync::Arc { @@ -4553,9 +4499,12 @@ mod handlers { { sess.refresh_mcp_servers_if_requested(¤t_context) .await; - let regular_task = sess.take_startup_regular_task().await.unwrap_or_default(); - sess.spawn_task(Arc::clone(¤t_context), items, regular_task) - .await; + sess.spawn_task( + Arc::clone(¤t_context), + items, + crate::tasks::RegularTask::new(), + ) + .await; } } @@ -5485,13 +5434,6 @@ pub(crate) async fn run_turn( let model_info = turn_context.model_info.clone(); let auto_compact_limit = model_info.auto_compact_token_limit().unwrap_or(i64::MAX); - - let event = EventMsg::TurnStarted(TurnStartedEvent { - turn_id: turn_context.sub_id.clone(), - model_context_window: turn_context.model_context_window(), - collaboration_mode_kind: turn_context.collaboration_mode.mode, - }); - sess.send_event(&turn_context, event).await; // TODO(ccunningham): Pre-turn compaction runs before context updates and the // new user message are recorded. Estimate pending incoming items (context // diffs/full reinjection + user input) and trigger compaction preemptively @@ -6236,7 +6178,7 @@ fn codex_apps_connector_id(tool: &crate::mcp_connection_manager::ToolInfo) -> Op tool.connector_id.as_deref() } -fn build_prompt( +pub(crate) fn build_prompt( input: Vec, router: &ToolRouter, turn_context: &TurnContext, diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index 34ed7bcd63b..fab591db7dc 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -39,6 +39,7 @@ use crate::protocol::TokenCountEvent; use crate::protocol::TokenUsage; use crate::protocol::TokenUsageInfo; use crate::protocol::TurnCompleteEvent; +use crate::protocol::TurnStartedEvent; use crate::protocol::UserMessageEvent; use crate::rollout::policy::EventPersistenceMode; use crate::rollout::recorder::RolloutRecorder; @@ -142,6 +143,108 @@ fn skill_message(text: &str) -> ResponseItem { } } +#[tokio::test] +async fn regular_turn_emits_turn_started_without_waiting_for_startup_prewarm() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let (_tx, startup_prewarm_rx) = tokio::sync::oneshot::channel::<()>(); + let handle = tokio::spawn(async move { + let _ = startup_prewarm_rx.await; + Ok(test_model_client_session()) + }); + + sess.set_session_startup_prewarm( + crate::session_startup_prewarm::SessionStartupPrewarmHandle::new( + handle, + std::time::Instant::now(), + crate::client::WEBSOCKET_CONNECT_TIMEOUT, + ), + ) + .await; + sess.spawn_task( + Arc::clone(&tc), + Vec::new(), + crate::tasks::RegularTask::new(), + ) + .await; + + let first = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()) + .await + .expect("expected turn started event without waiting for startup prewarm") + .expect("channel open"); + assert!(matches!( + first.msg, + EventMsg::TurnStarted(TurnStartedEvent { turn_id, .. }) if turn_id == tc.sub_id + )); + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; +} + +#[tokio::test] +async fn interrupting_regular_turn_waiting_on_startup_prewarm_emits_turn_aborted() { + let (sess, tc, rx) = make_session_and_context_with_rx().await; + let (_tx, startup_prewarm_rx) = tokio::sync::oneshot::channel::<()>(); + let handle = tokio::spawn(async move { + let _ = startup_prewarm_rx.await; + Ok(test_model_client_session()) + }); + + sess.set_session_startup_prewarm( + crate::session_startup_prewarm::SessionStartupPrewarmHandle::new( + handle, + std::time::Instant::now(), + crate::client::WEBSOCKET_CONNECT_TIMEOUT, + ), + ) + .await; + sess.spawn_task( + Arc::clone(&tc), + Vec::new(), + crate::tasks::RegularTask::new(), + ) + .await; + + let first = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv()) + .await + .expect("expected turn started event without waiting for startup prewarm") + .expect("channel open"); + assert!(matches!( + first.msg, + EventMsg::TurnStarted(TurnStartedEvent { turn_id, .. }) if turn_id == tc.sub_id + )); + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("expected turn aborted event") + .expect("channel open"); + assert!(matches!( + second.msg, + EventMsg::TurnAborted(crate::protocol::TurnAbortedEvent { + turn_id: Some(turn_id), + reason: TurnAbortReason::Interrupted, + }) if turn_id == tc.sub_id + )); +} + +fn test_model_client_session() -> crate::client::ModelClientSession { + crate::client::ModelClient::new( + None, + ThreadId::try_from("00000000-0000-4000-8000-000000000001") + .expect("test thread id should be valid"), + crate::model_provider_info::ModelProviderInfo::create_openai_provider( + /* base_url */ None, + ), + codex_protocol::protocol::SessionSource::Exec, + None, + false, + false, + false, + None, + ) + .new_session() +} + fn developer_input_texts(items: &[ResponseItem]) -> Vec<&str> { items .iter() diff --git a/codex-rs/core/src/config/config_tests.rs b/codex-rs/core/src/config/config_tests.rs index 7e82bd7da9d..c6372de258e 100644 --- a/codex-rs/core/src/config/config_tests.rs +++ b/codex-rs/core/src/config/config_tests.rs @@ -4076,6 +4076,7 @@ wire_api = "responses" request_max_retries = 4 # retry failed HTTP requests stream_max_retries = 10 # retry dropped SSE streams stream_idle_timeout_ms = 300000 # 5m idle timeout +websocket_connect_timeout_ms = 15000 [profiles.o3] model = "o3" @@ -4130,6 +4131,7 @@ model_verbosity = "high" request_max_retries: Some(4), stream_max_retries: Some(10), stream_idle_timeout_ms: Some(300_000), + websocket_connect_timeout_ms: Some(15_000), requires_openai_auth: false, supports_websockets: false, }; diff --git a/codex-rs/core/src/config/schema_tests.rs b/codex-rs/core/src/config/schema_tests.rs index 6205d43f40e..31fabd64bd2 100644 --- a/codex-rs/core/src/config/schema_tests.rs +++ b/codex-rs/core/src/config/schema_tests.rs @@ -6,6 +6,10 @@ use pretty_assertions::assert_eq; use similar::TextDiff; use tempfile::TempDir; +fn trim_single_trailing_newline(contents: &str) -> &str { + contents.strip_suffix('\n').unwrap_or(contents) +} + #[test] fn config_schema_matches_fixture() { let fixture_path = codex_utils_cargo_bin::find_resource!("config.schema.json") @@ -40,9 +44,12 @@ Run `just write-config-schema` to overwrite with your changes.\n\n{diff}" std::fs::read_to_string(&tmp_path).expect("read back config schema from temp path"); #[cfg(windows)] let fixture = fixture.replace("\r\n", "\n"); + #[cfg(windows)] + let tmp_contents = tmp_contents.replace("\r\n", "\n"); assert_eq!( - fixture, tmp_contents, + trim_single_trailing_newline(&fixture), + trim_single_trailing_newline(&tmp_contents), "fixture should match exactly with generated schema" ); } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 2056f0dfe0a..e02a346545a 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -69,6 +69,7 @@ pub mod plugins; mod sandbox_tags; pub mod sandboxing; mod session_prefix; +mod session_startup_prewarm; mod shell_detect; mod stream_events_utils; pub mod test_support; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index be7a38d27d1..737a47780d8 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -22,6 +22,7 @@ use std::time::Duration; const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000; const DEFAULT_STREAM_MAX_RETRIES: u64 = 5; const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4; +pub(crate) const DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS: u64 = 15_000; /// Hard cap for user-configured `stream_max_retries`. const MAX_STREAM_MAX_RETRIES: u64 = 100; /// Hard cap for user-configured `request_max_retries`. @@ -112,6 +113,10 @@ pub struct ModelProviderInfo { /// the connection as lost. pub stream_idle_timeout_ms: Option, + /// Maximum time (in milliseconds) to wait for a websocket connection attempt before treating + /// it as failed. + pub websocket_connect_timeout_ms: Option, + /// Does this provider require an OpenAI API Key or ChatGPT login token? If true, /// user is presented with login screen on first run, and login preference and token/key /// are stored in auth.json. If false (which is the default), login screen is skipped, @@ -227,6 +232,13 @@ impl ModelProviderInfo { .unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS)) } + /// Effective timeout for websocket connect attempts. + pub fn websocket_connect_timeout(&self) -> Duration { + self.websocket_connect_timeout_ms + .map(Duration::from_millis) + .unwrap_or(Duration::from_millis(DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS)) + } + pub fn create_openai_provider(base_url: Option) -> ModelProviderInfo { ModelProviderInfo { name: OPENAI_PROVIDER_NAME.into(), @@ -256,6 +268,7 @@ impl ModelProviderInfo { request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, requires_openai_auth: true, supports_websockets: true, } @@ -332,6 +345,7 @@ pub fn create_oss_provider_with_base_url(base_url: &str, wire_api: WireApi) -> M request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, } diff --git a/codex-rs/core/src/model_provider_info_tests.rs b/codex-rs/core/src/model_provider_info_tests.rs index e6d5cea36ba..a5309117ae7 100644 --- a/codex-rs/core/src/model_provider_info_tests.rs +++ b/codex-rs/core/src/model_provider_info_tests.rs @@ -20,6 +20,7 @@ base_url = "http://localhost:11434/v1" request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; @@ -51,6 +52,7 @@ query_params = { api-version = "2025-04-01-preview" } request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; @@ -85,6 +87,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; @@ -105,3 +108,16 @@ wire_api = "chat" let err = toml::from_str::(provider_toml).unwrap_err(); assert!(err.to_string().contains(CHAT_WIRE_API_REMOVED_ERROR)); } + +#[test] +fn test_deserialize_websocket_connect_timeout() { + let provider_toml = r#" +name = "OpenAI" +base_url = "https://api.openai.com/v1" +websocket_connect_timeout_ms = 15000 +supports_websockets = true + "#; + + let provider: ModelProviderInfo = toml::from_str(provider_toml).unwrap(); + assert_eq!(provider.websocket_connect_timeout_ms, Some(15_000)); +} diff --git a/codex-rs/core/src/models_manager/manager_tests.rs b/codex-rs/core/src/models_manager/manager_tests.rs index 6981d6d799a..da42f2c8596 100644 --- a/codex-rs/core/src/models_manager/manager_tests.rs +++ b/codex-rs/core/src/models_manager/manager_tests.rs @@ -71,6 +71,7 @@ fn provider_for(base_url: String) -> ModelProviderInfo { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(5_000), + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, } diff --git a/codex-rs/core/src/session_startup_prewarm.rs b/codex-rs/core/src/session_startup_prewarm.rs new file mode 100644 index 00000000000..326d864f1e0 --- /dev/null +++ b/codex-rs/core/src/session_startup_prewarm.rs @@ -0,0 +1,241 @@ +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; +use tracing::info; +use tracing::warn; + +use crate::client::ModelClientSession; +use crate::codex::INITIAL_SUBMIT_ID; +use crate::codex::Session; +use crate::codex::build_prompt; +use crate::codex::built_tools; +use crate::error::Result as CodexResult; +use codex_otel::SessionTelemetry; +use codex_otel::metrics::names::STARTUP_PREWARM_AGE_AT_FIRST_TURN_METRIC; +use codex_otel::metrics::names::STARTUP_PREWARM_DURATION_METRIC; +use codex_protocol::models::BaseInstructions; + +pub(crate) struct SessionStartupPrewarmHandle { + task: JoinHandle>, + started_at: Instant, + timeout: Duration, +} + +pub(crate) enum SessionStartupPrewarmResolution { + Cancelled, + Ready(Box), + Unavailable { + status: &'static str, + prewarm_duration: Option, + }, +} + +impl SessionStartupPrewarmHandle { + pub(crate) fn new( + task: JoinHandle>, + started_at: Instant, + timeout: Duration, + ) -> Self { + Self { + task, + started_at, + timeout, + } + } + + async fn resolve( + self, + session_telemetry: &SessionTelemetry, + cancellation_token: &CancellationToken, + ) -> SessionStartupPrewarmResolution { + let Self { + mut task, + started_at, + timeout, + } = self; + let age_at_first_turn = started_at.elapsed(); + let remaining = timeout.saturating_sub(age_at_first_turn); + + let resolution = if task.is_finished() { + Self::resolution_from_join_result(task.await, started_at) + } else { + match tokio::select! { + _ = cancellation_token.cancelled() => None, + result = tokio::time::timeout(remaining, &mut task) => Some(result), + } { + Some(Ok(result)) => Self::resolution_from_join_result(result, started_at), + Some(Err(_elapsed)) => { + task.abort(); + info!("startup websocket prewarm timed out before the first turn could use it"); + SessionStartupPrewarmResolution::Unavailable { + status: "timed_out", + prewarm_duration: Some(started_at.elapsed()), + } + } + None => { + task.abort(); + session_telemetry.record_duration( + STARTUP_PREWARM_AGE_AT_FIRST_TURN_METRIC, + age_at_first_turn, + &[("status", "cancelled")], + ); + session_telemetry.record_duration( + STARTUP_PREWARM_DURATION_METRIC, + started_at.elapsed(), + &[("status", "cancelled")], + ); + return SessionStartupPrewarmResolution::Cancelled; + } + } + }; + + match resolution { + SessionStartupPrewarmResolution::Cancelled => { + SessionStartupPrewarmResolution::Cancelled + } + SessionStartupPrewarmResolution::Ready(prewarmed_session) => { + session_telemetry.record_duration( + STARTUP_PREWARM_AGE_AT_FIRST_TURN_METRIC, + age_at_first_turn, + &[("status", "consumed")], + ); + SessionStartupPrewarmResolution::Ready(prewarmed_session) + } + SessionStartupPrewarmResolution::Unavailable { + status, + prewarm_duration, + } => { + session_telemetry.record_duration( + STARTUP_PREWARM_AGE_AT_FIRST_TURN_METRIC, + age_at_first_turn, + &[("status", status)], + ); + if let Some(prewarm_duration) = prewarm_duration { + session_telemetry.record_duration( + STARTUP_PREWARM_DURATION_METRIC, + prewarm_duration, + &[("status", status)], + ); + } + SessionStartupPrewarmResolution::Unavailable { + status, + prewarm_duration, + } + } + } + } + + fn resolution_from_join_result( + result: std::result::Result, tokio::task::JoinError>, + started_at: Instant, + ) -> SessionStartupPrewarmResolution { + match result { + Ok(Ok(prewarmed_session)) => { + SessionStartupPrewarmResolution::Ready(Box::new(prewarmed_session)) + } + Ok(Err(err)) => { + warn!("startup websocket prewarm setup failed: {err:#}"); + SessionStartupPrewarmResolution::Unavailable { + status: "failed", + prewarm_duration: None, + } + } + Err(err) => { + warn!("startup websocket prewarm setup join failed: {err}"); + SessionStartupPrewarmResolution::Unavailable { + status: "join_failed", + prewarm_duration: Some(started_at.elapsed()), + } + } + } + } +} + +impl Session { + pub(crate) async fn schedule_startup_prewarm(self: &Arc, base_instructions: String) { + let session_telemetry = self.services.session_telemetry.clone(); + let websocket_connect_timeout = self.provider().await.websocket_connect_timeout(); + let started_at = Instant::now(); + let startup_prewarm_session = Arc::clone(self); + let startup_prewarm = tokio::spawn(async move { + let result = + schedule_startup_prewarm_inner(startup_prewarm_session, base_instructions).await; + let status = if result.is_ok() { "ready" } else { "failed" }; + session_telemetry.record_duration( + STARTUP_PREWARM_DURATION_METRIC, + started_at.elapsed(), + &[("status", status)], + ); + result + }); + self.set_session_startup_prewarm(SessionStartupPrewarmHandle::new( + startup_prewarm, + started_at, + websocket_connect_timeout, + )) + .await; + } + + pub(crate) async fn consume_startup_prewarm_for_regular_turn( + &self, + cancellation_token: &CancellationToken, + ) -> SessionStartupPrewarmResolution { + let Some(startup_prewarm) = self.take_session_startup_prewarm().await else { + return SessionStartupPrewarmResolution::Unavailable { + status: "not_scheduled", + prewarm_duration: None, + }; + }; + startup_prewarm + .resolve(&self.services.session_telemetry, cancellation_token) + .await + } +} + +async fn schedule_startup_prewarm_inner( + session: Arc, + base_instructions: String, +) -> CodexResult { + let startup_turn_context = session + .new_default_turn_with_sub_id(INITIAL_SUBMIT_ID.to_owned()) + .await; + let startup_cancellation_token = CancellationToken::new(); + let startup_router = built_tools( + session.as_ref(), + startup_turn_context.as_ref(), + &[], + &HashSet::new(), + /*skills_outcome*/ None, + &startup_cancellation_token, + ) + .await?; + let startup_prompt = build_prompt( + Vec::new(), + startup_router.as_ref(), + startup_turn_context.as_ref(), + BaseInstructions { + text: base_instructions, + }, + ); + let startup_turn_metadata_header = startup_turn_context + .turn_metadata_state + .current_header_value(); + let mut client_session = session.services.model_client.new_session(); + client_session + .prewarm_websocket( + &startup_prompt, + &startup_turn_context.model_info, + &startup_turn_context.session_telemetry, + startup_turn_context.reasoning_effort, + startup_turn_context.reasoning_summary, + startup_turn_context.config.service_tier, + startup_turn_metadata_header.as_deref(), + ) + .await?; + + Ok(client_session) +} diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index 40faa4b8568..563e8b3403c 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -4,17 +4,15 @@ use codex_protocol::models::PermissionProfile; use codex_protocol::models::ResponseItem; use std::collections::HashMap; use std::collections::HashSet; -use tokio::task::JoinHandle; use crate::codex::PreviousTurnSettings; use crate::codex::SessionConfiguration; use crate::context_manager::ContextManager; -use crate::error::Result as CodexResult; use crate::protocol::RateLimitSnapshot; use crate::protocol::TokenUsage; use crate::protocol::TokenUsageInfo; use crate::sandboxing::merge_permission_profiles; -use crate::tasks::RegularTask; +use crate::session_startup_prewarm::SessionStartupPrewarmHandle; use crate::truncate::TruncationPolicy; use codex_protocol::protocol::TurnContextItem; @@ -30,8 +28,8 @@ pub(crate) struct SessionState { /// model/realtime handling on subsequent regular turns (including full-context /// reinjection after resume or `/compact`). previous_turn_settings: Option, - /// Startup regular task pre-created during session initialization. - pub(crate) startup_regular_task: Option>>, + /// Startup prewarmed session prepared during session initialization. + pub(crate) startup_prewarm: Option, pub(crate) active_connector_selection: HashSet, pub(crate) pending_session_start_source: Option, granted_permissions: Option, @@ -49,7 +47,7 @@ impl SessionState { dependency_env: HashMap::new(), mcp_dependency_prompted: HashSet::new(), previous_turn_settings: None, - startup_regular_task: None, + startup_prewarm: None, active_connector_selection: HashSet::new(), pending_session_start_source: None, granted_permissions: None, @@ -165,14 +163,15 @@ impl SessionState { self.dependency_env.clone() } - pub(crate) fn set_startup_regular_task(&mut self, task: JoinHandle>) { - self.startup_regular_task = Some(task); + pub(crate) fn set_session_startup_prewarm( + &mut self, + startup_prewarm: SessionStartupPrewarmHandle, + ) { + self.startup_prewarm = Some(startup_prewarm); } - pub(crate) fn take_startup_regular_task( - &mut self, - ) -> Option>> { - self.startup_regular_task.take() + pub(crate) fn take_session_startup_prewarm(&mut self) -> Option { + self.startup_prewarm.take() } // Adds connector IDs to the active set and returns the merged selection. diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 2089561199b..c237af4d112 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -384,7 +384,6 @@ impl Session { turn.add_task(task); *active = Some(turn); } - async fn take_active_turn(&self) -> Option { let mut active = self.active_turn.lock().await; active.take() diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index f1851b93478..6deb9abdb6b 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -1,64 +1,27 @@ use std::sync::Arc; -use std::sync::Mutex; -use crate::client::ModelClient; -use crate::client::ModelClientSession; -use crate::client_common::Prompt; +use async_trait::async_trait; +use tokio_util::sync::CancellationToken; + use crate::codex::TurnContext; use crate::codex::run_turn; -use crate::error::Result as CodexResult; +use crate::protocol::EventMsg; +use crate::protocol::TurnStartedEvent; +use crate::session_startup_prewarm::SessionStartupPrewarmResolution; use crate::state::TaskKind; -use async_trait::async_trait; use codex_protocol::user_input::UserInput; -use tokio_util::sync::CancellationToken; use tracing::Instrument; use tracing::trace_span; use super::SessionTask; use super::SessionTaskContext; -pub(crate) struct RegularTask { - prewarmed_session: Mutex>, -} - -impl Default for RegularTask { - fn default() -> Self { - Self { - prewarmed_session: Mutex::new(None), - } - } -} +#[derive(Default)] +pub(crate) struct RegularTask; impl RegularTask { - pub(crate) async fn with_startup_prewarm( - model_client: ModelClient, - prompt: Prompt, - turn_context: Arc, - turn_metadata_header: Option, - ) -> CodexResult { - let mut client_session = model_client.new_session(); - client_session - .prewarm_websocket( - &prompt, - &turn_context.model_info, - &turn_context.session_telemetry, - turn_context.reasoning_effort, - turn_context.reasoning_summary, - turn_context.config.service_tier, - turn_metadata_header.as_deref(), - ) - .await?; - - Ok(Self { - prewarmed_session: Mutex::new(Some(client_session)), - }) - } - - async fn take_prewarmed_session(&self) -> Option { - self.prewarmed_session - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .take() + pub(crate) fn new() -> Self { + Self } } @@ -81,8 +44,25 @@ impl SessionTask for RegularTask { ) -> Option { let sess = session.clone_session(); let run_turn_span = trace_span!("run_turn"); + // Regular turns emit `TurnStarted` inline so first-turn lifecycle does + // not wait on startup prewarm resolution. + let event = EventMsg::TurnStarted(TurnStartedEvent { + turn_id: ctx.sub_id.clone(), + model_context_window: ctx.model_context_window(), + collaboration_mode_kind: ctx.collaboration_mode.mode, + }); + sess.send_event(ctx.as_ref(), event).await; sess.set_server_reasoning_included(/*included*/ false).await; - let prewarmed_client_session = self.take_prewarmed_session().await; + let prewarmed_client_session = match sess + .consume_startup_prewarm_for_regular_turn(&cancellation_token) + .await + { + SessionStartupPrewarmResolution::Cancelled => return None, + SessionStartupPrewarmResolution::Unavailable { .. } => None, + SessionStartupPrewarmResolution::Ready(prewarmed_client_session) => { + Some(*prewarmed_client_session) + } + }; run_turn( sess, ctx, diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index d5376fcd7d5..d1c73f39d9c 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -53,6 +53,7 @@ async fn responses_stream_includes_subagent_header_on_review() { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(5_000), + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; @@ -165,6 +166,7 @@ async fn responses_stream_includes_subagent_header_on_other() { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(5_000), + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; @@ -272,6 +274,7 @@ async fn responses_respects_model_info_overrides_from_config() { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(5_000), + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index 2fc5f8b9d10..7d785694dcb 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -1792,6 +1792,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(5_000), + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; @@ -2393,6 +2394,7 @@ async fn azure_overrides_assign_properties_used_for_responses_url() { request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; @@ -2477,6 +2479,7 @@ async fn env_var_overrides_loaded_auth() { request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; diff --git a/codex-rs/core/tests/suite/client_websockets.rs b/codex-rs/core/tests/suite/client_websockets.rs index 2c7b4d48e10..c2dd16f8b58 100755 --- a/codex-rs/core/tests/suite/client_websockets.rs +++ b/codex-rs/core/tests/suite/client_websockets.rs @@ -1500,6 +1500,13 @@ fn prompt_with_input_and_instructions(input: Vec, instructions: &s } fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo { + websocket_provider_with_connect_timeout(server, None) +} + +fn websocket_provider_with_connect_timeout( + server: &WebSocketTestServer, + websocket_connect_timeout_ms: Option, +) -> ModelProviderInfo { ModelProviderInfo { name: "mock-ws".into(), base_url: Some(format!("{}/v1", server.uri())), @@ -1513,6 +1520,7 @@ fn websocket_provider(server: &WebSocketTestServer) -> ModelProviderInfo { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(5_000), + websocket_connect_timeout_ms, requires_openai_auth: false, supports_websockets: true, } @@ -1543,7 +1551,23 @@ async fn websocket_harness_with_options( websocket_v2_enabled: bool, prefer_websockets: bool, ) -> WebsocketTestHarness { - let provider = websocket_provider(server); + websocket_harness_with_provider_options( + websocket_provider(server), + runtime_metrics_enabled, + websocket_enabled, + websocket_v2_enabled, + prefer_websockets, + ) + .await +} + +async fn websocket_harness_with_provider_options( + provider: ModelProviderInfo, + runtime_metrics_enabled: bool, + websocket_enabled: bool, + websocket_v2_enabled: bool, + prefer_websockets: bool, +) -> WebsocketTestHarness { let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home).await; config.model = Some(MODEL.to_string()); diff --git a/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs b/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs index a8d1b379509..23ffc4afb00 100644 --- a/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs +++ b/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs @@ -76,6 +76,7 @@ async fn continue_after_stream_error() { request_max_retries: Some(1), stream_max_retries: Some(1), stream_idle_timeout_ms: Some(2_000), + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; diff --git a/codex-rs/core/tests/suite/stream_no_completed.rs b/codex-rs/core/tests/suite/stream_no_completed.rs index e6fc7ee8cb6..5d1b2148111 100644 --- a/codex-rs/core/tests/suite/stream_no_completed.rs +++ b/codex-rs/core/tests/suite/stream_no_completed.rs @@ -61,6 +61,7 @@ async fn retries_on_early_close() { request_max_retries: Some(0), stream_max_retries: Some(1), stream_idle_timeout_ms: Some(2000), + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; diff --git a/codex-rs/otel/src/metrics/names.rs b/codex-rs/otel/src/metrics/names.rs index 5063001f2c8..569cdc8256e 100644 --- a/codex-rs/otel/src/metrics/names.rs +++ b/codex-rs/otel/src/metrics/names.rs @@ -25,4 +25,9 @@ pub const TURN_TTFM_DURATION_METRIC: &str = "codex.turn.ttfm.duration_ms"; pub const TURN_NETWORK_PROXY_METRIC: &str = "codex.turn.network_proxy"; pub const TURN_TOOL_CALL_METRIC: &str = "codex.turn.tool.call"; pub const TURN_TOKEN_USAGE_METRIC: &str = "codex.turn.token_usage"; +/// Total runtime of a startup prewarm attempt until it completes, tagged by final status. +pub const STARTUP_PREWARM_DURATION_METRIC: &str = "codex.startup_prewarm.duration_ms"; +/// Age of the startup prewarm attempt when the first real turn resolves it, tagged by outcome. +pub const STARTUP_PREWARM_AGE_AT_FIRST_TURN_METRIC: &str = + "codex.startup_prewarm.age_at_first_turn_ms"; pub const THREAD_STARTED_METRIC: &str = "codex.thread.started";