diff --git a/src/cortex-cli/src/utils/paths.rs b/src/cortex-cli/src/utils/paths.rs index 8cdf03e..c9654ff 100644 --- a/src/cortex-cli/src/utils/paths.rs +++ b/src/cortex-cli/src/utils/paths.rs @@ -34,10 +34,16 @@ pub fn get_cortex_home() -> PathBuf { /// // Returns: /home/user/documents/file.txt /// ``` pub fn expand_tilde(path: &str) -> String { - if path.starts_with("~/") - && let Some(home) = dirs::home_dir() - { - return home.join(&path[2..]).to_string_lossy().to_string(); + if path == "~" { + // Handle bare "~" - return home directory + if let Some(home) = dirs::home_dir() { + return home.to_string_lossy().to_string(); + } + } else if let Some(suffix) = path.strip_prefix("~/") { + // Handle "~/" prefix - expand to home directory + rest of path + if let Some(home) = dirs::home_dir() { + return home.join(suffix).to_string_lossy().to_string(); + } } path.to_string() } @@ -58,8 +64,12 @@ pub fn expand_tilde(path: &str) -> String { pub fn validate_path_safety(path: &Path, base_dir: Option<&Path>) -> Result<(), String> { let path_str = path.to_string_lossy(); - // Check for path traversal attempts - if path_str.contains("..") { + // Check for path traversal attempts by examining path components + // This correctly handles filenames containing ".." like "file..txt" + if path + .components() + .any(|c| matches!(c, std::path::Component::ParentDir)) + { return Err("Path contains traversal sequence '..'".to_string()); } @@ -257,8 +267,15 @@ mod tests { #[test] fn test_expand_tilde_with_tilde_only() { - // Test tilde alone - should remain unchanged (not "~/") - assert_eq!(expand_tilde("~"), "~"); + // Test bare "~" - should expand to home directory + let result = expand_tilde("~"); + if let Some(home) = dirs::home_dir() { + let expected = home.to_string_lossy().to_string(); + assert_eq!(result, expected); + } else { + // If no home dir, original is returned + assert_eq!(result, "~"); + } } #[test] @@ -320,20 +337,30 @@ mod tests { #[test] fn test_validate_path_safety_detects_various_traversal_patterns() { - // Different traversal patterns - let patterns = ["foo/../bar", "...", "foo/bar/../baz", "./foo/../../../etc"]; + // Patterns that ARE path traversal (contain ".." as a component) + let traversal_patterns = ["foo/../bar", "foo/bar/../baz", "./foo/../../../etc", ".."]; - for pattern in patterns { + for pattern in traversal_patterns { let path = Path::new(pattern); let result = validate_path_safety(path, None); - // Only patterns containing ".." should fail - if pattern.contains("..") { - assert!( - result.is_err(), - "Expected traversal detection for: {}", - pattern - ); - } + assert!( + result.is_err(), + "Expected traversal detection for: {}", + pattern + ); + } + + // Patterns that are NOT path traversal (contain ".." in filenames only) + let safe_patterns = ["file..txt", "..hidden", "test...file", "foo/bar..baz/file"]; + + for pattern in safe_patterns { + let path = Path::new(pattern); + let result = validate_path_safety(path, None); + assert!( + result.is_ok(), + "False positive: '{}' should not be detected as traversal", + pattern + ); } } diff --git a/src/cortex-common/src/file_locking.rs b/src/cortex-common/src/file_locking.rs index f9b78db..d2b4f73 100644 --- a/src/cortex-common/src/file_locking.rs +++ b/src/cortex-common/src/file_locking.rs @@ -557,6 +557,9 @@ pub async fn atomic_write_async( .map_err(|e| FileLockError::AtomicWriteFailed(format!("spawn_blocking failed: {}", e)))? } +/// Maximum number of lock entries before triggering cleanup. +const MAX_LOCK_ENTRIES: usize = 10_000; + /// A file lock manager for coordinating access across multiple operations. /// /// This is useful when you need to perform multiple operations on a file @@ -577,15 +580,47 @@ impl FileLockManager { /// /// This is in addition to the filesystem-level advisory lock and helps /// coordinate access within the same process. + /// + /// Automatically cleans up stale lock entries when the map grows too large. pub fn get_lock(&self, path: impl AsRef) -> Arc> { let path = path.as_ref().to_path_buf(); let mut locks = self.locks.lock().unwrap(); + + // Clean up stale entries if the map is getting large + if locks.len() >= MAX_LOCK_ENTRIES { + Self::cleanup_stale_entries(&mut locks); + } + locks .entry(path) .or_insert_with(|| Arc::new(std::sync::Mutex::new(()))) .clone() } + /// Remove lock entries that are no longer in use. + /// + /// An entry is considered stale when only the HashMap holds a reference + /// to it (strong_count == 1), meaning no caller is currently using the lock. + fn cleanup_stale_entries( + locks: &mut std::collections::HashMap>>, + ) { + locks.retain(|_, arc| Arc::strong_count(arc) > 1); + } + + /// Manually trigger cleanup of stale lock entries. + /// + /// This removes entries where no external reference exists (only the + /// manager holds the Arc). Useful for periodic maintenance. + pub fn cleanup(&self) { + let mut locks = self.locks.lock().unwrap(); + Self::cleanup_stale_entries(&mut locks); + } + + /// Returns the current number of lock entries in the manager. + pub fn lock_count(&self) -> usize { + self.locks.lock().unwrap().len() + } + /// Execute an operation with both process-local and file-system locks. pub fn with_lock(&self, path: impl AsRef, mode: LockMode, f: F) -> FileLockResult where diff --git a/src/cortex-engine/src/config/config_discovery.rs b/src/cortex-engine/src/config/config_discovery.rs index 86e3c64..7e5b97c 100644 --- a/src/cortex-engine/src/config/config_discovery.rs +++ b/src/cortex-engine/src/config/config_discovery.rs @@ -4,20 +4,36 @@ //! with caching support for performance in monorepo environments. use std::collections::HashMap; +use std::hash::Hash; use std::path::{Path, PathBuf}; use std::sync::{LazyLock, RwLock}; use tracing::{debug, trace}; +/// Maximum number of entries in each cache to prevent unbounded memory growth. +const MAX_CACHE_SIZE: usize = 1000; + /// Cache for discovered config paths. /// Key is the start directory, value is the found config path (or None). static CONFIG_CACHE: LazyLock>>> = - LazyLock::new(|| RwLock::new(HashMap::new())); + LazyLock::new(|| RwLock::new(HashMap::with_capacity(MAX_CACHE_SIZE))); /// Cache for project roots. /// Key is the start directory, value is the project root path. static PROJECT_ROOT_CACHE: LazyLock>>> = - LazyLock::new(|| RwLock::new(HashMap::new())); + LazyLock::new(|| RwLock::new(HashMap::with_capacity(MAX_CACHE_SIZE))); + +/// Insert a key-value pair into the cache with eviction when full. +/// When the cache reaches MAX_CACHE_SIZE, removes an arbitrary entry before inserting. +fn insert_with_eviction(cache: &mut HashMap, key: K, value: V) { + if cache.len() >= MAX_CACHE_SIZE { + // Remove first entry (simple eviction strategy) + if let Some(k) = cache.keys().next().cloned() { + cache.remove(&k); + } + } + cache.insert(key, value); +} /// Markers that indicate a project root directory. const PROJECT_ROOT_MARKERS: &[&str] = &[ @@ -57,9 +73,9 @@ pub fn find_up(start_dir: &Path, filename: &str) -> Option { let result = find_up_uncached(start_dir, filename); - // Store in cache + // Store in cache with eviction when full if let Ok(mut cache) = CONFIG_CACHE.write() { - cache.insert(cache_key, result.clone()); + insert_with_eviction(&mut cache, cache_key, result.clone()); } result @@ -169,9 +185,9 @@ pub fn find_project_root(start_dir: &Path) -> Option { let result = find_project_root_uncached(start_dir); - // Store in cache + // Store in cache with eviction when full if let Ok(mut cache) = PROJECT_ROOT_CACHE.write() { - cache.insert(start_dir.to_path_buf(), result.clone()); + insert_with_eviction(&mut cache, start_dir.to_path_buf(), result.clone()); } result diff --git a/src/cortex-engine/src/tokenizer.rs b/src/cortex-engine/src/tokenizer.rs index 793f5e2..b8aeadc 100644 --- a/src/cortex-engine/src/tokenizer.rs +++ b/src/cortex-engine/src/tokenizer.rs @@ -3,9 +3,25 @@ //! Provides token counting and text tokenization for various models. use std::collections::HashMap; +use std::hash::Hash; use serde::{Deserialize, Serialize}; +/// Maximum number of entries in the token cache to prevent unbounded memory growth. +const MAX_CACHE_SIZE: usize = 1000; + +/// Insert a key-value pair into the cache with eviction when full. +/// When the cache reaches MAX_CACHE_SIZE, removes an arbitrary entry before inserting. +fn insert_with_eviction(cache: &mut HashMap, key: K, value: V) { + if cache.len() >= MAX_CACHE_SIZE { + // Remove first entry (simple eviction strategy) + if let Some(k) = cache.keys().next().cloned() { + cache.remove(&k); + } + } + cache.insert(key, value); +} + /// Tokenizer type. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -58,7 +74,7 @@ impl TokenizerType { pub struct TokenCounter { /// Tokenizer type. tokenizer: TokenizerType, - /// Cache. + /// Cache with bounded size to prevent unbounded memory growth. cache: HashMap, } @@ -67,7 +83,7 @@ impl TokenCounter { pub fn new(tokenizer: TokenizerType) -> Self { Self { tokenizer, - cache: HashMap::new(), + cache: HashMap::with_capacity(MAX_CACHE_SIZE), } } @@ -85,7 +101,7 @@ impl TokenCounter { } let count = self.count_uncached(text); - self.cache.insert(hash, count); + insert_with_eviction(&mut self.cache, hash, count); count } diff --git a/src/cortex-engine/src/tools/mod.rs b/src/cortex-engine/src/tools/mod.rs index 9f5b6e1..f693894 100644 --- a/src/cortex-engine/src/tools/mod.rs +++ b/src/cortex-engine/src/tools/mod.rs @@ -30,6 +30,7 @@ pub mod artifacts; pub mod context; pub mod handlers; pub mod registry; +pub mod response_store; pub mod router; pub mod spec; pub mod unified_executor; @@ -45,6 +46,11 @@ pub use artifacts::{ pub use context::ToolContext; pub use handlers::*; pub use registry::{PluginTool, ToolRegistry}; +pub use response_store::{ + CLEANUP_INTERVAL, DEFAULT_TTL, MAX_STORE_SIZE, StoreInfo, StoreStats, StoredResponse, + ToolResponseStore, ToolResponseStoreConfig, create_shared_store, + create_shared_store_with_config, +}; pub use router::ToolRouter; pub use spec::{ToolCall, ToolDefinition, ToolHandler, ToolResult}; pub use unified_executor::{ExecutorConfig, UnifiedToolExecutor}; diff --git a/src/cortex-engine/src/tools/response_store.rs b/src/cortex-engine/src/tools/response_store.rs new file mode 100644 index 0000000..9220c86 --- /dev/null +++ b/src/cortex-engine/src/tools/response_store.rs @@ -0,0 +1,537 @@ +//! Tool response storage with bounded capacity and automatic cleanup. +//! +//! This module provides a bounded storage for tool execution results that: +//! - Limits maximum number of stored responses to prevent unbounded memory growth +//! - Removes entries when they are consumed (read and take) +//! - Periodically cleans up stale entries based on TTL +//! +//! Fixes #5292 (unbounded growth) and #5293 (missing removal on read). + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::sync::RwLock; +use tracing::debug; + +use crate::tools::spec::ToolResult; + +/// Maximum number of responses to store before eviction. +/// This prevents unbounded memory growth from accumulated tool responses. +pub const MAX_STORE_SIZE: usize = 500; + +/// Default time-to-live for stored responses (5 minutes). +pub const DEFAULT_TTL: Duration = Duration::from_secs(300); + +/// Interval for periodic cleanup of stale entries (1 minute). +pub const CLEANUP_INTERVAL: Duration = Duration::from_secs(60); + +/// A stored tool response with metadata. +#[derive(Debug, Clone)] +pub struct StoredResponse { + /// The tool execution result. + pub result: ToolResult, + /// Tool name that produced this result. + pub tool_name: String, + /// When the response was stored. + pub stored_at: Instant, + /// Whether this response has been read (but not yet consumed). + pub read: bool, +} + +impl StoredResponse { + /// Create a new stored response. + pub fn new(tool_name: impl Into, result: ToolResult) -> Self { + Self { + result, + tool_name: tool_name.into(), + stored_at: Instant::now(), + read: false, + } + } + + /// Check if the response has expired. + pub fn is_expired(&self, ttl: Duration) -> bool { + self.stored_at.elapsed() > ttl + } + + /// Get the age of this response. + pub fn age(&self) -> Duration { + self.stored_at.elapsed() + } +} + +/// Configuration for the tool response store. +#[derive(Debug, Clone)] +pub struct ToolResponseStoreConfig { + /// Maximum number of responses to store. + pub max_size: usize, + /// Time-to-live for stored responses. + pub ttl: Duration, + /// Whether to remove entries on read (peek vs consume). + pub remove_on_read: bool, +} + +impl Default for ToolResponseStoreConfig { + fn default() -> Self { + Self { + max_size: MAX_STORE_SIZE, + ttl: DEFAULT_TTL, + remove_on_read: true, + } + } +} + +impl ToolResponseStoreConfig { + /// Create a config with custom max size. + pub fn with_max_size(mut self, max_size: usize) -> Self { + self.max_size = max_size; + self + } + + /// Create a config with custom TTL. + pub fn with_ttl(mut self, ttl: Duration) -> Self { + self.ttl = ttl; + self + } + + /// Set whether to remove entries on read. + pub fn with_remove_on_read(mut self, remove: bool) -> Self { + self.remove_on_read = remove; + self + } +} + +/// Bounded storage for tool execution responses. +/// +/// This store prevents unbounded memory growth by: +/// 1. Enforcing a maximum number of stored responses +/// 2. Removing entries when they are consumed +/// 3. Periodically cleaning up stale entries +/// +/// # Thread Safety +/// +/// The store uses `RwLock` for interior mutability and is safe to share +/// across threads via `Arc`. +#[derive(Debug)] +pub struct ToolResponseStore { + /// Stored responses keyed by tool call ID. + responses: RwLock>, + /// Configuration. + config: ToolResponseStoreConfig, + /// Last cleanup time. + last_cleanup: RwLock, + /// Statistics. + stats: RwLock, +} + +impl ToolResponseStore { + /// Create a new tool response store with default configuration. + pub fn new() -> Self { + Self::with_config(ToolResponseStoreConfig::default()) + } + + /// Create a tool response store with custom configuration. + pub fn with_config(config: ToolResponseStoreConfig) -> Self { + Self { + responses: RwLock::new(HashMap::new()), + config, + last_cleanup: RwLock::new(Instant::now()), + stats: RwLock::new(StoreStats::default()), + } + } + + /// Store a tool response. + /// + /// If the store is at capacity, the oldest entry will be evicted. + /// Returns `true` if an entry was evicted to make room. + pub async fn store( + &self, + call_id: impl Into, + tool_name: impl Into, + result: ToolResult, + ) -> bool { + let call_id = call_id.into(); + let tool_name = tool_name.into(); + let mut evicted = false; + + // Perform periodic cleanup if needed + self.maybe_cleanup().await; + + let mut responses = self.responses.write().await; + + // Evict oldest entry if at capacity + if responses.len() >= self.config.max_size { + if let Some(oldest_key) = self.find_oldest_key(&responses) { + responses.remove(&oldest_key); + evicted = true; + debug!( + evicted_key = %oldest_key, + "Evicted oldest response to make room" + ); + } + } + + let response = StoredResponse::new(tool_name, result); + responses.insert(call_id.clone(), response); + + // Update stats + let mut stats = self.stats.write().await; + stats.total_stored += 1; + if evicted { + stats.evictions += 1; + } + + evicted + } + + /// Get a response without removing it (peek). + /// + /// Marks the response as read but does not consume it. + pub async fn get(&self, call_id: &str) -> Option { + let mut responses = self.responses.write().await; + + if let Some(response) = responses.get_mut(call_id) { + response.read = true; + let mut stats = self.stats.write().await; + stats.reads += 1; + Some(response.result.clone()) + } else { + None + } + } + + /// Take (consume) a response, removing it from the store. + /// + /// This is the primary method for retrieving responses as it ensures + /// entries are cleaned up after being consumed (#5293). + pub async fn take(&self, call_id: &str) -> Option { + let mut responses = self.responses.write().await; + + if let Some(response) = responses.remove(call_id) { + let mut stats = self.stats.write().await; + stats.takes += 1; + Some(response.result) + } else { + None + } + } + + /// Check if a response exists for the given call ID. + pub async fn contains(&self, call_id: &str) -> bool { + self.responses.read().await.contains_key(call_id) + } + + /// Get the current number of stored responses. + pub async fn len(&self) -> usize { + self.responses.read().await.len() + } + + /// Check if the store is empty. + pub async fn is_empty(&self) -> bool { + self.responses.read().await.is_empty() + } + + /// Remove all expired entries. + /// + /// Returns the number of entries removed. + pub async fn cleanup_expired(&self) -> usize { + let mut responses = self.responses.write().await; + let ttl = self.config.ttl; + let before = responses.len(); + + responses.retain(|_, v| !v.is_expired(ttl)); + + let removed = before - responses.len(); + if removed > 0 { + debug!(removed, "Cleaned up expired responses"); + let mut stats = self.stats.write().await; + stats.expired_cleanups += removed as u64; + } + + removed + } + + /// Remove all read entries that haven't been consumed. + /// + /// This is useful for cleaning up entries that were peeked but never taken. + pub async fn cleanup_read(&self) -> usize { + let mut responses = self.responses.write().await; + let before = responses.len(); + + responses.retain(|_, v| !v.read); + + let removed = before - responses.len(); + if removed > 0 { + debug!(removed, "Cleaned up read-but-not-consumed responses"); + } + + removed + } + + /// Clear all stored responses. + pub async fn clear(&self) { + self.responses.write().await.clear(); + } + + /// Get store statistics. + pub async fn stats(&self) -> StoreStats { + self.stats.read().await.clone() + } + + /// Get detailed store info including current size and config. + pub async fn info(&self) -> StoreInfo { + let responses = self.responses.read().await; + let stats = self.stats.read().await; + + StoreInfo { + current_size: responses.len(), + max_size: self.config.max_size, + ttl_secs: self.config.ttl.as_secs(), + oldest_age_secs: responses + .values() + .map(|r| r.age().as_secs()) + .max() + .unwrap_or(0), + stats: stats.clone(), + } + } + + // Internal helpers + + /// Find the key of the oldest entry. + fn find_oldest_key(&self, responses: &HashMap) -> Option { + responses + .iter() + .min_by_key(|(_, v)| v.stored_at) + .map(|(k, _)| k.clone()) + } + + /// Perform cleanup if enough time has passed since last cleanup. + async fn maybe_cleanup(&self) { + let should_cleanup = { + let last = self.last_cleanup.read().await; + last.elapsed() > CLEANUP_INTERVAL + }; + + if should_cleanup { + *self.last_cleanup.write().await = Instant::now(); + let removed = self.cleanup_expired().await; + if removed > 0 { + debug!(removed, "Periodic cleanup removed expired entries"); + } + } + } +} + +impl Default for ToolResponseStore { + fn default() -> Self { + Self::new() + } +} + +/// Statistics for the tool response store. +#[derive(Debug, Clone, Default)] +pub struct StoreStats { + /// Total responses stored. + pub total_stored: u64, + /// Number of get (peek) operations. + pub reads: u64, + /// Number of take (consume) operations. + pub takes: u64, + /// Number of evictions due to capacity limit. + pub evictions: u64, + /// Number of entries removed by TTL cleanup. + pub expired_cleanups: u64, +} + +/// Detailed store information. +#[derive(Debug, Clone)] +pub struct StoreInfo { + /// Current number of stored responses. + pub current_size: usize, + /// Maximum allowed size. + pub max_size: usize, + /// TTL in seconds. + pub ttl_secs: u64, + /// Age of oldest entry in seconds. + pub oldest_age_secs: u64, + /// Store statistics. + pub stats: StoreStats, +} + +/// Create a shared tool response store. +pub fn create_shared_store() -> Arc { + Arc::new(ToolResponseStore::new()) +} + +/// Create a shared tool response store with custom configuration. +pub fn create_shared_store_with_config(config: ToolResponseStoreConfig) -> Arc { + Arc::new(ToolResponseStore::with_config(config)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_store_and_take() { + let store = ToolResponseStore::new(); + + let result = ToolResult::success("test output"); + store.store("call-1", "Read", result.clone()).await; + + assert!(store.contains("call-1").await); + assert_eq!(store.len().await, 1); + + let taken = store.take("call-1").await; + assert!(taken.is_some()); + assert_eq!(taken.unwrap().output, "test output"); + + // After take, entry should be gone + assert!(!store.contains("call-1").await); + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_store_and_get() { + let store = ToolResponseStore::new(); + + let result = ToolResult::success("test output"); + store.store("call-1", "Read", result).await; + + // Get should return result but not remove it + let got = store.get("call-1").await; + assert!(got.is_some()); + assert!(store.contains("call-1").await); + + // Second get should still work + let got2 = store.get("call-1").await; + assert!(got2.is_some()); + } + + #[tokio::test] + async fn test_capacity_eviction() { + let config = ToolResponseStoreConfig::default().with_max_size(3); + let store = ToolResponseStore::with_config(config); + + // Fill to capacity + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store + .store("call-2", "Read", ToolResult::success("2")) + .await; + store + .store("call-3", "Read", ToolResult::success("3")) + .await; + + assert_eq!(store.len().await, 3); + + // Add one more, should evict oldest + let evicted = store + .store("call-4", "Read", ToolResult::success("4")) + .await; + assert!(evicted); + assert_eq!(store.len().await, 3); + + // call-1 should be evicted (oldest) + assert!(!store.contains("call-1").await); + assert!(store.contains("call-4").await); + } + + #[tokio::test] + async fn test_expired_cleanup() { + let config = ToolResponseStoreConfig::default().with_ttl(Duration::from_millis(50)); + let store = ToolResponseStore::with_config(config); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + assert_eq!(store.len().await, 1); + + // Wait for expiration + tokio::time::sleep(Duration::from_millis(100)).await; + + let removed = store.cleanup_expired().await; + assert_eq!(removed, 1); + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_cleanup_read() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store + .store("call-2", "Read", ToolResult::success("2")) + .await; + + // Read one entry + store.get("call-1").await; + + // Cleanup read entries + let removed = store.cleanup_read().await; + assert_eq!(removed, 1); + assert_eq!(store.len().await, 1); + assert!(!store.contains("call-1").await); + assert!(store.contains("call-2").await); + } + + #[tokio::test] + async fn test_stats() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store.get("call-1").await; + store.take("call-1").await; + + let stats = store.stats().await; + assert_eq!(stats.total_stored, 1); + assert_eq!(stats.reads, 1); + assert_eq!(stats.takes, 1); + } + + #[tokio::test] + async fn test_nonexistent_key() { + let store = ToolResponseStore::new(); + + assert!(store.get("nonexistent").await.is_none()); + assert!(store.take("nonexistent").await.is_none()); + assert!(!store.contains("nonexistent").await); + } + + #[tokio::test] + async fn test_clear() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store + .store("call-2", "Read", ToolResult::success("2")) + .await; + + assert_eq!(store.len().await, 2); + + store.clear().await; + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_info() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + + let info = store.info().await; + assert_eq!(info.current_size, 1); + assert_eq!(info.max_size, MAX_STORE_SIZE); + } +} diff --git a/src/cortex-engine/src/validation.rs b/src/cortex-engine/src/validation.rs index a8afbec..5ff4d9e 100644 --- a/src/cortex-engine/src/validation.rs +++ b/src/cortex-engine/src/validation.rs @@ -269,6 +269,33 @@ pub struct CommandValidator { pub allow_shell_operators: bool, } +/// Normalize a command string for consistent validation. +/// +/// This function handles bypass attempts such as: +/// - Extra whitespace: "rm -rf" → "rm -rf" +/// - Quoted parts: "'rm' -rf" → "rm -rf" +/// - Path variants: "/bin/rm -rf" → "rm -rf" +fn normalize_command(cmd: &str) -> String { + cmd.split_whitespace() + .enumerate() + .map(|(idx, part)| { + // Remove surrounding quotes (single and double) + let unquoted = part.trim_matches(|c| c == '\'' || c == '"'); + + // For the first part (command), extract basename to handle path variants + if idx == 0 { + Path::new(unquoted) + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or(unquoted) + } else { + unquoted + } + }) + .collect::>() + .join(" ") +} + impl CommandValidator { /// Create a new validator. pub fn new() -> Self { @@ -332,9 +359,12 @@ impl CommandValidator { )); } - // Check allowed list + // Normalize the command for consistent validation + let normalized = normalize_command(command); + + // Check allowed list using normalized command if let Some(ref allowed) = self.allowed { - let cmd = command.split_whitespace().next().unwrap_or(""); + let cmd = normalized.split_whitespace().next().unwrap_or(""); if !allowed.contains(cmd) { result.add_error(ValidationError::new( "command", @@ -343,9 +373,10 @@ impl CommandValidator { } } - // Check blocked commands + // Check blocked commands against normalized form for blocked in &self.blocked { - if command.contains(blocked) { + let normalized_blocked = normalize_command(blocked); + if normalized.contains(&normalized_blocked) { result.add_error(ValidationError::new( "command", "Command contains blocked pattern", @@ -354,9 +385,9 @@ impl CommandValidator { } } - // Check blocked patterns + // Check blocked patterns against both original and normalized for pattern in &self.blocked_patterns { - if command.contains(pattern) { + if command.contains(pattern) || normalized.contains(pattern) { result.add_error(ValidationError::new( "command", "Command contains dangerous pattern", @@ -700,6 +731,97 @@ mod tests { assert!(result.valid); } + #[test] + fn test_command_validation_whitespace_bypass() { + let validator = CommandValidator::new(); + + // Extra whitespace should not bypass validation + let result = validator.validate("rm -rf /"); + assert!( + !result.valid, + "Extra whitespace should not bypass blocked command" + ); + + let result = validator.validate("rm -rf /"); + assert!( + !result.valid, + "Multiple spaces should not bypass blocked command" + ); + } + + #[test] + fn test_command_validation_quote_bypass() { + let validator = CommandValidator::new(); + + // Quoted commands should not bypass validation + let result = validator.validate("'rm' -rf /"); + assert!( + !result.valid, + "Single quotes should not bypass blocked command" + ); + + let result = validator.validate("\"rm\" -rf /"); + assert!( + !result.valid, + "Double quotes should not bypass blocked command" + ); + + let result = validator.validate("'rm' '-rf' '/'"); + assert!( + !result.valid, + "Fully quoted command should not bypass blocked command" + ); + } + + #[test] + fn test_command_validation_path_bypass() { + let validator = CommandValidator::new(); + + // Path variants should not bypass validation + let result = validator.validate("/bin/rm -rf /"); + assert!( + !result.valid, + "Absolute path should not bypass blocked command" + ); + + let result = validator.validate("/usr/bin/rm -rf /"); + assert!(!result.valid, "Full path should not bypass blocked command"); + + let result = validator.validate("./rm -rf /"); + assert!( + !result.valid, + "Relative path should not bypass blocked command" + ); + } + + #[test] + fn test_command_validation_combined_bypass() { + let validator = CommandValidator::new(); + + // Combined bypass attempts + let result = validator.validate("'/bin/rm' -rf /"); + assert!( + !result.valid, + "Combined path and whitespace should not bypass" + ); + + let result = validator.validate("\"/usr/bin/rm\" '-rf' '/'"); + assert!( + !result.valid, + "Combined quotes, path, and whitespace should not bypass" + ); + } + + #[test] + fn test_normalize_command() { + // Test the normalize function directly + assert_eq!(normalize_command("rm -rf /"), "rm -rf /"); + assert_eq!(normalize_command("rm -rf /"), "rm -rf /"); + assert_eq!(normalize_command("'rm' -rf /"), "rm -rf /"); + assert_eq!(normalize_command("/bin/rm -rf /"), "rm -rf /"); + assert_eq!(normalize_command("'/usr/bin/rm' '-rf' '/'"), "rm -rf /"); + } + #[test] fn test_url_validation() { let validator = UrlValidator::new(); diff --git a/src/cortex-mcp-server/src/server.rs b/src/cortex-mcp-server/src/server.rs index 96fb8d8..266b9e4 100644 --- a/src/cortex-mcp-server/src/server.rs +++ b/src/cortex-mcp-server/src/server.rs @@ -222,14 +222,17 @@ impl McpServer { } async fn handle_initialize(&self, params: Option) -> Result { - // Check state - let current_state = *self.state.read().await; - if current_state != ServerState::Uninitialized { - return Err(JsonRpcError::invalid_request("Server already initialized")); + // Atomic check-and-transition: hold write lock during entire state check and modification + // to prevent TOCTOU race conditions where multiple concurrent initialize requests + // could both pass the uninitialized check before either sets the state + { + let mut state_guard = self.state.write().await; + if *state_guard != ServerState::Uninitialized { + return Err(JsonRpcError::invalid_request("Server already initialized")); + } + *state_guard = ServerState::Initializing; } - *self.state.write().await = ServerState::Initializing; - // Parse params let init_params: InitializeParams = params .map(serde_json::from_value) diff --git a/src/cortex-plugins/src/registry.rs b/src/cortex-plugins/src/registry.rs index 79f961e..f9bb235 100644 --- a/src/cortex-plugins/src/registry.rs +++ b/src/cortex-plugins/src/registry.rs @@ -674,18 +674,21 @@ impl PluginRegistry { let info = plugin.info().clone(); let id = info.id.clone(); - { - let plugins = self.plugins.read().await; - if plugins.contains_key(&id) { - return Err(PluginError::AlreadyExists(id)); - } - } - + // Use entry API to atomically check-and-insert within a single write lock + // to prevent TOCTOU race conditions where multiple concurrent registrations + // could both pass the contains_key check before either inserts let handle = PluginHandle::new(plugin); - { let mut plugins = self.plugins.write().await; - plugins.insert(id.clone(), handle); + use std::collections::hash_map::Entry; + match plugins.entry(id.clone()) { + Entry::Occupied(_) => { + return Err(PluginError::AlreadyExists(id)); + } + Entry::Vacant(entry) => { + entry.insert(handle); + } + } } { diff --git a/src/cortex-resume/src/session_store.rs b/src/cortex-resume/src/session_store.rs index 48ed5b9..04a1f9f 100644 --- a/src/cortex-resume/src/session_store.rs +++ b/src/cortex-resume/src/session_store.rs @@ -13,15 +13,34 @@ use tokio::fs; use tokio::sync::{Mutex as AsyncMutex, RwLock}; use tracing::{debug, info}; +/// Maximum number of lock entries before triggering cleanup. +const MAX_LOCK_ENTRIES: usize = 10_000; + /// Global file lock manager for session store operations. /// Prevents concurrent modifications to the same file within the process. static FILE_LOCKS: once_cell::sync::Lazy>>>> = once_cell::sync::Lazy::new(|| std::sync::Mutex::new(HashMap::new())); +/// Remove lock entries that are no longer in use. +/// +/// An entry is considered stale when only the HashMap holds a reference +/// to it (strong_count == 1), meaning no caller is currently using the lock. +fn cleanup_stale_file_locks(locks: &mut HashMap>>) { + locks.retain(|_, arc| Arc::strong_count(arc) > 1); +} + /// Acquire an async lock for a specific file path. +/// +/// Automatically cleans up stale lock entries when the map grows too large. fn get_file_lock(path: &Path) -> Arc> { let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); let mut locks = FILE_LOCKS.lock().unwrap(); + + // Clean up stale entries if the map is getting large + if locks.len() >= MAX_LOCK_ENTRIES { + cleanup_stale_file_locks(&mut locks); + } + locks .entry(canonical) .or_insert_with(|| Arc::new(AsyncMutex::new(()))) diff --git a/src/cortex-storage/src/sessions/storage.rs b/src/cortex-storage/src/sessions/storage.rs index dd750b0..7f3d38c 100644 --- a/src/cortex-storage/src/sessions/storage.rs +++ b/src/cortex-storage/src/sessions/storage.rs @@ -124,20 +124,67 @@ impl SessionStorage { } /// Save a session to disk. + /// + /// This function ensures data durability by calling sync_all() (fsync) + /// after writing to prevent data loss on crash or forceful termination. pub async fn save_session(&self, session: &StoredSession) -> Result<()> { let path = self.paths.session_path(&session.id); let content = serde_json::to_string_pretty(session)?; - fs::write(&path, content).await?; + + // Write content to file + let file = fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(&path) + .await?; + + use tokio::io::AsyncWriteExt; + let mut file = file; + file.write_all(content.as_bytes()).await?; + file.flush().await?; + + // Ensure data is durably written to disk (fsync) to prevent data loss on crash + file.sync_all().await?; + + // Sync parent directory on Unix for crash safety (ensures directory entry is persisted) + #[cfg(unix)] + { + if let Some(parent) = path.parent() { + if let Ok(dir) = fs::File::open(parent).await { + let _ = dir.sync_all().await; + } + } + } + debug!(session_id = %session.id, "Session saved"); Ok(()) } /// Save a session synchronously. + /// + /// This function ensures data durability by calling sync_all() (fsync) + /// after writing to prevent data loss on crash or forceful termination. pub fn save_session_sync(&self, session: &StoredSession) -> Result<()> { let path = self.paths.session_path(&session.id); let file = std::fs::File::create(&path)?; - let writer = BufWriter::new(file); - serde_json::to_writer_pretty(writer, session)?; + let mut writer = BufWriter::new(file); + serde_json::to_writer_pretty(&mut writer, session)?; + writer.flush()?; + + // Ensure data is durably written to disk (fsync) to prevent data loss on crash + writer.get_ref().sync_all()?; + + // Sync parent directory on Unix for crash safety (ensures directory entry is persisted) + #[cfg(unix)] + { + if let Some(parent) = path.parent() { + if let Ok(dir) = std::fs::File::open(parent) { + let _ = dir.sync_all(); + } + } + } + debug!(session_id = %session.id, "Session saved"); Ok(()) } diff --git a/src/cortex-tui/src/session/storage.rs b/src/cortex-tui/src/session/storage.rs index 7e1621e..17524a9 100644 --- a/src/cortex-tui/src/session/storage.rs +++ b/src/cortex-tui/src/session/storage.rs @@ -87,6 +87,9 @@ impl SessionStorage { // ======================================================================== /// Saves session metadata. + /// + /// Uses atomic write (temp file + rename) with fsync for durability. + /// This prevents data loss on crash or forceful termination. pub fn save_meta(&self, meta: &SessionMeta) -> Result<()> { self.ensure_session_dir(&meta.id)?; @@ -94,13 +97,35 @@ impl SessionStorage { let content = serde_json::to_string_pretty(meta).context("Failed to serialize session metadata")?; - // Atomic write: write to temp file then rename + // Atomic write: write to temp file, fsync, then rename let temp_path = path.with_extension("json.tmp"); - fs::write(&temp_path, &content) + + // Write and sync temp file + let file = File::create(&temp_path) + .with_context(|| format!("Failed to create temp metadata file: {:?}", temp_path))?; + let mut writer = BufWriter::new(file); + writer + .write_all(content.as_bytes()) .with_context(|| format!("Failed to write temp metadata file: {:?}", temp_path))?; + writer.flush()?; + + // Ensure data is durably written to disk (fsync) before rename + writer.get_ref().sync_all().with_context(|| { + format!("Failed to sync temp metadata file to disk: {:?}", temp_path) + })?; + + // Rename temp file to final path fs::rename(&temp_path, &path) .with_context(|| format!("Failed to rename metadata file: {:?}", path))?; + // Sync parent directory on Unix for crash safety (ensures directory entry is persisted) + #[cfg(unix)] + if let Some(parent) = path.parent() + && let Ok(dir) = File::open(parent) + { + let _ = dir.sync_all(); + } + Ok(()) } @@ -212,6 +237,14 @@ impl SessionStorage { fs::rename(&temp_path, &path) .with_context(|| format!("Failed to rename history file: {:?}", path))?; + // Sync parent directory on Unix for crash safety (ensures directory entry is persisted) + #[cfg(unix)] + if let Some(parent) = path.parent() + && let Ok(dir) = File::open(parent) + { + let _ = dir.sync_all(); + } + Ok(()) }