From b57608643ba3fabd73a3e4580fc1505d80e20881 Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Sat, 24 Jan 2026 12:30:08 +0900 Subject: [PATCH 1/3] feat: Implement authentication rate limiting (fail2ban-like) Add AuthRateLimiter with ban support to protect against brute-force attacks: - Track failed authentication attempts per IP address - Automatically ban IPs that exceed max attempts within time window - Configurable max attempts, time window, and ban duration - IP whitelist support for trusted addresses - Automatic cleanup of expired bans and failure records - Background cleanup task running every 60 seconds Configuration options added to SecurityConfig: - auth_window: Time window for counting attempts (default: 300s) - whitelist_ips: IPs exempt from rate limiting Integration with SSH handler: - Check if IP is banned before authentication - Record failures and trigger bans on threshold - Record success to reset failure counter - Logging for ban events Closes #140 --- src/server/config/types.rs | 22 ++ src/server/handler.rs | 125 ++++++ src/server/mod.rs | 25 +- src/server/security/mod.rs | 57 +++ src/server/security/rate_limit.rs | 613 ++++++++++++++++++++++++++++++ 5 files changed, 841 insertions(+), 1 deletion(-) create mode 100644 src/server/security/mod.rs create mode 100644 src/server/security/rate_limit.rs diff --git a/src/server/config/types.rs b/src/server/config/types.rs index 0c6dbd25..cb898ab5 100644 --- a/src/server/config/types.rs +++ b/src/server/config/types.rs @@ -368,12 +368,28 @@ pub struct SecurityConfig { #[serde(default = "default_max_auth_attempts")] pub max_auth_attempts: u32, + /// Time window in seconds for counting authentication attempts. + /// + /// Failed attempts outside this window are not counted toward the ban threshold. + /// + /// Default: 300 (5 minutes) + #[serde(default = "default_auth_window")] + pub auth_window: u64, + /// Ban duration in seconds after exceeding max auth attempts. /// /// Default: 300 (5 minutes) #[serde(default = "default_ban_time")] pub ban_time: u64, + /// IP addresses that are never banned (whitelist). + /// + /// These IPs are exempt from rate limiting and banning. + /// + /// Example: ["127.0.0.1", "::1"] + #[serde(default)] + pub whitelist_ips: Vec, + /// Maximum number of concurrent sessions per user. /// /// Default: 10 @@ -449,6 +465,10 @@ fn default_max_auth_attempts() -> u32 { 5 } +fn default_auth_window() -> u64 { + 300 +} + fn default_ban_time() -> u64 { 300 } @@ -517,7 +537,9 @@ impl Default for SecurityConfig { fn default() -> Self { Self { max_auth_attempts: default_max_auth_attempts(), + auth_window: default_auth_window(), ban_time: default_ban_time(), + whitelist_ips: Vec::new(), max_sessions_per_user: default_max_sessions(), idle_timeout: default_idle_timeout(), allowed_ips: Vec::new(), diff --git a/src/server/handler.rs b/src/server/handler.rs index 4f5442ec..eb476da9 100644 --- a/src/server/handler.rs +++ b/src/server/handler.rs @@ -32,6 +32,7 @@ use super::auth::AuthProvider; use super::config::ServerConfig; use super::exec::CommandExecutor; use super::pty::PtyConfig as PtyMasterConfig; +use super::security::AuthRateLimiter; use super::session::{ChannelState, PtyConfig, SessionId, SessionInfo, SessionManager}; use super::sftp::SftpHandler; use super::shell::ShellSession; @@ -57,6 +58,9 @@ pub struct SshHandler { /// Rate limiter for authentication attempts. rate_limiter: RateLimiter, + /// Auth rate limiter with ban support (fail2ban-like). + auth_rate_limiter: Option, + /// Session information for this connection. session_info: Option, @@ -83,6 +87,7 @@ impl SshHandler { sessions, auth_provider, rate_limiter, + auth_rate_limiter: None, session_info: Some(SessionInfo::new(peer_addr)), channels: HashMap::new(), } @@ -106,6 +111,33 @@ impl SshHandler { sessions, auth_provider, rate_limiter, + auth_rate_limiter: None, + session_info: Some(SessionInfo::new(peer_addr)), + channels: HashMap::new(), + } + } + + /// Create a new SSH handler with shared rate limiters including auth ban support. + /// + /// This is the preferred constructor for production use as it shares + /// both rate limiters across all handlers, providing server-wide rate limiting + /// and fail2ban-like functionality. + pub fn with_rate_limiters( + peer_addr: Option, + config: Arc, + sessions: Arc>, + rate_limiter: RateLimiter, + auth_rate_limiter: AuthRateLimiter, + ) -> Self { + let auth_provider = config.create_auth_provider(); + + Self { + peer_addr, + config, + sessions, + auth_provider, + rate_limiter, + auth_rate_limiter: Some(auth_rate_limiter), session_info: Some(SessionInfo::new(peer_addr)), channels: HashMap::new(), } @@ -128,6 +160,7 @@ impl SshHandler { sessions, auth_provider, rate_limiter, + auth_rate_limiter: None, session_info: Some(SessionInfo::new(peer_addr)), channels: HashMap::new(), } @@ -284,6 +317,7 @@ impl russh::server::Handler for SshHandler { // Clone what we need for the async block let auth_provider = Arc::clone(&self.auth_provider); let rate_limiter = self.rate_limiter.clone(); + let auth_rate_limiter = self.auth_rate_limiter.clone(); let peer_addr = self.peer_addr; let user = user.to_string(); let public_key = public_key.clone(); @@ -292,6 +326,23 @@ impl russh::server::Handler for SshHandler { let session_info = &mut self.session_info; async move { + // Check if IP is banned (fail2ban-like check) + if let Some(ref limiter) = auth_rate_limiter { + if let Some(ip) = peer_addr.map(|a| a.ip()) { + if limiter.is_banned(&ip).await { + tracing::warn!( + user = %user, + peer = ?peer_addr, + "Rejected auth from banned IP" + ); + return Ok(Auth::Reject { + proceed_with_methods: None, + partial_success: false, + }); + } + } + } + if exceeded { tracing::warn!( user = %user, @@ -349,6 +400,13 @@ impl russh::server::Handler for SshHandler { info.authenticate(&user); } + // Record success to reset failure counter + if let Some(ref limiter) = auth_rate_limiter { + if let Some(ip) = peer_addr.map(|a| a.ip()) { + limiter.record_success(&ip).await; + } + } + Ok(Auth::Accept) } Ok(_) => { @@ -359,6 +417,20 @@ impl russh::server::Handler for SshHandler { "Public key authentication rejected" ); + // Record failure for ban tracking + if let Some(ref limiter) = auth_rate_limiter { + if let Some(ip) = peer_addr.map(|a| a.ip()) { + let banned = limiter.record_failure(ip).await; + if banned { + tracing::warn!( + user = %user, + peer = ?peer_addr, + "IP banned due to too many failed auth attempts" + ); + } + } + } + let proceed = if methods.is_empty() { None } else { @@ -378,6 +450,13 @@ impl russh::server::Handler for SshHandler { "Error during public key verification" ); + // Record failure for ban tracking + if let Some(ref limiter) = auth_rate_limiter { + if let Some(ip) = peer_addr.map(|a| a.ip()) { + limiter.record_failure(ip).await; + } + } + let proceed = if methods.is_empty() { None } else { @@ -421,6 +500,7 @@ impl russh::server::Handler for SshHandler { // Clone what we need for the async block let auth_provider = Arc::clone(&self.auth_provider); let rate_limiter = self.rate_limiter.clone(); + let auth_rate_limiter = self.auth_rate_limiter.clone(); let peer_addr = self.peer_addr; let user = user.to_string(); // Use Zeroizing to ensure password is securely cleared from memory when dropped @@ -431,6 +511,23 @@ impl russh::server::Handler for SshHandler { let session_info = &mut self.session_info; async move { + // Check if IP is banned (fail2ban-like check) + if let Some(ref limiter) = auth_rate_limiter { + if let Some(ip) = peer_addr.map(|a| a.ip()) { + if limiter.is_banned(&ip).await { + tracing::warn!( + user = %user, + peer = ?peer_addr, + "Rejected password auth from banned IP" + ); + return Ok(Auth::Reject { + proceed_with_methods: None, + partial_success: false, + }); + } + } + } + // Check if password auth is enabled if !allow_password { tracing::debug!( @@ -504,6 +601,13 @@ impl russh::server::Handler for SshHandler { info.authenticate(&user); } + // Record success to reset failure counter + if let Some(ref limiter) = auth_rate_limiter { + if let Some(ip) = peer_addr.map(|a| a.ip()) { + limiter.record_success(&ip).await; + } + } + Ok(Auth::Accept) } Ok(_) => { @@ -513,6 +617,20 @@ impl russh::server::Handler for SshHandler { "Password authentication rejected" ); + // Record failure for ban tracking + if let Some(ref limiter) = auth_rate_limiter { + if let Some(ip) = peer_addr.map(|a| a.ip()) { + let banned = limiter.record_failure(ip).await; + if banned { + tracing::warn!( + user = %user, + peer = ?peer_addr, + "IP banned due to too many failed password auth attempts" + ); + } + } + } + let proceed = if methods.is_empty() { None } else { @@ -532,6 +650,13 @@ impl russh::server::Handler for SshHandler { "Error during password verification" ); + // Record failure for ban tracking + if let Some(ref limiter) = auth_rate_limiter { + if let Some(ip) = peer_addr.map(|a| a.ip()) { + limiter.record_failure(ip).await; + } + } + let proceed = if methods.is_empty() { None } else { diff --git a/src/server/mod.rs b/src/server/mod.rs index b15af8e7..562e77d7 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -49,6 +49,7 @@ pub mod config; pub mod exec; pub mod handler; pub mod pty; +pub mod security; pub mod session; pub mod sftp; pub mod shell; @@ -69,6 +70,7 @@ pub use self::config::{ServerConfig, ServerConfigBuilder}; pub use self::exec::{CommandExecutor, ExecConfig}; pub use self::handler::SshHandler; pub use self::pty::{PtyConfig as PtyMasterConfig, PtyMaster}; +pub use self::security::{AuthRateLimitConfig, AuthRateLimiter}; pub use self::session::{ ChannelMode, ChannelState, PtyConfig, SessionId, SessionInfo, SessionManager, }; @@ -214,10 +216,28 @@ impl BsshServer { // This allows rapid testing while still providing protection against brute force let rate_limiter = RateLimiter::with_simple_config(100, 10.0); + // Create auth rate limiter with configuration + let auth_rate_limiter = AuthRateLimiter::new(AuthRateLimitConfig::new( + self.config.max_auth_attempts, + 300, // Default 5 minute window + 300, // Default 5 minute ban + )); + + // Start background cleanup task for auth rate limiter + let cleanup_limiter = auth_rate_limiter.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + loop { + interval.tick().await; + cleanup_limiter.cleanup().await; + } + }); + let mut server = BsshServerRunner { config: Arc::clone(&self.config), sessions: Arc::clone(&self.sessions), rate_limiter, + auth_rate_limiter, }; // Use run_on_socket which handles the server loop @@ -248,6 +268,8 @@ struct BsshServerRunner { sessions: Arc>, /// Shared rate limiter for authentication attempts across all handlers rate_limiter: RateLimiter, + /// Auth rate limiter with ban support (fail2ban-like) + auth_rate_limiter: AuthRateLimiter, } impl russh::server::Server for BsshServerRunner { @@ -259,11 +281,12 @@ impl russh::server::Server for BsshServerRunner { "New client connection" ); - SshHandler::with_rate_limiter( + SshHandler::with_rate_limiters( peer_addr, Arc::clone(&self.config), Arc::clone(&self.sessions), self.rate_limiter.clone(), + self.auth_rate_limiter.clone(), ) } diff --git a/src/server/security/mod.rs b/src/server/security/mod.rs new file mode 100644 index 00000000..bb4736a0 --- /dev/null +++ b/src/server/security/mod.rs @@ -0,0 +1,57 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Security module for bssh-server. +//! +//! This module provides security features including: +//! +//! - [`AuthRateLimiter`]: Authentication rate limiting with ban support (fail2ban-like) +//! +//! # Authentication Rate Limiting +//! +//! The `AuthRateLimiter` tracks failed authentication attempts per IP address +//! and automatically bans IPs that exceed the configured threshold. +//! +//! ## Example +//! +//! ``` +//! use bssh::server::security::{AuthRateLimiter, AuthRateLimitConfig}; +//! use std::net::IpAddr; +//! +//! #[tokio::main] +//! async fn main() { +//! let config = AuthRateLimitConfig::default(); +//! let limiter = AuthRateLimiter::new(config); +//! +//! let ip: IpAddr = "192.168.1.100".parse().unwrap(); +//! +//! // Check if banned before auth +//! if limiter.is_banned(&ip).await { +//! println!("IP is banned"); +//! return; +//! } +//! +//! // On auth failure +//! if limiter.record_failure(ip).await { +//! println!("IP has been banned after too many failures"); +//! } +//! +//! // On auth success +//! limiter.record_success(&ip).await; +//! } +//! ``` + +mod rate_limit; + +pub use rate_limit::{AuthRateLimitConfig, AuthRateLimiter}; diff --git a/src/server/security/rate_limit.rs b/src/server/security/rate_limit.rs new file mode 100644 index 00000000..e16701ba --- /dev/null +++ b/src/server/security/rate_limit.rs @@ -0,0 +1,613 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Authentication rate limiter with ban support. +//! +//! This module provides fail2ban-like functionality for protecting the SSH server +//! against brute-force attacks. It tracks failed authentication attempts per IP +//! and automatically bans IPs that exceed the configured threshold. + +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +/// Configuration for authentication rate limiting. +/// +/// This struct defines the parameters for the fail2ban-like functionality: +/// - `max_attempts`: Maximum failed attempts before ban +/// - `window`: Time window for counting attempts +/// - `ban_duration`: How long to ban an IP +/// - `whitelist`: IPs that are never banned +#[derive(Debug, Clone)] +pub struct AuthRateLimitConfig { + /// Maximum failed attempts before ban. + pub max_attempts: u32, + /// Time window for counting attempts. + pub window: Duration, + /// Ban duration after exceeding max attempts. + pub ban_duration: Duration, + /// Whitelist IPs (never banned). + pub whitelist: Vec, +} + +impl Default for AuthRateLimitConfig { + fn default() -> Self { + Self { + max_attempts: 5, + window: Duration::from_secs(300), // 5 minutes + ban_duration: Duration::from_secs(300), // 5 minutes + whitelist: vec![], + } + } +} + +impl AuthRateLimitConfig { + /// Create a new configuration with specified parameters. + /// + /// # Arguments + /// + /// * `max_attempts` - Maximum failed attempts before ban + /// * `window_secs` - Time window in seconds for counting attempts + /// * `ban_duration_secs` - Ban duration in seconds + pub fn new(max_attempts: u32, window_secs: u64, ban_duration_secs: u64) -> Self { + Self { + max_attempts, + window: Duration::from_secs(window_secs), + ban_duration: Duration::from_secs(ban_duration_secs), + whitelist: vec![], + } + } + + /// Add an IP to the whitelist. + pub fn add_whitelist(&mut self, ip: IpAddr) { + if !self.whitelist.contains(&ip) { + self.whitelist.push(ip); + } + } + + /// Set the whitelist from a list of IPs. + pub fn with_whitelist(mut self, whitelist: Vec) -> Self { + self.whitelist = whitelist; + self + } +} + +/// Record of failed authentication attempts for an IP. +#[derive(Debug)] +struct FailureRecord { + /// Number of failed attempts. + count: u32, + /// Timestamp of the first failure in the current window. + first_failure: Instant, + /// Timestamp of the most recent failure. + last_failure: Instant, +} + +/// Authentication rate limiter with ban support. +/// +/// This struct provides fail2ban-like functionality for the SSH server. +/// It tracks failed authentication attempts per IP address and automatically +/// bans IPs that exceed the configured maximum attempts within the time window. +/// +/// # Features +/// +/// - **Failure tracking**: Counts failed authentication attempts per IP +/// - **Automatic banning**: Bans IPs that exceed the threshold +/// - **Time-based window**: Failures outside the window are not counted +/// - **Configurable ban duration**: Bans expire after the configured time +/// - **IP whitelist**: Whitelisted IPs are never banned +/// - **Automatic cleanup**: Expired records are cleaned up periodically +/// +/// # Thread Safety +/// +/// The rate limiter is thread-safe and can be shared across async tasks. +#[derive(Debug)] +pub struct AuthRateLimiter { + /// Failed attempt records per IP. + failures: Arc>>, + /// Banned IPs with expiration time. + bans: Arc>>, + /// Configuration. + config: AuthRateLimitConfig, +} + +impl AuthRateLimiter { + /// Create a new authentication rate limiter with the given configuration. + pub fn new(config: AuthRateLimitConfig) -> Self { + Self { + failures: Arc::new(RwLock::new(HashMap::new())), + bans: Arc::new(RwLock::new(HashMap::new())), + config, + } + } + + /// Check if an IP address is currently banned. + /// + /// Returns `true` if the IP is banned and the ban has not expired. + /// Whitelisted IPs always return `false`. + pub async fn is_banned(&self, ip: &IpAddr) -> bool { + // Whitelisted IPs are never banned + if self.config.whitelist.contains(ip) { + return false; + } + + let bans = self.bans.read().await; + if let Some(expiry) = bans.get(ip) { + if Instant::now() < *expiry { + return true; + } + } + false + } + + /// Record a failed authentication attempt. + /// + /// Increments the failure count for the IP. If the IP exceeds the maximum + /// allowed attempts within the time window, it will be banned. + /// + /// # Returns + /// + /// Returns `true` if the IP was banned as a result of this failure, + /// `false` otherwise. + pub async fn record_failure(&self, ip: IpAddr) -> bool { + // Skip whitelisted IPs + if self.config.whitelist.contains(&ip) { + return false; + } + + let mut failures = self.failures.write().await; + let now = Instant::now(); + + let record = failures.entry(ip).or_insert_with(|| FailureRecord { + count: 0, + first_failure: now, + last_failure: now, + }); + + // Reset if window expired + if now.duration_since(record.first_failure) > self.config.window { + record.count = 1; + record.first_failure = now; + } else { + record.count += 1; + } + record.last_failure = now; + + // Check if should ban + if record.count >= self.config.max_attempts { + drop(failures); // Release lock before acquiring ban lock + self.ban(ip).await; + return true; + } + + false + } + + /// Record a successful authentication. + /// + /// Clears the failure record for the IP, allowing a fresh start. + pub async fn record_success(&self, ip: &IpAddr) { + let mut failures = self.failures.write().await; + failures.remove(ip); + } + + /// Ban an IP address. + /// + /// The IP will be banned for the configured ban duration. + /// Also clears the failure record for the IP. + pub async fn ban(&self, ip: IpAddr) { + tracing::warn!( + ip = %ip, + duration_secs = self.config.ban_duration.as_secs(), + "Banning IP due to too many failed auth attempts" + ); + + let mut bans = self.bans.write().await; + let expiry = Instant::now() + self.config.ban_duration; + bans.insert(ip, expiry); + + // Clean up failure record + drop(bans); + let mut failures = self.failures.write().await; + failures.remove(&ip); + } + + /// Manually unban an IP address. + pub async fn unban(&self, ip: &IpAddr) { + let mut bans = self.bans.write().await; + if bans.remove(ip).is_some() { + tracing::info!(ip = %ip, "Manually unbanned IP"); + } + } + + /// Get the remaining attempts before ban for an IP. + /// + /// Returns the maximum attempts if the IP has no failure record. + pub async fn remaining_attempts(&self, ip: &IpAddr) -> u32 { + let failures = self.failures.read().await; + if let Some(record) = failures.get(ip) { + let now = Instant::now(); + // If window expired, return max attempts + if now.duration_since(record.first_failure) > self.config.window { + return self.config.max_attempts; + } + self.config.max_attempts.saturating_sub(record.count) + } else { + self.config.max_attempts + } + } + + /// Clean up expired records. + /// + /// This should be called periodically to prevent unbounded memory growth. + /// It removes: + /// - Expired bans + /// - Failure records outside the time window + pub async fn cleanup(&self) { + let now = Instant::now(); + + // Clean expired bans + { + let mut bans = self.bans.write().await; + let before = bans.len(); + bans.retain(|_, expiry| now < *expiry); + let after = bans.len(); + if before > after { + tracing::debug!( + removed = before - after, + remaining = after, + "Cleaned up expired bans" + ); + } + } + + // Clean old failure records + { + let mut failures = self.failures.write().await; + let before = failures.len(); + failures.retain(|_, record| { + now.duration_since(record.last_failure) < self.config.window + }); + let after = failures.len(); + if before > after { + tracing::debug!( + removed = before - after, + remaining = after, + "Cleaned up expired failure records" + ); + } + } + } + + /// Get the current list of banned IPs with remaining ban duration. + pub async fn get_bans(&self) -> Vec<(IpAddr, Duration)> { + let now = Instant::now(); + let bans = self.bans.read().await; + bans.iter() + .filter_map(|(ip, expiry)| { + if now < *expiry { + Some((*ip, *expiry - now)) + } else { + None + } + }) + .collect() + } + + /// Get the number of currently banned IPs. + pub async fn banned_count(&self) -> usize { + let now = Instant::now(); + let bans = self.bans.read().await; + bans.values().filter(|expiry| now < **expiry).count() + } + + /// Get the number of IPs with failure records. + pub async fn tracked_count(&self) -> usize { + self.failures.read().await.len() + } + + /// Get the configuration. + pub fn config(&self) -> &AuthRateLimitConfig { + &self.config + } + + /// Check if an IP is whitelisted. + pub fn is_whitelisted(&self, ip: &IpAddr) -> bool { + self.config.whitelist.contains(ip) + } +} + +impl Clone for AuthRateLimiter { + fn clone(&self) -> Self { + Self { + failures: Arc::clone(&self.failures), + bans: Arc::clone(&self.bans), + config: self.config.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::Ipv4Addr; + + fn test_ip() -> IpAddr { + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)) + } + + fn test_ip2() -> IpAddr { + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 101)) + } + + fn localhost() -> IpAddr { + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)) + } + + #[tokio::test] + async fn test_failure_counting() { + let config = AuthRateLimitConfig::new(5, 300, 300); + let limiter = AuthRateLimiter::new(config); + + let ip = test_ip(); + + // Record failures without triggering ban + for i in 1..5 { + let banned = limiter.record_failure(ip).await; + assert!(!banned, "Should not be banned after {i} failures"); + assert_eq!( + limiter.remaining_attempts(&ip).await, + 5 - i, + "Should have {} remaining attempts", + 5 - i + ); + } + } + + #[tokio::test] + async fn test_ban_after_max_attempts() { + let config = AuthRateLimitConfig::new(3, 300, 300); + let limiter = AuthRateLimiter::new(config); + + let ip = test_ip(); + + // First two failures + assert!(!limiter.record_failure(ip).await); + assert!(!limiter.record_failure(ip).await); + assert!(!limiter.is_banned(&ip).await); + + // Third failure should trigger ban + assert!(limiter.record_failure(ip).await); + assert!(limiter.is_banned(&ip).await); + } + + #[tokio::test] + async fn test_ban_expiration() { + let config = AuthRateLimitConfig::new(2, 300, 0); // 0 second ban + let limiter = AuthRateLimiter::new(config); + + let ip = test_ip(); + + // Trigger ban + limiter.record_failure(ip).await; + assert!(limiter.record_failure(ip).await); + + // Ban should have expired immediately (or very quickly) + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(!limiter.is_banned(&ip).await); + } + + #[tokio::test] + async fn test_whitelist_ips() { + let config = AuthRateLimitConfig::new(1, 300, 300) + .with_whitelist(vec![localhost()]); + let limiter = AuthRateLimiter::new(config); + + let whitelisted = localhost(); + let not_whitelisted = test_ip(); + + // Whitelisted IP should never be banned + assert!(!limiter.record_failure(whitelisted).await); + assert!(!limiter.is_banned(&whitelisted).await); + + // Non-whitelisted should be banned after 1 failure + assert!(limiter.record_failure(not_whitelisted).await); + assert!(limiter.is_banned(¬_whitelisted).await); + + assert!(limiter.is_whitelisted(&whitelisted)); + assert!(!limiter.is_whitelisted(¬_whitelisted)); + } + + #[tokio::test] + async fn test_success_resets_failures() { + let config = AuthRateLimitConfig::new(3, 300, 300); + let limiter = AuthRateLimiter::new(config); + + let ip = test_ip(); + + // Record 2 failures + limiter.record_failure(ip).await; + limiter.record_failure(ip).await; + assert_eq!(limiter.remaining_attempts(&ip).await, 1); + + // Successful auth resets failures + limiter.record_success(&ip).await; + assert_eq!(limiter.remaining_attempts(&ip).await, 3); + + // Should need 3 more failures to ban + limiter.record_failure(ip).await; + limiter.record_failure(ip).await; + assert!(!limiter.is_banned(&ip).await); + limiter.record_failure(ip).await; + assert!(limiter.is_banned(&ip).await); + } + + #[tokio::test] + async fn test_window_expiration() { + // Use a very short window for testing + let config = AuthRateLimitConfig { + max_attempts: 3, + window: Duration::from_millis(50), + ban_duration: Duration::from_secs(300), + whitelist: vec![], + }; + let limiter = AuthRateLimiter::new(config); + + let ip = test_ip(); + + // Record 2 failures + limiter.record_failure(ip).await; + limiter.record_failure(ip).await; + assert_eq!(limiter.remaining_attempts(&ip).await, 1); + + // Wait for window to expire + tokio::time::sleep(Duration::from_millis(60)).await; + + // Window expired, should have full attempts again + assert_eq!(limiter.remaining_attempts(&ip).await, 3); + + // New failure should start fresh count + assert!(!limiter.record_failure(ip).await); + } + + #[tokio::test] + async fn test_cleanup() { + let config = AuthRateLimitConfig { + max_attempts: 2, + window: Duration::from_millis(10), + ban_duration: Duration::from_millis(10), + whitelist: vec![], + }; + let limiter = AuthRateLimiter::new(config); + + let ip1 = test_ip(); + let ip2 = test_ip2(); + + // Create some records + limiter.record_failure(ip1).await; + limiter.record_failure(ip2).await; + limiter.record_failure(ip2).await; // This triggers ban + + assert_eq!(limiter.tracked_count().await, 1); // ip1 still tracked + assert_eq!(limiter.banned_count().await, 1); // ip2 banned + + // Wait for records to expire + tokio::time::sleep(Duration::from_millis(20)).await; + + // Cleanup + limiter.cleanup().await; + + assert_eq!(limiter.tracked_count().await, 0); + assert_eq!(limiter.banned_count().await, 0); + } + + #[tokio::test] + async fn test_manual_ban_unban() { + let config = AuthRateLimitConfig::new(5, 300, 300); + let limiter = AuthRateLimiter::new(config); + + let ip = test_ip(); + + // Manual ban + limiter.ban(ip).await; + assert!(limiter.is_banned(&ip).await); + + // Manual unban + limiter.unban(&ip).await; + assert!(!limiter.is_banned(&ip).await); + } + + #[tokio::test] + async fn test_get_bans() { + let config = AuthRateLimitConfig::new(1, 300, 300); + let limiter = AuthRateLimiter::new(config); + + let ip1 = test_ip(); + let ip2 = test_ip2(); + + // Ban two IPs + limiter.record_failure(ip1).await; + limiter.record_failure(ip2).await; + + let bans = limiter.get_bans().await; + assert_eq!(bans.len(), 2); + + // Check that both IPs are in the list + let ips: Vec = bans.iter().map(|(ip, _)| *ip).collect(); + assert!(ips.contains(&ip1)); + assert!(ips.contains(&ip2)); + + // Check that remaining durations are positive + for (_, duration) in &bans { + assert!(duration.as_secs() > 0); + } + } + + #[tokio::test] + async fn test_clone_shares_state() { + let config = AuthRateLimitConfig::new(3, 300, 300); + let limiter1 = AuthRateLimiter::new(config); + let limiter2 = limiter1.clone(); + + let ip = test_ip(); + + // Record failures on limiter1 + limiter1.record_failure(ip).await; + limiter1.record_failure(ip).await; + + // limiter2 should see the same state + assert_eq!(limiter2.remaining_attempts(&ip).await, 1); + + // Ban via limiter2 + limiter2.record_failure(ip).await; + + // limiter1 should see the ban + assert!(limiter1.is_banned(&ip).await); + } + + #[tokio::test] + async fn test_per_ip_isolation() { + let config = AuthRateLimitConfig::new(2, 300, 300); + let limiter = AuthRateLimiter::new(config); + + let ip1 = test_ip(); + let ip2 = test_ip2(); + + // Record failure for ip1 + limiter.record_failure(ip1).await; + + // ip2 should be unaffected + assert_eq!(limiter.remaining_attempts(&ip2).await, 2); + assert!(!limiter.is_banned(&ip2).await); + + // Ban ip1 + limiter.record_failure(ip1).await; + assert!(limiter.is_banned(&ip1).await); + + // ip2 still unaffected + assert!(!limiter.is_banned(&ip2).await); + } + + #[tokio::test] + async fn test_config_accessors() { + let config = AuthRateLimitConfig::new(10, 600, 1800); + let limiter = AuthRateLimiter::new(config); + + assert_eq!(limiter.config().max_attempts, 10); + assert_eq!(limiter.config().window.as_secs(), 600); + assert_eq!(limiter.config().ban_duration.as_secs(), 1800); + } +} From 0dee594e68e5c55b4a3ca5d4339dea9114d932fb Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Sat, 24 Jan 2026 12:35:37 +0900 Subject: [PATCH 2/3] fix: Address PR review feedback for auth rate limiting - Use configuration values instead of hardcoded values for auth_window and ban_time - Integrate whitelist_ips from configuration with validation and logging - Fix TOCTOU race condition in record_failure by removing entry atomically - Add capacity limit (max_tracked_ips) to prevent memory exhaustion DoS - Use HashSet for whitelist O(1) lookups instead of Vec O(n) - Add auth rate limit config fields to ServerConfig - Propagate security config from ServerFileConfig to ServerConfig - Add test for capacity limit enforcement --- src/server/config/mod.rs | 30 +++++++ src/server/mod.rs | 31 ++++++- src/server/security/rate_limit.rs | 143 ++++++++++++++++++++++-------- 3 files changed, 165 insertions(+), 39 deletions(-) diff --git a/src/server/config/mod.rs b/src/server/config/mod.rs index e5d36d15..12c40dc9 100644 --- a/src/server/config/mod.rs +++ b/src/server/config/mod.rs @@ -149,6 +149,22 @@ pub struct ServerConfig { /// Configuration for command execution. #[serde(default)] pub exec: ExecConfig, + + /// Time window for counting authentication attempts in seconds. + /// + /// Default: 300 (5 minutes) + #[serde(default = "default_auth_window_secs")] + pub auth_window_secs: u64, + + /// Ban duration in seconds after exceeding max auth attempts. + /// + /// Default: 300 (5 minutes) + #[serde(default = "default_ban_time_secs")] + pub ban_time_secs: u64, + + /// IP addresses that are never banned (whitelist). + #[serde(default)] + pub whitelist_ips: Vec, } /// Serializable configuration for public key authentication. @@ -213,6 +229,14 @@ fn default_idle_timeout_secs() -> u64 { 0 // 0 means no timeout } +fn default_auth_window_secs() -> u64 { + 300 // 5 minutes +} + +fn default_ban_time_secs() -> u64 { + 300 // 5 minutes +} + fn default_true() -> bool { true } @@ -233,6 +257,9 @@ impl Default for ServerConfig { publickey_auth: PublicKeyAuthConfigSerde::default(), password_auth: PasswordAuthConfigSerde::default(), exec: ExecConfig::default(), + auth_window_secs: default_auth_window_secs(), + ban_time_secs: default_ban_time_secs(), + whitelist_ips: Vec::new(), } } } @@ -521,6 +548,9 @@ impl ServerFileConfig { allowed_commands: None, blocked_commands: Vec::new(), }, + auth_window_secs: self.security.auth_window, + ban_time_secs: self.security.ban_time, + whitelist_ips: self.security.whitelist_ips, } } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 562e77d7..e9b413e5 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -217,11 +217,34 @@ impl BsshServer { let rate_limiter = RateLimiter::with_simple_config(100, 10.0); // Create auth rate limiter with configuration - let auth_rate_limiter = AuthRateLimiter::new(AuthRateLimitConfig::new( + // Parse whitelist IPs from config + let whitelist_ips: Vec = self + .config + .whitelist_ips + .iter() + .filter_map(|s| { + s.parse().map_err(|e| { + tracing::warn!(ip = %s, error = %e, "Invalid whitelist IP address in config, skipping"); + e + }).ok() + }) + .collect(); + + let auth_config = AuthRateLimitConfig::new( self.config.max_auth_attempts, - 300, // Default 5 minute window - 300, // Default 5 minute ban - )); + self.config.auth_window_secs, + self.config.ban_time_secs, + ).with_whitelist(whitelist_ips); + + let auth_rate_limiter = AuthRateLimiter::new(auth_config); + + tracing::info!( + max_attempts = self.config.max_auth_attempts, + auth_window_secs = self.config.auth_window_secs, + ban_time_secs = self.config.ban_time_secs, + whitelist_count = self.config.whitelist_ips.len(), + "Auth rate limiter configured" + ); // Start background cleanup task for auth rate limiter let cleanup_limiter = auth_rate_limiter.clone(); diff --git a/src/server/security/rate_limit.rs b/src/server/security/rate_limit.rs index e16701ba..a68787d8 100644 --- a/src/server/security/rate_limit.rs +++ b/src/server/security/rate_limit.rs @@ -18,7 +18,7 @@ //! against brute-force attacks. It tracks failed authentication attempts per IP //! and automatically bans IPs that exceed the configured threshold. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -31,6 +31,7 @@ use tokio::sync::RwLock; /// - `window`: Time window for counting attempts /// - `ban_duration`: How long to ban an IP /// - `whitelist`: IPs that are never banned +/// - `max_tracked_ips`: Maximum IPs to track (prevents memory exhaustion) #[derive(Debug, Clone)] pub struct AuthRateLimitConfig { /// Maximum failed attempts before ban. @@ -39,8 +40,11 @@ pub struct AuthRateLimitConfig { pub window: Duration, /// Ban duration after exceeding max attempts. pub ban_duration: Duration, - /// Whitelist IPs (never banned). - pub whitelist: Vec, + /// Whitelist IPs (never banned). Uses HashSet for O(1) lookups. + pub whitelist: HashSet, + /// Maximum number of IPs to track (prevents memory exhaustion). + /// When exceeded, oldest entries are removed. + pub max_tracked_ips: usize, } impl Default for AuthRateLimitConfig { @@ -49,7 +53,8 @@ impl Default for AuthRateLimitConfig { max_attempts: 5, window: Duration::from_secs(300), // 5 minutes ban_duration: Duration::from_secs(300), // 5 minutes - whitelist: vec![], + whitelist: HashSet::new(), + max_tracked_ips: 10000, // Limit memory usage } } } @@ -67,20 +72,25 @@ impl AuthRateLimitConfig { max_attempts, window: Duration::from_secs(window_secs), ban_duration: Duration::from_secs(ban_duration_secs), - whitelist: vec![], + whitelist: HashSet::new(), + max_tracked_ips: 10000, } } /// Add an IP to the whitelist. pub fn add_whitelist(&mut self, ip: IpAddr) { - if !self.whitelist.contains(&ip) { - self.whitelist.push(ip); - } + self.whitelist.insert(ip); } /// Set the whitelist from a list of IPs. pub fn with_whitelist(mut self, whitelist: Vec) -> Self { - self.whitelist = whitelist; + self.whitelist = whitelist.into_iter().collect(); + self + } + + /// Set the maximum number of IPs to track. + pub fn with_max_tracked_ips(mut self, max: usize) -> Self { + self.max_tracked_ips = max; self } } @@ -168,28 +178,57 @@ impl AuthRateLimiter { return false; } - let mut failures = self.failures.write().await; - let now = Instant::now(); + let should_ban; + { + let mut failures = self.failures.write().await; + let now = Instant::now(); - let record = failures.entry(ip).or_insert_with(|| FailureRecord { - count: 0, - first_failure: now, - last_failure: now, - }); + // Enforce capacity limit to prevent memory exhaustion + // If at capacity and this is a new IP, remove oldest entry + if failures.len() >= self.config.max_tracked_ips && !failures.contains_key(&ip) { + // Find and remove the oldest entry by last_failure time + if let Some(oldest_ip) = failures + .iter() + .min_by_key(|(_, record)| record.last_failure) + .map(|(ip, _)| *ip) + { + failures.remove(&oldest_ip); + tracing::debug!( + removed_ip = %oldest_ip, + capacity = self.config.max_tracked_ips, + "Removed oldest failure record due to capacity limit" + ); + } + } - // Reset if window expired - if now.duration_since(record.first_failure) > self.config.window { - record.count = 1; - record.first_failure = now; - } else { - record.count += 1; - } - record.last_failure = now; + let record = failures.entry(ip).or_insert_with(|| FailureRecord { + count: 0, + first_failure: now, + last_failure: now, + }); + + // Reset if window expired + if now.duration_since(record.first_failure) > self.config.window { + record.count = 1; + record.first_failure = now; + } else { + record.count += 1; + } + record.last_failure = now; + + // Check if should ban - record the decision while holding the lock + should_ban = record.count >= self.config.max_attempts; + + // If banning, remove from failures while we still hold the lock + // This prevents race conditions with concurrent record_failure calls + if should_ban { + failures.remove(&ip); + } + } // failures lock released here - // Check if should ban - if record.count >= self.config.max_attempts { - drop(failures); // Release lock before acquiring ban lock - self.ban(ip).await; + // Now apply the ban if needed + if should_ban { + self.ban_internal(ip).await; return true; } @@ -209,6 +248,18 @@ impl AuthRateLimiter { /// The IP will be banned for the configured ban duration. /// Also clears the failure record for the IP. pub async fn ban(&self, ip: IpAddr) { + // Clean up failure record first + { + let mut failures = self.failures.write().await; + failures.remove(&ip); + } + + self.ban_internal(ip).await; + } + + /// Internal method to apply a ban without modifying failure records. + /// Used by record_failure which has already cleaned up the failure record. + async fn ban_internal(&self, ip: IpAddr) { tracing::warn!( ip = %ip, duration_secs = self.config.ban_duration.as_secs(), @@ -218,11 +269,6 @@ impl AuthRateLimiter { let mut bans = self.bans.write().await; let expiry = Instant::now() + self.config.ban_duration; bans.insert(ip, expiry); - - // Clean up failure record - drop(bans); - let mut failures = self.failures.write().await; - failures.remove(&ip); } /// Manually unban an IP address. @@ -462,7 +508,8 @@ mod tests { max_attempts: 3, window: Duration::from_millis(50), ban_duration: Duration::from_secs(300), - whitelist: vec![], + whitelist: HashSet::new(), + max_tracked_ips: 10000, }; let limiter = AuthRateLimiter::new(config); @@ -489,7 +536,8 @@ mod tests { max_attempts: 2, window: Duration::from_millis(10), ban_duration: Duration::from_millis(10), - whitelist: vec![], + whitelist: HashSet::new(), + max_tracked_ips: 10000, }; let limiter = AuthRateLimiter::new(config); @@ -610,4 +658,29 @@ mod tests { assert_eq!(limiter.config().window.as_secs(), 600); assert_eq!(limiter.config().ban_duration.as_secs(), 1800); } + + #[tokio::test] + async fn test_capacity_limit() { + // Test that capacity limit prevents unbounded memory growth + let config = AuthRateLimitConfig::new(5, 300, 300).with_max_tracked_ips(3); + let limiter = AuthRateLimiter::new(config); + + let ip1: IpAddr = "192.168.1.1".parse().unwrap(); + let ip2: IpAddr = "192.168.1.2".parse().unwrap(); + let ip3: IpAddr = "192.168.1.3".parse().unwrap(); + let ip4: IpAddr = "192.168.1.4".parse().unwrap(); + + // Record failures for first 3 IPs + limiter.record_failure(ip1).await; + limiter.record_failure(ip2).await; + limiter.record_failure(ip3).await; + assert_eq!(limiter.tracked_count().await, 3); + + // Recording for 4th IP should evict the oldest + limiter.record_failure(ip4).await; + assert_eq!(limiter.tracked_count().await, 3); + + // ip4 should be tracked, ip1 should be evicted (it was oldest) + assert_eq!(limiter.remaining_attempts(&ip4).await, 4); + } } From 082ffc649b7802cc21187b46f02bac738a65e373 Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Sat, 24 Jan 2026 12:38:24 +0900 Subject: [PATCH 3/3] chore: finalize auth rate limiter with docs and formatting fixes - Fix code formatting (cargo fmt) - Update ARCHITECTURE.md with Server Security Module documentation - Update server-configuration.md with auth_window and whitelist_ips options - All 930 tests passing, clippy clean --- ARCHITECTURE.md | 17 ++++++++++++++++- docs/architecture/server-configuration.md | 10 ++++++++++ src/server/mod.rs | 3 ++- src/server/security/rate_limit.rs | 12 +++++------- 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 6e61fd5b..dd9bf342 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -189,6 +189,20 @@ Common utilities for code reuse between bssh client and server implementations: The `security` and `jump::rate_limiter` modules re-export from shared for backward compatibility. +### Server Security Module + +Security features for the SSH server (`src/server/security/`): + +- **AuthRateLimiter**: Fail2ban-like authentication rate limiting + - Tracks failed authentication attempts per IP address + - Automatic banning after exceeding configurable threshold + - Time-windowed failure counting (failures outside window not counted) + - Configurable ban duration with automatic expiration + - IP whitelist for exempting trusted addresses from banning + - Memory-safe with configurable maximum tracked IPs + - Automatic cleanup of expired records via background task + - Thread-safe async implementation with `Arc>` + ### Server CLI Binary **Binary**: `bssh-server` @@ -284,7 +298,8 @@ SSH server implementation using the russh library for accepting incoming connect - **SshHandler**: Per-connection handler for SSH protocol events - Public key authentication via AuthProvider trait - - Rate limiting for authentication attempts + - Rate limiting for authentication attempts (token bucket) + - Auth rate limiting with ban support (fail2ban-like) - Channel operations (open, close, EOF, data) - PTY, exec, shell, and subsystem request handling - Command execution with stdout/stderr streaming diff --git a/docs/architecture/server-configuration.md b/docs/architecture/server-configuration.md index 2b466e96..9b27cbe8 100644 --- a/docs/architecture/server-configuration.md +++ b/docs/architecture/server-configuration.md @@ -173,9 +173,19 @@ security: # Max auth attempts before banning IP max_auth_attempts: 5 # Default: 5 + # Time window for counting auth attempts (seconds) + # Failed attempts outside this window are not counted + auth_window: 300 # Default: 300 (5 minutes) + # Ban duration after exceeding max attempts (seconds) ban_time: 300 # Default: 300 (5 minutes) + # IPs that are never banned (whitelist) + # These IPs are exempt from rate limiting and banning + whitelist_ips: + - "127.0.0.1" + - "::1" + # Max concurrent sessions per user max_sessions_per_user: 10 # Default: 10 diff --git a/src/server/mod.rs b/src/server/mod.rs index e9b413e5..b6a6fd9d 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -234,7 +234,8 @@ impl BsshServer { self.config.max_auth_attempts, self.config.auth_window_secs, self.config.ban_time_secs, - ).with_whitelist(whitelist_ips); + ) + .with_whitelist(whitelist_ips); let auth_rate_limiter = AuthRateLimiter::new(auth_config); diff --git a/src/server/security/rate_limit.rs b/src/server/security/rate_limit.rs index a68787d8..aea6bfff 100644 --- a/src/server/security/rate_limit.rs +++ b/src/server/security/rate_limit.rs @@ -51,7 +51,7 @@ impl Default for AuthRateLimitConfig { fn default() -> Self { Self { max_attempts: 5, - window: Duration::from_secs(300), // 5 minutes + window: Duration::from_secs(300), // 5 minutes ban_duration: Duration::from_secs(300), // 5 minutes whitelist: HashSet::new(), max_tracked_ips: 10000, // Limit memory usage @@ -324,9 +324,8 @@ impl AuthRateLimiter { { let mut failures = self.failures.write().await; let before = failures.len(); - failures.retain(|_, record| { - now.duration_since(record.last_failure) < self.config.window - }); + failures + .retain(|_, record| now.duration_since(record.last_failure) < self.config.window); let after = failures.len(); if before > after { tracing::debug!( @@ -458,8 +457,7 @@ mod tests { #[tokio::test] async fn test_whitelist_ips() { - let config = AuthRateLimitConfig::new(1, 300, 300) - .with_whitelist(vec![localhost()]); + let config = AuthRateLimitConfig::new(1, 300, 300).with_whitelist(vec![localhost()]); let limiter = AuthRateLimiter::new(config); let whitelisted = localhost(); @@ -550,7 +548,7 @@ mod tests { limiter.record_failure(ip2).await; // This triggers ban assert_eq!(limiter.tracked_count().await, 1); // ip1 still tracked - assert_eq!(limiter.banned_count().await, 1); // ip2 banned + assert_eq!(limiter.banned_count().await, 1); // ip2 banned // Wait for records to expire tokio::time::sleep(Duration::from_millis(20)).await;