Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion codex-rs/core/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,12 @@
"description": "Whether this provider supports the Responses API WebSocket transport.",
"type": "boolean"
},
"websocket_connect_timeout_ms": {
"description": "Maximum time (in milliseconds) to wait for a websocket connection attempt before treating it as failed.",
"format": "uint64",
"minimum": 0.0,
"type": "integer"
},
"wire_api": {
"allOf": [
{
Expand Down Expand Up @@ -2473,4 +2479,4 @@
},
"title": "ConfigToml",
"type": "object"
}
}
100 changes: 62 additions & 38 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ const RESPONSES_WEBSOCKETS_V2_BETA_HEADER_VALUE: &str = "responses_websockets=20
const RESPONSES_ENDPOINT: &str = "/responses";
const RESPONSES_COMPACT_ENDPOINT: &str = "/responses/compact";
const MEMORIES_SUMMARIZE_ENDPOINT: &str = "/memories/trace_summarize";
#[cfg(test)]
pub(crate) const WEBSOCKET_CONNECT_TIMEOUT: Duration =
Duration::from_millis(crate::model_provider_info::DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS);
pub fn ws_version_from_features(config: &Config) -> bool {
config
.features
Expand Down Expand Up @@ -310,6 +313,27 @@ impl ModelClient {
.unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session;
}

pub(crate) fn force_http_fallback(
&self,
session_telemetry: &SessionTelemetry,
model_info: &ModelInfo,
) -> bool {
let websocket_enabled = self.responses_websocket_enabled(model_info);
let activated =
websocket_enabled && !self.state.disable_websockets.swap(true, Ordering::Relaxed);
if activated {
warn!("falling back to HTTP");
session_telemetry.counter(
"codex.transport.fallback_to_http",
/*inc*/ 1,
&[("from_wire_api", "responses_websocket")],
);
}

self.store_cached_websocket_session(WebsocketSession::default());
activated
}

/// Compacts the current conversation history using the Compact endpoint.
///
/// This is a unary call (no streaming) that returns a new list of
Expand Down Expand Up @@ -538,15 +562,22 @@ impl ModelClient {
auth_context,
request_route_telemetry,
);
let websocket_connect_timeout = self.state.provider.websocket_connect_timeout();
let start = Instant::now();
let result = ApiWebSocketResponsesClient::new(api_provider, api_auth)
.connect(
let result = match tokio::time::timeout(
websocket_connect_timeout,
ApiWebSocketResponsesClient::new(api_provider, api_auth).connect(
headers,
crate::default_client::default_headers(),
turn_state,
Some(websocket_telemetry),
)
.await;
),
)
.await
{
Ok(result) => result,
Err(_) => Err(ApiError::Transport(TransportError::Timeout)),
};
let error_message = result.as_ref().err().map(telemetry_api_error_message);
let response_debug = result
.as_ref()
Expand Down Expand Up @@ -637,13 +668,12 @@ impl Drop for ModelClientSession {
}

impl ModelClientSession {
fn activate_http_fallback(&self, websocket_enabled: bool) -> bool {
websocket_enabled
&& !self
.client
.state
.disable_websockets
.swap(true, Ordering::Relaxed)
fn reset_websocket_session(&mut self) {
self.websocket_session.connection = None;
self.websocket_session.last_request = None;
self.websocket_session.last_response_rx = None;
self.websocket_session
.set_connection_reused(/*connection_reused*/ false);
}

fn build_responses_request(
Expand Down Expand Up @@ -896,7 +926,7 @@ impl ModelClientSession {
.turn_state
.clone()
.unwrap_or_else(|| Arc::clone(&self.turn_state));
let new_conn = self
let new_conn = match self
.client
.connect_websocket(
session_telemetry,
Expand All @@ -907,7 +937,16 @@ impl ModelClientSession {
auth_context,
request_route_telemetry,
)
.await?;
.await
{
Ok(new_conn) => new_conn,
Err(err) => {
if matches!(err, ApiError::Transport(TransportError::Timeout)) {
self.reset_websocket_session();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Don't we auto retry?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so a retry does happen but at the turn level, not by creating a brand new websocket client from scratch

So when a request fails, Codex retries the same turn using the same turn-scoped session object. If we do not call reset_websocket_session(), that retry can carry leftover websocket state from the timed-out attempt and act like it is continuing a connection that never really came up.

}
return Err(err);
}
};
self.websocket_session.connection = Some(new_conn);
self.websocket_session
.set_connection_reused(/*connection_reused*/ false);
Expand Down Expand Up @@ -1130,15 +1169,12 @@ impl ModelClientSession {

let ws_request = self.prepare_websocket_request(ws_payload, &request);
self.websocket_session.last_request = Some(request);
let stream_result = self
.websocket_session
.connection
.as_ref()
.ok_or_else(|| {
map_api_error(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
})?
let stream_result = self.websocket_session.connection.as_ref().ok_or_else(|| {
map_api_error(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
})?;
let stream_result = stream_result
.stream_request(ws_request, self.websocket_session.connection_reused())
.await
.map_err(map_api_error)?;
Expand Down Expand Up @@ -1296,22 +1332,10 @@ impl ModelClientSession {
session_telemetry: &SessionTelemetry,
model_info: &ModelInfo,
) -> bool {
let websocket_enabled = self.client.responses_websocket_enabled(model_info);
let activated = self.activate_http_fallback(websocket_enabled);
if activated {
warn!("falling back to HTTP");
session_telemetry.counter(
"codex.transport.fallback_to_http",
/*inc*/ 1,
&[("from_wire_api", "responses_websocket")],
);

self.websocket_session.connection = None;
self.websocket_session.last_request = None;
self.websocket_session.last_response_rx = None;
self.websocket_session
.set_connection_reused(/*connection_reused*/ false);
}
let activated = self
.client
.force_http_fallback(session_telemetry, model_info);
self.websocket_session = WebsocketSession::default();
activated
}
}
Expand Down
90 changes: 16 additions & 74 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ use codex_protocol::protocol::SubAgentSource;
use codex_protocol::protocol::TurnAbortReason;
use codex_protocol::protocol::TurnContextItem;
use codex_protocol::protocol::TurnContextNetworkItem;
use codex_protocol::protocol::TurnStartedEvent;
use codex_protocol::protocol::W3cTraceContext;
use codex_protocol::request_permissions::PermissionGrantScope;
use codex_protocol::request_permissions::RequestPermissionProfile;
Expand Down Expand Up @@ -267,6 +266,7 @@ use crate::rollout::RolloutRecorderParams;
use crate::rollout::map_session_init_error;
use crate::rollout::metadata;
use crate::rollout::policy::EventPersistenceMode;
use crate::session_startup_prewarm::SessionStartupPrewarmHandle;
use crate::shell;
use crate::shell_snapshot::ShellSnapshot;
use crate::skills::SkillError;
Expand All @@ -286,7 +286,6 @@ use crate::state::SessionServices;
use crate::state::SessionState;
use crate::state_db;
use crate::tasks::GhostSnapshotTask;
use crate::tasks::RegularTask;
use crate::tasks::ReviewTask;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
Expand Down Expand Up @@ -2411,70 +2410,17 @@ impl Session {
.await
}

pub(crate) async fn take_startup_regular_task(&self) -> Option<RegularTask> {
let startup_regular_task = {
let mut state = self.state.lock().await;
state.take_startup_regular_task()
};
let startup_regular_task = startup_regular_task?;
match startup_regular_task.await {
Ok(Ok(regular_task)) => Some(regular_task),
Ok(Err(err)) => {
warn!("startup websocket prewarm setup failed: {err:#}");
None
}
Err(err) => {
warn!("startup websocket prewarm setup join failed: {err}");
None
}
}
}

async fn schedule_startup_prewarm(self: &Arc<Self>, base_instructions: String) {
let sess = Arc::clone(self);
let startup_regular_task: JoinHandle<CodexResult<RegularTask>> =
tokio::spawn(
async move { sess.schedule_startup_prewarm_inner(base_instructions).await },
);
pub(crate) async fn set_session_startup_prewarm(
&self,
startup_prewarm: SessionStartupPrewarmHandle,
) {
let mut state = self.state.lock().await;
state.set_startup_regular_task(startup_regular_task);
state.set_session_startup_prewarm(startup_prewarm);
}

async fn schedule_startup_prewarm_inner(
self: &Arc<Self>,
base_instructions: String,
) -> CodexResult<RegularTask> {
let startup_turn_context = self
.new_default_turn_with_sub_id(INITIAL_SUBMIT_ID.to_owned())
.await;
let startup_cancellation_token = CancellationToken::new();
let startup_router = built_tools(
self,
startup_turn_context.as_ref(),
&[],
&HashSet::new(),
/*skills_outcome*/ None,
&startup_cancellation_token,
)
.await?;
let startup_prompt = build_prompt(
Vec::new(),
startup_router.as_ref(),
startup_turn_context.as_ref(),
BaseInstructions {
text: base_instructions,
},
);
let startup_turn_metadata_header = startup_turn_context
.turn_metadata_state
.current_header_value();
RegularTask::with_startup_prewarm(
self.services.model_client.clone(),
startup_prompt,
startup_turn_context,
startup_turn_metadata_header,
)
.await
pub(crate) async fn take_session_startup_prewarm(&self) -> Option<SessionStartupPrewarmHandle> {
let mut state = self.state.lock().await;
state.take_session_startup_prewarm()
}

pub(crate) async fn get_config(&self) -> std::sync::Arc<Config> {
Expand Down Expand Up @@ -4553,9 +4499,12 @@ mod handlers {
{
sess.refresh_mcp_servers_if_requested(&current_context)
.await;
let regular_task = sess.take_startup_regular_task().await.unwrap_or_default();
sess.spawn_task(Arc::clone(&current_context), items, regular_task)
.await;
sess.spawn_task(
Arc::clone(&current_context),
items,
crate::tasks::RegularTask::new(),
)
.await;
}
}

Expand Down Expand Up @@ -5485,13 +5434,6 @@ pub(crate) async fn run_turn(

let model_info = turn_context.model_info.clone();
let auto_compact_limit = model_info.auto_compact_token_limit().unwrap_or(i64::MAX);

let event = EventMsg::TurnStarted(TurnStartedEvent {
turn_id: turn_context.sub_id.clone(),
model_context_window: turn_context.model_context_window(),
collaboration_mode_kind: turn_context.collaboration_mode.mode,
});
sess.send_event(&turn_context, event).await;
// TODO(ccunningham): Pre-turn compaction runs before context updates and the
// new user message are recorded. Estimate pending incoming items (context
// diffs/full reinjection + user input) and trigger compaction preemptively
Expand Down Expand Up @@ -6236,7 +6178,7 @@ fn codex_apps_connector_id(tool: &crate::mcp_connection_manager::ToolInfo) -> Op
tool.connector_id.as_deref()
}

fn build_prompt(
pub(crate) fn build_prompt(
input: Vec<ResponseItem>,
router: &ToolRouter,
turn_context: &TurnContext,
Expand Down
Loading
Loading