Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions src/agent/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2244,15 +2244,22 @@ impl Channel {
success,
..
} => {
let mut workers = self.state.active_workers.write().await;
if workers.remove(worker_id).is_none() {
// Use worker_handles as the source of truth for active workers.
// (active_workers is never populated because Worker is consumed by .run())
if self
.state
.worker_handles
.write()
.await
.remove(worker_id)
.is_none()
{
return Ok(());
}
drop(workers);

run_logger.log_worker_completed(*worker_id, result, *success);

self.state.worker_handles.write().await.remove(worker_id);
self.state.active_workers.write().await.remove(worker_id);
self.state.worker_inputs.write().await.remove(worker_id);

if *notify {
Expand Down Expand Up @@ -2619,4 +2626,49 @@ mod tests {

assert!(should_process_event_for_channel(&event, &channel_id));
}

#[test]
fn worker_complete_event_matches_own_channel() {
let channel_id: ChannelId = Arc::from("channel-a");
let event = ProcessEvent::WorkerComplete {
agent_id: Arc::from("agent"),
worker_id: uuid::Uuid::new_v4(),
channel_id: Some(channel_id.clone()),
result: "done".to_string(),
notify: true,
success: true,
};

assert!(should_process_event_for_channel(&event, &channel_id));
}

#[test]
fn worker_complete_event_ignored_for_other_channel() {
let channel_id: ChannelId = Arc::from("channel-a");
let event = ProcessEvent::WorkerComplete {
agent_id: Arc::from("agent"),
worker_id: uuid::Uuid::new_v4(),
channel_id: Some(Arc::from("channel-b")),
result: "done".to_string(),
notify: true,
success: true,
};

assert!(!should_process_event_for_channel(&event, &channel_id));
}

#[test]
fn worker_complete_event_ignored_when_no_channel() {
let channel_id: ChannelId = Arc::from("channel-a");
let event = ProcessEvent::WorkerComplete {
agent_id: Arc::from("agent"),
worker_id: uuid::Uuid::new_v4(),
channel_id: None,
result: "done".to_string(),
notify: true,
success: true,
};

assert!(!should_process_event_for_channel(&event, &channel_id));
}
}
36 changes: 36 additions & 0 deletions src/agent/channel_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,4 +739,40 @@ mod tests {
other => panic!("unexpected event: {other:?}"),
}
}

#[tokio::test]
async fn spawn_worker_task_carries_channel_id() {
let (event_tx, mut event_rx) = broadcast::channel(8);
let worker_id: WorkerId = Uuid::new_v4();
let channel_id: crate::ChannelId = Arc::from("test-channel");

let handle = spawn_worker_task(
worker_id,
event_tx,
Arc::<str>::from("agent"),
Some(channel_id.clone()),
None,
async { Ok::<String, crate::Error>("result".to_string()) },
);

let event = tokio::time::timeout(Duration::from_secs(2), event_rx.recv())
.await
.expect("worker completion event should be delivered")
.expect("broadcast receive should succeed");
handle.await.expect("worker task should join cleanly");

match event {
ProcessEvent::WorkerComplete {
channel_id: event_channel_id,
worker_id: completed_worker_id,
success,
..
} => {
assert_eq!(completed_worker_id, worker_id);
assert_eq!(event_channel_id, Some(channel_id));
assert!(success);
}
other => panic!("unexpected event: {other:?}"),
}
}
}