diff --git a/Cargo.lock b/Cargo.lock index de5594c..130925f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1502,6 +1502,7 @@ dependencies = [ "serde_json", "strum_macros", "tempfile", + "tracing", "uuid", ] diff --git a/src/cortex-engine/src/streaming.rs b/src/cortex-engine/src/streaming.rs index 35bfcef..52c2a83 100644 --- a/src/cortex-engine/src/streaming.rs +++ b/src/cortex-engine/src/streaming.rs @@ -15,6 +15,10 @@ use futures::Stream; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; +/// Maximum number of events to buffer before dropping old ones. +/// Prevents unbounded memory growth if drain_events() is not called regularly. +const MAX_BUFFER_SIZE: usize = 10_000; + /// Token usage for streaming. #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct StreamTokenUsage { @@ -213,7 +217,7 @@ impl StreamProcessor { Self { state: StreamState::Idle, content: StreamContent::new(), - buffer: VecDeque::new(), + buffer: VecDeque::with_capacity(1024), // Pre-allocate reasonable capacity start_time: None, first_token_time: None, last_event_time: None, @@ -284,6 +288,10 @@ impl StreamProcessor { } } + // Enforce buffer size limit to prevent unbounded memory growth + if self.buffer.len() >= MAX_BUFFER_SIZE { + self.buffer.pop_front(); + } self.buffer.push_back(event); } diff --git a/src/cortex-protocol/Cargo.toml b/src/cortex-protocol/Cargo.toml index f162dff..489e953 100644 --- a/src/cortex-protocol/Cargo.toml +++ b/src/cortex-protocol/Cargo.toml @@ -20,6 +20,7 @@ uuid = { workspace = true, features = ["serde", "v4"] } chrono = { workspace = true } strum_macros = "0.27" base64 = { workspace = true } +tracing = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } diff --git a/src/cortex-protocol/src/protocol/message_parts.rs b/src/cortex-protocol/src/protocol/message_parts.rs index 67d238f..608f708 100644 --- a/src/cortex-protocol/src/protocol/message_parts.rs +++ b/src/cortex-protocol/src/protocol/message_parts.rs @@ -182,6 +182,45 @@ pub enum ToolState { }, } + +impl ToolState { + /// Check if transitioning to the given state is valid. + /// + /// Valid transitions: + /// - Pending -> Running, Completed, Error + /// - Running -> Completed, Error + /// - Completed -> (terminal, no transitions) + /// - Error -> (terminal, no transitions) + /// + /// State machine: + /// ```text + /// Pending -> Running -> Completed + /// | | + /// | +-> Error + /// +-> Completed + /// +-> Error + /// ``` + pub fn can_transition_to(&self, target: &ToolState) -> bool { + match (self, target) { + // From Pending, can go to any non-Pending state + (ToolState::Pending { .. }, ToolState::Running { .. }) => true, + (ToolState::Pending { .. }, ToolState::Completed { .. }) => true, + (ToolState::Pending { .. }, ToolState::Error { .. }) => true, + + // From Running, can go to Completed or Error + (ToolState::Running { .. }, ToolState::Completed { .. }) => true, + (ToolState::Running { .. }, ToolState::Error { .. }) => true, + + // Terminal states cannot transition + (ToolState::Completed { .. }, _) => false, + (ToolState::Error { .. }, _) => false, + + // Any other transition is invalid + _ => false, + } + } +} + /// Subtask execution status. #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -552,6 +591,8 @@ impl MessageWithParts { } /// Update a tool state by call ID. + /// + /// Logs a warning if the state transition is invalid (e.g., from a terminal state). pub fn update_tool_state(&mut self, call_id: &str, new_state: ToolState) -> bool { for part in &mut self.parts { if let MessagePart::Tool { @@ -561,6 +602,14 @@ impl MessageWithParts { } = &mut part.part { if cid == call_id { + if !state.can_transition_to(&new_state) { + tracing::warn!( + "Invalid ToolState transition from {:?} to {:?} for call_id {}", + state, + new_state, + call_id + ); + } *state = new_state; return true; }