diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 88c1860c..164b52f7 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -17,6 +17,7 @@ mod connection_manager; mod execution_strategy; mod output_mode; +mod output_sync; mod parallel; mod result_types; mod stream_manager; diff --git a/src/executor/output_sync.rs b/src/executor/output_sync.rs new file mode 100644 index 00000000..2e3d7d6f --- /dev/null +++ b/src/executor/output_sync.rs @@ -0,0 +1,168 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Thread-safe output synchronization for preventing race conditions +//! when multiple nodes write to stdout/stderr simultaneously. + +use once_cell::sync::Lazy; +use std::io::{self, Write}; +use std::sync::Mutex; + +/// Global stdout mutex to prevent interleaved output +static STDOUT_MUTEX: Lazy> = Lazy::new(|| Mutex::new(io::stdout())); + +/// Global stderr mutex to prevent interleaved output +static STDERR_MUTEX: Lazy> = Lazy::new(|| Mutex::new(io::stderr())); + +/// Thread-safe println! that prevents output interleaving +/// +/// This function acquires a mutex lock before writing to ensure +/// that the entire line is written atomically without interruption +/// from other threads. +pub fn synchronized_println(text: &str) -> io::Result<()> { + let mut stdout = STDOUT_MUTEX.lock().unwrap(); + writeln!(stdout, "{text}")?; + stdout.flush()?; + Ok(()) +} + +/// Thread-safe eprintln! that prevents output interleaving +/// +/// This function acquires a mutex lock before writing to ensure +/// that the entire line is written atomically without interruption +/// from other threads. +#[allow(dead_code)] +pub fn synchronized_eprintln(text: &str) -> io::Result<()> { + let mut stderr = STDERR_MUTEX.lock().unwrap(); + writeln!(stderr, "{text}")?; + stderr.flush()?; + Ok(()) +} + +/// Batch write multiple lines to stdout atomically +/// +/// This function writes multiple lines while holding the lock, +/// ensuring that all lines from the same node appear together. +#[allow(dead_code)] +pub fn synchronized_print_lines<'a, I>(lines: I) -> io::Result<()> +where + I: Iterator, +{ + let mut stdout = STDOUT_MUTEX.lock().unwrap(); + for line in lines { + writeln!(stdout, "{line}")?; + } + stdout.flush()?; + Ok(()) +} + +/// Batch write multiple lines to stderr atomically +/// +/// This function writes multiple lines while holding the lock, +/// ensuring that all lines from the same node appear together. +#[allow(dead_code)] +pub fn synchronized_eprint_lines<'a, I>(lines: I) -> io::Result<()> +where + I: Iterator, +{ + let mut stderr = STDERR_MUTEX.lock().unwrap(); + for line in lines { + writeln!(stderr, "{line}")?; + } + stderr.flush()?; + Ok(()) +} + +/// Synchronized output writer for node prefixed output +pub struct NodeOutputWriter { + node_prefix: String, +} + +impl NodeOutputWriter { + /// Create a new writer with a node prefix + pub fn new(node_host: &str) -> Self { + Self { + node_prefix: format!("[{node_host}]"), + } + } + + /// Write stdout lines with node prefix atomically + pub fn write_stdout_lines(&self, text: &str) -> io::Result<()> { + let lines: Vec = text + .lines() + .map(|line| format!("{} {}", self.node_prefix, line)) + .collect(); + + if !lines.is_empty() { + let mut stdout = STDOUT_MUTEX.lock().unwrap(); + for line in lines { + writeln!(stdout, "{line}")?; + } + stdout.flush()?; + } + Ok(()) + } + + /// Write stderr lines with node prefix atomically + pub fn write_stderr_lines(&self, text: &str) -> io::Result<()> { + let lines: Vec = text + .lines() + .map(|line| format!("{} {}", self.node_prefix, line)) + .collect(); + + if !lines.is_empty() { + let mut stderr = STDERR_MUTEX.lock().unwrap(); + for line in lines { + writeln!(stderr, "{line}")?; + } + stderr.flush()?; + } + Ok(()) + } + + /// Write a single stdout line with node prefix + pub fn write_stdout(&self, line: &str) -> io::Result<()> { + synchronized_println(&format!("{} {}", self.node_prefix, line)) + } + + /// Write a single stderr line with node prefix + #[allow(dead_code)] + pub fn write_stderr(&self, line: &str) -> io::Result<()> { + synchronized_eprintln(&format!("{} {}", self.node_prefix, line)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_node_output_writer() { + let writer = NodeOutputWriter::new("test-host"); + assert_eq!(writer.node_prefix, "[test-host]"); + } + + #[test] + fn test_synchronized_output() { + // These tests just verify the functions compile and don't panic + // Actual thread safety is tested through integration tests + + let _ = synchronized_println("test"); + let _ = synchronized_eprintln("test error"); + + let lines = ["line1", "line2"]; + let _ = synchronized_print_lines(lines.iter().copied()); + let _ = synchronized_eprint_lines(lines.iter().copied()); + } +} diff --git a/src/executor/parallel.rs b/src/executor/parallel.rs index fdedff19..f882ccad 100644 --- a/src/executor/parallel.rs +++ b/src/executor/parallel.rs @@ -488,10 +488,12 @@ impl ParallelExecutor { let semaphore = Arc::new(Semaphore::new(self.max_parallel)); let mut manager = MultiNodeStreamManager::new(); let mut handles = Vec::new(); + let mut channels = Vec::new(); // Keep track of senders for cleanup // Spawn tasks for each node with streaming for node in &self.nodes { let (tx, rx) = mpsc::channel(1000); + channels.push(tx.clone()); // Keep a reference for cleanup manager.add_stream(node.clone(), rx); let node_clone = node.clone(); @@ -507,8 +509,32 @@ impl ParallelExecutor { let semaphore = Arc::clone(&semaphore); let handle = tokio::spawn(async move { - // Acquire semaphore - let _permit = semaphore.acquire().await.ok(); + // Use defer pattern to ensure cleanup even on panic + struct CleanupGuard { + _permit: Option, + } + + impl Drop for CleanupGuard { + fn drop(&mut self) { + tracing::trace!("Releasing semaphore permit in cleanup guard"); + } + } + + // Acquire semaphore with guard + let permit = match semaphore.acquire().await { + Ok(p) => p, + Err(e) => { + tracing::error!("Failed to acquire semaphore: {}", e); + return ( + node_clone, + Err(anyhow::anyhow!("Semaphore acquisition failed")), + ); + } + }; + + let _guard = CleanupGuard { + _permit: Some(permit), + }; let mut client = SshClient::new( node_clone.host.clone(), @@ -527,30 +553,50 @@ impl ParallelExecutor { jump_hosts_spec: jump_hosts.as_deref(), }; - match client - .connect_and_execute_with_output_streaming(&command, &config, tx) + // Ensure channel is closed on all paths + let result = match client + .connect_and_execute_with_output_streaming(&command, &config, tx.clone()) .await { - Ok(exit_status) => (node_clone, Ok(exit_status)), - Err(e) => (node_clone, Err(e)), - } + Ok(exit_status) => { + tracing::debug!( + "Command completed for {}: exit code {}", + node_clone.host, + exit_status + ); + (node_clone, Ok(exit_status)) + } + Err(e) => { + tracing::error!("Command failed for {}: {}", node_clone.host, e); + (node_clone, Err(e)) + } + }; + + // Explicitly drop the channel to signal completion + drop(tx); + result }); handles.push(handle); } - // Stream mode: output in real-time with [node] prefixes - if output_mode.is_stream() { + // Execute based on mode and ensure cleanup + let result = if output_mode.is_stream() { + // Stream mode: output in real-time with [node] prefixes self.handle_stream_mode(&mut manager, handles).await - } - // File mode: save to per-node files - else if let Some(output_dir) = output_mode.output_dir() { + } else if let Some(output_dir) = output_mode.output_dir() { + // File mode: save to per-node files self.handle_file_mode(&mut manager, handles, output_dir) .await } else { // Fallback to normal mode self.execute(command).await - } + }; + + // Ensure all channels are closed (important for cleanup) + drop(channels); + + result } /// Handle stream mode output with [node] prefixes @@ -559,6 +605,7 @@ impl ParallelExecutor { manager: &mut super::stream_manager::MultiNodeStreamManager, handles: Vec)>>, ) -> Result> { + use super::output_sync::NodeOutputWriter; use std::time::Duration; let mut pending_handles = handles; @@ -569,30 +616,44 @@ impl ParallelExecutor { // Poll all streams for new output manager.poll_all(); - // Output any new data with [node] prefixes + // Output any new data with [node] prefixes using synchronized writes for stream in manager.streams_mut() { let stdout = stream.take_stdout(); let stderr = stream.take_stderr(); if !stdout.is_empty() { - if let Ok(text) = String::from_utf8(stdout) { - for line in text.lines() { - println!("[{}] {}", stream.node.host, line); - } + // Use lossy conversion to handle non-UTF8 data gracefully + let text = String::from_utf8_lossy(&stdout); + let writer = NodeOutputWriter::new(&stream.node.host); + if let Err(e) = writer.write_stdout_lines(&text) { + tracing::error!("Failed to write stdout for {}: {}", stream.node.host, e); } } if !stderr.is_empty() { - if let Ok(text) = String::from_utf8(stderr) { - for line in text.lines() { - eprintln!("[{}] {}", stream.node.host, line); - } + // Use lossy conversion to handle non-UTF8 data gracefully + let text = String::from_utf8_lossy(&stderr); + let writer = NodeOutputWriter::new(&stream.node.host); + if let Err(e) = writer.write_stderr_lines(&text) { + tracing::error!("Failed to write stderr for {}: {}", stream.node.host, e); } } } - // Check for completed tasks - pending_handles.retain_mut(|handle| !handle.is_finished()); + // Check for completed tasks and handle panics + let mut i = 0; + while i < pending_handles.len() { + if pending_handles[i].is_finished() { + let handle = pending_handles.remove(i); + // Check if task panicked + if let Err(e) = &handle.await { + tracing::error!("Task panicked: {}", e); + // Continue processing other nodes + } + } else { + i += 1; + } + } // Small sleep to avoid busy waiting tokio::time::sleep(Duration::from_millis(50)).await; @@ -634,8 +695,44 @@ impl ParallelExecutor { use std::time::Duration; use tokio::fs; - // Create output directory if it doesn't exist - fs::create_dir_all(output_dir).await?; + // Validate output directory + if output_dir.exists() && !output_dir.is_dir() { + return Err(anyhow::anyhow!( + "Output path exists but is not a directory: {}", + output_dir.display() + )); + } + + // Create output directory if it doesn't exist with proper error handling + if let Err(e) = fs::create_dir_all(output_dir).await { + return Err(anyhow::anyhow!( + "Failed to create output directory '{}': {} - Check permissions", + output_dir.display(), + e + )); + } + + // Check if we can write to the directory + let test_file = output_dir.join(".bssh_test_write"); + match fs::File::create(&test_file).await { + Ok(_) => { + // Clean up test file + let _ = fs::remove_file(&test_file).await; + } + Err(e) => { + return Err(anyhow::anyhow!( + "Output directory '{}' is not writable: {}", + output_dir.display(), + e + )); + } + } + + // Log output directory for user reference + tracing::info!( + "Writing node outputs to directory: {}", + output_dir.display() + ); let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S"); @@ -661,24 +758,60 @@ impl ParallelExecutor { let stdout_path = output_dir.join(format!("{hostname}_{timestamp}.stdout")); let stderr_path = output_dir.join(format!("{hostname}_{timestamp}.stderr")); - // Write stdout + // Write stdout with error handling if !stream.stdout().is_empty() { - fs::write(&stdout_path, stream.stdout()).await?; - println!( - "[{}] Output saved to {}", - stream.node.host, - stdout_path.display() - ); + match fs::write(&stdout_path, stream.stdout()).await { + Ok(_) => { + // Use synchronized output to prevent interleaving + let writer = super::output_sync::NodeOutputWriter::new(&stream.node.host); + if let Err(e) = writer + .write_stdout(&format!("Output saved to {}", stdout_path.display())) + { + tracing::error!( + "Failed to write status for {}: {}", + stream.node.host, + e + ); + } + } + Err(e) => { + tracing::error!( + "Failed to write stdout for {} to {}: {}", + stream.node.host, + stdout_path.display(), + e + ); + // Continue processing other nodes despite error + } + } } - // Write stderr + // Write stderr with error handling if !stream.stderr().is_empty() { - fs::write(&stderr_path, stream.stderr()).await?; - println!( - "[{}] Errors saved to {}", - stream.node.host, - stderr_path.display() - ); + match fs::write(&stderr_path, stream.stderr()).await { + Ok(_) => { + // Use synchronized output to prevent interleaving + let writer = super::output_sync::NodeOutputWriter::new(&stream.node.host); + if let Err(e) = writer + .write_stdout(&format!("Errors saved to {}", stderr_path.display())) + { + tracing::error!( + "Failed to write status for {}: {}", + stream.node.host, + e + ); + } + } + Err(e) => { + tracing::error!( + "Failed to write stderr for {} to {}: {}", + stream.node.host, + stderr_path.display(), + e + ); + // Continue processing other nodes despite error + } + } } let result = diff --git a/src/executor/stream_manager.rs b/src/executor/stream_manager.rs index 5640a85c..3344f473 100644 --- a/src/executor/stream_manager.rs +++ b/src/executor/stream_manager.rs @@ -22,6 +22,66 @@ use crate::node::Node; use crate::ssh::tokio_client::CommandOutput; use tokio::sync::mpsc; +/// Maximum buffer size per stream (10MB) +/// This prevents memory exhaustion when nodes produce large amounts of output +const MAX_BUFFER_SIZE: usize = 10 * 1024 * 1024; // 10MB + +/// A rolling buffer that maintains a fixed maximum size +/// When the buffer exceeds MAX_BUFFER_SIZE, old data is discarded +#[derive(Debug)] +struct RollingBuffer { + data: Vec, + total_bytes_received: usize, + bytes_dropped: usize, +} + +impl RollingBuffer { + fn new() -> Self { + Self { + data: Vec::new(), + total_bytes_received: 0, + bytes_dropped: 0, + } + } + + /// Append data to the buffer, dropping old data if necessary + fn append(&mut self, new_data: &[u8]) { + self.total_bytes_received += new_data.len(); + self.data.extend_from_slice(new_data); + + // If buffer exceeds maximum size, keep only the most recent data + if self.data.len() > MAX_BUFFER_SIZE { + let overflow = self.data.len() - MAX_BUFFER_SIZE; + self.bytes_dropped += overflow; + + // Remove old data from the beginning + self.data.drain(0..overflow); + + // Log warning about dropped data + tracing::warn!( + "Buffer overflow: dropped {} bytes (total dropped: {})", + overflow, + self.bytes_dropped + ); + } + } + + /// Get the current buffer contents + fn as_slice(&self) -> &[u8] { + &self.data + } + + /// Take the buffer contents and clear it + fn take(&mut self) -> Vec { + std::mem::take(&mut self.data) + } + + /// Check if data has been dropped + fn has_overflow(&self) -> bool { + self.bytes_dropped > 0 + } +} + /// Execution status for a node's command #[derive(Debug, Clone, PartialEq, Eq)] pub enum ExecutionStatus { @@ -41,15 +101,18 @@ pub enum ExecutionStatus { /// along with execution status and exit code. This allows for /// independent processing of each node's output without blocking /// on other nodes. +/// +/// Buffers are limited to MAX_BUFFER_SIZE to prevent memory exhaustion. +/// When buffers exceed this limit, old data is automatically discarded. pub struct NodeStream { /// The node this stream is associated with pub node: Node, /// Channel receiver for command output receiver: mpsc::Receiver, - /// Buffer for standard output - stdout_buffer: Vec, - /// Buffer for standard error - stderr_buffer: Vec, + /// Buffer for standard output (with overflow protection) + stdout_buffer: RollingBuffer, + /// Buffer for standard error (with overflow protection) + stderr_buffer: RollingBuffer, /// Current execution status status: ExecutionStatus, /// Exit code (if completed) @@ -64,8 +127,8 @@ impl NodeStream { Self { node, receiver, - stdout_buffer: Vec::new(), - stderr_buffer: Vec::new(), + stdout_buffer: RollingBuffer::new(), + stderr_buffer: RollingBuffer::new(), status: ExecutionStatus::Pending, exit_code: None, closed: false, @@ -90,10 +153,22 @@ impl NodeStream { received_data = true; match output { CommandOutput::StdOut(data) => { - self.stdout_buffer.extend_from_slice(&data); + self.stdout_buffer.append(&data); + if self.stdout_buffer.has_overflow() { + tracing::warn!( + "Node {} stdout buffer overflow - old data discarded", + self.node.host + ); + } } CommandOutput::StdErr(data) => { - self.stderr_buffer.extend_from_slice(&data); + self.stderr_buffer.append(&data); + if self.stderr_buffer.has_overflow() { + tracing::warn!( + "Node {} stderr buffer overflow - old data discarded", + self.node.host + ); + } } } } @@ -107,6 +182,7 @@ impl NodeStream { if self.status != ExecutionStatus::Failed(String::new()) { self.status = ExecutionStatus::Completed; } + tracing::debug!("Channel disconnected for node {}", self.node.host); break; } } @@ -117,26 +193,26 @@ impl NodeStream { /// Get reference to stdout buffer pub fn stdout(&self) -> &[u8] { - &self.stdout_buffer + self.stdout_buffer.as_slice() } /// Get reference to stderr buffer pub fn stderr(&self) -> &[u8] { - &self.stderr_buffer + self.stderr_buffer.as_slice() } /// Take stdout buffer and clear it /// /// This is useful for consuming output in chunks while streaming pub fn take_stdout(&mut self) -> Vec { - std::mem::take(&mut self.stdout_buffer) + self.stdout_buffer.take() } /// Take stderr buffer and clear it /// /// This is useful for consuming output in chunks while streaming pub fn take_stderr(&mut self) -> Vec { - std::mem::take(&mut self.stderr_buffer) + self.stderr_buffer.take() } /// Get current execution status