diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 0f41e3a4..1545fc53 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -337,6 +337,91 @@ let tasks: Vec>> = nodes - Uses system known_hosts file (~/.ssh/known_hosts) - SSH agent authentication with auto-detection +### 4.0.1 Command Output Streaming Infrastructure + +**Status:** Implemented (2025-10-29) as part of Phase 1 of Issue #68 + +**Design Motivation:** +Real-time command output streaming enables future UI features such as live progress bars, per-node output display, and streaming aggregation. The infrastructure provides the foundation for responsive UIs while maintaining full backward compatibility with existing synchronous APIs. + +**Architecture:** + +The streaming infrastructure consists of three key components: + +1. **CommandOutput Enum** (`tokio_client/channel_manager.rs`) + ```rust + pub enum CommandOutput { + StdOut(CryptoVec), + StdErr(CryptoVec), + } + ``` + - Represents streaming output events + - Separates stdout and stderr streams + - Uses russh's `CryptoVec` for zero-copy efficiency + +2. **CommandOutputBuffer** (`tokio_client/channel_manager.rs`) + ```rust + pub(crate) struct CommandOutputBuffer { + sender: Sender, + receiver_task: JoinHandle<(Vec, Vec)>, + } + ``` + - Internal buffer for collecting streaming output + - Background task aggregates stdout and stderr + - Channel capacity: 100 events (tunable) + - Used by synchronous `execute()` for backward compatibility + +3. **Streaming API Methods** + - `Client::execute_streaming(command, sender)` - Low-level streaming API + - `SshClient::connect_and_execute_with_output_streaming()` - High-level streaming API + - Both respect timeout settings and handle errors consistently + +**Implementation Pattern:** + +```rust +// Streaming execution (new in Phase 1) +let (sender, receiver_task) = build_output_buffer(); +let exit_status = client.execute_streaming("command", sender).await?; +let (stdout, stderr) = receiver_task.await?; + +// Backward-compatible execution (refactored to use streaming) +let result = client.execute("command").await?; +// Internally uses execute_streaming() + CommandOutputBuffer +``` + +**Backward Compatibility:** + +The existing `execute()` method was refactored to use `execute_streaming()` internally: +- Same function signature +- Same return type (`CommandExecutedResult`) +- Same error handling behavior +- Same timeout behavior +- Zero breaking changes to existing code + +**Performance Characteristics:** +- Channel-based architecture with bounded buffer (100 events) +- Zero-copy transfer of SSH channel data via `CryptoVec` +- Background task for output aggregation (non-blocking) +- Memory overhead: ~16KB per streaming command (8KB stdout + 1KB stderr + buffer) +- Latency: Real-time streaming with minimal buffering delay + +**Error Handling:** +- New `JoinError` variant in `tokio_client::Error` +- Handles task join failures gracefully +- Timeout handling preserved from original implementation +- Channel send errors handled silently (receiver may be dropped) + +**Testing:** +- Integration tests cover streaming with stdout/stderr separation +- Backward compatibility test ensures no behavioral changes +- Tests use localhost SSH for reproducible validation +- All existing tests pass with zero modifications + +**Future Phases (Issue #68):** +- Phase 2: Executor integration for parallel streaming +- Phase 3: UI components (progress bars, live updates) +- Phase 4: Advanced features (filtering, aggregation) + ### 4.1 Authentication Module (`ssh/auth.rs`) **Status:** Implemented (2025-10-17) as part of code deduplication refactoring (Issue #34) diff --git a/src/ssh/client/command.rs b/src/ssh/client/command.rs index 0746cc85..1b5b0815 100644 --- a/src/ssh/client/command.rs +++ b/src/ssh/client/command.rs @@ -16,9 +16,11 @@ use super::config::ConnectionConfig; use super::core::SshClient; use super::result::CommandResult; use crate::ssh::known_hosts::StrictHostKeyChecking; +use crate::ssh::tokio_client::CommandOutput; use anyhow::{Context, Result}; use std::path::Path; use std::time::Duration; +use tokio::sync::mpsc::Sender; // SSH command execution timeout design: // - 5 minutes (300s) handles long-running commands @@ -160,4 +162,108 @@ impl SshClient { .with_context(|| format!("Failed to execute command '{}' on {}:{}. The SSH connection was successful but the command could not be executed.", command, self.host, self.port)) } } + + /// Execute a command with streaming output support + /// + /// This method provides real-time command output streaming through the provided sender channel. + /// Output is sent as `CommandOutput::StdOut` or `CommandOutput::StdErr` variants. + /// + /// # Arguments + /// * `command` - The command to execute + /// * `config` - Connection configuration + /// * `output_sender` - Channel sender for streaming output + /// + /// # Returns + /// The exit status of the command + pub async fn connect_and_execute_with_output_streaming( + &mut self, + command: &str, + config: &ConnectionConfig<'_>, + output_sender: Sender, + ) -> Result { + tracing::debug!("Connecting to {}:{}", self.host, self.port); + + // Determine authentication method based on parameters + let auth_method = self + .determine_auth_method( + config.key_path, + config.use_agent, + config.use_password, + #[cfg(target_os = "macos")] + config.use_keychain, + ) + .await?; + + let strict_mode = config + .strict_mode + .unwrap_or(StrictHostKeyChecking::AcceptNew); + + // Create client connection - either direct or through jump hosts + let client = self + .establish_connection( + &auth_method, + strict_mode, + config.jump_hosts_spec, + config.key_path, + config.use_agent, + config.use_password, + ) + .await?; + + tracing::debug!("Connected and authenticated successfully"); + tracing::debug!("Executing command with streaming: {}", command); + + // Execute command with streaming and timeout + let exit_status = self + .execute_streaming_with_timeout(&client, command, config.timeout_seconds, output_sender) + .await?; + + tracing::debug!("Command execution completed with status: {}", exit_status); + + Ok(exit_status) + } + + /// Execute a command with streaming output and the specified timeout + async fn execute_streaming_with_timeout( + &self, + client: &crate::ssh::tokio_client::Client, + command: &str, + timeout_seconds: Option, + output_sender: Sender, + ) -> Result { + if let Some(timeout_secs) = timeout_seconds { + if timeout_secs == 0 { + // No timeout (unlimited) + tracing::debug!("Executing command with streaming, no timeout (unlimited)"); + client.execute_streaming(command, output_sender) + .await + .with_context(|| format!("Failed to execute command '{}' on {}:{}. The SSH connection was successful but the command could not be executed.", command, self.host, self.port)) + } else { + // With timeout + let command_timeout = Duration::from_secs(timeout_secs); + tracing::debug!( + "Executing command with streaming, timeout of {} seconds", + timeout_secs + ); + tokio::time::timeout( + command_timeout, + client.execute_streaming(command, output_sender) + ) + .await + .with_context(|| format!("Command execution timeout: The command '{}' did not complete within {} seconds on {}:{}", command, timeout_secs, self.host, self.port))? + .with_context(|| format!("Failed to execute command '{}' on {}:{}. The SSH connection was successful but the command could not be executed.", command, self.host, self.port)) + } + } else { + // Default timeout if not specified + let command_timeout = Duration::from_secs(DEFAULT_COMMAND_TIMEOUT_SECS); + tracing::debug!("Executing command with streaming, default timeout of 300 seconds"); + tokio::time::timeout( + command_timeout, + client.execute_streaming(command, output_sender) + ) + .await + .with_context(|| format!("Command execution timeout: The command '{}' did not complete within 5 minutes on {}:{}", command, self.host, self.port))? + .with_context(|| format!("Failed to execute command '{}' on {}:{}. The SSH connection was successful but the command could not be executed.", command, self.host, self.port)) + } + } } diff --git a/src/ssh/tokio_client/channel_manager.rs b/src/ssh/tokio_client/channel_manager.rs index e429e043..2e35b047 100644 --- a/src/ssh/tokio_client/channel_manager.rs +++ b/src/ssh/tokio_client/channel_manager.rs @@ -22,9 +22,11 @@ use russh::client::Msg; use russh::Channel; +use russh::CryptoVec; use std::io; use std::net::SocketAddr; -use tokio::io::AsyncWriteExt; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::task::JoinHandle; use super::connection::Client; use super::ToSocketAddrsWithHostname; @@ -51,6 +53,53 @@ const SSH_CMD_BUFFER_SIZE: usize = 8192; /// - Matches typical terminal line lengths const SSH_RESPONSE_BUFFER_SIZE: usize = 1024; +/// Output events channel capacity for streaming +/// - 100 events provides good buffering without excessive memory +/// - Balances between latency and throughput +const OUTPUT_EVENTS_CHANNEL_SIZE: usize = 100; + +/// Command output variants for streaming +#[derive(Debug, Clone)] +pub enum CommandOutput { + /// Standard output data + StdOut(CryptoVec), + /// Standard error data + StdErr(CryptoVec), +} + +/// Buffer for collecting streaming command output +pub(crate) struct CommandOutputBuffer { + pub(crate) sender: Sender, + pub(crate) receiver_task: JoinHandle<(Vec, Vec)>, +} + +impl CommandOutputBuffer { + /// Create a new command output buffer with a background task to collect output + pub(crate) fn new() -> Self { + let (sender, mut receiver): (Sender, Receiver) = + channel(OUTPUT_EVENTS_CHANNEL_SIZE); + + let receiver_task = tokio::task::spawn(async move { + let mut stdout = Vec::with_capacity(SSH_CMD_BUFFER_SIZE); + let mut stderr = Vec::with_capacity(SSH_RESPONSE_BUFFER_SIZE); + + while let Some(output) = receiver.recv().await { + match output { + CommandOutput::StdOut(buffer) => stdout.extend_from_slice(&buffer), + CommandOutput::StdErr(buffer) => stderr.extend_from_slice(&buffer), + } + } + + (stdout, stderr) + }); + + Self { + sender, + receiver_task, + } + } +} + /// Result of a command execution. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct CommandExecutedResult { @@ -113,26 +162,34 @@ impl Client { Err(connect_err) } - /// Execute a remote command via the ssh connection. + /// Execute a remote command via the ssh connection with streaming output. /// - /// Returns stdout, stderr and the exit code of the command, - /// packaged in a [`CommandExecutedResult`] struct. - /// If you need the stderr output interleaved within stdout, you should postfix the command with a redirection, - /// e.g. `echo foo 2>&1`. - /// If you dont want any output at all, use something like `echo foo >/dev/null 2>&1`. + /// This method sends command output in real-time to the provided sender channel. + /// Output is sent as `CommandOutput::StdOut` or `CommandOutput::StdErr` variants. + /// + /// Returns only the exit status of the command. Stdout and stderr are streamed + /// through the sender channel. /// /// Make sure your commands don't read from stdin and exit after bounded time. /// /// Can be called multiple times, but every invocation is a new shell context. /// Thus `cd`, setting variables and alike have no effect on future invocations. - pub async fn execute(&self, command: &str) -> Result { + /// + /// # Arguments + /// * `command` - The command to execute + /// * `sender` - Channel sender for streaming output + /// + /// # Returns + /// The exit status of the command + pub async fn execute_streaming( + &self, + command: &str, + sender: Sender, + ) -> Result { // Sanitize command to prevent injection attacks let sanitized_command = crate::utils::sanitize_command(command) .map_err(|e| super::Error::CommandValidationFailed(e.to_string()))?; - // Pre-allocate buffers with capacity to avoid frequent reallocations - let mut stdout_buffer = Vec::with_capacity(SSH_CMD_BUFFER_SIZE); - let mut stderr_buffer = Vec::with_capacity(SSH_RESPONSE_BUFFER_SIZE); let mut channel = self.connection_handle.channel_open_session().await?; channel.exec(true, sanitized_command.as_str()).await?; @@ -140,15 +197,14 @@ impl Client { // While the channel has messages... while let Some(msg) = channel.wait().await { - //dbg!(&msg); match msg { - // If we get data, add it to the buffer + // If we get data, send it to the streaming channel russh::ChannelMsg::Data { ref data } => { - stdout_buffer.write_all(data).await.unwrap() + let _ = sender.send(CommandOutput::StdOut(data.clone())).await; } russh::ChannelMsg::ExtendedData { ref data, ext } => { if ext == 1 { - stderr_buffer.write_all(data).await.unwrap() + let _ = sender.send(CommandOutput::StdErr(data.clone())).await; } } @@ -157,7 +213,7 @@ impl Client { // not be finished yet! russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status), - // We SHOULD get this EOF messagge, but 4254 sec 5.3 also permits + // We SHOULD get this EOF message, but 4254 sec 5.3 also permits // the channel to close without it being sent. And sometimes this // message can even precede the Data message, so don't handle it // russh::ChannelMsg::Eof => break, @@ -165,20 +221,51 @@ impl Client { } } + // Drop sender to signal completion to receiver + drop(sender); + // If we received an exit code, report it back if let Some(result) = result { - Ok(CommandExecutedResult { - stdout: String::from_utf8_lossy(&stdout_buffer).to_string(), - stderr: String::from_utf8_lossy(&stderr_buffer).to_string(), - exit_status: result, - }) - + Ok(result) // Otherwise, report an error } else { Err(super::Error::CommandDidntExit) } } + /// Execute a remote command via the ssh connection. + /// + /// Returns stdout, stderr and the exit code of the command, + /// packaged in a [`CommandExecutedResult`] struct. + /// If you need the stderr output interleaved within stdout, you should postfix the command with a redirection, + /// e.g. `echo foo 2>&1`. + /// If you dont want any output at all, use something like `echo foo >/dev/null 2>&1`. + /// + /// Make sure your commands don't read from stdin and exit after bounded time. + /// + /// Can be called multiple times, but every invocation is a new shell context. + /// Thus `cd`, setting variables and alike have no effect on future invocations. + pub async fn execute(&self, command: &str) -> Result { + // Use streaming internally but collect all output + let output_buffer = CommandOutputBuffer::new(); + let sender = output_buffer.sender.clone(); + + // Execute with streaming + let exit_status = self.execute_streaming(command, sender).await?; + + // Wait for all output to be collected + let (stdout_bytes, stderr_bytes) = output_buffer + .receiver_task + .await + .map_err(super::Error::JoinError)?; + + Ok(CommandExecutedResult { + stdout: String::from_utf8_lossy(&stdout_bytes).to_string(), + stderr: String::from_utf8_lossy(&stderr_bytes).to_string(), + exit_status, + }) + } + /// Request an interactive shell channel. /// /// This method opens a new SSH channel suitable for interactive shell sessions. diff --git a/src/ssh/tokio_client/error.rs b/src/ssh/tokio_client/error.rs index b45ca010..087a89e4 100644 --- a/src/ssh/tokio_client/error.rs +++ b/src/ssh/tokio_client/error.rs @@ -51,4 +51,6 @@ pub enum Error { PortForwardingNotSupported, #[error("Global request failed: {0}")] GlobalRequestFailed(String), + #[error("Task join error: {0}")] + JoinError(#[from] tokio::task::JoinError), } diff --git a/src/ssh/tokio_client/mod.rs b/src/ssh/tokio_client/mod.rs index 9ecf518d..811d71c2 100644 --- a/src/ssh/tokio_client/mod.rs +++ b/src/ssh/tokio_client/mod.rs @@ -23,7 +23,7 @@ mod to_socket_addrs_with_hostname; // Re-export public API types for backward compatibility pub use authentication::{AuthKeyboardInteractive, AuthMethod, ServerCheckMethod}; -pub use channel_manager::CommandExecutedResult; +pub use channel_manager::{CommandExecutedResult, CommandOutput}; pub use connection::{Client, ClientHandler}; pub use error::Error; pub use to_socket_addrs_with_hostname::ToSocketAddrsWithHostname; diff --git a/tests/streaming_test.rs b/tests/streaming_test.rs new file mode 100644 index 00000000..78d29368 --- /dev/null +++ b/tests/streaming_test.rs @@ -0,0 +1,221 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bssh::ssh::tokio_client::{AuthMethod, Client, CommandOutput, ServerCheckMethod}; +use tokio::sync::mpsc::channel; + +/// Type alias for output buffer components +type OutputBuffer = ( + tokio::sync::mpsc::Sender, + tokio::task::JoinHandle<(Vec, Vec)>, +); + +/// Helper function to build a test output buffer +fn build_test_output_buffer() -> OutputBuffer { + let (sender, mut receiver) = channel(100); + + let receiver_task = tokio::task::spawn(async move { + let mut stdout = Vec::new(); + let mut stderr = Vec::new(); + + while let Some(output) = receiver.recv().await { + match output { + CommandOutput::StdOut(buffer) => stdout.extend_from_slice(&buffer), + CommandOutput::StdErr(buffer) => stderr.extend_from_slice(&buffer), + } + } + + (stdout, stderr) + }); + + (sender, receiver_task) +} + +/// Check if SSH is available and can connect to localhost +fn can_ssh_to_localhost() -> bool { + use std::process::Command; + + // Check if SSH server is running and we can connect to localhost + let output = Command::new("ssh") + .args([ + "-o", + "ConnectTimeout=2", + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + "-o", + "PasswordAuthentication=no", + "-o", + "BatchMode=yes", + "localhost", + "echo", + "test", + ]) + .output(); + + match output { + Ok(result) => result.status.success(), + Err(_) => false, + } +} + +#[tokio::test] +async fn test_localhost_execute_streaming_output() { + if !can_ssh_to_localhost() { + eprintln!("Skipping streaming test: Cannot SSH to localhost"); + return; + } + + // Get current username + let username = std::env::var("USER").unwrap_or_else(|_| "root".to_string()); + + // Create client + let client = Client::connect( + ("localhost", 22), + &username, + AuthMethod::Agent, // Use SSH agent for authentication + ServerCheckMethod::NoCheck, + ) + .await; + + if client.is_err() { + eprintln!("Skipping streaming test: Cannot connect to localhost"); + return; + } + + let client = client.unwrap(); + + // Build output buffer for streaming + let (sender, receiver_task) = build_test_output_buffer(); + + // Execute command with streaming + let exit_status = client + .execute_streaming("echo 'Hello from streaming test'", sender) + .await; + + assert!(exit_status.is_ok(), "Command should execute successfully"); + let exit_status = exit_status.unwrap(); + assert_eq!(exit_status, 0, "Command should exit with status 0"); + + // Wait for output collection + let (stdout_bytes, stderr_bytes) = receiver_task.await.unwrap(); + let stdout = String::from_utf8_lossy(&stdout_bytes); + let stderr = String::from_utf8_lossy(&stderr_bytes); + + // Verify output + assert!( + stdout.contains("Hello from streaming test"), + "Stdout should contain test message, got: {stdout}" + ); + assert_eq!(stderr, "", "Stderr should be empty, got: {stderr}"); +} + +#[tokio::test] +async fn test_backward_compatibility_execute() { + if !can_ssh_to_localhost() { + eprintln!("Skipping backward compatibility test: Cannot SSH to localhost"); + return; + } + + // Get current username + let username = std::env::var("USER").unwrap_or_else(|_| "root".to_string()); + + // Create client + let client = Client::connect( + ("localhost", 22), + &username, + AuthMethod::Agent, + ServerCheckMethod::NoCheck, + ) + .await; + + if client.is_err() { + eprintln!("Skipping backward compatibility test: Cannot connect to localhost"); + return; + } + + let client = client.unwrap(); + + // Execute command using the original execute() method + let result = client.execute("echo 'Backward compatibility test'").await; + + assert!(result.is_ok(), "Command should execute successfully"); + let result = result.unwrap(); + + // Verify behavior is exactly the same as before + assert_eq!(result.exit_status, 0, "Command should exit with status 0"); + assert!( + result.stdout.contains("Backward compatibility test"), + "Stdout should contain test message, got: {}", + result.stdout + ); + assert_eq!( + result.stderr, "", + "Stderr should be empty, got: {}", + result.stderr + ); +} + +#[tokio::test] +async fn test_streaming_with_stderr() { + if !can_ssh_to_localhost() { + eprintln!("Skipping stderr streaming test: Cannot SSH to localhost"); + return; + } + + // Get current username + let username = std::env::var("USER").unwrap_or_else(|_| "root".to_string()); + + // Create client + let client = Client::connect( + ("localhost", 22), + &username, + AuthMethod::Agent, + ServerCheckMethod::NoCheck, + ) + .await; + + if client.is_err() { + eprintln!("Skipping stderr streaming test: Cannot connect to localhost"); + return; + } + + let client = client.unwrap(); + + // Build output buffer for streaming + let (sender, receiver_task) = build_test_output_buffer(); + + // Execute command that outputs to both stdout and stderr + let exit_status = client + .execute_streaming("echo 'stdout message' && echo 'stderr message' >&2", sender) + .await; + + assert!(exit_status.is_ok(), "Command should execute successfully"); + + // Wait for output collection + let (stdout_bytes, stderr_bytes) = receiver_task.await.unwrap(); + let stdout = String::from_utf8_lossy(&stdout_bytes); + let stderr = String::from_utf8_lossy(&stderr_bytes); + + // Verify both streams + assert!( + stdout.contains("stdout message"), + "Stdout should contain stdout message, got: {stdout}" + ); + assert!( + stderr.contains("stderr message"), + "Stderr should contain stderr message, got: {stderr}" + ); +}