diff --git a/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs b/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs index 1c30ee530d1..bfb28a227dc 100644 --- a/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs +++ b/codex-rs/app-server/tests/suite/v2/realtime_conversation.rs @@ -190,7 +190,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> { read_notification::(&mut mcp, "thread/realtime/closed") .await?; assert_eq!(closed.thread_id, output_audio.thread_id); - assert_eq!(closed.reason.as_deref(), Some("transport_closed")); + assert_eq!(closed.reason.as_deref(), Some("error")); let connections = realtime_server.connections(); assert_eq!(connections.len(), 1); diff --git a/codex-rs/core/src/realtime_conversation.rs b/codex-rs/core/src/realtime_conversation.rs index 2a8a6337a79..c1c117b2f1e 100644 --- a/codex-rs/core/src/realtime_conversation.rs +++ b/codex-rs/core/src/realtime_conversation.rs @@ -56,6 +56,18 @@ const REALTIME_STARTUP_CONTEXT_TOKEN_BUDGET: usize = 5_000; const ACTIVE_RESPONSE_CONFLICT_ERROR_PREFIX: &str = "Conversation already has an active response in progress:"; +#[derive(Debug)] +enum RealtimeConversationEnd { + Requested, + TransportClosed, + Error, +} + +enum RealtimeFanoutTaskStop { + Abort, + Detach, +} + pub(crate) struct RealtimeConversationManager { state: Mutex>, } @@ -120,7 +132,8 @@ struct ConversationState { user_text_tx: Sender, writer: RealtimeWebsocketWriter, handoff: RealtimeHandoffState, - task: JoinHandle<()>, + input_task: JoinHandle<()>, + fanout_task: Option>, realtime_active: Arc, } @@ -150,9 +163,7 @@ impl RealtimeConversationManager { guard.take() }; if let Some(state) = previous_state { - state.realtime_active.store(false, Ordering::Relaxed); - state.task.abort(); - let _ = state.task.await; + stop_conversation_state(state, RealtimeFanoutTaskStop::Abort).await; } let session_kind = match session_config.event_parser { RealtimeEventParser::V1 => RealtimeSessionKind::V1, @@ -199,12 +210,48 @@ impl RealtimeConversationManager { user_text_tx, writer, handoff, - task, + input_task: task, + fanout_task: None, realtime_active: Arc::clone(&realtime_active), }); Ok((events_rx, realtime_active)) } + pub(crate) async fn register_fanout_task( + &self, + realtime_active: &Arc, + fanout_task: JoinHandle<()>, + ) { + let mut fanout_task = Some(fanout_task); + { + let mut guard = self.state.lock().await; + if let Some(state) = guard.as_mut() + && Arc::ptr_eq(&state.realtime_active, realtime_active) + { + state.fanout_task = fanout_task.take(); + } + } + + if let Some(fanout_task) = fanout_task { + fanout_task.abort(); + let _ = fanout_task.await; + } + } + + pub(crate) async fn finish_if_active(&self, realtime_active: &Arc) { + let state = { + let mut guard = self.state.lock().await; + match guard.as_ref() { + Some(state) if Arc::ptr_eq(&state.realtime_active, realtime_active) => guard.take(), + _ => None, + } + }; + + if let Some(state) = state { + stop_conversation_state(state, RealtimeFanoutTaskStop::Detach).await; + } + } + pub(crate) async fn audio_in(&self, frame: RealtimeAudioFrame) -> CodexResult<()> { let sender = { let guard = self.state.lock().await; @@ -332,19 +379,78 @@ impl RealtimeConversationManager { }; if let Some(state) = state { - state.realtime_active.store(false, Ordering::Relaxed); - state.task.abort(); - let _ = state.task.await; + stop_conversation_state(state, RealtimeFanoutTaskStop::Abort).await; } Ok(()) } } +async fn stop_conversation_state( + mut state: ConversationState, + fanout_task_stop: RealtimeFanoutTaskStop, +) { + state.realtime_active.store(false, Ordering::Relaxed); + state.input_task.abort(); + let _ = state.input_task.await; + + if let Some(fanout_task) = state.fanout_task.take() { + match fanout_task_stop { + RealtimeFanoutTaskStop::Abort => { + fanout_task.abort(); + let _ = fanout_task.await; + } + RealtimeFanoutTaskStop::Detach => {} + } + } +} + pub(crate) async fn handle_start( sess: &Arc, sub_id: String, params: ConversationStartParams, ) -> CodexResult<()> { + let prepared_start = match prepare_realtime_start(sess, params).await { + Ok(prepared_start) => prepared_start, + Err(err) => { + error!("failed to prepare realtime conversation: {err}"); + let message = err.to_string(); + sess.send_event_raw(Event { + id: sub_id, + msg: EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::Error(message), + }), + }) + .await; + return Ok(()); + } + }; + + if let Err(err) = handle_start_inner(sess, &sub_id, prepared_start).await { + error!("failed to start realtime conversation: {err}"); + let message = err.to_string(); + sess.send_event_raw(Event { + id: sub_id.clone(), + msg: EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::Error(message), + }), + }) + .await; + } + Ok(()) +} + +struct PreparedRealtimeConversationStart { + api_provider: ApiProvider, + extra_headers: Option, + requested_session_id: Option, + version: RealtimeWsVersion, + session_config: RealtimeSessionConfig, +} + +async fn prepare_realtime_start( + sess: &Arc, + params: ConversationStartParams, +) -> CodexResult { let provider = sess.provider().await; let auth = sess.services.auth_manager.auth().await; let realtime_api_key = realtime_api_key(auth.as_ref(), &provider)?; @@ -380,9 +486,7 @@ pub(crate) async fn handle_start( RealtimeWsMode::Conversational => RealtimeSessionMode::Conversational, RealtimeWsMode::Transcription => RealtimeSessionMode::Transcription, }; - let requested_session_id = params - .session_id - .or_else(|| Some(sess.conversation_id.to_string())); + let requested_session_id = params.session_id.or(Some(sess.conversation_id.to_string())); let session_config = RealtimeSessionConfig { instructions: prompt, model, @@ -392,24 +496,37 @@ pub(crate) async fn handle_start( }; let extra_headers = realtime_request_headers(requested_session_id.as_deref(), realtime_api_key.as_str())?; + Ok(PreparedRealtimeConversationStart { + api_provider, + extra_headers, + requested_session_id, + version, + session_config, + }) +} + +async fn handle_start_inner( + sess: &Arc, + sub_id: &str, + prepared_start: PreparedRealtimeConversationStart, +) -> CodexResult<()> { + let PreparedRealtimeConversationStart { + api_provider, + extra_headers, + requested_session_id, + version, + session_config, + } = prepared_start; info!("starting realtime conversation"); - let (events_rx, realtime_active) = match sess + let (events_rx, realtime_active) = sess .conversation .start(api_provider, extra_headers, session_config) - .await - { - Ok(events_rx) => events_rx, - Err(err) => { - error!("failed to start realtime conversation: {err}"); - send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::Other).await; - return Ok(()); - } - }; + .await?; info!("realtime conversation started"); sess.send_event_raw(Event { - id: sub_id.clone(), + id: sub_id.to_string(), msg: EventMsg::RealtimeConversationStarted(RealtimeConversationStartedEvent { session_id: requested_session_id, version, @@ -418,12 +535,18 @@ pub(crate) async fn handle_start( .await; let sess_clone = Arc::clone(sess); - tokio::spawn(async move { + let sub_id = sub_id.to_string(); + let fanout_realtime_active = Arc::clone(&realtime_active); + let fanout_task = tokio::spawn(async move { let ev = |msg| Event { id: sub_id.clone(), msg, }; + let mut end = RealtimeConversationEnd::TransportClosed; while let Ok(event) = events_rx.recv().await { + if !fanout_realtime_active.load(Ordering::Relaxed) { + break; + } // if not audio out, log the event if !matches!(event, RealtimeEvent::AudioOut(_)) { info!( @@ -431,6 +554,9 @@ pub(crate) async fn handle_start( "received realtime conversation event" ); } + if matches!(event, RealtimeEvent::Error(_)) { + end = RealtimeConversationEnd::Error; + } let maybe_routed_text = match &event { RealtimeEvent::HandoffRequested(handoff) => { realtime_text_from_handoff_request(handoff) @@ -442,6 +568,9 @@ pub(crate) async fn handle_start( let sess_for_routed_text = Arc::clone(&sess_clone); sess_for_routed_text.route_realtime_text_input(text).await; } + if !fanout_realtime_active.load(Ordering::Relaxed) { + break; + } sess_clone .send_event_raw(ev(EventMsg::RealtimeConversationRealtime( RealtimeConversationRealtimeEvent { @@ -450,17 +579,20 @@ pub(crate) async fn handle_start( ))) .await; } - if realtime_active.swap(false, Ordering::Relaxed) { - info!("realtime conversation transport closed"); + if fanout_realtime_active.swap(false, Ordering::Relaxed) { + if matches!(end, RealtimeConversationEnd::TransportClosed) { + info!("realtime conversation transport closed"); + } sess_clone - .send_event_raw(ev(EventMsg::RealtimeConversationClosed( - RealtimeConversationClosedEvent { - reason: Some("transport_closed".to_string()), - }, - ))) + .conversation + .finish_if_active(&fanout_realtime_active) .await; + send_realtime_conversation_closed(&sess_clone, sub_id, end).await; } }); + sess.conversation + .register_fanout_task(&realtime_active, fanout_task) + .await; Ok(()) } @@ -472,7 +604,12 @@ pub(crate) async fn handle_audio( ) { if let Err(err) = sess.conversation.audio_in(params.frame).await { error!("failed to append realtime audio: {err}"); - send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest).await; + if sess.conversation.running_state().await.is_some() { + warn!("realtime audio input failed while the session was already ending"); + } else { + send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest) + .await; + } } } @@ -480,14 +617,12 @@ fn realtime_text_from_handoff_request(handoff: &RealtimeHandoffRequested) -> Opt let active_transcript = handoff .active_transcript .iter() - .map(|entry| format!("{}: {}", entry.role, entry.text)) + .map(|entry| format!("{role}: {text}", role = entry.role, text = entry.text)) .collect::>() .join("\n"); (!active_transcript.is_empty()) .then_some(active_transcript) - .or_else(|| { - (!handoff.input_transcript.is_empty()).then(|| handoff.input_transcript.clone()) - }) + .or((!handoff.input_transcript.is_empty()).then_some(handoff.input_transcript.clone())) } fn realtime_api_key( @@ -547,25 +682,17 @@ pub(crate) async fn handle_text( debug!(text = %params.text, "[realtime-text] appending realtime conversation text input"); if let Err(err) = sess.conversation.text_in(params.text).await { error!("failed to append realtime text: {err}"); - send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest).await; + if sess.conversation.running_state().await.is_some() { + warn!("realtime text input failed while the session was already ending"); + } else { + send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest) + .await; + } } } pub(crate) async fn handle_close(sess: &Arc, sub_id: String) { - match sess.conversation.shutdown().await { - Ok(()) => { - sess.send_event_raw(Event { - id: sub_id, - msg: EventMsg::RealtimeConversationClosed(RealtimeConversationClosedEvent { - reason: Some("requested".to_string()), - }), - }) - .await; - } - Err(err) => { - send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::Other).await; - } - } + end_realtime_conversation(sess, sub_id, RealtimeConversationEnd::Requested).await; } fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { @@ -593,6 +720,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { if let Err(err) = writer.send_conversation_item_create(text).await { let mapped_error = map_api_error(err); warn!("failed to send input text: {mapped_error}"); + let _ = events_tx + .send(RealtimeEvent::Error(mapped_error.to_string())) + .await; break; } if matches!(session_kind, RealtimeSessionKind::V2) { @@ -601,6 +731,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { } else if let Err(err) = writer.send_response_create().await { let mapped_error = map_api_error(err); warn!("failed to send text response.create: {mapped_error}"); + let _ = events_tx + .send(RealtimeEvent::Error(mapped_error.to_string())) + .await; break; } else { pending_response_create = false; @@ -625,6 +758,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { { let mapped_error = map_api_error(err); warn!("failed to send handoff output: {mapped_error}"); + let _ = events_tx + .send(RealtimeEvent::Error(mapped_error.to_string())) + .await; break; } } @@ -638,6 +774,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { { let mapped_error = map_api_error(err); warn!("failed to send handoff output: {mapped_error}"); + let _ = events_tx + .send(RealtimeEvent::Error(mapped_error.to_string())) + .await; break; } if matches!(session_kind, RealtimeSessionKind::V2) { @@ -648,6 +787,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { warn!( "failed to send handoff response.create: {mapped_error}" ); + let _ = events_tx + .send(RealtimeEvent::Error(mapped_error.to_string())) + .await; break; } else { pending_response_create = false; @@ -685,6 +827,11 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { warn!( "failed to send deferred response.create: {mapped_error}" ); + let _ = events_tx + .send(RealtimeEvent::Error( + mapped_error.to_string(), + )) + .await; break; } pending_response_create = false; @@ -732,6 +879,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { warn!( "failed to send deferred response.create after cancellation: {mapped_error}" ); + let _ = events_tx + .send(RealtimeEvent::Error(mapped_error.to_string())) + .await; break; } pending_response_create = false; @@ -773,11 +923,6 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { } } Ok(None) => { - let _ = events_tx - .send(RealtimeEvent::Error( - "realtime websocket connection is closed".to_string(), - )) - .await; break; } Err(err) => { @@ -800,6 +945,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> { if let Err(err) = writer.send_audio_frame(frame).await { let mapped_error = map_api_error(err); error!("failed to send input audio: {mapped_error}"); + let _ = events_tx + .send(RealtimeEvent::Error(mapped_error.to_string())) + .await; break; } } @@ -839,7 +987,7 @@ fn update_output_audio_state( fn audio_duration_ms(frame: &RealtimeAudioFrame) -> u32 { let Some(samples_per_channel) = frame .samples_per_channel - .or_else(|| decoded_samples_per_channel(frame)) + .or(decoded_samples_per_channel(frame)) else { return 0; }; @@ -870,6 +1018,33 @@ async fn send_conversation_error( .await; } +async fn end_realtime_conversation( + sess: &Arc, + sub_id: String, + end: RealtimeConversationEnd, +) { + let _ = sess.conversation.shutdown().await; + send_realtime_conversation_closed(sess, sub_id, end).await; +} + +async fn send_realtime_conversation_closed( + sess: &Arc, + sub_id: String, + end: RealtimeConversationEnd, +) { + let reason = match end { + RealtimeConversationEnd::Requested => Some("requested".to_string()), + RealtimeConversationEnd::TransportClosed => Some("transport_closed".to_string()), + RealtimeConversationEnd::Error => Some("error".to_string()), + }; + + sess.send_event_raw(Event { + id: sub_id, + msg: EventMsg::RealtimeConversationClosed(RealtimeConversationClosedEvent { reason }), + }) + .await; +} + #[cfg(test)] #[path = "realtime_conversation_tests.rs"] mod tests; diff --git a/codex-rs/core/tests/suite/realtime_conversation.rs b/codex-rs/core/tests/suite/realtime_conversation.rs index 8d156d17dd4..ad38c193b34 100644 --- a/codex-rs/core/tests/suite/realtime_conversation.rs +++ b/codex-rs/core/tests/suite/realtime_conversation.rs @@ -30,15 +30,17 @@ use core_test_support::wait_for_event_match; use pretty_assertions::assert_eq; use serde_json::Value; use serde_json::json; -use serial_test::serial; -use std::ffi::OsString; use std::fs; +use std::process::Command; use std::time::Duration; use tokio::sync::oneshot; +use tokio::time::timeout; const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex."; const MEMORY_PROMPT_PHRASE: &str = "You have access to a memory folder with guidance from prior runs."; +const REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR: &str = + "CODEX_REALTIME_CONVERSATION_TEST_SUBPROCESS"; fn websocket_request_text( request: &core_test_support::responses::WebSocketRequest, ) -> Option { @@ -82,6 +84,33 @@ where tokio::time::sleep(Duration::from_millis(10)).await; } } + +fn run_realtime_conversation_test_in_subprocess( + test_name: &str, + openai_api_key: Option<&str>, +) -> Result<()> { + let mut command = Command::new(std::env::current_exe()?); + command + .arg("--exact") + .arg(test_name) + .env(REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR, "1"); + match openai_api_key { + Some(openai_api_key) => { + command.env(OPENAI_API_KEY_ENV_VAR, openai_api_key); + } + None => { + command.env_remove(OPENAI_API_KEY_ENV_VAR); + } + } + let output = command.output()?; + assert!( + output.status.success(), + "subprocess test `{test_name}` failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr), + ); + Ok(()) +} async fn seed_recent_thread( test: &TestCodex, title: &str, @@ -260,11 +289,16 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -#[serial(openai_api_key_env)] async fn conversation_start_uses_openai_env_key_fallback_with_chatgpt_auth() -> Result<()> { + if std::env::var_os(REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR).is_none() { + return run_realtime_conversation_test_in_subprocess( + "suite::realtime_conversation::conversation_start_uses_openai_env_key_fallback_with_chatgpt_auth", + Some("env-realtime-key"), + ); + } + skip_if_no_network!(Ok(())); - let _env_guard = EnvGuard::set(OPENAI_API_KEY_ENV_VAR, "env-realtime-key"); let server = start_websocket_server(vec![ vec![], vec![vec![json!({ @@ -369,34 +403,6 @@ async fn conversation_transport_close_emits_closed_event() -> Result<()> { Ok(()) } -struct EnvGuard { - key: &'static str, - original: Option, -} - -impl EnvGuard { - fn set(key: &'static str, value: &str) -> Self { - let original = std::env::var_os(key); - // SAFETY: this guard restores the original value before the test exits. - unsafe { - std::env::set_var(key, value); - } - Self { key, original } - } -} - -impl Drop for EnvGuard { - fn drop(&mut self) { - // SAFETY: this guard restores the original value for the modified env var. - unsafe { - match &self.original { - Some(value) => std::env::set_var(self.key, value), - None => std::env::remove_var(self.key), - } - } - } -} - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn conversation_audio_before_start_emits_error() -> Result<()> { skip_if_no_network!(Ok(())); @@ -429,6 +435,91 @@ async fn conversation_audio_before_start_emits_error() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn conversation_start_preflight_failure_emits_realtime_error_only() -> Result<()> { + if std::env::var_os(REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR).is_none() { + return run_realtime_conversation_test_in_subprocess( + "suite::realtime_conversation::conversation_start_preflight_failure_emits_realtime_error_only", + None, + ); + } + + skip_if_no_network!(Ok(())); + + let server = start_websocket_server(vec![]).await; + let mut builder = test_codex().with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let test = builder.build_with_websocket_server(&server).await?; + + test.codex + .submit(Op::RealtimeConversationStart(ConversationStartParams { + prompt: "backend prompt".to_string(), + session_id: None, + })) + .await?; + + let err = wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::Error(message), + }) => Some(message.clone()), + _ => None, + }) + .await; + assert_eq!(err, "realtime conversation requires API key auth"); + + let closed = timeout(Duration::from_millis(200), async { + wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationClosed(closed) => Some(closed.clone()), + _ => None, + }) + .await + }) + .await; + assert!(closed.is_err(), "preflight failure should not emit closed"); + + server.shutdown().await; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn conversation_start_connect_failure_emits_realtime_error_only() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_websocket_server(vec![]).await; + let mut builder = test_codex().with_config(|config| { + config.experimental_realtime_ws_base_url = Some("http://127.0.0.1:1".to_string()); + }); + let test = builder.build_with_websocket_server(&server).await?; + + test.codex + .submit(Op::RealtimeConversationStart(ConversationStartParams { + prompt: "backend prompt".to_string(), + session_id: None, + })) + .await?; + + let err = wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent { + payload: RealtimeEvent::Error(message), + }) => Some(message.clone()), + _ => None, + }) + .await; + assert!(!err.is_empty()); + + let closed = timeout(Duration::from_millis(200), async { + wait_for_event_match(&test.codex, |msg| match msg { + EventMsg::RealtimeConversationClosed(closed) => Some(closed.clone()), + _ => None, + }) + .await + }) + .await; + assert!(closed.is_err(), "connect failure should not emit closed"); + + server.shutdown().await; + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn conversation_text_before_start_emits_error() -> Result<()> { skip_if_no_network!(Ok(()));