Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codex-rs/app-server/src/codex_message_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8871,6 +8871,7 @@ mod tests {
request_id: sent_request_id,
..
}),
..
} = request_message
else {
panic!("expected tool request to be sent to the subscribed connection");
Expand Down
15 changes: 12 additions & 3 deletions codex-rs/app-server/src/command_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use crate::outgoing_message::ConnectionRequestId;
use crate::outgoing_message::OutgoingMessageSender;

const EXEC_TIMEOUT_EXIT_CODE: i32 = 124;
const OUTPUT_CHUNK_SIZE_HINT: usize = 64 * 1024;

#[derive(Clone)]
pub(crate) struct CommandExecManager {
Expand Down Expand Up @@ -577,13 +578,19 @@ fn spawn_process_output(params: SpawnProcessOutputParams) -> tokio::task::JoinHa
let mut buffer: Vec<u8> = Vec::new();
let mut observed_num_bytes = 0usize;
loop {
let chunk = tokio::select! {
let mut chunk = tokio::select! {
chunk = output_rx.recv() => match chunk {
Some(chunk) => chunk,
None => break,
},
_ = stdio_timeout_rx.wait_for(|&v| v) => break,
};
// Individual chunks are at most 8KiB, so overshooting a bit is acceptable.
while chunk.len() < OUTPUT_CHUNK_SIZE_HINT
&& let Ok(next_chunk) = output_rx.try_recv()
{
chunk.extend_from_slice(&next_chunk);
}
let capped_chunk = match output_bytes_cap {
Some(output_bytes_cap) => {
let capped_chunk_len = output_bytes_cap
Expand All @@ -597,8 +604,8 @@ fn spawn_process_output(params: SpawnProcessOutputParams) -> tokio::task::JoinHa
let cap_reached = Some(observed_num_bytes) == output_bytes_cap;
if let (true, Some(process_id)) = (stream_output, process_id.as_ref()) {
outgoing
.send_server_notification_to_connections(
&[connection_id],
.send_server_notification_to_connection_and_wait(
connection_id,
ServerNotification::CommandExecOutputDelta(
CommandExecOutputDeltaNotification {
process_id: process_id.clone(),
Expand Down Expand Up @@ -809,6 +816,7 @@ mod tests {
let OutgoingEnvelope::ToConnection {
connection_id,
message,
..
} = envelope
else {
panic!("expected connection-scoped outgoing message");
Expand Down Expand Up @@ -891,6 +899,7 @@ mod tests {
let OutgoingEnvelope::ToConnection {
connection_id,
message,
..
} = envelope
else {
panic!("expected connection-scoped outgoing message");
Expand Down
11 changes: 8 additions & 3 deletions codex-rs/app-server/src/in_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::OutgoingEnvelope;
use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingMessageSender;
use crate::outgoing_message::QueuedOutgoingMessage;
use crate::transport::CHANNEL_CAPACITY;
use crate::transport::OutboundConnectionState;
use crate::transport::route_outgoing_envelope;
Expand Down Expand Up @@ -353,7 +354,7 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingEnvelope>(channel_capacity);
let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx));

let (writer_tx, mut writer_rx) = mpsc::channel::<OutgoingMessage>(channel_capacity);
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(channel_capacity);
let outbound_initialized = Arc::new(AtomicBool::new(false));
let outbound_experimental_api_enabled = Arc::new(AtomicBool::new(false));
let outbound_opted_out_notification_methods = Arc::new(RwLock::new(HashSet::new()));
Expand Down Expand Up @@ -547,10 +548,11 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
}
}
}
outgoing_message = writer_rx.recv() => {
let Some(outgoing_message) = outgoing_message else {
queued_message = writer_rx.recv() => {
let Some(queued_message) = queued_message else {
break;
};
let outgoing_message = queued_message.message;
match outgoing_message {
OutgoingMessage::Response(response) => {
if let Some(response_tx) = pending_request_responses.remove(&response.id) {
Expand Down Expand Up @@ -629,6 +631,9 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
}
}
}
if let Some(write_complete_tx) = queued_message.write_complete_tx {
let _ = write_complete_tx.send(());
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion codex-rs/app-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::message_processor::MessageProcessorArgs;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::OutgoingEnvelope;
use crate::outgoing_message::OutgoingMessageSender;
use crate::outgoing_message::QueuedOutgoingMessage;
use crate::transport::CHANNEL_CAPACITY;
use crate::transport::ConnectionState;
use crate::transport::OutboundConnectionState;
Expand Down Expand Up @@ -103,7 +104,7 @@ enum OutboundControlEvent {
/// Register a new writer for an opened connection.
Opened {
connection_id: ConnectionId,
writer: mpsc::Sender<crate::outgoing_message::OutgoingMessage>,
writer: mpsc::Sender<QueuedOutgoingMessage>,
disconnect_sender: Option<CancellationToken>,
initialized: Arc<AtomicBool>,
experimental_api_enabled: Arc<AtomicBool>,
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/app-server/src/message_processor/tracing_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ async fn read_response<T: serde::de::DeserializeOwned>(
let crate::outgoing_message::OutgoingEnvelope::ToConnection {
connection_id,
message,
..
} = envelope
else {
continue;
Expand Down Expand Up @@ -420,6 +421,7 @@ async fn read_thread_started_notification(
crate::outgoing_message::OutgoingEnvelope::ToConnection {
connection_id,
message,
..
} => {
if connection_id != TEST_CONNECTION_ID {
continue;
Expand Down
90 changes: 89 additions & 1 deletion codex-rs/app-server/src/outgoing_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,33 @@ impl RequestContext {
}
}

#[derive(Debug, Clone)]
#[derive(Debug)]
pub(crate) enum OutgoingEnvelope {
ToConnection {
connection_id: ConnectionId,
message: OutgoingMessage,
write_complete_tx: Option<oneshot::Sender<()>>,
},
Broadcast {
message: OutgoingMessage,
},
}

#[derive(Debug)]
pub(crate) struct QueuedOutgoingMessage {
pub(crate) message: OutgoingMessage,
pub(crate) write_complete_tx: Option<oneshot::Sender<()>>,
}

impl QueuedOutgoingMessage {
pub(crate) fn new(message: OutgoingMessage) -> Self {
Self {
message,
write_complete_tx: None,
}
}
}

/// Sends messages to the client and manages request callbacks.
pub(crate) struct OutgoingMessageSender {
next_server_request_id: AtomicI64,
Expand Down Expand Up @@ -299,6 +315,7 @@ impl OutgoingMessageSender {
.send(OutgoingEnvelope::ToConnection {
connection_id: *connection_id,
message: outgoing_message.clone(),
write_complete_tx: None,
})
.await
{
Expand Down Expand Up @@ -333,6 +350,7 @@ impl OutgoingMessageSender {
.send(OutgoingEnvelope::ToConnection {
connection_id,
message: OutgoingMessage::Request(request),
write_complete_tx: None,
})
.await
{
Expand Down Expand Up @@ -519,6 +537,7 @@ impl OutgoingMessageSender {
.send(OutgoingEnvelope::ToConnection {
connection_id: *connection_id,
message: outgoing_message.clone(),
write_complete_tx: None,
})
.await
{
Expand All @@ -527,6 +546,28 @@ impl OutgoingMessageSender {
}
}

pub(crate) async fn send_server_notification_to_connection_and_wait(
&self,
connection_id: ConnectionId,
notification: ServerNotification,
) {
tracing::trace!("app-server event: {notification}");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if this was just for temporary debugging, can we remove? IIRC we've had issues with noisy trace logs in the past

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's replicated from existing send_server_notification_to_connections, should I remove it from both?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was added by #12695

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, fine to keep then. seems deliberate

let outgoing_message = OutgoingMessage::AppServerNotification(notification);
let (write_complete_tx, write_complete_rx) = oneshot::channel();
if let Err(err) = self
.sender
.send(OutgoingEnvelope::ToConnection {
connection_id,
message: outgoing_message,
write_complete_tx: Some(write_complete_tx),
})
.await
{
warn!("failed to send server notification to client: {err:?}");
}
let _ = write_complete_rx.await;
}

pub(crate) async fn send_error(
&self,
request_id: ConnectionRequestId,
Expand Down Expand Up @@ -566,6 +607,7 @@ impl OutgoingMessageSender {
let send_fut = self.sender.send(OutgoingEnvelope::ToConnection {
connection_id,
message,
write_complete_tx: None,
});
let send_result = if let Some(request_context) = request_context {
send_fut.instrument(request_context.span()).await
Expand Down Expand Up @@ -818,6 +860,7 @@ mod tests {
OutgoingEnvelope::ToConnection {
connection_id,
message,
..
} => {
assert_eq!(connection_id, ConnectionId(42));
let OutgoingMessage::Response(response) = message else {
Expand Down Expand Up @@ -880,6 +923,7 @@ mod tests {
OutgoingEnvelope::ToConnection {
connection_id,
message,
..
} => {
assert_eq!(connection_id, ConnectionId(9));
let OutgoingMessage::Error(outgoing_error) = message else {
Expand All @@ -892,6 +936,50 @@ mod tests {
}
}

#[tokio::test]
async fn send_server_notification_to_connection_and_wait_tracks_write_completion() {
let (tx, mut rx) = mpsc::channel::<OutgoingEnvelope>(4);
let outgoing = OutgoingMessageSender::new(tx);
let send_task = tokio::spawn(async move {
outgoing
.send_server_notification_to_connection_and_wait(
ConnectionId(42),
ServerNotification::ModelRerouted(ModelReroutedNotification {
thread_id: "thread-1".to_string(),
turn_id: "turn-1".to_string(),
from_model: "gpt-5.3-codex".to_string(),
to_model: "gpt-5.2".to_string(),
reason: ModelRerouteReason::HighRiskCyberActivity,
}),
)
.await
});

let envelope = timeout(Duration::from_secs(1), rx.recv())
.await
.expect("should receive envelope before timeout")
.expect("channel should contain one message");
let OutgoingEnvelope::ToConnection {
connection_id,
message,
write_complete_tx,
} = envelope
else {
panic!("expected targeted server notification envelope");
};
assert_eq!(connection_id, ConnectionId(42));
assert!(matches!(message, OutgoingMessage::AppServerNotification(_)));
write_complete_tx
.expect("write completion sender should be attached")
.send(())
.expect("receiver should still be waiting");

timeout(Duration::from_secs(1), send_task)
.await
.expect("send task should finish after write completion is signaled")
.expect("send task should not panic");
}

#[tokio::test]
async fn connection_closed_clears_registered_request_contexts() {
let (tx, _rx) = mpsc::channel::<OutgoingEnvelope>(4);
Expand Down
Loading
Loading