From 9febdb7158785acc9778989f0223038a0363271f Mon Sep 17 00:00:00 2001 From: Joshua Potts <8704475+iamjpotts@users.noreply.github.com> Date: Sun, 21 Sep 2025 17:49:51 -0400 Subject: [PATCH] feat: stream output of executed commands Signed-off-by: Joshua Potts <8704475+iamjpotts@users.noreply.github.com> --- src/ssh/client.rs | 49 +++++++++---- src/ssh/tokio_client/client.rs | 121 +++++++++++++++++++++++++++------ src/ssh/tokio_client/error.rs | 2 + tests/integration_test.rs | 76 +++++++++++++++++++-- 4 files changed, 207 insertions(+), 41 deletions(-) diff --git a/src/ssh/client.rs b/src/ssh/client.rs index 9f3f982d..26ae236b 100644 --- a/src/ssh/client.rs +++ b/src/ssh/client.rs @@ -14,9 +14,11 @@ use super::tokio_client::{AuthMethod, Client}; use crate::jump::{parse_jump_hosts, JumpHostChain}; +use crate::ssh::tokio_client::client::{CommandOutput, CommandOutputBuffer}; use anyhow::{Context, Result}; use std::path::Path; use std::time::Duration; +use tokio::sync::mpsc::Sender; use zeroize::Zeroizing; /// Configuration for SSH connection and command execution @@ -84,6 +86,32 @@ impl SshClient { command: &str, config: &ConnectionConfig<'_>, ) -> Result { + let CommandOutputBuffer { + sender, + receiver_task, + } = CommandOutputBuffer::new(); + + let exit_status = self + .connect_and_execute_with_output_streaming(command, config, sender) + .await?; + + let (output, stderr) = receiver_task.await?; + + // Convert result to our format + Ok(CommandResult { + host: self.host.clone(), + output, + stderr, + exit_status, + }) + } + + pub async fn connect_and_execute_with_output_streaming( + &mut self, + command: &str, + config: &ConnectionConfig<'_>, + output_sender: Sender, + ) -> Result { tracing::debug!("Connecting to {}:{}", self.host, self.port); // Determine authentication method based on parameters @@ -137,11 +165,11 @@ impl SshClient { tracing::debug!("Executing command: {}", command); // Execute command with timeout - let result = if let Some(timeout_secs) = config.timeout_seconds { + let exit_status = if let Some(timeout_secs) = config.timeout_seconds { if timeout_secs == 0 { // No timeout (unlimited) tracing::debug!("Executing command with no timeout (unlimited)"); - client.execute(command) + client.execute_streaming(command, output_sender) .await .with_context(|| format!("Failed to execute command '{}' on {}:{}. The SSH connection was successful but the command could not be executed.", command, self.host, self.port))? } else { @@ -150,7 +178,7 @@ impl SshClient { tracing::debug!("Executing command with timeout of {} seconds", timeout_secs); tokio::time::timeout( command_timeout, - client.execute(command) + client.execute_streaming(command, output_sender) ) .await .with_context(|| format!("Command execution timeout: The command '{}' did not complete within {} seconds on {}:{}", command, timeout_secs, self.host, self.port))? @@ -168,25 +196,16 @@ impl SshClient { tracing::debug!("Executing command with default timeout of 300 seconds"); tokio::time::timeout( command_timeout, - client.execute(command) + client.execute_streaming(command, output_sender) ) .await .with_context(|| format!("Command execution timeout: The command '{}' did not complete within 5 minutes on {}:{}", command, self.host, self.port))? .with_context(|| format!("Failed to execute command '{}' on {}:{}. The SSH connection was successful but the command could not be executed.", command, self.host, self.port))? }; - tracing::debug!( - "Command execution completed with status: {}", - result.exit_status - ); + tracing::debug!("Command execution completed with status: {exit_status}",); - // Convert result to our format - Ok(CommandResult { - host: self.host.clone(), - output: result.stdout.into_bytes(), - stderr: result.stderr.into_bytes(), - exit_status: result.exit_status, - }) + Ok(exit_status) } /// Create a direct SSH connection (no jump hosts) diff --git a/src/ssh/tokio_client/client.rs b/src/ssh/tokio_client/client.rs index 87e72e72..b9ec3421 100644 --- a/src/ssh/tokio_client/client.rs +++ b/src/ssh/tokio_client/client.rs @@ -1,7 +1,7 @@ use russh::client::KeyboardInteractiveAuthResponse; use russh::{ client::{Config, Handle, Handler, Msg}, - Channel, + Channel, CryptoVec, }; use russh_sftp::{client::SftpSession, protocol::OpenFlags}; use std::net::SocketAddr; @@ -9,6 +9,8 @@ use std::sync::Arc; use std::{fmt::Debug, path::Path}; use std::{io, path::PathBuf}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::mpsc::Sender; +use tokio::task::JoinHandle; use zeroize::Zeroizing; use super::ToSocketAddrsWithHostname; @@ -813,38 +815,81 @@ impl Client { /// /// Returns stdout, stderr and the exit code of the command, /// packaged in a [`CommandExecutedResult`] struct. + /// /// If you need the stderr output interleaved within stdout, you should postfix the command with a redirection, /// e.g. `echo foo 2>&1`. - /// If you dont want any output at all, use something like `echo foo >/dev/null 2>&1`. + /// If you don't want any output at all, use something like `echo foo >/dev/null 2>&1`. /// /// Make sure your commands don't read from stdin and exit after bounded time. /// /// Can be called multiple times, but every invocation is a new shell context. - /// Thus `cd`, setting variables and alike have no effect on future invocations. + /// Thus, `cd` and setting variables and alike have no effect on future invocations. pub async fn execute(&self, command: &str) -> Result { + let CommandOutputBuffer { + sender, + receiver_task, + } = CommandOutputBuffer::new(); + + let exit_status = self.execute_streaming(command, sender).await?; + + let (stdout, stderr) = receiver_task.await?; + + let result = CommandExecutedResult { + stdout: String::from_utf8_lossy(&stdout).into(), + stderr: String::from_utf8_lossy(&stderr).into(), + exit_status, + }; + + Ok(result) + } + + /// The same as [`Self:: execute`] except that output from stdout and stderr is + /// provided as it is received via callback functions. Once the command has + /// finished, returns its exit code. + pub async fn execute_streaming( + &self, + command: &str, + sender: Sender, + ) -> Result { // Sanitize command to prevent injection attacks let sanitized_command = crate::utils::sanitize_command(command) .map_err(|e| super::Error::CommandValidationFailed(e.to_string()))?; - // Pre-allocate buffers with capacity to avoid frequent reallocations - let mut stdout_buffer = Vec::with_capacity(SSH_CMD_BUFFER_SIZE); - let mut stderr_buffer = Vec::with_capacity(SSH_RESPONSE_BUFFER_SIZE); let mut channel = self.connection_handle.channel_open_session().await?; channel.exec(true, sanitized_command.as_str()).await?; let mut result: Option = None; + let mut receiver_dropped = false; + // While the channel has messages... while let Some(msg) = channel.wait().await { - //dbg!(&msg); match msg { // If we get data, add it to the buffer - russh::ChannelMsg::Data { ref data } => { - stdout_buffer.write_all(data).await.unwrap() + russh::ChannelMsg::Data { data } => { + if let Err(_send_error) = sender.send(CommandOutput::StdOut(data)).await { + // only log the warning once per command + if !receiver_dropped { + receiver_dropped = true; + + tracing::warn!( + "receiver dropped; cannot send command output to receiver" + ); + } + } } - russh::ChannelMsg::ExtendedData { ref data, ext } => { + russh::ChannelMsg::ExtendedData { data, ext } => { if ext == 1 { - stderr_buffer.write_all(data).await.unwrap() + if let Err(_send_error) = sender.send(CommandOutput::StdErr(data)).await { + // only log the warning once per command + if !receiver_dropped { + receiver_dropped = true; + + tracing::warn!( + "receiver dropped; cannot send command output to receiver" + ); + } + } } } @@ -862,17 +907,7 @@ impl Client { } // If we received an exit code, report it back - if let Some(result) = result { - Ok(CommandExecutedResult { - stdout: String::from_utf8_lossy(&stdout_buffer).to_string(), - stderr: String::from_utf8_lossy(&stderr_buffer).to_string(), - exit_status: result, - }) - - // Otherwise, report an error - } else { - Err(super::Error::CommandDidntExit) - } + result.ok_or(super::Error::CommandDidntExit) } /// Request an interactive shell with PTY support. @@ -1008,6 +1043,48 @@ impl Debug for Client { } } +/// Partial output of a command +pub enum CommandOutput { + /// Partial stdout output of a command + StdOut(CryptoVec), + /// Partial stderr output of a command + StdErr(CryptoVec), +} + +pub(crate) struct CommandOutputBuffer { + pub(crate) sender: Sender, + pub(crate) receiver_task: JoinHandle<(Vec, Vec)>, +} + +impl CommandOutputBuffer { + pub(crate) fn new() -> Self { + // The output collection task should easily keep up with output received from ssh server + const OUTPUT_EVENTS_CHANNEL_SIZE: usize = 100; + + let (sender, mut receiver) = tokio::sync::mpsc::channel(OUTPUT_EVENTS_CHANNEL_SIZE); + + let receiver_task = tokio::task::spawn(async move { + // Pre-allocate buffers with capacity to avoid frequent reallocations + let mut stdout = Vec::with_capacity(SSH_CMD_BUFFER_SIZE); + let mut stderr = Vec::with_capacity(SSH_RESPONSE_BUFFER_SIZE); + + while let Some(output) = receiver.recv().await { + match output { + CommandOutput::StdOut(buffer) => stdout.extend_from_slice(&buffer), + CommandOutput::StdErr(buffer) => stderr.extend_from_slice(&buffer), + } + } + + (stdout, stderr) + }); + + Self { + sender, + receiver_task, + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct CommandExecutedResult { /// The stdout output of the command. diff --git a/src/ssh/tokio_client/error.rs b/src/ssh/tokio_client/error.rs index b45ca010..f5d2ff40 100644 --- a/src/ssh/tokio_client/error.rs +++ b/src/ssh/tokio_client/error.rs @@ -41,6 +41,8 @@ pub enum Error { SftpError(#[from] russh_sftp::client::error::Error), #[error("I/O error")] IoError(#[from] io::Error), + #[error("Task join error: {0}")] + JoinError(#[from] tokio::task::JoinError), #[error("Command validation failed: {0}")] CommandValidationFailed(String), #[error("Port forwarding request failed: {0}")] diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 324b101d..4b93006d 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -14,10 +14,16 @@ use bssh::executor::ParallelExecutor; use bssh::node::Node; +use bssh::ssh::client::ConnectionConfig; +use bssh::ssh::known_hosts::StrictHostKeyChecking; +use bssh::ssh::tokio_client::client::CommandOutput; +use bssh::ssh::SshClient; use std::fs; use std::path::PathBuf; use std::process::Command; use tempfile::TempDir; +use tokio::sync::mpsc::Sender; +use tokio::task::JoinHandle; /// Check if SSH is available and can connect to localhost fn can_ssh_to_localhost() -> bool { @@ -46,6 +52,68 @@ fn can_ssh_to_localhost() -> bool { } } +fn build_test_output_buffer() -> (Sender, JoinHandle<(Vec, Vec)>) { + let (sender, mut output_receiver) = tokio::sync::mpsc::channel(10); + + let receiver_task = tokio::task::spawn(async move { + let mut stdout = Vec::new(); + let mut stderr = Vec::new(); + + while let Some(output) = output_receiver.recv().await { + match output { + CommandOutput::StdOut(buffer) => stdout.extend_from_slice(&buffer), + CommandOutput::StdErr(buffer) => stderr.extend_from_slice(&buffer), + } + } + + (stdout, stderr) + }); + + (sender, receiver_task) +} + +fn get_localhost_test_user() -> String { + std::env::var("USER").unwrap_or_else(|_| "root".to_string()) +} + +#[tokio::test] +async fn test_localhost_execute_streaming_output() { + if !can_ssh_to_localhost() { + eprintln!("Skipping integration test: Cannot SSH to localhost"); + return; + } + + let mut client = SshClient::new("localhost".into(), 22, get_localhost_test_user()); + + let config = ConnectionConfig { + key_path: None, + strict_mode: Some(StrictHostKeyChecking::No), + use_agent: false, + use_password: false, + timeout_seconds: None, + jump_hosts_spec: None, + }; + + const COMMAND: &str = "bash -c 'echo a message && echo an error >&2 && exit 123'"; + + let (sender, receiver_task) = build_test_output_buffer(); + + let exit_code = client + .connect_and_execute_with_output_streaming(COMMAND, &config, sender) + .await + .expect("executed command"); + + assert_eq!(exit_code, 123); + + let (stdout, stderr) = receiver_task.await.expect("joined output task"); + + let stdout = String::from_utf8_lossy(&stdout).to_string(); + let stderr = String::from_utf8_lossy(&stderr).to_string(); + + assert_eq!(stdout, "a message\n"); + assert_eq!(stderr, "an error\n"); +} + #[tokio::test] async fn test_localhost_upload_download_roundtrip() { if !can_ssh_to_localhost() { @@ -66,7 +134,7 @@ async fn test_localhost_upload_download_roundtrip() { let nodes = vec![Node::new( "localhost".to_string(), 22, - std::env::var("USER").unwrap_or_else(|_| "root".to_string()), + get_localhost_test_user(), )]; // Try to find an SSH key - use None if not found (will try SSH agent) let ssh_key = dirs::home_dir().and_then(|h| { @@ -141,7 +209,7 @@ async fn test_localhost_multiple_file_upload() { let nodes = vec![Node::new( "localhost".to_string(), 22, - std::env::var("USER").unwrap_or_else(|_| "root".to_string()), + get_localhost_test_user(), )]; // Try to find an SSH key - use None if not found (will try SSH agent) let ssh_key = dirs::home_dir().and_then(|h| { @@ -181,7 +249,7 @@ async fn test_parallel_execution_with_multiple_nodes() { return; } - let user = std::env::var("USER").unwrap_or_else(|_| "root".to_string()); + let user = get_localhost_test_user(); let nodes = vec![ Node::new("localhost".to_string(), 22, user.clone()), Node::new("127.0.0.1".to_string(), 22, user.clone()), @@ -223,7 +291,7 @@ async fn test_download_with_unique_filenames() { fs::write(&source_file, "Shared content").unwrap(); // Create executor with two "different" nodes (both localhost) - let user = std::env::var("USER").unwrap_or_else(|_| "root".to_string()); + let user = get_localhost_test_user(); let nodes = vec![ Node::new("localhost".to_string(), 22, user.clone()), Node::new("127.0.0.1".to_string(), 22, user),