diff --git a/src/commands/interactive/connection.rs b/src/commands/interactive/connection.rs index ea4ce61b..4b229798 100644 --- a/src/commands/interactive/connection.rs +++ b/src/commands/interactive/connection.rs @@ -70,13 +70,18 @@ impl InteractiveCommand { ) })?; - // Check if key authentication failed and password fallback is allowed + // Check if authentication failed and password fallback is allowed + // This matches SSH key failures as well as SSH agent authentication failures let result = match result { - Err(SshError::KeyAuthFailed) - if allow_password_fallback && atty::is(atty::Stream::Stdin) => - { + Err( + SshError::KeyAuthFailed + | SshError::AgentAuthenticationFailed + | SshError::AgentNoIdentities + | SshError::AgentConnectionFailed + | SshError::AgentRequestIdentitiesFailed, + ) if allow_password_fallback && atty::is(atty::Stream::Stdin) => { tracing::debug!( - "SSH key authentication failed for {username}@{host}:{port}, attempting password fallback" + "SSH authentication failed for {username}@{host}:{port}, attempting password fallback" ); // Prompt for password (matching OpenSSH behavior) @@ -459,3 +464,114 @@ impl InteractiveCommand { Ok(channel) } } + +/// Check if an SSH error indicates an authentication failure that should trigger password fallback. +/// +/// This function returns true for errors that occur when: +/// - SSH key authentication fails (server rejects the key) +/// - SSH agent authentication fails (agent has keys but server rejects them) +/// - SSH agent has no identities loaded +/// - SSH agent connection fails +/// - SSH agent identity request fails +/// +/// These are all cases where falling back to password authentication makes sense, +/// matching OpenSSH's behavior. +pub fn is_auth_error_for_password_fallback(error: &SshError) -> bool { + matches!( + error, + SshError::KeyAuthFailed + | SshError::AgentAuthenticationFailed + | SshError::AgentNoIdentities + | SshError::AgentConnectionFailed + | SshError::AgentRequestIdentitiesFailed + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_key_auth_failed_triggers_password_fallback() { + let error = SshError::KeyAuthFailed; + assert!( + is_auth_error_for_password_fallback(&error), + "KeyAuthFailed should trigger password fallback" + ); + } + + #[test] + fn test_agent_auth_failed_triggers_password_fallback() { + let error = SshError::AgentAuthenticationFailed; + assert!( + is_auth_error_for_password_fallback(&error), + "AgentAuthenticationFailed should trigger password fallback" + ); + } + + #[test] + fn test_agent_no_identities_triggers_password_fallback() { + let error = SshError::AgentNoIdentities; + assert!( + is_auth_error_for_password_fallback(&error), + "AgentNoIdentities should trigger password fallback" + ); + } + + #[test] + fn test_agent_connection_failed_triggers_password_fallback() { + let error = SshError::AgentConnectionFailed; + assert!( + is_auth_error_for_password_fallback(&error), + "AgentConnectionFailed should trigger password fallback" + ); + } + + #[test] + fn test_agent_request_identities_failed_triggers_password_fallback() { + let error = SshError::AgentRequestIdentitiesFailed; + assert!( + is_auth_error_for_password_fallback(&error), + "AgentRequestIdentitiesFailed should trigger password fallback" + ); + } + + #[test] + fn test_password_wrong_does_not_trigger_fallback() { + let error = SshError::PasswordWrong; + assert!( + !is_auth_error_for_password_fallback(&error), + "PasswordWrong should NOT trigger password fallback (already tried password)" + ); + } + + #[test] + fn test_server_check_failed_does_not_trigger_fallback() { + let error = SshError::ServerCheckFailed; + assert!( + !is_auth_error_for_password_fallback(&error), + "ServerCheckFailed should NOT trigger password fallback (host key issue)" + ); + } + + #[test] + fn test_io_error_does_not_trigger_fallback() { + let error = SshError::IoError(std::io::Error::new( + std::io::ErrorKind::ConnectionRefused, + "connection refused", + )); + assert!( + !is_auth_error_for_password_fallback(&error), + "IoError should NOT trigger password fallback (network issue)" + ); + } + + #[test] + fn test_keyboard_interactive_auth_failed_does_not_trigger_fallback() { + let error = SshError::KeyboardInteractiveAuthFailed; + assert!( + !is_auth_error_for_password_fallback(&error), + "KeyboardInteractiveAuthFailed should NOT trigger password fallback" + ); + } +} diff --git a/src/commands/interactive/mod.rs b/src/commands/interactive/mod.rs index d51c2430..8950e739 100644 --- a/src/commands/interactive/mod.rs +++ b/src/commands/interactive/mod.rs @@ -53,7 +53,7 @@ //! - `InteractiveResult`: Summary of interactive session execution mod commands; -mod connection; +pub mod connection; mod execution; mod multiplex; mod single_node; diff --git a/tests/password_fallback_test.rs b/tests/password_fallback_test.rs new file mode 100644 index 00000000..fffb6a62 --- /dev/null +++ b/tests/password_fallback_test.rs @@ -0,0 +1,107 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Tests for password fallback functionality in SSH connections. +//! +//! These tests verify that the password fallback mechanism correctly triggers +//! for various SSH authentication error types, including SSH agent errors. + +use bssh::commands::interactive::connection::is_auth_error_for_password_fallback; +use bssh::ssh::tokio_client::Error as SshError; + +/// Test that all SSH agent-related authentication failures trigger password fallback +#[test] +fn test_all_agent_errors_trigger_password_fallback() { + let agent_errors = vec![ + ( + SshError::AgentAuthenticationFailed, + "AgentAuthenticationFailed", + ), + (SshError::AgentNoIdentities, "AgentNoIdentities"), + (SshError::AgentConnectionFailed, "AgentConnectionFailed"), + ( + SshError::AgentRequestIdentitiesFailed, + "AgentRequestIdentitiesFailed", + ), + ]; + + for (error, name) in agent_errors { + assert!( + is_auth_error_for_password_fallback(&error), + "{} should trigger password fallback", + name + ); + } +} + +/// Test that key authentication failure triggers password fallback +#[test] +fn test_key_auth_failure_triggers_password_fallback() { + let error = SshError::KeyAuthFailed; + assert!( + is_auth_error_for_password_fallback(&error), + "KeyAuthFailed should trigger password fallback" + ); +} + +/// Test that non-authentication errors do NOT trigger password fallback +#[test] +fn test_non_auth_errors_do_not_trigger_fallback() { + let non_auth_errors: Vec<(SshError, &str)> = vec![ + (SshError::PasswordWrong, "PasswordWrong"), + (SshError::ServerCheckFailed, "ServerCheckFailed"), + (SshError::CommandDidntExit, "CommandDidntExit"), + ( + SshError::KeyboardInteractiveAuthFailed, + "KeyboardInteractiveAuthFailed", + ), + ( + SshError::IoError(std::io::Error::new( + std::io::ErrorKind::ConnectionRefused, + "connection refused", + )), + "IoError", + ), + ]; + + for (error, name) in non_auth_errors { + assert!( + !is_auth_error_for_password_fallback(&error), + "{} should NOT trigger password fallback", + name + ); + } +} + +/// Test that PasswordWrong specifically does not trigger fallback +/// (to prevent infinite password retry loops) +#[test] +fn test_password_wrong_prevents_infinite_loop() { + let error = SshError::PasswordWrong; + assert!( + !is_auth_error_for_password_fallback(&error), + "PasswordWrong must NOT trigger password fallback to prevent infinite retry loops" + ); +} + +/// Test that ServerCheckFailed (host key verification) does not trigger password fallback +/// (security: host key issues should not be bypassed) +#[test] +fn test_host_key_verification_not_bypassed() { + let error = SshError::ServerCheckFailed; + assert!( + !is_auth_error_for_password_fallback(&error), + "ServerCheckFailed must NOT trigger password fallback - host key verification is a security feature" + ); +}