diff --git a/src/cortex-agents/src/mention.rs b/src/cortex-agents/src/mention.rs index 81d9d71..47d1f9b 100644 --- a/src/cortex-agents/src/mention.rs +++ b/src/cortex-agents/src/mention.rs @@ -17,6 +17,46 @@ use regex::Regex; use std::sync::LazyLock; +/// Safely get the string slice up to the given byte position. +/// +/// Returns the slice `&text[..pos]` if `pos` is at a valid UTF-8 character boundary. +/// If `pos` is inside a multi-byte character, finds the nearest valid boundary +/// by searching backwards. +fn safe_slice_up_to(text: &str, pos: usize) -> &str { + if pos >= text.len() { + return text; + } + if text.is_char_boundary(pos) { + return &text[..pos]; + } + // Find the nearest valid boundary by searching backwards + let mut valid_pos = pos; + while valid_pos > 0 && !text.is_char_boundary(valid_pos) { + valid_pos -= 1; + } + &text[..valid_pos] +} + +/// Safely get the string slice from the given byte position to the end. +/// +/// Returns the slice `&text[pos..]` if `pos` is at a valid UTF-8 character boundary. +/// If `pos` is inside a multi-byte character, finds the nearest valid boundary +/// by searching forwards. +fn safe_slice_from(text: &str, pos: usize) -> &str { + if pos >= text.len() { + return ""; + } + if text.is_char_boundary(pos) { + return &text[pos..]; + } + // Find the nearest valid boundary by searching forwards + let mut valid_pos = pos; + while valid_pos < text.len() && !text.is_char_boundary(valid_pos) { + valid_pos += 1; + } + &text[valid_pos..] +} + /// A parsed agent mention from user input. #[derive(Debug, Clone, PartialEq, Eq)] pub struct AgentMention { @@ -108,10 +148,10 @@ pub fn extract_mention_and_text( ) -> Option<(AgentMention, String)> { let mention = find_first_valid_mention(text, valid_agents)?; - // Remove the mention from text + // Remove the mention from text, using safe slicing for UTF-8 boundaries let mut remaining = String::with_capacity(text.len()); - remaining.push_str(&text[..mention.start]); - remaining.push_str(&text[mention.end..]); + remaining.push_str(safe_slice_up_to(text, mention.start)); + remaining.push_str(safe_slice_from(text, mention.end)); // Trim and normalize whitespace let remaining = remaining.trim().to_string(); @@ -123,7 +163,8 @@ pub fn extract_mention_and_text( pub fn starts_with_mention(text: &str, valid_agents: &[&str]) -> bool { let text = text.trim(); if let Some(mention) = find_first_valid_mention(text, valid_agents) { - mention.start == 0 || text[..mention.start].trim().is_empty() + // Use safe slicing to handle UTF-8 boundaries + mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty() } else { false } @@ -196,8 +237,8 @@ pub fn parse_message_for_agent(text: &str, valid_agents: &[&str]) -> ParsedAgent // Check if message starts with @agent if let Some((mention, remaining)) = extract_mention_and_text(text, valid_agents) { - // Only trigger if mention is at the start - if mention.start == 0 || text[..mention.start].trim().is_empty() { + // Only trigger if mention is at the start, using safe slicing for UTF-8 boundaries + if mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty() { return ParsedAgentMessage::for_agent(mention.agent_name, remaining, text.to_string()); } } @@ -318,4 +359,99 @@ mod tests { assert_eq!(mentions[0].agent_name, "my-agent"); assert_eq!(mentions[1].agent_name, "my_agent"); } + + // UTF-8 boundary safety tests + #[test] + fn test_safe_slice_up_to_ascii() { + let text = "hello world"; + assert_eq!(safe_slice_up_to(text, 5), "hello"); + assert_eq!(safe_slice_up_to(text, 0), ""); + assert_eq!(safe_slice_up_to(text, 100), "hello world"); + } + + #[test] + fn test_safe_slice_up_to_multibyte() { + // "こんにちは" - each character is 3 bytes + let text = "こんにちは"; + assert_eq!(safe_slice_up_to(text, 3), "こ"); // Valid boundary + assert_eq!(safe_slice_up_to(text, 6), "こん"); // Valid boundary + // Position 4 is inside the second character, should return "こ" + assert_eq!(safe_slice_up_to(text, 4), "こ"); + assert_eq!(safe_slice_up_to(text, 5), "こ"); + } + + #[test] + fn test_safe_slice_from_multibyte() { + let text = "こんにちは"; + assert_eq!(safe_slice_from(text, 3), "んにちは"); // Valid boundary + // Position 4 is inside second character, should skip to position 6 + assert_eq!(safe_slice_from(text, 4), "にちは"); + assert_eq!(safe_slice_from(text, 5), "にちは"); + } + + #[test] + fn test_extract_mention_with_multibyte_prefix() { + let valid = vec!["general"]; + + // Multi-byte characters before mention + let result = extract_mention_and_text("日本語 @general search files", &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + // The prefix should be preserved without panicking + assert!(remaining.contains("search files")); + } + + #[test] + fn test_starts_with_mention_multibyte() { + let valid = vec!["general"]; + + // Whitespace with multi-byte characters should not cause panic + assert!(starts_with_mention(" @general task", &valid)); + + // Multi-byte characters before mention - should return false, not panic + assert!(!starts_with_mention("日本語 @general task", &valid)); + } + + #[test] + fn test_parse_message_for_agent_multibyte() { + let valid = vec!["general"]; + + // Multi-byte prefix - should not panic + let parsed = parse_message_for_agent("日本語 @general find files", &valid); + // Since mention is not at the start, should not invoke task + assert!(!parsed.should_invoke_task); + + // Multi-byte in the prompt (after mention) + let parsed = parse_message_for_agent("@general 日本語を検索", &valid); + assert!(parsed.should_invoke_task); + assert_eq!(parsed.agent, Some("general".to_string())); + assert_eq!(parsed.prompt, "日本語を検索"); + } + + #[test] + fn test_extract_mention_with_emoji() { + let valid = vec!["general"]; + + // Emojis are 4 bytes each + let result = extract_mention_and_text("🎉 @general celebrate", &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + assert!(remaining.contains("celebrate")); + } + + #[test] + fn test_mixed_multibyte_and_ascii() { + let valid = vec!["general"]; + + // Mix of ASCII, CJK, and emoji + let text = "Hello 世界 🌍 @general search for 日本語"; + let result = extract_mention_and_text(text, &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + // Should not panic and produce valid output + assert!(!remaining.is_empty()); + } } diff --git a/src/cortex-compact/src/compactor.rs b/src/cortex-compact/src/compactor.rs index f5cbeb1..00fe23d 100644 --- a/src/cortex-compact/src/compactor.rs +++ b/src/cortex-compact/src/compactor.rs @@ -106,7 +106,10 @@ impl Compactor { }]; new_items.extend(items.into_iter().skip(preserved_start)); - let tokens_after = current_tokens - tokens_in_compacted + summary_tokens; + // Use saturating arithmetic to prevent underflow if tokens_in_compacted > current_tokens + let tokens_after = current_tokens + .saturating_sub(tokens_in_compacted) + .saturating_add(summary_tokens); let result = CompactionResult::success(summary, current_tokens, tokens_after, items_removed);