Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,17 @@ MPI-compatible exit code handling:
- Automatic main rank detection (Backend.AI integration)
- Preserves actual exit codes (SIGSEGV=139, OOM=137, etc.)

### Shared Module

Common utilities for code reuse between bssh client and potential server implementations:

- **Validation**: Input validation for usernames, hostnames, paths with security checks
- **Rate Limiting**: Generic token bucket rate limiter for connection/auth throttling
- **Authentication Types**: Common auth result types and user info structures
- **Error Types**: Shared error types for validation, auth, connection, and rate limiting

The `security` and `jump::rate_limiter` modules re-export from shared for backward compatibility.

## Data Flow

### Command Execution Flow
Expand Down
2 changes: 2 additions & 0 deletions docs/architecture/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ src/
├── interactive/ → Interactive Mode
├── jump/ → Jump Host Support
├── forward/ → Port Forwarding
├── shared/ → Shared utilities (validation, rate limiting, auth types, errors)
├── security/ → Security utilities (re-exports from shared for compatibility)
└── commands/ → Command Implementations
```

Expand Down
4 changes: 2 additions & 2 deletions src/jump/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl JumpHostChain {
/// * `max_burst` - Maximum number of connections allowed in a burst
/// * `refill_rate` - Number of connections allowed per second (sustained rate)
pub fn with_rate_limit(mut self, max_burst: u32, refill_rate: f64) -> Self {
self.rate_limiter = ConnectionRateLimiter::with_config(max_burst, refill_rate);
self.rate_limiter = ConnectionRateLimiter::with_simple_config(max_burst, refill_rate);
self
}

Expand Down Expand Up @@ -354,7 +354,7 @@ impl JumpHostChain {

// Apply rate limiting to prevent DoS attacks on jump hosts
self.rate_limiter
.try_acquire(&jump_host.host)
.try_acquire(&jump_host.host.clone())
.await
.with_context(|| format!("Rate limited for jump host {}", jump_host.host))?;

Expand Down
2 changes: 1 addition & 1 deletion src/jump/chain/chain_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub(super) async fn connect_direct(

// Apply rate limiting to prevent DoS attacks
rate_limiter
.try_acquire(host)
.try_acquire(&host.to_string())
.await
.with_context(|| format!("Rate limited for host {host}"))?;

Expand Down
4 changes: 2 additions & 2 deletions src/jump/chain/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub(super) async fn connect_through_tunnel(

// Apply rate limiting for intermediate jump hosts
rate_limiter
.try_acquire(&jump_host.host)
.try_acquire(&jump_host.host.clone())
.await
.with_context(|| format!("Rate limited for jump host {}", jump_host.host))?;

Expand Down Expand Up @@ -184,7 +184,7 @@ pub(super) async fn connect_to_destination(

// Apply rate limiting for final destination
rate_limiter
.try_acquire(destination_host)
.try_acquire(&destination_host.to_string())
.await
.with_context(|| format!("Rate limited for destination {destination_host}"))?;

Expand Down
9 changes: 9 additions & 0 deletions src/jump/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@
//! * Connection reuse for multiple operations
//! * Automatic retry with exponential backoff
//! * Integration with existing host verification and authentication
//!
//! # Rate Limiter
//!
//! The rate limiter has been moved to `crate::shared::rate_limit` to enable
//! code reuse between the bssh client and server implementations. This module
//! continues to re-export it for backward compatibility.

pub mod chain;
pub mod connection;
pub mod parser;

// Keep the rate_limiter module for backward compatibility but it now re-exports
// from shared
pub mod rate_limiter;

pub use chain::{JumpConnection, JumpHostChain};
Expand Down
194 changes: 43 additions & 151 deletions src/jump/rate_limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,186 +12,78 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::{bail, Result};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::warn;

/// Token bucket rate limiter for connection attempts
///
/// Prevents DoS attacks by limiting the rate of connection attempts
/// per host. Uses a token bucket algorithm with configurable capacity
/// and refill rate.
#[derive(Debug, Clone)]
pub struct ConnectionRateLimiter {
/// Token buckets per host
buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
/// Maximum tokens per bucket (burst capacity)
max_tokens: u32,
/// Tokens refilled per second
refill_rate: f64,
/// Duration after which inactive buckets are cleaned up
cleanup_after: Duration,
}

#[derive(Debug)]
struct TokenBucket {
/// Current token count
tokens: f64,
/// Last refill timestamp
last_refill: Instant,
/// Last access timestamp (for cleanup)
last_access: Instant,
}

impl ConnectionRateLimiter {
/// Create a new rate limiter with default settings
///
/// Default: 10 connections burst, 2 connections/second sustained
pub fn new() -> Self {
Self {
buckets: Arc::new(RwLock::new(HashMap::new())),
max_tokens: 10, // Allow burst of 10 connections
refill_rate: 2.0, // 2 connections per second sustained
cleanup_after: Duration::from_secs(300), // Clean up after 5 minutes
}
}

/// Create a new rate limiter with custom settings
pub fn with_config(max_tokens: u32, refill_rate: f64) -> Self {
Self {
buckets: Arc::new(RwLock::new(HashMap::new())),
max_tokens,
refill_rate,
cleanup_after: Duration::from_secs(300),
}
}

/// Try to acquire a token for a connection attempt
///
/// Returns Ok(()) if a token was acquired, or an error if rate limited
pub async fn try_acquire(&self, host: &str) -> Result<()> {
let mut buckets = self.buckets.write().await;
let now = Instant::now();

// Clean up old buckets periodically
if buckets.len() > 100 {
self.cleanup_old_buckets(&mut buckets, now);
}

let bucket = buckets
.entry(host.to_string())
.or_insert_with(|| TokenBucket {
tokens: self.max_tokens as f64,
last_refill: now,
last_access: now,
});

// Refill tokens based on time elapsed
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
let tokens_to_add = elapsed * self.refill_rate;
bucket.tokens = (bucket.tokens + tokens_to_add).min(self.max_tokens as f64);
bucket.last_refill = now;
bucket.last_access = now;

// Try to consume a token
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
Ok(())
} else {
let wait_time = (1.0 - bucket.tokens) / self.refill_rate;
warn!(
"Rate limit exceeded for host {}: wait {:.1}s before retry",
host, wait_time
);
bail!(
"Connection rate limit exceeded for {host}. Please wait {wait_time:.1} seconds before retrying."
)
}
}

/// Check if a host is currently rate limited without consuming a token
pub async fn is_rate_limited(&self, host: &str) -> bool {
let buckets = self.buckets.read().await;
if let Some(bucket) = buckets.get(host) {
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
let tokens_available =
(bucket.tokens + elapsed * self.refill_rate).min(self.max_tokens as f64);
tokens_available < 1.0
} else {
false
}
}

/// Clean up old token buckets that haven't been used recently
fn cleanup_old_buckets(&self, buckets: &mut HashMap<String, TokenBucket>, now: Instant) {
buckets.retain(|_host, bucket| now.duration_since(bucket.last_access) < self.cleanup_after);
}

/// Reset rate limit for a specific host (useful for testing or admin override)
pub async fn reset_host(&self, host: &str) {
let mut buckets = self.buckets.write().await;
buckets.remove(host);
}

/// Clear all rate limit data
pub async fn clear_all(&self) {
let mut buckets = self.buckets.write().await;
buckets.clear();
}
}

impl Default for ConnectionRateLimiter {
fn default() -> Self {
Self::new()
}
}
//! Token bucket rate limiter for connection attempts.
//!
//! This module re-exports the rate limiter from the shared module for
//! backward compatibility. New code should prefer importing directly from
//! `crate::shared::rate_limit`.
//!
//! # Migration Note
//!
//! The rate limiter has been moved to `crate::shared::rate_limit` and
//! generalized to work with any hashable key type. This module continues
//! to export `ConnectionRateLimiter` (which is `RateLimiter<String>`) for
//! backward compatibility.
//!
//! # Examples
//!
//! ```rust
//! // Old style (still works)
//! use bssh::jump::rate_limiter::ConnectionRateLimiter;
//!
//! // New style (preferred for new code)
//! use bssh::shared::rate_limit::{RateLimiter, RateLimitConfig};
//! ```

