diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 84c826f5..3a960cc8 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -178,6 +178,17 @@ MPI-compatible exit code handling: - Automatic main rank detection (Backend.AI integration) - Preserves actual exit codes (SIGSEGV=139, OOM=137, etc.) +### Shared Module + +Common utilities for code reuse between bssh client and potential server implementations: + +- **Validation**: Input validation for usernames, hostnames, paths with security checks +- **Rate Limiting**: Generic token bucket rate limiter for connection/auth throttling +- **Authentication Types**: Common auth result types and user info structures +- **Error Types**: Shared error types for validation, auth, connection, and rate limiting + +The `security` and `jump::rate_limiter` modules re-export from shared for backward compatibility. + ## Data Flow ### Command Execution Flow diff --git a/docs/architecture/README.md b/docs/architecture/README.md index c3f2a2b9..abd2fd9c 100644 --- a/docs/architecture/README.md +++ b/docs/architecture/README.md @@ -71,6 +71,8 @@ src/ ├── interactive/ → Interactive Mode ├── jump/ → Jump Host Support ├── forward/ → Port Forwarding +├── shared/ → Shared utilities (validation, rate limiting, auth types, errors) +├── security/ → Security utilities (re-exports from shared for compatibility) └── commands/ → Command Implementations ``` diff --git a/src/jump/chain.rs b/src/jump/chain.rs index 0922c3fd..200d7b13 100644 --- a/src/jump/chain.rs +++ b/src/jump/chain.rs @@ -145,7 +145,7 @@ impl JumpHostChain { /// * `max_burst` - Maximum number of connections allowed in a burst /// * `refill_rate` - Number of connections allowed per second (sustained rate) pub fn with_rate_limit(mut self, max_burst: u32, refill_rate: f64) -> Self { - self.rate_limiter = ConnectionRateLimiter::with_config(max_burst, refill_rate); + self.rate_limiter = ConnectionRateLimiter::with_simple_config(max_burst, refill_rate); self } @@ -354,7 +354,7 @@ impl JumpHostChain { // Apply rate limiting to prevent DoS attacks on jump hosts self.rate_limiter - .try_acquire(&jump_host.host) + .try_acquire(&jump_host.host.clone()) .await .with_context(|| format!("Rate limited for jump host {}", jump_host.host))?; diff --git a/src/jump/chain/chain_connection.rs b/src/jump/chain/chain_connection.rs index 35e0848a..8aa369d7 100644 --- a/src/jump/chain/chain_connection.rs +++ b/src/jump/chain/chain_connection.rs @@ -35,7 +35,7 @@ pub(super) async fn connect_direct( // Apply rate limiting to prevent DoS attacks rate_limiter - .try_acquire(host) + .try_acquire(&host.to_string()) .await .with_context(|| format!("Rate limited for host {host}"))?; diff --git a/src/jump/chain/tunnel.rs b/src/jump/chain/tunnel.rs index 02e58d7f..3c5c5e6d 100644 --- a/src/jump/chain/tunnel.rs +++ b/src/jump/chain/tunnel.rs @@ -46,7 +46,7 @@ pub(super) async fn connect_through_tunnel( // Apply rate limiting for intermediate jump hosts rate_limiter - .try_acquire(&jump_host.host) + .try_acquire(&jump_host.host.clone()) .await .with_context(|| format!("Rate limited for jump host {}", jump_host.host))?; @@ -184,7 +184,7 @@ pub(super) async fn connect_to_destination( // Apply rate limiting for final destination rate_limiter - .try_acquire(destination_host) + .try_acquire(&destination_host.to_string()) .await .with_context(|| format!("Rate limited for destination {destination_host}"))?; diff --git a/src/jump/mod.rs b/src/jump/mod.rs index a1491ac7..b411ebf8 100644 --- a/src/jump/mod.rs +++ b/src/jump/mod.rs @@ -25,10 +25,19 @@ //! * Connection reuse for multiple operations //! * Automatic retry with exponential backoff //! * Integration with existing host verification and authentication +//! +//! # Rate Limiter +//! +//! The rate limiter has been moved to `crate::shared::rate_limit` to enable +//! code reuse between the bssh client and server implementations. This module +//! continues to re-export it for backward compatibility. pub mod chain; pub mod connection; pub mod parser; + +// Keep the rate_limiter module for backward compatibility but it now re-exports +// from shared pub mod rate_limiter; pub use chain::{JumpConnection, JumpHostChain}; diff --git a/src/jump/rate_limiter.rs b/src/jump/rate_limiter.rs index eea21fbd..35ce4537 100644 --- a/src/jump/rate_limiter.rs +++ b/src/jump/rate_limiter.rs @@ -12,186 +12,78 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::{bail, Result}; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tokio::sync::RwLock; -use tracing::warn; - -/// Token bucket rate limiter for connection attempts -/// -/// Prevents DoS attacks by limiting the rate of connection attempts -/// per host. Uses a token bucket algorithm with configurable capacity -/// and refill rate. -#[derive(Debug, Clone)] -pub struct ConnectionRateLimiter { - /// Token buckets per host - buckets: Arc>>, - /// Maximum tokens per bucket (burst capacity) - max_tokens: u32, - /// Tokens refilled per second - refill_rate: f64, - /// Duration after which inactive buckets are cleaned up - cleanup_after: Duration, -} - -#[derive(Debug)] -struct TokenBucket { - /// Current token count - tokens: f64, - /// Last refill timestamp - last_refill: Instant, - /// Last access timestamp (for cleanup) - last_access: Instant, -} - -impl ConnectionRateLimiter { - /// Create a new rate limiter with default settings - /// - /// Default: 10 connections burst, 2 connections/second sustained - pub fn new() -> Self { - Self { - buckets: Arc::new(RwLock::new(HashMap::new())), - max_tokens: 10, // Allow burst of 10 connections - refill_rate: 2.0, // 2 connections per second sustained - cleanup_after: Duration::from_secs(300), // Clean up after 5 minutes - } - } - - /// Create a new rate limiter with custom settings - pub fn with_config(max_tokens: u32, refill_rate: f64) -> Self { - Self { - buckets: Arc::new(RwLock::new(HashMap::new())), - max_tokens, - refill_rate, - cleanup_after: Duration::from_secs(300), - } - } - - /// Try to acquire a token for a connection attempt - /// - /// Returns Ok(()) if a token was acquired, or an error if rate limited - pub async fn try_acquire(&self, host: &str) -> Result<()> { - let mut buckets = self.buckets.write().await; - let now = Instant::now(); - - // Clean up old buckets periodically - if buckets.len() > 100 { - self.cleanup_old_buckets(&mut buckets, now); - } - - let bucket = buckets - .entry(host.to_string()) - .or_insert_with(|| TokenBucket { - tokens: self.max_tokens as f64, - last_refill: now, - last_access: now, - }); - - // Refill tokens based on time elapsed - let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); - let tokens_to_add = elapsed * self.refill_rate; - bucket.tokens = (bucket.tokens + tokens_to_add).min(self.max_tokens as f64); - bucket.last_refill = now; - bucket.last_access = now; - - // Try to consume a token - if bucket.tokens >= 1.0 { - bucket.tokens -= 1.0; - Ok(()) - } else { - let wait_time = (1.0 - bucket.tokens) / self.refill_rate; - warn!( - "Rate limit exceeded for host {}: wait {:.1}s before retry", - host, wait_time - ); - bail!( - "Connection rate limit exceeded for {host}. Please wait {wait_time:.1} seconds before retrying." - ) - } - } - - /// Check if a host is currently rate limited without consuming a token - pub async fn is_rate_limited(&self, host: &str) -> bool { - let buckets = self.buckets.read().await; - if let Some(bucket) = buckets.get(host) { - let now = Instant::now(); - let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); - let tokens_available = - (bucket.tokens + elapsed * self.refill_rate).min(self.max_tokens as f64); - tokens_available < 1.0 - } else { - false - } - } - - /// Clean up old token buckets that haven't been used recently - fn cleanup_old_buckets(&self, buckets: &mut HashMap, now: Instant) { - buckets.retain(|_host, bucket| now.duration_since(bucket.last_access) < self.cleanup_after); - } - - /// Reset rate limit for a specific host (useful for testing or admin override) - pub async fn reset_host(&self, host: &str) { - let mut buckets = self.buckets.write().await; - buckets.remove(host); - } - - /// Clear all rate limit data - pub async fn clear_all(&self) { - let mut buckets = self.buckets.write().await; - buckets.clear(); - } -} - -impl Default for ConnectionRateLimiter { - fn default() -> Self { - Self::new() - } -} +//! Token bucket rate limiter for connection attempts. +//! +//! This module re-exports the rate limiter from the shared module for +//! backward compatibility. New code should prefer importing directly from +//! `crate::shared::rate_limit`. +//! +//! # Migration Note +//! +//! The rate limiter has been moved to `crate::shared::rate_limit` and +//! generalized to work with any hashable key type. This module continues +//! to export `ConnectionRateLimiter` (which is `RateLimiter`) for +//! backward compatibility. +//! +//! # Examples +//! +//! ```rust +//! // Old style (still works) +//! use bssh::jump::rate_limiter::ConnectionRateLimiter; +//! +//! // New style (preferred for new code) +//! use bssh::shared::rate_limit::{RateLimiter, RateLimitConfig}; +//! ``` + +// Re-export the ConnectionRateLimiter type alias for backward compatibility +pub use crate::shared::rate_limit::ConnectionRateLimiter; + +// Also re-export RateLimitConfig for users who want to configure the limiter +pub use crate::shared::rate_limit::RateLimitConfig; #[cfg(test)] mod tests { use super::*; + use std::time::Duration; #[tokio::test] async fn test_rate_limiter_allows_burst() { - let limiter = ConnectionRateLimiter::with_config(3, 1.0); + let limiter = ConnectionRateLimiter::with_simple_config(3, 1.0); // Should allow 3 connections in burst - assert!(limiter.try_acquire("test.com").await.is_ok()); - assert!(limiter.try_acquire("test.com").await.is_ok()); - assert!(limiter.try_acquire("test.com").await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); // 4th should fail - assert!(limiter.try_acquire("test.com").await.is_err()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_err()); } #[tokio::test] async fn test_rate_limiter_refills() { - let limiter = ConnectionRateLimiter::with_config(2, 10.0); // Fast refill for testing + let limiter = ConnectionRateLimiter::with_simple_config(2, 10.0); // Fast refill for testing // Use up tokens - assert!(limiter.try_acquire("test.com").await.is_ok()); - assert!(limiter.try_acquire("test.com").await.is_ok()); - assert!(limiter.try_acquire("test.com").await.is_err()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_err()); // Wait for refill tokio::time::sleep(Duration::from_millis(150)).await; // Should have refilled - assert!(limiter.try_acquire("test.com").await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); } #[tokio::test] async fn test_rate_limiter_per_host() { - let limiter = ConnectionRateLimiter::with_config(1, 1.0); + let limiter = ConnectionRateLimiter::with_simple_config(1, 1.0); // Different hosts should have separate buckets - assert!(limiter.try_acquire("host1.com").await.is_ok()); - assert!(limiter.try_acquire("host2.com").await.is_ok()); + assert!(limiter.try_acquire(&"host1.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"host2.com".to_string()).await.is_ok()); // But same host should be limited - assert!(limiter.try_acquire("host1.com").await.is_err()); + assert!(limiter.try_acquire(&"host1.com".to_string()).await.is_err()); } } diff --git a/src/lib.rs b/src/lib.rs index 8a179a51..f876bfbb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ pub mod jump; pub mod node; pub mod pty; pub mod security; +pub mod shared; pub mod ssh; pub mod ui; pub mod utils; diff --git a/src/security/mod.rs b/src/security/mod.rs index 6fd8a8d3..7adf3f24 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -14,12 +14,21 @@ //! Security utilities for validating and sanitizing user input and handling //! sensitive data securely. +//! +//! # Re-exports +//! +//! This module re-exports validation functions from the shared module for +//! backward compatibility. New code should prefer importing directly from +//! `crate::shared::validation`. mod sudo; -mod validation; -// Re-export validation functions -pub use validation::{ +// Keep the validation module for backward compatibility but it now re-exports +// from shared +pub mod validation; + +// Re-export validation functions from shared module for backward compatibility +pub use crate::shared::validation::{ sanitize_error_message, validate_hostname, validate_local_path, validate_remote_path, validate_username, }; diff --git a/src/security/validation.rs b/src/security/validation.rs index fa5cf717..bdca0daf 100644 --- a/src/security/validation.rs +++ b/src/security/validation.rs @@ -12,275 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Security utilities for validating and sanitizing user input - -use anyhow::{Context, Result}; -use std::path::{Path, PathBuf}; - -/// Validate and sanitize a local file path to prevent path traversal attacks -/// -/// This function ensures: -/// - No path traversal sequences (..) -/// - No double slashes (//) -/// - Path is canonical and resolved -/// - No symlink attacks -pub fn validate_local_path(path: &Path) -> Result { - // Convert to string to check for dangerous patterns - let path_str = path.to_string_lossy(); - - // Check for path traversal attempts - if path_str.contains("..") { - anyhow::bail!("Path traversal detected: path contains '..'"); - } - - // Check for double slashes - if path_str.contains("//") { - anyhow::bail!("Invalid path: contains double slashes"); - } - - // Get canonical path (resolves symlinks, .., ., etc.) - // This will fail if the path doesn't exist yet, so we handle that case - let canonical = if path.exists() { - path.canonicalize() - .with_context(|| format!("Failed to canonicalize path: {path:?}"))? - } else { - // For non-existent paths, validate the parent directory - if let Some(parent) = path.parent() { - if parent.as_os_str().is_empty() { - // Parent is empty, use current directory - std::env::current_dir() - .with_context(|| "Failed to get current directory")? - .join(path) - } else if parent.exists() { - let canonical_parent = parent - .canonicalize() - .with_context(|| format!("Failed to canonicalize parent path: {parent:?}"))?; - - // Get the file name - let file_name = path - .file_name() - .ok_or_else(|| anyhow::anyhow!("Invalid path: no file name component"))?; - - // Validate file name doesn't contain path separators - let file_name_str = file_name.to_string_lossy(); - if file_name_str.contains('/') || file_name_str.contains('\\') { - anyhow::bail!("Invalid file name: contains path separator"); - } - - canonical_parent.join(file_name) - } else { - // Parent doesn't exist, recursively create and validate - validate_local_path(parent)?; - validate_local_path(path)? - } - } else { - // No parent, treat as relative to current directory - std::env::current_dir() - .with_context(|| "Failed to get current directory")? - .join(path) - } - }; - - Ok(canonical) -} - -/// Validate a remote path string to prevent injection attacks -/// -/// This function ensures: -/// - No shell metacharacters that could cause command injection -/// - No path traversal sequences -/// - Only valid characters for file paths -pub fn validate_remote_path(path: &str) -> Result { - // Check for empty path - if path.is_empty() { - anyhow::bail!("Remote path cannot be empty"); - } - - // Check path length to prevent DoS - const MAX_PATH_LENGTH: usize = 4096; - if path.len() > MAX_PATH_LENGTH { - anyhow::bail!("Remote path too long (max {MAX_PATH_LENGTH} characters)"); - } - - // Check for shell metacharacters that could cause injection - const DANGEROUS_CHARS: &[char] = &[ - ';', '&', '|', '`', '$', '(', ')', '{', '}', '<', '>', '\n', '\r', '\0', '!', '*', '?', - '[', ']', // Shell wildcards that could cause issues - ]; - - for &ch in DANGEROUS_CHARS { - if path.contains(ch) { - anyhow::bail!("Remote path contains invalid character: '{ch}'"); - } - } - - // Check for command substitution patterns - if path.contains("$(") || path.contains("${") || path.contains("`)") { - anyhow::bail!("Remote path contains potential command substitution"); - } - - // Check for path traversal - if path.contains("../") || path.starts_with("..") || path.ends_with("..") { - anyhow::bail!("Remote path contains path traversal sequence"); - } - - // Check for double slashes (could indicate protocol bypasses) - if path.contains("//") && !path.starts_with("//") { - anyhow::bail!("Remote path contains double slashes"); - } - - // Validate that path contains only allowed characters - // Allow: alphanumeric, spaces, and common path characters - let valid_chars = path.chars().all(|c| { - c.is_ascii_alphanumeric() - || c == '/' - || c == '\\' - || c == '.' - || c == '-' - || c == '_' - || c == ' ' - || c == '~' - || c == '=' - || c == ',' - || c == ':' - || c == '@' - }); - - if !valid_chars { - anyhow::bail!("Remote path contains invalid characters"); - } - - Ok(path.to_string()) -} - -/// Sanitize a hostname to prevent injection attacks -pub fn validate_hostname(hostname: &str) -> Result { - // Check for empty hostname - if hostname.is_empty() { - anyhow::bail!("Hostname cannot be empty"); - } - - // Check hostname length (RFC 1123) - const MAX_HOSTNAME_LENGTH: usize = 253; - if hostname.len() > MAX_HOSTNAME_LENGTH { - anyhow::bail!("Hostname too long (max {MAX_HOSTNAME_LENGTH} characters)"); - } - - // Validate hostname format (RFC 1123) - // Allow alphanumeric, dots, hyphens, and colons (for IPv6) - let valid_chars = hostname.chars().all(|c| { - c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == ':' || c == '[' || c == ']' - }); - - if !valid_chars { - anyhow::bail!("Hostname contains invalid characters"); - } - - // Check for suspicious patterns - if hostname.contains("..") || hostname.contains("--") { - anyhow::bail!("Hostname contains suspicious repeated characters"); - } - - Ok(hostname.to_string()) -} - -/// Validate a username to prevent injection attacks -pub fn validate_username(username: &str) -> Result { - // Check for empty username - if username.is_empty() { - anyhow::bail!("Username cannot be empty"); - } - - // Check username length - const MAX_USERNAME_LENGTH: usize = 32; - if username.len() > MAX_USERNAME_LENGTH { - anyhow::bail!("Username too long (max {MAX_USERNAME_LENGTH} characters)"); - } - - // Validate username format (POSIX-compliant) - // Allow alphanumeric, underscore, hyphen, and dot - let valid_chars = username - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.'); - - if !valid_chars { - anyhow::bail!("Username contains invalid characters"); - } - - // Username should not start with a hyphen - if username.starts_with('-') { - anyhow::bail!("Username cannot start with a hyphen"); - } - - Ok(username.to_string()) -} - -/// Sanitize error messages to prevent information leakage -/// -/// This function redacts sensitive information like usernames, hostnames, -/// and ports from error messages to prevent information disclosure. -pub fn sanitize_error_message(message: &str) -> String { - // Replace specific user mentions with generic text - let mut sanitized = message.to_string(); - - // Remove specific usernames (format: user 'username') - if let Some(start) = sanitized.find("user '") { - if let Some(end) = sanitized[start + 6..].find('\'') { - let before = &sanitized[..start + 5]; - let after = &sanitized[start + 6 + end + 1..]; - sanitized = format!("{before}{after}"); - } - } - - // Remove hostname:port combinations - // Match patterns like "on hostname:port" or "to hostname:port" - let re_patterns = [ - r" on [a-zA-Z0-9\.\-]+:[0-9]+", - r" to [a-zA-Z0-9\.\-]+:[0-9]+", - r" at [a-zA-Z0-9\.\-]+:[0-9]+", - r" from [a-zA-Z0-9\.\-]+:[0-9]+", - ]; - - for _pattern in &re_patterns { - // Simple pattern matching without regex for security - // This is a simplified approach - in production, consider using a proper regex library - if sanitized.contains(" on ") - || sanitized.contains(" to ") - || sanitized.contains(" at ") - || sanitized.contains(" from ") - { - // Replace with generic message - sanitized = sanitized - .replace(" on ", " on ") - .replace(" to ", " to ") - .replace(" at ", " at ") - .replace(" from ", " from "); - } - } - - // Remove any remaining IP addresses - // Simple check for IPv4 pattern - let parts: Vec<&str> = sanitized.split_whitespace().collect(); - let mut result_parts = Vec::new(); - - for part in parts { - if part.split('.').count() == 4 - && part - .split('.') - .all(|p| p.parse::().is_ok() || p.contains(':')) - { - result_parts.push(""); - } else { - result_parts.push(part); - } - } - - result_parts.join(" ") -} - +//! Security utilities for validating and sanitizing user input. +//! +//! This module re-exports validation functions from the shared module for +//! backward compatibility. New code should prefer importing directly from +//! `crate::shared::validation`. +//! +//! # Migration Note +//! +//! The validation utilities have been moved to `crate::shared::validation` +//! to enable code reuse between the bssh client and server implementations. +//! This module continues to work for backward compatibility. + +// Re-export all validation functions from the shared module +pub use crate::shared::validation::{ + sanitize_error_message, validate_hostname, validate_local_path, validate_remote_path, + validate_username, +}; + +// Re-export tests to ensure they still run #[cfg(test)] mod tests { use super::*; + use std::path::Path; #[test] fn test_validate_local_path() { diff --git a/src/shared/auth_types.rs b/src/shared/auth_types.rs new file mode 100644 index 00000000..69852075 --- /dev/null +++ b/src/shared/auth_types.rs @@ -0,0 +1,358 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Shared authentication types for client and server implementations. +//! +//! This module provides common types used in authentication flows that can +//! be shared between the bssh client and server implementations. +//! +//! # Types +//! +//! - [`AuthResult`]: The outcome of an authentication attempt +//! - [`UserInfo`]: Information about an authenticated user +//! - [`AuthMethod`]: Available authentication methods + +use std::path::PathBuf; + +/// The result of an authentication attempt. +/// +/// This enum represents the three possible outcomes of attempting to +/// authenticate a user, following the SSH protocol semantics. +/// +/// # Examples +/// +/// ``` +/// use bssh::shared::auth_types::AuthResult; +/// +/// fn check_password(username: &str, password: &str) -> AuthResult { +/// if password == "secret" { +/// AuthResult::Accept +/// } else { +/// AuthResult::Reject +/// } +/// } +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum AuthResult { + /// Authentication succeeded - user is fully authenticated. + Accept, + + /// Authentication failed - access denied. + #[default] + Reject, + + /// Partial authentication success - more methods required. + /// + /// This is used for multi-factor authentication scenarios where + /// one authentication method succeeded but additional methods + /// are required to complete authentication. + Partial { + /// The list of remaining authentication methods the user must complete. + remaining_methods: Vec, + }, +} + +impl AuthResult { + /// Returns `true` if the authentication was fully successful. + pub fn is_accepted(&self) -> bool { + matches!(self, AuthResult::Accept) + } + + /// Returns `true` if the authentication was rejected. + pub fn is_rejected(&self) -> bool { + matches!(self, AuthResult::Reject) + } + + /// Returns `true` if partial authentication occurred. + pub fn is_partial(&self) -> bool { + matches!(self, AuthResult::Partial { .. }) + } + + /// Creates a partial result with the specified remaining methods. + /// + /// # Arguments + /// + /// * `methods` - An iterator of method names that are still required + /// + /// # Examples + /// + /// ``` + /// use bssh::shared::auth_types::AuthResult; + /// + /// let result = AuthResult::partial(["keyboard-interactive", "publickey"]); + /// assert!(result.is_partial()); + /// ``` + pub fn partial(methods: I) -> Self + where + I: IntoIterator, + S: Into, + { + AuthResult::Partial { + remaining_methods: methods.into_iter().map(Into::into).collect(), + } + } +} + +/// Information about an authenticated user. +/// +/// This struct contains user information that is commonly needed after +/// successful authentication, both for client-side display and server-side +/// session setup. +/// +/// # Examples +/// +/// ``` +/// use bssh::shared::auth_types::UserInfo; +/// use std::path::PathBuf; +/// +/// let user = UserInfo::new("johndoe") +/// .with_home_dir("/home/johndoe") +/// .with_shell("/bin/bash") +/// .with_uid(1000) +/// .with_gid(1000); +/// +/// assert_eq!(user.username, "johndoe"); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UserInfo { + /// The username of the authenticated user. + pub username: String, + + /// The user's home directory. + pub home_dir: PathBuf, + + /// The user's default shell. + pub shell: PathBuf, + + /// The user's numeric user ID (Unix-specific). + pub uid: Option, + + /// The user's primary group ID (Unix-specific). + pub gid: Option, + + /// Additional group IDs the user belongs to (Unix-specific). + pub groups: Vec, + + /// Display name or full name of the user. + pub display_name: Option, +} + +impl UserInfo { + /// Create a new UserInfo with just a username. + /// + /// Other fields are initialized to sensible defaults: + /// - home_dir: `/home/` on Unix, empty on other platforms + /// - shell: `/bin/sh` on Unix, empty on other platforms + /// - uid/gid: None + /// + /// # Arguments + /// + /// * `username` - The username + /// + /// # Examples + /// + /// ``` + /// use bssh::shared::auth_types::UserInfo; + /// + /// let user = UserInfo::new("alice"); + /// assert_eq!(user.username, "alice"); + /// ``` + pub fn new(username: impl Into) -> Self { + let username = username.into(); + + #[cfg(unix)] + let (home_dir, shell) = ( + PathBuf::from(format!("/home/{username}")), + PathBuf::from("/bin/sh"), + ); + + #[cfg(not(unix))] + let (home_dir, shell) = (PathBuf::new(), PathBuf::new()); + + Self { + username, + home_dir, + shell, + uid: None, + gid: None, + groups: Vec::new(), + display_name: None, + } + } + + /// Set the home directory. + pub fn with_home_dir(mut self, path: impl Into) -> Self { + self.home_dir = path.into(); + self + } + + /// Set the default shell. + pub fn with_shell(mut self, path: impl Into) -> Self { + self.shell = path.into(); + self + } + + /// Set the user ID. + pub fn with_uid(mut self, uid: u32) -> Self { + self.uid = Some(uid); + self + } + + /// Set the primary group ID. + pub fn with_gid(mut self, gid: u32) -> Self { + self.gid = Some(gid); + self + } + + /// Set the additional group IDs. + pub fn with_groups(mut self, groups: impl Into>) -> Self { + self.groups = groups.into(); + self + } + + /// Set the display name. + pub fn with_display_name(mut self, name: impl Into) -> Self { + self.display_name = Some(name.into()); + self + } +} + +/// Common SSH authentication method identifiers. +/// +/// These constants represent the standard SSH authentication method names +/// as defined in RFC 4252. +pub mod auth_method_names { + /// Password authentication (RFC 4252) + pub const PASSWORD: &str = "password"; + + /// Public key authentication (RFC 4252) + pub const PUBLICKEY: &str = "publickey"; + + /// Keyboard-interactive authentication (RFC 4256) + pub const KEYBOARD_INTERACTIVE: &str = "keyboard-interactive"; + + /// Host-based authentication (RFC 4252) + pub const HOSTBASED: &str = "hostbased"; + + /// No authentication required + pub const NONE: &str = "none"; +} + +/// Server host key verification methods. +/// +/// These methods control how the client verifies the server's host key +/// during connection. This type is re-exported from the authentication +/// module for convenience. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] +#[non_exhaustive] +pub enum ServerCheckMethod { + /// No verification - accept any host key (insecure, for testing only). + NoCheck, + + /// Verify against a specific base64 encoded public key. + PublicKey(String), + + /// Verify against a public key file. + PublicKeyFile(String), + + /// Use default known_hosts file (~/.ssh/known_hosts). + #[default] + DefaultKnownHostsFile, + + /// Use a specific known_hosts file path. + KnownHostsFile(String), +} + +impl ServerCheckMethod { + /// Create a ServerCheckMethod from a base64 encoded public key. + /// + /// # Arguments + /// + /// * `key` - The base64 encoded public key + pub fn with_public_key(key: impl Into) -> Self { + Self::PublicKey(key.into()) + } + + /// Create a ServerCheckMethod from a public key file path. + /// + /// # Arguments + /// + /// * `path` - Path to the public key file + pub fn with_public_key_file(path: impl Into) -> Self { + Self::PublicKeyFile(path.into()) + } + + /// Create a ServerCheckMethod from a known_hosts file path. + /// + /// # Arguments + /// + /// * `path` - Path to the known_hosts file + pub fn with_known_hosts_file(path: impl Into) -> Self { + Self::KnownHostsFile(path.into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_auth_result_states() { + let accept = AuthResult::Accept; + assert!(accept.is_accepted()); + assert!(!accept.is_rejected()); + assert!(!accept.is_partial()); + + let reject = AuthResult::Reject; + assert!(!reject.is_accepted()); + assert!(reject.is_rejected()); + assert!(!reject.is_partial()); + + let partial = AuthResult::partial(["password", "publickey"]); + assert!(!partial.is_accepted()); + assert!(!partial.is_rejected()); + assert!(partial.is_partial()); + } + + #[test] + fn test_user_info_builder() { + let user = UserInfo::new("testuser") + .with_home_dir("/custom/home") + .with_shell("/bin/zsh") + .with_uid(1001) + .with_gid(1001) + .with_groups(vec![100, 101]) + .with_display_name("Test User"); + + assert_eq!(user.username, "testuser"); + assert_eq!(user.home_dir, PathBuf::from("/custom/home")); + assert_eq!(user.shell, PathBuf::from("/bin/zsh")); + assert_eq!(user.uid, Some(1001)); + assert_eq!(user.gid, Some(1001)); + assert_eq!(user.groups, vec![100, 101]); + assert_eq!(user.display_name, Some("Test User".to_string())); + } + + #[test] + fn test_server_check_method() { + let default = ServerCheckMethod::default(); + assert_eq!(default, ServerCheckMethod::DefaultKnownHostsFile); + + let key = ServerCheckMethod::with_public_key("ssh-rsa AAAA..."); + assert!(matches!(key, ServerCheckMethod::PublicKey(_))); + + let file = ServerCheckMethod::with_known_hosts_file("/path/to/known_hosts"); + assert!(matches!(file, ServerCheckMethod::KnownHostsFile(_))); + } +} diff --git a/src/shared/error.rs b/src/shared/error.rs new file mode 100644 index 00000000..cecb1f8c --- /dev/null +++ b/src/shared/error.rs @@ -0,0 +1,341 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Shared error types for client and server implementations. +//! +//! This module provides common error types that can be used by both +//! the bssh client and server implementations. +//! +//! # Error Categories +//! +//! - [`ValidationError`]: Input validation failures +//! - [`AuthError`]: Authentication-related errors +//! - [`ConnectionError`]: Connection and network errors +//! - [`RateLimitError`]: Rate limiting errors + +use std::fmt; +use std::io; + +/// Error type for input validation failures. +/// +/// This error is returned when user input fails validation checks. +/// +/// # Examples +/// +/// ``` +/// use bssh::shared::error::ValidationError; +/// +/// let err = ValidationError::new("username", "contains invalid characters"); +/// assert!(err.to_string().contains("username")); +/// ``` +#[derive(Debug, Clone)] +pub struct ValidationError { + /// The field or input that failed validation + pub field: String, + /// Description of why validation failed + pub message: String, +} + +impl ValidationError { + /// Create a new validation error. + /// + /// # Arguments + /// + /// * `field` - The name of the field that failed validation + /// * `message` - Description of the validation failure + pub fn new(field: impl Into, message: impl Into) -> Self { + Self { + field: field.into(), + message: message.into(), + } + } + + /// Create an error for an empty field. + pub fn empty(field: impl Into) -> Self { + let field = field.into(); + Self { + message: format!("{field} cannot be empty"), + field, + } + } + + /// Create an error for a field that is too long. + pub fn too_long(field: impl Into, max_length: usize) -> Self { + let field = field.into(); + Self { + message: format!("{field} exceeds maximum length of {max_length}"), + field, + } + } + + /// Create an error for invalid characters. + pub fn invalid_characters(field: impl Into) -> Self { + let field = field.into(); + Self { + message: format!("{field} contains invalid characters"), + field, + } + } +} + +impl fmt::Display for ValidationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Validation error for '{}': {}", self.field, self.message) + } +} + +impl std::error::Error for ValidationError {} + +/// Error type for authentication failures. +/// +/// This enum represents various authentication-related errors that can +/// occur during the SSH authentication process. +#[derive(Debug)] +pub enum AuthError { + /// Invalid credentials (wrong password, invalid key, etc.) + InvalidCredentials, + + /// Authentication method not supported by server + MethodNotSupported(String), + + /// User not found or not allowed + UserNotAllowed(String), + + /// Account is locked or disabled + AccountLocked(String), + + /// Too many authentication attempts + TooManyAttempts, + + /// Authentication timeout + Timeout, + + /// SSH agent not available + AgentNotAvailable, + + /// No identities available in SSH agent + NoIdentities, + + /// Key file not found or unreadable + KeyFileError(String), + + /// Key format invalid or corrupted + KeyInvalid(String), + + /// Passphrase required but not provided + PassphraseRequired, + + /// Passphrase incorrect + PassphraseIncorrect, + + /// Server rejected the connection + ServerRejected(String), + + /// Internal error during authentication + Internal(String), +} + +impl fmt::Display for AuthError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthError::InvalidCredentials => write!(f, "Invalid credentials"), + AuthError::MethodNotSupported(method) => { + write!(f, "Authentication method '{method}' not supported") + } + AuthError::UserNotAllowed(_) => { + write!(f, "User is not allowed to connect") + } + AuthError::AccountLocked(_) => write!(f, "Account is locked"), + AuthError::TooManyAttempts => write!(f, "Too many authentication attempts"), + AuthError::Timeout => write!(f, "Authentication timed out"), + AuthError::AgentNotAvailable => write!(f, "SSH agent not available"), + AuthError::NoIdentities => write!(f, "No identities available in SSH agent"), + AuthError::KeyFileError(path) => write!(f, "Cannot read key file: {path}"), + AuthError::KeyInvalid(reason) => write!(f, "Invalid key: {reason}"), + AuthError::PassphraseRequired => write!(f, "Passphrase required for key"), + AuthError::PassphraseIncorrect => write!(f, "Incorrect passphrase"), + AuthError::ServerRejected(reason) => write!(f, "Server rejected: {reason}"), + AuthError::Internal(reason) => write!(f, "Internal authentication error: {reason}"), + } + } +} + +impl std::error::Error for AuthError {} + +/// Error type for connection failures. +/// +/// This enum represents various connection-related errors that can occur +/// during SSH connection establishment. +#[derive(Debug)] +pub enum ConnectionError { + /// Could not resolve hostname + DnsResolutionFailed(String), + + /// Connection refused by server + ConnectionRefused(String), + + /// Connection timed out + Timeout(String), + + /// Network unreachable + NetworkUnreachable(String), + + /// Host unreachable + HostUnreachable(String), + + /// Server closed connection unexpectedly + ConnectionClosed(String), + + /// Protocol version mismatch + ProtocolMismatch(String), + + /// Host key verification failed + HostKeyVerificationFailed(String), + + /// Rate limited + RateLimited(String), + + /// TLS/SSL error + TlsError(String), + + /// IO error + Io(io::Error), + + /// Other error + Other(String), +} + +impl fmt::Display for ConnectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ConnectionError::DnsResolutionFailed(host) => { + write!(f, "Could not resolve hostname: {host}") + } + ConnectionError::ConnectionRefused(host) => { + write!(f, "Connection refused by {host}") + } + ConnectionError::Timeout(msg) => write!(f, "Connection timed out: {msg}"), + ConnectionError::NetworkUnreachable(msg) => write!(f, "Network unreachable: {msg}"), + ConnectionError::HostUnreachable(host) => write!(f, "Host unreachable: {host}"), + ConnectionError::ConnectionClosed(msg) => write!(f, "Connection closed: {msg}"), + ConnectionError::ProtocolMismatch(msg) => write!(f, "Protocol mismatch: {msg}"), + ConnectionError::HostKeyVerificationFailed(msg) => { + write!(f, "Host key verification failed: {msg}") + } + ConnectionError::RateLimited(msg) => write!(f, "Rate limited: {msg}"), + ConnectionError::TlsError(msg) => write!(f, "TLS error: {msg}"), + ConnectionError::Io(err) => write!(f, "IO error: {err}"), + ConnectionError::Other(msg) => write!(f, "Connection error: {msg}"), + } + } +} + +impl std::error::Error for ConnectionError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + ConnectionError::Io(err) => Some(err), + _ => None, + } + } +} + +impl From for ConnectionError { + fn from(err: io::Error) -> Self { + ConnectionError::Io(err) + } +} + +/// Error type for rate limiting. +/// +/// This error is returned when a rate limit is exceeded. +#[derive(Debug, Clone)] +pub struct RateLimitError { + /// The identifier that was rate limited (e.g., hostname, IP) + pub identifier: String, + /// Estimated wait time in seconds before retry + pub wait_seconds: f64, +} + +impl RateLimitError { + /// Create a new rate limit error. + /// + /// # Arguments + /// + /// * `identifier` - The identifier that was rate limited + /// * `wait_seconds` - Estimated wait time before retry + pub fn new(identifier: impl Into, wait_seconds: f64) -> Self { + Self { + identifier: identifier.into(), + wait_seconds, + } + } +} + +impl fmt::Display for RateLimitError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Rate limit exceeded for '{}'. Please wait {:.1} seconds before retrying.", + self.identifier, self.wait_seconds + ) + } +} + +impl std::error::Error for RateLimitError {} + +/// A result type using our shared error types. +pub type SharedResult = Result>; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validation_error() { + let err = ValidationError::new("username", "contains spaces"); + assert_eq!(err.field, "username"); + assert!(err.to_string().contains("Validation error")); + + let empty_err = ValidationError::empty("password"); + assert!(empty_err.message.contains("cannot be empty")); + + let long_err = ValidationError::too_long("hostname", 253); + assert!(long_err.message.contains("253")); + } + + #[test] + fn test_auth_error_display() { + let err = AuthError::InvalidCredentials; + assert!(err.to_string().contains("Invalid credentials")); + + let method_err = AuthError::MethodNotSupported("kerberos".to_string()); + assert!(method_err.to_string().contains("kerberos")); + } + + #[test] + fn test_connection_error() { + let err = ConnectionError::ConnectionRefused("example.com:22".to_string()); + assert!(err.to_string().contains("Connection refused")); + + let io_err = ConnectionError::from(io::Error::new(io::ErrorKind::NotFound, "test")); + assert!(io_err.to_string().contains("IO error")); + } + + #[test] + fn test_rate_limit_error() { + let err = RateLimitError::new("192.168.1.1", 5.5); + assert!(err.to_string().contains("192.168.1.1")); + assert!(err.to_string().contains("5.5")); + } +} diff --git a/src/shared/mod.rs b/src/shared/mod.rs new file mode 100644 index 00000000..af552ede --- /dev/null +++ b/src/shared/mod.rs @@ -0,0 +1,72 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Shared module for code reuse between bssh client and server. +//! +//! This module contains utilities and types that are used by both the +//! bssh SSH client and the bssh-server implementations. By centralizing +//! this shared code, we ensure: +//! +//! - Consistent behavior between client and server +//! - No code duplication +//! - Easier maintenance +//! +//! # Modules +//! +//! - [`validation`]: Input validation utilities for usernames, hostnames, paths +//! - [`rate_limit`]: Generic token bucket rate limiter +//! - [`auth_types`]: Common authentication types and results +//! - [`error`]: Shared error types +//! +//! # Usage +//! +//! The shared module is designed to be used transparently by other modules. +//! For backward compatibility, the existing modules (`security`, `jump`) +//! re-export these shared utilities, so existing code continues to work. +//! +//! ## Direct Usage +//! +//! ```rust +//! use bssh::shared::validation::validate_hostname; +//! use bssh::shared::rate_limit::RateLimiter; +//! use bssh::shared::auth_types::{AuthResult, UserInfo}; +//! +//! // Validate a hostname +//! let hostname = validate_hostname("example.com").unwrap(); +//! +//! // Use the generic rate limiter +//! let limiter: RateLimiter = RateLimiter::new(); +//! ``` +//! +//! ## Via Re-exports (Backward Compatible) +//! +//! ```rust +//! // These continue to work as before +//! use bssh::security::validate_hostname; +//! use bssh::jump::rate_limiter::ConnectionRateLimiter; +//! ``` + +pub mod auth_types; +pub mod error; +pub mod rate_limit; +pub mod validation; + +// Re-export commonly used items at the module level for convenience +pub use auth_types::{AuthResult, ServerCheckMethod, UserInfo}; +pub use error::{AuthError, ConnectionError, RateLimitError, ValidationError}; +pub use rate_limit::{ConnectionRateLimiter, RateLimitConfig, RateLimiter}; +pub use validation::{ + sanitize_error_message, validate_hostname, validate_local_path, validate_remote_path, + validate_username, +}; diff --git a/src/shared/rate_limit.rs b/src/shared/rate_limit.rs new file mode 100644 index 00000000..acff82ae --- /dev/null +++ b/src/shared/rate_limit.rs @@ -0,0 +1,505 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Generic rate limiting using the token bucket algorithm. +//! +//! This module provides a reusable rate limiter that can be used for: +//! - Client: Connection attempt rate limiting +//! - Server: Authentication attempt rate limiting (fail2ban-like) +//! +//! # Token Bucket Algorithm +//! +//! The token bucket algorithm allows for bursting while maintaining +//! a sustained rate limit. Each bucket: +//! - Has a maximum capacity (burst size) +//! - Refills at a configured rate +//! - Requires one token per operation +//! +//! # Examples +//! +//! ``` +//! use bssh::shared::rate_limit::{RateLimiter, RateLimitConfig}; +//! use std::time::Duration; +//! +//! // Create a rate limiter for string keys (e.g., hostnames) +//! let config = RateLimitConfig::new(10, 2.0, Duration::from_secs(300)); +//! let limiter: RateLimiter = RateLimiter::with_config(config); +//! +//! // Use with different key types +//! // For IP addresses: RateLimiter +//! // For user IDs: RateLimiter +//! ``` + +use anyhow::{bail, Result}; +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; +use tracing::warn; + +/// Configuration for rate limiting. +/// +/// This struct defines the parameters for the token bucket algorithm: +/// - `max_tokens`: Maximum tokens (burst capacity) +/// - `refill_rate`: Tokens added per second +/// - `cleanup_after`: Duration after which inactive buckets are removed +#[derive(Debug, Clone)] +pub struct RateLimitConfig { + /// Maximum tokens per bucket (burst capacity) + pub max_tokens: u32, + /// Tokens refilled per second + pub refill_rate: f64, + /// Duration after which inactive buckets are cleaned up + pub cleanup_after: Duration, +} + +impl RateLimitConfig { + /// Create a new rate limit configuration. + /// + /// # Arguments + /// + /// * `max_tokens` - Maximum tokens per bucket (burst capacity) + /// * `refill_rate` - Tokens refilled per second (sustained rate) + /// * `cleanup_after` - Duration after which inactive buckets are removed + /// + /// # Examples + /// + /// ``` + /// use bssh::shared::rate_limit::RateLimitConfig; + /// use std::time::Duration; + /// + /// // Allow burst of 10, sustained 2/sec, cleanup after 5 minutes + /// let config = RateLimitConfig::new(10, 2.0, Duration::from_secs(300)); + /// ``` + pub fn new(max_tokens: u32, refill_rate: f64, cleanup_after: Duration) -> Self { + Self { + max_tokens, + refill_rate, + cleanup_after, + } + } +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + max_tokens: 10, // Allow burst of 10 operations + refill_rate: 2.0, // 2 operations per second sustained + cleanup_after: Duration::from_secs(300), // Clean up after 5 minutes + } + } +} + +/// Token bucket for a single key. +#[derive(Debug)] +struct TokenBucket { + /// Current token count + tokens: f64, + /// Last refill timestamp + last_refill: Instant, + /// Last access timestamp (for cleanup) + last_access: Instant, +} + +/// Generic token bucket rate limiter. +/// +/// This rate limiter can be used with any hashable key type, making it +/// suitable for various use cases: +/// - Connection rate limiting by hostname (String) +/// - Authentication rate limiting by IP address (IpAddr) +/// - API rate limiting by user ID (u64) +/// +/// # Type Parameters +/// +/// * `K` - The key type used to identify rate limit buckets. +/// Must implement `Hash`, `Eq`, `Clone`, and `Send + Sync`. +/// +/// # Thread Safety +/// +/// The rate limiter is thread-safe and can be shared across async tasks. +/// +/// # Examples +/// +/// ``` +/// use bssh::shared::rate_limit::{RateLimiter, RateLimitConfig}; +/// use std::time::Duration; +/// +/// #[tokio::main] +/// async fn main() { +/// let limiter: RateLimiter = RateLimiter::new(); +/// +/// // Acquire a token for a host +/// if limiter.try_acquire(&"example.com".to_string()).await.is_ok() { +/// println!("Allowed"); +/// } else { +/// println!("Rate limited"); +/// } +/// } +/// ``` +#[derive(Debug)] +pub struct RateLimiter +where + K: Hash + Eq + Clone + Send + Sync, +{ + /// Token buckets per key + buckets: Arc>>, + /// Rate limit configuration + config: RateLimitConfig, +} + +impl RateLimiter +where + K: Hash + Eq + Clone + Send + Sync + std::fmt::Display, +{ + /// Create a new rate limiter with default settings. + /// + /// Default: 10 operations burst, 2 operations/second sustained, + /// cleanup after 5 minutes of inactivity. + pub fn new() -> Self { + Self { + buckets: Arc::new(RwLock::new(HashMap::new())), + config: RateLimitConfig::default(), + } + } + + /// Create a new rate limiter with custom configuration. + /// + /// # Arguments + /// + /// * `config` - The rate limit configuration + /// + /// # Examples + /// + /// ``` + /// use bssh::shared::rate_limit::{RateLimiter, RateLimitConfig}; + /// use std::time::Duration; + /// + /// let config = RateLimitConfig::new(5, 1.0, Duration::from_secs(60)); + /// let limiter: RateLimiter = RateLimiter::with_config(config); + /// ``` + pub fn with_config(config: RateLimitConfig) -> Self { + Self { + buckets: Arc::new(RwLock::new(HashMap::new())), + config, + } + } + + /// Create a new rate limiter with simple configuration. + /// + /// # Arguments + /// + /// * `max_tokens` - Maximum tokens (burst capacity) + /// * `refill_rate` - Tokens refilled per second + /// + /// # Examples + /// + /// ``` + /// use bssh::shared::rate_limit::RateLimiter; + /// + /// // Allow burst of 5, sustained 1/sec + /// let limiter: RateLimiter = RateLimiter::with_simple_config(5, 1.0); + /// ``` + pub fn with_simple_config(max_tokens: u32, refill_rate: f64) -> Self { + Self { + buckets: Arc::new(RwLock::new(HashMap::new())), + config: RateLimitConfig { + max_tokens, + refill_rate, + cleanup_after: Duration::from_secs(300), + }, + } + } + + /// Try to acquire a token for the given key. + /// + /// Returns `Ok(())` if a token was acquired, or an error if rate limited. + /// + /// # Arguments + /// + /// * `key` - The key identifying the rate limit bucket + /// + /// # Returns + /// + /// - `Ok(())` if the operation is allowed + /// - `Err(...)` if rate limited, with the wait time in the error message + /// + /// # Examples + /// + /// ``` + /// use bssh::shared::rate_limit::RateLimiter; + /// + /// #[tokio::main] + /// async fn main() { + /// let limiter: RateLimiter = RateLimiter::new(); + /// + /// match limiter.try_acquire(&"key".to_string()).await { + /// Ok(()) => println!("Allowed"), + /// Err(e) => println!("Rate limited: {e}"), + /// } + /// } + /// ``` + pub async fn try_acquire(&self, key: &K) -> Result<()> { + let mut buckets = self.buckets.write().await; + let now = Instant::now(); + + // Clean up old buckets periodically to prevent unbounded memory growth + // Use a lower threshold to trigger cleanup more frequently + if buckets.len() > 10 { + self.cleanup_old_buckets(&mut buckets, now); + } + + let bucket = buckets.entry(key.clone()).or_insert_with(|| TokenBucket { + tokens: self.config.max_tokens as f64, + last_refill: now, + last_access: now, + }); + + // Refill tokens based on time elapsed + let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); + let tokens_to_add = elapsed * self.config.refill_rate; + bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.max_tokens as f64); + bucket.last_refill = now; + bucket.last_access = now; + + // Try to consume a token + if bucket.tokens >= 1.0 { + bucket.tokens -= 1.0; + Ok(()) + } else { + let wait_time = (1.0 - bucket.tokens) / self.config.refill_rate; + // Log without exposing the key to prevent information disclosure + warn!("Rate limit exceeded: wait {:.1}s before retry", wait_time); + bail!("Rate limit exceeded. Please wait {wait_time:.1} seconds before retrying.") + } + } + + /// Check if a key is currently rate limited without consuming a token. + /// + /// This is useful for checking rate limit status without affecting the bucket. + /// + /// # Arguments + /// + /// * `key` - The key identifying the rate limit bucket + /// + /// # Returns + /// + /// `true` if the key is rate limited, `false` otherwise. + pub async fn is_rate_limited(&self, key: &K) -> bool { + let buckets = self.buckets.read().await; + if let Some(bucket) = buckets.get(key) { + let now = Instant::now(); + let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); + let tokens_available = (bucket.tokens + elapsed * self.config.refill_rate) + .min(self.config.max_tokens as f64); + tokens_available < 1.0 + } else { + false + } + } + + /// Get the current token count for a key. + /// + /// Returns `None` if the key has no bucket (never been rate limited). + /// + /// # Arguments + /// + /// * `key` - The key identifying the rate limit bucket + /// + /// # Returns + /// + /// The current (estimated) token count, or `None` if no bucket exists. + pub async fn get_tokens(&self, key: &K) -> Option { + let buckets = self.buckets.read().await; + buckets.get(key).map(|bucket| { + let now = Instant::now(); + let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); + (bucket.tokens + elapsed * self.config.refill_rate).min(self.config.max_tokens as f64) + }) + } + + /// Clean up old token buckets that haven't been used recently. + fn cleanup_old_buckets(&self, buckets: &mut HashMap, now: Instant) { + buckets.retain(|_key, bucket| { + now.duration_since(bucket.last_access) < self.config.cleanup_after + }); + } + + /// Reset rate limit for a specific key. + /// + /// This removes the bucket for the key, allowing a fresh start. + /// Useful for testing or administrative overrides. + /// + /// # Arguments + /// + /// * `key` - The key to reset + pub async fn reset_key(&self, key: &K) { + let mut buckets = self.buckets.write().await; + buckets.remove(key); + } + + /// Clear all rate limit data. + /// + /// This removes all buckets, resetting the rate limiter to initial state. + pub async fn clear_all(&self) { + let mut buckets = self.buckets.write().await; + buckets.clear(); + } + + /// Get the number of tracked keys. + /// + /// Returns the number of keys currently being rate limited. + pub async fn tracked_key_count(&self) -> usize { + let buckets = self.buckets.read().await; + buckets.len() + } + + /// Get the rate limit configuration. + pub fn config(&self) -> &RateLimitConfig { + &self.config + } +} + +impl Default for RateLimiter +where + K: Hash + Eq + Clone + Send + Sync + std::fmt::Display, +{ + fn default() -> Self { + Self::new() + } +} + +impl Clone for RateLimiter +where + K: Hash + Eq + Clone + Send + Sync, +{ + fn clone(&self) -> Self { + Self { + buckets: Arc::clone(&self.buckets), + config: self.config.clone(), + } + } +} + +/// Type alias for connection rate limiting by hostname. +/// +/// This is a convenience type for the common use case of rate limiting +/// connection attempts per hostname string. +pub type ConnectionRateLimiter = RateLimiter; + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_rate_limiter_allows_burst() { + let limiter: RateLimiter = RateLimiter::with_simple_config(3, 1.0); + + // Should allow 3 operations in burst + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + + // 4th should fail + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_err()); + } + + #[tokio::test] + async fn test_rate_limiter_refills() { + let limiter: RateLimiter = RateLimiter::with_simple_config(2, 10.0); // Fast refill for testing + + // Use up tokens + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_err()); + + // Wait for refill + tokio::time::sleep(Duration::from_millis(150)).await; + + // Should have refilled + assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok()); + } + + #[tokio::test] + async fn test_rate_limiter_per_key() { + let limiter: RateLimiter = RateLimiter::with_simple_config(1, 1.0); + + // Different keys should have separate buckets + assert!(limiter.try_acquire(&"host1.com".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"host2.com".to_string()).await.is_ok()); + + // But same key should be limited + assert!(limiter.try_acquire(&"host1.com".to_string()).await.is_err()); + } + + #[tokio::test] + async fn test_rate_limiter_with_numeric_key() { + // Test with numeric keys (e.g., user IDs) + let limiter: RateLimiter = RateLimiter::with_simple_config(2, 1.0); + + assert!(limiter.try_acquire(&1).await.is_ok()); + assert!(limiter.try_acquire(&1).await.is_ok()); + assert!(limiter.try_acquire(&1).await.is_err()); + assert!(limiter.try_acquire(&2).await.is_ok()); // Different key + } + + #[tokio::test] + async fn test_is_rate_limited() { + let limiter: RateLimiter = RateLimiter::with_simple_config(1, 1.0); + + // Initially not limited + assert!(!limiter.is_rate_limited(&"test".to_string()).await); + + // Use up tokens + assert!(limiter.try_acquire(&"test".to_string()).await.is_ok()); + + // Now limited + assert!(limiter.is_rate_limited(&"test".to_string()).await); + } + + #[tokio::test] + async fn test_reset_key() { + let limiter: RateLimiter = RateLimiter::with_simple_config(1, 1.0); + + // Use up tokens + assert!(limiter.try_acquire(&"test".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"test".to_string()).await.is_err()); + + // Reset + limiter.reset_key(&"test".to_string()).await; + + // Should work again + assert!(limiter.try_acquire(&"test".to_string()).await.is_ok()); + } + + #[tokio::test] + async fn test_clear_all() { + let limiter: RateLimiter = RateLimiter::with_simple_config(1, 1.0); + + // Use up tokens for multiple keys + assert!(limiter.try_acquire(&"host1".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"host2".to_string()).await.is_ok()); + assert_eq!(limiter.tracked_key_count().await, 2); + + // Clear all + limiter.clear_all().await; + + // Should be empty + assert_eq!(limiter.tracked_key_count().await, 0); + + // All should work again + assert!(limiter.try_acquire(&"host1".to_string()).await.is_ok()); + assert!(limiter.try_acquire(&"host2".to_string()).await.is_ok()); + } +} diff --git a/src/shared/validation.rs b/src/shared/validation.rs new file mode 100644 index 00000000..d8124c34 --- /dev/null +++ b/src/shared/validation.rs @@ -0,0 +1,553 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Shared validation utilities for validating and sanitizing user input. +//! +//! This module provides security utilities for validating user input that can +//! be reused between the bssh client and server implementations. +//! +//! # Security +//! +//! These functions are designed to prevent: +//! - Path traversal attacks +//! - Command injection +//! - Information leakage through error messages + +use anyhow::{Context, Result}; +use std::path::{Path, PathBuf}; + +/// Maximum recursion depth for validating non-existent paths +const MAX_PATH_VALIDATION_DEPTH: u32 = 20; + +/// Helper function to validate non-existent paths with recursion depth limit. +/// +/// This prevents infinite recursion when validating paths with non-existent parents. +fn validate_nonexistent_path(path: &Path, depth: u32) -> Result { + // Check recursion depth to prevent stack overflow + if depth >= MAX_PATH_VALIDATION_DEPTH { + anyhow::bail!("Path validation depth exceeded (max {MAX_PATH_VALIDATION_DEPTH} levels)"); + } + + if let Some(parent) = path.parent() { + if parent.as_os_str().is_empty() { + // Parent is empty, use current directory + Ok(std::env::current_dir() + .with_context(|| "Failed to get current directory")? + .join(path)) + } else if parent.exists() { + let canonical_parent = parent + .canonicalize() + .with_context(|| format!("Failed to canonicalize parent path: {parent:?}"))?; + + // Get the file name + let file_name = path + .file_name() + .ok_or_else(|| anyhow::anyhow!("Invalid path: no file name component"))?; + + // Validate file name doesn't contain path separators + let file_name_str = file_name.to_string_lossy(); + if file_name_str.contains('/') || file_name_str.contains('\\') { + anyhow::bail!("Invalid file name: contains path separator"); + } + + Ok(canonical_parent.join(file_name)) + } else { + // Parent doesn't exist, recursively validate with depth tracking + let canonical_parent = validate_nonexistent_path(parent, depth + 1)?; + + // Get the file name + let file_name = path + .file_name() + .ok_or_else(|| anyhow::anyhow!("Invalid path: no file name component"))?; + + // Validate file name doesn't contain path separators + let file_name_str = file_name.to_string_lossy(); + if file_name_str.contains('/') || file_name_str.contains('\\') { + anyhow::bail!("Invalid file name: contains path separator"); + } + + Ok(canonical_parent.join(file_name)) + } + } else { + // No parent, treat as relative to current directory + Ok(std::env::current_dir() + .with_context(|| "Failed to get current directory")? + .join(path)) + } +} + +/// Validate and sanitize a local file path to prevent path traversal attacks. +/// +/// This function ensures: +/// - No path traversal sequences (..) +/// - No double slashes (//) +/// - Path is canonical and resolved +/// - No symlink attacks +/// +/// # Arguments +/// +/// * `path` - The local file path to validate +/// +/// # Returns +/// +/// Returns the canonical path if validation succeeds. +/// +/// # Errors +/// +/// Returns an error if: +/// - Path contains traversal sequences (..) +/// - Path contains double slashes (//) +/// - Path cannot be canonicalized +/// +/// # Examples +/// +/// ``` +/// use std::path::Path; +/// use bssh::shared::validation::validate_local_path; +/// +/// // Valid path +/// let result = validate_local_path(Path::new("/tmp/test.txt")); +/// assert!(result.is_ok()); +/// +/// // Invalid path with traversal +/// let result = validate_local_path(Path::new("../etc/passwd")); +/// assert!(result.is_err()); +/// ``` +pub fn validate_local_path(path: &Path) -> Result { + // Convert to string to check for dangerous patterns + let path_str = path.to_string_lossy(); + + // Check for path traversal attempts + if path_str.contains("..") { + anyhow::bail!("Path traversal detected: path contains '..'"); + } + + // Check for double slashes + if path_str.contains("//") { + anyhow::bail!("Invalid path: contains double slashes"); + } + + // Get canonical path (resolves symlinks, .., ., etc.) + // This will fail if the path doesn't exist yet, so we handle that case + let canonical = if path.exists() { + path.canonicalize() + .with_context(|| format!("Failed to canonicalize path: {path:?}"))? + } else { + // For non-existent paths, validate the parent directory + validate_nonexistent_path(path, 0)? + }; + + Ok(canonical) +} + +/// Validate a remote path string to prevent injection attacks. +/// +/// This function ensures: +/// - No shell metacharacters that could cause command injection +/// - No path traversal sequences +/// - Only valid characters for file paths +/// +/// # Arguments +/// +/// * `path` - The remote path string to validate +/// +/// # Returns +/// +/// Returns the validated path string if validation succeeds. +/// +/// # Errors +/// +/// Returns an error if: +/// - Path is empty +/// - Path is too long (>4096 characters) +/// - Path contains shell metacharacters +/// - Path contains command substitution patterns +/// - Path contains path traversal sequences +/// +/// # Examples +/// +/// ``` +/// use bssh::shared::validation::validate_remote_path; +/// +/// // Valid paths +/// assert!(validate_remote_path("/home/user/file.txt").is_ok()); +/// assert!(validate_remote_path("~/documents/report.pdf").is_ok()); +/// +/// // Invalid paths +/// assert!(validate_remote_path("/tmp/$(whoami)").is_err()); +/// assert!(validate_remote_path("../etc/passwd").is_err()); +/// ``` +pub fn validate_remote_path(path: &str) -> Result { + // Check for empty path + if path.is_empty() { + anyhow::bail!("Remote path cannot be empty"); + } + + // Check path length to prevent DoS + const MAX_PATH_LENGTH: usize = 4096; + if path.len() > MAX_PATH_LENGTH { + anyhow::bail!("Remote path too long (max {MAX_PATH_LENGTH} characters)"); + } + + // Check for shell metacharacters that could cause injection + const DANGEROUS_CHARS: &[char] = &[ + ';', '&', '|', '`', '$', '(', ')', '{', '}', '<', '>', '\n', '\r', '\0', '!', '*', '?', + '[', ']', // Shell wildcards that could cause issues + ]; + + for &ch in DANGEROUS_CHARS { + if path.contains(ch) { + anyhow::bail!("Remote path contains invalid character: '{ch}'"); + } + } + + // Check for command substitution patterns + if path.contains("$(") || path.contains("${") || path.contains("`)") { + anyhow::bail!("Remote path contains potential command substitution"); + } + + // Check for path traversal - all possible patterns + if path.contains("../") + || path.contains("/..") + || path.starts_with("../") + || path.starts_with("/..") + || path.ends_with("/..") + || path == ".." + { + anyhow::bail!("Remote path contains path traversal sequence"); + } + + // Check for double slashes (could indicate protocol bypasses) + if path.contains("//") && !path.starts_with("//") { + anyhow::bail!("Remote path contains double slashes"); + } + + // Validate that path contains only allowed characters + // Allow: alphanumeric, spaces, and common path characters + let valid_chars = path.chars().all(|c| { + c.is_ascii_alphanumeric() + || c == '/' + || c == '\\' + || c == '.' + || c == '-' + || c == '_' + || c == ' ' + || c == '~' + || c == '=' + || c == ',' + || c == ':' + || c == '@' + }); + + if !valid_chars { + anyhow::bail!("Remote path contains invalid characters"); + } + + Ok(path.to_string()) +} + +/// Sanitize a hostname to prevent injection attacks. +/// +/// This function validates that hostnames conform to RFC 1123 and don't contain +/// characters that could be used for command injection. +/// +/// # Arguments +/// +/// * `hostname` - The hostname string to validate +/// +/// # Returns +/// +/// Returns the validated hostname if validation succeeds. +/// +/// # Errors +/// +/// Returns an error if: +/// - Hostname is empty +/// - Hostname is too long (>253 characters, per RFC 1123) +/// - Hostname contains invalid characters +/// - Hostname contains suspicious patterns +/// +/// # Examples +/// +/// ``` +/// use bssh::shared::validation::validate_hostname; +/// +/// // Valid hostnames +/// assert!(validate_hostname("example.com").is_ok()); +/// assert!(validate_hostname("192.168.1.1").is_ok()); +/// assert!(validate_hostname("[::1]").is_ok()); +/// +/// // Invalid hostnames +/// assert!(validate_hostname("example..com").is_err()); +/// assert!(validate_hostname("example.com; ls").is_err()); +/// ``` +pub fn validate_hostname(hostname: &str) -> Result { + // Check for empty hostname + if hostname.is_empty() { + anyhow::bail!("Hostname cannot be empty"); + } + + // Check hostname length (RFC 1123) + const MAX_HOSTNAME_LENGTH: usize = 253; + if hostname.len() > MAX_HOSTNAME_LENGTH { + anyhow::bail!("Hostname too long (max {MAX_HOSTNAME_LENGTH} characters)"); + } + + // Validate hostname format (RFC 1123) + // Allow alphanumeric, dots, hyphens, and colons (for IPv6) + let valid_chars = hostname.chars().all(|c| { + c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == ':' || c == '[' || c == ']' + }); + + if !valid_chars { + anyhow::bail!("Hostname contains invalid characters"); + } + + // Check for suspicious patterns + if hostname.contains("..") || hostname.contains("--") { + anyhow::bail!("Hostname contains suspicious repeated characters"); + } + + Ok(hostname.to_string()) +} + +/// Validate a username to prevent injection attacks. +/// +/// This function validates that usernames conform to POSIX standards and don't +/// contain characters that could be used for command injection. +/// +/// # Arguments +/// +/// * `username` - The username string to validate +/// +/// # Returns +/// +/// Returns the validated username if validation succeeds. +/// +/// # Errors +/// +/// Returns an error if: +/// - Username is empty +/// - Username is too long (>32 characters) +/// - Username contains invalid characters +/// - Username starts with a hyphen +/// +/// # Examples +/// +/// ``` +/// use bssh::shared::validation::validate_username; +/// +/// // Valid usernames +/// assert!(validate_username("john_doe").is_ok()); +/// assert!(validate_username("user123").is_ok()); +/// +/// // Invalid usernames +/// assert!(validate_username("-user").is_err()); +/// assert!(validate_username("user@domain").is_err()); +/// ``` +pub fn validate_username(username: &str) -> Result { + // Check for empty username + if username.is_empty() { + anyhow::bail!("Username cannot be empty"); + } + + // Check username length + const MAX_USERNAME_LENGTH: usize = 32; + if username.len() > MAX_USERNAME_LENGTH { + anyhow::bail!("Username too long (max {MAX_USERNAME_LENGTH} characters)"); + } + + // Validate username format (POSIX-compliant) + // Allow alphanumeric, underscore, hyphen, and dot + let valid_chars = username + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.'); + + if !valid_chars { + anyhow::bail!("Username contains invalid characters"); + } + + // Username should not start with a hyphen + if username.starts_with('-') { + anyhow::bail!("Username cannot start with a hyphen"); + } + + Ok(username.to_string()) +} + +/// Sanitize error messages to prevent information leakage. +/// +/// This function redacts sensitive information like usernames, hostnames, +/// and ports from error messages to prevent information disclosure. +/// +/// # Arguments +/// +/// * `message` - The error message to sanitize +/// +/// # Returns +/// +/// Returns the sanitized error message with sensitive information redacted. +/// +/// # Examples +/// +/// ``` +/// use bssh::shared::validation::sanitize_error_message; +/// +/// let message = "Failed to connect to 192.168.1.1:22"; +/// let sanitized = sanitize_error_message(message); +/// // IP address is redacted +/// ``` +pub fn sanitize_error_message(message: &str) -> String { + let mut sanitized = message.to_string(); + + // Remove specific usernames (format: user 'username') + if let Some(start) = sanitized.find("user '") { + if let Some(end) = sanitized[start + 6..].find('\'') { + let before = &sanitized[..start + 5]; + let after = &sanitized[start + 6 + end + 1..]; + sanitized = format!("{before}{after}"); + } + } + + // Remove hostname:port combinations in common patterns + // We process these sequentially since each replacement may affect subsequent ones + let patterns = [ + (" on ", " on "), + (" to ", " to "), + (" at ", " at "), + (" from ", " from "), + ]; + + for (pattern, replacement) in &patterns { + if sanitized.contains(pattern) { + // Find pattern and replace following hostname:port + let parts: Vec<&str> = sanitized.split(pattern).collect(); + let mut result = String::new(); + + for (i, part) in parts.iter().enumerate() { + result.push_str(part); + if i < parts.len() - 1 { + result.push_str(replacement); + // Skip the actual hostname:port in the next part + if let Some(next_space) = parts[i + 1].find(' ') { + result.push_str(&parts[i + 1][next_space..]); + } + } + } + sanitized = result; + } + } + + // Remove any remaining IP addresses + // Simple check for IPv4 pattern + let parts: Vec<&str> = sanitized.split_whitespace().collect(); + let mut result_parts = Vec::new(); + + for part in parts { + if part.split('.').count() == 4 + && part + .split('.') + .all(|p| p.parse::().is_ok() || p.contains(':')) + { + result_parts.push(""); + } else { + result_parts.push(part); + } + } + + result_parts.join(" ") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_local_path() { + // Valid paths + assert!(validate_local_path(Path::new("/tmp/test.txt")).is_ok()); + assert!(validate_local_path(Path::new("./test.txt")).is_ok()); + + // Invalid paths with traversal + assert!(validate_local_path(Path::new("../etc/passwd")).is_err()); + assert!(validate_local_path(Path::new("/tmp/../etc/passwd")).is_err()); + assert!(validate_local_path(Path::new("/tmp//test")).is_err()); + } + + #[test] + fn test_validate_remote_path() { + // Valid paths + assert!(validate_remote_path("/home/user/file.txt").is_ok()); + assert!(validate_remote_path("~/documents/report.pdf").is_ok()); + assert!(validate_remote_path("C:\\Users\\test\\file.txt").is_ok()); + + // Invalid paths + assert!(validate_remote_path("../etc/passwd").is_err()); + assert!(validate_remote_path("/tmp/$(whoami)").is_err()); + assert!(validate_remote_path("/tmp/test; rm -rf /").is_err()); + assert!(validate_remote_path("/tmp/test`id`").is_err()); + assert!(validate_remote_path("/tmp/test|cat").is_err()); + assert!(validate_remote_path("").is_err()); + } + + #[test] + fn test_validate_hostname() { + // Valid hostnames + assert!(validate_hostname("example.com").is_ok()); + assert!(validate_hostname("192.168.1.1").is_ok()); + assert!(validate_hostname("server-01.example.com").is_ok()); + assert!(validate_hostname("[::1]").is_ok()); + + // Invalid hostnames + assert!(validate_hostname("example..com").is_err()); + assert!(validate_hostname("server--01").is_err()); + assert!(validate_hostname("example.com; ls").is_err()); + assert!(validate_hostname("").is_err()); + } + + #[test] + fn test_validate_username() { + // Valid usernames + assert!(validate_username("john_doe").is_ok()); + assert!(validate_username("user123").is_ok()); + assert!(validate_username("test.user").is_ok()); + + // Invalid usernames + assert!(validate_username("-user").is_err()); + assert!(validate_username("user@domain").is_err()); + assert!(validate_username("user name").is_err()); + assert!(validate_username("").is_err()); + assert!(validate_username(&"a".repeat(50)).is_err()); + } + + #[test] + fn test_sanitize_error_message() { + // Test standalone IP address at start of message + let msg = "192.168.1.1 refused connection"; + let sanitized = sanitize_error_message(msg); + assert!(sanitized.contains("")); + assert!(!sanitized.contains("192.168.1.1")); + + // Test username redaction (user 'name' pattern) + let msg = "Authentication failed for user 'johndoe'"; + let sanitized = sanitize_error_message(msg); + assert!(sanitized.contains("")); + assert!(!sanitized.contains("johndoe")); + + // Test message without sensitive info passes through + let msg = "Connection timed out"; + let sanitized = sanitize_error_message(msg); + assert_eq!(sanitized, "Connection timed out"); + } +}