diff --git a/src/commands/interactive/connection.rs b/src/commands/interactive/connection.rs index a68f8e24..ea4ce61b 100644 --- a/src/commands/interactive/connection.rs +++ b/src/commands/interactive/connection.rs @@ -20,12 +20,13 @@ use russh::client::Msg; use russh::Channel; use std::io::{self, Write}; use tokio::time::{timeout, Duration}; +use zeroize::Zeroizing; use crate::jump::{parse_jump_hosts, JumpHostChain}; use crate::node::Node; use crate::ssh::{ known_hosts::get_check_method, - tokio_client::{AuthMethod, Client, ServerCheckMethod}, + tokio_client::{AuthMethod, Client, Error as SshError, ServerCheckMethod}, }; use super::types::{InteractiveCommand, NodeSession}; @@ -33,6 +34,9 @@ use super::types::{InteractiveCommand, NodeSession}; impl InteractiveCommand { /// Helper function to establish SSH connection with proper error handling and rate limiting /// This eliminates code duplication across different connection paths and prevents brute-force attacks + /// + /// If `allow_password_fallback` is true and key authentication fails, it will prompt for password + /// and retry with password authentication (matching OpenSSH behavior). async fn establish_connection( addr: (&str, u16), username: &str, @@ -40,6 +44,7 @@ impl InteractiveCommand { check_method: ServerCheckMethod, host: &str, port: u16, + allow_password_fallback: bool, ) -> Result { const SSH_CONNECT_TIMEOUT_SECS: u64 = 30; let connect_timeout = Duration::from_secs(SSH_CONNECT_TIMEOUT_SECS); @@ -56,15 +61,47 @@ impl InteractiveCommand { let result = timeout( connect_timeout, - Client::connect(addr, username, auth_method, check_method), + Client::connect(addr, username, auth_method, check_method.clone()), ) .await .with_context(|| { format!( "Connection timeout: Failed to connect to {host}:{port} after {SSH_CONNECT_TIMEOUT_SECS} seconds" ) - })? - .with_context(|| format!("SSH connection failed to {host}:{port}")); + })?; + + // Check if key authentication failed and password fallback is allowed + let result = match result { + Err(SshError::KeyAuthFailed) + if allow_password_fallback && atty::is(atty::Stream::Stdin) => + { + tracing::debug!( + "SSH key authentication failed for {username}@{host}:{port}, attempting password fallback" + ); + + // Prompt for password (matching OpenSSH behavior) + let password = Self::prompt_password(username, host).await?; + + // Retry with password authentication + let password_auth = AuthMethod::with_password(&password); + + // Small delay before retry to prevent rapid attempts + tokio::time::sleep(Duration::from_millis(500)).await; + + timeout( + connect_timeout, + Client::connect(addr, username, password_auth, check_method), + ) + .await + .with_context(|| { + format!( + "Connection timeout: Failed to connect to {host}:{port} after {SSH_CONNECT_TIMEOUT_SECS} seconds" + ) + })? + .with_context(|| format!("SSH connection failed to {host}:{port}")) + } + other => other.with_context(|| format!("SSH connection failed to {host}:{port}")), + }; // SECURITY: Normalize timing to prevent timing attacks // Ensure all authentication attempts take at least 500ms to complete @@ -79,6 +116,22 @@ impl InteractiveCommand { result } + /// Prompt for password with secure handling + async fn prompt_password(username: &str, host: &str) -> Result> { + let username = username.to_string(); + let host = host.to_string(); + + tokio::task::spawn_blocking(move || { + let password = Zeroizing::new( + rpassword::prompt_password(format!("{username}@{host}'s password: ")) + .with_context(|| "Failed to read password")?, + ); + Ok(password) + }) + .await + .with_context(|| "Password prompt task failed")? + } + /// Determine authentication method based on node and config (same logic as exec mode) pub(super) async fn determine_auth_method(&self, node: &Node) -> Result { // Use centralized authentication logic from auth module @@ -164,6 +217,7 @@ impl InteractiveCommand { tracing::debug!("No valid jump hosts found, using direct connection"); // Use the helper function to establish connection + // Enable password fallback for interactive mode (matches OpenSSH behavior) Self::establish_connection( addr, &node.username, @@ -171,6 +225,7 @@ impl InteractiveCommand { check_method.clone(), &node.host, node.port, + !self.use_password, // Allow fallback unless explicit password mode ) .await? } else { @@ -239,6 +294,7 @@ impl InteractiveCommand { tracing::debug!("Using direct connection (no jump hosts)"); // Use the helper function to establish connection + // Enable password fallback for interactive mode (matches OpenSSH behavior) Self::establish_connection( addr, &node.username, @@ -246,6 +302,7 @@ impl InteractiveCommand { check_method, &node.host, node.port, + !self.use_password, // Allow fallback unless explicit password mode ) .await? }; @@ -300,6 +357,7 @@ impl InteractiveCommand { tracing::debug!("No valid jump hosts found, using direct connection for PTY"); // Use the helper function to establish connection + // Enable password fallback for interactive mode (matches OpenSSH behavior) Self::establish_connection( addr, &node.username, @@ -307,6 +365,7 @@ impl InteractiveCommand { check_method.clone(), &node.host, node.port, + !self.use_password, // Allow fallback unless explicit password mode ) .await? } else { @@ -375,6 +434,7 @@ impl InteractiveCommand { tracing::debug!("Using direct connection for PTY (no jump hosts)"); // Use the helper function to establish connection + // Enable password fallback for interactive mode (matches OpenSSH behavior) Self::establish_connection( addr, &node.username, @@ -382,6 +442,7 @@ impl InteractiveCommand { check_method, &node.host, node.port, + !self.use_password, // Allow fallback unless explicit password mode ) .await? }; diff --git a/src/ssh/auth.rs b/src/ssh/auth.rs index 61aad338..b21ba023 100644 --- a/src/ssh/auth.rs +++ b/src/ssh/auth.rs @@ -218,15 +218,20 @@ impl AuthContext { } } - // Priority 3: Key file authentication + // Priority 3: Key file authentication (explicit -i flag) if let Some(ref key_path) = self.key_path { return self.key_file_auth(key_path).await; } - // Priority 4: SSH agent auto-detection (if use_agent is true) + // Priority 4: SSH agent auto-detection (like OpenSSH behavior) + // OpenSSH tries SSH agent first when available, as it can try all registered keys #[cfg(not(target_os = "windows"))] - if self.use_agent { + if !self.use_agent { + // Auto-detect SSH agent even without --use-agent flag if let Some(auth) = self.agent_auth()? { + tracing::debug!( + "Using SSH agent (auto-detected) - agent will try all registered keys" + ); return Ok(auth); } } diff --git a/src/utils/logging.rs b/src/utils/logging.rs index 95899cf9..65fa6546 100644 --- a/src/utils/logging.rs +++ b/src/utils/logging.rs @@ -15,15 +15,24 @@ use tracing_subscriber::EnvFilter; pub fn init_logging(verbosity: u8) { - let filter = match verbosity { - 0 => EnvFilter::new("bssh=warn"), - 1 => EnvFilter::new("bssh=info"), - 2 => EnvFilter::new("bssh=debug"), - _ => EnvFilter::new("bssh=trace"), + // Priority: RUST_LOG environment variable > verbosity flag + let filter = if std::env::var("RUST_LOG").is_ok() { + // Use RUST_LOG if set (allows debugging russh and other dependencies) + EnvFilter::from_default_env() + } else { + // Fall back to verbosity-based filter + match verbosity { + 0 => EnvFilter::new("bssh=warn"), + 1 => EnvFilter::new("bssh=info"), + // -vv: Include russh debug logs for SSH troubleshooting + 2 => EnvFilter::new("bssh=debug,russh=debug"), + // -vvv: Full trace including all dependencies + _ => EnvFilter::new("bssh=trace,russh=trace,russh_sftp=debug"), + } }; tracing_subscriber::fmt() .with_env_filter(filter) - .with_target(false) + .with_target(true) // Show module targets for better debugging .init(); }