diff --git a/.gitignore b/.gitignore index d44dbe1..6f8a414 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ node_modules/ ## Build build/ dist/ +*.egg-info/ ## Logs logs/ @@ -41,4 +42,5 @@ data/ ## .cursor .cursor/ docs/ -tests/ \ No newline at end of file +tests/ +.aider* diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..3047f32 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,165 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Code Graph Knowledge System is a Neo4j-based intelligent knowledge management system that combines vector search, graph databases, and LLM integration for document processing and RAG (Retrieval Augmented Generation). The system processes documents into a knowledge graph and provides intelligent querying capabilities. + +## Architecture + +### Core Components +- **FastAPI Application** (`main.py`, `core/app.py`): Main web server with async request handling +- **Neo4j Knowledge Service** (`services/neo4j_knowledge_service.py`): Primary service handling LlamaIndex + Neo4j integration for knowledge graph operations +- **SQL Parsers** (`services/sql_parser.py`, `services/universal_sql_schema_parser.py`): Database schema analysis and parsing +- **Task Queue System** (`services/task_queue.py`, `monitoring/task_monitor.py`): Async background processing with web monitoring +- **MCP Server** (`mcp_server.py`, `start_mcp.py`): Model Context Protocol integration for AI assistants + +### Multi-Provider LLM Support +The system supports multiple LLM and embedding providers: +- **Ollama**: Local LLM hosting (default) +- **OpenAI**: GPT models and embeddings +- **Google Gemini**: Gemini models and embeddings +- **OpenRouter**: Access to multiple model providers +- **HuggingFace**: Local embedding models + +Configuration is handled via environment variables in `.env` file (see `env.example`). + +## Development Commands + +### Running the Application +```bash +# Start main application +python start.py + +# Start MCP server (for AI assistant integration) +python start_mcp.py + +# Using script entry points (after uv sync) +uv run server +uv run mcp_client + +# Direct FastAPI startup +python main.py +``` + +### Testing +```bash +# Run tests +pytest tests/ + +# Run with coverage +pytest tests/ --cov + +# Run specific test file +pytest tests/test_specific.py +``` + +### Code Quality +```bash +# Format code +black . +isort . + +# Lint code +ruff check . +``` + +### Dependencies +```bash +# Install dependencies +pip install -e . + +# Using uv (recommended) +uv pip install -e . +``` + +## Configuration + +### Environment Setup +1. Copy `env.example` to `.env` +2. Configure Neo4j connection: `NEO4J_URI`, `NEO4J_USER`, `NEO4J_PASSWORD` +3. Choose LLM provider: `LLM_PROVIDER` (ollama/openai/gemini/openrouter) +4. Set embedding provider: `EMBEDDING_PROVIDER` + +### Neo4j Requirements +- Neo4j 5.0+ with APOC plugin +- Default connection: `bolt://localhost:7687` +- Database: `neo4j` (default) + +### Service Dependencies +The application checks service health on startup via `start.py:check_dependencies()`. Required services: +- Neo4j database connection +- LLM provider (Ollama/OpenAI/etc.) + +## Key Development Patterns + +### Service Initialization +All services use async initialization patterns. The `Neo4jKnowledgeService` must be initialized before use: +```python +await knowledge_service.initialize() +``` + +### Error Handling +Services return structured responses with `success` field and error details: +```python +result = await service.operation() +if not result.get("success"): + # Handle error from result["error"] +``` + +### Timeout Management +Operations use configurable timeouts from `config.py`: +- `connection_timeout`: Database connections +- `operation_timeout`: Standard operations +- `large_document_timeout`: Large document processing + +### LlamaIndex Integration +The system uses LlamaIndex's `KnowledgeGraphIndex` with Neo4j backend. Global settings are configured in `services/neo4j_knowledge_service.py:initialize()`. + +## API Structure + +### Main Endpoints +- `/api/v1/health`: Service health check +- `/api/v1/knowledge/query`: Query knowledge base with RAG +- `/api/v1/knowledge/search`: Vector similarity search +- `/api/v1/documents/*`: Document management +- `/api/v1/sql/*`: SQL parsing and analysis + +### Real-time Task Monitoring +The system provides multiple approaches for real-time task monitoring: + +#### Web UI Monitoring (`/ui/monitor`) +When `ENABLE_MONITORING=true`, NiceGUI monitoring interface is available with: +- Real-time task status updates via WebSocket +- File upload functionality (50KB size limit) +- Directory batch processing +- Task progress visualization + +#### Server-Sent Events (SSE) API +SSE endpoints for streaming real-time updates: +- `/api/v1/sse/task/{task_id}`: Monitor single task progress +- `/api/v1/sse/tasks`: Monitor all tasks with optional status filtering +- `/api/v1/sse/stats`: Get active SSE connection statistics + +#### MCP Real-time Tools +MCP server provides real-time monitoring tools: +- `watch_task`: Monitor single task with progress history +- `watch_tasks`: Monitor multiple tasks until completion +- Supports custom timeouts and update intervals +- **Note**: These are MCP protocol tools, not HTTP endpoints + +#### Client Implementation Examples +- `examples/pure_mcp_client.py`: Pure MCP client using `watch_task` tools +- `examples/hybrid_http_sse_client.py`: HTTP + SSE hybrid approach + +### Large File Handling Strategy +The system handles large documents through multiple approaches: +- **Small files (<10KB)**: Direct synchronous processing +- **Medium files (10-50KB)**: Temporary file strategy with background processing +- **Large files (>50KB)**: UI prompts for directory processing or MCP client usage +- **MCP client**: Automatic temporary file creation for large documents + +## Testing Approach + +Tests are located in `tests/` directory. The system includes comprehensive testing for SQL parsing functionality. Use `pytest` for running tests. \ No newline at end of file diff --git a/README.md b/README.md index 4d2b729..29f79b3 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,15 @@ Code Graph Knowledge System is an enterprise-grade solution that transforms unst - **Universal SQL Schema Parser**: Configurable database schema analysis with industry-specific templates - **Intelligent Query Engine**: Hybrid search combining vector similarity and graph traversal - **Asynchronous Task Processing**: Background processing for large document collections with real-time monitoring -- **Web-based Monitoring Dashboard**: Real-time task queue monitoring with NiceGUI interface +- **Real-time Task Monitoring**: Multiple real-time monitoring solutions + - Web UI Monitoring: NiceGUI interface with file upload and directory batch processing + - SSE Streaming API: HTTP Server-Sent Events for real-time task progress updates + - MCP Real-time Tools: AI assistant integrated task monitoring tools - **Multi-Database Support**: Oracle, MySQL, PostgreSQL, SQL Server schema parsing and analysis - **RESTful API**: Complete API endpoints for document management and knowledge querying - **MCP Protocol Support**: Model Context Protocol integration for AI assistant compatibility -- **Multi-provider LLM Support**: Compatible with Ollama, OpenAI, and Gemini models +- **Multi-provider LLM Support**: Compatible with Ollama, OpenAI, Gemini, and OpenRouter models +- **Large File Handling Strategy**: Intelligent file size detection with multiple processing approaches ### Technical Architecture - **FastAPI Backend**: High-performance async web framework @@ -105,12 +109,21 @@ Code Graph Knowledge System is an enterprise-grade solution that transforms unst 5. **Run the Application** ```bash + # Start main service python start.py + # or use script entry points + uv run server + + # Start MCP service (optional) + python start_mcp.py + # or use script entry points + uv run mcp_client ``` 6. **Access the Interface** - API Documentation: http://localhost:8000/docs - Task Monitor: http://localhost:8000/ui/monitor + - Real-time SSE Monitor: http://localhost:8000/api/v1/sse/tasks - Health Check: http://localhost:8000/api/v1/health ## API Usage @@ -154,6 +167,53 @@ response = httpx.post("http://localhost:8000/api/v1/knowledge/search", json={ }) ``` +## Real-time Task Monitoring + +The system provides three real-time task monitoring approaches: + +### 1. Web UI Monitoring Interface +Access http://localhost:8000/ui/monitor for graphical monitoring: +- Real-time task status updates +- File upload functionality (50KB size limit) +- Directory batch processing +- Task progress visualization + +### 2. Server-Sent Events (SSE) API +Real-time monitoring via HTTP streaming endpoints: + +```javascript +// Monitor single task +const eventSource = new EventSource('/api/v1/sse/task/task-id'); +eventSource.onmessage = function(event) { + const data = JSON.parse(event.data); + console.log('Task progress:', data.progress); +}; + +// Monitor all tasks +const allTasksSource = new EventSource('/api/v1/sse/tasks'); +``` + +### 3. MCP Real-time Tools +Task monitoring via MCP protocol: + +```python +# Use pure MCP client monitoring +# See examples/pure_mcp_client.py + +# Monitor single task +result = await session.call_tool("watch_task", { + "task_id": task_id, + "timeout": 300, + "interval": 1.0 +}) + +# Monitor multiple tasks +result = await session.call_tool("watch_tasks", { + "task_ids": [task1, task2, task3], + "timeout": 300 +}) +``` + ## MCP Integration The system supports Model Context Protocol (MCP) for seamless integration with AI assistants: @@ -174,6 +234,10 @@ python start_mcp.py } ``` +### Client Implementation Examples +- `examples/pure_mcp_client.py`: Pure MCP client using MCP tools for monitoring +- `examples/hybrid_http_sse_client.py`: HTTP + SSE hybrid approach + ## Configuration Key configuration options in `.env`: diff --git a/README_CN.md b/README_CN.md index 45c59ee..9638322 100644 --- a/README_CN.md +++ b/README_CN.md @@ -13,10 +13,14 @@ - **Neo4j GraphRAG 集成**:使用 Neo4j 原生向量索引的高级图检索增强生成 - **智能查询引擎**:结合向量相似度和图遍历的混合搜索 - **异步任务处理**:支持大型文档集合的后台处理和实时监控 -- **基于Web的监控仪表板**:使用 NiceGUI 界面进行实时任务队列监控 +- **实时任务监控**:多种实时监控方案 + - Web UI监控:NiceGUI界面,支持文件上传和目录批处理 + - SSE流式API:HTTP Server-Sent Events实时任务进度推送 + - MCP实时工具:AI助手集成的任务监控工具 - **RESTful API**:完整的文档管理和知识查询 API 端点 - **MCP 协议支持**:模型上下文协议集成,兼容 AI 助手 -- **多提供商LLM支持**:兼容 Ollama、OpenAI 和 Gemini 模型 +- **多提供商LLM支持**:兼容 Ollama、OpenAI、Gemini 和 OpenRouter 模型 +- **大文件处理策略**:智能文件大小检测和多种处理方案 ### 技术架构 - **FastAPI 后端**:高性能异步网络框架 @@ -92,12 +96,21 @@ 5. **运行应用程序** ```bash + # 启动主服务 python start.py + # 或使用脚本入口点 + uv run server + + # 启动MCP服务(可选) + python start_mcp.py + # 或使用脚本入口点 + uv run mcp_client ``` 6. **访问界面** - API 文档:http://localhost:8000/docs - 任务监控:http://localhost:8000/ui/monitor + - 实时监控SSE:http://localhost:8000/api/v1/sse/tasks - 健康检查:http://localhost:8000/api/v1/health ## API 使用 @@ -141,6 +154,53 @@ response = httpx.post("http://localhost:8000/api/v1/knowledge/search", json={ }) ``` +## 实时任务监控 + +系统提供三种实时任务监控方案: + +### 1. Web UI 监控界面 +访问 http://localhost:8000/ui/monitor 使用图形界面: +- 实时任务状态更新 +- 文件上传功能(50KB大小限制) +- 目录批量处理 +- 任务进度可视化 + +### 2. Server-Sent Events (SSE) API +通过 HTTP 流式端点进行实时监控: + +```javascript +// 监控单个任务 +const eventSource = new EventSource('/api/v1/sse/task/task-id'); +eventSource.onmessage = function(event) { + const data = JSON.parse(event.data); + console.log('Task progress:', data.progress); +}; + +// 监控所有任务 +const allTasksSource = new EventSource('/api/v1/sse/tasks'); +``` + +### 3. MCP 实时工具 +通过 MCP 协议进行任务监控: + +```python +# 使用纯MCP客户端监控 +# 参见 examples/pure_mcp_client.py + +# 监控单个任务 +result = await session.call_tool("watch_task", { + "task_id": task_id, + "timeout": 300, + "interval": 1.0 +}) + +# 监控多个任务 +result = await session.call_tool("watch_tasks", { + "task_ids": [task1, task2, task3], + "timeout": 300 +}) +``` + ## MCP 集成 系统支持模型上下文协议(MCP),可与 AI 助手无缝集成: @@ -161,6 +221,10 @@ python start_mcp.py } ``` +### 客户端实现示例 +- `examples/pure_mcp_client.py`: 纯MCP客户端,使用MCP工具进行监控 +- `examples/hybrid_http_sse_client.py`: HTTP + SSE 混合方案 + ## 配置 `.env` 文件中的关键配置选项: diff --git a/api/sse_routes.py b/api/sse_routes.py new file mode 100644 index 0000000..9e123ad --- /dev/null +++ b/api/sse_routes.py @@ -0,0 +1,252 @@ +""" +Server-Sent Events (SSE) routes for real-time task monitoring +""" + +import asyncio +import json +from typing import Optional, Dict, Any +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +from loguru import logger + +from services.task_queue import task_queue, TaskStatus + +router = APIRouter(prefix="/sse", tags=["SSE"]) + +# Active SSE connections +active_connections: Dict[str, Dict[str, Any]] = {} + +@router.get("/task/{task_id}") +async def stream_task_progress(task_id: str, request: Request): + """ + Stream task progress via Server-Sent Events + + Args: + task_id: Task ID to monitor + """ + + async def event_generator(): + connection_id = f"{task_id}_{id(request)}" + active_connections[connection_id] = { + "task_id": task_id, + "request": request, + "start_time": asyncio.get_event_loop().time() + } + + try: + logger.info(f"Starting SSE stream for task {task_id}") + + # Send initial connection event + yield f"data: {json.dumps({'type': 'connected', 'task_id': task_id, 'timestamp': asyncio.get_event_loop().time()})}\n\n" + + last_progress = -1 + last_status = None + + while True: + # Check if client disconnected + if await request.is_disconnected(): + logger.info(f"Client disconnected from SSE stream for task {task_id}") + break + + # Get task status + task_result = task_queue.get_task_status(task_id) + + if task_result is None: + # Task does not exist + yield f"data: {json.dumps({'type': 'error', 'error': 'Task not found', 'task_id': task_id})}\n\n" + break + + # Check for progress updates + if (task_result.progress != last_progress or + task_result.status.value != last_status): + + event_data = { + "type": "progress", + "task_id": task_id, + "progress": task_result.progress, + "status": task_result.status.value, + "message": task_result.message, + "timestamp": asyncio.get_event_loop().time() + } + + yield f"data: {json.dumps(event_data)}\n\n" + + last_progress = task_result.progress + last_status = task_result.status.value + + # Check if task is completed + if task_result.status.value in ['success', 'failed', 'cancelled']: + completion_data = { + "type": "completed", + "task_id": task_id, + "final_status": task_result.status.value, + "final_progress": task_result.progress, + "final_message": task_result.message, + "result": task_result.result, + "error": task_result.error, + "created_at": task_result.created_at.isoformat(), + "started_at": task_result.started_at.isoformat() if task_result.started_at else None, + "completed_at": task_result.completed_at.isoformat() if task_result.completed_at else None, + "timestamp": asyncio.get_event_loop().time() + } + + yield f"data: {json.dumps(completion_data)}\n\n" + logger.info(f"Task {task_id} completed via SSE: {task_result.status.value}") + break + + # Wait 1 second before next check + await asyncio.sleep(1) + + except asyncio.CancelledError: + logger.info(f"SSE stream cancelled for task {task_id}") + except Exception as e: + logger.error(f"Error in SSE stream for task {task_id}: {e}") + yield f"data: {json.dumps({'type': 'error', 'error': str(e), 'task_id': task_id})}\n\n" + finally: + # Clean up connection + if connection_id in active_connections: + del active_connections[connection_id] + logger.info(f"SSE stream ended for task {task_id}") + + return StreamingResponse( + event_generator(), + media_type="text/plain", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "Cache-Control" + } + ) + +@router.get("/tasks") +async def stream_all_tasks(request: Request, status_filter: Optional[str] = None): + """ + Stream all tasks progress via Server-Sent Events + + Args: + status_filter: Optional status filter (pending, processing, success, failed, cancelled) + """ + + async def event_generator(): + connection_id = f"all_tasks_{id(request)}" + active_connections[connection_id] = { + "task_id": "all", + "request": request, + "start_time": asyncio.get_event_loop().time(), + "status_filter": status_filter + } + + try: + logger.info(f"Starting SSE stream for all tasks (filter: {status_filter})") + + # Send initial connection event + yield f"data: {json.dumps({'type': 'connected', 'scope': 'all_tasks', 'filter': status_filter, 'timestamp': asyncio.get_event_loop().time()})}\n\n" + + # 发送初始任务列表 + status_enum = None + if status_filter: + try: + status_enum = TaskStatus(status_filter.lower()) + except ValueError: + yield f"data: {json.dumps({'type': 'error', 'error': f'Invalid status filter: {status_filter}'})}\n\n" + return + + last_task_count = 0 + last_task_states = {} + + while True: + # Check if client disconnected + if await request.is_disconnected(): + logger.info("Client disconnected from all tasks SSE stream") + break + + # 获取当前任务列表 + tasks = task_queue.get_all_tasks(status_filter=status_enum, limit=50) + current_task_count = len(tasks) + + # 检查任务数量变化 + if current_task_count != last_task_count: + count_data = { + "type": "task_count_changed", + "total_tasks": current_task_count, + "filter": status_filter, + "timestamp": asyncio.get_event_loop().time() + } + yield f"data: {json.dumps(count_data)}\n\n" + last_task_count = current_task_count + + # 检查每个任务的状态变化 + current_states = {} + for task in tasks: + task_key = task.task_id + current_state = { + "status": task.status.value, + "progress": task.progress, + "message": task.message + } + current_states[task_key] = current_state + + # 比较状态变化 + if (task_key not in last_task_states or + last_task_states[task_key] != current_state): + + task_data = { + "type": "task_updated", + "task_id": task.task_id, + "status": task.status.value, + "progress": task.progress, + "message": task.message, + "metadata": task.metadata, + "timestamp": asyncio.get_event_loop().time() + } + yield f"data: {json.dumps(task_data)}\n\n" + + last_task_states = current_states + + # 等待2秒再检查 + await asyncio.sleep(2) + + except asyncio.CancelledError: + logger.info("All tasks SSE stream cancelled") + except Exception as e: + logger.error(f"Error in all tasks SSE stream: {e}") + yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n" + finally: + # Clean up connection + if connection_id in active_connections: + del active_connections[connection_id] + logger.info("All tasks SSE stream ended") + + return StreamingResponse( + event_generator(), + media_type="text/plain", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "Cache-Control" + } + ) + +@router.get("/stats") +async def get_sse_stats(): + """ + Get SSE connection statistics + """ + stats = { + "active_connections": len(active_connections), + "connections": [] + } + + for conn_id, conn_info in active_connections.items(): + stats["connections"].append({ + "connection_id": conn_id, + "task_id": conn_info["task_id"], + "duration": asyncio.get_event_loop().time() - conn_info["start_time"], + "status_filter": conn_info.get("status_filter") + }) + + return stats \ No newline at end of file diff --git a/api/task_routes.py b/api/task_routes.py index c853cd6..9956272 100644 --- a/api/task_routes.py +++ b/api/task_routes.py @@ -12,6 +12,7 @@ from services.task_queue import task_queue, TaskStatus from services.task_storage import TaskType from loguru import logger +from config import settings router = APIRouter(prefix="/tasks", tags=["Task Management"]) @@ -68,6 +69,25 @@ async def create_task(request: CreateTaskRequest): if request.metadata: task_kwargs.update(request.metadata) + # Handle large documents by storing them temporarily + if request.task_type == "document_processing": + document_content = task_kwargs.get("document_content") + if document_content and len(document_content) > settings.max_document_size: + import tempfile + import os + + # Create temporary file for large document + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp_file: + tmp_file.write(document_content) + temp_path = tmp_file.name + + logger.info(f"Large document ({len(document_content)} bytes) saved to temporary file: {temp_path}") + + # Replace content with path reference + task_kwargs["document_path"] = temp_path + task_kwargs["document_content"] = None # Clear large content + task_kwargs["_temp_file"] = True # Mark as temporary file for cleanup + # select processing function based on task type task_func = None if request.task_type == "document_processing": diff --git a/config.py b/config.py index 54d3ee2..c3d5b8c 100644 --- a/config.py +++ b/config.py @@ -26,7 +26,7 @@ class Settings(BaseSettings): neo4j_database: str = Field(default="neo4j", description="Neo4j database name") # LLM Provider Configuration - llm_provider: Literal["ollama", "openai", "gemini"] = Field( + llm_provider: Literal["ollama", "openai", "gemini", "openrouter"] = Field( default="ollama", description="LLM provider to use", alias="LLM_PROVIDER" @@ -45,8 +45,14 @@ class Settings(BaseSettings): google_api_key: Optional[str] = Field(default=None, description="Google API key", alias="GOOGLE_API_KEY") gemini_model: str = Field(default="gemini-pro", description="Gemini model name", alias="GEMINI_MODEL") + # OpenRouter Configuration + openrouter_api_key: Optional[str] = Field(default=None, description="OpenRouter API key", alias="OPENROUTER_API_KEY") + openrouter_base_url: str = Field(default="https://openrouter.ai/api/v1", description="OpenRouter API base URL", alias="OPENROUTER_BASE_URL") + openrouter_model: Optional[str] = Field(default="openai/gpt-3.5-turbo", description="OpenRouter model", alias="OPENROUTER_MODEL") + openrouter_max_tokens: int = Field(default=2048, description="OpenRouter max tokens for completion", alias="OPENROUTER_MAX_TOKENS") + # Embedding Provider Configuration - embedding_provider: Literal["ollama", "openai", "gemini", "huggingface"] = Field( + embedding_provider: Literal["ollama", "openai", "gemini", "huggingface", "openrouter"] = Field( default="ollama", description="Embedding provider to use", alias="EMBEDDING_PROVIDER" @@ -64,6 +70,9 @@ class Settings(BaseSettings): # HuggingFace Embedding huggingface_embedding_model: str = Field(default="BAAI/bge-small-en-v1.5", description="HuggingFace embedding model", alias="HF_EMBEDDING_MODEL") + # OpenRouter Embedding + openrouter_embedding_model: str = Field(default="text-embedding-ada-002", description="OpenRouter embedding model", alias="OPENROUTER_EMBEDDING_MODEL") + # Model Parameters temperature: float = Field(default=0.1, description="LLM temperature") max_tokens: int = Field(default=2048, description="Maximum tokens for LLM response") @@ -78,6 +87,10 @@ class Settings(BaseSettings): operation_timeout: int = Field(default=120, description="Operation timeout in seconds") large_document_timeout: int = Field(default=300, description="Large document processing timeout in seconds") + # Document Processing Settings + max_document_size: int = Field(default=10 * 1024 * 1024, description="Maximum document size in bytes (10MB)") + max_payload_size: int = Field(default=50 * 1024 * 1024, description="Maximum task payload size for storage (50MB)") + # API Settings cors_origins: list = Field(default=["*"], description="CORS allowed origins") api_key: Optional[str] = Field(default=None, description="API authentication key") @@ -160,6 +173,26 @@ def validate_gemini_connection(): print(f"Gemini connection failed: {e}") return False +def validate_openrouter_connection(): + """Validate OpenRouter API connection""" + if not settings.openrouter_api_key: + print("OpenRouter API key not provided") + return False + try: + import httpx + # We'll use the models endpoint to check the connection + headers = { + "Authorization": f"Bearer {settings.openrouter_api_key}", + # OpenRouter requires these headers for identification + "HTTP-Referer": "CodeGraphKnowledgeService", + "X-Title": "CodeGraph Knowledge Service" + } + response = httpx.get("https://openrouter.ai/api/v1/models", headers=headers) + return response.status_code == 200 + except Exception as e: + print(f"OpenRouter connection failed: {e}") + return False + def get_current_model_info(): """Get information about currently configured models""" return { @@ -167,13 +200,15 @@ def get_current_model_info(): "llm_model": { "ollama": settings.ollama_model, "openai": settings.openai_model, - "gemini": settings.gemini_model + "gemini": settings.gemini_model, + "openrouter": settings.openrouter_model }.get(settings.llm_provider), "embedding_provider": settings.embedding_provider, "embedding_model": { "ollama": settings.ollama_embedding_model, "openai": settings.openai_embedding_model, "gemini": settings.gemini_embedding_model, - "huggingface": settings.huggingface_embedding_model + "huggingface": settings.huggingface_embedding_model, + "openrouter": settings.openrouter_embedding_model }.get(settings.embedding_provider) - } \ No newline at end of file + } diff --git a/config/sky.yml b/config/sky.yml new file mode 100644 index 0000000..72d36da --- /dev/null +++ b/config/sky.yml @@ -0,0 +1,104 @@ +project_name: "X Platform" +database_schema: "TEST" + +# SQL dialect settings +statement_separator: "/" +comment_patterns: + - "--.*$" + - "/\\*.*?\\*/" + +# Parsing rules +table_name_pattern: "create\\s+table\\s+(\\w+)\\.(\\w+)" +column_section_pattern: "\\((.*?)\\)" + +# Output settings +include_statistics: true +include_data_types_analysis: true +include_documentation: true + +# Business domain classification rules +business_domains: + policy_management: + - "POLICY" + - "PREMIUM" + - "COVERAGE" + - "CLAIM" + - "BENEFIT" + + customer_management: + - "CLIENT" + - "CUSTOMER" + - "INSURED" + - "CONTACT" + - "HOUSEHOLD" + - "PERSON" + + agent_management: + - "AGENT" + - "ADVISOR" + - "BROKER" + - "SALES" + - "BRANCH" + - "DISTRICT" + + product_management: + - "PRODUCT" + - "PLAN" + - "RIDER" + - "ANNUIT" + - "GIC" + - "WELLNESS" + + fund_management: + - "FD_" + - "FUND" + - "INVESTMENT" + - "PORTFOLIO" + - "SEG_FUND" + - "CANNEX" + + commission_management: + - "COMMISSION" + - "COMM_" + - "PAYMENT" + - "PAYABLE" + - "BONUS" + + underwriting_management: + - "UNDERWRITING" + - "UW_" + - "RATING" + - "RISK" + - "SURVEY" + - "HEIGHT_WEIGHT" + + system_management: + - "TYPE_" + - "CONFIG" + - "PARAM" + - "LOOKUP" + - "SETTING" + - "LICENSE" + - "COUNTRY" + - "PROVINCE" + + report_analysis: + - "SUN_" + - "REPORT" + - "STAT" + - "ANALYTICS" + - "TOTALS" + + compliance_audit: + - "AUDIT" + - "COMPLIANCE" + - "AGREEMENT" + - "LIABILITY" + + web_digital: + - "WEB_" + - "ONLINE" + - "DIGITAL" + - "LOGIN" + - "USER" + - "PROFILE" \ No newline at end of file diff --git a/core/logging.py b/core/logging.py index 84a70fa..5725a9b 100644 --- a/core/logging.py +++ b/core/logging.py @@ -10,10 +10,16 @@ def setup_logging(): """configure logging system""" + import logging # remove default log handler logger.remove() + # Suppress NiceGUI WebSocket debug logs + logging.getLogger("websockets").setLevel(logging.WARNING) + logging.getLogger("socketio").setLevel(logging.WARNING) + logging.getLogger("engineio").setLevel(logging.WARNING) + # add console log handler logger.add( sys.stderr, diff --git a/core/routes.py b/core/routes.py index 5727194..3e4e2d8 100644 --- a/core/routes.py +++ b/core/routes.py @@ -8,6 +8,7 @@ from api.neo4j_routes import router as neo4j_router from api.task_routes import router as task_router from api.websocket_routes import router as ws_router +from api.sse_routes import router as sse_router def setup_routes(app: FastAPI) -> None: @@ -17,4 +18,5 @@ def setup_routes(app: FastAPI) -> None: app.include_router(router, prefix="/api/v1", tags=["General"]) app.include_router(neo4j_router, prefix="/api/v1", tags=["Neo4j Knowledge"]) app.include_router(task_router, prefix="/api/v1", tags=["Task Management"]) + app.include_router(sse_router, prefix="/api/v1", tags=["Real-time Updates"]) \ No newline at end of file diff --git a/env.example b/env.example index 415a252..0303589 100644 --- a/env.example +++ b/env.example @@ -13,12 +13,12 @@ NEO4J_URI=bolt://localhost:7687 NEO4J_USER=neo4j NEO4J_PASSWORD=password -# LLM Provider (ollama/openai/gemini) +# LLM Provider (ollama/openai/gemini/openrouter) LLM_PROVIDER=ollama EMBEDDING_PROVIDER=ollama # Ollama Configuration -OLLAMA_HOST=http://localhost:11434 +tLLAMA_HOST=http://localhost:11434 OLLAMA_MODEL=llama3.2 OLLAMA_EMBEDDING_MODEL=nomic-embed-text @@ -35,6 +35,13 @@ OLLAMA_EMBEDDING_MODEL=nomic-embed-text # HuggingFace Configuration (if using) # HF_EMBEDDING_MODEL=BAAI/bge-small-en-v1.5 +# OpenRouter Configuration (if using) +# OPENROUTER_API_KEY=your_api_key_here +# OPENROUTER_MODEL=openai/gpt-3.5-turbo +# OPENROUTER_EMBEDDING_MODEL=text-embedding-ada-002 +# OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 +# OPENROUTER_MAX_TOKENS=2048 + # Processing Configuration CHUNK_SIZE=1000 CHUNK_OVERLAP=200 @@ -61,4 +68,4 @@ OPERATION_TIMEOUT=120 LARGE_DOCUMENT_TIMEOUT=300 # Vector Search Settings -VECTOR_DIMENSION=384 \ No newline at end of file +VECTOR_DIMENSION=384 diff --git a/examples/hybrid_http_sse_client.py b/examples/hybrid_http_sse_client.py new file mode 100644 index 0000000..31bc108 --- /dev/null +++ b/examples/hybrid_http_sse_client.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +HTTP SSE客户端示例(混合实现) + +演示如何通过HTTP API提交任务并使用SSE进行实时监控 +注意:这不是纯MCP实现,而是HTTP + SSE的混合方案 +""" + +import asyncio +import json +import aiohttp +from typing import Optional, Dict, Any +import time + +class MCPSSEClient: + """MCP + SSE Combined Client""" + + def __init__(self, mcp_server_url: str = "stdio", sse_base_url: str = "http://localhost:8000/api/v1/sse"): + self.mcp_server_url = mcp_server_url + self.sse_base_url = sse_base_url + self.session: Optional[aiohttp.ClientSession] = None + + async def __aenter__(self): + self.session = aiohttp.ClientSession() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.session: + await self.session.close() + + async def stream_task_progress(self, task_id: str, callback=None): + """ + Stream task progress via SSE + + Args: + task_id: Task ID + callback: Progress callback function + """ + if not self.session: + raise RuntimeError("Client not initialized. Use 'async with' statement.") + + url = f"{self.sse_base_url}/task/{task_id}" + + try: + async with self.session.get(url) as response: + print(f"📡 Connected to SSE stream for task {task_id}") + + async for line in response.content: + line = line.decode('utf-8').strip() + + if line.startswith('data: '): + data_str = line[6:] # Remove 'data: ' prefix + + try: + data = json.loads(data_str) + await self._handle_sse_event(data, callback) + + # Exit stream if task completed + if data.get('type') in ['completed', 'error']: + break + + except json.JSONDecodeError: + print(f"⚠️ Invalid JSON data: {data_str}") + + except aiohttp.ClientError as e: + print(f"❌ SSE connection error: {e}") + except Exception as e: + print(f"❌ Unexpected error: {e}") + + async def _handle_sse_event(self, data: Dict[str, Any], callback=None): + """Handle SSE events""" + event_type = data.get('type') + + if event_type == 'connected': + print(f"✅ Connected to task monitoring") + + elif event_type == 'progress': + progress = data.get('progress', 0) + status = data.get('status', 'unknown') + message = data.get('message', '') + print(f"📊 Progress: {progress:.1f}% | Status: {status} | {message}") + + if callback: + await callback('progress', data) + + elif event_type == 'completed': + final_status = data.get('final_status', 'unknown') + final_message = data.get('final_message', '') + print(f"🎉 Task completed: {final_status} | {final_message}") + + if callback: + await callback('completed', data) + + elif event_type == 'error': + error = data.get('error', 'Unknown error') + print(f"❌ Error: {error}") + + if callback: + await callback('error', data) + +async def demo_mcp_with_sse(): + """Demonstrate MCP + SSE combined usage""" + + print("🚀 MCP + SSE Real-time Monitoring Demo") + print("=" * 50) + + # Simulate task submission via MCP (using HTTP request to simulate MCP call) + async with aiohttp.ClientSession() as session: + + # 1. Submit document processing task + print("📄 Submitting document processing task...") + + task_data = { + "task_name": "Process Large Document", + "task_type": "document_processing", + "payload": { + "document_content": "This is a large document content..." * 100, # Simulate large document + "document_type": "text" + } + } + + # Submit task via API (simulate MCP call) + async with session.post( + "http://localhost:8000/api/v1/tasks/submit", + json=task_data + ) as response: + if response.status == 200: + result = await response.json() + task_id = result.get('task_id') + print(f"✅ Task submitted successfully! Task ID: {task_id}") + else: + print(f"❌ Failed to submit task: {response.status}") + return + + # 2. Use SSE for real-time task progress monitoring + print(f"\n📡 Starting real-time monitoring for task {task_id}...") + + async def progress_callback(event_type: str, data: Dict[str, Any]): + """Custom progress callback""" + if event_type == 'progress': + # Can add custom logic here, such as updating UI, sending notifications, etc. + pass + elif event_type == 'completed': + print(f"✨ Task result: {data.get('result', {})}") + + # 3. Start SSE stream monitoring + async with MCPSSEClient() as client: + await client.stream_task_progress(task_id, progress_callback) + + print("\n🎯 Demo completed!") + +async def demo_multiple_tasks_monitoring(): + """Demonstrate multiple tasks monitoring""" + + print("🚀 Multiple Tasks Monitoring Demo") + print("=" * 50) + + task_ids = [] + + # Submit multiple tasks + async with aiohttp.ClientSession() as session: + for i in range(3): + task_data = { + "task_name": f"Document {i+1}", + "task_type": "document_processing", + "payload": { + "document_content": f"Document {i+1} content..." * 50, + "document_type": "text" + } + } + + async with session.post( + "http://localhost:8000/api/v1/tasks/submit", + json=task_data + ) as response: + if response.status == 200: + result = await response.json() + task_id = result.get('task_id') + task_ids.append(task_id) + print(f"✅ Task {i+1} submitted: {task_id}") + + # Monitor all tasks + print(f"\n📡 Monitoring {len(task_ids)} tasks...") + + async with MCPSSEClient() as client: + url = f"{client.sse_base_url}/tasks" + + async with client.session.get(url) as response: + print("📊 Connected to all tasks stream") + + tasks_completed = 0 + async for line in response.content: + line = line.decode('utf-8').strip() + + if line.startswith('data: '): + data_str = line[6:] + + try: + data = json.loads(data_str) + event_type = data.get('type') + + if event_type == 'task_updated': + task_id = data.get('task_id') + status = data.get('status') + progress = data.get('progress', 0) + + if task_id in task_ids: + print(f"📊 Task {task_id[:8]}: {progress:.1f}% | {status}") + + if status in ['success', 'failed', 'cancelled']: + tasks_completed += 1 + + if tasks_completed >= len(task_ids): + print("🎉 All tasks completed!") + break + + except json.JSONDecodeError: + pass + +def main(): + """Main function""" + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "multi": + asyncio.run(demo_multiple_tasks_monitoring()) + else: + asyncio.run(demo_mcp_with_sse()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/pure_mcp_client.py b/examples/pure_mcp_client.py new file mode 100644 index 0000000..0592244 --- /dev/null +++ b/examples/pure_mcp_client.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +""" +Pure MCP Client Real-time Monitoring Example + +Demonstrates how to perform real-time task monitoring via MCP tools without relying on HTTP SSE API +""" + +import asyncio +import json +from typing import Dict, Any +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + +async def demo_pure_mcp_monitoring(): + """Demonstrate pure MCP real-time monitoring""" + + print("🚀 Pure MCP Real-time Monitoring Demo") + print("=" * 50) + + # 连接到MCP服务器 + server_params = StdioServerParameters( + command="python", + args=["mcp_server.py"] + ) + + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + print("✅ Connected to MCP server") + + # 1. Submit document processing task + print("\n📄 Submitting document via MCP...") + + result = await session.call_tool( + "add_document", + { + "content": "This is a large document content that will trigger background processing..." * 200, + "title": "Large Test Document", + "metadata": {"source": "mcp_demo", "type": "test"} + } + ) + + if result.isError: + print(f"❌ Failed to submit task: {result.error}") + return + + task_id = result.content[0].text + task_data = json.loads(task_id) + + if not task_data.get("success"): + print(f"❌ Task submission failed: {task_data.get('error')}") + return + + actual_task_id = task_data.get("task_id") + print(f"✅ Task submitted! Task ID: {actual_task_id}") + + # 2. Use MCP watch_task tool for real-time monitoring + print(f"\n📡 Starting MCP real-time monitoring for task {actual_task_id}...") + + watch_result = await session.call_tool( + "watch_task", + { + "task_id": actual_task_id, + "timeout": 300, # 5分钟超时 + "interval": 1.0 # 每秒检查一次 + } + ) + + if watch_result.isError: + print(f"❌ Watch task failed: {watch_result.error}") + return + + # Parse monitoring results + watch_data = json.loads(watch_result.content[0].text) + + if watch_data.get("success"): + print(f"\n🎉 Task completed successfully!") + print(f"Final Status: {watch_data.get('final_status')}") + print(f"Final Message: {watch_data.get('final_message')}") + print(f"Total Watch Time: {watch_data.get('total_watch_time', 0):.2f}s") + + # Show progress history + progress_history = watch_data.get('progress_history', []) + if progress_history: + print(f"\n📊 Progress History ({len(progress_history)} updates):") + for i, entry in enumerate(progress_history[-5:]): # Show last 5 updates + print(f" {i+1}. {entry['progress']:.1f}% - {entry['status']} - {entry['message']}") + + # Show final result + final_result = watch_data.get('result') + if final_result: + print(f"\n✨ Final Result: {final_result}") + + else: + print(f"❌ Watch failed: {watch_data.get('error')}") + +async def demo_multiple_tasks_mcp(): + """Demonstrate MCP multiple tasks monitoring""" + + print("\n🚀 Multiple Tasks MCP Monitoring Demo") + print("=" * 50) + + server_params = StdioServerParameters( + command="python", + args=["mcp_server.py"] + ) + + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + print("✅ Connected to MCP server") + + # Submit multiple tasks + task_ids = [] + for i in range(3): + print(f"\n📄 Submitting document {i+1}...") + + result = await session.call_tool( + "add_document", + { + "content": f"Document {i+1} content that needs processing..." * 50, + "title": f"Test Document {i+1}", + "metadata": {"batch": "demo", "index": i+1} + } + ) + + if not result.isError: + task_data = json.loads(result.content[0].text) + if task_data.get("success"): + task_id = task_data.get("task_id") + task_ids.append(task_id) + print(f"✅ Task {i+1} submitted: {task_id}") + + if not task_ids: + print("❌ No tasks submitted successfully") + return + + # Use watch_tasks to monitor all tasks + print(f"\n📡 Monitoring {len(task_ids)} tasks...") + + watch_result = await session.call_tool( + "watch_tasks", + { + "task_ids": task_ids, + "timeout": 300, + "interval": 2.0 + } + ) + + if watch_result.isError: + print(f"❌ Watch tasks failed: {watch_result.error}") + return + + # Parse batch monitoring results + watch_data = json.loads(watch_result.content[0].text) + + if watch_data.get("success"): + print(f"\n🎉 All tasks monitoring completed!") + + summary = watch_data.get('summary', {}) + print(f"📊 Summary:") + print(f" Total: {summary.get('total_tasks', 0)}") + print(f" Successful: {summary.get('successful', 0)}") + print(f" Failed: {summary.get('failed', 0)}") + print(f" Total Time: {watch_data.get('total_watch_time', 0):.2f}s") + + # Show final status of each task + final_results = watch_data.get('final_results', {}) + print(f"\n📋 Final Results:") + for task_id, result in final_results.items(): + status = result.get('status', 'unknown') + message = result.get('message', '') + print(f" {task_id[:8]}: {status} - {message}") + + else: + print(f"❌ Batch watch failed: {watch_data.get('error')}") + +async def demo_task_listing(): + """Demonstrate task list querying""" + + print("\n🚀 Task Listing Demo") + print("=" * 30) + + server_params = StdioServerParameters( + command="python", + args=["mcp_server.py"] + ) + + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + + # List all tasks + result = await session.call_tool( + "list_tasks", + { + "limit": 10 + } + ) + + if result.isError: + print(f"❌ List tasks failed: {result.error}") + return + + tasks_data = json.loads(result.content[0].text) + + if tasks_data.get("success"): + tasks = tasks_data.get('tasks', []) + print(f"📋 Found {len(tasks)} recent tasks:") + + for task in tasks[:5]: # Show first 5 + task_id = task['task_id'] + status = task['status'] + progress = task['progress'] + message = task['message'] + print(f" {task_id[:8]}: {progress:.1f}% - {status} - {message}") + + else: + print(f"❌ Failed to list tasks: {tasks_data.get('error')}") + +def main(): + """Main function""" + import sys + + if len(sys.argv) > 1: + if sys.argv[1] == "multi": + asyncio.run(demo_multiple_tasks_mcp()) + elif sys.argv[1] == "list": + asyncio.run(demo_task_listing()) + else: + print("Usage: python pure_mcp_client.py [multi|list]") + else: + asyncio.run(demo_pure_mcp_monitoring()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mcp_server.py b/mcp_server.py index f552956..38ca856 100644 --- a/mcp_server.py +++ b/mcp_server.py @@ -158,14 +158,31 @@ async def add_document( return result else: - # for large documents, use asynchronous task queue - task_id = await submit_document_processing_task( - knowledge_service.add_document, - content=content, - title=title, - metadata=metadata, - task_name=f"Add Document: {title}" - ) + # for large documents (>=10KB), save to temporary file first + import tempfile + import os + + temp_fd, temp_path = tempfile.mkstemp(suffix=f"_{title.replace('/', '_')}.txt", text=True) + try: + with os.fdopen(temp_fd, 'w', encoding='utf-8') as temp_file: + temp_file.write(content) + + # use file path instead of content to avoid payload size issues + task_id = await submit_document_processing_task( + knowledge_service.add_file, # Use add_file instead of add_document + temp_path, + task_name=f"Add Large Document: {title}", + # Add metadata to track this is a temp file that should be cleaned up + _temp_file=True, + _original_title=title, + _original_metadata=metadata + ) + except: + # Clean up on error + os.close(temp_fd) + if os.path.exists(temp_path): + os.unlink(temp_path) + raise if ctx: await ctx.info(f"Large document queued for processing. Task ID: {task_id}") @@ -341,6 +358,239 @@ async def get_task_status( "error": error_msg } +# MCP tool: watch task (real-time task monitoring) +@mcp.tool +async def watch_task( + task_id: str, + timeout: int = 300, + interval: float = 1.0, + ctx: Context = None +) -> Dict[str, Any]: + """ + Watch a task progress with real-time updates until completion. + + Args: + task_id: The task ID to watch + timeout: Maximum time to wait in seconds (default: 300) + interval: Check interval in seconds (default: 1.0) + + Returns: + Dict containing final task status and progress history + """ + try: + await ensure_service_initialized() + + if ctx: + await ctx.info(f"Watching task: {task_id} (timeout: {timeout}s, interval: {interval}s)") + + import asyncio + start_time = asyncio.get_event_loop().time() + progress_history = [] + last_progress = -1 + last_status = None + + while True: + current_time = asyncio.get_event_loop().time() + if current_time - start_time > timeout: + return { + "success": False, + "error": "Watch timeout exceeded", + "progress_history": progress_history + } + + task_result = task_queue.get_task_status(task_id) + + if task_result is None: + return { + "success": False, + "error": "Task not found", + "progress_history": progress_history + } + + # Record progress changes + if (task_result.progress != last_progress or + task_result.status.value != last_status): + + progress_entry = { + "timestamp": asyncio.get_event_loop().time(), + "progress": task_result.progress, + "status": task_result.status.value, + "message": task_result.message + } + progress_history.append(progress_entry) + + # Send real-time updates to client + if ctx: + await ctx.info(f"Progress: {task_result.progress:.1f}% - {task_result.message}") + + last_progress = task_result.progress + last_status = task_result.status.value + + # Check if task is completed + if task_result.status.value in ['success', 'failed', 'cancelled']: + final_result = { + "success": True, + "task_id": task_result.task_id, + "final_status": task_result.status.value, + "final_progress": task_result.progress, + "final_message": task_result.message, + "created_at": task_result.created_at.isoformat(), + "started_at": task_result.started_at.isoformat() if task_result.started_at else None, + "completed_at": task_result.completed_at.isoformat() if task_result.completed_at else None, + "result": task_result.result, + "error": task_result.error, + "progress_history": progress_history, + "total_watch_time": current_time - start_time + } + + if ctx: + if task_result.status.value == 'success': + await ctx.info(f"Task completed successfully in {current_time - start_time:.1f}s") + else: + await ctx.error(f"Task {task_result.status.value}: {task_result.error or task_result.message}") + + return final_result + + # Wait for next check + await asyncio.sleep(interval) + + except Exception as e: + error_msg = f"Watch task failed: {str(e)}" + logger.error(error_msg) + if ctx: + await ctx.error(error_msg) + return { + "success": False, + "error": error_msg, + "progress_history": progress_history if 'progress_history' in locals() else [] + } + +# MCP tool: watch multiple tasks (batch monitoring) +@mcp.tool +async def watch_tasks( + task_ids: List[str], + timeout: int = 300, + interval: float = 2.0, + ctx: Context = None +) -> Dict[str, Any]: + """ + Watch multiple tasks progress with real-time updates until all complete. + + Args: + task_ids: List of task IDs to watch + timeout: Maximum time to wait in seconds (default: 300) + interval: Check interval in seconds (default: 2.0) + + Returns: + Dict containing all task statuses and progress histories + """ + try: + await ensure_service_initialized() + + if ctx: + await ctx.info(f"Watching {len(task_ids)} tasks (timeout: {timeout}s, interval: {interval}s)") + + import asyncio + start_time = asyncio.get_event_loop().time() + tasks_progress = {task_id: [] for task_id in task_ids} + completed_tasks = set() + + while True: + current_time = asyncio.get_event_loop().time() + if current_time - start_time > timeout: + return { + "success": False, + "error": "Watch timeout exceeded", + "tasks_progress": tasks_progress, + "completed_tasks": list(completed_tasks), + "pending_tasks": list(set(task_ids) - completed_tasks) + } + + # Check all tasks + active_tasks = [] + for task_id in task_ids: + if task_id in completed_tasks: + continue + + task_result = task_queue.get_task_status(task_id) + if task_result is None: + completed_tasks.add(task_id) + continue + + # Record progress + progress_entry = { + "timestamp": current_time, + "progress": task_result.progress, + "status": task_result.status.value, + "message": task_result.message + } + + # Only record changed progress + if (not tasks_progress[task_id] or + tasks_progress[task_id][-1]["progress"] != task_result.progress or + tasks_progress[task_id][-1]["status"] != task_result.status.value): + + tasks_progress[task_id].append(progress_entry) + + if ctx: + await ctx.info(f"Task {task_id}: {task_result.progress:.1f}% - {task_result.message}") + + # Check if completed + if task_result.status.value in ['success', 'failed', 'cancelled']: + completed_tasks.add(task_id) + if ctx: + await ctx.info(f"Task {task_id} completed: {task_result.status.value}") + else: + active_tasks.append(task_id) + + # All tasks completed + if len(completed_tasks) == len(task_ids): + final_results = {} + for task_id in task_ids: + task_result = task_queue.get_task_status(task_id) + if task_result: + final_results[task_id] = { + "status": task_result.status.value, + "progress": task_result.progress, + "message": task_result.message, + "result": task_result.result, + "error": task_result.error + } + + if ctx: + success_count = sum(1 for task_id in task_ids + if task_queue.get_task_status(task_id) and + task_queue.get_task_status(task_id).status.value == 'success') + await ctx.info(f"All tasks completed! {success_count}/{len(task_ids)} successful") + + return { + "success": True, + "tasks_progress": tasks_progress, + "final_results": final_results, + "completed_tasks": list(completed_tasks), + "total_watch_time": current_time - start_time, + "summary": { + "total_tasks": len(task_ids), + "successful": sum(1 for r in final_results.values() if r["status"] == "success"), + "failed": sum(1 for r in final_results.values() if r["status"] == "failed"), + "cancelled": sum(1 for r in final_results.values() if r["status"] == "cancelled") + } + } + + # Wait for next check + await asyncio.sleep(interval) + + except Exception as e: + error_msg = f"Watch tasks failed: {str(e)}" + logger.error(error_msg) + if ctx: + await ctx.error(error_msg) + return { + "success": False, + "error": error_msg, + "tasks_progress": tasks_progress if 'tasks_progress' in locals() else {} + } + # MCP tool: list all tasks @mcp.tool async def list_tasks( diff --git a/monitoring/task_monitor.py b/monitoring/task_monitor.py index 8ae0706..9391bc8 100644 --- a/monitoring/task_monitor.py +++ b/monitoring/task_monitor.py @@ -78,40 +78,93 @@ async def load_stats(): async def load_tasks(status_filter: Optional[str] = None): """Load task list""" try: - # Convert status filter - status_enum = None - if status_filter: - status_enum = TaskStatus(status_filter) - - tasks = task_queue.get_all_tasks( - status_filter=status_enum, - limit=50 - ) + # Get tasks from storage for most up-to-date status + if task_queue._storage: + from services.task_storage import TaskStatus as StorageTaskStatus, TaskType + + # Convert status filter + storage_status_enum = None + if status_filter: + storage_status_enum = StorageTaskStatus(status_filter) + + # Get tasks from storage + storage_tasks = await task_queue._storage.list_tasks( + status=storage_status_enum, + limit=50 + ) + + # Convert storage tasks to TaskResult objects for display + tasks = [] + for storage_task in storage_tasks: + # Check if task exists in memory, otherwise create it + if storage_task.id in task_queue.tasks: + memory_task = task_queue.tasks[storage_task.id] + # Update memory task with storage status + memory_task.status = storage_task.status + memory_task.progress = storage_task.progress + memory_task.completed_at = storage_task.completed_at + memory_task.started_at = storage_task.started_at + memory_task.error = storage_task.error_message + tasks.append(memory_task) + else: + # Create TaskResult from storage task + from services.task_queue import TaskResult + task_result = TaskResult( + task_id=storage_task.id, + status=storage_task.status, + progress=storage_task.progress, + message="", + error=storage_task.error_message, + created_at=storage_task.created_at, + started_at=storage_task.started_at, + completed_at=storage_task.completed_at, + metadata=storage_task.payload + ) + tasks.append(task_result) + # Also update memory + task_queue.tasks[storage_task.id] = task_result + else: + # Fallback to memory if storage not available + status_enum = None + if status_filter: + status_enum = TaskStatus(status_filter) + + tasks = task_queue.get_all_tasks( + status_filter=status_enum, + limit=50 + ) + update_tasks_display(tasks) except Exception as e: ui.notify(f'Failed to load task list: {str(e)}', type='negative') def update_stats_display(stats): """Update statistics display""" - pending_card.set_content(str(stats['pending_tasks'])) - processing_card.set_content(str(stats['processing_tasks'])) - success_card.set_content(str(stats['completed_tasks'])) - failed_card.set_content(str(stats['failed_tasks'])) - cancelled_card.set_content(str(stats['cancelled_tasks'])) - total_card.set_content(str(stats['total_tasks'])) + try: + pending_card.set_text(str(stats['pending_tasks'])) + processing_card.set_text(str(stats['processing_tasks'])) + success_card.set_text(str(stats['completed_tasks'])) + failed_card.set_text(str(stats['failed_tasks'])) + cancelled_card.set_text(str(stats['cancelled_tasks'])) + total_card.set_text(str(stats['total_tasks'])) + except Exception as e: + logger.warning(f"Failed to update stats display: {e}") def update_tasks_display(tasks): """Update task list display""" - task_container.clear() - - if not tasks: + try: + task_container.clear() + + if not tasks: + with task_container: + ui.label('No tasks').classes('text-gray-500 text-center w-full py-8') + return + with task_container: - ui.label('No tasks').classes('text-gray-500 text-center w-full py-8') - return - - with task_container: - for task in tasks: - create_task_card(task) + for task in tasks: + create_task_card(task) + except Exception as e: + logger.warning(f"Failed to update tasks display: {e}") def create_task_card(task): """Create task card""" @@ -202,31 +255,100 @@ async def cancel_task(task_id: str): def view_task_details(task): """View task details""" - with ui.dialog() as dialog, ui.card().classes('w-96'): - ui.label('Task Details').classes('text-h6') - - details = [ - ('Task ID', task.task_id), - ('Task Name', task.metadata.get('task_name', 'Unnamed')), - ('Task Type', TASK_TYPE_MAP.get(task.metadata.get('task_type', ''), 'Unknown')), - ('Status', STATUS_MAP.get(task.status.value, task.status.value)), - ('Progress', f"{task.progress:.1f}%"), - ('Created at', task.created_at.strftime('%Y-%m-%d %H:%M:%S')), - ('Started at', task.started_at.strftime('%Y-%m-%d %H:%M:%S') if task.started_at else 'Not started'), - ('Completed at', task.completed_at.strftime('%Y-%m-%d %H:%M:%S') if task.completed_at else 'Not completed'), - ('Message', task.message), - ] - - for label, value in details: - with ui.row().classes('w-full'): - ui.label(f'{label}:').classes('font-bold') - ui.label(str(value)) + with ui.dialog() as dialog, ui.card().classes('w-full max-w-2xl'): + ui.label('Task Details').classes('text-h6 mb-4') - if task.error: - ui.label('Error:').classes('font-bold text-red-600 mt-2') - ui.label(task.error).classes('text-red-600') + # Basic info + with ui.card().classes('w-full mb-4'): + ui.label('Basic Information').classes('font-bold mb-2') + basic_details = [ + ('Task ID', task.task_id), + ('Task Name', task.metadata.get('task_name', 'Unnamed')), + ('Task Type', TASK_TYPE_MAP.get(task.metadata.get('task_type', ''), 'Unknown')), + ('Status', STATUS_MAP.get(task.status.value, task.status.value)), + ('Progress', f"{task.progress:.1f}%"), + ] + + for label, value in basic_details: + with ui.row().classes('w-full justify-between'): + ui.label(f'{label}:').classes('font-medium') + ui.label(str(value)) + + # Timing info + with ui.card().classes('w-full mb-4'): + ui.label('Timing Information').classes('font-bold mb-2') + + created_at = task.created_at.strftime('%Y-%m-%d %H:%M:%S') + started_at = task.started_at.strftime('%Y-%m-%d %H:%M:%S') if task.started_at else 'Not started' + completed_at = task.completed_at.strftime('%Y-%m-%d %H:%M:%S') if task.completed_at else 'Not completed' + + # Calculate durations + from datetime import datetime + now = datetime.now() + total_time = 'N/A' + processing_time = 'N/A' + wait_time = 'N/A' + + if task.started_at: + wait_time = str(task.started_at - task.created_at).split('.')[0] + if task.completed_at: + processing_time = str(task.completed_at - task.started_at).split('.')[0] + total_time = str(task.completed_at - task.created_at).split('.')[0] + elif task.status.value == 'processing': + processing_time = str(now - task.started_at).split('.')[0] + ' (ongoing)' + total_time = str(now - task.created_at).split('.')[0] + ' (ongoing)' + else: + wait_time = str(now - task.created_at).split('.')[0] + ' (waiting)' + + timing_details = [ + ('Created at', created_at), + ('Started at', started_at), + ('Completed at', completed_at), + ('Wait time', wait_time), + ('Processing time', processing_time), + ('Total time', total_time), + ] + + for label, value in timing_details: + with ui.row().classes('w-full justify-between'): + ui.label(f'{label}:').classes('font-medium') + ui.label(str(value)) + + # Status and message + if task.message or task.error: + with ui.card().classes('w-full mb-4'): + ui.label('Status & Messages').classes('font-bold mb-2') + + if task.message: + ui.label('Current message:').classes('font-medium') + ui.label(task.message).classes('text-blue-600 mb-2') + + if task.error: + ui.label('Error:').classes('font-medium text-red-600') + ui.label(task.error).classes('text-red-600') - ui.button('Close', on_click=dialog.close).classes('mt-4') + # Metadata + if hasattr(task, 'metadata') and task.metadata: + with ui.card().classes('w-full mb-4'): + ui.label('Metadata').classes('font-bold mb-2') + + # Show relevant metadata + metadata_items = [] + if 'filename' in task.metadata: + metadata_items.append(('Filename', task.metadata['filename'])) + if 'file_size' in task.metadata: + metadata_items.append(('File size', f"{task.metadata['file_size']:,} chars")) + if 'directory_path' in task.metadata: + metadata_items.append(('Directory', task.metadata['directory_path'])) + + for label, value in metadata_items: + with ui.row().classes('w-full justify-between'): + ui.label(f'{label}:').classes('font-medium') + ui.label(str(value)) + + # Close button + with ui.row().classes('w-full justify-end mt-4'): + ui.button('Close', on_click=dialog.close).props('color=primary') dialog.open() @@ -246,6 +368,138 @@ async def refresh_data(): await load_stats() await load_tasks(status_filter_select.value if status_filter_select.value else None) + async def smart_refresh(): + """Smart refresh that only updates when needed""" + try: + # Only refresh if there are active (pending/processing) tasks + active_tasks = task_queue.get_all_tasks(limit=10) + has_active = any(t.status.value in ['pending', 'processing'] for t in active_tasks) + + if has_active: + # Refresh both stats and tasks when there's activity + await load_stats() + await load_tasks(status_filter_select.value if status_filter_select.value else None) + else: + # Only refresh stats when no active tasks + await load_stats() + + except Exception as e: + logger.debug(f"Smart refresh error: {e}") + + def detect_file_type(filename: str) -> str: + """Detect file type based on extension""" + extension = filename.lower().split('.')[-1] if '.' in filename else '' + + type_mapping = { + 'txt': 'text', + 'md': 'markdown', + 'java': 'java', + 'py': 'python', + 'js': 'javascript', + 'sql': 'sql', + 'json': 'json', + 'xml': 'xml', + 'html': 'html', + 'csv': 'csv' + } + + return type_mapping.get(extension, 'text') + + async def handle_file_upload(e): + """Handle file upload for document processing""" + try: + # Get uploaded file content + content = e.content.read().decode('utf-8') + filename = e.name + file_type = detect_file_type(filename) + file_size = len(content) + + # Display file info + file_info_label.set_text(f'📄 {filename} | {file_size:,} chars | Type: {file_type}') + + # Check file size limits + if file_size > 50 * 1024: # 50KB threshold + ui.notify( + f'文件过大 ({file_size:,} 字符)!\n' + f'建议使用以下方式处理大文件:\n' + f'1. 使用"目录处理"功能,将文件放在目录中进行批量处理\n' + f'2. 通过 MCP 客户端处理大文件\n' + f'3. 将文件分割成较小的片段\n' + f'当前限制:50KB 以下', + type='warning' + ) + return + + # For acceptable file sizes, process directly + task_kwargs = { + "document_content": content, + "document_type": file_type + } + + # Create document processing task + from services.task_processors import process_document_task + task_id = await task_queue.submit_task( + task_func=process_document_task, + task_kwargs=task_kwargs, + task_name=f"Process File: {filename}", + task_type="document_processing", + metadata={"filename": filename, "file_size": file_size}, + priority=1 + ) + + ui.notify(f'File "{filename}" uploaded and processing started! Task ID: {task_id}', type='positive') + + # Update the form + task_type_select.value = 'document_processing' + task_name_input.value = f"Process File: {filename}" + + # Refresh task list + await refresh_data() + + except UnicodeDecodeError: + ui.notify('File encoding error. Please ensure the file is in UTF-8 format.', type='negative') + except Exception as error: + ui.notify(f'File upload failed: {str(error)}', type='negative') + + async def handle_directory_processing(): + """Handle directory batch processing""" + try: + directory_path = directory_input.value.strip() + if not directory_path: + ui.notify('Please enter a directory path', type='warning') + return + + # Parse file patterns + patterns = [p.strip() for p in file_patterns_input.value.split(',') if p.strip()] + if not patterns: + patterns = ['*.*'] + + # Create batch processing task + from services.task_processors import process_batch_processing_task + task_id = await task_queue.submit_task( + task_func=process_batch_processing_task, + task_kwargs={ + "directory_path": directory_path, + "file_patterns": patterns + }, + task_name=f"Process Directory: {directory_path}", + task_type="batch_processing", + metadata={"directory_path": directory_path}, + priority=1 + ) + + ui.notify(f'Directory processing started! Task ID: {task_id}', type='positive') + + # Update the form + task_type_select.value = 'batch_processing' + task_name_input.value = f"Process Directory: {directory_path}" + + # Refresh task list + await refresh_data() + + except Exception as error: + ui.notify(f'Directory processing failed: {str(error)}', type='negative') + # Create UI layout with ui.column().classes('w-full max-w-7xl mx-auto p-4'): # Title @@ -298,14 +552,99 @@ async def refresh_data(): task_name_input = ui.input('Task Name', placeholder='Enter task name').classes('w-full') + # File upload and directory input section + with ui.expansion('📁 File & Directory Operations', icon='upload_file').classes('w-full'): + with ui.column().classes('w-full gap-4 p-4'): + # File upload for document processing + ui.label('Upload Small Files (≤ 50KB) for Document Processing').classes('font-bold') + ui.label('For larger files, use Directory Processing below').classes('text-sm text-gray-600 mb-2') + with ui.row().classes('w-full gap-2 items-end'): + file_upload = ui.upload( + on_upload=lambda e: handle_file_upload(e), + multiple=False, + max_file_size=100 * 1024 * 1024 # 100MB browser limit (we check 50KB in code) + ).props('accept=".txt,.md,.java,.py,.js,.sql,.json,.xml,.html,.csv"').classes('flex-1') + + ui.button( + 'Clear', + on_click=lambda: file_upload.reset(), + icon='clear' + ).props('size=sm') + + # File info display + file_info_label = ui.label('').classes('text-sm text-gray-600') + + # Directory path for batch processing + ui.separator().classes('my-4') + ui.label('Directory Path for Batch Processing').classes('font-bold') + + with ui.row().classes('w-full gap-2'): + directory_input = ui.input( + 'Directory Path', + placeholder='/path/to/your/documents' + ).classes('flex-1') + + ui.button( + 'Process Directory', + on_click=lambda: handle_directory_processing(), + icon='folder_open' + ).props('color=blue') + + file_patterns_input = ui.input( + 'File Patterns (comma-separated)', + placeholder='*.txt,*.md,*.java,*.py', + value='*.txt,*.md,*.java,*.py,*.js,*.sql' + ).classes('w-full') + + # Help text + with ui.expansion('ℹ️ Help & Supported Formats', icon='help').classes('w-full mt-4'): + ui.html(''' +
+

