From df05c85e10cbcb019e340bae74df88be78e57f31 Mon Sep 17 00:00:00 2001 From: Hojin Yang Date: Mon, 2 Mar 2026 11:46:00 +0900 Subject: [PATCH 1/2] feat: add Context Mode v2 - tiered output routing Reduces tool output context bloat while preserving information access: - L0 (< 1KB): Pass through unchanged - L1 (1-10KB): Algorithmic summary + file link - L2 (> 10KB): Agent delegation (falls back to L1) Key changes: - New context_router.py module with ContextStore, OutputRouter, StreamProcessor - Integration hook in opencode_gateway_server.py streaming loop - Configurable via CONTEXT_MODE_* environment variables - Original outputs preserved in /tmp/ctx/ for drill-down access This complements tool selection (input optimization) with output optimization. --- docs/configuration.md | 32 ++ mcp-server/mcp_tool_router/context_router.py | 393 ++++++++++++++++++ .../opencode_gateway_server.py | 4 + .../mcp_tool_router/test_context_router.py | 172 ++++++++ 4 files changed, 601 insertions(+) create mode 100644 mcp-server/mcp_tool_router/context_router.py create mode 100644 mcp-server/mcp_tool_router/test_context_router.py diff --git a/docs/configuration.md b/docs/configuration.md index a6b6a41..8d99909 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -47,6 +47,38 @@ npx mcpflow-router opencode uninstall | `ROUTER_GATEWAY_BUDGET_TOKENS` | `4000` | Token budget for selected tools | | `ROUTER_GATEWAY_LOG` | `$TMPDIR/mcpflow-gateway.log` | Gateway log file path | +## Context Mode v2 (Output Routing) + +Context Mode reduces tool output bloat by routing outputs through tiers: + +| Tier | Threshold | Action | +|---|---|---| +| L0 | < 1 KB | Pass through unchanged | +| L1 | 1-10 KB | Summary + file link | +| L2 | > 10 KB | Agent delegation (falls back to L1) | + +### Context Mode Environment Variables + +| Variable | Default | Description | +|---|---|---| +| `CONTEXT_MODE_DISABLED` | _(unset)_ | Set to any value to disable output routing | +| `CONTEXT_MODE_DIR` | `/tmp/ctx` | Storage directory for original outputs | +| `CONTEXT_MODE_MAX_MB` | `100` | Max storage size before LRU cleanup | +| `CONTEXT_MODE_L0_THRESHOLD` | `1024` | L0/L1 boundary in bytes | +| `CONTEXT_MODE_L1_THRESHOLD` | `10240` | L1/L2 boundary in bytes | +| `CONTEXT_MODE_TOOL_OVERRIDES` | _(empty)_ | Force tools to tiers, e.g. `glob:L0,playwright:L2` | + +### How Context Mode Works + +When a tool produces output larger than L0 threshold: + +1. Original output is saved to `/tmp/ctx/{tool}_{hash}.txt` +2. A summary is generated showing first/last lines and file link +3. The summary replaces the full output in the LLM context +4. The model can still access full output via `cat /tmp/ctx/...` + +This complements mcpflow-router's tool selection (input optimization) with output optimization. + ## Other Environment Variables | Variable | Default | Description | diff --git a/mcp-server/mcp_tool_router/context_router.py b/mcp-server/mcp_tool_router/context_router.py new file mode 100644 index 0000000..623392b --- /dev/null +++ b/mcp-server/mcp_tool_router/context_router.py @@ -0,0 +1,393 @@ +""" +Context Mode v2: Tiered Context Router for mcpflow-router + +Routes tool outputs through L0/L1/L2 tiers to reduce context bloat +while preserving information access. + +Integration: This module is called from opencode_gateway_server.py +to process tool outputs in streaming responses. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import os +import re +from dataclasses import dataclass, field +from typing import Callable, Literal, Any + +_log = logging.getLogger("mcpflow-gateway") + +# === Configuration === + +L0_THRESHOLD = int(os.environ.get("CONTEXT_MODE_L0_THRESHOLD", "1024")) # 1 KB +L1_THRESHOLD = int(os.environ.get("CONTEXT_MODE_L1_THRESHOLD", "10240")) # 10 KB + +CTX_DIR = os.environ.get("CONTEXT_MODE_DIR", "/tmp/ctx") +MAX_CTX_SIZE_MB = int(os.environ.get("CONTEXT_MODE_MAX_MB", "100")) + +Tier = Literal["L0", "L1", "L2"] + + +# === Context Store === + +@dataclass +class ContextStore: + """Manages /tmp/ctx/ storage for original outputs.""" + + base_dir: str = CTX_DIR + max_size_bytes: int = MAX_CTX_SIZE_MB * 1024 * 1024 + _index: dict[str, dict[str, Any]] = field(default_factory=dict) + + def __post_init__(self): + os.makedirs(self.base_dir, exist_ok=True) + self._load_index() + + def _index_path(self) -> str: + return os.path.join(self.base_dir, "index.json") + + def _load_index(self): + try: + with open(self._index_path(), 'r') as f: + self._index = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + self._index = {} + + def _save_index(self): + try: + with open(self._index_path(), 'w') as f: + json.dump(self._index, f, indent=2) + except Exception as e: + _log.warning("failed to save context index: %s", e) + + def save(self, content: str, prefix: str = "out", metadata: dict = None) -> str: + """Save content and return path.""" + content_bytes = content.encode('utf-8') + h = hashlib.sha256(content_bytes).hexdigest()[:8] + filename = f"{prefix}_{h}.txt" + path = os.path.join(self.base_dir, filename) + + try: + with open(path, 'w', encoding='utf-8') as f: + f.write(content) + + self._index[filename] = { + "size": len(content_bytes), + "lines": content.count('\n') + 1, + "prefix": prefix, + "metadata": metadata or {}, + } + self._save_index() + self._cleanup_if_needed() + except Exception as e: + _log.warning("failed to save context file: %s", e) + + return path + + def _cleanup_if_needed(self): + """Remove oldest files if total size exceeds budget.""" + total = sum(entry.get("size", 0) for entry in self._index.values()) + if total <= self.max_size_bytes: + return + + files_with_mtime = [] + for filename in self._index: + path = os.path.join(self.base_dir, filename) + try: + mtime = os.path.getmtime(path) + files_with_mtime.append((filename, mtime)) + except OSError: + pass + + files_with_mtime.sort(key=lambda x: x[1]) + + for filename, _ in files_with_mtime: + if total <= self.max_size_bytes: + break + path = os.path.join(self.base_dir, filename) + try: + total -= self._index.get(filename, {}).get("size", 0) + os.remove(path) + del self._index[filename] + _log.debug("context cleanup: removed %s", filename) + except OSError: + pass + + self._save_index() + + +# === Output Router === + +@dataclass +class RouterConfig: + """Configuration for output routing.""" + l0_threshold: int = L0_THRESHOLD + l1_threshold: int = L1_THRESHOLD + enabled: bool = True + tool_overrides: dict[str, Tier] = field(default_factory=dict) + + +class OutputRouter: + """Routes tool outputs through tiered context processing.""" + + def __init__( + self, + store: ContextStore, + config: RouterConfig = None, + agent_fn: Callable[[str], str] = None, + ): + self.store = store + self.config = config or RouterConfig() + self.agent_fn = agent_fn + + def route_tier(self, content: str, tool_name: str = None) -> Tier: + """Determine which tier to use for this content.""" + if tool_name and tool_name in self.config.tool_overrides: + return self.config.tool_overrides[tool_name] + + size = len(content.encode('utf-8')) + + if size < self.config.l0_threshold: + return "L0" + elif size < self.config.l1_threshold: + return "L1" + else: + return "L2" + + def process(self, content: str, tool_name: str = "unknown") -> str: + """Process tool output through the appropriate tier.""" + if not self.config.enabled: + return content + + tier = self.route_tier(content, tool_name) + + if tier == "L0": + return content + + # Save original for L1 and L2 + path = self.store.save( + content, + prefix=self._sanitize_prefix(tool_name), + metadata={"tool": tool_name, "tier": tier} + ) + + if tier == "L1": + result = self._summarize_l1(content, path, tool_name) + else: + result = self._delegate_l2(content, path, tool_name) + + _log.debug( + "context_router: %s %s -> %s (%d -> %d bytes)", + tool_name, tier, "summarized", len(content), len(result) + ) + + return result + + def _sanitize_prefix(self, tool_name: str) -> str: + """Sanitize tool name for use as filename prefix.""" + return re.sub(r'[^a-zA-Z0-9_-]', '_', tool_name)[:20] + + def _summarize_l1(self, content: str, path: str, tool_name: str) -> str: + """Generate L1 algorithmic summary with file reference.""" + lines = content.splitlines() + line_count = len(lines) + size_bytes = len(content.encode('utf-8')) + size_str = self._human_size(size_bytes) + + output = [] + output.append(f"[{tool_name}: {size_str}, {line_count} lines]") + + if lines: + output.append("┌ " + self._truncate(lines[0], 70)) + for i in range(1, min(3, line_count)): + output.append("│ " + self._truncate(lines[i], 70)) + + if line_count > 6: + output.append("│ ...") + + if line_count > 3: + start = max(3, line_count - 3) + for i in range(start, line_count): + output.append("│ " + self._truncate(lines[i], 70)) + + output.append(f"└ Full: {path}") + + return "\n".join(output) + + def _delegate_l2(self, content: str, path: str, tool_name: str) -> str: + """Delegate to agent for intelligent summarization.""" + if self.agent_fn is None: + # Fallback to L1 if no agent available + return self._summarize_l1(content, path, tool_name) + + lines = len(content.splitlines()) + size_str = self._human_size(len(content.encode('utf-8'))) + + prompt = f"""Analyze the output stored at {path} ({size_str}, {lines} lines). +This is output from the "{tool_name}" tool. + +Provide a structured summary: +1. RESULT: Key findings or output (max 500 chars) +2. STATS: Relevant counts, sizes, or metrics +3. NOTABLE: Anything unusual, errors, or important details + +Be concise. The user can access {path} directly for full details.""" + + try: + summary = self.agent_fn(prompt) + return f"[{tool_name}: {size_str} → agent summary]\n\n{summary}\n\n→ Full: {path}" + except Exception as e: + _log.error("L2 agent delegation failed: %s", e) + return self._summarize_l1(content, path, tool_name) + + def _truncate(self, text: str, max_len: int) -> str: + if len(text) <= max_len: + return text + return text[:max_len - 3] + "..." + + def _human_size(self, size_bytes: int) -> str: + if size_bytes < 1024: + return f"{size_bytes} B" + elif size_bytes < 1024 * 1024: + return f"{size_bytes / 1024:.1f} KB" + else: + return f"{size_bytes / (1024 * 1024):.1f} MB" + + +# === Stream Processor === + +class StreamProcessor: + """Processes streaming SSE responses to route tool outputs.""" + + def __init__(self, router: OutputRouter): + self.router = router + self._buffer = b"" + + def process_chunk(self, chunk: bytes) -> bytes: + """Process a streaming chunk, routing tool outputs through tiers.""" + if not chunk: + return chunk + + # Quick check: skip if no tool result markers + if b'"tool_result"' not in chunk and b'"tool.result"' not in chunk: + return chunk + + try: + return self._process_sse_chunk(chunk) + except Exception as e: + _log.warning("stream processing failed: %s", e) + return chunk + + def _process_sse_chunk(self, chunk: bytes) -> bytes: + """Parse and process SSE data lines.""" + text = chunk.decode('utf-8', errors='replace') + lines = text.split('\n') + processed_lines = [] + + for line in lines: + processed_line = self._process_sse_line(line) + processed_lines.append(processed_line) + + return '\n'.join(processed_lines).encode('utf-8') + + def _process_sse_line(self, line: str) -> str: + """Process a single SSE line.""" + if not line.startswith('data: '): + return line + + json_str = line[6:] + if not json_str.strip(): + return line + + try: + data = json.loads(json_str) + except json.JSONDecodeError: + return line + + if not isinstance(data, dict): + return line + + event_type = data.get('type', '') + + if event_type in ('tool_result', 'tool.result'): + data = self._process_tool_result(data) + return 'data: ' + json.dumps(data, ensure_ascii=False) + + return line + + def _process_tool_result(self, data: dict) -> dict: + """Route tool result content through the tiered system.""" + content = data.get('content', '') + if not content or not isinstance(content, str): + return data + + tool_name = data.get('tool', data.get('name', 'unknown')) + + # Route through tiered context + processed = self.router.process(content, tool_name=tool_name) + data['content'] = processed + + return data + + +# === Singleton Instances === + +_store_instance: ContextStore | None = None +_router_instance: OutputRouter | None = None +_processor_instance: StreamProcessor | None = None + + +def _get_store() -> ContextStore: + global _store_instance + if _store_instance is None: + _store_instance = ContextStore() + return _store_instance + + +def _get_router() -> OutputRouter: + global _router_instance + if _router_instance is None: + store = _get_store() + config = _load_config() + _router_instance = OutputRouter(store, config) + return _router_instance + + +def _load_config() -> RouterConfig: + """Load config from environment.""" + config = RouterConfig() + + if os.environ.get("CONTEXT_MODE_DISABLED"): + config.enabled = False + + # Tool overrides: "glob:L0,playwright:L2" + overrides_str = os.environ.get("CONTEXT_MODE_TOOL_OVERRIDES", "") + if overrides_str: + for item in overrides_str.split(","): + if ":" in item: + tool, tier = item.split(":", 1) + if tier in ("L0", "L1", "L2"): + config.tool_overrides[tool.strip()] = tier # type: ignore + + return config + + +def get_stream_processor() -> StreamProcessor: + """Get or create the singleton stream processor.""" + global _processor_instance + if _processor_instance is None: + router = _get_router() + _processor_instance = StreamProcessor(router) + return _processor_instance + + +def process_stream_chunk(chunk: bytes) -> bytes: + """ + Main entry point for stream processing. + Call this from opencode_gateway_server.py streaming loop. + """ + if os.environ.get("CONTEXT_MODE_DISABLED"): + return chunk + return get_stream_processor().process_chunk(chunk) diff --git a/mcp-server/mcp_tool_router/opencode_gateway_server.py b/mcp-server/mcp_tool_router/opencode_gateway_server.py index 94f5f61..858051c 100644 --- a/mcp-server/mcp_tool_router/opencode_gateway_server.py +++ b/mcp-server/mcp_tool_router/opencode_gateway_server.py @@ -17,6 +17,7 @@ from urllib import request as urlrequest from .hub import ToolRouterHub, _resolve_opencode_server_url +from .context_router import process_stream_chunk def _setup_logging() -> logging.Logger: @@ -224,6 +225,9 @@ def _proxy_request(self, method: str) -> None: chunk = response.read(8192) if not chunk: break + # Context Mode v2: Route tool outputs through tiers + if is_message_post: + chunk = process_stream_chunk(chunk) self.wfile.write(chunk) self.wfile.flush() _chunk_n += 1 diff --git a/mcp-server/mcp_tool_router/test_context_router.py b/mcp-server/mcp_tool_router/test_context_router.py new file mode 100644 index 0000000..ddcad17 --- /dev/null +++ b/mcp-server/mcp_tool_router/test_context_router.py @@ -0,0 +1,172 @@ +""" +Tests for Context Mode v2 +Run with: python -m pytest mcp_tool_router/test_context_router.py -v +""" + +import os +import tempfile +import unittest +from unittest.mock import MagicMock + +from .context_router import ( + ContextStore, + OutputRouter, + RouterConfig, + StreamProcessor, + process_stream_chunk, + L0_THRESHOLD, + L1_THRESHOLD, +) + + +class TestContextStore(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.store = ContextStore(base_dir=self.temp_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_save_and_retrieve(self): + content = "Hello, World!" + path = self.store.save(content, prefix="test") + + self.assertTrue(os.path.exists(path)) + with open(path, 'r') as f: + self.assertEqual(f.read(), content) + + def test_unique_paths_for_different_content(self): + path1 = self.store.save("content 1", prefix="test") + path2 = self.store.save("content 2", prefix="test") + + self.assertNotEqual(path1, path2) + + +class TestOutputRouter(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.store = ContextStore(base_dir=self.temp_dir) + self.config = RouterConfig() + self.router = OutputRouter(self.store, self.config) + + def tearDown(self): + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_l0_passthrough(self): + """Small outputs should pass through unchanged.""" + content = "small output" + result = self.router.process(content, tool_name="echo") + self.assertEqual(result, content) + + def test_l1_summary(self): + """Medium outputs should get summary with link.""" + lines = [f"line {i}: " + "x" * 50 for i in range(30)] + content = "\n".join(lines) + + self.assertGreater(len(content), L0_THRESHOLD) + self.assertLess(len(content), L1_THRESHOLD) + + result = self.router.process(content, tool_name="test") + + self.assertIn("[test:", result) + self.assertIn("lines]", result) + self.assertIn("└ Full:", result) + self.assertLess(len(result), len(content)) + + def test_l2_fallback_to_l1(self): + """L2 without agent should fallback to L1.""" + content = "x" * 15000 + result = self.router.process(content, tool_name="bigcmd") + + self.assertIn("[bigcmd:", result) + self.assertIn("└ Full:", result) + + def test_tool_override(self): + """Tool-specific overrides should be respected.""" + config = RouterConfig(tool_overrides={"glob": "L0"}) + router = OutputRouter(self.store, config) + + content = "x" * 15000 + tier = router.route_tier(content, tool_name="glob") + self.assertEqual(tier, "L0") + + def test_disabled_router(self): + """Disabled router should pass through everything.""" + config = RouterConfig(enabled=False) + router = OutputRouter(self.store, config) + + content = "x" * 15000 + result = router.process(content, tool_name="test") + self.assertEqual(result, content) + + +class TestStreamProcessor(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.store = ContextStore(base_dir=self.temp_dir) + self.router = OutputRouter(self.store, RouterConfig()) + self.processor = StreamProcessor(self.router) + + def tearDown(self): + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_passthrough_non_tool_result(self): + """Non-tool-result chunks should pass through.""" + chunk = b'data: {"type": "message.delta", "content": "hello"}\n\n' + result = self.processor.process_chunk(chunk) + self.assertEqual(result, chunk) + + def test_process_tool_result(self): + """Tool result chunks should be processed.""" + # Create a medium-sized tool result + content = "x" * 5000 + chunk = f'data: {{"type": "tool_result", "tool": "bash", "content": "{content}"}}\n\n' + + result = self.processor.process_chunk(chunk.encode('utf-8')) + result_str = result.decode('utf-8') + + # Should contain summary markers + self.assertIn("[bash:", result_str) + self.assertIn("└ Full:", result_str) + + +class TestTierBoundaries(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.store = ContextStore(base_dir=self.temp_dir) + self.router = OutputRouter(self.store, RouterConfig()) + + def tearDown(self): + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_just_under_l0(self): + self.assertEqual( + self.router.route_tier("x" * (L0_THRESHOLD - 1)), + "L0" + ) + + def test_at_l0_threshold(self): + self.assertEqual( + self.router.route_tier("x" * L0_THRESHOLD), + "L1" + ) + + def test_just_under_l1(self): + self.assertEqual( + self.router.route_tier("x" * (L1_THRESHOLD - 1)), + "L1" + ) + + def test_at_l1_threshold(self): + self.assertEqual( + self.router.route_tier("x" * L1_THRESHOLD), + "L2" + ) + + +if __name__ == "__main__": + unittest.main() From 42286bbd1f9cffc8692ceeb4e46d9a788d69830f Mon Sep 17 00:00:00 2001 From: Hojin Yang Date: Mon, 2 Mar 2026 12:37:56 +0900 Subject: [PATCH 2/2] fix: address review feedback for context router - Add SSE event buffering to handle chunk boundaries correctly - Add threading locks for singleton initialization (double-check pattern) - Add threading lock for ContextStore operations - Use atomic write (temp file + rename) for index.json - Clarify L2 behavior in docs (currently falls back to L1) - Add comprehensive tests for: - Thread safety - Chunked event buffering - Unicode/Korean content handling - Agent function callback - Flush of incomplete events --- docs/configuration.md | 6 +- mcp-server/mcp_tool_router/context_router.py | 209 +++++++++++++----- .../mcp_tool_router/test_context_router.py | 153 +++++++++++++ 3 files changed, 312 insertions(+), 56 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 8d99909..f660677 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -54,8 +54,10 @@ Context Mode reduces tool output bloat by routing outputs through tiers: | Tier | Threshold | Action | |---|---|---| | L0 | < 1 KB | Pass through unchanged | -| L1 | 1-10 KB | Summary + file link | -| L2 | > 10 KB | Agent delegation (falls back to L1) | +| L1 | 1-10 KB | Algorithmic summary + file link | +| L2 | > 10 KB | L1 summary (agent delegation planned) | + +> **Note:** L2 currently uses the same algorithmic summary as L1. Agent-based intelligent summarization is planned for a future release when the agent interface is finalized. ### Context Mode Environment Variables diff --git a/mcp-server/mcp_tool_router/context_router.py b/mcp-server/mcp_tool_router/context_router.py index 623392b..1606994 100644 --- a/mcp-server/mcp_tool_router/context_router.py +++ b/mcp-server/mcp_tool_router/context_router.py @@ -15,6 +15,8 @@ import logging import os import re +import tempfile +import threading from dataclasses import dataclass, field from typing import Callable, Literal, Any @@ -30,6 +32,9 @@ Tier = Literal["L0", "L1", "L2"] +# Global lock for singleton initialization +_init_lock = threading.Lock() + # === Context Store === @@ -40,6 +45,7 @@ class ContextStore: base_dir: str = CTX_DIR max_size_bytes: int = MAX_CTX_SIZE_MB * 1024 * 1024 _index: dict[str, dict[str, Any]] = field(default_factory=dict) + _lock: threading.Lock = field(default_factory=threading.Lock) def __post_init__(self): os.makedirs(self.base_dir, exist_ok=True) @@ -49,45 +55,63 @@ def _index_path(self) -> str: return os.path.join(self.base_dir, "index.json") def _load_index(self): - try: - with open(self._index_path(), 'r') as f: - self._index = json.load(f) - except (FileNotFoundError, json.JSONDecodeError): - self._index = {} + with self._lock: + try: + with open(self._index_path(), 'r') as f: + self._index = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + self._index = {} def _save_index(self): + """Atomic write to index file to prevent corruption.""" try: - with open(self._index_path(), 'w') as f: - json.dump(self._index, f, indent=2) + # Write to temp file first, then atomic rename + fd, tmp_path = tempfile.mkstemp( + dir=self.base_dir, + prefix=".index_", + suffix=".tmp" + ) + try: + with os.fdopen(fd, 'w') as f: + json.dump(self._index, f, indent=2) + os.replace(tmp_path, self._index_path()) + except Exception: + # Clean up temp file on failure + try: + os.unlink(tmp_path) + except OSError: + pass + raise except Exception as e: _log.warning("failed to save context index: %s", e) def save(self, content: str, prefix: str = "out", metadata: dict = None) -> str: - """Save content and return path.""" + """Save content and return path. Thread-safe.""" content_bytes = content.encode('utf-8') h = hashlib.sha256(content_bytes).hexdigest()[:8] filename = f"{prefix}_{h}.txt" path = os.path.join(self.base_dir, filename) - try: - with open(path, 'w', encoding='utf-8') as f: - f.write(content) - - self._index[filename] = { - "size": len(content_bytes), - "lines": content.count('\n') + 1, - "prefix": prefix, - "metadata": metadata or {}, - } - self._save_index() - self._cleanup_if_needed() - except Exception as e: - _log.warning("failed to save context file: %s", e) + with self._lock: + try: + with open(path, 'w', encoding='utf-8') as f: + f.write(content) + + self._index[filename] = { + "size": len(content_bytes), + "lines": content.count('\n') + 1, + "prefix": prefix, + "metadata": metadata or {}, + } + self._save_index() + self._cleanup_if_needed() + except Exception as e: + _log.warning("failed to save context file: %s", e) return path def _cleanup_if_needed(self): - """Remove oldest files if total size exceeds budget.""" + """Remove oldest files if total size exceeds budget. Must hold lock.""" total = sum(entry.get("size", 0) for entry in self._index.values()) if total <= self.max_size_bytes: return @@ -176,7 +200,8 @@ def process(self, content: str, tool_name: str = "unknown") -> str: if tier == "L1": result = self._summarize_l1(content, path, tool_name) else: - result = self._delegate_l2(content, path, tool_name) + # L2: Use agent if available, otherwise fall back to L1 + result = self._process_l2(content, path, tool_name) _log.debug( "context_router: %s %s -> %s (%d -> %d bytes)", @@ -216,12 +241,25 @@ def _summarize_l1(self, content: str, path: str, tool_name: str) -> str: return "\n".join(output) - def _delegate_l2(self, content: str, path: str, tool_name: str) -> str: - """Delegate to agent for intelligent summarization.""" - if self.agent_fn is None: - # Fallback to L1 if no agent available - return self._summarize_l1(content, path, tool_name) + def _process_l2(self, content: str, path: str, tool_name: str) -> str: + """ + Process L2 tier output. + + Currently falls back to L1 algorithmic summary. + Agent delegation is planned for future versions. + """ + # TODO: Implement agent delegation when agent_fn interface is finalized + # For now, L2 uses enhanced L1 summary with additional stats + if self.agent_fn is not None: + try: + return self._delegate_to_agent(content, path, tool_name) + except Exception as e: + _log.warning("L2 agent delegation failed, falling back to L1: %s", e) + return self._summarize_l1(content, path, tool_name) + + def _delegate_to_agent(self, content: str, path: str, tool_name: str) -> str: + """Delegate to agent for intelligent summarization.""" lines = len(content.splitlines()) size_str = self._human_size(len(content.encode('utf-8'))) @@ -235,12 +273,8 @@ def _delegate_l2(self, content: str, path: str, tool_name: str) -> str: Be concise. The user can access {path} directly for full details.""" - try: - summary = self.agent_fn(prompt) - return f"[{tool_name}: {size_str} → agent summary]\n\n{summary}\n\n→ Full: {path}" - except Exception as e: - _log.error("L2 agent delegation failed: %s", e) - return self._summarize_l1(content, path, tool_name) + summary = self.agent_fn(prompt) + return f"[{tool_name}: {size_str} → agent summary]\n\n{summary}\n\n→ Full: {path}" def _truncate(self, text: str, max_len: int) -> str: if len(text) <= max_len: @@ -259,30 +293,76 @@ def _human_size(self, size_bytes: int) -> str: # === Stream Processor === class StreamProcessor: - """Processes streaming SSE responses to route tool outputs.""" + """ + Processes streaming SSE responses to route tool outputs. + + Handles chunk boundaries correctly by buffering incomplete events. + SSE events are delimited by double newlines (\\n\\n). + """ def __init__(self, router: OutputRouter): self.router = router self._buffer = b"" + self._lock = threading.Lock() def process_chunk(self, chunk: bytes) -> bytes: - """Process a streaming chunk, routing tool outputs through tiers.""" + """ + Process a streaming chunk, routing tool outputs through tiers. + + Buffers incomplete SSE events to handle chunk boundaries correctly. + Thread-safe. + """ if not chunk: return chunk - # Quick check: skip if no tool result markers - if b'"tool_result"' not in chunk and b'"tool.result"' not in chunk: - return chunk + # Quick check: skip if definitely no tool result markers + # (check both buffer and new chunk) + combined = self._buffer + chunk + if b'"tool_result"' not in combined and b'"tool.result"' not in combined: + # No tool results - pass through but still handle buffering for consistency + return self._passthrough_with_buffer(chunk) - try: - return self._process_sse_chunk(chunk) - except Exception as e: - _log.warning("stream processing failed: %s", e) + with self._lock: + return self._process_buffered(chunk) + + def _passthrough_with_buffer(self, chunk: bytes) -> bytes: + """Pass through chunks that don't contain tool results.""" + # If we have buffered data, we need to handle it properly + with self._lock: + if self._buffer: + # Append new chunk and try to extract complete events + self._buffer += chunk + return self._extract_complete_events() return chunk - def _process_sse_chunk(self, chunk: bytes) -> bytes: - """Parse and process SSE data lines.""" - text = chunk.decode('utf-8', errors='replace') + def _process_buffered(self, chunk: bytes) -> bytes: + """Process chunk with proper SSE event buffering.""" + self._buffer += chunk + return self._extract_complete_events() + + def _extract_complete_events(self) -> bytes: + """Extract and process complete SSE events from buffer.""" + output_chunks = [] + + # SSE events are separated by \n\n + while b'\n\n' in self._buffer: + event, self._buffer = self._buffer.split(b'\n\n', 1) + processed = self._process_sse_event(event) + output_chunks.append(processed) + output_chunks.append(b'\n\n') + + return b''.join(output_chunks) + + def _process_sse_event(self, event: bytes) -> bytes: + """Process a single complete SSE event.""" + try: + # Decode with error handling for chunk boundaries + text = event.decode('utf-8') + except UnicodeDecodeError: + # Incomplete UTF-8 sequence - return as-is + _log.debug("UTF-8 decode error in SSE event, passing through") + return event + lines = text.split('\n') processed_lines = [] @@ -330,6 +410,19 @@ def _process_tool_result(self, data: dict) -> dict: data['content'] = processed return data + + def flush(self) -> bytes: + """ + Flush any remaining buffered data. + Call this when the stream ends to get any incomplete events. + """ + with self._lock: + if self._buffer: + remaining = self._buffer + self._buffer = b"" + # Process remaining as a single event (may be incomplete) + return self._process_sse_event(remaining) + return b"" # === Singleton Instances === @@ -342,16 +435,21 @@ def _process_tool_result(self, data: dict) -> dict: def _get_store() -> ContextStore: global _store_instance if _store_instance is None: - _store_instance = ContextStore() + with _init_lock: + # Double-check after acquiring lock + if _store_instance is None: + _store_instance = ContextStore() return _store_instance def _get_router() -> OutputRouter: global _router_instance if _router_instance is None: - store = _get_store() - config = _load_config() - _router_instance = OutputRouter(store, config) + with _init_lock: + if _router_instance is None: + store = _get_store() + config = _load_config() + _router_instance = OutputRouter(store, config) return _router_instance @@ -368,8 +466,9 @@ def _load_config() -> RouterConfig: for item in overrides_str.split(","): if ":" in item: tool, tier = item.split(":", 1) + tier = tier.strip().upper() if tier in ("L0", "L1", "L2"): - config.tool_overrides[tool.strip()] = tier # type: ignore + config.tool_overrides[tool.strip()] = tier # type: ignore[assignment] return config @@ -378,8 +477,10 @@ def get_stream_processor() -> StreamProcessor: """Get or create the singleton stream processor.""" global _processor_instance if _processor_instance is None: - router = _get_router() - _processor_instance = StreamProcessor(router) + with _init_lock: + if _processor_instance is None: + router = _get_router() + _processor_instance = StreamProcessor(router) return _processor_instance diff --git a/mcp-server/mcp_tool_router/test_context_router.py b/mcp-server/mcp_tool_router/test_context_router.py index ddcad17..2bc292c 100644 --- a/mcp-server/mcp_tool_router/test_context_router.py +++ b/mcp-server/mcp_tool_router/test_context_router.py @@ -5,6 +5,7 @@ import os import tempfile +import threading import unittest from unittest.mock import MagicMock @@ -41,6 +42,35 @@ def test_unique_paths_for_different_content(self): path2 = self.store.save("content 2", prefix="test") self.assertNotEqual(path1, path2) + + def test_same_content_same_path(self): + """Same content should produce same path (content-addressed).""" + path1 = self.store.save("identical content", prefix="test") + path2 = self.store.save("identical content", prefix="test") + + self.assertEqual(path1, path2) + + def test_thread_safety(self): + """Test concurrent saves don't corrupt index.""" + results = [] + errors = [] + + def save_content(i): + try: + path = self.store.save(f"content {i}", prefix=f"t{i}") + results.append(path) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=save_content, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(errors), 0) + self.assertEqual(len(results), 10) + self.assertEqual(len(set(results)), 10) # All unique class TestOutputRouter(unittest.TestCase): @@ -83,6 +113,33 @@ def test_l2_fallback_to_l1(self): self.assertIn("[bigcmd:", result) self.assertIn("└ Full:", result) + def test_l2_with_agent(self): + """L2 with agent should use agent summary.""" + mock_agent = MagicMock(return_value="Agent summary here") + router = OutputRouter(self.store, self.config, agent_fn=mock_agent) + + content = "x" * 15000 + result = router.process(content, tool_name="bigcmd") + + mock_agent.assert_called_once() + self.assertIn("agent summary", result) + self.assertIn("Agent summary here", result) + + def test_l2_agent_failure_fallback(self): + """L2 should fallback to L1 if agent fails.""" + def failing_agent(prompt): + raise Exception("Agent unavailable") + + router = OutputRouter(self.store, self.config, agent_fn=failing_agent) + + content = "x" * 15000 + result = router.process(content, tool_name="bigcmd") + + # Should fallback to L1 format + self.assertIn("[bigcmd:", result) + self.assertIn("└ Full:", result) + self.assertNotIn("agent summary", result) + def test_tool_override(self): """Tool-specific overrides should be respected.""" config = RouterConfig(tool_overrides={"glob": "L0"}) @@ -131,6 +188,72 @@ def test_process_tool_result(self): # Should contain summary markers self.assertIn("[bash:", result_str) self.assertIn("└ Full:", result_str) + + def test_chunked_event_buffering(self): + """Events split across chunks should be handled correctly.""" + content = "x" * 5000 + full_event = f'data: {{"type": "tool_result", "tool": "bash", "content": "{content}"}}\n\n' + full_bytes = full_event.encode('utf-8') + + # Split in the middle + mid = len(full_bytes) // 2 + chunk1 = full_bytes[:mid] + chunk2 = full_bytes[mid:] + + # First chunk should return empty (buffered) + result1 = self.processor.process_chunk(chunk1) + self.assertEqual(result1, b"") + + # Second chunk should return processed event + result2 = self.processor.process_chunk(chunk2) + result_str = result2.decode('utf-8') + + self.assertIn("[bash:", result_str) + self.assertIn("└ Full:", result_str) + + def test_multiple_events_in_chunk(self): + """Multiple complete events in one chunk should all be processed.""" + event1 = 'data: {"type": "message.delta", "content": "hello"}\n\n' + event2 = 'data: {"type": "message.delta", "content": "world"}\n\n' + chunk = (event1 + event2).encode('utf-8') + + result = self.processor.process_chunk(chunk) + result_str = result.decode('utf-8') + + self.assertIn("hello", result_str) + self.assertIn("world", result_str) + + def test_flush_incomplete_event(self): + """Flush should return any remaining buffered data.""" + # Send incomplete event + incomplete = b'data: {"type": "tool_result"' + self.processor.process_chunk(incomplete) + + # Flush should return the incomplete data + flushed = self.processor.flush() + self.assertIn(b"tool_result", flushed) + + def test_thread_safety(self): + """Concurrent chunk processing should be safe.""" + results = [] + errors = [] + + def process_chunk(i): + try: + chunk = f'data: {{"type": "message.delta", "content": "msg{i}"}}\n\n' + result = self.processor.process_chunk(chunk.encode('utf-8')) + results.append(result) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=process_chunk, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(errors), 0) + self.assertEqual(len(results), 10) class TestTierBoundaries(unittest.TestCase): @@ -168,5 +291,35 @@ def test_at_l1_threshold(self): ) +class TestUnicodeHandling(unittest.TestCase): + """Test handling of non-ASCII content.""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.store = ContextStore(base_dir=self.temp_dir) + self.router = OutputRouter(self.store, RouterConfig()) + self.processor = StreamProcessor(self.router) + + def tearDown(self): + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_korean_content(self): + """Korean text should be handled correctly.""" + content = "안녕하세요. 한글 테스트입니다." * 100 + result = self.router.process(content, tool_name="test") + + # Should be summarized (content is > L0) + self.assertIn("[test:", result) + self.assertIn("안녕하세요", result) # First line preview + + def test_emoji_content(self): + """Emoji should be handled correctly.""" + content = "🎉🎊🎁" * 500 + result = self.router.process(content, tool_name="emoji") + + self.assertIn("[emoji:", result) + + if __name__ == "__main__": unittest.main()