diff --git a/src/cortex-cli/src/utils/paths.rs b/src/cortex-cli/src/utils/paths.rs index 8cdf03e..d0bbcdf 100644 --- a/src/cortex-cli/src/utils/paths.rs +++ b/src/cortex-cli/src/utils/paths.rs @@ -34,10 +34,16 @@ pub fn get_cortex_home() -> PathBuf { /// // Returns: /home/user/documents/file.txt /// ``` pub fn expand_tilde(path: &str) -> String { - if path.starts_with("~/") - && let Some(home) = dirs::home_dir() - { - return home.join(&path[2..]).to_string_lossy().to_string(); + if path == "~" { + // Handle bare "~" - return home directory + if let Some(home) = dirs::home_dir() { + return home.to_string_lossy().to_string(); + } + } else if path.starts_with("~/") { + // Handle "~/" prefix - expand to home directory + rest of path + if let Some(home) = dirs::home_dir() { + return home.join(&path[2..]).to_string_lossy().to_string(); + } } path.to_string() } @@ -58,8 +64,12 @@ pub fn expand_tilde(path: &str) -> String { pub fn validate_path_safety(path: &Path, base_dir: Option<&Path>) -> Result<(), String> { let path_str = path.to_string_lossy(); - // Check for path traversal attempts - if path_str.contains("..") { + // Check for path traversal attempts by examining path components + // This correctly handles filenames containing ".." like "file..txt" + if path + .components() + .any(|c| matches!(c, std::path::Component::ParentDir)) + { return Err("Path contains traversal sequence '..'".to_string()); } @@ -257,8 +267,15 @@ mod tests { #[test] fn test_expand_tilde_with_tilde_only() { - // Test tilde alone - should remain unchanged (not "~/") - assert_eq!(expand_tilde("~"), "~"); + // Test bare "~" - should expand to home directory + let result = expand_tilde("~"); + if let Some(home) = dirs::home_dir() { + let expected = home.to_string_lossy().to_string(); + assert_eq!(result, expected); + } else { + // If no home dir, original is returned + assert_eq!(result, "~"); + } } #[test] @@ -320,20 +337,30 @@ mod tests { #[test] fn test_validate_path_safety_detects_various_traversal_patterns() { - // Different traversal patterns - let patterns = ["foo/../bar", "...", "foo/bar/../baz", "./foo/../../../etc"]; + // Patterns that ARE path traversal (contain ".." as a component) + let traversal_patterns = ["foo/../bar", "foo/bar/../baz", "./foo/../../../etc", ".."]; - for pattern in patterns { + for pattern in traversal_patterns { let path = Path::new(pattern); let result = validate_path_safety(path, None); - // Only patterns containing ".." should fail - if pattern.contains("..") { - assert!( - result.is_err(), - "Expected traversal detection for: {}", - pattern - ); - } + assert!( + result.is_err(), + "Expected traversal detection for: {}", + pattern + ); + } + + // Patterns that are NOT path traversal (contain ".." in filenames only) + let safe_patterns = ["file..txt", "..hidden", "test...file", "foo/bar..baz/file"]; + + for pattern in safe_patterns { + let path = Path::new(pattern); + let result = validate_path_safety(path, None); + assert!( + result.is_ok(), + "False positive: '{}' should not be detected as traversal", + pattern + ); } } diff --git a/src/cortex-compact/src/config.rs b/src/cortex-compact/src/config.rs index 099c126..f90b9d9 100644 --- a/src/cortex-compact/src/config.rs +++ b/src/cortex-compact/src/config.rs @@ -1,6 +1,6 @@ //! Compaction configuration. -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; /// Configuration for auto-compaction. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -8,8 +8,8 @@ pub struct CompactionConfig { /// Whether auto-compaction is enabled. #[serde(default = "default_true")] pub enabled: bool, - /// Token threshold to trigger compaction (percentage of max). - #[serde(default = "default_threshold")] + /// Token threshold to trigger compaction (ratio 0.0-1.0 of max context). + #[serde(default = "default_threshold", deserialize_with = "deserialize_threshold_percent")] pub threshold_percent: f32, /// Minimum tokens to keep after compaction. #[serde(default = "default_min_tokens")] @@ -25,6 +25,20 @@ pub struct CompactionConfig { pub preserve_recent_turns: usize, } +/// Deserialize threshold_percent with validation (must be 0.0-1.0). +fn deserialize_threshold_percent<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let value = f32::deserialize(deserializer)?; + if !(0.0..=1.0).contains(&value) { + return Err(serde::de::Error::custom( + "threshold_percent must be between 0.0 and 1.0", + )); + } + Ok(value) +} + fn default_true() -> bool { true } diff --git a/src/cortex-engine/src/config/mod.rs b/src/cortex-engine/src/config/mod.rs index e4b9f65..dd99214 100644 --- a/src/cortex-engine/src/config/mod.rs +++ b/src/cortex-engine/src/config/mod.rs @@ -46,9 +46,9 @@ pub struct Config { /// Provider configuration. pub model_provider: ModelProviderInfo, /// Context window size. - pub model_context_window: Option, + pub model_context_window: Option, /// Auto-compact token limit. - pub model_auto_compact_token_limit: Option, + pub model_auto_compact_token_limit: Option, /// Approval policy. pub approval_policy: AskForApproval, /// Sandbox policy. diff --git a/src/cortex-engine/src/config/types.rs b/src/cortex-engine/src/config/types.rs index 3e5e4fe..ee3b39a 100644 --- a/src/cortex-engine/src/config/types.rs +++ b/src/cortex-engine/src/config/types.rs @@ -69,8 +69,8 @@ pub struct PermissionConfig { pub struct ConfigToml { pub model: Option, pub model_provider: Option, - pub model_context_window: Option, - pub model_auto_compact_token_limit: Option, + pub model_context_window: Option, + pub model_auto_compact_token_limit: Option, pub approval_policy: Option, pub sandbox_mode: Option, pub sandbox_workspace_write: Option, diff --git a/src/cortex-engine/src/tools/mod.rs b/src/cortex-engine/src/tools/mod.rs index 9f5b6e1..43103b9 100644 --- a/src/cortex-engine/src/tools/mod.rs +++ b/src/cortex-engine/src/tools/mod.rs @@ -30,6 +30,7 @@ pub mod artifacts; pub mod context; pub mod handlers; pub mod registry; +pub mod response_store; pub mod router; pub mod spec; pub mod unified_executor; @@ -45,6 +46,10 @@ pub use artifacts::{ pub use context::ToolContext; pub use handlers::*; pub use registry::{PluginTool, ToolRegistry}; +pub use response_store::{ + CLEANUP_INTERVAL, DEFAULT_TTL, MAX_STORE_SIZE, StoreInfo, StoreStats, StoredResponse, + ToolResponseStore, ToolResponseStoreConfig, create_shared_store, create_shared_store_with_config, +}; pub use router::ToolRouter; pub use spec::{ToolCall, ToolDefinition, ToolHandler, ToolResult}; pub use unified_executor::{ExecutorConfig, UnifiedToolExecutor}; diff --git a/src/cortex-engine/src/tools/response_store.rs b/src/cortex-engine/src/tools/response_store.rs new file mode 100644 index 0000000..c402fb2 --- /dev/null +++ b/src/cortex-engine/src/tools/response_store.rs @@ -0,0 +1,511 @@ +//! Tool response storage with bounded capacity and automatic cleanup. +//! +//! This module provides a bounded storage for tool execution results that: +//! - Limits maximum number of stored responses to prevent unbounded memory growth +//! - Removes entries when they are consumed (read and take) +//! - Periodically cleans up stale entries based on TTL +//! +//! Fixes #5292 (unbounded growth) and #5293 (missing removal on read). + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::sync::RwLock; +use tracing::{debug, warn}; + +use crate::tools::spec::ToolResult; + +/// Maximum number of responses to store before eviction. +/// This prevents unbounded memory growth from accumulated tool responses. +pub const MAX_STORE_SIZE: usize = 500; + +/// Default time-to-live for stored responses (5 minutes). +pub const DEFAULT_TTL: Duration = Duration::from_secs(300); + +/// Interval for periodic cleanup of stale entries (1 minute). +pub const CLEANUP_INTERVAL: Duration = Duration::from_secs(60); + +/// A stored tool response with metadata. +#[derive(Debug, Clone)] +pub struct StoredResponse { + /// The tool execution result. + pub result: ToolResult, + /// Tool name that produced this result. + pub tool_name: String, + /// When the response was stored. + pub stored_at: Instant, + /// Whether this response has been read (but not yet consumed). + pub read: bool, +} + +impl StoredResponse { + /// Create a new stored response. + pub fn new(tool_name: impl Into, result: ToolResult) -> Self { + Self { + result, + tool_name: tool_name.into(), + stored_at: Instant::now(), + read: false, + } + } + + /// Check if the response has expired. + pub fn is_expired(&self, ttl: Duration) -> bool { + self.stored_at.elapsed() > ttl + } + + /// Get the age of this response. + pub fn age(&self) -> Duration { + self.stored_at.elapsed() + } +} + +/// Configuration for the tool response store. +#[derive(Debug, Clone)] +pub struct ToolResponseStoreConfig { + /// Maximum number of responses to store. + pub max_size: usize, + /// Time-to-live for stored responses. + pub ttl: Duration, + /// Whether to remove entries on read (peek vs consume). + pub remove_on_read: bool, +} + +impl Default for ToolResponseStoreConfig { + fn default() -> Self { + Self { + max_size: MAX_STORE_SIZE, + ttl: DEFAULT_TTL, + remove_on_read: true, + } + } +} + +impl ToolResponseStoreConfig { + /// Create a config with custom max size. + pub fn with_max_size(mut self, max_size: usize) -> Self { + self.max_size = max_size; + self + } + + /// Create a config with custom TTL. + pub fn with_ttl(mut self, ttl: Duration) -> Self { + self.ttl = ttl; + self + } + + /// Set whether to remove entries on read. + pub fn with_remove_on_read(mut self, remove: bool) -> Self { + self.remove_on_read = remove; + self + } +} + +/// Bounded storage for tool execution responses. +/// +/// This store prevents unbounded memory growth by: +/// 1. Enforcing a maximum number of stored responses +/// 2. Removing entries when they are consumed +/// 3. Periodically cleaning up stale entries +/// +/// # Thread Safety +/// +/// The store uses `RwLock` for interior mutability and is safe to share +/// across threads via `Arc`. +#[derive(Debug)] +pub struct ToolResponseStore { + /// Stored responses keyed by tool call ID. + responses: RwLock>, + /// Configuration. + config: ToolResponseStoreConfig, + /// Last cleanup time. + last_cleanup: RwLock, + /// Statistics. + stats: RwLock, +} + +impl ToolResponseStore { + /// Create a new tool response store with default configuration. + pub fn new() -> Self { + Self::with_config(ToolResponseStoreConfig::default()) + } + + /// Create a tool response store with custom configuration. + pub fn with_config(config: ToolResponseStoreConfig) -> Self { + Self { + responses: RwLock::new(HashMap::new()), + config, + last_cleanup: RwLock::new(Instant::now()), + stats: RwLock::new(StoreStats::default()), + } + } + + /// Store a tool response. + /// + /// If the store is at capacity, the oldest entry will be evicted. + /// Returns `true` if an entry was evicted to make room. + pub async fn store(&self, call_id: impl Into, tool_name: impl Into, result: ToolResult) -> bool { + let call_id = call_id.into(); + let tool_name = tool_name.into(); + let mut evicted = false; + + // Perform periodic cleanup if needed + self.maybe_cleanup().await; + + let mut responses = self.responses.write().await; + + // Evict oldest entry if at capacity + if responses.len() >= self.config.max_size { + if let Some(oldest_key) = self.find_oldest_key(&responses) { + responses.remove(&oldest_key); + evicted = true; + debug!( + evicted_key = %oldest_key, + "Evicted oldest response to make room" + ); + } + } + + let response = StoredResponse::new(tool_name, result); + responses.insert(call_id.clone(), response); + + // Update stats + let mut stats = self.stats.write().await; + stats.total_stored += 1; + if evicted { + stats.evictions += 1; + } + + evicted + } + + /// Get a response without removing it (peek). + /// + /// Marks the response as read but does not consume it. + pub async fn get(&self, call_id: &str) -> Option { + let mut responses = self.responses.write().await; + + if let Some(response) = responses.get_mut(call_id) { + response.read = true; + let mut stats = self.stats.write().await; + stats.reads += 1; + Some(response.result.clone()) + } else { + None + } + } + + /// Take (consume) a response, removing it from the store. + /// + /// This is the primary method for retrieving responses as it ensures + /// entries are cleaned up after being consumed (#5293). + pub async fn take(&self, call_id: &str) -> Option { + let mut responses = self.responses.write().await; + + if let Some(response) = responses.remove(call_id) { + let mut stats = self.stats.write().await; + stats.takes += 1; + Some(response.result) + } else { + None + } + } + + /// Check if a response exists for the given call ID. + pub async fn contains(&self, call_id: &str) -> bool { + self.responses.read().await.contains_key(call_id) + } + + /// Get the current number of stored responses. + pub async fn len(&self) -> usize { + self.responses.read().await.len() + } + + /// Check if the store is empty. + pub async fn is_empty(&self) -> bool { + self.responses.read().await.is_empty() + } + + /// Remove all expired entries. + /// + /// Returns the number of entries removed. + pub async fn cleanup_expired(&self) -> usize { + let mut responses = self.responses.write().await; + let ttl = self.config.ttl; + let before = responses.len(); + + responses.retain(|_, v| !v.is_expired(ttl)); + + let removed = before - responses.len(); + if removed > 0 { + debug!(removed, "Cleaned up expired responses"); + let mut stats = self.stats.write().await; + stats.expired_cleanups += removed as u64; + } + + removed + } + + /// Remove all read entries that haven't been consumed. + /// + /// This is useful for cleaning up entries that were peeked but never taken. + pub async fn cleanup_read(&self) -> usize { + let mut responses = self.responses.write().await; + let before = responses.len(); + + responses.retain(|_, v| !v.read); + + let removed = before - responses.len(); + if removed > 0 { + debug!(removed, "Cleaned up read-but-not-consumed responses"); + } + + removed + } + + /// Clear all stored responses. + pub async fn clear(&self) { + self.responses.write().await.clear(); + } + + /// Get store statistics. + pub async fn stats(&self) -> StoreStats { + self.stats.read().await.clone() + } + + /// Get detailed store info including current size and config. + pub async fn info(&self) -> StoreInfo { + let responses = self.responses.read().await; + let stats = self.stats.read().await; + + StoreInfo { + current_size: responses.len(), + max_size: self.config.max_size, + ttl_secs: self.config.ttl.as_secs(), + oldest_age_secs: responses + .values() + .map(|r| r.age().as_secs()) + .max() + .unwrap_or(0), + stats: stats.clone(), + } + } + + // Internal helpers + + /// Find the key of the oldest entry. + fn find_oldest_key(&self, responses: &HashMap) -> Option { + responses + .iter() + .min_by_key(|(_, v)| v.stored_at) + .map(|(k, _)| k.clone()) + } + + /// Perform cleanup if enough time has passed since last cleanup. + async fn maybe_cleanup(&self) { + let should_cleanup = { + let last = self.last_cleanup.read().await; + last.elapsed() > CLEANUP_INTERVAL + }; + + if should_cleanup { + *self.last_cleanup.write().await = Instant::now(); + let removed = self.cleanup_expired().await; + if removed > 0 { + debug!(removed, "Periodic cleanup removed expired entries"); + } + } + } +} + +impl Default for ToolResponseStore { + fn default() -> Self { + Self::new() + } +} + +/// Statistics for the tool response store. +#[derive(Debug, Clone, Default)] +pub struct StoreStats { + /// Total responses stored. + pub total_stored: u64, + /// Number of get (peek) operations. + pub reads: u64, + /// Number of take (consume) operations. + pub takes: u64, + /// Number of evictions due to capacity limit. + pub evictions: u64, + /// Number of entries removed by TTL cleanup. + pub expired_cleanups: u64, +} + +/// Detailed store information. +#[derive(Debug, Clone)] +pub struct StoreInfo { + /// Current number of stored responses. + pub current_size: usize, + /// Maximum allowed size. + pub max_size: usize, + /// TTL in seconds. + pub ttl_secs: u64, + /// Age of oldest entry in seconds. + pub oldest_age_secs: u64, + /// Store statistics. + pub stats: StoreStats, +} + +/// Create a shared tool response store. +pub fn create_shared_store() -> Arc { + Arc::new(ToolResponseStore::new()) +} + +/// Create a shared tool response store with custom configuration. +pub fn create_shared_store_with_config(config: ToolResponseStoreConfig) -> Arc { + Arc::new(ToolResponseStore::with_config(config)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_store_and_take() { + let store = ToolResponseStore::new(); + + let result = ToolResult::success("test output"); + store.store("call-1", "Read", result.clone()).await; + + assert!(store.contains("call-1").await); + assert_eq!(store.len().await, 1); + + let taken = store.take("call-1").await; + assert!(taken.is_some()); + assert_eq!(taken.unwrap().output, "test output"); + + // After take, entry should be gone + assert!(!store.contains("call-1").await); + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_store_and_get() { + let store = ToolResponseStore::new(); + + let result = ToolResult::success("test output"); + store.store("call-1", "Read", result).await; + + // Get should return result but not remove it + let got = store.get("call-1").await; + assert!(got.is_some()); + assert!(store.contains("call-1").await); + + // Second get should still work + let got2 = store.get("call-1").await; + assert!(got2.is_some()); + } + + #[tokio::test] + async fn test_capacity_eviction() { + let config = ToolResponseStoreConfig::default().with_max_size(3); + let store = ToolResponseStore::with_config(config); + + // Fill to capacity + store.store("call-1", "Read", ToolResult::success("1")).await; + store.store("call-2", "Read", ToolResult::success("2")).await; + store.store("call-3", "Read", ToolResult::success("3")).await; + + assert_eq!(store.len().await, 3); + + // Add one more, should evict oldest + let evicted = store.store("call-4", "Read", ToolResult::success("4")).await; + assert!(evicted); + assert_eq!(store.len().await, 3); + + // call-1 should be evicted (oldest) + assert!(!store.contains("call-1").await); + assert!(store.contains("call-4").await); + } + + #[tokio::test] + async fn test_expired_cleanup() { + let config = ToolResponseStoreConfig::default() + .with_ttl(Duration::from_millis(50)); + let store = ToolResponseStore::with_config(config); + + store.store("call-1", "Read", ToolResult::success("1")).await; + assert_eq!(store.len().await, 1); + + // Wait for expiration + tokio::time::sleep(Duration::from_millis(100)).await; + + let removed = store.cleanup_expired().await; + assert_eq!(removed, 1); + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_cleanup_read() { + let store = ToolResponseStore::new(); + + store.store("call-1", "Read", ToolResult::success("1")).await; + store.store("call-2", "Read", ToolResult::success("2")).await; + + // Read one entry + store.get("call-1").await; + + // Cleanup read entries + let removed = store.cleanup_read().await; + assert_eq!(removed, 1); + assert_eq!(store.len().await, 1); + assert!(!store.contains("call-1").await); + assert!(store.contains("call-2").await); + } + + #[tokio::test] + async fn test_stats() { + let store = ToolResponseStore::new(); + + store.store("call-1", "Read", ToolResult::success("1")).await; + store.get("call-1").await; + store.take("call-1").await; + + let stats = store.stats().await; + assert_eq!(stats.total_stored, 1); + assert_eq!(stats.reads, 1); + assert_eq!(stats.takes, 1); + } + + #[tokio::test] + async fn test_nonexistent_key() { + let store = ToolResponseStore::new(); + + assert!(store.get("nonexistent").await.is_none()); + assert!(store.take("nonexistent").await.is_none()); + assert!(!store.contains("nonexistent").await); + } + + #[tokio::test] + async fn test_clear() { + let store = ToolResponseStore::new(); + + store.store("call-1", "Read", ToolResult::success("1")).await; + store.store("call-2", "Read", ToolResult::success("2")).await; + + assert_eq!(store.len().await, 2); + + store.clear().await; + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_info() { + let store = ToolResponseStore::new(); + + store.store("call-1", "Read", ToolResult::success("1")).await; + + let info = store.info().await; + assert_eq!(info.current_size, 1); + assert_eq!(info.max_size, MAX_STORE_SIZE); + } +} diff --git a/src/cortex-engine/src/validation.rs b/src/cortex-engine/src/validation.rs index a8afbec..e1b472a 100644 --- a/src/cortex-engine/src/validation.rs +++ b/src/cortex-engine/src/validation.rs @@ -269,6 +269,34 @@ pub struct CommandValidator { pub allow_shell_operators: bool, } +/// Normalize a command string for consistent validation. +/// +/// This function handles bypass attempts such as: +/// - Extra whitespace: "rm -rf" → "rm -rf" +/// - Quoted parts: "'rm' -rf" → "rm -rf" +/// - Path variants: "/bin/rm -rf" → "rm -rf" +fn normalize_command(cmd: &str) -> String { + cmd.split_whitespace() + .enumerate() + .map(|(idx, part)| { + // Remove surrounding quotes (single and double) + let unquoted = part + .trim_matches(|c| c == '\'' || c == '"'); + + // For the first part (command), extract basename to handle path variants + if idx == 0 { + Path::new(unquoted) + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or(unquoted) + } else { + unquoted + } + }) + .collect::>() + .join(" ") +} + impl CommandValidator { /// Create a new validator. pub fn new() -> Self { @@ -332,9 +360,12 @@ impl CommandValidator { )); } - // Check allowed list + // Normalize the command for consistent validation + let normalized = normalize_command(command); + + // Check allowed list using normalized command if let Some(ref allowed) = self.allowed { - let cmd = command.split_whitespace().next().unwrap_or(""); + let cmd = normalized.split_whitespace().next().unwrap_or(""); if !allowed.contains(cmd) { result.add_error(ValidationError::new( "command", @@ -343,9 +374,10 @@ impl CommandValidator { } } - // Check blocked commands + // Check blocked commands against normalized form for blocked in &self.blocked { - if command.contains(blocked) { + let normalized_blocked = normalize_command(blocked); + if normalized.contains(&normalized_blocked) { result.add_error(ValidationError::new( "command", "Command contains blocked pattern", @@ -354,9 +386,9 @@ impl CommandValidator { } } - // Check blocked patterns + // Check blocked patterns against both original and normalized for pattern in &self.blocked_patterns { - if command.contains(pattern) { + if command.contains(pattern) || normalized.contains(pattern) { result.add_error(ValidationError::new( "command", "Command contains dangerous pattern", @@ -700,6 +732,70 @@ mod tests { assert!(result.valid); } + #[test] + fn test_command_validation_whitespace_bypass() { + let validator = CommandValidator::new(); + + // Extra whitespace should not bypass validation + let result = validator.validate("rm -rf /"); + assert!(!result.valid, "Extra whitespace should not bypass blocked command"); + + let result = validator.validate("rm -rf /"); + assert!(!result.valid, "Multiple spaces should not bypass blocked command"); + } + + #[test] + fn test_command_validation_quote_bypass() { + let validator = CommandValidator::new(); + + // Quoted commands should not bypass validation + let result = validator.validate("'rm' -rf /"); + assert!(!result.valid, "Single quotes should not bypass blocked command"); + + let result = validator.validate("\"rm\" -rf /"); + assert!(!result.valid, "Double quotes should not bypass blocked command"); + + let result = validator.validate("'rm' '-rf' '/'"); + assert!(!result.valid, "Fully quoted command should not bypass blocked command"); + } + + #[test] + fn test_command_validation_path_bypass() { + let validator = CommandValidator::new(); + + // Path variants should not bypass validation + let result = validator.validate("/bin/rm -rf /"); + assert!(!result.valid, "Absolute path should not bypass blocked command"); + + let result = validator.validate("/usr/bin/rm -rf /"); + assert!(!result.valid, "Full path should not bypass blocked command"); + + let result = validator.validate("./rm -rf /"); + assert!(!result.valid, "Relative path should not bypass blocked command"); + } + + #[test] + fn test_command_validation_combined_bypass() { + let validator = CommandValidator::new(); + + // Combined bypass attempts + let result = validator.validate("'/bin/rm' -rf /"); + assert!(!result.valid, "Combined path and whitespace should not bypass"); + + let result = validator.validate("\"/usr/bin/rm\" '-rf' '/'"); + assert!(!result.valid, "Combined quotes, path, and whitespace should not bypass"); + } + + #[test] + fn test_normalize_command() { + // Test the normalize function directly + assert_eq!(normalize_command("rm -rf /"), "rm -rf /"); + assert_eq!(normalize_command("rm -rf /"), "rm -rf /"); + assert_eq!(normalize_command("'rm' -rf /"), "rm -rf /"); + assert_eq!(normalize_command("/bin/rm -rf /"), "rm -rf /"); + assert_eq!(normalize_command("'/usr/bin/rm' '-rf' '/'"), "rm -rf /"); + } + #[test] fn test_url_validation() { let validator = UrlValidator::new(); diff --git a/src/cortex-mcp-server/src/server.rs b/src/cortex-mcp-server/src/server.rs index 96fb8d8..266b9e4 100644 --- a/src/cortex-mcp-server/src/server.rs +++ b/src/cortex-mcp-server/src/server.rs @@ -222,14 +222,17 @@ impl McpServer { } async fn handle_initialize(&self, params: Option) -> Result { - // Check state - let current_state = *self.state.read().await; - if current_state != ServerState::Uninitialized { - return Err(JsonRpcError::invalid_request("Server already initialized")); + // Atomic check-and-transition: hold write lock during entire state check and modification + // to prevent TOCTOU race conditions where multiple concurrent initialize requests + // could both pass the uninitialized check before either sets the state + { + let mut state_guard = self.state.write().await; + if *state_guard != ServerState::Uninitialized { + return Err(JsonRpcError::invalid_request("Server already initialized")); + } + *state_guard = ServerState::Initializing; } - *self.state.write().await = ServerState::Initializing; - // Parse params let init_params: InitializeParams = params .map(serde_json::from_value) diff --git a/src/cortex-otel/src/config.rs b/src/cortex-otel/src/config.rs index 2c40b06..b66d26e 100644 --- a/src/cortex-otel/src/config.rs +++ b/src/cortex-otel/src/config.rs @@ -1,6 +1,6 @@ //! OpenTelemetry configuration. -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; /// OpenTelemetry settings. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -30,7 +30,7 @@ pub struct OtelSettings { pub propagate_context: bool, /// Sampling ratio (0.0 to 1.0). - #[serde(default = "default_sampling_ratio")] + #[serde(default = "default_sampling_ratio", deserialize_with = "deserialize_sampling_ratio")] pub sampling_ratio: f64, /// Export timeout in seconds. @@ -38,6 +38,20 @@ pub struct OtelSettings { pub export_timeout_secs: u64, } +/// Deserialize sampling_ratio with validation (must be 0.0-1.0). +fn deserialize_sampling_ratio<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let value = f64::deserialize(deserializer)?; + if !(0.0..=1.0).contains(&value) { + return Err(serde::de::Error::custom( + "sampling_ratio must be between 0.0 and 1.0", + )); + } + Ok(value) +} + impl Default for OtelSettings { fn default() -> Self { OtelSettings { @@ -84,7 +98,8 @@ impl OtelSettings { } if let Ok(ratio) = std::env::var("OTEL_TRACES_SAMPLER_ARG") - && let Ok(ratio) = ratio.parse() + && let Ok(ratio) = ratio.parse::() + && (0.0..=1.0).contains(&ratio) { settings.sampling_ratio = ratio; } diff --git a/src/cortex-plugins/src/registry.rs b/src/cortex-plugins/src/registry.rs index 79f961e..f9bb235 100644 --- a/src/cortex-plugins/src/registry.rs +++ b/src/cortex-plugins/src/registry.rs @@ -674,18 +674,21 @@ impl PluginRegistry { let info = plugin.info().clone(); let id = info.id.clone(); - { - let plugins = self.plugins.read().await; - if plugins.contains_key(&id) { - return Err(PluginError::AlreadyExists(id)); - } - } - + // Use entry API to atomically check-and-insert within a single write lock + // to prevent TOCTOU race conditions where multiple concurrent registrations + // could both pass the contains_key check before either inserts let handle = PluginHandle::new(plugin); - { let mut plugins = self.plugins.write().await; - plugins.insert(id.clone(), handle); + use std::collections::hash_map::Entry; + match plugins.entry(id.clone()) { + Entry::Occupied(_) => { + return Err(PluginError::AlreadyExists(id)); + } + Entry::Vacant(entry) => { + entry.insert(handle); + } + } } { diff --git a/src/cortex-protocol/src/protocol/event_payloads.rs b/src/cortex-protocol/src/protocol/event_payloads.rs index 78fb491..7e692b7 100644 --- a/src/cortex-protocol/src/protocol/event_payloads.rs +++ b/src/cortex-protocol/src/protocol/event_payloads.rs @@ -40,7 +40,7 @@ pub struct SessionConfiguredEvent { #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)] pub struct TaskStartedEvent { - pub model_context_window: Option, + pub model_context_window: Option, } #[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)] diff --git a/src/cortex-protocol/src/protocol/tokens.rs b/src/cortex-protocol/src/protocol/tokens.rs index f2af7af..b2a1db0 100644 --- a/src/cortex-protocol/src/protocol/tokens.rs +++ b/src/cortex-protocol/src/protocol/tokens.rs @@ -18,7 +18,7 @@ pub struct TokenUsage { pub struct TokenUsageInfo { pub total_token_usage: TokenUsage, pub last_token_usage: TokenUsage, - pub model_context_window: Option, + pub model_context_window: Option, #[serde(default)] pub context_tokens: i64, } diff --git a/src/cortex-storage/src/sessions/storage.rs b/src/cortex-storage/src/sessions/storage.rs index dd750b0..c1dedc1 100644 --- a/src/cortex-storage/src/sessions/storage.rs +++ b/src/cortex-storage/src/sessions/storage.rs @@ -59,48 +59,123 @@ impl SessionStorage { // ======================================================================== /// List all sessions, sorted by most recent first. + /// + /// Returns both successfully loaded sessions and logs any errors encountered during listing. + /// This ensures the caller is aware of any issues while still getting available sessions. pub async fn list_sessions(&self) -> Result> { let mut sessions = Vec::new(); + let mut errors: Vec = Vec::new(); if !self.paths.sessions_dir.exists() { return Ok(sessions); } - let mut entries = fs::read_dir(&self.paths.sessions_dir).await?; - while let Some(entry) = entries.next_entry().await? { - let path = entry.path(); - if path.extension().map(|e| e == "json").unwrap_or(false) { - match self.load_session_from_path(&path).await { - Ok(session) => sessions.push(session.into()), - Err(e) => warn!(path = %path.display(), error = %e, "Failed to load session"), + let mut entries = fs::read_dir(&self.paths.sessions_dir).await.map_err(|e| { + StorageError::Io(std::io::Error::new( + e.kind(), + format!( + "Failed to read sessions directory {:?}: {}", + self.paths.sessions_dir, e + ), + )) + })?; + + loop { + let entry_result = entries.next_entry().await; + match entry_result { + Ok(Some(entry)) => { + let path = entry.path(); + if path.extension().map(|e| e == "json").unwrap_or(false) { + match self.load_session_from_path(&path).await { + Ok(session) => sessions.push(session.into()), + Err(e) => { + let error_msg = + format!("Failed to load session {:?}: {}", path.display(), e); + errors.push(error_msg.clone()); + warn!(path = %path.display(), error = %e, "Failed to load session"); + } + } + } + } + Ok(None) => break, + Err(e) => { + let error_msg = format!("Failed to read directory entry: {}", e); + errors.push(error_msg); + warn!(error = %e, "Failed to read directory entry during session listing"); } } } + // Log aggregate error count if any errors occurred + if !errors.is_empty() { + warn!( + error_count = errors.len(), + "Encountered {} error(s) while listing sessions", + errors.len() + ); + } + // Sort by updated_at descending (newest first) sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); Ok(sessions) } /// List all sessions synchronously. + /// + /// Returns both successfully loaded sessions and logs any errors encountered during listing. + /// This ensures the caller is aware of any issues while still getting available sessions. pub fn list_sessions_sync(&self) -> Result> { let mut sessions = Vec::new(); + let mut errors: Vec = Vec::new(); if !self.paths.sessions_dir.exists() { return Ok(sessions); } - for entry in std::fs::read_dir(&self.paths.sessions_dir)? { - let entry = entry?; + let entries = std::fs::read_dir(&self.paths.sessions_dir).map_err(|e| { + StorageError::Io(std::io::Error::new( + e.kind(), + format!( + "Failed to read sessions directory {:?}: {}", + self.paths.sessions_dir, e + ), + )) + })?; + + for entry_result in entries { + let entry = match entry_result { + Ok(e) => e, + Err(e) => { + let error_msg = format!("Failed to read directory entry: {}", e); + errors.push(error_msg); + warn!(error = %e, "Failed to read directory entry during session listing"); + continue; + } + }; + let path = entry.path(); if path.extension().map(|e| e == "json").unwrap_or(false) { match self.load_session_from_path_sync(&path) { Ok(session) => sessions.push(session.into()), - Err(e) => warn!(path = %path.display(), error = %e, "Failed to load session"), + Err(e) => { + let error_msg = + format!("Failed to load session {:?}: {}", path.display(), e); + errors.push(error_msg); + warn!(path = %path.display(), error = %e, "Failed to load session"); + } } } } + // Log aggregate error count if any errors occurred + if !errors.is_empty() { + warn!( + error_count = errors.len(), + "Encountered {} error(s) while listing sessions", + errors.len() + ); + } + sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); Ok(sessions) } @@ -438,7 +513,9 @@ impl SessionStorage { impl Default for SessionStorage { fn default() -> Self { - Self::new().expect("Failed to create session storage") + Self::new().expect( + "SessionStorage initialization failed - check directory permissions and disk space", + ) } } diff --git a/src/cortex-tui/src/session/manager.rs b/src/cortex-tui/src/session/manager.rs index 9f0dead..c2d03fa 100644 --- a/src/cortex-tui/src/session/manager.rs +++ b/src/cortex-tui/src/session/manager.rs @@ -179,10 +179,18 @@ impl CortexSession { } /// Updates token counts in metadata. + /// + /// Saves metadata to disk. If save fails, the in-memory state is still + /// updated but marked as modified for later retry. pub fn add_tokens(&mut self, input: i64, output: i64) { self.meta.add_tokens(input, output); if let Err(e) = self.storage.save_meta(&self.meta) { - tracing::error!("Failed to save metadata: {}", e); + tracing::error!( + session_id = %self.meta.id, + error = %e, + "Failed to save metadata after token update" + ); + self.modified = true; } } @@ -201,79 +209,135 @@ impl CortexSession { /// Removes and returns the last exchange (user + assistant messages). /// Returns None if there are fewer than 2 messages. + /// + /// Only updates in-memory state after successful storage operations. + /// If storage fails, the messages are restored to maintain consistency. pub fn pop_last_exchange(&mut self) -> Option> { if self.messages.len() < 2 { return None; } + // Pop messages from memory temporarily + let last = self.messages.pop(); + let prev = self.messages.pop(); + + // Build result from popped messages let mut result = Vec::new(); + let mut popped_messages = Vec::new(); - // Pop the last message (should be assistant) - if let Some(last) = self.messages.pop() { - let role = match last.role.as_str() { + if let Some(ref msg) = prev { + let role = match msg.role.as_str() { "assistant" => cortex_core::widgets::MessageRole::Assistant, "user" => cortex_core::widgets::MessageRole::User, _ => cortex_core::widgets::MessageRole::System, }; result.push(cortex_core::widgets::Message { role, - content: last.content, + content: msg.content.clone(), timestamp: None, is_streaming: false, tool_name: None, }); + popped_messages.push(msg.clone()); } - // Pop the previous message (should be user) - if let Some(prev) = self.messages.pop() { - let role = match prev.role.as_str() { + if let Some(ref msg) = last { + let role = match msg.role.as_str() { "assistant" => cortex_core::widgets::MessageRole::Assistant, "user" => cortex_core::widgets::MessageRole::User, _ => cortex_core::widgets::MessageRole::System, }; - result.insert( - 0, - cortex_core::widgets::Message { - role, - content: prev.content, - timestamp: None, - is_streaming: false, - tool_name: None, - }, - ); + result.push(cortex_core::widgets::Message { + role, + content: msg.content.clone(), + timestamp: None, + is_streaming: false, + tool_name: None, + }); + popped_messages.push(msg.clone()); } - // Update metadata - self.meta.message_count = self.messages.len() as u32; - self.modified = true; + // Try to save updated state to storage + let rewrite_result = self.storage.rewrite_messages(&self.meta.id, &self.messages); + + match rewrite_result { + Ok(()) => { + // Storage succeeded, update metadata + self.meta.message_count = self.messages.len() as u32; + + if let Err(e) = self.storage.save_meta(&self.meta) { + tracing::error!( + session_id = %self.meta.id, + error = %e, + "Failed to save metadata after undo - history is updated but metadata may be stale" + ); + self.modified = true; + } - // Save updated state (messages need to be rewritten) - if let Err(e) = self.storage.rewrite_messages(&self.meta.id, &self.messages) { - tracing::error!("Failed to rewrite messages after undo: {}", e); - } - if let Err(e) = self.storage.save_meta(&self.meta) { - tracing::error!("Failed to save metadata after undo: {}", e); - } + Some(result) + } + Err(e) => { + // Storage failed - restore messages to maintain consistency + tracing::error!( + session_id = %self.meta.id, + error = %e, + "Failed to rewrite messages after undo - restoring original state" + ); + + // Restore in reverse order (prev was popped second, so push it first) + if let Some(msg) = prev { + self.messages.push(msg); + } + if let Some(msg) = last { + self.messages.push(msg); + } - Some(result) + // Return None to indicate the operation failed + None + } + } } /// Internal method to add a message and persist it. + /// + /// Only updates in-memory state after successful storage operations. + /// This ensures consistency between disk and memory state. fn add_message_internal(&mut self, message: StoredMessage) -> &StoredMessage { - // Append to storage first - if let Err(e) = self.storage.append_message(&self.meta.id, &message) { - tracing::error!("Failed to save message: {}", e); - } + // Try to append to storage first - only update memory state on success + match self.storage.append_message(&self.meta.id, &message) { + Ok(()) => { + // Storage succeeded, now update metadata + self.meta.increment_messages(); + + // Try to save metadata - if this fails, we still keep the message + // since it was already persisted to history + if let Err(e) = self.storage.save_meta(&self.meta) { + tracing::error!( + session_id = %self.meta.id, + error = %e, + "Failed to save metadata after message append - message is saved but metadata may be stale" + ); + } - // Update metadata - self.meta.increment_messages(); - if let Err(e) = self.storage.save_meta(&self.meta) { - tracing::error!("Failed to save metadata: {}", e); + // Add to in-memory list only after storage success + self.messages.push(message); + } + Err(e) => { + // Storage failed - log error but still add to memory for this session + // This allows the conversation to continue even if persistence fails + tracing::error!( + session_id = %self.meta.id, + error = %e, + "Failed to save message to storage - message exists only in memory" + ); + + // Still add to memory so the conversation can continue + self.messages.push(message); + self.modified = true; // Mark as modified since we have unsaved changes + } } - // Add to in-memory list - self.messages.push(message); - self.messages.last().unwrap() + self.messages.last().expect("message was just added") } /// Converts messages to API format for completion requests. diff --git a/src/cortex-tui/src/session/storage.rs b/src/cortex-tui/src/session/storage.rs index 7e1621e..c12b828 100644 --- a/src/cortex-tui/src/session/storage.rs +++ b/src/cortex-tui/src/session/storage.rs @@ -220,23 +220,55 @@ impl SessionStorage { // ======================================================================== /// Lists all sessions (sorted by updated_at descending). + /// + /// Returns both successfully loaded sessions and any errors encountered during listing. + /// This ensures the caller is aware of any issues while still getting available sessions. pub fn list_sessions(&self) -> Result> { self.ensure_base_dir()?; let mut summaries = Vec::new(); + let mut errors: Vec = Vec::new(); + + let entries = fs::read_dir(&self.base_dir) + .with_context(|| format!("Failed to read sessions directory: {:?}", self.base_dir))?; + + for entry_result in entries { + let entry = match entry_result { + Ok(e) => e, + Err(e) => { + errors.push(format!("Failed to read directory entry: {}", e)); + continue; + } + }; + + let path = entry.path(); + if !path.is_dir() { + continue; + } + + let session_id = match path.file_name().and_then(|n| n.to_str()) { + Some(id) => id, + None => continue, + }; - if let Ok(entries) = fs::read_dir(&self.base_dir) { - for entry in entries.flatten() { - let path = entry.path(); - if path.is_dir() - && let Some(session_id) = path.file_name().and_then(|n| n.to_str()) - && let Ok(meta) = self.load_meta(session_id) - { - summaries.push(SessionSummary::from(&meta)); + match self.load_meta(session_id) { + Ok(meta) => summaries.push(SessionSummary::from(&meta)), + Err(e) => { + errors.push(format!("Failed to load session '{}': {}", session_id, e)); } } } + // Log errors but don't fail - return what we could load + if !errors.is_empty() { + tracing::warn!( + error_count = errors.len(), + "Encountered {} error(s) while listing sessions: {:?}", + errors.len(), + errors + ); + } + // Sort by updated_at descending (most recent first) summaries.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); @@ -323,7 +355,9 @@ impl SessionStorage { impl Default for SessionStorage { fn default() -> Self { - Self::new().expect("Failed to create session storage") + Self::new().expect( + "SessionStorage initialization failed - check directory permissions and disk space", + ) } } @@ -347,11 +381,9 @@ mod tests { let (storage, _temp) = create_test_storage(); let session_id = "test-session-123"; - assert!( - storage - .session_dir(session_id) - .ends_with("test-session-123") - ); + assert!(storage + .session_dir(session_id) + .ends_with("test-session-123")); assert!(storage.meta_path(session_id).ends_with("meta.json")); assert!(storage.history_path(session_id).ends_with("history.jsonl")); }