📁 File Upload (小文件 ≤ 50KB):

+ + +

📂 Directory Processing (推荐用于大文件):

+ + +

🤖 MCP Client (程序化处理):

+ + +

🔧 Features:

+ +
+ ''') + task_payload_input = ui.textarea( 'Payload (JSON)', placeholder='{"document_content": "example content", "document_type": "text"}' ).classes('w-full') task_priority_select = ui.select( - 'Priority', options={'0': 'Normal', '1': 'High', '2': 'Urgent'}, + label='Priority', value='0' ).classes('w-full') @@ -331,7 +670,10 @@ async def refresh_data(): value='', on_change=lambda: asyncio.create_task(refresh_data()) ) - ui.button('Refresh', on_click=refresh_data).props('icon=refresh') + + with ui.row().classes('gap-2'): + ui.button('🔄 Manual Refresh', on_click=refresh_data).props('color=primary size=sm') + ui.badge('🟢 Live Updates', color='green').classes('text-xs') # Task container task_container = ui.column().classes('w-full') @@ -339,5 +681,32 @@ async def refresh_data(): # Initial load data await refresh_data() - # Set auto refresh (every 5 seconds) - ui.timer(5.0, refresh_data) \ No newline at end of file + # Setup smart refresh (less intrusive than before) + ui.timer(5.0, lambda: asyncio.create_task(smart_refresh())) + + @ui.page('/') + async def root_page(): + """Root page that redirects to monitor""" + ui.page_title('Code Graph Knowledge System') + + with ui.column().classes('w-full max-w-4xl mx-auto p-8'): + with ui.card().classes('w-full text-center p-8'): + ui.label('Code Graph Knowledge System').classes('text-3xl font-bold mb-4') + ui.label('Neo4j-based intelligent knowledge management system').classes('text-gray-600 mb-6') + + with ui.row().classes('gap-4 justify-center'): + ui.button('Task Monitor', on_click=lambda: ui.navigate.to('/monitor')).props('size=lg color=primary') + ui.button('API Docs', on_click=lambda: ui.navigate.to('/docs')).props('size=lg color=secondary') + + with ui.expansion('System Information', icon='info').classes('w-full mt-6'): + ui.label('Available Features:').classes('font-bold mb-2') + features = [ + 'Document Processing and Knowledge Extraction', + 'SQL Schema Parsing and Analysis', + 'Knowledge Graph Construction', + 'Vector Search and RAG Queries', + 'Batch File Processing', + 'Real-time Task Monitoring' + ] + for feature in features: + ui.label(f'• {feature}').classes('ml-4 mb-1') \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index aae65d7..8a6b3ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,4 +35,13 @@ dependencies = [ "google-generativeai>=0.3.0", "matplotlib>=3.10.3", "nicegui>=2.19.0", + "llama-index-llms-openrouter>=0.3.2", ] + +[project.scripts] +server = "start:main" +mcp_client = "start_mcp:main" + +[tool.setuptools] +packages = ["api", "core", "services", "monitoring"] +py-modules = ["start", "start_mcp", "mcp_server", "config", "main"] diff --git a/services/neo4j_knowledge_service.py b/services/neo4j_knowledge_service.py index 548f67d..301f0b3 100644 --- a/services/neo4j_knowledge_service.py +++ b/services/neo4j_knowledge_service.py @@ -22,6 +22,7 @@ from llama_index.llms.ollama import Ollama from llama_index.llms.openai import OpenAI from llama_index.llms.gemini import Gemini +from llama_index.llms.openrouter import OpenRouter # Embedding Providers from llama_index.embeddings.ollama import OllamaEmbedding @@ -84,6 +85,16 @@ def _create_llm(self): temperature=settings.temperature, max_tokens=settings.max_tokens ) + elif provider == "openrouter": + if not settings.openrouter_api_key: + raise ValueError("OpenRouter API key is required for OpenRouter provider") + return OpenRouter( + model=settings.openrouter_model, + api_key=settings.openrouter_api_key, + temperature=settings.temperature, + max_tokens=settings.openrouter_max_tokens, + timeout=self.operation_timeout + ) else: raise ValueError(f"Unsupported LLM provider: {provider}") @@ -117,6 +128,15 @@ def _create_embedding_model(self): return HuggingFaceEmbedding( model_name=settings.huggingface_embedding_model ) + elif provider == "openrouter": + if not settings.openrouter_api_key: + raise ValueError("OpenRouter API key is required for OpenRouter embedding provider") + return OpenAIEmbedding( + model=settings.openrouter_embedding_model, + api_key=settings.openrouter_api_key, + api_base=settings.openrouter_base_url, + timeout=self.operation_timeout + ) else: raise ValueError(f"Unsupported embedding provider: {provider}") @@ -659,4 +679,4 @@ def set_timeouts(self, connection_timeout: int = None, operation_timeout: int = logger.info(f"Large document timeout set to {large_document_timeout}s") # global service instance -neo4j_knowledge_service = Neo4jKnowledgeService() \ No newline at end of file +neo4j_knowledge_service = Neo4jKnowledgeService() diff --git a/services/pipeline/embeddings.py b/services/pipeline/embeddings.py index 55f9c89..1c0b7f1 100644 --- a/services/pipeline/embeddings.py +++ b/services/pipeline/embeddings.py @@ -197,6 +197,63 @@ async def generate_embeddings(self, texts: List[str]) -> List[List[float]]: logger.error(f"Failed to generate Ollama embeddings: {e}") raise +class OpenRouterEmbeddingGenerator(EmbeddingGenerator): + """OpenRouter embedding generator""" + + def __init__(self, api_key: str, model: str = "text-embedding-ada-002"): + self.api_key = api_key + self.model = model + self.client = None + + async def _get_client(self): + """get OpenRouter client (which is the same as OpenAI client)""" + if self.client is None: + try: + from openai import AsyncOpenAI + self.client = AsyncOpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=self.api_key, + # OpenRouter requires the HTTP referer header to be set + # We set the referer to the application's name, or use a default + default_headers={ + "HTTP-Referer": "CodeGraphKnowledgeService", + "X-Title": "CodeGraph Knowledge Service" + } + ) + except ImportError: + raise ImportError("Please install openai: pip install openai") + return self.client + + async def generate_embedding(self, text: str) -> List[float]: + """generate single text embedding vector""" + client = await self._get_client() + + try: + response = await client.embeddings.create( + input=text, + model=self.model + ) + return response.data[0].embedding + + except Exception as e: + logger.error(f"Failed to generate OpenRouter embedding: {e}") + raise + + async def generate_embeddings(self, texts: List[str]) -> List[List[float]]: + """batch generate embedding vectors""" + client = await self._get_client() + + try: + response = await client.embeddings.create( + input=texts, + model=self.model + ) + return [data.embedding for data in response.data] + + except Exception as e: + logger.error(f"Failed to generate OpenRouter embeddings: {e}") + raise + class EmbeddingGeneratorFactory: """embedding generator factory""" @@ -206,7 +263,7 @@ def create_generator(config: dict) -> EmbeddingGenerator: provider = config.get("provider", "huggingface").lower() if provider == "huggingface": - model_name = config.get("model_name", "BAAI/bge-small-zh-v1.5") + model_name = config.get("model_name", "BAAI/bge-small-zh-v1.05") return HuggingFaceEmbeddingGenerator(model_name=model_name) elif provider == "openai": @@ -221,6 +278,13 @@ def create_generator(config: dict) -> EmbeddingGenerator: model = config.get("model", "nomic-embed-text") return OllamaEmbeddingGenerator(host=host, model=model) + elif provider == "openrouter": + api_key = config.get("api_key") + if not api_key: + raise ValueError("OpenRouter API key is required") + model = config.get("model", "text-embedding-ada-002") + return OpenRouterEmbeddingGenerator(api_key=api_key, model=model) + else: raise ValueError(f"Unsupported embedding provider: {provider}") @@ -240,4 +304,4 @@ def get_default_embedding_generator() -> EmbeddingGenerator: def set_default_embedding_generator(generator: EmbeddingGenerator): """set default embedding generator""" global default_embedding_generator - default_embedding_generator = generator \ No newline at end of file + default_embedding_generator = generator diff --git a/services/task_processors.py b/services/task_processors.py index 68c8b3b..984d06a 100644 --- a/services/task_processors.py +++ b/services/task_processors.py @@ -4,16 +4,14 @@ """ import asyncio -import logging from typing import Dict, Any, Optional, Callable from abc import ABC, abstractmethod from pathlib import Path import json +from loguru import logger from .task_storage import TaskType, Task -logger = logging.getLogger(__name__) - class TaskProcessor(ABC): """task processor base class""" @@ -38,14 +36,26 @@ async def process(self, task: Task, progress_callback: Optional[Callable] = None payload = task.payload try: + logger.info(f"Task {task.id} - Starting document processing") self._update_progress(progress_callback, 10, "Starting document processing") - # extract parameters from payload - document_content = payload.get("document_content") - document_path = payload.get("document_path") - document_type = payload.get("document_type", "text") + # extract parameters from payload (parameters are nested under "kwargs") + kwargs = payload.get("kwargs", {}) + document_content = kwargs.get("document_content") + document_path = kwargs.get("document_path") + document_type = kwargs.get("document_type", "text") + temp_file_cleanup = kwargs.get("_temp_file", False) + + # Debug logging for large document issues + logger.info(f"Task {task.id} - Content length: {len(document_content) if document_content else 'None'}") + logger.info(f"Task {task.id} - Path provided: {document_path}") + logger.info(f"Task {task.id} - Available kwargs keys: {list(kwargs.keys())}") + logger.info(f"Task {task.id} - Full payload structure: task_name={payload.get('task_name')}, has_kwargs={bool(kwargs)}") if not document_content and not document_path: + logger.error(f"Task {task.id} - Missing document content/path. Payload keys: {list(payload.keys())}") + logger.error(f"Task {task.id} - Kwargs content: {kwargs}") + logger.error(f"Task {task.id} - Document content type: {type(document_content)}, Path type: {type(document_path)}") raise ValueError("Either document_content or document_path must be provided") # if path is provided, read file content @@ -84,6 +94,16 @@ async def process(self, task: Task, progress_callback: Optional[Callable] = None except Exception as e: logger.error(f"Document processing failed: {e}") raise + finally: + # Clean up temporary file if it was created + if temp_file_cleanup and document_path: + try: + import os + if os.path.exists(document_path): + os.unlink(document_path) + logger.info(f"Cleaned up temporary file: {document_path}") + except Exception as cleanup_error: + logger.warning(f"Failed to clean up temporary file {document_path}: {cleanup_error}") async def _process_with_neo4j(self, content: str, doc_type: str, progress_callback: Optional[Callable]) -> Dict[str, Any]: """use Neo4j service to process document""" @@ -136,10 +156,11 @@ async def process(self, task: Task, progress_callback: Optional[Callable] = None try: self._update_progress(progress_callback, 10, "Starting schema parsing") - # extract parameters from payload - schema_content = payload.get("schema_content") - schema_path = payload.get("schema_path") - schema_type = payload.get("schema_type", "sql") + # extract parameters from payload (parameters are nested under "kwargs") + kwargs = payload.get("kwargs", {}) + schema_content = kwargs.get("schema_content") + schema_path = kwargs.get("schema_path") + schema_type = kwargs.get("schema_type", "sql") if not schema_content and not schema_path: raise ValueError("Either schema_content or schema_path must be provided") @@ -236,9 +257,10 @@ async def process(self, task: Task, progress_callback: Optional[Callable] = None try: self._update_progress(progress_callback, 10, "Starting knowledge graph construction") - # extract parameters from payload - data_sources = payload.get("data_sources", []) - construction_type = payload.get("construction_type", "full") + # extract parameters from payload (parameters are nested under "kwargs") + kwargs = payload.get("kwargs", {}) + data_sources = kwargs.get("data_sources", []) + construction_type = kwargs.get("construction_type", "full") if not data_sources: raise ValueError("No data sources provided for knowledge graph construction") @@ -331,10 +353,11 @@ async def process(self, task: Task, progress_callback: Optional[Callable] = None try: self._update_progress(progress_callback, 10, "Starting batch processing") - # extract parameters from payload - directory_path = payload.get("directory_path") - file_patterns = payload.get("file_patterns", ["*.txt", "*.md", "*.sql"]) - batch_size = payload.get("batch_size", 10) + # extract parameters from payload (parameters are nested under "kwargs") + kwargs = payload.get("kwargs", {}) + directory_path = kwargs.get("directory_path") + file_patterns = kwargs.get("file_patterns", ["*.txt", "*.md", "*.sql"]) + batch_size = kwargs.get("batch_size", 10) if not directory_path: raise ValueError("Directory path is required for batch processing") diff --git a/services/task_queue.py b/services/task_queue.py index 882e842..90faa46 100644 --- a/services/task_queue.py +++ b/services/task_queue.py @@ -11,9 +11,7 @@ from dataclasses import dataclass, field from datetime import datetime import json -import logging - -logger = logging.getLogger(__name__) +from loguru import logger class TaskStatus(Enum): PENDING = "pending" @@ -47,6 +45,7 @@ def __init__(self, max_concurrent_tasks: int = 3): self._cleanup_task = None self._storage = None # delay initialization to avoid circular import self._worker_id = str(uuid.uuid4()) # unique worker ID for locking + self._task_worker = None # task processing worker async def start(self): """start task queue""" @@ -61,7 +60,21 @@ async def start(self): self._cleanup_task = asyncio.create_task(self._cleanup_completed_tasks()) # start worker to process pending tasks - asyncio.create_task(self._process_pending_tasks()) + logger.info("About to start task processing worker...") + task_worker = asyncio.create_task(self._process_pending_tasks()) + logger.info("Task processing worker started") + + # Store the task worker reference to keep it alive + self._task_worker = task_worker + + # Test if we can get pending tasks immediately + try: + test_tasks = await self._storage.get_pending_tasks(limit=5) + logger.info(f"Initial pending tasks check: found {len(test_tasks)} tasks") + for task in test_tasks: + logger.info(f" - Task {task.id}: {task.type.value}") + except Exception as e: + logger.error(f"Failed to get initial pending tasks: {e}") logger.info(f"Task queue started with max {self.max_concurrent_tasks} concurrent tasks") @@ -75,6 +88,11 @@ async def stop(self): if task_id in self.tasks: self.tasks[task_id].status = TaskStatus.CANCELLED + # stop task worker + if hasattr(self, '_task_worker') and self._task_worker: + self._task_worker.cancel() + self._task_worker = None + # stop cleanup task if self._cleanup_task: self._cleanup_task.cancel() @@ -90,6 +108,7 @@ async def _restore_tasks_from_storage(self): try: # restore all incomplete tasks stored_tasks = await self._storage.list_tasks(limit=1000) + logger.info(f"Restoring {len(stored_tasks)} tasks from storage") for task in stored_tasks: # create TaskResult object for memory management @@ -178,37 +197,60 @@ async def submit_task(self, async def _process_pending_tasks(self): """continuously process pending tasks""" + logger.info("Task processing loop started") + loop_count = 0 while True: + loop_count += 1 + if loop_count % 60 == 1: # Log every 60 iterations (every minute) + logger.debug(f"Task processing loop iteration {loop_count}") try: + if not self._storage: + if loop_count % 50 == 1: # Log storage issue every 50 iterations + logger.warning("No storage available for task processing") + await asyncio.sleep(1) + continue + if self._storage: # 获取待处理的任务 pending_tasks = await self._storage.get_pending_tasks( limit=self.max_concurrent_tasks ) + if loop_count % 10 == 1 and pending_tasks: # Log every 10 iterations if tasks found + logger.info(f"Found {len(pending_tasks)} pending tasks") + elif pending_tasks: # Always log when tasks are found + logger.debug(f"Found {len(pending_tasks)} pending tasks") + for task in pending_tasks: # 检查是否已经在运行 if task.id in self.running_tasks: + logger.debug(f"Task {task.id} already running, skipping") continue + logger.info(f"Attempting to acquire lock for task {task.id}") # 尝试获取任务锁 if await self._storage.acquire_task_lock(task.id, self._worker_id): + logger.info(f"Lock acquired, starting execution for task {task.id}") # 启动任务执行 async_task = asyncio.create_task( self._execute_stored_task(task) ) self.running_tasks[task.id] = async_task + else: + logger.debug(f"Failed to acquire lock for task {task.id}") # 等待一段时间再检查 await asyncio.sleep(1) except Exception as e: logger.error(f"Error in task processing loop: {e}") + logger.exception(f"Full traceback for task processing loop error:") await asyncio.sleep(5) async def _execute_stored_task(self, task): """execute stored task""" task_id = task.id + logger.info(f"Starting execution of stored task {task_id}") task_result = self.tasks.get(task_id) if not task_result: @@ -241,7 +283,9 @@ async def _execute_stored_task(self, task): # here we need to dynamically restore task function based on task type # for now, we use a placeholder, actual implementation needs task registration mechanism + logger.info(f"Task {task_id} about to execute by type: {task.type}") result = await self._execute_task_by_type(task) + logger.info(f"Task {task_id} execution completed with result: {type(result)}") # task completed task_result.status = TaskStatus.SUCCESS @@ -467,9 +511,9 @@ async def submit_document_processing_task( ) -> str: """submit document processing task""" return await task_queue.submit_task( - service_method, - args, - kwargs, + task_func=service_method, + task_args=args, + task_kwargs=kwargs, task_name=task_name, task_type="document_processing" ) @@ -482,9 +526,9 @@ async def submit_directory_processing_task( ) -> str: """submit directory processing task""" return await task_queue.submit_task( - service_method, - (directory_path,), - kwargs, + task_func=service_method, + task_args=(directory_path,), + task_kwargs=kwargs, task_name=task_name, task_type="batch_processing" ) \ No newline at end of file diff --git a/services/task_storage.py b/services/task_storage.py index 0b506ce..5b78c8c 100644 --- a/services/task_storage.py +++ b/services/task_storage.py @@ -13,6 +13,7 @@ from dataclasses import dataclass, asdict from pathlib import Path from loguru import logger +from config import settings from .task_queue import TaskResult, TaskStatus @@ -43,16 +44,52 @@ def to_dict(self) -> Dict[str, Any]: data['created_at'] = self.created_at.isoformat() data['started_at'] = self.started_at.isoformat() if self.started_at else None data['completed_at'] = self.completed_at.isoformat() if self.completed_at else None - data['payload'] = json.dumps(self.payload) + + # Add error handling for large payload serialization + try: + payload_json = json.dumps(self.payload) + # Check if payload is too large + if len(payload_json) > settings.max_payload_size: + logger.warning(f"Task {self.id} payload is very large ({len(payload_json)} bytes)") + # For very large payloads, store a summary instead + summary_payload = { + "error": "Payload too large for storage", + "original_size": len(payload_json), + "original_keys": list(self.payload.keys()) if isinstance(self.payload, dict) else str(type(self.payload)), + "truncated_sample": str(self.payload)[:1000] + "..." if len(str(self.payload)) > 1000 else str(self.payload) + } + data['payload'] = json.dumps(summary_payload) + else: + data['payload'] = payload_json + except (TypeError, ValueError) as e: + logger.error(f"Failed to serialize payload for task {self.id}: {e}") + # Store a truncated version for debugging + data['payload'] = json.dumps({ + "error": "Payload too large to serialize", + "original_keys": list(self.payload.keys()) if isinstance(self.payload, dict) else str(type(self.payload)), + "serialization_error": str(e) + }) + return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Task': + # Handle payload deserialization with error handling + payload = {} + try: + if isinstance(data['payload'], str): + payload = json.loads(data['payload']) + else: + payload = data['payload'] + except (json.JSONDecodeError, TypeError) as e: + logger.error(f"Failed to deserialize payload for task {data['id']}: {e}") + payload = {"error": "Failed to deserialize payload", "raw_payload": str(data['payload'])[:1000]} + return cls( id=data['id'], type=TaskType(data['type']), status=TaskStatus(data['status']), - payload=json.loads(data['payload']) if isinstance(data['payload'], str) else data['payload'], + payload=payload, created_at=datetime.fromisoformat(data['created_at']), started_at=datetime.fromisoformat(data['started_at']) if data['started_at'] else None, completed_at=datetime.fromisoformat(data['completed_at']) if data['completed_at'] else None, diff --git a/start.py b/start.py index fa2d52f..b3f1004 100644 --- a/start.py +++ b/start.py @@ -11,7 +11,7 @@ # add project root to path sys.path.insert(0, str(Path(__file__).parent)) -from config import settings, validate_neo4j_connection, validate_ollama_connection +from config import settings, validate_neo4j_connection, validate_ollama_connection, validate_openrouter_connection, get_current_model_info from loguru import logger def check_dependencies(): @@ -20,9 +20,16 @@ def check_dependencies(): checks = [ ("Neo4j", validate_neo4j_connection), - ("Ollama", validate_ollama_connection) ] + # Conditionally add Ollama if it is the selected LLM or embedding provider + if settings.llm_provider == "ollama" or settings.embedding_provider == "ollama": + checks.append(("Ollama", validate_ollama_connection)) + + # Conditionally add OpenRouter if it is the selected LLM or embedding provider + if settings.llm_provider == "openrouter" or settings.embedding_provider == "openrouter": + checks.append(("OpenRouter", validate_openrouter_connection)) + all_passed = True for service_name, check_func in checks: try: @@ -68,9 +75,10 @@ def print_startup_info(): print(f" Neo4j: {settings.neo4j_uri}") print(f" Ollama: {settings.ollama_base_url}") print() + model_info = get_current_model_info() print("model config:") - print(f" LLM: {settings.ollama_model}") - print(f" Embedding: {settings.embedding_model}") + print(f" LLM: {model_info['llm_model']}") + print(f" Embedding: {model_info['embedding_model']}") print("="*60) print() @@ -108,4 +116,4 @@ def main(): sys.exit(1) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/start_mcp.py b/start_mcp.py index e545f04..a1bff77 100644 --- a/start_mcp.py +++ b/start_mcp.py @@ -42,7 +42,7 @@ def check_dependencies(): def check_services(): """check necessary services""" - from config import validate_neo4j_connection, validate_ollama_connection + from config import validate_neo4j_connection, validate_ollama_connection, validate_openrouter_connection, settings logger.info("Checking service connections...") @@ -54,13 +54,22 @@ def check_services(): logger.error("Please ensure Neo4j is running and accessible") return False - # check Ollama connection - if validate_ollama_connection(): - logger.info("✓ Ollama connection successful") - else: - logger.error("✗ Ollama connection failed") - logger.error("Please ensure Ollama is running and accessible") - return False + # Conditionally check LLM provider connections + if settings.llm_provider == "ollama" or settings.embedding_provider == "ollama": + if validate_ollama_connection(): + logger.info("✓ Ollama connection successful") + else: + logger.error("✗ Ollama connection failed") + logger.error("Please ensure Ollama is running and accessible") + return False + + if settings.llm_provider == "openrouter" or settings.embedding_provider == "openrouter": + if validate_openrouter_connection(): + logger.info("✓ OpenRouter connection successful") + else: + logger.error("✗ OpenRouter connection failed") + logger.error("Please ensure OpenRouter API key is configured correctly") + return False return True