diff --git a/src/cortex-agents/src/mention.rs b/src/cortex-agents/src/mention.rs index 81d9d71..59bb955 100644 --- a/src/cortex-agents/src/mention.rs +++ b/src/cortex-agents/src/mention.rs @@ -17,6 +17,46 @@ use regex::Regex; use std::sync::LazyLock; +/// Safely get the string slice up to the given byte position. +/// +/// Returns the slice `&text[..pos]` if `pos` is at a valid UTF-8 character boundary. +/// If `pos` is inside a multi-byte character, finds the nearest valid boundary +/// by searching backwards. +fn safe_slice_up_to(text: &str, pos: usize) -> &str { + if pos >= text.len() { + return text; + } + if text.is_char_boundary(pos) { + return &text[..pos]; + } + // Find the nearest valid boundary by searching backwards + let mut valid_pos = pos; + while valid_pos > 0 && !text.is_char_boundary(valid_pos) { + valid_pos -= 1; + } + &text[..valid_pos] +} + +/// Safely get the string slice from the given byte position to the end. +/// +/// Returns the slice `&text[pos..]` if `pos` is at a valid UTF-8 character boundary. +/// If `pos` is inside a multi-byte character, finds the nearest valid boundary +/// by searching forwards. +fn safe_slice_from(text: &str, pos: usize) -> &str { + if pos >= text.len() { + return ""; + } + if text.is_char_boundary(pos) { + return &text[pos..]; + } + // Find the nearest valid boundary by searching forwards + let mut valid_pos = pos; + while valid_pos < text.len() && !text.is_char_boundary(valid_pos) { + valid_pos += 1; + } + &text[valid_pos..] +} + /// A parsed agent mention from user input. #[derive(Debug, Clone, PartialEq, Eq)] pub struct AgentMention { @@ -108,10 +148,10 @@ pub fn extract_mention_and_text( ) -> Option<(AgentMention, String)> { let mention = find_first_valid_mention(text, valid_agents)?; - // Remove the mention from text + // Remove the mention from text, using safe slicing for UTF-8 boundaries let mut remaining = String::with_capacity(text.len()); - remaining.push_str(&text[..mention.start]); - remaining.push_str(&text[mention.end..]); + remaining.push_str(safe_slice_up_to(text, mention.start)); + remaining.push_str(safe_slice_from(text, mention.end)); // Trim and normalize whitespace let remaining = remaining.trim().to_string(); @@ -123,7 +163,8 @@ pub fn extract_mention_and_text( pub fn starts_with_mention(text: &str, valid_agents: &[&str]) -> bool { let text = text.trim(); if let Some(mention) = find_first_valid_mention(text, valid_agents) { - mention.start == 0 || text[..mention.start].trim().is_empty() + // Use safe slicing to handle UTF-8 boundaries + mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty() } else { false } @@ -196,8 +237,8 @@ pub fn parse_message_for_agent(text: &str, valid_agents: &[&str]) -> ParsedAgent // Check if message starts with @agent if let Some((mention, remaining)) = extract_mention_and_text(text, valid_agents) { - // Only trigger if mention is at the start - if mention.start == 0 || text[..mention.start].trim().is_empty() { + // Only trigger if mention is at the start, using safe slicing for UTF-8 boundaries + if mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty() { return ParsedAgentMessage::for_agent(mention.agent_name, remaining, text.to_string()); } } @@ -318,4 +359,99 @@ mod tests { assert_eq!(mentions[0].agent_name, "my-agent"); assert_eq!(mentions[1].agent_name, "my_agent"); } + + // UTF-8 boundary safety tests + #[test] + fn test_safe_slice_up_to_ascii() { + let text = "hello world"; + assert_eq!(safe_slice_up_to(text, 5), "hello"); + assert_eq!(safe_slice_up_to(text, 0), ""); + assert_eq!(safe_slice_up_to(text, 100), "hello world"); + } + + #[test] + fn test_safe_slice_up_to_multibyte() { + // "こんにちは" - each character is 3 bytes + let text = "こんにちは"; + assert_eq!(safe_slice_up_to(text, 3), "こ"); // Valid boundary + assert_eq!(safe_slice_up_to(text, 6), "こん"); // Valid boundary + // Position 4 is inside the second character, should return "こ" + assert_eq!(safe_slice_up_to(text, 4), "こ"); + assert_eq!(safe_slice_up_to(text, 5), "こ"); + } + + #[test] + fn test_safe_slice_from_multibyte() { + let text = "こんにちは"; + assert_eq!(safe_slice_from(text, 3), "んにちは"); // Valid boundary + // Position 4 is inside second character, should skip to position 6 + assert_eq!(safe_slice_from(text, 4), "にちは"); + assert_eq!(safe_slice_from(text, 5), "にちは"); + } + + #[test] + fn test_extract_mention_with_multibyte_prefix() { + let valid = vec!["general"]; + + // Multi-byte characters before mention + let result = extract_mention_and_text("日本語 @general search files", &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + // The prefix should be preserved without panicking + assert!(remaining.contains("search files")); + } + + #[test] + fn test_starts_with_mention_multibyte() { + let valid = vec!["general"]; + + // Whitespace with multi-byte characters should not cause panic + assert!(starts_with_mention(" @general task", &valid)); + + // Multi-byte characters before mention - should return false, not panic + assert!(!starts_with_mention("日本語 @general task", &valid)); + } + + #[test] + fn test_parse_message_for_agent_multibyte() { + let valid = vec!["general"]; + + // Multi-byte prefix - should not panic + let parsed = parse_message_for_agent("日本語 @general find files", &valid); + // Since mention is not at the start, should not invoke task + assert!(!parsed.should_invoke_task); + + // Multi-byte in the prompt (after mention) + let parsed = parse_message_for_agent("@general 日本語を検索", &valid); + assert!(parsed.should_invoke_task); + assert_eq!(parsed.agent, Some("general".to_string())); + assert_eq!(parsed.prompt, "日本語を検索"); + } + + #[test] + fn test_extract_mention_with_emoji() { + let valid = vec!["general"]; + + // Emojis are 4 bytes each + let result = extract_mention_and_text("🎉 @general celebrate", &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + assert!(remaining.contains("celebrate")); + } + + #[test] + fn test_mixed_multibyte_and_ascii() { + let valid = vec!["general"]; + + // Mix of ASCII, CJK, and emoji + let text = "Hello 世界 🌍 @general search for 日本語"; + let result = extract_mention_and_text(text, &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + // Should not panic and produce valid output + assert!(!remaining.is_empty()); + } } diff --git a/src/cortex-app-server/src/auth.rs b/src/cortex-app-server/src/auth.rs index 414f36f..4f240c3 100644 --- a/src/cortex-app-server/src/auth.rs +++ b/src/cortex-app-server/src/auth.rs @@ -45,7 +45,7 @@ impl Claims { pub fn new(user_id: impl Into, expiry_seconds: u64) -> Self { let now = SystemTime::now() .duration_since(UNIX_EPOCH) - .unwrap() + .unwrap_or_default() .as_secs(); Self { @@ -75,7 +75,7 @@ impl Claims { pub fn is_expired(&self) -> bool { let now = SystemTime::now() .duration_since(UNIX_EPOCH) - .unwrap() + .unwrap_or_default() .as_secs(); self.exp < now } @@ -187,7 +187,7 @@ impl AuthService { pub async fn cleanup_revoked_tokens(&self) { let now = SystemTime::now() .duration_since(UNIX_EPOCH) - .unwrap() + .unwrap_or_default() .as_secs(); let mut revoked = self.revoked_tokens.write().await; diff --git a/src/cortex-app-server/src/config.rs b/src/cortex-app-server/src/config.rs index 35ac75b..92be050 100644 --- a/src/cortex-app-server/src/config.rs +++ b/src/cortex-app-server/src/config.rs @@ -49,12 +49,18 @@ pub struct ServerConfig { pub max_body_size: usize, /// Request timeout in seconds (applies to full request lifecycle). + /// + /// See `cortex_common::http_client` module documentation for the complete + /// timeout hierarchy across Cortex services. #[serde(default = "default_request_timeout")] pub request_timeout: u64, /// Read timeout for individual chunks in seconds. /// Applies to chunked transfer encoding to prevent indefinite hangs /// when clients disconnect without sending the terminal chunk. + /// + /// See `cortex_common::http_client` module documentation for the complete + /// timeout hierarchy across Cortex services. #[serde(default = "default_read_timeout")] pub read_timeout: u64, @@ -71,12 +77,16 @@ pub struct ServerConfig { pub cors_origins: Vec, /// Graceful shutdown timeout in seconds. + /// + /// See `cortex_common::http_client` module documentation for the complete + /// timeout hierarchy across Cortex services. #[serde(default = "default_shutdown_timeout")] pub shutdown_timeout: u64, } fn default_shutdown_timeout() -> u64 { 30 // 30 seconds for graceful shutdown + // See cortex_common::http_client for timeout hierarchy documentation } fn default_listen_addr() -> String { diff --git a/src/cortex-app-server/src/middleware.rs b/src/cortex-app-server/src/middleware.rs index a997157..45d4406 100644 --- a/src/cortex-app-server/src/middleware.rs +++ b/src/cortex-app-server/src/middleware.rs @@ -40,7 +40,8 @@ pub async fn request_id_middleware(mut request: Request, next: Next) -> Response let mut response = next.run(request).await; response.headers_mut().insert( REQUEST_ID_HEADER, - HeaderValue::from_str(&request_id).unwrap(), + HeaderValue::from_str(&request_id) + .unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id")), ); response diff --git a/src/cortex-app-server/src/storage.rs b/src/cortex-app-server/src/storage.rs index 6c5d44e..1aa617f 100644 --- a/src/cortex-app-server/src/storage.rs +++ b/src/cortex-app-server/src/storage.rs @@ -47,8 +47,6 @@ pub struct StoredToolCall { /// Session storage manager. pub struct SessionStorage { - #[allow(dead_code)] - base_dir: PathBuf, sessions_dir: PathBuf, history_dir: PathBuf, } @@ -66,7 +64,6 @@ impl SessionStorage { info!("Session storage initialized at {:?}", base_dir); Ok(Self { - base_dir, sessions_dir, history_dir, }) diff --git a/src/cortex-apply-patch/src/hunk.rs b/src/cortex-apply-patch/src/hunk.rs index ea67a97..ab5b1f1 100644 --- a/src/cortex-apply-patch/src/hunk.rs +++ b/src/cortex-apply-patch/src/hunk.rs @@ -250,9 +250,6 @@ pub struct SearchReplace { pub search: String, /// The text to replace with. pub replace: String, - /// Replace all occurrences (true) or just the first (false). - #[allow(dead_code)] - pub replace_all: bool, } impl SearchReplace { @@ -266,16 +263,8 @@ impl SearchReplace { path: path.into(), search: search.into(), replace: replace.into(), - replace_all: false, } } - - /// Set whether to replace all occurrences. - #[allow(dead_code)] - pub fn with_replace_all(mut self, replace_all: bool) -> Self { - self.replace_all = replace_all; - self - } } #[cfg(test)] diff --git a/src/cortex-common/src/http_client.rs b/src/cortex-common/src/http_client.rs index b181ac8..3b290ff 100644 --- a/src/cortex-common/src/http_client.rs +++ b/src/cortex-common/src/http_client.rs @@ -9,6 +9,54 @@ //! //! DNS caching is configured with reasonable TTL to allow failover and load //! balancer updates (#2177). +//! +//! # Timeout Configuration Guide +//! +//! This section documents the timeout hierarchy across the Cortex codebase. Use this +//! as a reference when configuring timeouts for new features or debugging timeout issues. +//! +//! ## Timeout Hierarchy +//! +//! | Use Case | Timeout | Constant/Location | Rationale | +//! |-----------------------------|---------|--------------------------------------------|-----------------------------------------| +//! | Health checks | 5s | `HEALTH_CHECK_TIMEOUT` (this module) | Quick validation of service status | +//! | Standard HTTP requests | 30s | `DEFAULT_TIMEOUT` (this module) | Normal API calls with reasonable margin | +//! | Per-chunk read (streaming) | 30s | `read_timeout` (cortex-app-server/config) | Individual chunk timeout during stream | +//! | Pool idle timeout | 60s | `POOL_IDLE_TIMEOUT` (this module) | DNS re-resolution for failover | +//! | LLM Request (non-streaming) | 120s | `DEFAULT_REQUEST_TIMEOUT_SECS` (cortex-exec/runner) | Model inference takes time | +//! | LLM Streaming total | 300s | `STREAMING_TIMEOUT` (this module) | Long-running streaming responses | +//! | Server request lifecycle | 300s | `request_timeout` (cortex-app-server/config) | Full HTTP request/response cycle | +//! | Entire exec session | 600s | `DEFAULT_TIMEOUT_SECS` (cortex-exec/runner) | Multi-turn conversation limit | +//! | Graceful shutdown | 30s | `shutdown_timeout` (cortex-app-server/config) | Time for cleanup on shutdown | +//! +//! ## Module-Specific Timeouts +//! +//! ### cortex-common (this module) +//! - `DEFAULT_TIMEOUT` (30s): Use for standard API calls. +//! - `STREAMING_TIMEOUT` (300s): Use for LLM streaming endpoints. +//! - `HEALTH_CHECK_TIMEOUT` (5s): Use for health/readiness checks. +//! - `POOL_IDLE_TIMEOUT` (60s): Connection pool cleanup for DNS freshness. +//! +//! ### cortex-exec (runner.rs) +//! - `DEFAULT_TIMEOUT_SECS` (600s): Maximum duration for entire exec session. +//! - `DEFAULT_REQUEST_TIMEOUT_SECS` (120s): Single LLM request timeout. +//! +//! ### cortex-app-server (config.rs) +//! - `request_timeout` (300s): Full request lifecycle timeout. +//! - `read_timeout` (30s): Per-chunk timeout for streaming reads. +//! - `shutdown_timeout` (30s): Graceful shutdown duration. +//! +//! ### cortex-engine (api_client.rs) +//! - Re-exports constants from this module for consistency. +//! +//! ## Recommendations +//! +//! When adding new timeout configurations: +//! 1. Use constants from this module when possible for consistency. +//! 2. Document any new timeout constants with their rationale. +//! 3. Consider the timeout hierarchy - inner timeouts should be shorter than outer ones. +//! 4. For LLM operations, use longer timeouts (120s-300s) to accommodate model inference. +//! 5. For health checks and quick validations, use short timeouts (5s-10s). use reqwest::Client; use std::time::Duration; diff --git a/src/cortex-engine/src/async_utils.rs b/src/cortex-engine/src/async_utils.rs index f7b0490..ed63a6f 100644 --- a/src/cortex-engine/src/async_utils.rs +++ b/src/cortex-engine/src/async_utils.rs @@ -147,13 +147,17 @@ impl ConcurrencyLimiter { } /// Execute with limit. - pub async fn execute(&self, f: F) -> T + /// + /// Returns an error if the semaphore is closed. + pub async fn execute(&self, f: F) -> Result where F: FnOnce() -> Fut, Fut: Future, { - let _permit = self.semaphore.acquire().await.unwrap(); - f().await + let _permit = self.semaphore.acquire().await.map_err(|_| { + CortexError::Internal("concurrency limiter semaphore closed unexpectedly".into()) + })?; + Ok(f().await) } /// Get available permits. @@ -178,26 +182,36 @@ impl AsyncOnce { } /// Get or initialize. - pub async fn get_or_init(&self, init: F) -> T + /// + /// Returns an error if the internal state is inconsistent (value missing after init flag set). + pub async fn get_or_init(&self, init: F) -> Result where F: FnOnce() -> Fut, Fut: Future, { // Fast path if *self.initialized.read().await { - return self.value.read().await.clone().unwrap(); + return self.value.read().await.clone().ok_or_else(|| { + CortexError::Internal( + "AsyncOnce: value missing despite initialized flag being set".into(), + ) + }); } // Slow path let mut initialized = self.initialized.write().await; if *initialized { - return self.value.read().await.clone().unwrap(); + return self.value.read().await.clone().ok_or_else(|| { + CortexError::Internal( + "AsyncOnce: value missing despite initialized flag being set".into(), + ) + }); } let value = init().await; *self.value.write().await = Some(value.clone()); *initialized = true; - value + Ok(value) } /// Check if initialized. @@ -399,7 +413,12 @@ impl AsyncCache { } /// Run futures concurrently with limit. -pub async fn concurrent(items: impl IntoIterator, limit: usize) -> Vec +/// +/// Returns an error if the semaphore is closed unexpectedly. +pub async fn concurrent( + items: impl IntoIterator, + limit: usize, +) -> Result> where F: FnOnce() -> Fut, Fut: Future, @@ -410,12 +429,17 @@ where for item in items { let sem = semaphore.clone(); handles.push(async move { - let _permit = sem.acquire().await.unwrap(); - item().await + let _permit = sem.acquire().await.map_err(|_| { + CortexError::Internal("concurrent execution semaphore closed unexpectedly".into()) + })?; + Ok(item().await) }); } - futures::future::join_all(handles).await + futures::future::join_all(handles) + .await + .into_iter() + .collect() } /// Select the first future to complete. @@ -503,7 +527,10 @@ mod tests { })); } - futures::future::join_all(handles).await; + let results: Vec<_> = futures::future::join_all(handles).await; + for result in results { + assert!(result.is_ok()); + } assert_eq!(*counter.lock().await, 5); } @@ -511,8 +538,8 @@ mod tests { async fn test_async_once() { let once: AsyncOnce = AsyncOnce::new(); - let v1 = once.get_or_init(|| async { 42 }).await; - let v2 = once.get_or_init(|| async { 100 }).await; + let v1 = once.get_or_init(|| async { 42 }).await.unwrap(); + let v2 = once.get_or_init(|| async { 100 }).await.unwrap(); assert_eq!(v1, 42); assert_eq!(v2, 42); @@ -560,7 +587,7 @@ mod tests { Box::new(|| Box::pin(async { 2 })), Box::new(|| Box::pin(async { 3 })), ]; - let results = concurrent(items, 2).await; + let results = concurrent(items, 2).await.unwrap(); assert_eq!(results.len(), 3); } diff --git a/src/cortex-engine/src/ratelimit.rs b/src/cortex-engine/src/ratelimit.rs index 5423512..e15ea4a 100644 --- a/src/cortex-engine/src/ratelimit.rs +++ b/src/cortex-engine/src/ratelimit.rs @@ -341,9 +341,13 @@ impl ConcurrencyLimiter { } /// Acquire a permit. - pub async fn acquire(&self) -> ConcurrencyPermit { - let permit = self.semaphore.clone().acquire_owned().await.unwrap(); - ConcurrencyPermit { _permit: permit } + /// + /// Returns an error if the semaphore is closed. + pub async fn acquire(&self) -> Result { + let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| { + CortexError::Internal("concurrency limiter semaphore closed unexpectedly".into()) + })?; + Ok(ConcurrencyPermit { _permit: permit }) } /// Try to acquire a permit. @@ -595,8 +599,8 @@ mod tests { async fn test_concurrency_limiter() { let limiter = ConcurrencyLimiter::new(2); - let _p1 = limiter.acquire().await; - let _p2 = limiter.acquire().await; + let _p1 = limiter.acquire().await.unwrap(); + let _p2 = limiter.acquire().await.unwrap(); // Third should fail immediately assert!(limiter.try_acquire().is_none()); diff --git a/src/cortex-engine/src/tools/handlers/grep.rs b/src/cortex-engine/src/tools/handlers/grep.rs index 26d2561..ecef2d9 100644 --- a/src/cortex-engine/src/tools/handlers/grep.rs +++ b/src/cortex-engine/src/tools/handlers/grep.rs @@ -29,6 +29,7 @@ struct GrepArgs { glob_pattern: Option, #[serde(default = "default_output_mode")] output_mode: String, + #[serde(alias = "head_limit")] max_results: Option, #[serde(default)] multiline: bool, diff --git a/src/cortex-engine/src/tools/mod.rs b/src/cortex-engine/src/tools/mod.rs index 9f5b6e1..f693894 100644 --- a/src/cortex-engine/src/tools/mod.rs +++ b/src/cortex-engine/src/tools/mod.rs @@ -30,6 +30,7 @@ pub mod artifacts; pub mod context; pub mod handlers; pub mod registry; +pub mod response_store; pub mod router; pub mod spec; pub mod unified_executor; @@ -45,6 +46,11 @@ pub use artifacts::{ pub use context::ToolContext; pub use handlers::*; pub use registry::{PluginTool, ToolRegistry}; +pub use response_store::{ + CLEANUP_INTERVAL, DEFAULT_TTL, MAX_STORE_SIZE, StoreInfo, StoreStats, StoredResponse, + ToolResponseStore, ToolResponseStoreConfig, create_shared_store, + create_shared_store_with_config, +}; pub use router::ToolRouter; pub use spec::{ToolCall, ToolDefinition, ToolHandler, ToolResult}; pub use unified_executor::{ExecutorConfig, UnifiedToolExecutor}; diff --git a/src/cortex-engine/src/tools/response_store.rs b/src/cortex-engine/src/tools/response_store.rs new file mode 100644 index 0000000..9220c86 --- /dev/null +++ b/src/cortex-engine/src/tools/response_store.rs @@ -0,0 +1,537 @@ +//! Tool response storage with bounded capacity and automatic cleanup. +//! +//! This module provides a bounded storage for tool execution results that: +//! - Limits maximum number of stored responses to prevent unbounded memory growth +//! - Removes entries when they are consumed (read and take) +//! - Periodically cleans up stale entries based on TTL +//! +//! Fixes #5292 (unbounded growth) and #5293 (missing removal on read). + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::sync::RwLock; +use tracing::debug; + +use crate::tools::spec::ToolResult; + +/// Maximum number of responses to store before eviction. +/// This prevents unbounded memory growth from accumulated tool responses. +pub const MAX_STORE_SIZE: usize = 500; + +/// Default time-to-live for stored responses (5 minutes). +pub const DEFAULT_TTL: Duration = Duration::from_secs(300); + +/// Interval for periodic cleanup of stale entries (1 minute). +pub const CLEANUP_INTERVAL: Duration = Duration::from_secs(60); + +/// A stored tool response with metadata. +#[derive(Debug, Clone)] +pub struct StoredResponse { + /// The tool execution result. + pub result: ToolResult, + /// Tool name that produced this result. + pub tool_name: String, + /// When the response was stored. + pub stored_at: Instant, + /// Whether this response has been read (but not yet consumed). + pub read: bool, +} + +impl StoredResponse { + /// Create a new stored response. + pub fn new(tool_name: impl Into, result: ToolResult) -> Self { + Self { + result, + tool_name: tool_name.into(), + stored_at: Instant::now(), + read: false, + } + } + + /// Check if the response has expired. + pub fn is_expired(&self, ttl: Duration) -> bool { + self.stored_at.elapsed() > ttl + } + + /// Get the age of this response. + pub fn age(&self) -> Duration { + self.stored_at.elapsed() + } +} + +/// Configuration for the tool response store. +#[derive(Debug, Clone)] +pub struct ToolResponseStoreConfig { + /// Maximum number of responses to store. + pub max_size: usize, + /// Time-to-live for stored responses. + pub ttl: Duration, + /// Whether to remove entries on read (peek vs consume). + pub remove_on_read: bool, +} + +impl Default for ToolResponseStoreConfig { + fn default() -> Self { + Self { + max_size: MAX_STORE_SIZE, + ttl: DEFAULT_TTL, + remove_on_read: true, + } + } +} + +impl ToolResponseStoreConfig { + /// Create a config with custom max size. + pub fn with_max_size(mut self, max_size: usize) -> Self { + self.max_size = max_size; + self + } + + /// Create a config with custom TTL. + pub fn with_ttl(mut self, ttl: Duration) -> Self { + self.ttl = ttl; + self + } + + /// Set whether to remove entries on read. + pub fn with_remove_on_read(mut self, remove: bool) -> Self { + self.remove_on_read = remove; + self + } +} + +/// Bounded storage for tool execution responses. +/// +/// This store prevents unbounded memory growth by: +/// 1. Enforcing a maximum number of stored responses +/// 2. Removing entries when they are consumed +/// 3. Periodically cleaning up stale entries +/// +/// # Thread Safety +/// +/// The store uses `RwLock` for interior mutability and is safe to share +/// across threads via `Arc`. +#[derive(Debug)] +pub struct ToolResponseStore { + /// Stored responses keyed by tool call ID. + responses: RwLock>, + /// Configuration. + config: ToolResponseStoreConfig, + /// Last cleanup time. + last_cleanup: RwLock, + /// Statistics. + stats: RwLock, +} + +impl ToolResponseStore { + /// Create a new tool response store with default configuration. + pub fn new() -> Self { + Self::with_config(ToolResponseStoreConfig::default()) + } + + /// Create a tool response store with custom configuration. + pub fn with_config(config: ToolResponseStoreConfig) -> Self { + Self { + responses: RwLock::new(HashMap::new()), + config, + last_cleanup: RwLock::new(Instant::now()), + stats: RwLock::new(StoreStats::default()), + } + } + + /// Store a tool response. + /// + /// If the store is at capacity, the oldest entry will be evicted. + /// Returns `true` if an entry was evicted to make room. + pub async fn store( + &self, + call_id: impl Into, + tool_name: impl Into, + result: ToolResult, + ) -> bool { + let call_id = call_id.into(); + let tool_name = tool_name.into(); + let mut evicted = false; + + // Perform periodic cleanup if needed + self.maybe_cleanup().await; + + let mut responses = self.responses.write().await; + + // Evict oldest entry if at capacity + if responses.len() >= self.config.max_size { + if let Some(oldest_key) = self.find_oldest_key(&responses) { + responses.remove(&oldest_key); + evicted = true; + debug!( + evicted_key = %oldest_key, + "Evicted oldest response to make room" + ); + } + } + + let response = StoredResponse::new(tool_name, result); + responses.insert(call_id.clone(), response); + + // Update stats + let mut stats = self.stats.write().await; + stats.total_stored += 1; + if evicted { + stats.evictions += 1; + } + + evicted + } + + /// Get a response without removing it (peek). + /// + /// Marks the response as read but does not consume it. + pub async fn get(&self, call_id: &str) -> Option { + let mut responses = self.responses.write().await; + + if let Some(response) = responses.get_mut(call_id) { + response.read = true; + let mut stats = self.stats.write().await; + stats.reads += 1; + Some(response.result.clone()) + } else { + None + } + } + + /// Take (consume) a response, removing it from the store. + /// + /// This is the primary method for retrieving responses as it ensures + /// entries are cleaned up after being consumed (#5293). + pub async fn take(&self, call_id: &str) -> Option { + let mut responses = self.responses.write().await; + + if let Some(response) = responses.remove(call_id) { + let mut stats = self.stats.write().await; + stats.takes += 1; + Some(response.result) + } else { + None + } + } + + /// Check if a response exists for the given call ID. + pub async fn contains(&self, call_id: &str) -> bool { + self.responses.read().await.contains_key(call_id) + } + + /// Get the current number of stored responses. + pub async fn len(&self) -> usize { + self.responses.read().await.len() + } + + /// Check if the store is empty. + pub async fn is_empty(&self) -> bool { + self.responses.read().await.is_empty() + } + + /// Remove all expired entries. + /// + /// Returns the number of entries removed. + pub async fn cleanup_expired(&self) -> usize { + let mut responses = self.responses.write().await; + let ttl = self.config.ttl; + let before = responses.len(); + + responses.retain(|_, v| !v.is_expired(ttl)); + + let removed = before - responses.len(); + if removed > 0 { + debug!(removed, "Cleaned up expired responses"); + let mut stats = self.stats.write().await; + stats.expired_cleanups += removed as u64; + } + + removed + } + + /// Remove all read entries that haven't been consumed. + /// + /// This is useful for cleaning up entries that were peeked but never taken. + pub async fn cleanup_read(&self) -> usize { + let mut responses = self.responses.write().await; + let before = responses.len(); + + responses.retain(|_, v| !v.read); + + let removed = before - responses.len(); + if removed > 0 { + debug!(removed, "Cleaned up read-but-not-consumed responses"); + } + + removed + } + + /// Clear all stored responses. + pub async fn clear(&self) { + self.responses.write().await.clear(); + } + + /// Get store statistics. + pub async fn stats(&self) -> StoreStats { + self.stats.read().await.clone() + } + + /// Get detailed store info including current size and config. + pub async fn info(&self) -> StoreInfo { + let responses = self.responses.read().await; + let stats = self.stats.read().await; + + StoreInfo { + current_size: responses.len(), + max_size: self.config.max_size, + ttl_secs: self.config.ttl.as_secs(), + oldest_age_secs: responses + .values() + .map(|r| r.age().as_secs()) + .max() + .unwrap_or(0), + stats: stats.clone(), + } + } + + // Internal helpers + + /// Find the key of the oldest entry. + fn find_oldest_key(&self, responses: &HashMap) -> Option { + responses + .iter() + .min_by_key(|(_, v)| v.stored_at) + .map(|(k, _)| k.clone()) + } + + /// Perform cleanup if enough time has passed since last cleanup. + async fn maybe_cleanup(&self) { + let should_cleanup = { + let last = self.last_cleanup.read().await; + last.elapsed() > CLEANUP_INTERVAL + }; + + if should_cleanup { + *self.last_cleanup.write().await = Instant::now(); + let removed = self.cleanup_expired().await; + if removed > 0 { + debug!(removed, "Periodic cleanup removed expired entries"); + } + } + } +} + +impl Default for ToolResponseStore { + fn default() -> Self { + Self::new() + } +} + +/// Statistics for the tool response store. +#[derive(Debug, Clone, Default)] +pub struct StoreStats { + /// Total responses stored. + pub total_stored: u64, + /// Number of get (peek) operations. + pub reads: u64, + /// Number of take (consume) operations. + pub takes: u64, + /// Number of evictions due to capacity limit. + pub evictions: u64, + /// Number of entries removed by TTL cleanup. + pub expired_cleanups: u64, +} + +/// Detailed store information. +#[derive(Debug, Clone)] +pub struct StoreInfo { + /// Current number of stored responses. + pub current_size: usize, + /// Maximum allowed size. + pub max_size: usize, + /// TTL in seconds. + pub ttl_secs: u64, + /// Age of oldest entry in seconds. + pub oldest_age_secs: u64, + /// Store statistics. + pub stats: StoreStats, +} + +/// Create a shared tool response store. +pub fn create_shared_store() -> Arc { + Arc::new(ToolResponseStore::new()) +} + +/// Create a shared tool response store with custom configuration. +pub fn create_shared_store_with_config(config: ToolResponseStoreConfig) -> Arc { + Arc::new(ToolResponseStore::with_config(config)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_store_and_take() { + let store = ToolResponseStore::new(); + + let result = ToolResult::success("test output"); + store.store("call-1", "Read", result.clone()).await; + + assert!(store.contains("call-1").await); + assert_eq!(store.len().await, 1); + + let taken = store.take("call-1").await; + assert!(taken.is_some()); + assert_eq!(taken.unwrap().output, "test output"); + + // After take, entry should be gone + assert!(!store.contains("call-1").await); + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_store_and_get() { + let store = ToolResponseStore::new(); + + let result = ToolResult::success("test output"); + store.store("call-1", "Read", result).await; + + // Get should return result but not remove it + let got = store.get("call-1").await; + assert!(got.is_some()); + assert!(store.contains("call-1").await); + + // Second get should still work + let got2 = store.get("call-1").await; + assert!(got2.is_some()); + } + + #[tokio::test] + async fn test_capacity_eviction() { + let config = ToolResponseStoreConfig::default().with_max_size(3); + let store = ToolResponseStore::with_config(config); + + // Fill to capacity + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store + .store("call-2", "Read", ToolResult::success("2")) + .await; + store + .store("call-3", "Read", ToolResult::success("3")) + .await; + + assert_eq!(store.len().await, 3); + + // Add one more, should evict oldest + let evicted = store + .store("call-4", "Read", ToolResult::success("4")) + .await; + assert!(evicted); + assert_eq!(store.len().await, 3); + + // call-1 should be evicted (oldest) + assert!(!store.contains("call-1").await); + assert!(store.contains("call-4").await); + } + + #[tokio::test] + async fn test_expired_cleanup() { + let config = ToolResponseStoreConfig::default().with_ttl(Duration::from_millis(50)); + let store = ToolResponseStore::with_config(config); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + assert_eq!(store.len().await, 1); + + // Wait for expiration + tokio::time::sleep(Duration::from_millis(100)).await; + + let removed = store.cleanup_expired().await; + assert_eq!(removed, 1); + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_cleanup_read() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store + .store("call-2", "Read", ToolResult::success("2")) + .await; + + // Read one entry + store.get("call-1").await; + + // Cleanup read entries + let removed = store.cleanup_read().await; + assert_eq!(removed, 1); + assert_eq!(store.len().await, 1); + assert!(!store.contains("call-1").await); + assert!(store.contains("call-2").await); + } + + #[tokio::test] + async fn test_stats() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store.get("call-1").await; + store.take("call-1").await; + + let stats = store.stats().await; + assert_eq!(stats.total_stored, 1); + assert_eq!(stats.reads, 1); + assert_eq!(stats.takes, 1); + } + + #[tokio::test] + async fn test_nonexistent_key() { + let store = ToolResponseStore::new(); + + assert!(store.get("nonexistent").await.is_none()); + assert!(store.take("nonexistent").await.is_none()); + assert!(!store.contains("nonexistent").await); + } + + #[tokio::test] + async fn test_clear() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store + .store("call-2", "Read", ToolResult::success("2")) + .await; + + assert_eq!(store.len().await, 2); + + store.clear().await; + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_info() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + + let info = store.info().await; + assert_eq!(info.current_size, 1); + assert_eq!(info.max_size, MAX_STORE_SIZE); + } +} diff --git a/src/cortex-exec/src/runner.rs b/src/cortex-exec/src/runner.rs index e831324..a802236 100644 --- a/src/cortex-exec/src/runner.rs +++ b/src/cortex-exec/src/runner.rs @@ -27,9 +27,17 @@ use cortex_protocol::ConversationId; use crate::output::{OutputFormat, OutputWriter}; /// Default timeout for the entire execution (10 minutes). +/// +/// This is the maximum duration for a multi-turn exec session. +/// See `cortex_common::http_client` module documentation for the complete +/// timeout hierarchy across Cortex services. const DEFAULT_TIMEOUT_SECS: u64 = 600; /// Default timeout for a single LLM request (2 minutes). +/// +/// Allows sufficient time for model inference while preventing indefinite hangs. +/// See `cortex_common::http_client` module documentation for the complete +/// timeout hierarchy across Cortex services. const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 120; /// Maximum retries for transient errors. @@ -187,7 +195,10 @@ impl ExecRunner { self.client = Some(client); } - Ok(self.client.as_ref().unwrap().as_ref()) + self.client + .as_ref() + .map(|c| c.as_ref()) + .ok_or_else(|| CortexError::Internal("LLM client not initialized".to_string())) } /// Get filtered tool definitions based on options. diff --git a/src/cortex-mcp-client/src/transport.rs b/src/cortex-mcp-client/src/transport.rs index 22152cf..0ee141d 100644 --- a/src/cortex-mcp-client/src/transport.rs +++ b/src/cortex-mcp-client/src/transport.rs @@ -20,8 +20,7 @@ use cortex_mcp_types::{ use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::process::{Child, Command}; use tokio::sync::{Mutex, RwLock}; -use tokio::time::sleep; -use tracing::{debug, error, info, warn}; +use tracing::{debug, info, warn}; // ============================================================================ // Transport Trait @@ -199,61 +198,6 @@ impl StdioTransport { Ok(()) } - /// Reconnect with exponential backoff. - /// - /// Properly cleans up existing connections before each attempt to prevent - /// file descriptor leaks (#2198). - #[allow(dead_code)] - async fn reconnect(&self) -> Result<()> { - if !self.reconnect_config.enabled { - return Err(anyhow!("Reconnection disabled")); - } - - let mut attempt = 0; - let mut delay = self.reconnect_config.initial_delay; - - while attempt < self.reconnect_config.max_attempts { - attempt += 1; - info!( - attempt, - max = self.reconnect_config.max_attempts, - "Attempting reconnection" - ); - - // Clean up any existing connection before attempting reconnect - // This prevents file descriptor leaks on repeated failures (#2198) - { - let mut process_guard = self.process.lock().await; - if let Some(mut child) = process_guard.take() { - // Kill the process and wait for it to clean up - let _ = child.kill().await; - // Wait a short time for resources to be released - drop(child); - } - self.connected.store(false, Ordering::SeqCst); - } - - // Clear any stale pending responses - self.pending_responses.write().await.clear(); - - match self.connect().await { - Ok(()) => { - info!("Reconnection successful"); - return Ok(()); - } - Err(e) => { - error!(error = %e, attempt, "Reconnection failed"); - if attempt < self.reconnect_config.max_attempts { - sleep(delay).await; - delay = (delay * 2).min(self.reconnect_config.max_delay); - } - } - } - } - - Err(anyhow!("Failed to reconnect after {} attempts", attempt)) - } - /// Send a request and wait for response. async fn send_request(&self, request: JsonRpcRequest) -> Result { // Ensure connected @@ -516,51 +460,6 @@ impl HttpTransport { fn next_request_id(&self) -> RequestId { RequestId::Number(self.request_id.fetch_add(1, Ordering::SeqCst) as i64) } - - /// Test connection. - #[allow(dead_code)] - async fn test_connection(&self) -> Result<()> { - let request = JsonRpcRequest::new(self.next_request_id(), methods::PING); - self.send_request(request).await?; - Ok(()) - } - - /// Reconnect with exponential backoff. - #[allow(dead_code)] - async fn reconnect(&self) -> Result<()> { - if !self.reconnect_config.enabled { - return Err(anyhow!("Reconnection disabled")); - } - - let mut attempt = 0; - let mut delay = self.reconnect_config.initial_delay; - - while attempt < self.reconnect_config.max_attempts { - attempt += 1; - info!( - attempt, - max = self.reconnect_config.max_attempts, - "Attempting HTTP reconnection" - ); - - match self.test_connection().await { - Ok(()) => { - info!("HTTP reconnection successful"); - self.connected.store(true, Ordering::SeqCst); - return Ok(()); - } - Err(e) => { - error!(error = %e, attempt, "HTTP reconnection failed"); - if attempt < self.reconnect_config.max_attempts { - sleep(delay).await; - delay = (delay * 2).min(self.reconnect_config.max_delay); - } - } - } - } - - Err(anyhow!("Failed to reconnect after {} attempts", attempt)) - } } #[async_trait] diff --git a/src/cortex-mcp-server/src/server.rs b/src/cortex-mcp-server/src/server.rs index 96fb8d8..266b9e4 100644 --- a/src/cortex-mcp-server/src/server.rs +++ b/src/cortex-mcp-server/src/server.rs @@ -222,14 +222,17 @@ impl McpServer { } async fn handle_initialize(&self, params: Option) -> Result { - // Check state - let current_state = *self.state.read().await; - if current_state != ServerState::Uninitialized { - return Err(JsonRpcError::invalid_request("Server already initialized")); + // Atomic check-and-transition: hold write lock during entire state check and modification + // to prevent TOCTOU race conditions where multiple concurrent initialize requests + // could both pass the uninitialized check before either sets the state + { + let mut state_guard = self.state.write().await; + if *state_guard != ServerState::Uninitialized { + return Err(JsonRpcError::invalid_request("Server already initialized")); + } + *state_guard = ServerState::Initializing; } - *self.state.write().await = ServerState::Initializing; - // Parse params let init_params: InitializeParams = params .map(serde_json::from_value) diff --git a/src/cortex-plugins/src/registry.rs b/src/cortex-plugins/src/registry.rs index 79f961e..f9bb235 100644 --- a/src/cortex-plugins/src/registry.rs +++ b/src/cortex-plugins/src/registry.rs @@ -674,18 +674,21 @@ impl PluginRegistry { let info = plugin.info().clone(); let id = info.id.clone(); - { - let plugins = self.plugins.read().await; - if plugins.contains_key(&id) { - return Err(PluginError::AlreadyExists(id)); - } - } - + // Use entry API to atomically check-and-insert within a single write lock + // to prevent TOCTOU race conditions where multiple concurrent registrations + // could both pass the contains_key check before either inserts let handle = PluginHandle::new(plugin); - { let mut plugins = self.plugins.write().await; - plugins.insert(id.clone(), handle); + use std::collections::hash_map::Entry; + match plugins.entry(id.clone()) { + Entry::Occupied(_) => { + return Err(PluginError::AlreadyExists(id)); + } + Entry::Vacant(entry) => { + entry.insert(handle); + } + } } { diff --git a/src/cortex-tui/src/cards/commands.rs b/src/cortex-tui/src/cards/commands.rs index b777c5a..25226b4 100644 --- a/src/cortex-tui/src/cards/commands.rs +++ b/src/cortex-tui/src/cards/commands.rs @@ -225,8 +225,9 @@ impl CardView for CommandsCard { fn desired_height(&self, max_height: u16, _width: u16) -> u16 { // Base height for list items + search bar + some padding - let command_count = self.commands.len() as u16; - let content_height = command_count + 2; // +2 for search bar and padding + // Use saturating conversion to prevent overflow when count > u16::MAX + let command_count = u16::try_from(self.commands.len()).unwrap_or(u16::MAX); + let content_height = command_count.saturating_add(2); // +2 for search bar and padding // Clamp between min 5 and max 14, respecting max_height content_height.clamp(5, 14).min(max_height) diff --git a/src/cortex-tui/src/cards/models.rs b/src/cortex-tui/src/cards/models.rs index a5d0e48..a7abf75 100644 --- a/src/cortex-tui/src/cards/models.rs +++ b/src/cortex-tui/src/cards/models.rs @@ -147,8 +147,9 @@ impl CardView for ModelsCard { fn desired_height(&self, max_height: u16, _width: u16) -> u16 { // Base height for list items + search bar + some padding - let model_count = self.models.len() as u16; - let content_height = model_count + 2; // +2 for search bar and padding + // Use saturating conversion to prevent overflow when count > u16::MAX + let model_count = u16::try_from(self.models.len()).unwrap_or(u16::MAX); + let content_height = model_count.saturating_add(2); // +2 for search bar and padding // Clamp between min 5 and max 12, respecting max_height content_height.clamp(5, 12).min(max_height) diff --git a/src/cortex-tui/src/cards/sessions.rs b/src/cortex-tui/src/cards/sessions.rs index 76c67a0..b856f91 100644 --- a/src/cortex-tui/src/cards/sessions.rs +++ b/src/cortex-tui/src/cards/sessions.rs @@ -207,7 +207,9 @@ impl CardView for SessionsCard { fn desired_height(&self, max_height: u16, _width: u16) -> u16 { // Base height: sessions + header + search bar + padding - let content_height = self.sessions.len() as u16 + 3; + // Use saturating conversion to prevent overflow when count > u16::MAX + let session_count = u16::try_from(self.sessions.len()).unwrap_or(u16::MAX); + let content_height = session_count.saturating_add(3); let min_height = 5; let max_desired = 15; content_height