diff --git a/implants/imix/src/task.rs b/implants/imix/src/task.rs index e04a184ea..bd4ef5a67 100644 --- a/implants/imix/src/task.rs +++ b/implants/imix/src/task.rs @@ -13,11 +13,12 @@ use tokio::sync::mpsc::{self, UnboundedSender}; #[derive(Debug)] struct StreamPrinter { tx: UnboundedSender, + error_tx: UnboundedSender, } impl StreamPrinter { - fn new(tx: UnboundedSender) -> Self { - Self { tx } + fn new(tx: UnboundedSender, error_tx: UnboundedSender) -> Self { + Self { tx, error_tx } } } @@ -29,7 +30,7 @@ impl Printer for StreamPrinter { fn print_err(&self, _span: &Span, s: &str) { // We format with newline to match BufferPrinter behavior - let _ = self.tx.send(format!("{}\n", s)); + let _ = self.error_tx.send(format!("{}\n", s)); } } @@ -168,7 +169,8 @@ fn execute_task( ) { // Setup StreamPrinter and Interpreter let (tx, rx) = mpsc::unbounded_channel(); - let printer = Arc::new(StreamPrinter::new(tx)); + let (error_tx, error_rx) = mpsc::unbounded_channel(); + let printer = Arc::new(StreamPrinter::new(tx, error_tx)); let mut interp = setup_interpreter(task_context.clone(), &tome, agent.clone(), printer.clone()); // Report Start @@ -180,6 +182,7 @@ fn execute_task( agent.clone(), runtime_handle.clone(), rx, + error_rx, ); // Run Interpreter with panic protection @@ -267,27 +270,71 @@ fn spawn_output_consumer( agent: Arc, runtime_handle: tokio::runtime::Handle, mut rx: mpsc::UnboundedReceiver, + mut error_rx: mpsc::UnboundedReceiver, ) -> tokio::task::JoinHandle<()> { runtime_handle.spawn(async move { #[cfg(debug_assertions)] log::info!("task={} Started output stream", task_context.task_id); let task_id = task_context.task_id; - while let Some(msg) = rx.recv().await { - match agent.report_task_output(ReportTaskOutputRequest { - output: Some(TaskOutput { - id: task_id, - output: msg, - error: None, - exec_started_at: None, - exec_finished_at: None, - }), - context: Some(task_context.clone().into()), - }) { - Ok(_) => {} - Err(_e) => { - #[cfg(debug_assertions)] - log::error!("task={task_id} failed to report output: {_e}"); + let mut rx_open = true; + let mut error_rx_open = true; + + loop { + tokio::select! { + val = rx.recv(), if rx_open => { + match val { + Some(msg) => { + match agent.report_task_output(ReportTaskOutputRequest { + output: Some(TaskOutput { + id: task_id, + output: msg, + error: None, + exec_started_at: None, + exec_finished_at: None, + }), + context: Some(task_context.clone().into()), + }) { + Ok(_) => {} + Err(_e) => { + #[cfg(debug_assertions)] + log::error!("task={task_id} failed to report output: {_e}"); + } + } + } + None => { + rx_open = false; + } + } } + val = error_rx.recv(), if error_rx_open => { + match val { + Some(msg) => { + match agent.report_task_output(ReportTaskOutputRequest { + output: Some(TaskOutput { + id: task_id, + output: String::new(), + error: Some(TaskError { msg }), + exec_started_at: None, + exec_finished_at: None, + }), + context: Some(task_context.clone().into()), + }) { + Ok(_) => {} + Err(_e) => { + #[cfg(debug_assertions)] + log::error!("task={task_id} failed to report error: {_e}"); + } + } + } + None => { + error_rx_open = false; + } + } + } + } + + if !rx_open && !error_rx_open { + break; } } }) diff --git a/implants/imix/src/tests/task_tests.rs b/implants/imix/src/tests/task_tests.rs index 520240dca..f537a5e25 100644 --- a/implants/imix/src/tests/task_tests.rs +++ b/implants/imix/src/tests/task_tests.rs @@ -274,3 +274,61 @@ async fn test_task_registry_list_and_stop() { "Task should be removed from list" ); } + +#[tokio::test] +async fn test_task_eprint_behavior() { + let agent = Arc::new(MockAgent::new()); + let task_id = 111; + let code = "eprint(\"This is an error\")\nprint(\"This is output\")"; + + let task = c2::Task { + id: task_id, + tome: Some(Tome { + eldritch: code.to_string(), + ..Default::default() + }), + quest_name: "eprint_test".to_string(), + ..Default::default() + }; + + let registry = TaskRegistry::new(); + registry.spawn(task, agent.clone()); + + tokio::time::sleep(Duration::from_secs(3)).await; + + let reports = agent.output_reports.lock().unwrap(); + + // Check if "This is an error" appears in output or error field + let error_in_output = reports.iter().any(|r| { + r.output + .as_ref() + .map(|o| o.output.contains("This is an error")) + .unwrap_or(false) + }); + + let error_in_error = reports.iter().any(|r| { + r.output + .as_ref() + .map(|o| { + if let Some(err) = &o.error { + err.msg.contains("This is an error") + } else { + false + } + }) + .unwrap_or(false) + }); + + println!("Error in output: {}", error_in_output); + println!("Error in error field: {}", error_in_error); + + // Current behavior (before fix): eprint goes to output + // Desired behavior: eprint goes to error field + + // So if I assert what I want: + assert!(error_in_error, "eprint should be reported as TaskError"); + assert!( + !error_in_output, + "eprint should NOT be reported as regular output" + ); +}