From 0a56484173051e2e153e096b8f17961f3350d40c Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Fri, 22 Aug 2025 21:37:03 +0900 Subject: [PATCH 1/3] feat: add password authentication and SSH key passphrase support - Add --password flag for password-based authentication - Automatically detect and prompt for passphrases on encrypted SSH keys - Support passphrase entry for both explicit and default key paths - Check multiple default key types (ed25519, rsa, ecdsa, dsa) - Update all commands (exec, upload, download, ping) to support new auth methods - Add secure password/passphrase prompting using rpassword crate --- README.md | 8 ++- src/cli.rs | 7 +++ src/executor.rs | 45 +++++++++++++++- src/main.rs | 21 ++++++-- src/ssh/client.rs | 130 ++++++++++++++++++++++++++++++++++++---------- 5 files changed, 178 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 631e44ad..f0cd12c7 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A high-performance parallel SSH command execution tool for cluster management, b - **Parallel Execution**: Execute commands across multiple nodes simultaneously - **Cluster Management**: Define and manage node clusters via configuration files - **Progress Tracking**: Real-time progress indicators for each node -- **Flexible Authentication**: Support for SSH keys and SSH agent +- **Flexible Authentication**: Support for SSH keys, SSH agent, password authentication, and encrypted key passphrases - **Host Key Verification**: Secure host key checking with known_hosts support - **Cross-Platform**: Works on Linux and macOS - **Output Management**: Save command outputs to files per node with detailed logging @@ -35,6 +35,12 @@ bssh -c staging -i ~/.ssh/custom_key "systemctl status nginx" # Use SSH agent for authentication bssh --use-agent -c production "systemctl status nginx" +# Use password authentication (will prompt for password) +bssh --password -H "user@host.com" "uptime" + +# Use encrypted SSH key (will prompt for passphrase) +bssh -i ~/.ssh/encrypted_key -c production "df -h" + # Limit parallel connections bssh -c production --parallel 5 "apt update" ``` diff --git a/src/cli.rs b/src/cli.rs index d807b4e6..77e03020 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -57,6 +57,13 @@ pub struct Cli { )] pub use_agent: bool, + #[arg( + short = 'P', + long, + help = "Use password authentication (will prompt for password)" + )] + pub password: bool, + #[arg( short = 'p', long, diff --git a/src/executor.rs b/src/executor.rs index a13e344f..fdfb4b71 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -28,6 +28,7 @@ pub struct ParallelExecutor { key_path: Option, strict_mode: StrictHostKeyChecking, use_agent: bool, + use_password: bool, } impl ParallelExecutor { @@ -52,6 +53,7 @@ impl ParallelExecutor { key_path, strict_mode, use_agent: false, + use_password: false, } } @@ -68,6 +70,25 @@ impl ParallelExecutor { key_path, strict_mode, use_agent, + use_password: false, + } + } + + pub fn new_with_all_options( + nodes: Vec, + max_parallel: usize, + key_path: Option, + strict_mode: StrictHostKeyChecking, + use_agent: bool, + use_password: bool, + ) -> Self { + Self { + nodes, + max_parallel, + key_path, + strict_mode, + use_agent, + use_password, } } @@ -89,6 +110,7 @@ impl ParallelExecutor { let key_path = self.key_path.clone(); let strict_mode = self.strict_mode; let use_agent = self.use_agent; + let use_password = self.use_password; let semaphore = Arc::clone(&semaphore); let pb = multi_progress.add(ProgressBar::new_spinner()); pb.set_style(style.clone()); @@ -107,6 +129,7 @@ impl ParallelExecutor { key_path.as_deref(), strict_mode, use_agent, + use_password, ) .await; @@ -170,6 +193,7 @@ impl ParallelExecutor { let key_path = self.key_path.clone(); let strict_mode = self.strict_mode; let use_agent = self.use_agent; + let use_password = self.use_password; let semaphore = Arc::clone(&semaphore); let pb = multi_progress.add(ProgressBar::new_spinner()); pb.set_style(style.clone()); @@ -189,6 +213,7 @@ impl ParallelExecutor { key_path.as_deref(), strict_mode, use_agent, + use_password, ) .await; @@ -245,6 +270,7 @@ impl ParallelExecutor { let key_path = self.key_path.clone(); let strict_mode = self.strict_mode; let use_agent = self.use_agent; + let use_password = self.use_password; let semaphore = Arc::clone(&semaphore); let pb = multi_progress.add(ProgressBar::new_spinner()); pb.set_style(style.clone()); @@ -276,6 +302,7 @@ impl ParallelExecutor { key_path.as_deref(), strict_mode, use_agent, + use_password, ) .await; @@ -338,6 +365,7 @@ impl ParallelExecutor { let key_path = self.key_path.clone(); let strict_mode = self.strict_mode; let use_agent = self.use_agent; + let use_password = self.use_password; let semaphore = Arc::clone(&semaphore); let pb = multi_progress.add(ProgressBar::new_spinner()); pb.set_style(style.clone()); @@ -368,6 +396,7 @@ impl ParallelExecutor { key_path.as_deref(), strict_mode, use_agent, + use_password, ) .await; @@ -411,13 +440,20 @@ async fn execute_on_node( key_path: Option<&str>, strict_mode: StrictHostKeyChecking, use_agent: bool, + use_password: bool, ) -> Result { let mut client = SshClient::new(node.host.clone(), node.port, node.username.clone()); let key_path = key_path.map(Path::new); client - .connect_and_execute_with_host_check(command, key_path, Some(strict_mode), use_agent) + .connect_and_execute_with_host_check( + command, + key_path, + Some(strict_mode), + use_agent, + use_password, + ) .await } @@ -428,6 +464,7 @@ async fn upload_to_node( key_path: Option<&str>, strict_mode: StrictHostKeyChecking, use_agent: bool, + use_password: bool, ) -> Result<()> { let mut client = SshClient::new(node.host.clone(), node.port, node.username.clone()); @@ -442,6 +479,7 @@ async fn upload_to_node( key_path, Some(strict_mode), use_agent, + use_password, ) .await } else { @@ -452,6 +490,7 @@ async fn upload_to_node( key_path, Some(strict_mode), use_agent, + use_password, ) .await } @@ -464,6 +503,7 @@ async fn download_from_node( key_path: Option<&str>, strict_mode: StrictHostKeyChecking, use_agent: bool, + use_password: bool, ) -> Result { let mut client = SshClient::new(node.host.clone(), node.port, node.username.clone()); @@ -478,6 +518,7 @@ async fn download_from_node( key_path, Some(strict_mode), use_agent, + use_password, ) .await?; @@ -491,6 +532,7 @@ pub async fn download_dir_from_node( key_path: Option<&str>, strict_mode: StrictHostKeyChecking, use_agent: bool, + use_password: bool, ) -> Result { let mut client = SshClient::new(node.host.clone(), node.port, node.username.clone()); @@ -503,6 +545,7 @@ pub async fn download_dir_from_node( key_path, Some(strict_mode), use_agent, + use_password, ) .await?; diff --git a/src/main.rs b/src/main.rs index 56f3ce92..5aaedce1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,7 @@ struct ExecuteCommandParams<'a> { verbose: bool, strict_mode: StrictHostKeyChecking, use_agent: bool, + use_password: bool, output_dir: Option<&'a Path>, } @@ -45,6 +46,7 @@ struct FileTransferParams<'a> { key_path: Option<&'a Path>, strict_mode: StrictHostKeyChecking, use_agent: bool, + use_password: bool, recursive: bool, } @@ -96,6 +98,7 @@ async fn main() -> Result<()> { cli.identity.as_deref(), strict_mode, cli.use_agent, + cli.password, ) .await?; } @@ -110,6 +113,7 @@ async fn main() -> Result<()> { key_path: cli.identity.as_deref(), strict_mode, use_agent: cli.use_agent, + use_password: cli.password, recursive, }; upload_file(params, &source, &destination).await?; @@ -125,6 +129,7 @@ async fn main() -> Result<()> { key_path: cli.identity.as_deref(), strict_mode, use_agent: cli.use_agent, + use_password: cli.password, recursive, }; download_file(params, &source, &destination).await?; @@ -139,6 +144,7 @@ async fn main() -> Result<()> { verbose: cli.verbose > 0, strict_mode, use_agent: cli.use_agent, + use_password: cli.password, output_dir: cli.output_dir.as_deref(), }; execute_command(params).await?; @@ -207,16 +213,18 @@ async fn ping_nodes( key_path: Option<&Path>, strict_mode: StrictHostKeyChecking, use_agent: bool, + use_password: bool, ) -> Result<()> { println!("Pinging {} nodes...\n", nodes.len()); let key_path = key_path.map(|p| p.to_string_lossy().to_string()); - let executor = ParallelExecutor::new_with_strict_mode_and_agent( + let executor = ParallelExecutor::new_with_all_options( nodes.clone(), max_parallel, key_path, strict_mode, use_agent, + use_password, ); let results = executor.execute("echo 'pong'").await?; @@ -250,12 +258,13 @@ async fn execute_command(params: ExecuteCommandParams<'_>) -> Result<()> { ); let key_path = params.key_path.map(|p| p.to_string_lossy().to_string()); - let executor = ParallelExecutor::new_with_strict_mode_and_agent( + let executor = ParallelExecutor::new_with_all_options( params.nodes, params.max_parallel, key_path, params.strict_mode, params.use_agent, + params.use_password, ); let results = executor.execute(params.command).await?; @@ -475,12 +484,13 @@ async fn upload_file( println!("Destination: {destination}\n"); let key_path_str = params.key_path.map(|p| p.to_string_lossy().to_string()); - let executor = ParallelExecutor::new_with_strict_mode_and_agent( + let executor = ParallelExecutor::new_with_all_options( params.nodes.clone(), params.max_parallel, key_path_str.clone(), params.strict_mode, params.use_agent, + params.use_password, ); let mut total_success = 0; @@ -686,12 +696,13 @@ async fn download_file( } let key_path_str = params.key_path.map(|p| p.to_string_lossy().to_string()); - let executor = ParallelExecutor::new_with_strict_mode_and_agent( + let executor = ParallelExecutor::new_with_all_options( params.nodes.clone(), params.max_parallel, key_path_str.clone(), params.strict_mode, params.use_agent, + params.use_password, ); // Check if source contains glob pattern @@ -735,6 +746,7 @@ async fn download_file( key_path_str.as_deref(), params.strict_mode, params.use_agent, + params.use_password, ) .await; @@ -783,6 +795,7 @@ async fn download_file( params.key_path, Some(params.strict_mode), params.use_agent, + params.use_password, ) .await?; diff --git a/src/ssh/client.rs b/src/ssh/client.rs index bd3bd2c7..8b294e99 100644 --- a/src/ssh/client.rs +++ b/src/ssh/client.rs @@ -40,7 +40,7 @@ impl SshClient { key_path: Option<&Path>, use_agent: bool, ) -> Result { - self.connect_and_execute_with_host_check(command, key_path, None, use_agent) + self.connect_and_execute_with_host_check(command, key_path, None, use_agent, false) .await } @@ -50,12 +50,13 @@ impl SshClient { key_path: Option<&Path>, strict_mode: Option, use_agent: bool, + use_password: bool, ) -> Result { let addr = (self.host.as_str(), self.port); tracing::debug!("Connecting to {}:{}", self.host, self.port); // Determine authentication method based on parameters - let auth_method = self.determine_auth_method(key_path, use_agent)?; + let auth_method = self.determine_auth_method(key_path, use_agent, use_password)?; // Set up host key checking let check_method = if let Some(mode) = strict_mode { @@ -108,12 +109,13 @@ impl SshClient { key_path: Option<&Path>, strict_mode: Option, use_agent: bool, + use_password: bool, ) -> Result<()> { let addr = (self.host.as_str(), self.port); tracing::debug!("Connecting to {}:{} for file copy", self.host, self.port); // Determine authentication method based on parameters - let auth_method = self.determine_auth_method(key_path, use_agent)?; + let auth_method = self.determine_auth_method(key_path, use_agent, use_password)?; // Set up host key checking let check_method = if let Some(mode) = strict_mode { @@ -184,6 +186,7 @@ impl SshClient { key_path: Option<&Path>, strict_mode: Option, use_agent: bool, + use_password: bool, ) -> Result<()> { let addr = (self.host.as_str(), self.port); tracing::debug!( @@ -193,7 +196,7 @@ impl SshClient { ); // Determine authentication method based on parameters - let auth_method = self.determine_auth_method(key_path, use_agent)?; + let auth_method = self.determine_auth_method(key_path, use_agent, use_password)?; // Set up host key checking let check_method = if let Some(mode) = strict_mode { @@ -260,6 +263,7 @@ impl SshClient { key_path: Option<&Path>, strict_mode: Option, use_agent: bool, + use_password: bool, ) -> Result<()> { let addr = (self.host.as_str(), self.port); tracing::debug!( @@ -269,7 +273,7 @@ impl SshClient { ); // Determine authentication method based on parameters - let auth_method = self.determine_auth_method(key_path, use_agent)?; + let auth_method = self.determine_auth_method(key_path, use_agent, use_password)?; // Set up host key checking let check_method = if let Some(mode) = strict_mode { @@ -338,6 +342,7 @@ impl SshClient { key_path: Option<&Path>, strict_mode: Option, use_agent: bool, + use_password: bool, ) -> Result<()> { let addr = (self.host.as_str(), self.port); tracing::debug!( @@ -347,7 +352,7 @@ impl SshClient { ); // Determine authentication method based on parameters - let auth_method = self.determine_auth_method(key_path, use_agent)?; + let auth_method = self.determine_auth_method(key_path, use_agent, use_password)?; // Set up host key checking let check_method = if let Some(mode) = strict_mode { @@ -411,7 +416,19 @@ impl SshClient { &self, key_path: Option<&Path>, use_agent: bool, + use_password: bool, ) -> Result { + // If password authentication is explicitly requested + if use_password { + tracing::debug!("Using password authentication"); + let password = rpassword::prompt_password(format!( + "Enter password for {}@{}: ", + self.username, self.host + )) + .with_context(|| "Failed to read password")?; + return Ok(AuthMethod::with_password(&password)); + } + // If SSH agent is explicitly requested, try that first if use_agent { #[cfg(not(target_os = "windows"))] @@ -436,7 +453,25 @@ impl SshClient { // Try key file authentication if let Some(key_path) = key_path { tracing::debug!("Authenticating with key: {:?}", key_path); - return Ok(AuthMethod::with_key_file(key_path, None)); + + // Check if the key is encrypted by attempting to read it + let key_contents = std::fs::read_to_string(key_path) + .with_context(|| format!("Failed to read SSH key file: {key_path:?}"))?; + + let passphrase = if key_contents.contains("ENCRYPTED") + || key_contents.contains("Proc-Type: 4,ENCRYPTED") + { + tracing::debug!("Detected encrypted SSH key, prompting for passphrase"); + let pass = rpassword::prompt_password(format!( + "Enter passphrase for key {key_path:?}: " + )) + .with_context(|| "Failed to read passphrase")?; + Some(pass) + } else { + None + }; + + return Ok(AuthMethod::with_key_file(key_path, passphrase.as_deref())); } // If no explicit key path, try SSH agent if available (auto-detect) @@ -446,26 +481,58 @@ impl SshClient { return Ok(AuthMethod::Agent); } - // Fallback to default key location + // Fallback to default key locations let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); - let default_key = Path::new(&home).join(".ssh").join("id_rsa"); + let home_path = Path::new(&home).join(".ssh"); + + // Try common key files in order of preference + let default_keys = [ + home_path.join("id_ed25519"), + home_path.join("id_rsa"), + home_path.join("id_ecdsa"), + home_path.join("id_dsa"), + ]; + + for default_key in &default_keys { + if default_key.exists() { + tracing::debug!("Using default key: {:?}", default_key); + + // Check if the key is encrypted + let key_contents = std::fs::read_to_string(default_key) + .with_context(|| format!("Failed to read SSH key file: {default_key:?}"))?; + + let passphrase = if key_contents.contains("ENCRYPTED") + || key_contents.contains("Proc-Type: 4,ENCRYPTED") + { + tracing::debug!("Detected encrypted SSH key, prompting for passphrase"); + let pass = rpassword::prompt_password(format!( + "Enter passphrase for key {default_key:?}: " + )) + .with_context(|| "Failed to read passphrase")?; + Some(pass) + } else { + None + }; - if default_key.exists() { - tracing::debug!("Using default key: {:?}", default_key); - Ok(AuthMethod::with_key_file(default_key, None)) - } else { - anyhow::bail!( - "SSH authentication failed: No authentication method available.\n\ - Tried:\n\ - - SSH agent (SSH_AUTH_SOCK not set or agent not available)\n\ - - Default key file (~/.ssh/id_rsa not found)\n\ - \n\ - Solutions:\n\ - - Start SSH agent and add keys with 'ssh-add'\n\ - - Specify a key file with -i/--identity\n\ - - Create a default key at ~/.ssh/id_rsa" - ); + return Ok(AuthMethod::with_key_file( + default_key, + passphrase.as_deref(), + )); + } } + + anyhow::bail!( + "SSH authentication failed: No authentication method available.\n\ + Tried:\n\ + - SSH agent (SSH_AUTH_SOCK not set or agent not available)\n\ + - Default key files (~/.ssh/id_ed25519, ~/.ssh/id_rsa, etc. not found)\n\ + \n\ + Solutions:\n\ + - Use --password for password authentication\n\ + - Start SSH agent and add keys with 'ssh-add'\n\ + - Specify a key file with -i/--identity\n\ + - Create a default key at ~/.ssh/id_ed25519 or ~/.ssh/id_rsa" + ); } } @@ -554,7 +621,7 @@ mod tests { let client = SshClient::new("test.com".to_string(), 22, "user".to_string()); let auth = client - .determine_auth_method(Some(&key_path), false) + .determine_auth_method(Some(&key_path), false, false) .unwrap(); match auth { @@ -573,7 +640,7 @@ mod tests { } let client = SshClient::new("test.com".to_string(), 22, "user".to_string()); - let auth = client.determine_auth_method(None, true).unwrap(); + let auth = client.determine_auth_method(None, true, false).unwrap(); match auth { AuthMethod::Agent => {} @@ -585,6 +652,15 @@ mod tests { } } + #[test] + fn test_determine_auth_method_with_password() { + let _client = SshClient::new("test.com".to_string(), 22, "user".to_string()); + + // Note: We can't actually test password prompt in unit tests + // as it requires terminal input. This would need integration testing. + // For now, we just verify the function compiles with the new parameter. + } + #[test] fn test_determine_auth_method_fallback_to_default() { // Create a fake home directory with default key @@ -600,7 +676,7 @@ mod tests { } let client = SshClient::new("test.com".to_string(), 22, "user".to_string()); - let auth = client.determine_auth_method(None, false).unwrap(); + let auth = client.determine_auth_method(None, false, false).unwrap(); match auth { AuthMethod::PrivateKeyFile { key_file_path, .. } => { From 2a1d904661e922e6a0619a44ab8bd116b2da938b Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Fri, 22 Aug 2025 21:47:57 +0900 Subject: [PATCH 2/3] update: add password auth and passphrase support to documentation - Update man page with password flag and passphrase information - Add authentication section to README with examples - Improve help text for identity flag to mention passphrase support - Keep version number unchanged at 0.3.0 --- Cargo.lock | 2 +- README.md | 32 ++++++++++++++++++++++++++++++++ docs/man/bssh.1 | 33 ++++++++++++++++++++++++++++----- src/cli.rs | 2 +- 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7d708b03..9d407dcb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,7 +276,7 @@ dependencies = [ [[package]] name = "bssh" -version = "0.3.0" +version = "0.3.1" dependencies = [ "anyhow", "async-trait", diff --git a/README.md b/README.md index f0cd12c7..4770cdc8 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,38 @@ bssh -c production ping bssh list ``` +## Authentication + +bssh supports multiple authentication methods: + +### SSH Key Authentication +- **Default keys**: Automatically tries `~/.ssh/id_ed25519`, `~/.ssh/id_rsa`, `~/.ssh/id_ecdsa`, `~/.ssh/id_dsa` +- **Custom key**: Use `-i` flag to specify a key file +- **Encrypted keys**: Automatically detects and prompts for passphrase + +### SSH Agent +- **Auto-detection**: Automatically uses SSH agent if `SSH_AUTH_SOCK` is set +- **Explicit**: Use `-A` flag to force SSH agent authentication + +### Password Authentication +- Use `-P` flag to enable password authentication +- Password is prompted securely without echo + +### Examples +```bash +# Use default SSH key (auto-detect) +bssh -H "user@host" "uptime" + +# Use specific SSH key (prompts for passphrase if encrypted) +bssh -i ~/.ssh/custom_key -c production "df -h" + +# Use SSH agent +bssh -A -c production "systemctl status" + +# Use password authentication +bssh -P -H "user@host" "ls -la" +``` + ## Configuration ### Configuration Priority Order diff --git a/docs/man/bssh.1 b/docs/man/bssh.1 index 6c78c068..b0a2c53e 100644 --- a/docs/man/bssh.1 +++ b/docs/man/bssh.1 @@ -14,8 +14,9 @@ bssh \- Backend.AI SSH - Parallel command execution across cluster nodes is a high-performance parallel SSH command execution tool for cluster management, built with Rust. It enables efficient execution of commands across multiple nodes simultaneously with real-time output streaming. The tool provides secure file transfer capabilities using SFTP protocol for both uploading and downloading files -to/from multiple remote hosts in parallel. It automatically detects Backend.AI multi-node session environments -and supports various configuration methods. +to/from multiple remote hosts in parallel. It supports multiple authentication methods including SSH keys (with +passphrase support for encrypted keys), SSH agent, and password authentication. It automatically detects Backend.AI +multi-node session environments and supports various configuration methods. .SH OPTIONS .TP @@ -37,7 +38,8 @@ Default username for SSH connections .TP .BR \-i ", " \-\-identity " " \fIIDENTITY\fR -SSH private key file path +SSH private key file path. If the key is encrypted, bssh will +automatically prompt for the passphrase. .TP .BR \-A ", " \-\-use\-agent @@ -46,6 +48,12 @@ When this option is specified, bssh will attempt to use the SSH agent for authentication. Falls back to key file authentication if the agent is not available or authentication fails. +.TP +.BR \-P ", " \-\-password +Use password authentication. When this option is specified, bssh will +prompt for the password securely without echoing it to the terminal. +This is useful for systems that don't have SSH keys configured. + .TP .BR \-p ", " \-\-parallel " " \fIPARALLEL\fR Maximum parallel connections (default: 10) @@ -205,6 +213,20 @@ Use custom SSH key: Use SSH agent for authentication: .B bssh -A -c production "systemctl status" +.TP +Use password authentication: +.B bssh -P -H "user@host.com" "uptime" +.RS +Prompts for password interactively +.RE + +.TP +Use encrypted SSH key: +.B bssh -i ~/.ssh/encrypted_key -c production "df -h" +.RS +Automatically detects encrypted key and prompts for passphrase +.RE + .TP Save output to files: .B bssh --output-dir ./results -c production "ps aux" @@ -299,8 +321,9 @@ User configuration directory location SSH known hosts file for host key verification .TP -.I ~/.ssh/id_rsa -Default SSH private key +.I ~/.ssh/id_ed25519, ~/.ssh/id_rsa, ~/.ssh/id_ecdsa, ~/.ssh/id_dsa +Default SSH private keys (checked in order of preference). If a key is +encrypted, bssh will prompt for the passphrase. .TP .I $SSH_AUTH_SOCK diff --git a/src/cli.rs b/src/cli.rs index 77e03020..d3a31294 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -47,7 +47,7 @@ pub struct Cli { #[arg(short = 'u', long, help = "Default username for SSH connections")] pub user: Option, - #[arg(short = 'i', long, help = "SSH private key file path")] + #[arg(short = 'i', long, help = "SSH private key file path (prompts for passphrase if encrypted)")] pub identity: Option, #[arg( From a7341e22ea961fdc358ec489c9d2c6e29cbb930e Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Fri, 22 Aug 2025 21:54:16 +0900 Subject: [PATCH 3/3] fix: use current user instead of root as default username - Replace hardcoded 'root' fallback with current user detection - Try USER, USERNAME, LOGNAME environment variables first - Use whoami crate to get system username as last resort - Add test to verify current user is used correctly - Ensures non-root users don't accidentally connect as root --- Cargo.lock | 21 ++++++++++++++++++++- Cargo.toml | 1 + src/cli.rs | 6 +++++- src/config.rs | 28 ++++++++++++++++++++++++++-- src/node.rs | 29 ++++++++++++++++++++++++++++- src/ssh/client.rs | 7 +++---- 6 files changed, 83 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9d407dcb..2414309c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,7 +276,7 @@ dependencies = [ [[package]] name = "bssh" -version = "0.3.1" +version = "0.3.0" dependencies = [ "anyhow", "async-trait", @@ -298,6 +298,7 @@ dependencies = [ "tokio", "tracing", "tracing-subscriber", + "whoami", ] [[package]] @@ -1253,6 +1254,7 @@ checksum = "391290121bad3d37fbddad76d8f5d1c1c314cfc646d143d7e07a3086ddff0ce3" dependencies = [ "bitflags", "libc", + "redox_syscall", ] [[package]] @@ -2536,6 +2538,12 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -2627,6 +2635,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "whoami" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" +dependencies = [ + "libredox", + "wasite", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index b58bd352..ffc42057 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ directories = "6" dirs = "6.0" chrono = "0.4" glob = "0.3" +whoami = "1.5" [dev-dependencies] tempfile = "3" diff --git a/src/cli.rs b/src/cli.rs index d3a31294..b272a057 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -47,7 +47,11 @@ pub struct Cli { #[arg(short = 'u', long, help = "Default username for SSH connections")] pub user: Option, - #[arg(short = 'i', long, help = "SSH private key file path (prompts for passphrase if encrypted)")] + #[arg( + short = 'i', + long, + help = "SSH private key file path (prompts for passphrase if encrypted)" + )] pub identity: Option, #[arg( diff --git a/src/config.rs b/src/config.rs index 41ba98d9..206cfd59 100644 --- a/src/config.rs +++ b/src/config.rs @@ -103,7 +103,18 @@ impl Config { // Get current user as default let default_user = env::var("USER") .or_else(|_| env::var("USERNAME")) - .unwrap_or_else(|_| "root".to_string()); + .or_else(|_| env::var("LOGNAME")) + .unwrap_or_else(|_| { + // Try to get current user from system + #[cfg(unix)] + { + whoami::username() + } + #[cfg(not(unix))] + { + "user".to_string() + } + }); // Backend.AI multi-node clusters use port 2200 by default nodes.push(NodeConfig::Simple(format!("{default_user}@{host}:2200"))); @@ -215,7 +226,20 @@ impl Config { .or_else(|| cluster.defaults.user.as_ref().map(|u| expand_env_vars(u))) .or_else(|| self.defaults.user.as_ref().map(|u| expand_env_vars(u))) .unwrap_or_else(|| { - std::env::var("USER").unwrap_or_else(|_| "root".to_string()) + std::env::var("USER") + .or_else(|_| std::env::var("USERNAME")) + .or_else(|_| std::env::var("LOGNAME")) + .unwrap_or_else(|_| { + // Try to get current user from system + #[cfg(unix)] + { + whoami::username() + } + #[cfg(not(unix))] + { + "user".to_string() + } + }) }); let port = port diff --git a/src/node.rs b/src/node.rs index 3953b828..2c882429 100644 --- a/src/node.rs +++ b/src/node.rs @@ -61,7 +61,18 @@ impl Node { .unwrap_or_else(|| { std::env::var("USER") .or_else(|_| std::env::var("USERNAME")) - .unwrap_or_else(|_| "root".to_string()) + .or_else(|_| std::env::var("LOGNAME")) + .unwrap_or_else(|_| { + // Try to get current user from system + #[cfg(unix)] + { + whoami::username() + } + #[cfg(not(unix))] + { + "user".to_string() + } + }) }); Ok(Node { @@ -121,4 +132,20 @@ mod tests { let node = Node::parse("example.com", Some("default_user")).unwrap(); assert_eq!(node.username, "default_user"); } + + #[test] + fn test_parse_uses_current_user_when_no_default() { + // When no user is specified, it should use current user from environment + let node = Node::parse("example.com", None).unwrap(); + // Should not be "root" unless the current user is actually root + let current_user = std::env::var("USER") + .or_else(|_| std::env::var("USERNAME")) + .or_else(|_| std::env::var("LOGNAME")) + .unwrap_or_else(|_| whoami::username()); + assert_eq!(node.username, current_user); + // Specifically verify it doesn't default to root when we're not root + if current_user != "root" { + assert_ne!(node.username, "root"); + } + } } diff --git a/src/ssh/client.rs b/src/ssh/client.rs index 8b294e99..7c62f3d0 100644 --- a/src/ssh/client.rs +++ b/src/ssh/client.rs @@ -462,10 +462,9 @@ impl SshClient { || key_contents.contains("Proc-Type: 4,ENCRYPTED") { tracing::debug!("Detected encrypted SSH key, prompting for passphrase"); - let pass = rpassword::prompt_password(format!( - "Enter passphrase for key {key_path:?}: " - )) - .with_context(|| "Failed to read passphrase")?; + let pass = + rpassword::prompt_password(format!("Enter passphrase for key {key_path:?}: ")) + .with_context(|| "Failed to read passphrase")?; Some(pass) } else { None