// Re-export the ConnectionRateLimiter type alias for backward compatibility
pub use crate::shared::rate_limit::ConnectionRateLimiter;

// Also re-export RateLimitConfig for users who want to configure the limiter
pub use crate::shared::rate_limit::RateLimitConfig;

#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;

#[tokio::test]
async fn test_rate_limiter_allows_burst() {
let limiter = ConnectionRateLimiter::with_config(3, 1.0);
let limiter = ConnectionRateLimiter::with_simple_config(3, 1.0);

// Should allow 3 connections in burst
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());

// 4th should fail
assert!(limiter.try_acquire("test.com").await.is_err());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_err());
}

#[tokio::test]
async fn test_rate_limiter_refills() {
let limiter = ConnectionRateLimiter::with_config(2, 10.0); // Fast refill for testing
let limiter = ConnectionRateLimiter::with_simple_config(2, 10.0); // Fast refill for testing

// Use up tokens
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire("test.com").await.is_err());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_err());

// Wait for refill
tokio::time::sleep(Duration::from_millis(150)).await;

// Should have refilled
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
}

#[tokio::test]
async fn test_rate_limiter_per_host() {
let limiter = ConnectionRateLimiter::with_config(1, 1.0);
let limiter = ConnectionRateLimiter::with_simple_config(1, 1.0);

// Different hosts should have separate buckets
assert!(limiter.try_acquire("host1.com").await.is_ok());
assert!(limiter.try_acquire("host2.com").await.is_ok());
assert!(limiter.try_acquire(&"host1.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"host2.com".to_string()).await.is_ok());

// But same host should be limited
assert!(limiter.try_acquire("host1.com").await.is_err());
assert!(limiter.try_acquire(&"host1.com".to_string()).await.is_err());
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub mod jump;
pub mod node;
pub mod pty;
pub mod security;
pub mod shared;
pub mod ssh;
pub mod ui;
pub mod utils;
Expand Down
15 changes: 12 additions & 3 deletions src/security/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@

//! Security utilities for validating and sanitizing user input and handling
//! sensitive data securely.
//!
//! # Re-exports
//!
//! This module re-exports validation functions from the shared module for
//! backward compatibility. New code should prefer importing directly from
//! `crate::shared::validation`.

mod sudo;
mod validation;

// Re-export validation functions
pub use validation::{
// Keep the validation module for backward compatibility but it now re-exports
// from shared
pub mod validation;

// Re-export validation functions from shared module for backward compatibility
pub use crate::shared::validation::{
sanitize_error_message, validate_hostname, validate_local_path, validate_remote_path,
validate_username,
};
Expand Down
Loading