diff --git a/AGENTS.md b/AGENTS.md index 72a6f2be2..3706ab943 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,7 +25,6 @@ Don't create helpers, utilities, or abstractions for one-time operations. Don't - Icon 应该从 @ant-design/icons-vue 或者 lucide-vue-next (推荐,但是需要注意尺寸) - Vue 中的样式使用 less,非必要情况必须使用[base.css](web/src/assets/css/base.css) 中的颜色变量。 - UI风格要简洁,同时要保持一致性,不要悬停位移,不要过度使用阴影以及渐变色。 -- 开发完成后,可以在 docker 的 web 文件夹下,运行 npm run format 格式化代码 ### 后端开发规范 @@ -40,12 +39,11 @@ docker compose exec api uv run python test/your_script.py # 放在 test 文件 ``` 注意: - - Python 代码要符合 Python 的规范,符合 pythonic 风格 - 尽量使用较新的语法,避免使用旧版本的语法(版本兼容到 3.12+) +- 修改完成后使用 YUXI_SUPER_ADMIN_NAME / YUXI_SUPER_ADMIN_PASSWORD 环境变量调试接口(读取 .env 文件) **其他**: -- 使用 YUXI_SUPER_ADMIN_NAME / YUXI_SUPER_ADMIN_PASSWORD 调试接口 - 如果需要新建说明文档(仅开发者可见,非必要不创建),则保存在 `docs/vibe` 文件夹下面 - 代码更新后要检查文档部分是否有需要更新的地方,文档的目录定义在 `docs/.vitepress/config.mts` 中。文档应该更新最新版(`docs/latest`) diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index 99d90c9ac..4a6403ec4 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -16,6 +16,7 @@ services: - .env.prod environment: - HOST_IP=${HOST_IP:-} + - POSTGRES_URL=${POSTGRES_URL:-postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@postgres:5432/${POSTGRES_DB:-yuxi_know}} - NEO4J_URI=${NEO4J_URI:-bolt://graph:7687} - NEO4J_USERNAME=${NEO4J_USERNAME:-neo4j} - NEO4J_PASSWORD=${NEO4J_PASSWORD:-0123456789} @@ -150,6 +151,27 @@ services: - app-network restart: unless-stopped + postgres: + image: postgres:16 + container_name: postgres + ports: + - "5432:5432" + environment: + - POSTGRES_USER=${POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-postgres} + - POSTGRES_DB=${POSTGRES_DB:-yuxi_know} + - TZ=${TZ:-Asia/Shanghai} + volumes: + - ./docker/volumes/postgresql:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-postgres} -d ${POSTGRES_DB:-yuxi_know} || exit 1"] + interval: 5s + timeout: 3s + retries: 30 + networks: + - app-network + restart: unless-stopped + mineru-vllm-server: build: context: . diff --git a/docker-compose.yml b/docker-compose.yml index 67bef0cf6..1f27cbb8e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -28,6 +28,7 @@ services: # region api_envs environment: - HOST_IP=${HOST_IP:-} + - POSTGRES_URL=${POSTGRES_URL:-postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@postgres:5432/${POSTGRES_DB:-yuxi_know}} - NEO4J_URI=${NEO4J_URI:-bolt://graph:7687} - NEO4J_USERNAME=${NEO4J_USERNAME:-neo4j} - NEO4J_PASSWORD=${NEO4J_PASSWORD:-0123456789} @@ -51,8 +52,8 @@ services: retries: 8 start_period: 180s depends_on: - # postgres: - # condition: service_healthy + postgres: + condition: service_healthy milvus: condition: service_healthy minio: @@ -181,15 +182,25 @@ services: - app-network restart: unless-stopped - sqlite-web: - image: coleifer/sqlite-web:latest - container_name: sqlite-web - volumes: - - ./saves/database:/data + postgres: + image: postgres:16 + container_name: postgres ports: - - "9092:8080" + - "5432:5432" environment: - SQLITE_DATABASE: server.db + - POSTGRES_USER=${POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-postgres} + - POSTGRES_DB=${POSTGRES_DB:-yuxi_know} + - TZ=${TZ:-Asia/Shanghai} + volumes: + - ./docker/volumes/postgresql:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-postgres} -d ${POSTGRES_DB:-yuxi_know} || exit 1"] + interval: 5s + timeout: 3s + retries: 30 + networks: + - app-network restart: unless-stopped # lastest version: wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/docker/compose.yaml diff --git a/docs/latest/advanced/misc.md b/docs/latest/advanced/misc.md index afb2a3303..333ab97e7 100644 --- a/docs/latest/advanced/misc.md +++ b/docs/latest/advanced/misc.md @@ -42,10 +42,10 @@ | **7474/7687** | Neo4j | graph | 图数据库 | | **9000/9001** | MinIO | milvus-minio | 对象存储 | | **19530/9091** | Milvus | milvus | 向量数据库 | +| **5432** | postgres | postgres | PostgreSQL 数据库 | | **30000** | MinerU | mineru | PDF 解析(可选)| | **8080** | PP-StructureV3 | paddlex-ocr | OCR 服务(可选)| | **8081** | vLLM | - | 本地推理(可选)| -| **9092** | SQLite Web UI | sqlite-web | SQLite 数据库管理界面(可选)| ::: tip 端口访问 - Web 界面: `http://localhost:5173` diff --git a/docs/latest/changelog/migrate_to_v0-5.md b/docs/latest/changelog/migrate_to_v0-5.md new file mode 100644 index 000000000..283420dbf --- /dev/null +++ b/docs/latest/changelog/migrate_to_v0-5.md @@ -0,0 +1,108 @@ +# v0.5 数据迁移指南 + +v0.5 版本将数据存储从 SQLite + JSON 文件迁移到 PostgreSQL。本指南帮助你完成数据迁移。 + +::: tip warning +迁移脚本可能会存在问题,不建议在生产环境下尝试,生产环境下,请新建或仔细检查迁移脚本,慎重迁移。 +::: + +## 迁移内容 + +| 数据类型 | 源 | 目标 | 存储内容 | +|---------|-----|------|---------| +| 业务数据 | SQLite (`saves/database/server.db`) | PostgreSQL | 用户、部门、对话、消息、工具调用、MCP 服务器等 | +| 知识库元数据 | JSON 文件 (`saves/knowledge_base_data/`) | PostgreSQL | 知识库配置、文件信息、评估数据 | +| Tasker 任务记录 | JSON 文件 (`saves/tasks/tasks.json`) | PostgreSQL | 后台任务状态、进度、结果(独立存储) | + +## 迁移前准备 + +### 1. 启动服务 + +```bash +docker compose up -d --build +``` + +### 2. 备份数据 + +**重要:** 迁移前必须备份数据! + +```bash +# 备份 saves 目录(包含 SQLite 数据库和知识库元数据) +cp -r saves saves_backup_$(date +%Y%m%d) + +# 如果使用外部数据库,也请备份 PostgreSQL +pg_dump -U postgres -d yuxi_know > pg_backup_$(date +%Y%m%d).sql +``` + +### 3. 确保 PostgreSQL 已启动 + +```bash +docker compose up -d postgres +# 等待健康检查通过 +``` + +## 执行迁移 + +### 方式一:使用统一迁移脚本(推荐) + +```bash +# 1. 预览迁移(不执行) +docker compose exec api python scripts/migrate_all.py --dry-run + +# 2. 执行迁移 +docker compose exec api python scripts/migrate_all.py --execute + +# 3. 验证迁移结果 +docker compose exec api python scripts/migrate_all.py --verify +``` + +### 方式二:分阶段迁移 + +```bash +# 只迁移业务数据 +docker compose exec api python scripts/migrate_all.py --execute --stage business + +# 只迁移知识库元数据 +docker compose exec api python scripts/migrate_all.py --execute --stage knowledge + +# 只迁移 Tasker 任务记录 +docker compose exec api python scripts/migrate_all.py --execute --stage tasker +``` + +## 重启服务 + +```bash +docker compose up -d +``` + +## 验证迁移 + +### 检查服务状态 + +```bash +# 查看 API 服务日志 +docker logs api-dev --tail 50 + +# 检查健康状态 +curl http://localhost:5050/api/system/health +``` + +### 验证数据 + +```bash +# 使用迁移脚本验证 +docker compose exec api python scripts/migrate_all.py --verify +``` + +预期输出: + +``` +============================================================ +📊 验证结果汇总 +============================================================ +✅ departments: SQLite=X, PostgreSQL=X +✅ users: SQLite=X, PostgreSQL=X +✅ conversations: SQLite=X, PostgreSQL=X +... +全部匹配: ✅ 是 +``` diff --git a/docs/latest/changelog/roadmap.md b/docs/latest/changelog/roadmap.md index 8569a9a61..753f2dd1a 100644 --- a/docs/latest/changelog/roadmap.md +++ b/docs/latest/changelog/roadmap.md @@ -14,6 +14,13 @@ - 系统层面添加 apikey,在智能体、知识库调用中支持 apikey 以支持外部调用 - 支持更多类型的文档源的导入功能 - 检查非 Agent 场景下的知识库的可见情况 +- Tasker 新增删除任务的接口 +- 部分场景应该使用默认模型作为默认值而不是空值 +- CommonRAG 添加更多检索类型,比如 BM25,关键词,grep 等 +- 文件上传解析后,如何提示用户需要入库 +- 检索测试中,添加问答 +- 优化 MySQL 查询工具中,自动从环境变量中加载知识库的表述作为工具的描述 + ### Bugs - 部分异常状态下,智能体的模型名称出现重叠[#279](https://github.com/xerrors/Yuxi-Know/issues/279) @@ -24,6 +31,7 @@ - 工具传递给模型的时候,使用英文,但部分模型不支持中文函数名(如gpt-4o-mini) - 当前的 upload 图谱查询为同步操作,可能会导致页面卡顿 - FileTable 的自动刷新失效 +- 生成基准测试会把所有的向量都计算一遍不合理 ## v0.5 @@ -33,7 +41,6 @@ - 优化 RAG 检索,支持根据文件 pattern 来检索(Agentic Mode) - 重构智能体对于“工具变更/模型变更”的处理逻辑,无需导入更复杂的中间件 - 重构知识库的 Agentic 配置逻辑,与 Tools 解耦 -- 新增Sqlite Web UI 方便通过Web页面管理数据库中数据[#463](https://github.com/xerrors/Yuxi-Know/pull/463) - 将工具与知识库解耦,在 context 中就完成解耦,虽然最终都是在 Agent 中的 get_tools 中获取 - 优化chunk逻辑,移除 QA 分割,集成到普通分块中,并优化可视化逻辑 - 重构知识库处理逻辑,分为 上传—解析—入库 三个阶段 diff --git a/docs/latest/intro/knowledge-base.md b/docs/latest/intro/knowledge-base.md index 96740c7ac..ea666a7cc 100644 --- a/docs/latest/intro/knowledge-base.md +++ b/docs/latest/intro/knowledge-base.md @@ -24,6 +24,20 @@ 在前端界面中,默认完成前两步,即上传后会自动解析,如果想要实现解析后还继续入库的话,需要在上传的时候勾选自动入库。否则需要在上传后手动点击入库。 +## 知识库权限控制 + +每个知识库可以配置独立的访问权限: + +- **共享模式**: 设置知识库是否全局共享 +- **部门访问**: 配置允许访问该知识库的部门范围 + +权限规则: +- **超级管理员**: 可访问所有知识库 +- **管理员**: 可访问共享以及本部门所有知识库 +- **普通用户**: 仅能访问已授权的知识库(通过部门或全局共享) + +创建/编辑知识库时,可在"分享配置"中设置权限。 + ## 文档管理 本系统的“上传 → 解析入库 → 检索/可视化”流程既可通过 Web 界面完成,也可使用 API/脚本批量处理。详见[文档解析](../advanced/document-processing.md) diff --git a/docs/latest/intro/project-overview.md b/docs/latest/intro/project-overview.md index e3a898f8a..b4db6425b 100644 --- a/docs/latest/intro/project-overview.md +++ b/docs/latest/intro/project-overview.md @@ -8,7 +8,7 @@ Yuxi-Know(语析)是一个基于知识图谱和向量数据库的智能知 - **后端服务**: [FastAPI](https://github.com/tiangolo/fastapi) + Python 3.12+ - **前端界面**: [Vue.js 3](https://github.com/vuejs/vue) + [Ant Design Vue](https://github.com/vueComponent/ant-design-vue) -- **数据库存储**: [SQLite](https://github.com/sqlite/sqlite) + [MinIO](https://github.com/minio/minio) +- **数据库存储**: [PostgreSQL](https://github.com/postgres/postgres) + [MinIO](https://github.com/minio/minio) - **知识存储**: [Milvus](https://github.com/milvus-io/milvus)(向量数据库)+ [Neo4j](https://github.com/neo4j/neo4j)(图数据库) - **智能体框架**: [LangGraph](https://github.com/langchain-ai/langgraph) - **文档解析**: [LightRAG](https://github.com/HKUDS/LightRAG) + [MinerU](https://github.com/HKUDS/MinerU) + [PP-Structure-V3](https://github.com/PaddlePaddle/PaddleOCR) @@ -20,5 +20,5 @@ Yuxi-Know(语析)是一个基于知识图谱和向量数据库的智能知 - **知识库管理**: 支持多种存储形式(Milvus、LightRAG) - **知识图谱**: 自动构建和可视化知识图谱,支持图查询 - **文档解析**: 支持 PDF、Word、图片等多种格式的智能解析 -- **权限管理**: 三级权限体系(超级管理员、管理员、普通用户) +- **权限管理**: 基于部门的知识库访问控制 - **内容安全**: 内置内容审查机制,保障服务合规性 diff --git a/docs/latest/intro/quick-start.md b/docs/latest/intro/quick-start.md index f3ac42e56..396672c2e 100644 --- a/docs/latest/intro/quick-start.md +++ b/docs/latest/intro/quick-start.md @@ -16,12 +16,13 @@ ```bash # 克隆稳定版本 -git clone --branch v0.4.3 --depth 1 https://github.com/xerrors/Yuxi-Know.git +git clone --branch v0.5.0-beta --depth 1 https://github.com/xerrors/Yuxi-Know.git cd Yuxi-Know ``` ::: warning 版本说明 -- `v0.4.3`: 稳定版本 +- `v0.4.4`: 稳定版本 +- `v0.5.0-beta`: 由于数据库重构使用 postgres,可能会存在数据库迁移问题,建议新用户使用。 - `main`: 最新开发版本(不稳定,新特性可能会导致新 bug) ::: @@ -106,7 +107,7 @@ docker compose down ## 故障排除 ::: tip 调试面板 -前端有个**调试面板**,长按侧边栏空白处触发,生产环境建议删除此特性,在 `AppLayout.vue` 中注释掉相关代码。 +前端有个**调试面板**,在头像选项里,生产环境建议删除此特性,在 `AppLayout.vue` 中注释掉相关代码。 ::: #### 查看服务状态 diff --git a/scripts/migrate_all.py b/scripts/migrate_all.py new file mode 100644 index 000000000..c80e7cf98 --- /dev/null +++ b/scripts/migrate_all.py @@ -0,0 +1,1471 @@ +""" +统一数据迁移脚本 + +功能: +- 阶段化执行迁移,支持单独运行某个阶段 +- 详细的日志输出和进度追踪 +- 支持预览、回滚和验证 +- 数据完整性检查 + +使用方式: + # 预览所有迁移 + python scripts/migrate_all.py --dry-run + + # 执行所有迁移 + python scripts/migrate_all.py --execute + + # 只迁移业务数据 (SQLite -> PostgreSQL) + python scripts/migrate_all.py --execute --stage business + + # 只迁移知识库元数据 (JSON -> PostgreSQL) + python scripts/migrate_all.py --execute --stage knowledge + + # 验证迁移结果 + python scripts/migrate_all.py --verify + + # 回滚所有迁移 + python scripts/migrate_all.py --rollback + + # 回滚指定阶段 + python scripts/migrate_all.py --rollback --stage business +""" + +import argparse +import asyncio +import glob +import json +import os +import sys +from dataclasses import dataclass, field +from datetime import datetime, UTC +from typing import Any, Callable + +# 确保路径正确 +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +os.environ.setdefault("YUXI_SKIP_APP_INIT", "1") + +from sqlalchemy import Column, DateTime, Integer, String, Text, create_engine, select, text +from sqlalchemy.orm import declarative_base, sessionmaker + +from src import config +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import ( + Department, + User, + Conversation, + Message, + ToolCall, + ConversationStats, + OperationLog, + MessageFeedback, + MCPServer, + AgentConfig, + TaskRecord, +) +from src.storage.postgres.models_knowledge import ( + KnowledgeBase, + KnowledgeFile, + EvaluationBenchmark, + EvaluationResult, + EvaluationResultDetail, +) +from src.utils import logger + + +# ============================================================ +# 迁移阶段定义 +# ============================================================ + +@dataclass +class MigrationStage: + """迁移阶段""" + name: str # 阶段名称 + description: str # 阶段描述 + migrate_fn: Callable # 迁移函数 + rollback_fn: Callable | None = None # 回滚函数 + verify_fn: Callable | None = None # 验证函数 + depends_on: list[str] = field(default_factory=list) # 依赖阶段 + + +@dataclass +class MigrationResult: + """迁移结果""" + stage_name: str + success: bool + dry_run: bool + records_total: int = 0 + records_migrated: int = 0 + records_skipped: int = 0 + error: str | None = None + duration_ms: float = 0.0 + + +# ============================================================ +# SQLite 模型定义 (仅用于迁移) +# ============================================================ + +SQLiteBase = declarative_base() + + +class SqliteDepartment(SQLiteBase): + __tablename__ = "departments" + + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + description = Column(Text) + created_at = Column(DateTime) + + +class SqliteUser(SQLiteBase): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + username = Column(String(50), unique=True, nullable=False) + user_id = Column(String(50), unique=True) + phone_number = Column(String(20)) + avatar = Column(String(500)) + password_hash = Column(String(255)) + role = Column(String(20), default="user") + department_id = Column(Integer) + created_at = Column(DateTime) + last_login = Column(DateTime) + login_failed_count = Column(Integer, default=0) + last_failed_login = Column(DateTime) + login_locked_until = Column(DateTime) + is_deleted = Column(Integer, default=0) + deleted_at = Column(DateTime) + + +class SqliteConversation(SQLiteBase): + __tablename__ = "conversations" + + id = Column(Integer, primary_key=True) + thread_id = Column(String(50), unique=True) + user_id = Column(String(64), nullable=False) + agent_id = Column(String(50)) + title = Column(String(255)) + status = Column(String(20), default="active") + created_at = Column(DateTime) + updated_at = Column(DateTime) + extra_metadata = Column(Text) + + +class SqliteMessage(SQLiteBase): + __tablename__ = "messages" + + id = Column(Integer, primary_key=True) + conversation_id = Column(Integer, nullable=False) + role = Column(String(20), nullable=False) + content = Column(Text) + message_type = Column(String(20), default="text") + created_at = Column(DateTime) + token_count = Column(Integer) + extra_metadata = Column(Text) + image_content = Column(Text) + + +class SqliteToolCall(SQLiteBase): + __tablename__ = "tool_calls" + + id = Column(Integer, primary_key=True) + message_id = Column(Integer, nullable=False) + langgraph_tool_call_id = Column(String(100)) + tool_name = Column(String(100)) + tool_input = Column(Text) + tool_output = Column(Text) + status = Column(String(20), default="pending") + error_message = Column(Text) + created_at = Column(DateTime) + + +class SqliteConversationStats(SQLiteBase): + __tablename__ = "conversation_stats" + + id = Column(Integer, primary_key=True) + conversation_id = Column(Integer, nullable=False) + message_count = Column(Integer, default=0) + total_tokens = Column(Integer, default=0) + model_used = Column(String(100)) + user_feedback = Column(String(20)) + created_at = Column(DateTime) + updated_at = Column(DateTime) + + +class SqliteOperationLog(SQLiteBase): + __tablename__ = "operation_logs" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer) + operation = Column(String(100)) + details = Column(Text) + ip_address = Column(String(50)) + timestamp = Column(DateTime) + + +class SqliteMessageFeedback(SQLiteBase): + __tablename__ = "message_feedbacks" + + id = Column(Integer, primary_key=True) + message_id = Column(Integer, nullable=False) + user_id = Column(String(64), nullable=False) + rating = Column(String(20)) + reason = Column(Text) + created_at = Column(DateTime) + + +class SqliteMCPServer(SQLiteBase): + __tablename__ = "mcp_servers" + + # 注意:SQLite 中 mcp_servers 表没有 id 列,主键是 name + name = Column(String(100), unique=True, nullable=False, primary_key=True) + description = Column(Text) + transport = Column(String(20), default="sse") + url = Column(String(500)) + command = Column(String(255)) + args = Column(Text) + headers = Column(Text) + timeout = Column(Integer) + sse_read_timeout = Column(Integer) + tags = Column(Text) + icon = Column(String(500)) + enabled = Column(Integer, default=1) + disabled_tools = Column(Text) + created_by = Column(String(100), nullable=False) + updated_by = Column(String(100), nullable=False) + created_at = Column(DateTime) + updated_at = Column(DateTime) + + +# ============================================================ +# 工具函数 +# ============================================================ + +def _utc_dt(value: Any) -> datetime | None: + """转换各种 datetime 格式为 naive UTC datetime""" + if not value: + return None + if isinstance(value, datetime): + if value.tzinfo is None: + return value + return value.astimezone(UTC).replace(tzinfo=None) + if isinstance(value, (int, float)): + return datetime.fromtimestamp(value, tz=UTC).replace(tzinfo=None) + if isinstance(value, str): + v = value.strip() + if not v: + return None + try: + dt_val = datetime.fromisoformat(v.replace("Z", "+00:00")) + if dt_val.tzinfo is None: + return dt_val + return dt_val.astimezone(UTC).replace(tzinfo=None) + except ValueError: + return None + return None + + +def _load_json(path: str) -> dict[str, Any]: + """加载 JSON 文件""" + if not os.path.exists(path): + return {} + with open(path, encoding="utf-8") as f: + return json.load(f) + + +def _parse_json(value: Any) -> Any: + """解析 JSON 字符串或返回原值""" + if value is None: + return None + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + return None + return value + + +def _log_separator(title: str = "", char: str = "=", width: int = 60) -> str: + """生成分隔线""" + if title: + return f"{char * ((width - len(title) - 2) // 2)} {title} {char * ((width - len(title) - 2) // 2)}" + return char * width + + +# ============================================================ +# SQLite 读取器 +# ============================================================ + +class SQLiteReader: + """SQLite 数据读取器""" + + def __init__(self): + db_path = os.path.join(config.save_dir, "database", "server.db") + if not os.path.exists(db_path): + raise FileNotFoundError(f"SQLite 数据库不存在: {db_path}") + self.engine = create_engine(f"sqlite:///{db_path}") + self.Session = sessionmaker(bind=self.engine) + + def get_session(self): + return self.Session() + + def count_table(self, table_name: str) -> int: + with self.get_session() as session: + result = session.execute(text(f"SELECT COUNT(*) FROM {table_name}")) + return result.scalar() or 0 + + def read_all(self, model): + with self.get_session() as session: + return session.execute(select(model)).scalars().all() + + +# ============================================================ +# 迁移阶段实现 +# ============================================================ + +class MigrationRunner: + """迁移运行器""" + + def __init__(self, dry_run: bool = False): + self.dry_run = dry_run + self.results: list[MigrationResult] = [] + self.start_time: datetime | None = None + + def log(self, message: str, level: str = "INFO"): + """带时间戳的日志输出""" + now = datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S") + prefix = { + "INFO": "ℹ️", + "WARN": "⚠️", + "ERROR": "❌", + "SUCCESS": "✅", + "STAGE": "🔄", + }.get(level, "ℹ️") + logger.info(f"[{now}] {prefix} {message}") + + async def run_stage(self, stage: MigrationStage) -> MigrationResult: + """执行单个迁移阶段""" + start = datetime.now() + result = MigrationResult(stage_name=stage.name, success=False, dry_run=self.dry_run) + + self.log(_log_separator(f"阶段: {stage.name}")) + self.log(stage.description) + + try: + if self.dry_run: + self.log("[DRY-RUN] 预览模式,跳过实际迁移") + result.success = True + else: + result = await stage.migrate_fn(result) + result.success = True + + except Exception as e: + result.error = str(e) + result.success = False + self.log(f"迁移失败: {e}", level="ERROR") + + result.duration_ms = (datetime.now() - start).total_seconds() * 1000 + self.results.append(result) + + # 输出阶段结果 + status = "✅ 成功" if result.success else "❌ 失败" + self.log(f"阶段完成: {status} ({result.duration_ms:.1f}ms)") + if result.records_total > 0: + self.log(f" 记录: {result.records_migrated}/{result.records_total} 迁移, {result.records_skipped} 跳过") + + return result + + # ----- 业务数据迁移阶段 ----- + + async def migrate_business_departments(self, result: MigrationResult) -> MigrationResult: + """迁移部门数据""" + sqlite_reader = SQLiteReader() + sqlite_depts = sqlite_reader.read_all(SqliteDepartment) + result.records_total = len(sqlite_depts) + + if self.dry_run: + for d in sqlite_depts: + self.log(f"[DRY-RUN] 将创建部门: {d.name}") + return result + + async with pg_manager.get_async_session_context() as session: + for sqlite_dept in sqlite_depts: + existing = await session.execute(select(Department).where(Department.id == sqlite_dept.id)) + if existing.scalar_one_or_none(): + result.records_skipped += 1 + continue + dept = Department( + id=sqlite_dept.id, + name=sqlite_dept.name, + description=sqlite_dept.description, + created_at=_utc_dt(sqlite_dept.created_at), + ) + session.add(dept) + result.records_migrated += 1 + + return result + + async def migrate_business_users(self, result: MigrationResult) -> MigrationResult: + """迁移用户数据""" + sqlite_reader = SQLiteReader() + sqlite_users = sqlite_reader.read_all(SqliteUser) + result.records_total = len(sqlite_users) + + if self.dry_run: + for u in sqlite_users: + self.log(f"[DRY-RUN] 将创建用户: {u.username} ({u.user_id})") + return result + + async with pg_manager.get_async_session_context() as session: + for sqlite_user in sqlite_users: + existing = await session.execute(select(User).where(User.id == sqlite_user.id)) + if existing.scalar_one_or_none(): + result.records_skipped += 1 + continue + user = User( + id=sqlite_user.id, + username=sqlite_user.username, + user_id=sqlite_user.user_id, + phone_number=sqlite_user.phone_number, + avatar=sqlite_user.avatar, + password_hash=sqlite_user.password_hash, + role=sqlite_user.role, + department_id=sqlite_user.department_id, + created_at=_utc_dt(sqlite_user.created_at), + last_login=_utc_dt(sqlite_user.last_login), + login_failed_count=sqlite_user.login_failed_count, + last_failed_login=_utc_dt(sqlite_user.last_failed_login), + login_locked_until=_utc_dt(sqlite_user.login_locked_until), + is_deleted=sqlite_user.is_deleted, + deleted_at=_utc_dt(sqlite_user.deleted_at), + ) + session.add(user) + result.records_migrated += 1 + + return result + + async def migrate_business_conversations(self, result: MigrationResult) -> MigrationResult: + """迁移对话数据""" + sqlite_reader = SQLiteReader() + sqlite_convs = sqlite_reader.read_all(SqliteConversation) + result.records_total = len(sqlite_convs) + + if self.dry_run: + for c in sqlite_convs: + self.log(f"[DRY-RUN] 将创建对话: {c.thread_id}") + return result + + async with pg_manager.get_async_session_context() as session: + for sqlite_conv in sqlite_convs: + existing = await session.execute(select(Conversation).where(Conversation.id == sqlite_conv.id)) + if existing.scalar_one_or_none(): + result.records_skipped += 1 + continue + title = sqlite_conv.title + if title and len(title) > 255: + title = title[:255] + conv = Conversation( + id=sqlite_conv.id, + thread_id=sqlite_conv.thread_id, + user_id=sqlite_conv.user_id, + agent_id=sqlite_conv.agent_id, + title=title, + status=sqlite_conv.status, + created_at=_utc_dt(sqlite_conv.created_at), + updated_at=_utc_dt(sqlite_conv.updated_at), + extra_metadata=_parse_json(sqlite_conv.extra_metadata), + ) + session.add(conv) + result.records_migrated += 1 + + return result + + async def migrate_business_messages(self, result: MigrationResult) -> MigrationResult: + """迁移消息数据""" + sqlite_reader = SQLiteReader() + sqlite_msgs = sqlite_reader.read_all(SqliteMessage) + result.records_total = len(sqlite_msgs) + + if self.dry_run: + self.log(f"[DRY-RUN] 将创建 {len(sqlite_msgs)} 条消息") + return result + + async with pg_manager.get_async_session_context() as session: + for sqlite_msg in sqlite_msgs: + existing = await session.execute(select(Message).where(Message.id == sqlite_msg.id)) + if existing.scalar_one_or_none(): + result.records_skipped += 1 + continue + msg = Message( + id=sqlite_msg.id, + conversation_id=sqlite_msg.conversation_id, + role=sqlite_msg.role, + content=sqlite_msg.content, + message_type=sqlite_msg.message_type, + created_at=_utc_dt(sqlite_msg.created_at), + token_count=sqlite_msg.token_count, + extra_metadata=_parse_json(sqlite_msg.extra_metadata), + image_content=sqlite_msg.image_content, + ) + session.add(msg) + result.records_migrated += 1 + + return result + + async def migrate_business_tool_calls(self, result: MigrationResult) -> MigrationResult: + """迁移工具调用数据""" + sqlite_reader = SQLiteReader() + sqlite_calls = sqlite_reader.read_all(SqliteToolCall) + result.records_total = len(sqlite_calls) + + if self.dry_run: + self.log(f"[DRY-RUN] 将创建 {len(sqlite_calls)} 个工具调用") + return result + + async with pg_manager.get_async_session_context() as session: + for sqlite_call in sqlite_calls: + existing = await session.execute(select(ToolCall).where(ToolCall.id == sqlite_call.id)) + if existing.scalar_one_or_none(): + result.records_skipped += 1 + continue + call = ToolCall( + id=sqlite_call.id, + message_id=sqlite_call.message_id, + langgraph_tool_call_id=sqlite_call.langgraph_tool_call_id, + tool_name=sqlite_call.tool_name, + tool_input=_parse_json(sqlite_call.tool_input), + tool_output=sqlite_call.tool_output, + status=sqlite_call.status, + error_message=sqlite_call.error_message, + created_at=_utc_dt(sqlite_call.created_at), + ) + session.add(call) + result.records_migrated += 1 + + return result + + async def migrate_business_stats(self, result: MigrationResult) -> MigrationResult: + """迁移对话统计数据""" + sqlite_reader = SQLiteReader() + sqlite_stats = sqlite_reader.read_all(SqliteConversationStats) + result.records_total = len(sqlite_stats) + + if self.dry_run: + self.log(f"[DRY-RUN] 将创建 {len(sqlite_stats)} 条对话统计") + return result + + async with pg_manager.get_async_session_context() as session: + for sqlite_stat in sqlite_stats: + existing = await session.execute( + select(ConversationStats).where(ConversationStats.id == sqlite_stat.id) + ) + if existing.scalar_one_or_none(): + result.records_skipped += 1 + continue + stat = ConversationStats( + id=sqlite_stat.id, + conversation_id=sqlite_stat.conversation_id, + message_count=sqlite_stat.message_count, + total_tokens=sqlite_stat.total_tokens, + model_used=sqlite_stat.model_used, + user_feedback=sqlite_stat.user_feedback, + created_at=_utc_dt(sqlite_stat.created_at), + updated_at=_utc_dt(sqlite_stat.updated_at), + ) + session.add(stat) + result.records_migrated += 1 + + return result + + async def migrate_business_operation_logs(self, result: MigrationResult) -> MigrationResult: + """迁移操作日志""" + sqlite_reader = SQLiteReader() + sqlite_logs = sqlite_reader.read_all(SqliteOperationLog) + result.records_total = len(sqlite_logs) + + if self.dry_run: + self.log(f"[DRY-RUN] 将创建 {len(sqlite_logs)} 条操作日志") + return result + + async with pg_manager.get_async_session_context() as session: + for sqlite_log in sqlite_logs: + existing = await session.execute(select(OperationLog).where(OperationLog.id == sqlite_log.id)) + if existing.scalar_one_or_none(): + result.records_skipped += 1 + continue + log = OperationLog( + id=sqlite_log.id, + user_id=sqlite_log.user_id, + operation=sqlite_log.operation, + details=sqlite_log.details, + ip_address=sqlite_log.ip_address, + timestamp=_utc_dt(sqlite_log.timestamp), + ) + session.add(log) + result.records_migrated += 1 + + return result + + async def migrate_business_feedbacks(self, result: MigrationResult) -> MigrationResult: + """迁移消息反馈""" + sqlite_reader = SQLiteReader() + sqlite_fbs = sqlite_reader.read_all(SqliteMessageFeedback) + result.records_total = len(sqlite_fbs) + + if self.dry_run: + self.log(f"[DRY-RUN] 将创建 {len(sqlite_fbs)} 条消息反馈") + return result + + async with pg_manager.get_async_session_context() as session: + for sqlite_fb in sqlite_fbs: + existing = await session.execute(select(MessageFeedback).where(MessageFeedback.id == sqlite_fb.id)) + if existing.scalar_one_or_none(): + result.records_skipped += 1 + continue + fb = MessageFeedback( + id=sqlite_fb.id, + message_id=sqlite_fb.message_id, + user_id=sqlite_fb.user_id, + rating=sqlite_fb.rating, + reason=sqlite_fb.reason, + created_at=_utc_dt(sqlite_fb.created_at), + ) + session.add(fb) + result.records_migrated += 1 + + return result + + async def migrate_business_mcp_servers(self, result: MigrationResult) -> MigrationResult: + """迁移 MCP 服务器""" + sqlite_reader = SQLiteReader() + sqlite_servers = sqlite_reader.read_all(SqliteMCPServer) + result.records_total = len(sqlite_servers) + + if self.dry_run: + for s in sqlite_servers: + self.log(f"[DRY-RUN] 将创建 MCP 服务器: {s.name}") + return result + + async with pg_manager.get_async_session_context() as session: + for sqlite_server in sqlite_servers: + existing = await session.execute(select(MCPServer).where(MCPServer.name == sqlite_server.name)) + if existing.scalar_one_or_none(): + result.records_skipped += 1 + continue + server = MCPServer( + name=sqlite_server.name, + description=sqlite_server.description, + transport=sqlite_server.transport, + url=sqlite_server.url, + command=sqlite_server.command, + args=sqlite_server.args, + headers=sqlite_server.headers, + timeout=sqlite_server.timeout, + sse_read_timeout=sqlite_server.sse_read_timeout, + tags=sqlite_server.tags, + icon=sqlite_server.icon, + enabled=sqlite_server.enabled, + disabled_tools=sqlite_server.disabled_tools, + created_by=sqlite_server.created_by, + updated_by=sqlite_server.updated_by, + created_at=_utc_dt(sqlite_server.created_at), + updated_at=_utc_dt(sqlite_server.updated_at), + ) + session.add(server) + result.records_migrated += 1 + + return result + + # ----- 知识库迁移阶段 ----- + + async def migrate_knowledge_bases(self, result: MigrationResult) -> MigrationResult: + """迁移知识库""" + base_dir = os.path.join(config.save_dir, "knowledge_base_data") + global_meta_path = os.path.join(base_dir, "global_metadata.json") + global_meta = _load_json(global_meta_path).get("databases", {}) + + kb_rows = [] + kb_type_dirs = [ + p for p in glob.glob(os.path.join(base_dir, "*_data")) + if os.path.isdir(p) and os.path.basename(p) != "uploads" + ] + + for kb_dir in kb_type_dirs: + kb_type = os.path.basename(kb_dir)[: -len("_data")] + meta_file = os.path.join(kb_dir, f"metadata_{kb_type}.json") + meta = _load_json(meta_file) + databases_meta = meta.get("databases", {}) + + for db_id, db_meta in databases_meta.items(): + g = global_meta.get(db_id, {}) + created_at = _utc_dt(g.get("created_at") or db_meta.get("created_at")) + updated_at = _utc_dt(g.get("updated_at")) or created_at + kb_rows.append({ + "db_id": db_id, + "name": g.get("name") or db_meta.get("name") or db_id, + "description": g.get("description") or db_meta.get("description"), + "kb_type": g.get("kb_type") or db_meta.get("kb_type") or kb_type, + "embed_info": db_meta.get("embed_info") or g.get("embed_info"), + "llm_info": db_meta.get("llm_info") or g.get("llm_info"), + "query_params": db_meta.get("query_params") or g.get("query_params"), + "additional_params": g.get("additional_params") or db_meta.get("metadata") or {}, + "share_config": {"is_shared": True, "accessible_departments": []}, + "mindmap": g.get("mindmap"), + "sample_questions": g.get("sample_questions") or [], + "created_at": created_at, + "updated_at": updated_at, + }) + + result.records_total = len(kb_rows) + + if self.dry_run: + for kb in kb_rows: + self.log(f"[DRY-RUN] 将创建知识库: {kb['name']} ({kb['db_id']})") + return result + + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + kb_repo = KnowledgeBaseRepository() + + for payload in kb_rows: + db_id = payload["db_id"] + existing = await kb_repo.get_by_id(db_id) + if existing: + result.records_skipped += 1 + continue + await kb_repo.create(payload) + result.records_migrated += 1 + + return result + + async def migrate_knowledge_files(self, result: MigrationResult) -> MigrationResult: + """迁移知识文件""" + base_dir = os.path.join(config.save_dir, "knowledge_base_data") + + file_rows = [] + kb_type_dirs = [ + p for p in glob.glob(os.path.join(base_dir, "*_data")) + if os.path.isdir(p) and os.path.basename(p) != "uploads" + ] + + for kb_dir in kb_type_dirs: + meta_file = os.path.join(kb_dir, f"metadata_{os.path.basename(kb_dir)[:-5]}.json") + meta = _load_json(meta_file) + files_meta = meta.get("files", {}) + + for file_id, fmeta in files_meta.items(): + db_id = fmeta.get("database_id") + if not db_id: + continue + file_rows.append({ + "file_id": file_id, + "db_id": db_id, + "parent_id": fmeta.get("parent_id"), + "filename": fmeta.get("filename") or "", + "original_filename": fmeta.get("original_filename") or fmeta.get("file_name"), + "file_type": fmeta.get("file_type") or fmeta.get("type"), + "path": fmeta.get("path"), + "minio_url": fmeta.get("minio_url"), + "markdown_file": fmeta.get("markdown_file"), + "status": fmeta.get("status"), + "content_hash": fmeta.get("content_hash"), + "file_size": fmeta.get("size") or fmeta.get("file_size"), + "content_type": fmeta.get("content_type"), + "processing_params": fmeta.get("processing_params"), + "is_folder": bool(fmeta.get("is_folder", False)), + "error_message": fmeta.get("error") or fmeta.get("error_message"), + "created_by": str(fmeta.get("created_by")) if fmeta.get("created_by") else None, + "updated_by": str(fmeta.get("updated_by")) if fmeta.get("updated_by") else None, + "created_at": _utc_dt(fmeta.get("created_at")), + "updated_at": _utc_dt(fmeta.get("updated_at")) or _utc_dt(fmeta.get("created_at")), + }) + + result.records_total = len(file_rows) + + if self.dry_run: + folders = [f for f in file_rows if f["is_folder"]] + files = [f for f in file_rows if not f["is_folder"]] + self.log(f"[DRY-RUN] 将创建 {len(folders)} 个文件夹和 {len(files)} 个文件") + return result + + from src.repositories.knowledge_file_repository import KnowledgeFileRepository + file_repo = KnowledgeFileRepository() + + # 先插入文件夹 + folders = [(f["file_id"], f) for f in file_rows if f["is_folder"]] + files = [(f["file_id"], f) for f in file_rows if not f["is_folder"]] + + for file_id, data in folders: + data = data.copy() + data.pop("file_id", None) # 移除重复的 file_id + await file_repo.upsert(file_id=file_id, data=data) + result.records_migrated += 1 + + for file_id, data in files: + data = data.copy() + data.pop("file_id", None) # 移除重复的 file_id + await file_repo.upsert(file_id=file_id, data=data) + result.records_migrated += 1 + + return result + + async def migrate_knowledge_evaluations(self, result: MigrationResult) -> MigrationResult: + """迁移评估数据""" + base_dir = os.path.join(config.save_dir, "knowledge_base_data") + total_migrated = 0 + + kb_type_dirs = [ + p for p in glob.glob(os.path.join(base_dir, "*_data")) + if os.path.isdir(p) and os.path.basename(p) != "uploads" + ] + + from src.repositories.evaluation_repository import EvaluationRepository + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + eval_repo = EvaluationRepository() + kb_repo = KnowledgeBaseRepository() + + # 迁移评估基准 + benchmark_rows = [] + for kb_dir in kb_type_dirs: + kb_type = os.path.basename(kb_dir)[: -len("_data")] + meta_file = os.path.join(kb_dir, f"metadata_{kb_type}.json") + meta = _load_json(meta_file) + benchmarks_meta = meta.get("benchmarks", {}) + + for db_id, bmap in benchmarks_meta.items(): + if not isinstance(bmap, dict): + continue + for benchmark_id, bmeta in bmap.items(): + benchmark_rows.append({ + "benchmark_id": benchmark_id, + "db_id": db_id, + "name": bmeta.get("name") or benchmark_id, + "description": bmeta.get("description"), + "question_count": int(bmeta.get("question_count") or 0), + "has_gold_chunks": bool(bmeta.get("has_gold_chunks")), + "has_gold_answers": bool(bmeta.get("has_gold_answers")), + "data_file_path": bmeta.get("benchmark_file") or bmeta.get("data_file_path"), + "created_by": str(bmeta.get("created_by")) if bmeta.get("created_by") else None, + "created_at": _utc_dt(bmeta.get("created_at")), + "updated_at": _utc_dt(bmeta.get("updated_at")) or _utc_dt(bmeta.get("created_at")), + }) + + result.records_total += len(benchmark_rows) + + if self.dry_run: + self.log(f"[DRY-RUN] 将创建 {len(benchmark_rows)} 个评估基准") + return result + + for payload in benchmark_rows: + existing = await eval_repo.get_benchmark(payload["benchmark_id"]) + if existing: + result.records_skipped += 1 + continue + # 检查知识库是否存在 + kb = await kb_repo.get_by_id(payload["db_id"]) + if kb is None: + self.log(f" 跳过评估基准 {payload['benchmark_id']}: 知识库 {payload['db_id']} 不存在") + result.records_skipped += 1 + continue + await eval_repo.create_benchmark(payload) + total_migrated += 1 + + # 迁移评估结果 + result_rows = [] + result_detail_rows = [] + + for kb_dir in kb_type_dirs: + kb_type = os.path.basename(kb_dir)[: -len("_data")] + meta_file = os.path.join(kb_dir, f"metadata_{kb_type}.json") + meta = _load_json(meta_file) + databases_meta = meta.get("databases", {}) + + for db_id in databases_meta.keys(): + result_dir = os.path.join(kb_dir, db_id, "results") + if not os.path.isdir(result_dir): + continue + for result_path in glob.glob(os.path.join(result_dir, "*.json")): + try: + data = _load_json(result_path) + except Exception: + continue + task_id = data.get("task_id") or os.path.splitext(os.path.basename(result_path))[0] + benchmark_id = data.get("benchmark_id") + started_at = _utc_dt(data.get("started_at")) + result_rows.append({ + "task_id": task_id, + "db_id": db_id, + "benchmark_id": benchmark_id, + "status": data.get("status") or "completed", + "retrieval_config": data.get("retrieval_config") or {}, + "metrics": data.get("metrics") or {}, + "overall_score": data.get("overall_score"), + "total_questions": int(data.get("total_questions") or 0), + "completed_questions": int(data.get("completed_questions") or 0), + "started_at": started_at, + "completed_at": _utc_dt(data.get("completed_at")) or started_at, + "created_by": str(data.get("created_by")) if data.get("created_by") else None, + }) + interim = data.get("interim_results") or data.get("results") or [] + for idx, item in enumerate(interim): + result_detail_rows.append({ + "task_id": task_id, + "query_index": idx, + "query_text": item.get("query") or item.get("query_text") or "", + "gold_chunk_ids": item.get("gold_chunk_ids"), + "gold_answer": item.get("gold_answer"), + "generated_answer": item.get("generated_answer"), + "retrieved_chunks": item.get("retrieved_chunks"), + "metrics": item.get("metrics") or {}, + }) + + result.records_total += len(result_rows) + len(result_detail_rows) + + if self.dry_run: + self.log(f"[DRY-RUN] 将创建 {len(result_rows)} 个评估结果和 {len(result_detail_rows)} 条详情") + return result + + for payload in result_rows: + existing = await eval_repo.get_result(payload["task_id"]) + if existing: + result.records_skipped += 1 + continue + # 检查知识库是否存在 + kb = await kb_repo.get_by_id(payload["db_id"]) + if kb is None: + self.log(f" 跳过评估结果 {payload['task_id']}: 知识库 {payload['db_id']} 不存在") + result.records_skipped += 1 + continue + await eval_repo.create_result(payload) + total_migrated += 1 + + for detail in result_detail_rows: + await eval_repo.upsert_result_detail( + task_id=detail["task_id"], + query_index=detail["query_index"], + data={ + "query_text": detail["query_text"], + "gold_chunk_ids": detail["gold_chunk_ids"], + "gold_answer": detail["gold_answer"], + "generated_answer": detail["generated_answer"], + "retrieved_chunks": detail["retrieved_chunks"], + "metrics": detail["metrics"], + }, + ) + total_migrated += 1 + + result.records_migrated = total_migrated + return result + + async def migrate_knowledge_tasks(self, result: MigrationResult) -> MigrationResult: + """迁移任务记录""" + tasks_json_path = os.path.join(config.save_dir, "tasks", "tasks.json") + task_items = _load_json(tasks_json_path).get("tasks", []) or [] + result.records_total = len(task_items) + + if self.dry_run: + self.log(f"[DRY-RUN] 将迁移 {len(task_items)} 个任务记录") + return result + + from src.repositories.task_repository import TaskRepository + task_repo = TaskRepository() + + for item in task_items: + task_id = item.get("id") + if not task_id: + continue + payload = item.get("payload") or {} + await task_repo.upsert( + task_id, + { + "name": item.get("name") or "Unnamed Task", + "type": item.get("type") or "general", + "status": item.get("status") or "pending", + "progress": float(item.get("progress") or 0.0), + "message": item.get("message") or "", + "payload": payload, + "result": item.get("result"), + "error": item.get("error"), + "cancel_requested": 1 if item.get("cancel_requested") else 0, + "created_at": _utc_dt(item.get("created_at")), + "updated_at": _utc_dt(item.get("updated_at")) or _utc_dt(item.get("created_at")), + "started_at": _utc_dt(item.get("started_at")), + "completed_at": _utc_dt(item.get("completed_at")), + }, + ) + result.records_migrated += 1 + + return result + + # ----- 回滚函数 ----- + + async def rollback_business(self) -> None: + """回滚业务数据""" + self.log(_log_separator("回滚: 业务数据"), level="WARN") + + tables = [ + MessageFeedback, + OperationLog, + ConversationStats, + ToolCall, + Message, + Conversation, + User, + Department, + MCPServer, + ] + + for model in tables: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(model)) + records = result.scalars().all() + for record in records: + await session.delete(record) + self.log(f" 已删除 {len(records)} 条 {model.__tablename__}") + + async def reset_sequences(self) -> None: + """重置 PostgreSQL 序列值,防止主键冲突 + + 迁移时直接使用了 SQLite 的原始 id 值,导致 PostgreSQL 的序列未同步。 + 此方法将序列值重置为当前最大 id + 1。 + """ + self.log(_log_separator("重置: PostgreSQL 序列"), level="WARN") + + tables_with_sequences = [ + ("departments", "id"), + ("users", "id"), + ("conversations", "id"), + ("messages", "id"), + ("tool_calls", "id"), + ("conversation_stats", "id"), + ("operation_logs", "id"), + ("message_feedbacks", "id"), + ("mcp_servers", None), # name 是主键,不是 serial + ("knowledge_bases", "id"), + ("knowledge_files", "id"), + ("evaluation_benchmarks", "id"), + ("evaluation_results", "id"), + ("evaluation_result_details", "id"), + ] + + async with pg_manager.get_async_session_context() as session: + for table_name, pk_column in tables_with_sequences: + if pk_column is None: + continue # 非自增主键,跳过 + try: + # 使用单条 SQL 获取 max_id 并重置序列 + await session.execute( + text(f""" + SELECT setval( + pg_get_serial_sequence('{table_name}', '{pk_column}'), + COALESCE((SELECT MAX({pk_column}) FROM {table_name}), 0) + 1 + ) + """) + ) + self.log(f" {table_name}: 序列已重置") + except Exception as e: + self.log(f" {table_name}: 重置序列失败 - {e}", level="WARN") + + async def rollback_knowledge(self) -> None: + """回滚知识库数据""" + self.log(_log_separator("回滚: 知识库数据"), level="WARN") + + from src.repositories.evaluation_repository import EvaluationRepository + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + from src.repositories.knowledge_file_repository import KnowledgeFileRepository + + eval_repo = EvaluationRepository() + kb_repo = KnowledgeBaseRepository() + file_repo = KnowledgeFileRepository() + + # 回滚顺序:子表 -> 父表 + await eval_repo.delete_all() + self.log(" 已删除所有评估数据") + + rows = await kb_repo.get_all() + for row in rows: + await file_repo.delete_by_db_id(row.db_id) + await kb_repo.delete(row.db_id) + self.log(f" 已删除 {len(rows)} 个知识库及其文件") + + async def rollback_tasker(self) -> None: + """回滚 Tasker 任务记录""" + self.log(_log_separator("回滚: Tasker 任务记录"), level="WARN") + + from src.repositories.task_repository import TaskRepository + + task_repo = TaskRepository() + await task_repo.delete_all() + self.log(" 已删除所有任务记录") + + # ----- 验证函数 ----- + + async def verify_business(self) -> dict: + """验证业务数据""" + self.log(_log_separator("验证: 业务数据")) + results = {} + + try: + sqlite_reader = SQLiteReader() + except FileNotFoundError: + self.log("SQLite 数据库不存在,跳过验证", level="WARN") + return {} + + sqlite_tables = { + "departments": SqliteDepartment, + "users": SqliteUser, + "conversations": SqliteConversation, + "messages": SqliteMessage, + "tool_calls": SqliteToolCall, + "conversation_stats": SqliteConversationStats, + "operation_logs": SqliteOperationLog, + "message_feedbacks": SqliteMessageFeedback, + "mcp_servers": SqliteMCPServer, + } + + pg_models = { + "departments": Department, + "users": User, + "conversations": Conversation, + "messages": Message, + "tool_calls": ToolCall, + "conversation_stats": ConversationStats, + "operation_logs": OperationLog, + "message_feedbacks": MessageFeedback, + "mcp_servers": (MCPServer, "name"), # MCPServer 主键是 name + } + + for table_name, sqlite_model in sqlite_tables.items(): + sqlite_count = sqlite_reader.count_table(table_name) + + pg_model_info = pg_models[table_name] + # 支持 (Model, pk_column) 元组形式 + if isinstance(pg_model_info, tuple): + pg_model, pk_column = pg_model_info + else: + pg_model, pk_column = pg_model_info, "id" + + async with pg_manager.get_async_session_context() as session: + from sqlalchemy import func + result = await session.execute(select(func.count(getattr(pg_model, pk_column)))) + pg_count = result.scalar() or 0 + + match = sqlite_count == pg_count + status = "✅" if match else "❌" + results[table_name] = {"sqlite": sqlite_count, "pg": pg_count, "match": match} + self.log(f" {status} {table_name}: SQLite={sqlite_count}, PG={pg_count}") + + return results + + async def verify_knowledge(self) -> dict: + """验证知识库数据""" + self.log(_log_separator("验证: 知识库数据")) + results = {} + + base_dir = os.path.join(config.save_dir, "knowledge_base_data") + + # 统计 JSON 文件中的数据 + json_kb_count = 0 + json_file_count = 0 + + kb_type_dirs = [ + p for p in glob.glob(os.path.join(base_dir, "*_data")) + if os.path.isdir(p) and os.path.basename(p) != "uploads" + ] + + for kb_dir in kb_type_dirs: + kb_type = os.path.basename(kb_dir)[: -len("_data")] + meta_file = os.path.join(kb_dir, f"metadata_{kb_type}.json") + meta = _load_json(meta_file) + json_kb_count += len(meta.get("databases", {})) + json_file_count += len(meta.get("files", {})) + + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + from src.repositories.knowledge_file_repository import KnowledgeFileRepository + + kb_repo = KnowledgeBaseRepository() + file_repo = KnowledgeFileRepository() + + pg_kb_count = len(await kb_repo.get_all()) + # 统计文件数量 + all_files = [] + rows = await kb_repo.get_all() + for row in rows: + files = await file_repo.list_by_db_id(row.db_id) + all_files.extend(files) + + pg_file_count = len(all_files) + + results["knowledge_bases"] = {"json": json_kb_count, "pg": pg_kb_count, "match": json_kb_count == pg_kb_count} + results["knowledge_files"] = {"json": json_file_count, "pg": pg_file_count, "match": json_file_count == pg_file_count} + + status_kb = "✅" if results["knowledge_bases"]["match"] else "❌" + status_file = "✅" if results["knowledge_files"]["match"] else "❌" + + self.log(f" {status_kb} knowledge_bases: JSON={json_kb_count}, PG={pg_kb_count}") + self.log(f" {status_file} knowledge_files: JSON={json_file_count}, PG={pg_file_count}") + + return results + + +# ============================================================ +# 阶段定义 +# ============================================================ + +def get_stages() -> dict[str, MigrationStage]: + """获取所有迁移阶段""" + runner = MigrationRunner() + + return { + # 业务数据阶段 (按外键依赖顺序) + "business-departments": MigrationStage( + name="business-departments", + description="迁移部门数据 (departments)", + migrate_fn=runner.migrate_business_departments, + rollback_fn=None, # 依赖业务回滚整体处理 + ), + "business-users": MigrationStage( + name="business-users", + description="迁移用户数据 (users),依赖 departments", + migrate_fn=runner.migrate_business_users, + depends_on=["business-departments"], + ), + "business-conversations": MigrationStage( + name="business-conversations", + description="迁移对话数据 (conversations)", + migrate_fn=runner.migrate_business_conversations, + depends_on=["business-users"], + ), + "business-messages": MigrationStage( + name="business-messages", + description="迁移消息数据 (messages),依赖 conversations", + migrate_fn=runner.migrate_business_messages, + depends_on=["business-conversations"], + ), + "business-tool-calls": MigrationStage( + name="business-tool-calls", + description="迁移工具调用数据 (tool_calls),依赖 messages", + migrate_fn=runner.migrate_business_tool_calls, + depends_on=["business-messages"], + ), + "business-stats": MigrationStage( + name="business-stats", + description="迁移对话统计数据 (conversation_stats)", + migrate_fn=runner.migrate_business_stats, + depends_on=["business-conversations"], + ), + "business-operation-logs": MigrationStage( + name="business-operation-logs", + description="迁移操作日志 (operation_logs)", + migrate_fn=runner.migrate_business_operation_logs, + depends_on=["business-users"], + ), + "business-feedbacks": MigrationStage( + name="business-feedbacks", + description="迁移消息反馈 (message_feedbacks)", + migrate_fn=runner.migrate_business_feedbacks, + depends_on=["business-messages"], + ), + "business-mcp-servers": MigrationStage( + name="business-mcp-servers", + description="迁移 MCP 服务器配置 (mcp_servers)", + migrate_fn=runner.migrate_business_mcp_servers, + ), + # 知识库阶段 + "knowledge-bases": MigrationStage( + name="knowledge-bases", + description="迁移知识库元数据 (knowledge_bases)", + migrate_fn=runner.migrate_knowledge_bases, + ), + "knowledge-files": MigrationStage( + name="knowledge-files", + description="迁移知识文件元数据 (knowledge_files),依赖 knowledge_bases", + migrate_fn=runner.migrate_knowledge_files, + depends_on=["knowledge-bases"], + ), + "knowledge-evaluations": MigrationStage( + name="knowledge-evaluations", + description="迁移评估数据 (benchmarks, results)", + migrate_fn=runner.migrate_knowledge_evaluations, + depends_on=["knowledge-bases"], + ), + # Tasker 阶段(独立于知识库) + "tasker-tasks": MigrationStage( + name="tasker-tasks", + description="迁移 Tasker 任务记录 (tasks)", + migrate_fn=runner.migrate_knowledge_tasks, + ), + } + + +def get_stage_groups() -> dict[str, list[str]]: + """获取阶段组(批量执行)""" + return { + "business": [ + "business-departments", + "business-users", + "business-conversations", + "business-messages", + "business-tool-calls", + "business-stats", + "business-operation-logs", + "business-feedbacks", + "business-mcp-servers", + ], + "knowledge": [ + "knowledge-bases", + "knowledge-files", + "knowledge-evaluations", + ], + "tasker": [ + "tasker-tasks", + ], + "all": list(get_stages().keys()), + } + + +# ============================================================ +# 主函数 +# ============================================================ + +async def main() -> None: + parser = argparse.ArgumentParser(description="统一数据迁移脚本") + parser.add_argument("--dry-run", action="store_true", help="预览迁移,不执行") + parser.add_argument("--execute", action="store_true", help="执行迁移") + parser.add_argument("--verify", action="store_true", help="验证迁移结果") + parser.add_argument("--rollback", action="store_true", help="回滚迁移") + parser.add_argument("--stage", type=str, help="指定阶段或阶段组 (如: business, knowledge, business-users)") + + args = parser.parse_args() + + # 默认dry-run + if not any([args.dry_run, args.execute, args.verify, args.rollback]): + args.dry_run = True + + # 初始化 PostgreSQL + pg_manager.initialize() + await pg_manager.create_tables() + + runner = MigrationRunner(dry_run=args.dry_run) + + # 打印标题 + if args.dry_run: + mode = "预览模式" + elif args.execute: + mode = "执行模式" + elif args.verify: + mode = "验证模式" + else: + mode = "回滚模式" + + print("\n" + "=" * 60) + print(f"🔧 数据迁移工具 | 模式: {mode}") + print("=" * 60) + + if args.verify: + # 验证模式 + results_business = await runner.verify_business() + results_knowledge = await runner.verify_knowledge() + + print("\n" + "=" * 60) + print("📊 验证结果汇总") + print("=" * 60) + + all_match = True + for table_name, counts in {**results_business, **results_knowledge}.items(): + if not counts.get("match", True): + all_match = False + + print(f"全部匹配: {'✅ 是' if all_match else '❌ 否'}") + return + + if args.rollback: + # 回滚模式 + if args.stage == "business" or args.stage is None: + await runner.rollback_business() + if args.stage == "knowledge" or args.stage is None: + await runner.rollback_knowledge() + if args.stage == "tasker" or args.stage is None: + await runner.rollback_tasker() + + if args.stage == "business": + print("\n✅ 已回滚业务数据") + elif args.stage == "knowledge": + print("\n✅ 已回滚知识库数据") + elif args.stage == "tasker": + print("\n✅ 已回滚 Tasker 任务记录") + else: + print("\n✅ 已回滚所有迁移数据") + return + + # 迁移模式 + stages = get_stages() + stage_groups = get_stage_groups() + + # 确定要执行的阶段 + if args.stage and args.stage in stage_groups: + stage_names = stage_groups[args.stage] + elif args.stage and args.stage in stages: + stage_names = [args.stage] + else: + stage_names = stage_groups["all"] + + # 按依赖顺序排序 + sorted_stages = [] + resolved = set() + while sorted_stages.__len__() < len(stage_names): + progress = False + for name in stage_names: + if name in resolved: + continue + stage = stages[name] + if all(dep in resolved for dep in stage.depends_on): + sorted_stages.append(name) + resolved.add(name) + progress = True + if not progress: + raise ValueError(f"无法解析依赖: {set(stage_names) - resolved}") + + print(f"\n📋 将执行 {len(sorted_stages)} 个阶段:") + for name in sorted_stages: + print(f" - {name}") + print() + + # 执行迁移 + total_start = datetime.now() + + for stage_name in sorted_stages: + stage = stages[stage_name] + await runner.run_stage(stage) + + # 重置 PostgreSQL 序列,防止后续插入时主键冲突 + if not args.dry_run: + await runner.reset_sequences() + + total_duration = (datetime.now() - total_start).total_seconds() + + # 输出汇总 + print("\n" + "=" * 60) + print("📊 迁移汇总") + print("=" * 60) + + total_migrated = sum(r.records_migrated for r in runner.results) + total_skipped = sum(r.records_skipped for r in runner.results) + failed = [r for r in runner.results if not r.success] + + print(f"总耗时: {total_duration:.1f}s") + print(f"迁移记录: {total_migrated}") + print(f"跳过记录: {total_skipped}") + print(f"失败阶段: {len(failed)}") + + if failed: + print("\n失败详情:") + for r in failed: + print(f" ❌ {r.stage_name}: {r.error}") + + if not args.dry_run: + print("\n💡 建议运行 --verify 验证数据完整性") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/migrate_business_from_sqlite.py b/scripts/migrate_business_from_sqlite.py new file mode 100644 index 000000000..252ee2d0e --- /dev/null +++ b/scripts/migrate_business_from_sqlite.py @@ -0,0 +1,707 @@ +""" +SQLite 到 PostgreSQL 业务数据迁移脚本 + +将用户、部门、对话等业务数据从 SQLite 迁移到 PostgreSQL。 +迁移顺序(按外键依赖): +1. departments (无依赖) +2. users (依赖 departments) +3. conversations (依赖 users) +4. messages (依赖 conversations) +5. tool_calls (依赖 messages) +6. conversation_stats (依赖 conversations) +7. operation_logs (依赖 users) +8. message_feedbacks (依赖 messages) +9. mcp_servers (无依赖) + +用法: + python scripts/migrate_business_from_sqlite.py --dry-run # 预览迁移 + python scripts/migrate_business_from_sqlite.py --execute # 执行迁移 + python scripts/migrate_business_from_sqlite.py --verify # 验证数据 + python scripts/migrate_business_from_sqlite.py --rollback # 回滚迁移 +""" + +import argparse +import asyncio +import os +import sys +from datetime import datetime, UTC +from typing import Any + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +os.environ.setdefault("YUXI_SKIP_APP_INIT", "1") + +from sqlalchemy import Column, DateTime, Integer, String, Text, create_engine, select, text +from sqlalchemy.orm import declarative_base, sessionmaker + +from src import config +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import ( + Department, + User, + Conversation, + Message, + ToolCall, + ConversationStats, + OperationLog, + MessageFeedback, + MCPServer, +) +from src.utils import logger + + +# ============================================================ +# SQLite 模型定义 (仅用于迁移脚本,内部使用) +# ============================================================ +Base = declarative_base() + + +class SqliteDepartment(Base): + __tablename__ = "departments" + + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + description = Column(Text) + created_at = Column(DateTime) + + +class SqliteUser(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + username = Column(String(50), unique=True, nullable=False) + user_id = Column(String(50), unique=True) + phone_number = Column(String(20)) + avatar = Column(String(500)) + password_hash = Column(String(255)) + role = Column(String(20), default="user") + department_id = Column(Integer) + created_at = Column(DateTime) + last_login = Column(DateTime) + login_failed_count = Column(Integer, default=0) + last_failed_login = Column(DateTime) + login_locked_until = Column(DateTime) + is_deleted = Column(Integer, default=0) # 0=否, 1=是 + deleted_at = Column(DateTime) + + +class SqliteConversation(Base): + __tablename__ = "conversations" + + id = Column(Integer, primary_key=True) + thread_id = Column(String(50), unique=True) + user_id = Column(String(64), nullable=False) + agent_id = Column(String(50)) + title = Column(String(255)) + status = Column(String(20), default="active") + created_at = Column(DateTime) + updated_at = Column(DateTime) + extra_metadata = Column(Text) + + +class SqliteMessage(Base): + __tablename__ = "messages" + + id = Column(Integer, primary_key=True) + conversation_id = Column(Integer, nullable=False) + role = Column(String(20), nullable=False) + content = Column(Text) + message_type = Column(String(20), default="text") + created_at = Column(DateTime) + token_count = Column(Integer) + extra_metadata = Column(Text) + image_content = Column(Text) + + +class SqliteToolCall(Base): + __tablename__ = "tool_calls" + + id = Column(Integer, primary_key=True) + message_id = Column(Integer, nullable=False) + langgraph_tool_call_id = Column(String(100)) + tool_name = Column(String(100)) + tool_input = Column(Text) + tool_output = Column(Text) + status = Column(String(20), default="pending") + error_message = Column(Text) + created_at = Column(DateTime) + + +class SqliteConversationStats(Base): + __tablename__ = "conversation_stats" + + id = Column(Integer, primary_key=True) + conversation_id = Column(Integer, nullable=False) + message_count = Column(Integer, default=0) + total_tokens = Column(Integer, default=0) + model_used = Column(String(100)) + user_feedback = Column(String(20)) + created_at = Column(DateTime) + updated_at = Column(DateTime) + + +class SqliteOperationLog(Base): + __tablename__ = "operation_logs" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer) # 外键到 users.id + operation = Column(String(100)) + details = Column(Text) + ip_address = Column(String(50)) + timestamp = Column(DateTime) + + +class SqliteMessageFeedback(Base): + __tablename__ = "message_feedbacks" + + id = Column(Integer, primary_key=True) + message_id = Column(Integer, nullable=False) + user_id = Column(String(64), nullable=False) + rating = Column(String(20)) + reason = Column(Text) + created_at = Column(DateTime) + + +class SqliteMCPServer(Base): + __tablename__ = "mcp_servers" + + id = Column(Integer, primary_key=True) + name = Column(String(100), unique=True, nullable=False) + description = Column(Text) + transport = Column(String(20), default="sse") + url = Column(String(500)) + command = Column(String(255)) + args = Column(Text) + headers = Column(Text) + timeout = Column(Integer) + sse_read_timeout = Column(Integer) + tags = Column(Text) + icon = Column(String(500)) + enabled = Column(Integer, default=1) # 1=是, 0=否 + disabled_tools = Column(Text) + created_by = Column(String(100), nullable=False) # 创建人用户名 + updated_by = Column(String(100), nullable=False) # 修改人用户名 + created_at = Column(DateTime) + updated_at = Column(DateTime) + + +def _utc_dt(value: Any) -> datetime | None: + """Convert various datetime formats to naive UTC datetime.""" + if not value: + return None + if isinstance(value, datetime): + if value.tzinfo is None: + return value + return value.astimezone(UTC).replace(tzinfo=None) + if isinstance(value, (int, float)): + return datetime.fromtimestamp(value, tz=UTC).replace(tzinfo=None) + if isinstance(value, str): + v = value.strip() + if not v: + return None + try: + dt_val = datetime.fromisoformat(v.replace("Z", "+00:00")) + if dt_val.tzinfo is None: + return dt_val + return dt_val.astimezone(UTC).replace(tzinfo=None) + except ValueError: + return None + return None + + +class SQLiteReader: + """SQLite 数据读取器""" + + def __init__(self): + db_path = os.path.join(config.save_dir, "database", "server.db") + self.engine = create_engine(f"sqlite:///{db_path}") + self.Session = sessionmaker(bind=self.engine) + + def get_session(self): + return self.Session() + + def read_departments(self) -> list[SqliteDepartment]: + with self.get_session() as session: + return session.execute(select(SqliteDepartment)).scalars().all() + + def read_users(self) -> list[SqliteUser]: + with self.get_session() as session: + return session.execute(select(SqliteUser)).scalars().all() + + def read_conversations(self) -> list[SqliteConversation]: + with self.get_session() as session: + return session.execute(select(SqliteConversation)).scalars().all() + + def read_messages(self) -> list[SqliteMessage]: + with self.get_session() as session: + return session.execute(select(SqliteMessage)).scalars().all() + + def read_tool_calls(self) -> list[SqliteToolCall]: + with self.get_session() as session: + return session.execute(select(SqliteToolCall)).scalars().all() + + def read_conversation_stats(self) -> list[SqliteConversationStats]: + with self.get_session() as session: + return session.execute(select(SqliteConversationStats)).scalars().all() + + def read_operation_logs(self) -> list[SqliteOperationLog]: + with self.get_session() as session: + return session.execute(select(SqliteOperationLog)).scalars().all() + + def read_message_feedbacks(self) -> list[SqliteMessageFeedback]: + with self.get_session() as session: + return session.execute(select(SqliteMessageFeedback)).scalars().all() + + def read_mcp_servers(self) -> list[SqliteMCPServer]: + with self.get_session() as session: + return session.execute(select(SqliteMCPServer)).scalars().all() + + def count_table(self, table_name: str) -> int: + with self.get_session() as session: + result = session.execute(text(f"SELECT COUNT(*) FROM {table_name}")) + return result.scalar() or 0 + + +async def migrate_departments(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, int]: + """迁移部门数据""" + sqlite_depts = sqlite_reader.read_departments() + logger.info(f"准备迁移 {len(sqlite_depts)} 个部门") + + created = 0 + if dry_run: + for sqlite_dept in sqlite_depts: + logger.info(f"[DRY-RUN] 将创建部门: {sqlite_dept.name}") + elif execute: + async with pg_manager.get_async_session_context() as session: + for sqlite_dept in sqlite_depts: + # 检查是否已存在 + existing = await session.execute(select(Department).where(Department.id == sqlite_dept.id)) + if existing.scalar_one_or_none() is None: + dept = Department( + id=sqlite_dept.id, + name=sqlite_dept.name, + description=sqlite_dept.description, + created_at=_utc_dt(sqlite_dept.created_at), + ) + session.add(dept) + created += 1 + + return {"total": len(sqlite_depts), "created": created} + + +async def migrate_users(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, int]: + """迁移用户数据""" + sqlite_users = sqlite_reader.read_users() + logger.info(f"准备迁移 {len(sqlite_users)} 个用户") + + created = 0 + if dry_run: + for sqlite_user in sqlite_users: + logger.info(f"[DRY-RUN] 将创建用户: {sqlite_user.username} ({sqlite_user.user_id})") + elif execute: + async with pg_manager.get_async_session_context() as session: + for sqlite_user in sqlite_users: + existing = await session.execute(select(User).where(User.id == sqlite_user.id)) + if existing.scalar_one_or_none() is None: + user = User( + id=sqlite_user.id, + username=sqlite_user.username, + user_id=sqlite_user.user_id, + phone_number=sqlite_user.phone_number, + avatar=sqlite_user.avatar, + password_hash=sqlite_user.password_hash, + role=sqlite_user.role, + department_id=sqlite_user.department_id, + created_at=_utc_dt(sqlite_user.created_at), + last_login=_utc_dt(sqlite_user.last_login), + login_failed_count=sqlite_user.login_failed_count, + last_failed_login=_utc_dt(sqlite_user.last_failed_login), + login_locked_until=_utc_dt(sqlite_user.login_locked_until), + is_deleted=sqlite_user.is_deleted, + deleted_at=_utc_dt(sqlite_user.deleted_at), + ) + session.add(user) + created += 1 + + return {"total": len(sqlite_users), "created": created} + + +async def migrate_conversations(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, int]: + """迁移对话数据""" + sqlite_convs = sqlite_reader.read_conversations() + logger.info(f"准备迁移 {len(sqlite_convs)} 个对话") + + created = 0 + if dry_run: + for sqlite_conv in sqlite_convs: + logger.info(f"[DRY-RUN] 将创建对话: {sqlite_conv.thread_id}") + elif execute: + async with pg_manager.get_async_session_context() as session: + for sqlite_conv in sqlite_convs: + existing = await session.execute(select(Conversation).where(Conversation.id == sqlite_conv.id)) + if existing.scalar_one_or_none() is None: + # 截断过长的 title + title = sqlite_conv.title + if title and len(title) > 255: + title = title[:255] + logger.warning(f"截断对话标题 (id={sqlite_conv.id}): 原始长度={len(sqlite_conv.title)}") + conv = Conversation( + id=sqlite_conv.id, + thread_id=sqlite_conv.thread_id, + user_id=sqlite_conv.user_id, + agent_id=sqlite_conv.agent_id, + title=title, + status=sqlite_conv.status, + created_at=_utc_dt(sqlite_conv.created_at), + updated_at=_utc_dt(sqlite_conv.updated_at), + extra_metadata=sqlite_conv.extra_metadata, + ) + session.add(conv) + created += 1 + + return {"total": len(sqlite_convs), "created": created} + + +async def migrate_messages(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, int]: + """迁移消息数据""" + sqlite_messages = sqlite_reader.read_messages() + logger.info(f"准备迁移 {len(sqlite_messages)} 条消息") + + created = 0 + if dry_run: + for sqlite_msg in sqlite_messages: + logger.info(f"[DRY-RUN] 将创建消息: id={sqlite_msg.id}, conversation={sqlite_msg.conversation_id}") + elif execute: + async with pg_manager.get_async_session_context() as session: + for sqlite_msg in sqlite_messages: + existing = await session.execute(select(Message).where(Message.id == sqlite_msg.id)) + if existing.scalar_one_or_none() is None: + msg = Message( + id=sqlite_msg.id, + conversation_id=sqlite_msg.conversation_id, + role=sqlite_msg.role, + content=sqlite_msg.content, + message_type=sqlite_msg.message_type, + created_at=_utc_dt(sqlite_msg.created_at), + token_count=sqlite_msg.token_count, + extra_metadata=sqlite_msg.extra_metadata, + image_content=sqlite_msg.image_content, + ) + session.add(msg) + created += 1 + + return {"total": len(sqlite_messages), "created": created} + + +async def migrate_tool_calls(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, int]: + """迁移工具调用数据""" + sqlite_calls = sqlite_reader.read_tool_calls() + logger.info(f"准备迁移 {len(sqlite_calls)} 个工具调用") + + created = 0 + if dry_run: + for sqlite_call in sqlite_calls: + logger.info(f"[DRY-RUN] 将创建工具调用: id={sqlite_call.id}, tool={sqlite_call.tool_name}") + elif execute: + async with pg_manager.get_async_session_context() as session: + for sqlite_call in sqlite_calls: + existing = await session.execute(select(ToolCall).where(ToolCall.id == sqlite_call.id)) + if existing.scalar_one_or_none() is None: + call = ToolCall( + id=sqlite_call.id, + message_id=sqlite_call.message_id, + langgraph_tool_call_id=sqlite_call.langgraph_tool_call_id, + tool_name=sqlite_call.tool_name, + tool_input=sqlite_call.tool_input, + tool_output=sqlite_call.tool_output, + status=sqlite_call.status, + error_message=sqlite_call.error_message, + created_at=_utc_dt(sqlite_call.created_at), + ) + session.add(call) + created += 1 + + return {"total": len(sqlite_calls), "created": created} + + +async def migrate_conversation_stats(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, int]: + """迁移对话统计数据""" + sqlite_stats = sqlite_reader.read_conversation_stats() + logger.info(f"准备迁移 {len(sqlite_stats)} 条对话统计") + + created = 0 + if dry_run: + for sqlite_stat in sqlite_stats: + logger.info(f"[DRY-RUN] 将创建对话统计: conversation_id={sqlite_stat.conversation_id}") + elif execute: + async with pg_manager.get_async_session_context() as session: + for sqlite_stat in sqlite_stats: + existing = await session.execute( + select(ConversationStats).where(ConversationStats.id == sqlite_stat.id) + ) + if existing.scalar_one_or_none() is None: + stat = ConversationStats( + id=sqlite_stat.id, + conversation_id=sqlite_stat.conversation_id, + message_count=sqlite_stat.message_count, + total_tokens=sqlite_stat.total_tokens, + model_used=sqlite_stat.model_used, + user_feedback=sqlite_stat.user_feedback, + created_at=_utc_dt(sqlite_stat.created_at), + updated_at=_utc_dt(sqlite_stat.updated_at), + ) + session.add(stat) + created += 1 + + return {"total": len(sqlite_stats), "created": created} + + +async def migrate_operation_logs(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, int]: + """迁移操作日志数据""" + sqlite_logs = sqlite_reader.read_operation_logs() + logger.info(f"准备迁移 {len(sqlite_logs)} 条操作日志") + + created = 0 + if dry_run: + for sqlite_log in sqlite_logs: + logger.info(f"[DRY-RUN] 将创建操作日志: id={sqlite_log.id}, operation={sqlite_log.operation}") + elif execute: + async with pg_manager.get_async_session_context() as session: + for sqlite_log in sqlite_logs: + existing = await session.execute(select(OperationLog).where(OperationLog.id == sqlite_log.id)) + if existing.scalar_one_or_none() is None: + log = OperationLog( + id=sqlite_log.id, + user_id=sqlite_log.user_id, + operation=sqlite_log.operation, + details=sqlite_log.details, + ip_address=sqlite_log.ip_address, + timestamp=_utc_dt(sqlite_log.timestamp), + ) + session.add(log) + created += 1 + + return {"total": len(sqlite_logs), "created": created} + + +async def migrate_message_feedbacks(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, int]: + """迁移消息反馈数据""" + sqlite_feedbacks = sqlite_reader.read_message_feedbacks() + logger.info(f"准备迁移 {len(sqlite_feedbacks)} 条消息反馈") + + created = 0 + if dry_run: + for sqlite_fb in sqlite_feedbacks: + logger.info(f"[DRY-RUN] 将创建消息反馈: id={sqlite_fb.id}, rating={sqlite_fb.rating}") + elif execute: + async with pg_manager.get_async_session_context() as session: + for sqlite_fb in sqlite_feedbacks: + existing = await session.execute(select(MessageFeedback).where(MessageFeedback.id == sqlite_fb.id)) + if existing.scalar_one_or_none() is None: + fb = MessageFeedback( + id=sqlite_fb.id, + message_id=sqlite_fb.message_id, + user_id=sqlite_fb.user_id, + rating=sqlite_fb.rating, + reason=sqlite_fb.reason, + created_at=_utc_dt(sqlite_fb.created_at), + ) + session.add(fb) + created += 1 + + return {"total": len(sqlite_feedbacks), "created": created} + + +async def migrate_mcp_servers(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, int]: + """迁移 MCP 服务器数据""" + sqlite_servers = sqlite_reader.read_mcp_servers() + logger.info(f"准备迁移 {len(sqlite_servers)} 个 MCP 服务器") + + created = 0 + if dry_run: + for sqlite_server in sqlite_servers: + logger.info(f"[DRY-RUN] 将创建 MCP 服务器: {sqlite_server.name}") + elif execute: + async with pg_manager.get_async_session_context() as session: + for sqlite_server in sqlite_servers: + existing = await session.execute(select(MCPServer).where(MCPServer.name == sqlite_server.name)) + if existing.scalar_one_or_none() is None: + server = MCPServer( + name=sqlite_server.name, + description=sqlite_server.description, + transport=sqlite_server.transport, + url=sqlite_server.url, + command=sqlite_server.command, + args=sqlite_server.args, + headers=sqlite_server.headers, + timeout=sqlite_server.timeout, + sse_read_timeout=sqlite_server.sse_read_timeout, + tags=sqlite_server.tags, + icon=sqlite_server.icon, + enabled=sqlite_server.enabled, + disabled_tools=sqlite_server.disabled_tools, + created_by=sqlite_server.created_by, + updated_by=sqlite_server.updated_by, + created_at=_utc_dt(sqlite_server.created_at), + updated_at=_utc_dt(sqlite_server.updated_at), + ) + session.add(server) + created += 1 + + return {"total": len(sqlite_servers), "created": created} + + +async def verify_migration(sqlite_reader: SQLiteReader) -> dict[str, dict]: + """验证迁移结果""" + # 使用 (模型, 主键列名) 格式,支持不同表使用不同的主键 + tables = [ + ("departments", Department, "id"), + ("users", User, "id"), + ("conversations", Conversation, "id"), + ("messages", Message, "id"), + ("tool_calls", ToolCall, "id"), + ("conversation_stats", ConversationStats, "id"), + ("operation_logs", OperationLog, "id"), + ("message_feedbacks", MessageFeedback, "id"), + ("mcp_servers", MCPServer, "name"), # MCPServer 使用 name 作为主键 + ] + + results = {} + for table_name, model, pk_column in tables: + sqlite_count = sqlite_reader.count_table(table_name) + + async with pg_manager.get_async_session_context() as session: + from sqlalchemy import func + + pk_attr = getattr(model, pk_column) + result = await session.execute(select(func.count(pk_attr))) + pg_count = result.scalar() or 0 + + results[table_name] = { + "sqlite": sqlite_count, + "postgresql": pg_count, + "match": sqlite_count == pg_count, + } + + return results + + +async def rollback_migration() -> None: + """回滚迁移 - 删除所有业务数据表""" + logger.warning("开始回滚迁移...") + + # 按外键依赖顺序删除 + tables_to_delete = [ + MessageFeedback, + OperationLog, + ConversationStats, + ToolCall, + Message, + Conversation, + User, + Department, + MCPServer, + ] + + for model in tables_to_delete: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(model)) + records = result.scalars().all() + for record in records: + await session.delete(record) + + logger.warning("回滚完成 - 已删除所有迁移的业务数据") + + +async def migrate_all(sqlite_reader: SQLiteReader, dry_run: bool, execute: bool) -> dict[str, Any]: + """执行所有迁移""" + results = {} + + # 按外键依赖顺序迁移 + results["departments"] = await migrate_departments(sqlite_reader, dry_run, execute) + results["users"] = await migrate_users(sqlite_reader, dry_run, execute) + results["conversations"] = await migrate_conversations(sqlite_reader, dry_run, execute) + results["messages"] = await migrate_messages(sqlite_reader, dry_run, execute) + results["tool_calls"] = await migrate_tool_calls(sqlite_reader, dry_run, execute) + results["conversation_stats"] = await migrate_conversation_stats(sqlite_reader, dry_run, execute) + results["operation_logs"] = await migrate_operation_logs(sqlite_reader, dry_run, execute) + results["message_feedbacks"] = await migrate_message_feedbacks(sqlite_reader, dry_run, execute) + results["mcp_servers"] = await migrate_mcp_servers(sqlite_reader, dry_run, execute) + + return results + + +async def main() -> None: + parser = argparse.ArgumentParser(description="SQLite 到 PostgreSQL 业务数据迁移") + parser.add_argument("--dry-run", action="store_true", help="预览迁移,不执行") + parser.add_argument("--execute", action="store_true", help="执行迁移") + parser.add_argument("--verify", action="store_true", help="验证迁移结果") + parser.add_argument("--rollback", action="store_true", help="回滚迁移") + parser.add_argument("--migrate-all", action="store_true", help="迁移所有业务数据") + parser.add_argument("--init-tables", action="store_true", help="仅初始化业务表结构") + + args = parser.parse_args() + + if not any([args.dry_run, args.execute, args.verify, args.rollback, args.migrate_all, args.init_tables]): + args.dry_run = True + + # 初始化 PostgreSQL 管理器 + pg_manager.initialize() + logger.info("PostgreSQL manager initialized") + + if args.init_tables: + # 仅初始化表结构 + await pg_manager.create_business_tables() + logger.info("业务表结构初始化完成") + return + + if args.verify: + # 验证模式 + sqlite_reader = SQLiteReader() + results = await verify_migration(sqlite_reader) + + logger.info("=" * 60) + logger.info("迁移验证结果:") + logger.info("=" * 60) + all_match = True + for table_name, counts in results.items(): + status = "✓" if counts["match"] else "✗" + logger.info(f"{status} {table_name}: SQLite={counts['sqlite']}, PostgreSQL={counts['postgresql']}") + if not counts["match"]: + all_match = False + logger.info("=" * 60) + logger.info(f"全部匹配: {'是' if all_match else '否'}") + return + + if args.rollback: + # 回滚模式 + if args.dry_run: + logger.info("[DRY-RUN] 将回滚所有迁移的业务数据") + else: + await rollback_migration() + return + + # 迁移模式 + sqlite_reader = SQLiteReader() + + if args.migrate_all: + # 检查是否需要初始化表结构 + logger.info("检查业务表结构...") + await pg_manager.create_business_tables() + logger.info("业务表结构就绪") + + results = await migrate_all(sqlite_reader, args.dry_run, args.execute) + + logger.info("=" * 60) + logger.info("迁移完成:") + for table_name, counts in results.items(): + logger.info(f" {table_name}: {counts['created']}/{counts['total']}") + logger.info("=" * 60) + + if not args.dry_run: + logger.info("建议运行 --verify 验证数据完整性") + else: + logger.info("使用 --migrate-all 执行迁移,或使用 --verify 验证数据") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/migrate_kb_metadata_to_db.py b/scripts/migrate_kb_metadata_to_db.py new file mode 100644 index 000000000..7229d70be --- /dev/null +++ b/scripts/migrate_kb_metadata_to_db.py @@ -0,0 +1,339 @@ +import argparse +import asyncio +import glob +import json +import os +import sys +from datetime import datetime, UTC +from typing import Any + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +os.environ.setdefault("YUXI_SKIP_APP_INIT", "1") + +from src import config +from src.repositories.evaluation_repository import EvaluationRepository +from src.repositories.knowledge_base_repository import KnowledgeBaseRepository +from src.repositories.knowledge_file_repository import KnowledgeFileRepository +from src.repositories.task_repository import TaskRepository +from src.utils import logger + + +def _load_json(path: str) -> dict[str, Any]: + if not os.path.exists(path): + return {} + with open(path, encoding="utf-8") as f: + return json.load(f) + + +def _utc_dt(value: Any) -> datetime | None: + """Convert various datetime formats to naive UTC datetime (consistent with model).""" + if not value: + return None + if isinstance(value, datetime): + # 转换为 UTC 并移除时区信息(模型使用 DateTime 无时区) + if value.tzinfo is None: + return value + return value.astimezone(UTC).replace(tzinfo=None) + if isinstance(value, (int, float)): + # 时间戳转换为 UTC 时间 + return datetime.fromtimestamp(value, tz=UTC).replace(tzinfo=None) + if isinstance(value, str): + v = value.strip() + if not v: + return None + try: + # 解析 ISO 格式并转换为 UTC + dt_val = datetime.fromisoformat(v.replace("Z", "+00:00")) + if dt_val.tzinfo is None: + return dt_val + return dt_val.astimezone(UTC).replace(tzinfo=None) + except ValueError: + return None + return None + + +def _default_share_config(meta: dict[str, Any]) -> dict[str, Any]: + share_config = meta.get("share_config") or {} + if "is_shared" not in share_config: + share_config["is_shared"] = True + if "accessible_departments" not in share_config: + share_config["accessible_departments"] = [] + return share_config + + +async def rollback_all() -> None: + eval_repo = EvaluationRepository() + kb_repo = KnowledgeBaseRepository() + file_repo = KnowledgeFileRepository() + task_repo = TaskRepository() + + await task_repo.delete_all() + await eval_repo.delete_all() + + rows = await kb_repo.get_all() + for row in rows: + await file_repo.delete_by_db_id(row.db_id) + await kb_repo.delete(row.db_id) + + +async def migrate(dry_run: bool, execute: bool, rollback: bool) -> None: + from src.storage.postgres.manager import pg_manager + + base_dir = os.path.join(config.save_dir, "knowledge_base_data") + global_meta_path = os.path.join(base_dir, "global_metadata.json") + global_meta = _load_json(global_meta_path).get("databases", {}) + + if rollback: + if dry_run: + logger.info("Dry-run rollback: would delete all knowledge metadata tables") + return + await rollback_all() + logger.info("Rollback completed") + return + + # 初始化表结构 + pg_manager.initialize() + await pg_manager.create_tables() + logger.info("知识库表结构初始化完成") + + kb_repo = KnowledgeBaseRepository() + file_repo = KnowledgeFileRepository() + eval_repo = EvaluationRepository() + task_repo = TaskRepository() + + kb_rows: list[dict[str, Any]] = [] + file_rows: list[tuple[str, dict[str, Any]]] = [] + benchmark_rows: list[dict[str, Any]] = [] + result_rows: list[dict[str, Any]] = [] + result_detail_rows: list[tuple[str, int, dict[str, Any]]] = [] + + kb_type_dirs = [ + p for p in glob.glob(os.path.join(base_dir, "*_data")) if os.path.isdir(p) and os.path.basename(p) != "uploads" + ] + + for kb_dir in kb_type_dirs: + kb_type = os.path.basename(kb_dir)[: -len("_data")] + meta_file = os.path.join(kb_dir, f"metadata_{kb_type}.json") + meta = _load_json(meta_file) + databases_meta: dict[str, Any] = meta.get("databases", {}) + files_meta: dict[str, Any] = meta.get("files", {}) + benchmarks_meta: dict[str, Any] = meta.get("benchmarks", {}) + + for db_id, db_meta in databases_meta.items(): + g = global_meta.get(db_id, {}) + created_at = _utc_dt(g.get("created_at") or db_meta.get("created_at")) + updated_at = _utc_dt(g.get("updated_at")) or created_at + kb_rows.append( + { + "db_id": db_id, + "name": g.get("name") or db_meta.get("name") or db_id, + "description": g.get("description") or db_meta.get("description"), + "kb_type": g.get("kb_type") or db_meta.get("kb_type") or kb_type, + "embed_info": db_meta.get("embed_info") or g.get("embed_info"), + "llm_info": db_meta.get("llm_info") or g.get("llm_info"), + "query_params": db_meta.get("query_params") or g.get("query_params"), + "additional_params": g.get("additional_params") or db_meta.get("metadata") or {}, + "share_config": _default_share_config(g or {}), + "mindmap": g.get("mindmap"), + "sample_questions": g.get("sample_questions") or [], + "created_at": created_at, + "updated_at": updated_at, + } + ) + + for file_id, fmeta in files_meta.items(): + db_id = fmeta.get("database_id") + if not db_id: + continue + file_rows.append( + ( + file_id, + { + "db_id": db_id, + "parent_id": fmeta.get("parent_id"), + "filename": fmeta.get("filename") or "", + "original_filename": fmeta.get("original_filename") or fmeta.get("file_name"), + "file_type": fmeta.get("file_type") or fmeta.get("type"), + "path": fmeta.get("path"), + "minio_url": fmeta.get("minio_url"), + "markdown_file": fmeta.get("markdown_file"), + "status": fmeta.get("status"), + "content_hash": fmeta.get("content_hash"), + "file_size": fmeta.get("size") or fmeta.get("file_size"), + "content_type": fmeta.get("content_type"), + "processing_params": fmeta.get("processing_params"), + "is_folder": bool(fmeta.get("is_folder", False)), + "error_message": fmeta.get("error") or fmeta.get("error_message"), + "created_by": str(fmeta.get("created_by")) if fmeta.get("created_by") else None, + "updated_by": str(fmeta.get("updated_by")) if fmeta.get("updated_by") else None, + "created_at": _utc_dt(fmeta.get("created_at")), + "updated_at": _utc_dt(fmeta.get("updated_at")) or _utc_dt(fmeta.get("created_at")), + }, + ) + ) + + for db_id, bmap in benchmarks_meta.items(): + if not isinstance(bmap, dict): + continue + for benchmark_id, bmeta in bmap.items(): + benchmark_rows.append( + { + "benchmark_id": benchmark_id, + "db_id": db_id, + "name": bmeta.get("name") or benchmark_id, + "description": bmeta.get("description"), + "question_count": int(bmeta.get("question_count") or 0), + "has_gold_chunks": bool(bmeta.get("has_gold_chunks")), + "has_gold_answers": bool(bmeta.get("has_gold_answers")), + "data_file_path": bmeta.get("benchmark_file") or bmeta.get("data_file_path"), + "created_by": str(bmeta.get("created_by")) if bmeta.get("created_by") else None, + "created_at": _utc_dt(bmeta.get("created_at")), + "updated_at": _utc_dt(bmeta.get("updated_at")) or _utc_dt(bmeta.get("created_at")), + } + ) + + for db_id in databases_meta.keys(): + result_dir = os.path.join(kb_dir, db_id, "results") + if not os.path.isdir(result_dir): + continue + for result_path in glob.glob(os.path.join(result_dir, "*.json")): + try: + data = _load_json(result_path) + except Exception as exc: + logger.warning(f"Skip invalid result file {result_path}: {exc}") + continue + task_id = data.get("task_id") or os.path.splitext(os.path.basename(result_path))[0] + benchmark_id = data.get("benchmark_id") + started_at = _utc_dt(data.get("started_at")) + result_rows.append( + { + "task_id": task_id, + "db_id": db_id, + "benchmark_id": benchmark_id, + "status": data.get("status") or "completed", + "retrieval_config": data.get("retrieval_config") or {}, + "metrics": data.get("metrics") or {}, + "overall_score": data.get("overall_score"), + "total_questions": int(data.get("total_questions") or 0), + "completed_questions": int(data.get("completed_questions") or 0), + "started_at": started_at, + "completed_at": _utc_dt(data.get("completed_at")) or started_at, + "created_by": str(data.get("created_by")) if data.get("created_by") else None, + } + ) + interim = data.get("interim_results") or data.get("results") or [] + for idx, item in enumerate(interim): + result_detail_rows.append( + ( + task_id, + idx, + { + "query_text": item.get("query") or item.get("query_text") or "", + "gold_chunk_ids": item.get("gold_chunk_ids"), + "gold_answer": item.get("gold_answer"), + "generated_answer": item.get("generated_answer"), + "retrieved_chunks": item.get("retrieved_chunks"), + "metrics": item.get("metrics") or {}, + }, + ) + ) + + tasks_json_path = os.path.join(config.save_dir, "tasks", "tasks.json") + task_rows: list[dict[str, Any]] = _load_json(tasks_json_path).get("tasks", []) or [] + + logger.info( + f"Prepared: knowledge_bases={len(kb_rows)}, knowledge_files={len(file_rows)}, " + f"benchmarks={len(benchmark_rows)}, results={len(result_rows)}, result_details={len(result_detail_rows)}, " + f"tasks={len(task_rows)}" + ) + + if dry_run and not execute: + return + + for payload in kb_rows: + db_id = payload["db_id"] + existing = await kb_repo.get_by_id(db_id) + data = payload.copy() + if existing is None: + await kb_repo.create(data) + else: + await kb_repo.update(db_id, data) + + # 先插入文件夹,再插入普通文件(确保父文件夹先存在) + folders = [(fid, data) for fid, data in file_rows if data.get("is_folder")] + files = [(fid, data) for fid, data in file_rows if not data.get("is_folder")] + + for file_id, data in folders: + await file_repo.upsert(file_id=file_id, data=data) + + for file_id, data in files: + await file_repo.upsert(file_id=file_id, data=data) + + for payload in benchmark_rows: + # 检查知识库是否存在 + kb = await kb_repo.get_by_id(payload["db_id"]) + if kb is None: + logger.warning(f"Skipping benchmark {payload['benchmark_id']}: knowledge base {payload['db_id']} not found") + continue + existing = await eval_repo.get_benchmark(payload["benchmark_id"]) + if existing is None: + await eval_repo.create_benchmark(payload) + + for payload in result_rows: + # 检查知识库是否存在 + kb = await kb_repo.get_by_id(payload["db_id"]) + if kb is None: + logger.warning(f"Skipping result {payload['task_id']}: knowledge base {payload['db_id']} not found") + continue + existing = await eval_repo.get_result(payload["task_id"]) + if existing is None: + await eval_repo.create_result(payload) + else: + await eval_repo.update_result(payload["task_id"], payload) + + for task_id, idx, data in result_detail_rows: + await eval_repo.upsert_result_detail(task_id=task_id, query_index=idx, data=data) + + for item in task_rows: + task_id = item.get("id") + if not task_id: + continue + payload = item.get("payload") or {} + result = item.get("result") + await task_repo.upsert( + task_id, + { + "name": item.get("name") or "Unnamed Task", + "type": item.get("type") or "general", + "status": item.get("status") or "pending", + "progress": float(item.get("progress") or 0.0), + "message": item.get("message") or "", + "payload": payload, + "result": result, + "error": item.get("error"), + "cancel_requested": 1 if item.get("cancel_requested") else 0, + "created_at": _utc_dt(item.get("created_at")), + "updated_at": _utc_dt(item.get("updated_at")) or _utc_dt(item.get("created_at")), + "started_at": _utc_dt(item.get("started_at")), + "completed_at": _utc_dt(item.get("completed_at")), + }, + ) + + logger.info("Migration completed") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--dry-run", action="store_true") + parser.add_argument("--execute", action="store_true") + parser.add_argument("--rollback", action="store_true") + args = parser.parse_args() + + if not args.dry_run and not args.execute and not args.rollback: + args.dry_run = True + + asyncio.run(migrate(dry_run=args.dry_run, execute=args.execute, rollback=args.rollback)) + + +if __name__ == "__main__": + main() diff --git a/scripts/migrate_user_fields.py b/scripts/migrate_user_fields.py deleted file mode 100644 index 7a55addb7..000000000 --- a/scripts/migrate_user_fields.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -""" -用户表字段迁移脚本 -为现有用户添加新字段:user_id, phone_number, avatar -将现有的 username 作为 user_id 的默认值 -""" - -# ruff: noqa: E402 - -import sys -from pathlib import Path - -from sqlalchemy import text - -# 添加项目根目录到Python路径 -PROJECT_ROOT = Path(__file__).parent -sys.path.insert(0, str(PROJECT_ROOT)) - -from src.storage.db.manager import db_manager -from src.storage.db.models import User as User - - -def migrate_user_fields(): - """执行用户字段迁移""" - print("开始用户字段迁移...") - - try: - # 获取数据库会话 - db = db_manager.get_session() - - # 1. 添加新字段(如果不存在) - print("检查并添加新字段...") - - # 检查字段是否存在的SQL - check_columns_sql = """ - SELECT column_name FROM information_schema.columns - WHERE table_name = 'users' AND table_schema = DATABASE() - """ - - try: - result = db.execute(text(check_columns_sql)) - existing_columns = [row[0] for row in result.fetchall()] - print(f"现有字段: {existing_columns}") - - # 添加缺失的字段 - if "user_id" not in existing_columns: - print("添加 user_id 字段...") - db.execute(text("ALTER TABLE users ADD COLUMN user_id VARCHAR(255)")) - - if "phone_number" not in existing_columns: - print("添加 phone_number 字段...") - db.execute(text("ALTER TABLE users ADD COLUMN phone_number VARCHAR(255)")) - - if "avatar" not in existing_columns: - print("添加 avatar 字段...") - db.execute(text("ALTER TABLE users ADD COLUMN avatar VARCHAR(500)")) - - db.commit() - print("字段添加完成") - - except Exception as e: - print(f"字段检查/添加失败: {e}") - # 对于SQLite,尝试直接添加字段 - try: - db.execute(text("ALTER TABLE users ADD COLUMN user_id TEXT")) - db.execute(text("ALTER TABLE users ADD COLUMN phone_number TEXT")) - db.execute(text("ALTER TABLE users ADD COLUMN avatar TEXT")) - db.commit() - print("字段添加完成(SQLite模式)") - except Exception as sqlite_e: - print(f"SQLite字段添加也失败: {sqlite_e}") - print("字段可能已存在,继续执行...") - - # 2. 为现有用户设置默认 user_id - print("为现有用户设置默认 user_id...") - - # 查询所有没有 user_id 的用户 - users_without_user_id = db.execute( - text("SELECT id, username FROM users WHERE user_id IS NULL OR user_id = ''") - ).fetchall() - - print(f"找到 {len(users_without_user_id)} 个需要设置 user_id 的用户") - - for user_id, username in users_without_user_id: - # 将 username 作为默认的 user_id - print(f"为用户 {username} (ID: {user_id}) 设置 user_id: {username}") - db.execute(text("UPDATE users SET user_id = :user_id WHERE id = :id"), {"user_id": username, "id": user_id}) - - db.commit() - - # 3. 添加唯一索引 - print("添加唯一索引...") - try: - # 先检查索引是否存在 - try: - db.execute(text("CREATE UNIQUE INDEX idx_users_user_id ON users(user_id)")) - print("创建 user_id 唯一索引") - except Exception: - print("user_id 索引可能已存在") - - try: - db.execute( - text( - "CREATE UNIQUE INDEX idx_users_phone_number ON users(phone_number) " - "WHERE phone_number IS NOT NULL" - ) - ) - print("创建 phone_number 唯一索引") - except Exception: - print("phone_number 索引可能已存在") - - db.commit() - - except Exception as e: - print(f"索引创建失败: {e}") - print("继续执行...") - - # 4. 验证迁移结果 - print("验证迁移结果...") - total_users = db.execute(text("SELECT COUNT(*) FROM users")).scalar() - users_with_user_id = db.execute( - text("SELECT COUNT(*) FROM users WHERE user_id IS NOT NULL AND user_id != ''") - ).scalar() - - print(f"总用户数: {total_users}") - print(f"已设置 user_id 的用户数: {users_with_user_id}") - - if total_users == users_with_user_id: - print("✅ 迁移成功完成!") - else: - print("❌ 迁移可能有问题,请检查数据库") - - except Exception as e: - print(f"迁移过程中发生错误: {e}") - if "db" in locals(): - db.rollback() - raise - finally: - if "db" in locals(): - db.close() - - -if __name__ == "__main__": - migrate_user_fields() diff --git a/scripts/migrate_user_soft_delete.py b/scripts/migrate_user_soft_delete.py deleted file mode 100644 index 304c8525b..000000000 --- a/scripts/migrate_user_soft_delete.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env python3 -""" -用户表软删除字段迁移脚本 - -该脚本用于为历史数据库添加 `is_deleted` 与 `deleted_at` 字段, -同时会执行现有的数据库迁移逻辑,确保用户表结构与最新模型保持一致。 -""" - -from __future__ import annotations - -import sys -from pathlib import Path - -# 将项目根目录加入到 Python 路径,便于脚本在容器中执行 -PROJECT_ROOT = Path(__file__).resolve().parent.parent -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - -from server.utils.migrate import DatabaseMigrator # noqa: E402 -from src import config # noqa: E402 - - -def main() -> None: - db_path = Path(config.save_dir) / "database" / "server.db" - migrator = DatabaseMigrator(str(db_path)) - - print(f"检测数据库: {db_path}") - current_version = migrator.get_current_version() - latest_version = migrator.get_latest_migration_version() - print(f"当前迁移版本: v{current_version}, 最新版本: v{latest_version}") - - try: - migrator.run_migrations() - print("✅ 迁移完成,数据库结构已更新") - except Exception as exc: - print(f"❌ 迁移失败: {exc}") - raise - - -if __name__ == "__main__": - main() diff --git a/scripts/test_agent_configs.py b/scripts/test_agent_configs.py new file mode 100644 index 000000000..d21d5cb00 --- /dev/null +++ b/scripts/test_agent_configs.py @@ -0,0 +1,236 @@ +import os +import sys +import time +from dataclasses import dataclass + +import requests + + +@dataclass +class Account: + username: str + password: str + label: str + + +ACCOUNTS = { + "superadmin": Account("zwj", "zwj12138", "superadmin"), + "dept_admin": Account("ceshizhuguan", "test_admin123", "dept_admin"), + "dept_user": Account("food2025", "jnufood", "dept_user"), +} + + +def _base_url() -> str: + return os.getenv("BASE_URL", "http://localhost:8000").rstrip("/") + + +def _request(method: str, path: str, *, token: str | None = None, json_data=None): + url = f"{_base_url()}{path}" + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + resp = requests.request(method, url, headers=headers, json=json_data, timeout=60) + return resp + + +def login(account: Account) -> str: + url = f"{_base_url()}/api/auth/token" + resp = requests.post( + url, + data={"username": account.username, "password": account.password}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=60, + ) + if resp.status_code != 200: + raise RuntimeError(f"login failed {account.label}: {resp.status_code} {resp.text}") + data = resp.json() + return data["access_token"], data.get("department_id"), data.get("user_id") + + +def get_first_agent_id(token: str) -> str: + resp = _request("GET", "/api/chat/agent", token=token) + if resp.status_code != 200: + raise RuntimeError(f"get agents failed: {resp.status_code} {resp.text}") + agents = resp.json().get("agents", []) + if not agents: + raise RuntimeError("no agents returned") + return agents[0]["id"] + + +def list_configs(token: str, agent_id: str) -> list[dict]: + resp = _request("GET", f"/api/chat/agent/{agent_id}/configs", token=token) + if resp.status_code != 200: + raise RuntimeError(f"list configs failed: {resp.status_code} {resp.text}") + return resp.json().get("configs", []) + + +def get_config(token: str, agent_id: str, config_id: int) -> dict: + resp = _request("GET", f"/api/chat/agent/{agent_id}/configs/{config_id}", token=token) + if resp.status_code != 200: + raise RuntimeError(f"get config failed: {resp.status_code} {resp.text}") + return resp.json()["config"] + + +def create_config(token: str, agent_id: str, name: str, set_default: bool = False) -> dict: + payload = { + "name": name, + "description": f"created-by-test {int(time.time())}", + "icon": None, + "pics": [], + "examples": ["hello"], + "config_json": {"context": {"system_prompt": f"system_prompt::{name}"}}, + "set_default": set_default, + } + resp = _request("POST", f"/api/chat/agent/{agent_id}/configs", token=token, json_data=payload) + if resp.status_code != 200: + raise RuntimeError(f"create config failed: {resp.status_code} {resp.text}") + return resp.json()["config"] + + +def update_config(token: str, agent_id: str, config_id: int, context_updates: dict) -> dict: + payload = {"config_json": {"context": context_updates}} + resp = _request("PUT", f"/api/chat/agent/{agent_id}/configs/{config_id}", token=token, json_data=payload) + if resp.status_code != 200: + raise RuntimeError(f"update config failed: {resp.status_code} {resp.text}") + return resp.json()["config"] + + +def set_default(token: str, agent_id: str, config_id: int) -> dict: + resp = _request("POST", f"/api/chat/agent/{agent_id}/configs/{config_id}/set_default", token=token, json_data={}) + if resp.status_code != 200: + raise RuntimeError(f"set default failed: {resp.status_code} {resp.text}") + return resp.json()["config"] + + +def delete_config(token: str, agent_id: str, config_id: int) -> None: + resp = _request("DELETE", f"/api/chat/agent/{agent_id}/configs/{config_id}", token=token) + if resp.status_code != 200: + raise RuntimeError(f"delete config failed: {resp.status_code} {resp.text}") + + +def chat_smoke(token: str, agent_id: str, config_id: int) -> None: + url = f"{_base_url()}/api/chat/agent/{agent_id}" + resp = requests.post( + url, + headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"}, + json={"query": "ping", "config": {"thread_id": None, "agent_config_id": config_id}}, + stream=True, + timeout=120, + ) + if resp.status_code != 200: + raise RuntimeError(f"chat failed: {resp.status_code} {resp.text}") + lines = [] + for line in resp.iter_lines(decode_unicode=True): + if not line: + continue + lines.append(line) + if '"status": "finished"' in line: + break + if not any('"status": "init"' in s for s in lines): + raise RuntimeError("chat stream missing init chunk") + if not any('"status": "finished"' in s for s in lines): + raise RuntimeError("chat stream missing finished chunk") + + +def assert_forbidden(resp: requests.Response, label: str): + if resp.status_code != 403: + raise RuntimeError(f"expected 403 for {label}, got {resp.status_code}: {resp.text}") + + +def main(): + super_token, super_dept_id, super_user_id = login(ACCOUNTS["superadmin"]) + test_admin_token, test_admin_dept_id, _ = login(ACCOUNTS["dept_admin"]) + default_dept_token, default_dept_id, _ = login(ACCOUNTS["dept_user"]) + + agent_id = get_first_agent_id(default_dept_token) + print("agent_id", agent_id) + + def run_dept_flow(token: str, dept_label: str): + cfgs = list_configs(token, agent_id) + if not cfgs: + raise RuntimeError(f"{dept_label}: configs should have default created") + default_cfg = next((c for c in cfgs if c.get("is_default")), cfgs[0]) + print(dept_label, "default_config", default_cfg["id"], default_cfg["name"]) + + created = create_config(token, agent_id, f"{dept_label}-测试配置A", set_default=False) + print(dept_label, "created_config", created["id"], created["name"]) + + dup = create_config(token, agent_id, f"{dept_label}-测试配置A", set_default=False) + print(dept_label, "created_duplicate_config", dup["id"], dup["name"]) + if dup["name"] == f"{dept_label}-测试配置A": + raise RuntimeError(f"{dept_label}: duplicate name should be auto-renamed with -副本") + + updated_default = set_default(token, agent_id, created["id"]) + if not updated_default.get("is_default"): + raise RuntimeError(f"{dept_label}: set_default should mark config as default") + + cfgs2 = list_configs(token, agent_id) + defaults = [c for c in cfgs2 if c.get("is_default")] + if len(defaults) != 1: + raise RuntimeError(f"{dept_label}: default must be unique, got {len(defaults)}") + + cfg_payload = get_config(token, agent_id, created["id"]) + if cfg_payload["id"] != created["id"]: + raise RuntimeError(f"{dept_label}: get config mismatch") + + updated = update_config( + token, + agent_id, + created["id"], + { + "system_prompt": f"system_prompt::{dept_label}::updated", + "tools": [], + "knowledges": [], + "mcps": [], + }, + ) + if (updated.get("config_json") or {}).get("context", {}).get( + "system_prompt" + ) != f"system_prompt::{dept_label}::updated": + raise RuntimeError(f"{dept_label}: update did not persist system_prompt") + + chat_smoke(token, agent_id, created["id"]) + + delete_config(token, agent_id, created["id"]) + delete_config(token, agent_id, dup["id"]) + + cfgs3 = list_configs(token, agent_id) + if not cfgs3: + raise RuntimeError(f"{dept_label}: configs should not be empty after delete; default should exist") + + run_dept_flow(default_dept_token, "default_dept") + run_dept_flow(test_admin_token, "test_dept") + + if super_dept_id is None or super_user_id is None: + raise RuntimeError("superadmin token missing department_id/user_id") + + tmp_user_payload = { + "username": f"tmp_user_{int(time.time())}", + "password": "tmp_pass_123", + "role": "user", + "department_id": int(super_dept_id), + } + created_user = _request("POST", "/api/auth/users", token=super_token, json_data=tmp_user_payload) + if created_user.status_code != 200: + raise RuntimeError(f"create tmp user failed: {created_user.status_code} {created_user.text}") + tmp_user = created_user.json() + tmp_user_login = tmp_user["user_id"] + tmp_user_id = tmp_user["id"] + + tmp_token, _, _ = login(Account(tmp_user_login, "tmp_pass_123", "tmp_user")) + forbidden = _request("POST", f"/api/chat/agent/{agent_id}/configs", token=tmp_token, json_data={"name": "x"}) + assert_forbidden(forbidden, "user create config") + + deleted_user = _request("DELETE", f"/api/auth/users/{tmp_user_id}", token=super_token) + if deleted_user.status_code != 200: + raise RuntimeError(f"delete tmp user failed: {deleted_user.status_code} {deleted_user.text}") + + print("OK") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print("FAILED:", e) + sys.exit(1) diff --git a/server/routers/auth_router.py b/server/routers/auth_router.py index b14592f96..2f4eebc90 100644 --- a/server/routers/auth_router.py +++ b/server/routers/auth_router.py @@ -1,5 +1,6 @@ import re import uuid +from src.utils import logger from fastapi import APIRouter, Depends, HTTPException, Request, status, UploadFile, File from fastapi.security import OAuth2PasswordRequestForm @@ -7,14 +8,22 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession -from src.storage.db.manager import db_manager -from src.storage.db.models import User, Department -from server.utils.auth_middleware import get_admin_user, get_current_user, get_db, get_required_user +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import User, Department +from src.repositories.user_repository import UserRepository +from src.repositories.department_repository import DepartmentRepository +from server.utils.auth_middleware import ( + get_admin_user, + get_superadmin_user, + get_current_user, + get_db, + get_required_user, +) from server.utils.auth_utils import AuthUtils from server.utils.user_utils import generate_unique_user_id, validate_username, is_valid_phone_number from server.utils.common_utils import log_operation from src.storage.minio import aupload_file_to_minio -from src.utils.datetime_utils import utc_now +from src.utils.datetime_utils import utc_now_naive # 创建路由器 auth = APIRouter(prefix="/auth", tags=["authentication"]) @@ -30,6 +39,8 @@ class Token(BaseModel): phone_number: str | None = None avatar: str | None = None role: str + department_id: int | None = None + department_name: str | None = None class UserCreate(BaseModel): @@ -166,7 +177,7 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( # 登录成功,重置失败计数器 user.reset_failed_login() - user.last_login = utc_now() + user.last_login = utc_now_naive() await db.commit() # 生成访问令牌 @@ -176,6 +187,12 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( # 记录登录操作 await log_operation(db, user.id, "登录") + # 获取部门名称 + department_name = None + if user.department_id: + result = await db.execute(select(Department.name).filter(Department.id == user.department_id)) + department_name = result.scalar_one_or_none() + return { "access_token": access_token, "token_type": "bearer", @@ -186,13 +203,14 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( "avatar": user.avatar, "role": user.role, "department_id": user.department_id, + "department_name": department_name, } # 路由:校验是否需要初始化管理员 @auth.get("/check-first-run") async def check_first_run(): - is_first_run = await db_manager.async_check_first_run() + is_first_run = await pg_manager.async_check_first_run() return {"first_run": is_first_run} @@ -200,7 +218,7 @@ async def check_first_run(): @auth.post("/initialize", response_model=Token) async def initialize_admin(admin_data: InitializeAdmin, db: AsyncSession = Depends(get_db)): # 检查是否是首次运行 - if not await db_manager.async_check_first_run(): + if not await pg_manager.async_check_first_run(): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="系统已经初始化,无法再次创建初始管理员", @@ -230,24 +248,28 @@ async def initialize_admin(admin_data: InitializeAdmin, db: AsyncSession = Depen user_id = admin_data.user_id # 创建默认部门 - default_department = Department(name="默认部门", description="系统初始化时创建的默认部门") - db.add(default_department) - await db.flush() # 获取部门ID - - new_admin = User( - username=admin_data.user_id, # username和user_id设置为相同值 - user_id=user_id, - phone_number=admin_data.phone_number, - avatar=None, # 初始化时头像为空 - password_hash=hashed_password, - role="superadmin", - department_id=default_department.id, - last_login=utc_now(), + dept_repo = DepartmentRepository() + default_department = await dept_repo.create( + { + "name": "默认部门", + "description": "系统初始化时创建的默认部门", + } ) - db.add(new_admin) - await db.commit() - await db.refresh(new_admin) + # 创建管理员用户 + user_repo = UserRepository() + new_admin = await user_repo.create( + { + "username": admin_data.user_id, + "user_id": user_id, + "phone_number": admin_data.phone_number, + "avatar": None, + "password_hash": hashed_password, + "role": "superadmin", + "department_id": default_department.id, + "last_login": utc_now_naive(), + } + ) # 生成访问令牌 token_data = {"sub": str(new_admin.id)} @@ -275,9 +297,15 @@ async def initialize_admin(admin_data: InitializeAdmin, db: AsyncSession = Depen @auth.get("/me", response_model=UserResponse) -async def read_users_me(current_user: User = Depends(get_current_user)): +async def read_users_me(current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)): """获取当前登录用户的个人信息""" - return current_user.to_dict() + user_dict = current_user.to_dict() + + if current_user.department_id: + result = await db.execute(select(Department.name).filter(Department.id == current_user.department_id)) + user_dict["department_name"] = result.scalar_one_or_none() + + return user_dict # 路由:更新个人资料 @@ -356,6 +384,7 @@ async def create_user( db: AsyncSession = Depends(get_db), ): """创建新用户(管理员权限)""" + user_repo = UserRepository() # 验证用户名 is_valid, error_msg = validate_username(user_data.username) @@ -366,9 +395,8 @@ async def create_user( ) # 检查用户名是否已存在 - result = await db.execute(select(User).filter(User.username == user_data.username)) - existing_user = result.scalar_one_or_none() - if existing_user: + users = await user_repo.list_users() + if any(u.username == user_data.username for u in users): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="用户名已存在", @@ -376,28 +404,25 @@ async def create_user( # 检查手机号是否已存在(如果提供了) if user_data.phone_number: - result = await db.execute(select(User).filter(User.phone_number == user_data.phone_number)) - existing_phone = result.scalar_one_or_none() - if existing_phone: + if await user_repo.exists_by_phone(user_data.phone_number): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="手机号已存在", ) # 生成唯一的user_id - result = await db.execute(select(User.user_id)) - existing_user_ids = [user_id for (user_id,) in result.all()] + existing_user_ids = await user_repo.get_all_user_ids() user_id = generate_unique_user_id(user_data.username, existing_user_ids) # 创建新用户 hashed_password = AuthUtils.hash_password(user_data.password) # 检查角色权限 - # 超级管理员可以创建任何类型的用户 - if user_data.role == "superadmin" and current_user.role != "superadmin": + # 禁止创建超级管理员账户(系统只能有一个超级管理员) + if user_data.role == "superadmin": raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="只有超级管理员才能创建超级管理员账户", + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能创建超级管理员账户", ) # 管理员只能创建普通用户 @@ -412,7 +437,11 @@ async def create_user( # 超级管理员创建用户时,使用指定的部门或默认部门 department_id = user_data.department_id if department_id is None: - department_id = await get_default_department_id(db) + # 获取默认部门 + dept_repo = DepartmentRepository() + departments = await dept_repo.list_departments() + default_dept = next((d for d in departments if d.name == "默认部门"), None) + department_id = default_dept.id if default_dept else None else: # 普通管理员创建用户时,自动继承该管理员的部门 department_id = current_user.department_id @@ -423,19 +452,17 @@ async def create_user( detail="普通管理员不能指定部门", ) - new_user = User( - username=user_data.username, - user_id=user_id, - phone_number=user_data.phone_number, - password_hash=hashed_password, - role=user_data.role, - department_id=department_id, + new_user = await user_repo.create( + { + "username": user_data.username, + "user_id": user_id, + "phone_number": user_data.phone_number, + "password_hash": hashed_password, + "role": user_data.role, + "department_id": department_id, + } ) - db.add(new_user) - await db.commit() - await db.refresh(new_user) - # 记录操作 await log_operation( db, current_user.id, "创建用户", f"创建用户: {user_data.username}, 角色: {user_data.role}", request @@ -449,28 +476,20 @@ async def create_user( async def read_users( skip: int = 0, limit: int = 100, current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db) ): + user_repo = UserRepository() + # 部门隔离逻辑 if current_user.role == "superadmin": - # 超级管理员可以看到所有用户,使用 JOIN 获取部门名称 - result = await db.execute( - select(User, Department.name.label("department_name")) - .outerjoin(Department, User.department_id == Department.id) - .filter(User.is_deleted == 0) - .offset(skip) - .limit(limit) - ) + # 超级管理员可以看到所有用户 + users_with_dept = await user_repo.list_with_department(skip=skip, limit=limit) else: # 普通管理员只能看到本部门用户 - result = await db.execute( - select(User, Department.name.label("department_name")) - .outerjoin(Department, User.department_id == Department.id) - .filter(User.is_deleted == 0, User.department_id == current_user.department_id) - .offset(skip) - .limit(limit) + users_with_dept = await user_repo.list_with_department( + skip=skip, limit=limit, department_id=current_user.department_id ) - rows = result.all() + users = [] - for user, dept_name in rows: + for user, dept_name in users_with_dept: user_dict = user.to_dict() user_dict["department_name"] = dept_name users.append(user_dict) @@ -490,6 +509,18 @@ async def read_user(user_id: int, current_user: User = Depends(get_admin_user), return user.to_dict() +async def check_department_admin_count(db: AsyncSession, department_id: int, exclude_user_id: int) -> int: + """检查部门中管理员数量(排除指定用户)""" + result = await db.execute( + select(func.count(User.id)).filter( + User.department_id == department_id, + User.role == "admin", + User.id != exclude_user_id, + ) + ) + return result.scalar() + + # 路由:更新用户信息(管理员权限) @auth.put("/users/{user_id}", response_model=UserResponse) async def update_user( @@ -541,6 +572,14 @@ async def update_user( update_details.append("密码已更新") if user_data.role is not None: + # 检查是否将管理员降级为普通用户 + if user.role == "admin" and user_data.role == "user" and user.department_id is not None: + admin_count = await check_department_admin_count(db, user.department_id, user_id) + if admin_count <= 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能将管理员降级为普通用户,因为该用户是当前部门的唯一管理员", + ) user.role = user_data.role update_details.append(f"角色: {user_data.role}") @@ -553,12 +592,22 @@ async def update_user( update_details.append(f"头像: {user_data.avatar or '已清空'}") # 部门修改权限控制(只有超级管理员可以修改用户部门) - if user_data.department_id is not None: + if user_data.department_id is not None and user_data.department_id != user.department_id: if current_user.role != "superadmin": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="只有超级管理员才能修改用户部门", ) + + # 检查该用户是否是当前部门的唯一管理员 + if user.role == "admin" and user.department_id is not None: + admin_count = await check_department_admin_count(db, user.department_id, user_id) + if admin_count <= 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能修改该用户的部门,因为该用户是当前部门的唯一管理员", + ) + user.department_id = user_data.department_id update_details.append(f"部门ID: {user_data.department_id}") @@ -583,33 +632,18 @@ async def delete_user( detail="用户不存在", ) - # 检查权限 + # 不能删除超级管理员账户 if user.role == "superadmin": - # 只有超级管理员可以删除超级管理员 - if current_user.role != "superadmin": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="只有超级管理员才能删除超级管理员账户", - ) - - # 检查是否是最后一个超级管理员 - result = await db.execute( - select(func.count(User.id)).filter(User.role == "superadmin", User.is_deleted == 0) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能删除超级管理员账户", ) - superadmin_count = result.scalar() - if superadmin_count <= 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="不能删除最后一个超级管理员账户", - ) # 检查是否是部门的唯一管理员 if user.role == "admin": result = await db.execute( select(func.count(User.id)).filter( - User.department_id == user.department_id, - User.role == "admin", - User.is_deleted == 0 + User.department_id == user.department_id, User.role == "admin", User.is_deleted == 0 ) ) admin_count = result.scalar() @@ -639,10 +673,10 @@ async def delete_user( import hashlib # 生成4位哈希(基于user_id保证唯一性) - hash_suffix = hashlib.md5(user.user_id.encode()).hexdigest()[:4] + hash_suffix = hashlib.sha256(user.user_id.encode()).hexdigest()[:4] user.is_deleted = 1 - user.deleted_at = utc_now() + user.deleted_at = utc_now_naive() user.username = f"已注销用户-{hash_suffix}" user.phone_number = None # 清空手机号,释放该手机号供其他用户使用 user.password_hash = "DELETED" # 禁止登录 @@ -737,3 +771,58 @@ async def upload_user_avatar( except Exception as e: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"头像上传失败: {str(e)}") + + +# 路由:模拟用户登录(超级管理员专用) +@auth.post("/impersonate/{user_id}", response_model=Token) +async def impersonate_user( + user_id: int, + request: Request, + current_user: User = Depends(get_superadmin_user), + db: AsyncSession = Depends(get_db), +): + """超级管理员模拟其他用户登录""" + # 查找目标用户 + result = await db.execute(select(User).filter(User.id == user_id, User.is_deleted == 0)) + target_user = result.scalar_one_or_none() + if target_user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="用户不存在", + ) + + # 不能模拟超级管理员 + if target_user.role == "superadmin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="不能模拟超级管理员账户", + ) + + # 生成访问令牌 + token_data = {"sub": str(target_user.id)} + access_token = AuthUtils.create_access_token(token_data) + + # 获取部门名称 + department_name = None + if target_user.department_id: + result = await db.execute(select(Department.name).filter(Department.id == target_user.department_id)) + department_name = result.scalar_one_or_none() + + # 记录操作(危险操作标记) + await log_operation(db, current_user.id, "⚠️ 危险操作-模拟用户", f"模拟用户: {target_user.username}", request) + + # 控制台警告日志 + logger.warning(f"⚠️ [危险操作] 超级管理员 {current_user.username} 模拟登录用户: {target_user.username}") + + return { + "access_token": access_token, + "token_type": "bearer", + "user_id": target_user.id, + "username": target_user.username, + "user_id_login": target_user.user_id, + "phone_number": target_user.phone_number, + "avatar": target_user.avatar, + "role": target_user.role, + "department_id": target_user.department_id, + "department_name": department_name, + } diff --git a/server/routers/chat_router.py b/server/routers/chat_router.py index 030ede903..c56c273bc 100644 --- a/server/routers/chat_router.py +++ b/server/routers/chat_router.py @@ -1,32 +1,32 @@ import asyncio -import json import traceback import uuid from fastapi import APIRouter, Body, Depends, HTTPException, Query, UploadFile, File from fastapi.responses import StreamingResponse -from langchain.messages import AIMessageChunk, HumanMessage, AIMessage -from langgraph.types import Command from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select -from src.storage.db.models import User, MessageFeedback, Message, Conversation -from src.storage.conversation import ConversationManager -from src.storage.db.manager import db_manager +from src.storage.postgres.models_business import User from server.routers.auth_router import get_admin_user from server.utils.auth_middleware import get_db, get_required_user from src import executor from src import config as conf from src.agents import agent_manager from src.models import select_model -from src.plugins.guard import content_guard -from src.services.doc_converter import ( - ATTACHMENT_ALLOWED_EXTENSIONS, - MAX_ATTACHMENT_SIZE_BYTES, - convert_upload_to_markdown, +from src.services.chat_stream_service import get_agent_state_view, stream_agent_chat, stream_agent_resume +from src.services.conversation_service import ( + create_thread_view, + delete_thread_attachment_view, + delete_thread_view, + list_thread_attachments_view, + list_threads_view, + update_thread_view, + upload_thread_attachment_view, ) -from src.utils.datetime_utils import utc_isoformat +from src.services.feedback_service import get_message_feedback_view, submit_message_feedback_view +from src.services.history_query_service import get_agent_history_view +from src.repositories.agent_config_repository import AgentConfigRepository from src.utils.logging_config import logger from src.utils.image_processor import process_uploaded_image @@ -44,6 +44,25 @@ class ImageUploadResponse(BaseModel): error: str | None = None +class AgentConfigCreate(BaseModel): + name: str + description: str | None = None + icon: str | None = None + pics: list[str] | None = None + examples: list[str] | None = None + config_json: dict | None = None + set_default: bool = False + + +class AgentConfigUpdate(BaseModel): + name: str | None = None + description: str | None = None + icon: str | None = None + pics: list[str] | None = None + examples: list[str] | None = None + config_json: dict | None = None + + chat = APIRouter(prefix="/chat", tags=["chat"]) # ============================================================================= @@ -96,283 +115,6 @@ async def set_default_agent(request_data: dict = Body(...), current_user=Depends raise HTTPException(status_code=500, detail=f"设置默认智能体出错: {str(e)}") -# ============================================================================= -# > === 对话分组 === -# ============================================================================= - - -async def _get_langgraph_messages(agent_instance, config_dict): - graph = await agent_instance.get_graph() - state = await graph.aget_state(config_dict) - - if not state or not state.values: - logger.warning("No state found in LangGraph") - return None - - return state.values.get("messages", []) - - -def _extract_agent_state(values: dict) -> dict: - if not isinstance(values, dict): - return {} - - def _norm_list(v): - if v is None: - return [] - if isinstance(v, (list, tuple)): - return list(v) - return [v] - - result = {} - result["todos"] = _norm_list(values.get("todos"))[:20] - result["files"] = _norm_list(values.get("files"))[:50] - - return result - - -async def _get_existing_message_ids(conv_mgr, thread_id): - """获取已保存的消息ID集合""" - existing_messages = await conv_mgr.get_messages_by_thread_id(thread_id) - return {msg.extra_metadata["id"] for msg in existing_messages if msg.extra_metadata and "id" in msg.extra_metadata} - - -async def _save_ai_message(conv_mgr, thread_id, msg_dict): - """保存AI消息和相关的工具调用""" - content = msg_dict.get("content", "") - tool_calls_data = msg_dict.get("tool_calls", []) - - # 保存AI消息 - ai_msg = await conv_mgr.add_message_by_thread_id( - thread_id=thread_id, - role="assistant", - content=content, - message_type="text", - extra_metadata=msg_dict, - ) - - # 保存工具调用 - if tool_calls_data: - logger.debug(f"Saving {len(tool_calls_data)} tool calls from AI message") - for tc in tool_calls_data: - await conv_mgr.add_tool_call( - message_id=ai_msg.id, - tool_name=tc.get("name", "unknown"), - tool_input=tc.get("args", {}), - status="pending", - langgraph_tool_call_id=tc.get("id"), - ) - - logger.debug(f"Saved AI message {ai_msg.id} with {len(tool_calls_data)} tool calls") - - -async def _save_tool_message(conv_mgr, msg_dict): - """保存工具执行结果""" - tool_call_id = msg_dict.get("tool_call_id") - content = msg_dict.get("content", "") - name = msg_dict.get("name", "") - - if not tool_call_id: - return - - # 确保tool_output是字符串类型 - if isinstance(content, list): - tool_output = json.dumps(content) if content else "" - else: - tool_output = str(content) - - # 更新工具调用结果 - updated_tc = await conv_mgr.update_tool_call_output( - langgraph_tool_call_id=tool_call_id, - tool_output=tool_output, - status="success", - ) - - if updated_tc: - logger.debug(f"Updated tool_call {tool_call_id} ({name}) with output") - else: - logger.warning(f"Tool call {tool_call_id} not found for update") - - -async def _require_user_conversation(conv_mgr: ConversationManager, thread_id: str, user_id: str) -> Conversation: - conversation = await conv_mgr.get_conversation_by_thread_id(thread_id) - if not conversation or conversation.user_id != str(user_id) or conversation.status == "deleted": - raise HTTPException(status_code=404, detail="对话线程不存在") - return conversation - - -def _serialize_attachment(record: dict) -> dict: - return { - "file_id": record.get("file_id"), - "file_name": record.get("file_name"), - "file_type": record.get("file_type"), - "file_size": record.get("file_size", 0), - "status": record.get("status", "parsed"), - "uploaded_at": record.get("uploaded_at"), - "truncated": record.get("truncated", False), - } - - -async def save_partial_message(conv_mgr, thread_id, full_msg=None, error_message=None, error_type="interrupted"): - """ - 统一保存AI消息到数据库的函数 - - Args: - conv_mgr: 对话管理器 - thread_id: 线程ID - full_msg: 完整的AI消息对象(可选) - error_message: 纯错误消息文本(当full_msg为空时使用) - error_type: 错误类型标识 - """ - try: - extra_metadata = { - "error_type": error_type, - "is_error": True, - "error_message": error_message or f"发生错误: {error_type}", - } - if full_msg: - # 保存部分生成的AI消息 - msg_dict = full_msg.model_dump() if hasattr(full_msg, "model_dump") else {} - content = full_msg.content if hasattr(full_msg, "content") else str(full_msg) - extra_metadata = msg_dict | extra_metadata - else: - content = "" - - saved_msg = await conv_mgr.add_message_by_thread_id( - thread_id=thread_id, - role="assistant", - content=content, - message_type="text", - extra_metadata=extra_metadata, - ) - - logger.info(f"Saved message due to {error_type}: {saved_msg.id}") - return saved_msg - - except Exception as e: - logger.error(f"Error saving message: {e}") - logger.error(traceback.format_exc()) - return None - - -async def save_messages_from_langgraph_state( - agent_instance, - thread_id, - conv_mgr, - config_dict, -): - """ - 从 LangGraph state 中读取完整消息并保存到数据库 - 这样可以获得完整的 tool_calls 参数 - """ - try: - messages = await _get_langgraph_messages(agent_instance, config_dict) - if messages is None: - return - - logger.debug(f"Retrieved {len(messages)} messages from LangGraph state") - existing_ids = await _get_existing_message_ids(conv_mgr, thread_id) - - for msg in messages: - msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else {} - msg_type = msg_dict.get("type", "unknown") - - if msg_type == "human" or msg.id in existing_ids: - continue - - if msg_type == "ai": - await _save_ai_message(conv_mgr, thread_id, msg_dict) - elif msg_type == "tool": - await _save_tool_message(conv_mgr, msg_dict) - else: - logger.warning(f"Unknown message type: {msg_type}, skipping") - continue - - logger.debug(f"Processed message type={msg_type}") - - logger.info("Saved messages from LangGraph state") - - except Exception as e: - logger.error(f"Error saving messages from LangGraph state: {e}") - logger.error(traceback.format_exc()) - - -async def check_and_handle_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): - """检查并处理 LangGraph 中断状态,发送人工审批请求到前端""" - try: - # 获取 agent 的 graph 对象 - graph = await agent.get_graph() - - # 获取当前状态,检查是否有中断 - state = await graph.aget_state(langgraph_config) - - if not state or not state.values: - logger.debug("No state found when checking for interrupts") - return - - # 检查是否有中断信息 - # LangGraph 中断信息通常在 state.tasks 或 __interrupt__ 字段中 - interrupt_info = None - - # 方法1: 检查 state.tasks 中的中断 - if hasattr(state, "tasks") and state.tasks: - for task in state.tasks: - if hasattr(task, "interrupts") and task.interrupts: - interrupt_info = task.interrupts[0] # 取第一个中断 - break - - # 方法2: 检查 state.values 中的 __interrupt__ 字段 - if not interrupt_info and state.values: - interrupt_data = state.values.get("__interrupt__") - if interrupt_data and isinstance(interrupt_data, list) and len(interrupt_data) > 0: - interrupt_info = interrupt_data[0] - - # 方法3: 检查 state.next 字段,如果指向中断节点 - if not interrupt_info and hasattr(state, "next") and state.next: - # 如果 next 指向某个需要审批的节点,可能需要额外处理 - logger.debug(f"State next nodes: {state.next}") - - if interrupt_info: - logger.info(f"Human approval interrupt detected: {interrupt_info}") - - # 提取中断信息 - question = "是否批准以下操作?" - operation = "需要人工审批的操作" - - if isinstance(interrupt_info, dict): - question = interrupt_info.get("question", question) - operation = interrupt_info.get("operation", operation) - elif isinstance(interrupt_info, (list, tuple)) and len(interrupt_info) > 0: - # 有些情况下中断信息可能是元组形式 - first_interrupt = interrupt_info[0] - if isinstance(first_interrupt, dict): - question = first_interrupt.get("question", question) - operation = first_interrupt.get("operation", operation) - else: - operation = str(first_interrupt) - else: - operation = str(interrupt_info) - - # 发送人工审批请求到前端 - logger.info(f"Sending human approval request - question: {question}, operation: {operation}") - - yield make_chunk( - status="human_approval_required", - thread_id=thread_id, - interrupt_info={"question": question, "operation": operation}, - ) - - else: - logger.debug("No human approval interrupt detected") - - except Exception as e: - logger.error(f"Error checking for interrupts: {e}") - logger.error(traceback.format_exc()) - # 不抛出异常,避免影响主流程 - - -# ============================================================================= - - @chat.post("/call") async def call(query: str = Body(...), meta: dict = Body(None), current_user: User = Depends(get_required_user)): """调用模型进行简单问答(需要登录)""" @@ -447,6 +189,172 @@ async def get_single_agent(agent_id: str, current_user: User = Depends(get_requi raise HTTPException(status_code=500, detail=f"获取智能体信息出错: {str(e)}") +@chat.get("/agent/{agent_id}/configs") +async def list_agent_configs( + agent_id: str, + current_user: User = Depends(get_required_user), + db: AsyncSession = Depends(get_db), +): + if not current_user.department_id: + raise HTTPException(status_code=400, detail="当前用户未绑定部门") + + if not agent_manager.get_agent(agent_id): + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + repo = AgentConfigRepository(db) + items = await repo.list_by_department_agent(department_id=current_user.department_id, agent_id=agent_id) + if not items: + await repo.get_or_create_default( + department_id=current_user.department_id, + agent_id=agent_id, + created_by=str(current_user.id), + ) + items = await repo.list_by_department_agent(department_id=current_user.department_id, agent_id=agent_id) + + configs = [ + { + "id": item.id, + "name": item.name, + "description": item.description, + "icon": item.icon, + "pics": item.pics or [], + "examples": item.examples or [], + "is_default": bool(item.is_default), + } + for item in items + ] + return {"configs": configs} + + +@chat.get("/agent/{agent_id}/configs/{config_id}") +async def get_agent_config_profile( + agent_id: str, + config_id: int, + current_user: User = Depends(get_required_user), + db: AsyncSession = Depends(get_db), +): + if not current_user.department_id: + raise HTTPException(status_code=400, detail="当前用户未绑定部门") + + if not agent_manager.get_agent(agent_id): + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + repo = AgentConfigRepository(db) + item = await repo.get_by_id(config_id) + if not item or item.agent_id != agent_id or item.department_id != current_user.department_id: + raise HTTPException(status_code=404, detail="配置不存在") + + return {"config": item.to_dict()} + + +@chat.post("/agent/{agent_id}/configs") +async def create_agent_config_profile( + agent_id: str, + payload: AgentConfigCreate, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + if not current_user.department_id: + raise HTTPException(status_code=400, detail="当前用户未绑定部门") + + if not agent_manager.get_agent(agent_id): + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + repo = AgentConfigRepository(db) + item = await repo.create( + department_id=current_user.department_id, + agent_id=agent_id, + name=payload.name, + description=payload.description, + icon=payload.icon, + pics=payload.pics, + examples=payload.examples, + config_json=payload.config_json, + is_default=payload.set_default, + created_by=str(current_user.id), + ) + if payload.set_default: + item = await repo.set_default(config=item, updated_by=str(current_user.id)) + + return {"config": item.to_dict()} + + +@chat.put("/agent/{agent_id}/configs/{config_id}") +async def update_agent_config_profile( + agent_id: str, + config_id: int, + payload: AgentConfigUpdate, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + if not current_user.department_id: + raise HTTPException(status_code=400, detail="当前用户未绑定部门") + + if not agent_manager.get_agent(agent_id): + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + repo = AgentConfigRepository(db) + item = await repo.get_by_id(config_id) + if not item or item.agent_id != agent_id or item.department_id != current_user.department_id: + raise HTTPException(status_code=404, detail="配置不存在") + + updated = await repo.update( + item, + name=payload.name, + description=payload.description, + icon=payload.icon, + pics=payload.pics, + examples=payload.examples, + config_json=payload.config_json, + updated_by=str(current_user.id), + ) + return {"config": updated.to_dict()} + + +@chat.post("/agent/{agent_id}/configs/{config_id}/set_default") +async def set_agent_config_default( + agent_id: str, + config_id: int, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + if not current_user.department_id: + raise HTTPException(status_code=400, detail="当前用户未绑定部门") + + if not agent_manager.get_agent(agent_id): + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + repo = AgentConfigRepository(db) + item = await repo.get_by_id(config_id) + if not item or item.agent_id != agent_id or item.department_id != current_user.department_id: + raise HTTPException(status_code=404, detail="配置不存在") + + updated = await repo.set_default(config=item, updated_by=str(current_user.id)) + return {"config": updated.to_dict()} + + +@chat.delete("/agent/{agent_id}/configs/{config_id}") +async def delete_agent_config_profile( + agent_id: str, + config_id: int, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + if not current_user.department_id: + raise HTTPException(status_code=400, detail="当前用户未绑定部门") + + if not agent_manager.get_agent(agent_id): + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + repo = AgentConfigRepository(db) + item = await repo.get_by_id(config_id) + if not item or item.agent_id != agent_id or item.department_id != current_user.department_id: + raise HTTPException(status_code=404, detail="配置不存在") + + await repo.delete(config=item, updated_by=str(current_user.id)) + return {"success": True} + + @chat.post("/agent/{agent_id}") async def chat_agent( agent_id: str, @@ -458,8 +366,6 @@ async def chat_agent( db: AsyncSession = Depends(get_db), ): """使用特定智能体进行对话(需要登录)""" - start_time = asyncio.get_event_loop().time() - logger.info(f"agent_id: {agent_id}, query: {query}, config: {config}, meta: {meta}") logger.info(f"image_content present: {image_content is not None}") if image_content: @@ -480,231 +386,18 @@ async def chat_agent( "has_image": bool(image_content), } ) - - # 将meta和thread_id整合到config中 - def make_chunk(content=None, **kwargs): - return ( - json.dumps( - {"request_id": meta.get("request_id"), "response": content, **kwargs}, ensure_ascii=False - ).encode("utf-8") - + b"\n" - ) - - async def stream_messages(): - # 构建多模态消息 - if image_content: - # 多模态消息格式 - human_message = HumanMessage( - content=[ - {"type": "text", "text": query}, - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_content}"}}, - ] - ) - message_type = "multimodal_image" - else: - # 普通文本消息 - human_message = HumanMessage(content=query) - message_type = "text" - - # 代表服务端已经收到了请求,发送前端友好的消息格式 - init_msg = {"role": "user", "content": query, "type": "human"} - - # 如果有图片,添加图片相关信息 - if image_content: - init_msg["message_type"] = "multimodal_image" - init_msg["image_content"] = image_content - else: - init_msg["message_type"] = "text" - - yield make_chunk(status="init", meta=meta, msg=init_msg) - - # Input guard - if conf.enable_content_guard and await content_guard.check(query): - yield make_chunk( - status="error", error_type="content_guard_blocked", error_message="输入内容包含敏感词", meta=meta - ) - return - - try: - agent = agent_manager.get_agent(agent_id) - except Exception as e: - logger.error(f"Error getting agent {agent_id}: {e}, {traceback.format_exc()}") - yield make_chunk( - status="error", - error_type="agent_error", - error_message=f"智能体 {agent_id} 获取失败: {str(e)}", - meta=meta, - ) - return - - messages = [human_message] - - # 构造运行时配置,如果没有thread_id则生成一个 - user_id = str(current_user.id) - thread_id = config.get("thread_id") - input_context = {"user_id": user_id, "thread_id": thread_id} - - if not thread_id: - thread_id = str(uuid.uuid4()) - logger.warning(f"No thread_id provided, generated new thread_id: {thread_id}") - - try: - async with db_manager.get_async_session_context() as db: - # Initialize conversation manager - conv_manager = ConversationManager(db) - - # Save user message - try: - await conv_manager.add_message_by_thread_id( - thread_id=thread_id, - role="user", - content=query, - message_type=message_type, - image_content=image_content, - extra_metadata={"raw_message": human_message.model_dump()}, - ) - except Exception as e: - logger.error(f"Error saving user message: {e}") - - try: - assert thread_id, "thread_id is required" - attachments = await conv_manager.get_attachments_by_thread_id(thread_id) - input_context["attachments"] = attachments - logger.debug(f"Loaded {len(attachments)} attachments for thread_id={thread_id}") - except Exception as e: - logger.error(f"Error loading attachments for thread_id={thread_id}: {e}") - input_context["attachments"] = [] - - full_msg = None - accumulated_content = [] - langgraph_config = {"configurable": input_context} - async for msg, metadata in agent.stream_messages(messages, input_context=input_context): - if isinstance(msg, AIMessageChunk): - accumulated_content.append(msg.content) - - content_for_check = "".join(accumulated_content[-10:]) - if conf.enable_content_guard and await content_guard.check_with_keywords(content_for_check): - logger.warning("Sensitive content detected in stream") - full_msg = AIMessage(content="".join(accumulated_content)) - await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked") - meta["time_cost"] = asyncio.get_event_loop().time() - start_time - yield make_chunk(status="interrupted", message="检测到敏感内容,已中断输出", meta=meta) - return - - yield make_chunk(content=msg.content, msg=msg.model_dump(), metadata=metadata, status="loading") - - else: - msg_dict = msg.model_dump() - yield make_chunk(msg=msg_dict, metadata=metadata, status="loading") - - try: - if msg_dict.get("type") == "tool": - graph = await agent.get_graph() - state = await graph.aget_state(langgraph_config) - agent_state = _extract_agent_state(getattr(state, "values", {})) if state else {} - if agent_state: - yield make_chunk(status="agent_state", agent_state=agent_state, meta=meta) - except Exception as e: - logger.error(f"Error processing tool message: {e}") - pass - - if not full_msg and accumulated_content: - full_msg = AIMessage(content="".join(accumulated_content)) - - if ( - conf.enable_content_guard - and hasattr(full_msg, "content") - and await content_guard.check(full_msg.content) - ): - logger.warning("Sensitive content detected in final message") - await save_partial_message(conv_manager, thread_id, full_msg, "content_guard_blocked") - meta["time_cost"] = asyncio.get_event_loop().time() - start_time - yield make_chunk(status="interrupted", message="检测到敏感内容,已中断输出", meta=meta) - return - - # After streaming finished, check for interrupts and save messages - - # Check for human approval interrupts - async for chunk in check_and_handle_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): - yield chunk - - meta["time_cost"] = asyncio.get_event_loop().time() - start_time - try: - graph = await agent.get_graph() - state = await graph.aget_state(langgraph_config) - agent_state = _extract_agent_state(getattr(state, "values", {})) if state else {} - except Exception: - agent_state = {} - - if agent_state: - yield make_chunk(status="agent_state", agent_state=agent_state, meta=meta) - - yield make_chunk(status="finished", meta=meta) - - # Save all messages from LangGraph state - await save_messages_from_langgraph_state( - agent_instance=agent, - thread_id=thread_id, - conv_mgr=conv_manager, - config_dict=langgraph_config, - ) - - except (asyncio.CancelledError, ConnectionError) as e: - # 客户端主动中断连接,检查中断并保存已生成的部分内容 - logger.warning(f"Client disconnected, cancelling stream: {e}") - - # Run save in a separate task to avoid cancellation - async def save_cleanup(): - nonlocal full_msg - if not full_msg and accumulated_content: - full_msg = AIMessage(content="".join(accumulated_content)) - - async with db_manager.get_async_session_context() as new_db: - new_conv_manager = ConversationManager(new_db) - await save_partial_message( - new_conv_manager, - thread_id, - full_msg=full_msg, - error_message="对话已中断" if not full_msg else None, - error_type="interrupted", - ) - - # Create a task and await it, shielding it from cancellation - # ensuring the DB operation completes even if the stream is cancelled - cleanup_task = asyncio.create_task(save_cleanup()) - try: - await asyncio.shield(cleanup_task) - except asyncio.CancelledError: - pass - except Exception as exc: - logger.error(f"Error during cleanup save: {exc}") - - # 通知前端中断(可能发送不到,但用于一致性) - yield make_chunk(status="interrupted", message="对话已中断", meta=meta) - - except Exception as e: - logger.error(f"Error streaming messages: {e}, {traceback.format_exc()}") - - error_msg = f"Error streaming messages: {e}" - error_type = "unexpected_error" - - if not full_msg and accumulated_content: - full_msg = AIMessage(content="".join(accumulated_content)) - - # 保存错误消息到数据库 - async with db_manager.get_async_session_context() as new_db: - new_conv_manager = ConversationManager(new_db) - await save_partial_message( - new_conv_manager, - thread_id, - full_msg=full_msg, - error_message=error_msg, - error_type=error_type, - ) - - yield make_chunk(status="error", error_type=error_type, error_message=error_msg, meta=meta) - - return StreamingResponse(stream_messages(), media_type="application/json") + return StreamingResponse( + stream_agent_chat( + agent_id=agent_id, + query=query, + config=config, + meta=meta, + image_content=image_content, + current_user=current_user, + db=db, + ), + media_type="application/json", + ) # ============================================================================= @@ -732,11 +425,11 @@ async def resume_agent_chat( agent_id: str, thread_id: str = Body(...), approved: bool = Body(...), + config: dict = Body({}), current_user: User = Depends(get_required_user), db: AsyncSession = Depends(get_db), ): """恢复被人工审批中断的对话(需要登录)""" - start_time = asyncio.get_event_loop().time() logger.info(f"Resuming agent_id: {agent_id}, thread_id: {thread_id}, approved: {approved}") meta = { @@ -747,103 +440,18 @@ async def resume_agent_chat( } if "request_id" not in meta or not meta.get("request_id"): meta["request_id"] = str(uuid.uuid4()) - - async def stream_resume(): - # 定义resume专用的make_chunk函数,与主聊天端点保持一致 - def make_resume_chunk(content=None, **kwargs): - return ( - json.dumps( - {"request_id": meta.get("request_id"), "response": content, **kwargs}, ensure_ascii=False - ).encode("utf-8") - + b"\n" - ) - - try: - agent = agent_manager.get_agent(agent_id) - except Exception as e: - logger.error(f"Error getting agent {agent_id}: {e}, {traceback.format_exc()}") - yield ( - f'{{"request_id": "{meta.get("request_id")}", "message": ' - f'"Error getting agent {agent_id}: {e}", "status": "error"}}\n' - ) - return - - # 发送init状态块,与主聊天端点保持一致 - init_msg = {"type": "system", "content": f"Resume with approved: {approved}"} - yield make_resume_chunk(status="init", meta=meta, msg=init_msg) - - # 使用 Command(resume=approved) 恢复执行 - resume_command = Command(resume=approved) - graph = await agent.get_graph() - - # 加载 context(包含 tools, model 等配置) - input_context = {"user_id": str(current_user.id), "thread_id": thread_id} - context = agent.context_schema.from_file(module_name=agent.module_name, input_context=input_context) - logger.debug(f"Resume with context: {context}") - - # 创建流式数据源 - stream_source = graph.astream( - resume_command, context=context, config={"configurable": input_context}, stream_mode="messages" - ) - - try: - async with db_manager.get_async_session_context() as db: - async for msg, metadata in stream_source: - # 确保msg有正确的ID结构 - msg_dict = msg.model_dump() - if "id" not in msg_dict: - msg_dict["id"] = str(uuid.uuid4()) - - yield make_resume_chunk( - content=getattr(msg, "content", ""), msg=msg_dict, metadata=metadata, status="loading" - ) - - # Check for new interrupts (support multiple human in the loop) - langgraph_config = {"configurable": input_context} - async for chunk in check_and_handle_interrupts( - agent, langgraph_config, make_resume_chunk, meta, thread_id - ): - yield chunk - - meta["time_cost"] = asyncio.get_event_loop().time() - start_time - yield make_resume_chunk(status="finished", meta=meta) - - # 保存消息到数据库 - conv_manager = ConversationManager(db) - await save_messages_from_langgraph_state( - agent_instance=agent, - thread_id=thread_id, - conv_mgr=conv_manager, - config_dict=langgraph_config, - ) - - except (asyncio.CancelledError, ConnectionError) as e: - # 客户端主动中断连接 - logger.warning(f"Client disconnected during resume: {e}") - - # 保存中断消息到数据库 - async with db_manager.get_async_session_context() as new_db: - new_conv_manager = ConversationManager(new_db) - await save_partial_message( - new_conv_manager, thread_id, error_message="对话恢复已中断", error_type="resume_interrupted" - ) - - yield make_resume_chunk(status="interrupted", message="对话恢复已中断", meta=meta) - - except Exception as e: - # 处理其他异常 - logger.error(f"Error during resume: {e}, {traceback.format_exc()}") - - # 保存错误消息到数据库 - async with db_manager.get_async_session_context() as new_db: - new_conv_manager = ConversationManager(new_db) - await save_partial_message( - new_conv_manager, thread_id, error_message=f"Error during resume: {e}", error_type="resume_error" - ) - - yield make_resume_chunk(message=f"Error during resume: {e}", status="error") - - return StreamingResponse(stream_resume(), media_type="application/json") + return StreamingResponse( + stream_agent_resume( + agent_id=agent_id, + thread_id=thread_id, + approved=approved, + meta=meta, + config=config, + current_user=current_user, + db=db, + ), + media_type="application/json", + ) @chat.post("/agent/{agent_id}/config") @@ -859,6 +467,35 @@ async def save_agent_config( if not (agent := agent_manager.get_agent(agent_id)): raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + # === 校验知识库权限 === + from src import knowledge_base + + if "knowledges" in config and config["knowledges"]: + # 获取用户有权访问的知识库名称 + try: + user_info = {"role": current_user.role, "department_id": current_user.department_id} + accessible_databases = await knowledge_base.get_databases_by_user(user_info) + accessible_kb_names = { + db.get("name") for db in accessible_databases.get("databases", []) if db.get("name") + } + except Exception as db_error: + logger.warning(f"获取知识库列表失败: {db_error}") + # 如果获取失败,superadmin 可以访问所有,非 superadmin 无法访问任何 + if current_user.role != "superadmin": + raise HTTPException(status_code=500, detail="无法获取知识库列表") + # 回退:获取所有数据库名称 + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + + kb_repo = KnowledgeBaseRepository() + rows = await kb_repo.get_all() + accessible_kb_names = {row.name for row in rows if row.name} + + # 检查配置中的知识库是否都可用 + invalid_kbs = [kb for kb in config["knowledges"] if kb not in accessible_kb_names] + if invalid_kbs: + raise HTTPException(status_code=403, detail=f"无权访问以下知识库: {', '.join(invalid_kbs)}") + # === 校验结束 === + # 使用配置类的save_to_file方法保存配置 result = agent.context_schema.save_to_file(config, agent.module_name) @@ -869,6 +506,8 @@ async def save_agent_config( else: raise HTTPException(status_code=500, detail="保存智能体配置失败") + except HTTPException: + raise except Exception as e: logger.error(f"保存智能体配置出错: {e}, {traceback.format_exc()}") raise HTTPException(status_code=500, detail=f"保存智能体配置出错: {str(e)}") @@ -880,69 +519,12 @@ async def get_agent_history( ): """获取智能体历史消息(需要登录)- 包含用户反馈状态""" try: - # 获取Agent实例验证 - if not agent_manager.get_agent(agent_id): - raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") - - # Use new storage system ONLY - conv_manager = ConversationManager(db) - await _require_user_conversation(conv_manager, thread_id, str(current_user.id)) - messages = await conv_manager.get_messages_by_thread_id(thread_id) - - # 当前用户ID - 用于过滤反馈 - current_user_id = str(current_user.id) - - # Convert to frontend-compatible format - history = [] - for msg in messages: - # Map role to type that frontend expects - role_type_map = {"user": "human", "assistant": "ai", "tool": "tool", "system": "system"} - - # 查找当前用户的反馈 - user_feedback = None - if msg.feedbacks: - for feedback in msg.feedbacks: - if feedback.user_id == current_user_id: - user_feedback = { - "id": feedback.id, - "rating": feedback.rating, - "reason": feedback.reason, - "created_at": feedback.created_at.isoformat() if feedback.created_at else None, - } - break - - msg_dict = { - "id": msg.id, # Include message ID for feedback - "type": role_type_map.get(msg.role, msg.role), # human/ai/tool/system - "content": msg.content, - "created_at": msg.created_at.isoformat() if msg.created_at else None, - "error_type": msg.extra_metadata.get("error_type") if msg.extra_metadata else None, - "error_message": msg.extra_metadata.get("error_message") if msg.extra_metadata else None, - "extra_metadata": msg.extra_metadata, # 保留完整的metadata以备前端需要 - "message_type": msg.message_type, # 添加消息类型字段 - "image_content": msg.image_content, # 添加图片内容字段 - "feedback": user_feedback, # 添加当前用户反馈状态 - } - - # Add tool calls if present (for AI messages) - if msg.tool_calls and len(msg.tool_calls) > 0: - msg_dict["tool_calls"] = [ - { - "id": str(tc.id), - "name": tc.tool_name, - "function": {"name": tc.tool_name}, - "args": tc.tool_input or {}, - "tool_call_result": {"content": (tc.tool_output or "")} if tc.status == "success" else None, - "status": tc.status, - "error_message": tc.error_message, - } - for tc in msg.tool_calls - ] - - history.append(msg_dict) - - logger.info(f"Loaded {len(history)} messages with feedback for thread {thread_id}") - return {"history": history} + return await get_agent_history_view( + agent_id=agent_id, + thread_id=thread_id, + current_user_id=str(current_user.id), + db=db, + ) except Exception as e: logger.error(f"获取智能体历史消息出错: {e}, {traceback.format_exc()}") @@ -958,19 +540,12 @@ async def get_agent_state( ): """获取智能体当前状态(需要登录)""" try: - if not agent_manager.get_agent(agent_id): - raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") - - conv_manager = ConversationManager(db) - await _require_user_conversation(conv_manager, thread_id, str(current_user.id)) - - agent = agent_manager.get_agent(agent_id) - graph = await agent.get_graph() - langgraph_config = {"configurable": {"user_id": str(current_user.id), "thread_id": thread_id}} - state = await graph.aget_state(langgraph_config) - agent_state = _extract_agent_state(getattr(state, "values", {})) if state else {} - - return {"agent_state": agent_state} + return await get_agent_state_view( + agent_id=agent_id, + thread_id=thread_id, + current_user_id=str(current_user.id), + db=db, + ) except HTTPException: raise except Exception as e: @@ -1043,59 +618,21 @@ async def create_thread( thread: ThreadCreate, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_required_user) ): """创建新对话线程 (使用新存储系统)""" - thread_id = str(uuid.uuid4()) - logger.debug(f"thread.agent_id: {thread.agent_id}") - - # Create conversation using new storage system - conv_manager = ConversationManager(db) - conversation = await conv_manager.create_conversation( - user_id=str(current_user.id), + return await create_thread_view( agent_id=thread.agent_id, - title=thread.title or "新的对话", - thread_id=thread_id, + title=thread.title, metadata=thread.metadata, + db=db, + current_user_id=str(current_user.id), ) - logger.info(f"Created conversation with thread_id: {thread_id}") - - return { - "id": conversation.thread_id, - "user_id": conversation.user_id, - "agent_id": conversation.agent_id, - "title": conversation.title, - "created_at": conversation.created_at.isoformat(), - "updated_at": conversation.updated_at.isoformat(), - } - @chat.get("/threads", response_model=list[ThreadResponse]) async def list_threads( agent_id: str, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_required_user) ): """获取用户的所有对话线程 (使用新存储系统)""" - assert agent_id, "agent_id 不能为空" - - logger.debug(f"agent_id: {agent_id}") - - # Use new storage system - conv_manager = ConversationManager(db) - conversations = await conv_manager.list_conversations( - user_id=str(current_user.id), - agent_id=agent_id, - status="active", - ) - - return [ - { - "id": conv.thread_id, - "user_id": conv.user_id, - "agent_id": conv.agent_id, - "title": conv.title, - "created_at": conv.created_at.isoformat(), - "updated_at": conv.updated_at.isoformat(), - } - for conv in conversations - ] + return await list_threads_view(agent_id=agent_id, db=db, current_user_id=str(current_user.id)) @chat.delete("/thread/{thread_id}") @@ -1103,20 +640,7 @@ async def delete_thread( thread_id: str, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_required_user) ): """删除对话线程 (使用新存储系统)""" - # Use new storage system - conv_manager = ConversationManager(db) - conversation = await conv_manager.get_conversation_by_thread_id(thread_id) - - if not conversation or conversation.user_id != str(current_user.id): - raise HTTPException(status_code=404, detail="对话线程不存在") - - # Soft delete - success = await conv_manager.delete_conversation(thread_id, soft_delete=True) - - if not success: - raise HTTPException(status_code=500, detail="删除失败") - - return {"message": "删除成功"} + return await delete_thread_view(thread_id=thread_id, db=db, current_user_id=str(current_user.id)) class ThreadUpdate(BaseModel): @@ -1131,31 +655,13 @@ async def update_thread( current_user: User = Depends(get_required_user), ): """更新对话线程信息 (使用新存储系统)""" - # Use new storage system - conv_manager = ConversationManager(db) - conversation = await conv_manager.get_conversation_by_thread_id(thread_id) - - if not conversation or conversation.user_id != str(current_user.id) or conversation.status == "deleted": - raise HTTPException(status_code=404, detail="对话线程不存在") - - # Update conversation - updated_conv = await conv_manager.update_conversation( + return await update_thread_view( thread_id=thread_id, title=thread_update.title, + db=db, + current_user_id=str(current_user.id), ) - if not updated_conv: - raise HTTPException(status_code=500, detail="更新失败") - - return { - "id": updated_conv.thread_id, - "user_id": updated_conv.user_id, - "agent_id": updated_conv.agent_id, - "title": updated_conv.title, - "created_at": updated_conv.created_at.isoformat(), - "updated_at": updated_conv.updated_at.isoformat(), - } - @chat.post("/thread/{thread_id}/attachments", response_model=AttachmentResponse) async def upload_thread_attachment( @@ -1165,30 +671,12 @@ async def upload_thread_attachment( current_user: User = Depends(get_required_user), ): """上传并解析附件为 Markdown,附加到指定对话线程。""" - conv_manager = ConversationManager(db) - conversation = await _require_user_conversation(conv_manager, thread_id, str(current_user.id)) - - try: - conversion = await convert_upload_to_markdown(file) - except ValueError as exc: - raise HTTPException(status_code=400, detail=str(exc)) from exc - except Exception as exc: # noqa: BLE001 - logger.error(f"附件解析失败: {exc}") - raise HTTPException(status_code=500, detail="附件解析失败,请稍后重试") from exc - - attachment_record = { - "file_id": conversion.file_id, - "file_name": conversion.file_name, - "file_type": conversion.file_type, - "file_size": conversion.file_size, - "status": "parsed", - "markdown": conversion.markdown, - "uploaded_at": utc_isoformat(), - "truncated": conversion.truncated, - } - await conv_manager.add_attachment(conversation.id, attachment_record) - - return _serialize_attachment(attachment_record) + return await upload_thread_attachment_view( + thread_id=thread_id, + file=file, + db=db, + current_user_id=str(current_user.id), + ) @chat.get("/thread/{thread_id}/attachments", response_model=AttachmentListResponse) @@ -1198,16 +686,11 @@ async def list_thread_attachments( current_user: User = Depends(get_required_user), ): """列出当前对话线程的所有附件元信息。""" - conv_manager = ConversationManager(db) - conversation = await _require_user_conversation(conv_manager, thread_id, str(current_user.id)) - attachments = await conv_manager.get_attachments(conversation.id) - return { - "attachments": [_serialize_attachment(item) for item in attachments], - "limits": { - "allowed_extensions": sorted(ATTACHMENT_ALLOWED_EXTENSIONS), - "max_size_bytes": MAX_ATTACHMENT_SIZE_BYTES, - }, - } + return await list_thread_attachments_view( + thread_id=thread_id, + db=db, + current_user_id=str(current_user.id), + ) @chat.delete("/thread/{thread_id}/attachments/{file_id}") @@ -1218,12 +701,12 @@ async def delete_thread_attachment( current_user: User = Depends(get_required_user), ): """移除指定附件。""" - conv_manager = ConversationManager(db) - conversation = await _require_user_conversation(conv_manager, thread_id, str(current_user.id)) - removed = await conv_manager.remove_attachment(conversation.id, file_id) - if not removed: - raise HTTPException(status_code=404, detail="附件不存在或已被删除") - return {"message": "附件已删除"} + return await delete_thread_attachment_view( + thread_id=thread_id, + file_id=file_id, + db=db, + current_user_id=str(current_user.id), + ) # ============================================================================= @@ -1252,61 +735,14 @@ async def submit_message_feedback( current_user: User = Depends(get_required_user), ): """提交消息反馈(需要登录)""" - try: - # Validate rating - if feedback_data.rating not in ["like", "dislike"]: - raise HTTPException(status_code=422, detail="Rating must be 'like' or 'dislike'") - - # Verify message exists and get conversation to check permissions - message_result = await db.execute(select(Message).filter_by(id=message_id)) - message = message_result.scalar_one_or_none() - - if not message: - raise HTTPException(status_code=404, detail="Message not found") - - # Verify user has access to this message (through conversation) - conversation_result = await db.execute(select(Conversation).filter_by(id=message.conversation_id)) - conversation = conversation_result.scalar_one_or_none() - if not conversation or conversation.user_id != str(current_user.id): - raise HTTPException(status_code=403, detail="Access denied") - - # Check if feedback already exists (user can only submit once) - existing_feedback_result = await db.execute( - select(MessageFeedback).filter_by(message_id=message_id, user_id=str(current_user.id)) - ) - existing_feedback = existing_feedback_result.scalar_one_or_none() - - if existing_feedback: - raise HTTPException(status_code=409, detail="Feedback already submitted for this message") - - # Create new feedback - new_feedback = MessageFeedback( - message_id=message_id, - user_id=str(current_user.id), - rating=feedback_data.rating, - reason=feedback_data.reason, - ) - - db.add(new_feedback) - await db.commit() - await db.refresh(new_feedback) - - logger.info(f"User {current_user.id} submitted {feedback_data.rating} feedback for message {message_id}") - - return MessageFeedbackResponse( - id=new_feedback.id, - message_id=new_feedback.message_id, - rating=new_feedback.rating, - reason=new_feedback.reason, - created_at=new_feedback.created_at.isoformat(), - ) - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error submitting message feedback: {e}, {traceback.format_exc()}") - await db.rollback() - raise HTTPException(status_code=500, detail=f"Failed to submit feedback: {str(e)}") + result = await submit_message_feedback_view( + message_id=message_id, + rating=feedback_data.rating, + reason=feedback_data.reason, + db=db, + current_user_id=str(current_user.id), + ) + return MessageFeedbackResponse(**result) @chat.get("/message/{message_id}/feedback") @@ -1316,29 +752,11 @@ async def get_message_feedback( current_user: User = Depends(get_required_user), ): """获取指定消息的用户反馈(需要登录)""" - try: - # Get user's feedback for this message - feedback_result = await db.execute( - select(MessageFeedback).filter_by(message_id=message_id, user_id=str(current_user.id)) - ) - feedback = feedback_result.scalar_one_or_none() - - if not feedback: - return {"has_feedback": False, "feedback": None} - - return { - "has_feedback": True, - "feedback": { - "id": feedback.id, - "rating": feedback.rating, - "reason": feedback.reason, - "created_at": feedback.created_at.isoformat(), - }, - } - - except Exception as e: - logger.error(f"Error getting message feedback: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get feedback: {str(e)}") + return await get_message_feedback_view( + message_id=message_id, + db=db, + current_user_id=str(current_user.id), + ) # ============================================================================= diff --git a/server/routers/dashboard_router.py b/server/routers/dashboard_router.py index 8d04cb11f..13eb9ccb3 100644 --- a/server/routers/dashboard_router.py +++ b/server/routers/dashboard_router.py @@ -8,16 +8,17 @@ import traceback from datetime import datetime, timedelta +from typing import Any from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel -from sqlalchemy import String, cast, distinct, func, or_, select +from sqlalchemy import Integer, String, cast, distinct, func, or_, select, text from sqlalchemy.ext.asyncio import AsyncSession from server.routers.auth_router import get_admin_user from server.utils.auth_middleware import get_db -from src.storage.conversation import ConversationManager -from src.storage.db.models import User +from src.repositories.conversation_repository import ConversationRepository +from src.storage.postgres.models_business import User from src.utils.datetime_utils import UTC, ensure_shanghai, shanghai_now, utc_now from src.utils.logging_config import logger @@ -25,6 +26,25 @@ dashboard = APIRouter(prefix="/dashboard", tags=["Dashboard"]) +def _get_time_group_format(column, time_range: str) -> Any: + """ + 根据数据库类型生成时间分组格式化表达式。 + PostgreSQL 使用 to_char + INTERVAL,SQLite 使用 datetime + strftime。 + """ + # 检查是否是 PostgreSQL(通过检测 engine 或使用方言) + # 这里直接使用 PostgreSQL 语法,因为所有业务数据现在都在 PostgreSQL 上 + if time_range == "14hours": + # 每小时: YYYY-MM-DD HH:00 + time_expr = func.to_char(column + text("INTERVAL '8 hours'"), "YYYY-MM-DD HH24:00") + elif time_range == "14weeks": + # 每周: YYYY-WW + time_expr = func.to_char(column + text("INTERVAL '8 hours'"), "YYYY-IW") + else: # 14days + # 每天: YYYY-MM-DD + time_expr = func.to_char(column + text("INTERVAL '8 hours'"), "YYYY-MM-DD") + return time_expr + + # ============================================================================= # Response Models # ============================================================================= @@ -116,7 +136,7 @@ async def get_all_conversations( current_user: User = Depends(get_admin_user), ): """获取所有对话(管理员权限)""" - from src.storage.db.models import Conversation, ConversationStats + from src.storage.postgres.models_business import Conversation, ConversationStats try: # Build query @@ -165,7 +185,7 @@ async def get_conversation_detail( ): """获取指定对话详情(管理员权限)""" try: - conv_manager = ConversationManager(db) + conv_manager = ConversationRepository(db) conversation = await conv_manager.get_conversation_by_thread_id(thread_id) if not conversation: @@ -233,9 +253,11 @@ async def get_user_activity_stats( ): """获取用户活动统计(管理员权限)""" try: - from src.storage.db.models import User, Conversation + from src.storage.postgres.models_business import Conversation, User now = utc_now() + # PostgreSQL with asyncpg requires naive datetime for naive DateTime columns + naive_now = now.replace(tzinfo=None) # Conversations may store either the numeric user primary key or the login user_id string. # Join condition accounts for both representations. @@ -253,7 +275,7 @@ async def get_user_activity_stats( select(func.count(distinct(User.id))) .select_from(Conversation) .join(User, user_join_condition) - .filter(Conversation.updated_at >= now - timedelta(days=1), User.is_deleted == 0) + .filter(Conversation.updated_at >= naive_now - timedelta(days=1), User.is_deleted == 0) ) active_users_24h = active_users_24h_result.scalar() or 0 @@ -261,14 +283,14 @@ async def get_user_activity_stats( select(func.count(distinct(User.id))) .select_from(Conversation) .join(User, user_join_condition) - .filter(Conversation.updated_at >= now - timedelta(days=30), User.is_deleted == 0) + .filter(Conversation.updated_at >= naive_now - timedelta(days=30), User.is_deleted == 0) ) active_users_30d = active_users_30d_result.scalar() or 0 # 最近7天每日活跃用户(排除已删除用户) daily_active_users = [] for i in range(7): - day_start = now - timedelta(days=i + 1) - day_end = now - timedelta(days=i) + day_start = naive_now - timedelta(days=i + 1) + day_end = naive_now - timedelta(days=i) active_count_result = await db.execute( select(func.count(distinct(User.id))) @@ -305,9 +327,11 @@ async def get_tool_call_stats( ): """获取工具调用统计(管理员权限)""" try: - from src.storage.db.models import ToolCall + from src.storage.postgres.models_business import ToolCall now = utc_now() + # PostgreSQL with asyncpg requires naive datetime for naive DateTime columns + naive_now = now.replace(tzinfo=None) # 基础工具调用统计 total_calls_result = await db.execute(select(func.count(ToolCall.id))) @@ -340,8 +364,8 @@ async def get_tool_call_stats( # 最近7天每日工具调用数 daily_tool_calls = [] for i in range(7): - day_start = now - timedelta(days=i + 1) - day_end = now - timedelta(days=i) + day_start = naive_now - timedelta(days=i + 1) + day_end = naive_now - timedelta(days=i) daily_count_result = await db.execute( select(func.count(ToolCall.id)).filter(ToolCall.created_at >= day_start, ToolCall.created_at < day_end) @@ -378,108 +402,66 @@ async def get_knowledge_stats( ): """获取知识库统计(管理员权限)""" try: - from src.knowledge.manager import KnowledgeBaseManager - import json - import os - - # 从知识库管理系统获取数据 - kb_manager = KnowledgeBaseManager(work_dir="/app/saves/knowledge_base_data") - - # 读取全局元数据文件 - metadata_file = "/app/saves/knowledge_base_data/global_metadata.json" - if os.path.exists(metadata_file): - with open(metadata_file, encoding="utf-8") as f: - global_metadata = json.load(f) - - databases = global_metadata.get("databases", {}) - total_databases = len(databases) - - # 统计不同类型的知识库 - databases_by_type = {} - files_by_type = {} - total_files = 0 - total_nodes = 0 - total_storage_size = 0 - - # 文件类型映射到中文友好名称 - file_type_mapping = { - "txt": "文本文件", - "pdf": "PDF文档", - "docx": "Word文档", - "doc": "Word文档", - "md": "Markdown", - "html": "HTML网页", - "htm": "HTML网页", - "json": "JSON数据", - "csv": "CSV表格", - "xlsx": "Excel表格", - "xls": "Excel表格", - "pptx": "PowerPoint", - "ppt": "PowerPoint", - "png": "PNG图片", - "jpg": "JPEG图片", - "jpeg": "JPEG图片", - "gif": "GIF图片", - "svg": "SVG图片", - "mp4": "MP4视频", - "mp3": "MP3音频", - "zip": "ZIP压缩包", - "rar": "RAR压缩包", - "7z": "7Z压缩包", - } - - # 统计文件:改为基于各知识库实现中的 files_meta,更加准确 - # 注意:部分记录可能来源于 URL,此时无法统计物理大小 - for kb_instance in kb_manager.kb_instances.values(): - files_meta = getattr(kb_instance, "files_meta", {}) or {} - total_files += len(files_meta) - - for _fid, finfo in files_meta.items(): - file_ext = (finfo.get("file_type") or "").lower() - # 统一映射显示名 - display_name = file_type_mapping.get(file_ext, file_ext.upper() + "文件" if file_ext else "其他") - files_by_type[display_name] = files_by_type.get(display_name, 0) + 1 - - # 估算大小(如果路径存在且是本地文件) - path = finfo.get("path") or "" - try: - if path and os.path.exists(path) and os.path.isfile(path): - total_storage_size += os.path.getsize(path) - except Exception: - # 忽略无法访问的路径 - pass - - # 统计知识库类型分布 - for kb_id, kb_info in databases.items(): - kb_type = kb_info.get("kb_type", "unknown") - display_type = { - "lightrag": "LightRAG", - "faiss": "FAISS", - "milvus": "Milvus", - "qdrant": "Qdrant", - "elasticsearch": "Elasticsearch", - "unknown": "未知类型", - }.get(kb_type.lower(), kb_type) - databases_by_type[display_type] = databases_by_type.get(display_type, 0) + 1 - - # 尝试从各个知识库系统获取更详细的统计 - try: - kb_instance = kb_manager.get_kb(kb_id) - if kb_instance and hasattr(kb_instance, "get_stats"): - stats = kb_instance.get_stats() - total_nodes += stats.get("node_count", 0) - except Exception as e: - logger.warning(f"Failed to get stats for KB {kb_id}: {e}") - continue + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + from src.repositories.knowledge_file_repository import KnowledgeFileRepository + + kb_repo = KnowledgeBaseRepository() + file_repo = KnowledgeFileRepository() + + kb_rows = await kb_repo.get_all() + total_databases = len(kb_rows) + + databases_by_type: dict[str, int] = {} + files_by_type: dict[str, int] = {} + total_files = 0 + total_nodes = 0 + total_storage_size = 0 + + file_type_mapping = { + "txt": "文本文件", + "pdf": "PDF文档", + "docx": "Word文档", + "doc": "Word文档", + "md": "Markdown", + "html": "HTML网页", + "htm": "HTML网页", + "json": "JSON数据", + "csv": "CSV表格", + "xlsx": "Excel表格", + "xls": "Excel表格", + "pptx": "PowerPoint", + "ppt": "PowerPoint", + "png": "PNG图片", + "jpg": "JPEG图片", + "jpeg": "JPEG图片", + "gif": "GIF图片", + "svg": "SVG图片", + "mp4": "MP4视频", + "mp3": "MP3音频", + "zip": "ZIP压缩包", + "rar": "RAR压缩包", + "7z": "7Z压缩包", + } - else: - # 如果没有元数据文件,返回空数据 - total_databases = 0 - total_files = 0 - total_nodes = 0 - total_storage_size = 0 - databases_by_type = {} - files_by_type = {} + for kb in kb_rows: + kb_type = (kb.kb_type or "unknown").lower() + display_type = { + "lightrag": "LightRAG", + "faiss": "FAISS", + "milvus": "Milvus", + "qdrant": "Qdrant", + "elasticsearch": "Elasticsearch", + "unknown": "未知类型", + }.get(kb_type, kb.kb_type or "未知类型") + databases_by_type[display_type] = databases_by_type.get(display_type, 0) + 1 + + files = await file_repo.list_by_db_id(kb.db_id) + total_files += len(files) + for record in files: + file_ext = (record.file_type or "").lower() + display_name = file_type_mapping.get(file_ext, file_ext.upper() + "文件" if file_ext else "其他") + files_by_type[display_name] = files_by_type.get(display_name, 0) + 1 + total_storage_size += int(record.file_size or 0) return KnowledgeStats( total_databases=total_databases, @@ -508,7 +490,7 @@ async def get_agent_analytics( ): """获取智能体分析(管理员权限)""" try: - from src.storage.db.models import Conversation, MessageFeedback, Message, ToolCall + from src.storage.postgres.models_business import Conversation, Message, MessageFeedback, ToolCall # 获取所有智能体 agents_result = await db.execute( @@ -604,7 +586,7 @@ async def get_dashboard_stats( current_user: User = Depends(get_admin_user), ): """获取基础统计(管理员权限)""" - from src.storage.db.models import Conversation, Message, MessageFeedback + from src.storage.postgres.models_business import Conversation, Message, MessageFeedback try: # Basic counts @@ -678,7 +660,7 @@ async def get_all_feedbacks( current_user: User = Depends(get_admin_user), ): """获取所有反馈记录(管理员权限)""" - from src.storage.db.models import MessageFeedback, Message, Conversation, User + from src.storage.postgres.models_business import Conversation, Message, MessageFeedback, User try: # Build query with joins including User table @@ -689,7 +671,7 @@ async def get_all_feedbacks( .join(Conversation, Message.conversation_id == Conversation.id) .outerjoin( User, - (MessageFeedback.user_id == User.id) | (MessageFeedback.user_id == User.user_id), + (MessageFeedback.user_id == cast(User.id, String)) | (MessageFeedback.user_id == User.user_id), ) ) @@ -756,7 +738,7 @@ async def get_call_timeseries_stats( ): """获取调用分析时间序列统计(管理员权限)""" try: - from src.storage.db.models import Conversation, Message, ToolCall + from src.storage.postgres.models_business import Conversation, Message, ToolCall # 计算时间范围(使用北京时间 UTC+8) now = utc_now() @@ -766,7 +748,7 @@ async def get_call_timeseries_stats( intervals = 14 # 包含当前小时:从13小时前开始 start_time = now - timedelta(hours=intervals - 1) - group_format = func.strftime("%Y-%m-%d %H:00", func.datetime(Message.created_at, "+8 hours")) + group_format = _get_time_group_format(Message.created_at, time_range) base_local_time = ensure_shanghai(start_time) elif time_range == "14weeks": intervals = 14 @@ -775,40 +757,40 @@ async def get_call_timeseries_stats( local_start = local_start - timedelta(days=local_start.weekday()) local_start = local_start.replace(hour=0, minute=0, second=0, microsecond=0) start_time = local_start.astimezone(UTC) - group_format = func.strftime("%Y-%W", func.datetime(Message.created_at, "+8 hours")) + group_format = _get_time_group_format(Message.created_at, time_range) base_local_time = local_start else: # 14days (default) intervals = 14 # 包含当前天:从13天前开始 start_time = now - timedelta(days=intervals - 1) - group_format = func.strftime("%Y-%m-%d", func.datetime(Message.created_at, "+8 hours")) + group_format = _get_time_group_format(Message.created_at, time_range) base_local_time = ensure_shanghai(start_time) + # Convert start_time to naive UTC datetime for PostgreSQL query + # PostgreSQL with asyncpg and naive DateTime columns requires naive datetime objects + query_start_time = start_time.replace(tzinfo=None) + # 根据类型查询数据 if type == "models": # 模型调用统计(基于消息数量,按模型分组) # 从message的extra_metadata中提取模型信息 + category_expr = cast(Message.extra_metadata["response_metadata"]["model_name"], String) query_result = await db.execute( select( group_format.label("date"), func.count(Message.id).label("count"), - func.json_extract(Message.extra_metadata, "$.response_metadata.model_name").label("category"), + category_expr.label("category"), ) - .filter(Message.role == "assistant", Message.created_at >= start_time) + .filter(Message.role == "assistant", Message.created_at >= query_start_time) .filter(Message.extra_metadata.isnot(None)) - .group_by(group_format, func.json_extract(Message.extra_metadata, "$.response_metadata.model_name")) + .group_by(group_format, category_expr) .order_by(group_format) ) query = query_result.all() elif type == "agents": # 智能体调用统计(基于对话更新时间,按智能体分组) - # 为对话创建独立的时间格式化器 - if time_range == "14hours": - conv_group_format = func.strftime("%Y-%m-%d %H:00", func.datetime(Conversation.updated_at, "+8 hours")) - elif time_range == "14weeks": - conv_group_format = func.strftime("%Y-%W", func.datetime(Conversation.updated_at, "+8 hours")) - else: # 14days - conv_group_format = func.strftime("%Y-%m-%d", func.datetime(Conversation.updated_at, "+8 hours")) + # 为对话创建独立的时间格式化器(使用 PostgreSQL 兼容的 to_char + INTERVAL) + conv_group_format = _get_time_group_format(Conversation.updated_at, time_range) query_result = await db.execute( select( @@ -817,7 +799,7 @@ async def get_call_timeseries_stats( Conversation.agent_id.label("category"), ) .filter(Conversation.updated_at.isnot(None)) - .filter(Conversation.updated_at >= start_time) + .filter(Conversation.updated_at >= query_start_time) .group_by(conv_group_format, Conversation.agent_id) .order_by(conv_group_format) ) @@ -831,14 +813,16 @@ async def get_call_timeseries_stats( select( group_format.label("date"), func.sum( - func.coalesce(func.json_extract(Message.extra_metadata, "$.usage_metadata.input_tokens"), 0) + func.coalesce( + cast(cast(Message.extra_metadata["usage_metadata"]["input_tokens"], String), Integer), 0 + ) ).label("count"), literal("input_tokens").label("category"), ) .filter( - Message.created_at >= start_time, + Message.created_at >= query_start_time, Message.extra_metadata.isnot(None), - func.json_extract(Message.extra_metadata, "$.usage_metadata").isnot(None), + Message.extra_metadata["usage_metadata"].isnot(None), ) .group_by(group_format) .order_by(group_format) @@ -850,14 +834,16 @@ async def get_call_timeseries_stats( select( group_format.label("date"), func.sum( - func.coalesce(func.json_extract(Message.extra_metadata, "$.usage_metadata.output_tokens"), 0) + func.coalesce( + cast(cast(Message.extra_metadata["usage_metadata"]["output_tokens"], String), Integer), 0 + ) ).label("count"), literal("output_tokens").label("category"), ) .filter( - Message.created_at >= start_time, + Message.created_at >= query_start_time, Message.extra_metadata.isnot(None), - func.json_extract(Message.extra_metadata, "$.usage_metadata").isnot(None), + Message.extra_metadata["usage_metadata"].isnot(None), ) .group_by(group_format) .order_by(group_format) @@ -870,13 +856,8 @@ async def get_call_timeseries_stats( results = input_results + output_results elif type == "tools": # 工具调用统计(按工具名称分组) - # 为工具调用创建独立的时间格式化器 - if time_range == "14hours": - tool_group_format = func.strftime("%Y-%m-%d %H:00", func.datetime(ToolCall.created_at, "+8 hours")) - elif time_range == "14weeks": - tool_group_format = func.strftime("%Y-%W", func.datetime(ToolCall.created_at, "+8 hours")) - else: # 14days - tool_group_format = func.strftime("%Y-%m-%d", func.datetime(ToolCall.created_at, "+8 hours")) + # 为工具调用创建独立的时间格式化器(使用 PostgreSQL 兼容的 to_char + INTERVAL) + tool_group_format = _get_time_group_format(ToolCall.created_at, time_range) query_result = await db.execute( select( @@ -884,7 +865,7 @@ async def get_call_timeseries_stats( func.count(ToolCall.id).label("count"), ToolCall.tool_name.label("category"), ) - .filter(ToolCall.created_at >= start_time) + .filter(ToolCall.created_at >= query_start_time) .group_by(tool_group_format, ToolCall.tool_name) .order_by(tool_group_format) ) @@ -971,7 +952,7 @@ def normalize_week_key(raw_key: str) -> str: # 计算统计指标 if type == "tools": # 对于工具调用,显示所有时间的总数(与ToolStatsComponent保持一致) - from src.storage.db.models import ToolCall + from src.storage.postgres.models_business import ToolCall total_count_result = await db.execute(select(func.count(ToolCall.id))) total_count = total_count_result.scalar() or 0 diff --git a/server/routers/department_router.py b/server/routers/department_router.py index 057e051d3..11a2d86d6 100644 --- a/server/routers/department_router.py +++ b/server/routers/department_router.py @@ -10,8 +10,10 @@ from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession -from src.storage.db.models import Department, User -from server.utils.auth_middleware import get_superadmin_user, get_db +from src.storage.postgres.models_business import Department, User +from src.repositories.department_repository import DepartmentRepository +from src.repositories.user_repository import UserRepository +from server.utils.auth_middleware import get_superadmin_user, get_admin_user, get_db from server.utils.auth_utils import AuthUtils from server.utils.common_utils import log_operation from server.utils.user_utils import is_valid_phone_number @@ -60,36 +62,16 @@ class DepartmentResponse(BaseModel): user_count: int = 0 -class DepartmentSimpleResponse(BaseModel): - """部门简单响应(不含用户数量)""" - - id: int - name: str - description: str | None = None - created_at: str - - # ============================================================================= # === 部门管理路由 === # ============================================================================= @department.get("", response_model=list[DepartmentResponse]) -async def get_departments(current_user: User = Depends(get_superadmin_user), db: AsyncSession = Depends(get_db)): - """获取所有部门列表""" - result = await db.execute(select(Department).order_by(Department.created_at.desc())) - departments = result.scalars().all() - - # 获取每个部门的用户数量 - department_list = [] - for dep in departments: - user_count_result = await db.execute( - select(func.count(User.id)).filter(User.department_id == dep.id, User.is_deleted == 0) - ) - user_count = user_count_result.scalar() - department_list.append({**dep.to_dict(), "user_count": user_count}) - - return department_list +async def get_departments(current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db)): + """获取所有部门列表(管理员可访问)""" + dept_repo = DepartmentRepository() + return await dept_repo.list_with_user_count() @department.get("/{department_id}", response_model=DepartmentResponse) @@ -120,10 +102,11 @@ async def create_department( db: AsyncSession = Depends(get_db), ): """创建新部门,同时创建该部门的管理员""" + dept_repo = DepartmentRepository() + user_repo = UserRepository() + # 检查部门名称是否已存在 - result = await db.execute(select(Department).filter(Department.name == department_data.name)) - existing = result.scalar_one_or_none() - if existing: + if await dept_repo.exists_by_name(department_data.name): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="部门名称已存在") # 验证管理员 user_id 格式 @@ -141,9 +124,7 @@ async def create_department( ) # 检查 user_id 是否已存在 - result = await db.execute(select(User).filter(User.user_id == admin_user_id)) - existing_user = result.scalar_one_or_none() - if existing_user: + if await user_repo.exists_by_user_id(admin_user_id): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="用户ID已存在", @@ -154,33 +135,32 @@ async def create_department( if admin_phone: if not is_valid_phone_number(admin_phone): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="手机号格式不正确") - result = await db.execute(select(User).filter(User.phone_number == admin_phone)) - existing_phone = result.scalar_one_or_none() - if existing_phone: + if await user_repo.exists_by_phone(admin_phone): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="手机号已存在", ) - new_department = Department(name=department_data.name, description=department_data.description) - - db.add(new_department) - await db.flush() # 获取部门ID + # 创建部门 + new_department = await dept_repo.create( + { + "name": department_data.name, + "description": department_data.description, + } + ) # 创建管理员用户 hashed_password = AuthUtils.hash_password(department_data.admin_password) - new_admin = User( - username=admin_user_id, # username 和 user_id 设置为相同值 - user_id=admin_user_id, - phone_number=admin_phone, - password_hash=hashed_password, - role="admin", - department_id=new_department.id, + await user_repo.create( + { + "username": admin_user_id, + "user_id": admin_user_id, + "phone_number": admin_phone, + "password_hash": hashed_password, + "role": "admin", + "department_id": new_department.id, + } ) - db.add(new_admin) - - await db.commit() - await db.refresh(new_department) # 记录操作 await log_operation( diff --git a/server/routers/evaluation_router.py b/server/routers/evaluation_router.py index 1db786c6f..5fee1ff25 100644 --- a/server/routers/evaluation_router.py +++ b/server/routers/evaluation_router.py @@ -1,7 +1,7 @@ import traceback from fastapi import APIRouter, HTTPException, Depends, File, Form, Body, UploadFile -from src.storage.db.models import User +from src.storage.postgres.models_business import User from server.utils.auth_middleware import get_admin_user from src.utils import logger diff --git a/server/routers/graph_router.py b/server/routers/graph_router.py index 7cca6ea44..ccbad5d81 100644 --- a/server/routers/graph_router.py +++ b/server/routers/graph_router.py @@ -6,7 +6,7 @@ from src import graph_base, knowledge_base from src.knowledge.adapters.base import GraphAdapter from src.knowledge.adapters.factory import GraphAdapterFactory -from src.storage.db.models import User +from src.storage.postgres.models_business import User from src.storage.minio.client import StorageError from src.utils.logging_config import logger @@ -31,12 +31,12 @@ async def _get_graph_adapter(db_id: str) -> GraphAdapter: # 检查图数据库服务状态 (仅对 Upload 类型需要) if not graph_base.is_running(): # 先尝试检测图谱类型,如果是不需要 graph_base 的类型则允许 - graph_type = GraphAdapterFactory.detect_graph_type(db_id, knowledge_base) + graph_type = await GraphAdapterFactory.detect_graph_type(db_id, knowledge_base) if graph_type == "upload": raise HTTPException(status_code=503, detail="Graph database service is not running") # 使用工厂方法自动创建适配器 - return GraphAdapterFactory.create_adapter_by_db_id( + return await GraphAdapterFactory.create_adapter_by_db_id( db_id=db_id, knowledge_base_manager=knowledge_base, graph_db_instance=graph_base ) @@ -83,7 +83,7 @@ async def get_graphs(current_user: User = Depends(get_admin_user)): ) # 2. 获取 LightRAG 数据库信息 - lightrag_dbs = knowledge_base.get_lightrag_databases() + lightrag_dbs = await knowledge_base.get_lightrag_databases() # 直接使用 LightRAG 适配器的默认 metadata from src.knowledge.adapters.lightrag import LightRAGGraphAdapter diff --git a/server/routers/knowledge_router.py b/server/routers/knowledge_router.py index 754814e44..ae133301c 100644 --- a/server/routers/knowledge_router.py +++ b/server/routers/knowledge_router.py @@ -1,4 +1,5 @@ import asyncio +import json import os import textwrap import traceback @@ -9,13 +10,13 @@ from fastapi.responses import FileResponse from starlette.responses import StreamingResponse -from server.services.tasker import TaskContext, tasker -from server.utils.auth_middleware import get_admin_user +from src.services.task_service import TaskContext, tasker +from server.utils.auth_middleware import get_admin_user, get_required_user from src import config, knowledge_base from src.knowledge.indexing import SUPPORTED_FILE_EXTENSIONS, is_supported_file_extension, process_file_to_markdown from src.knowledge.utils import calculate_content_hash from src.models.embed import test_all_embedding_models_status, test_embedding_model_status -from src.storage.db.models import User +from src.storage.postgres.models_business import User from src.storage.minio.client import StorageError, aupload_file_to_minio, get_minio_client from src.utils import logger @@ -63,10 +64,10 @@ @knowledge.get("/databases") async def get_databases(current_user: User = Depends(get_admin_user)): - """获取所有知识库""" + """获取所有知识库(根据用户权限过滤)""" try: - database = knowledge_base.get_databases() - return database + user_info = {"role": current_user.role, "department_id": current_user.department_id} + return await knowledge_base.get_databases_by_user(user_info) except Exception as e: logger.error(f"获取数据库列表失败 {e}, {traceback.format_exc()}") return {"message": f"获取数据库列表失败 {e}", "databases": []} @@ -80,17 +81,18 @@ async def create_database( kb_type: str = Body("lightrag"), additional_params: dict = Body({}), llm_info: dict = Body(None), + share_config: dict = Body(None), current_user: User = Depends(get_admin_user), ): """创建知识库""" logger.debug( f"Create database {database_name} with kb_type {kb_type}, " f"additional_params {additional_params}, llm_info {llm_info}, " - f"embed_model_name {embed_model_name}" + f"embed_model_name {embed_model_name}, share_config {share_config}" ) try: # 先检查名称是否已存在 - if knowledge_base.database_name_exists(database_name): + if await knowledge_base.database_name_exists(database_name): raise HTTPException( status_code=409, detail=f"知识库名称 '{database_name}' 已存在,请使用其他名称", @@ -116,8 +118,16 @@ def remove_reranker_config(kb: str, params: dict) -> None: remove_reranker_config(kb_type, additional_params) embed_info = config.embed_model_names[embed_model_name] + # 将Pydantic模型转换为字典以便JSON序列化 + embed_info_dict = embed_info.model_dump() if hasattr(embed_info, "model_dump") else embed_info.dict() database_info = await knowledge_base.create_database( - database_name, description, kb_type=kb_type, embed_info=embed_info, llm_info=llm_info, **additional_params + database_name, + description, + kb_type=kb_type, + embed_info=embed_info_dict, + llm_info=llm_info, + share_config=share_config, + **additional_params, ) # 需要重新加载所有智能体,因为工具刷新了 @@ -133,10 +143,32 @@ def remove_reranker_config(kb: str, params: dict) -> None: raise HTTPException(status_code=400, detail=f"创建数据库失败: {e}") +@knowledge.get("/databases/accessible") +async def get_accessible_databases(current_user: User = Depends(get_required_user)): + """获取当前用户有权访问的知识库列表(用于智能体配置)""" + try: + user_info = {"role": current_user.role, "department_id": current_user.department_id} + databases = await knowledge_base.get_databases_by_user(user_info) + + accessible = [ + { + "name": db.get("name", ""), + "db_id": db.get("db_id"), + "description": db.get("description", ""), + } + for db in databases.get("databases", []) + ] + + return {"databases": accessible} + except Exception as e: + logger.error(f"获取可访问知识库列表失败: {e}, {traceback.format_exc()}") + return {"message": f"获取可访问知识库列表失败: {str(e)}", "databases": []} + + @knowledge.get("/databases/{db_id}") async def get_database_info(db_id: str, current_user: User = Depends(get_admin_user)): """获取知识库详细信息""" - database = knowledge_base.get_database_info(db_id) + database = await knowledge_base.get_database_info(db_id) if database is None: raise HTTPException(status_code=404, detail="Database not found") return database @@ -148,13 +180,14 @@ async def update_database_info( name: str = Body(...), description: str = Body(...), llm_info: dict = Body(None), - additional_params: dict = Body({}), # Now accepts a dict + additional_params: dict = Body({}), + share_config: dict = Body(None), current_user: User = Depends(get_admin_user), ): """更新知识库信息""" logger.debug( - f"Update database {db_id} info: {name}, {description}, llm_info: {llm_info}, " - f"additional_params: {additional_params}" + f"[update_database_info] 接收到的参数: name={name}, llm_info={llm_info}, " + f"additional_params={additional_params}, share_config={share_config}" ) try: database = await knowledge_base.update_database( @@ -162,7 +195,8 @@ async def update_database_info( name, description, llm_info, - additional_params=additional_params, # Pass the dict to the manager + additional_params=additional_params, + share_config=share_config, ) return {"message": "更新成功", "database": database} except Exception as e: @@ -272,7 +306,7 @@ async def run_ingest(context: TaskContext): try: # 1. Add file record (UPLOADED) file_meta = await knowledge_base.add_file_record( - db_id, item, params=params, operator_id=current_user.id + db_id, item, params=params, operator_id=current_user.user_id ) file_id = file_meta["file_id"] added_files[item] = (file_id, file_meta) @@ -304,7 +338,7 @@ async def run_ingest(context: TaskContext): try: # 2. Parse file (PARSING -> PARSED) - file_meta = await knowledge_base.parse_file(db_id, file_id, operator_id=current_user.id) + file_meta = await knowledge_base.parse_file(db_id, file_id, operator_id=current_user.user_id) processed_items.append(file_meta) parse_success_count += 1 except Exception as parse_error: @@ -336,10 +370,10 @@ async def run_ingest(context: TaskContext): try: # 1. 更新入库参数 await knowledge_base.update_file_params( - db_id, file_id, indexing_params, operator_id=current_user.id + db_id, file_id, indexing_params, operator_id=current_user.user_id ) # 2. 执行入库 - result = await knowledge_base.index_file(db_id, file_id, operator_id=current_user.id) + result = await knowledge_base.index_file(db_id, file_id, operator_id=current_user.user_id) processed_items.append(result) except Exception as index_error: logger.error(f"自动入库失败 {item} (file_id={file_id}): {index_error}") @@ -381,7 +415,7 @@ async def run_ingest(context: TaskContext): return summary | {"items": processed_items} try: - database = knowledge_base.get_database_info(db_id) + database = await knowledge_base.get_database_info(db_id) task = await tasker.enqueue( name=f"知识库文档处理 ({database['name']})", task_type="knowledge_ingest", @@ -422,7 +456,7 @@ async def run_parse(context: TaskContext): await context.set_progress(progress, f"正在解析第 {idx}/{total} 个文档") try: - result = await knowledge_base.parse_file(db_id, file_id, operator_id=current_user.id) + result = await knowledge_base.parse_file(db_id, file_id, operator_id=current_user.user_id) processed_items.append(result) except Exception as e: logger.error(f"Parse failed for {file_id}: {e}") @@ -439,7 +473,7 @@ async def run_parse(context: TaskContext): return {"items": processed_items} try: - database = knowledge_base.get_database_info(db_id) + database = await knowledge_base.get_database_info(db_id) task = await tasker.enqueue( name=f"文档解析 ({database['name']})", task_type="knowledge_parse", @@ -516,7 +550,7 @@ async def run_index(context: TaskContext): return {"items": processed_items} try: - database = knowledge_base.get_database_info(db_id) + database = await knowledge_base.get_database_info(db_id) task = await tasker.enqueue( name=f"文档入库 ({database['name']})", task_type="knowledge_index", @@ -767,7 +801,7 @@ async def update_knowledge_base_query_params( """更新知识库查询参数配置""" try: # 获取知识库实例 - kb_instance = knowledge_base.get_kb(db_id) + kb_instance = await knowledge_base._get_kb_for_database(db_id) if not kb_instance: raise HTTPException(status_code=404, detail="Knowledge base not found") @@ -777,10 +811,13 @@ async def update_knowledge_base_query_params( if db_id not in kb_instance.databases_meta: raise HTTPException(status_code=404, detail="Database not found in instance metadata") - # 使用 setdefault 简化嵌套字典的初始化 - options = kb_instance.databases_meta[db_id].setdefault("query_params", {}).setdefault("options", {}) + # 确保 query_params 不为 None + if kb_instance.databases_meta[db_id].get("query_params") is None: + kb_instance.databases_meta[db_id]["query_params"] = {} + + options = kb_instance.databases_meta[db_id]["query_params"].setdefault("options", {}) options.update(params) - kb_instance._save_metadata() + await kb_instance._save_metadata() logger.info(f"更新知识库 {db_id} 查询参数: {params}") @@ -796,7 +833,7 @@ async def get_knowledge_base_query_params(db_id: str, current_user: User = Depen """获取知识库类型特定的查询参数""" try: # 获取知识库实例 - kb_instance = knowledge_base._get_kb_for_database(db_id) + kb_instance = await knowledge_base._get_kb_for_database(db_id) # 调用知识库实例的方法获取配置 params = kb_instance.get_query_params_config( @@ -872,15 +909,13 @@ async def generate_sample_questions( 生成的问题列表 """ try: - import json - from src.models import select_model # 从请求体中提取参数 count = request_body.get("count", 10) # 获取知识库信息 - db_info = knowledge_base.get_database_info(db_id) + db_info = await knowledge_base.get_database_info(db_id) if not db_info: raise HTTPException(status_code=404, detail=f"知识库 {db_id} 不存在") @@ -953,14 +988,10 @@ async def generate_sample_questions( # 保存问题到知识库元数据 try: - async with knowledge_base._metadata_lock: - # 确保知识库元数据存在 - if db_id not in knowledge_base.global_databases_meta: - knowledge_base.global_databases_meta[db_id] = {} - # 保存问题到对应知识库 - knowledge_base.global_databases_meta[db_id]["sample_questions"] = questions - knowledge_base._save_global_metadata() - logger.info(f"成功保存 {len(questions)} 个问题到知识库 {db_id}") + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + + await KnowledgeBaseRepository().update(db_id, {"sample_questions": questions}) + logger.info(f"成功保存 {len(questions)} 个问题到知识库 {db_id}") except Exception as save_error: logger.error(f"保存问题失败: {save_error}") @@ -995,12 +1026,15 @@ async def get_sample_questions(db_id: str, current_user: User = Depends(get_admi 问题列表 """ try: - # 直接从全局元数据中读取 - if db_id not in knowledge_base.global_databases_meta: + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + + kb_repo = KnowledgeBaseRepository() + kb = await kb_repo.get_by_id(db_id) + + if kb is None: raise HTTPException(status_code=404, detail=f"知识库 {db_id} 不存在") - db_meta = knowledge_base.global_databases_meta[db_id] - questions = db_meta.get("sample_questions", []) + questions = kb.sample_questions or [] return { "message": "success", @@ -1162,7 +1196,7 @@ async def get_knowledge_base_types(current_user: User = Depends(get_admin_user)) async def get_knowledge_base_statistics(current_user: User = Depends(get_admin_user)): """获取知识库统计信息""" try: - stats = knowledge_base.get_statistics() + stats = await knowledge_base.get_statistics() return {"stats": stats, "message": "success"} except Exception as e: logger.error(f"获取知识库统计失败 {e}, {traceback.format_exc()}") diff --git a/server/routers/mcp_router.py b/server/routers/mcp_router.py index d886ec019..66b31f0de 100644 --- a/server/routers/mcp_router.py +++ b/server/routers/mcp_router.py @@ -15,7 +15,7 @@ toggle_tool_enabled, update_mcp_server, ) -from src.storage.db.models import User +from src.storage.postgres.models_business import User from src.utils import logger from server.utils.auth_middleware import get_admin_user, get_db diff --git a/server/routers/mindmap_router.py b/server/routers/mindmap_router.py index 81fa27a16..f97d1dec1 100644 --- a/server/routers/mindmap_router.py +++ b/server/routers/mindmap_router.py @@ -14,7 +14,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException -from src.storage.db.models import User +from src.storage.postgres.models_business import User from server.utils.auth_middleware import get_admin_user from src import knowledge_base from src.models import select_model @@ -93,7 +93,7 @@ async def get_database_files(db_id: str, current_user: User = Depends(get_admin_ """ try: # 获取知识库详细信息 - db_info = knowledge_base.get_database_info(db_id) + db_info = await knowledge_base.get_database_info(db_id) if not db_info: raise HTTPException(status_code=404, detail=f"知识库 {db_id} 不存在") @@ -154,7 +154,7 @@ async def generate_mindmap( """ try: # 获取知识库信息 - db_info = knowledge_base.get_database_info(db_id) + db_info = await knowledge_base.get_database_info(db_id) if not db_info: raise HTTPException(status_code=404, detail=f"知识库 {db_id} 不存在") @@ -244,11 +244,10 @@ async def generate_mindmap( # 保存思维导图到知识库元数据 try: - async with knowledge_base._metadata_lock: - if db_id in knowledge_base.global_databases_meta: - knowledge_base.global_databases_meta[db_id]["mindmap"] = mindmap_data - knowledge_base._save_global_metadata() - logger.info(f"思维导图已保存到知识库: {db_id}") + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + + await KnowledgeBaseRepository().update(db_id, {"mindmap": mindmap_data}) + logger.info(f"思维导图已保存到知识库: {db_id}") except Exception as save_error: logger.error(f"保存思维导图失败: {save_error}") # 不影响返回结果,只记录错误 @@ -282,13 +281,14 @@ async def generate_mindmap( @mindmap.get("/databases") async def get_databases_overview(current_user: User = Depends(get_admin_user)): """ - 获取所有知识库的概览信息,用于思维导图界面选择 + 获取所有知识库的概览信息,用于思维导图界面选择(根据用户权限过滤) Returns: 知识库列表 """ try: - databases = knowledge_base.get_databases() + user_info = {"role": current_user.role, "department_id": current_user.department_id} + databases = await knowledge_base.get_databases_by_user(user_info) # databases["databases"] 是一个列表,每个元素已经包含了基本信息 db_list_raw = databases.get("databases", []) @@ -300,7 +300,7 @@ async def get_databases_overview(current_user: User = Depends(get_admin_user)): continue # 获取详细信息以获取文件数量 - detail_info = knowledge_base.get_database_info(db_id) + detail_info = await knowledge_base.get_database_info(db_id) file_count = len(detail_info.get("files", {})) if detail_info else 0 db_list.append( @@ -341,18 +341,19 @@ async def get_database_mindmap(db_id: str, current_user: User = Depends(get_admi 思维导图数据 """ try: - # 直接从全局元数据中读取思维导图 - if db_id not in knowledge_base.global_databases_meta: - raise HTTPException(status_code=404, detail=f"知识库 {db_id} 不存在") + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository - db_meta = knowledge_base.global_databases_meta[db_id] - mindmap_data = db_meta.get("mindmap") + kb_repo = KnowledgeBaseRepository() + kb = await kb_repo.get_by_id(db_id) + + if kb is None: + raise HTTPException(status_code=404, detail=f"知识库 {db_id} 不存在") return { "message": "success", - "mindmap": mindmap_data, + "mindmap": kb.mindmap, "db_id": db_id, - "db_name": db_meta.get("name", ""), + "db_name": kb.name, } except HTTPException: diff --git a/server/routers/system_router.py b/server/routers/system_router.py index 7be72c3d3..b19d365b5 100644 --- a/server/routers/system_router.py +++ b/server/routers/system_router.py @@ -5,7 +5,7 @@ import yaml from fastapi import APIRouter, Body, Depends, HTTPException -from src.storage.db.models import User +from src.storage.postgres.models_business import User from server.utils.auth_middleware import get_admin_user from src import config from src.models.chat import test_chat_model_status, test_all_chat_models_status diff --git a/server/routers/task_router.py b/server/routers/task_router.py index fc692b6b7..848af3f9c 100644 --- a/server/routers/task_router.py +++ b/server/routers/task_router.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query -from src.storage.db.models import User -from server.services.tasker import tasker +from src.storage.postgres.models_business import User +from src.services.task_service import tasker from server.utils.auth_middleware import get_admin_user tasks = APIRouter(prefix="/tasks", tags=["tasks"]) diff --git a/server/services/__init__.py b/server/services/__init__.py deleted file mode 100644 index 5bbcfe716..000000000 --- a/server/services/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tasker import TaskContext, Tasker, tasker - -__all__ = ["TaskContext", "Tasker", "tasker"] diff --git a/server/utils/auth_middleware.py b/server/utils/auth_middleware.py index 2be55704c..04a2bd31d 100644 --- a/server/utils/auth_middleware.py +++ b/server/utils/auth_middleware.py @@ -5,8 +5,8 @@ from jose import JWTError from sqlalchemy.ext.asyncio import AsyncSession -from src.storage.db.manager import db_manager -from src.storage.db.models import User +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import User from server.utils.auth_utils import AuthUtils # 定义OAuth2密码承载器,指定token URL @@ -25,7 +25,7 @@ # 获取数据库会话(异步版本) async def get_db(): - async with db_manager.get_async_session_context() as db: + async with pg_manager.get_async_session_context() as db: yield db @@ -61,7 +61,7 @@ async def get_current_user(token: str | None = Depends(oauth2_scheme), db: Async # 查找用户(异步版本) from sqlalchemy import select - result = await db.execute(select(User).filter(User.id == user_id)) + result = await db.execute(select(User).filter(User.id == int(user_id))) user = result.scalar_one_or_none() if user is None: raise credentials_exception diff --git a/server/utils/common_utils.py b/server/utils/common_utils.py index 02d685764..0adc252b7 100644 --- a/server/utils/common_utils.py +++ b/server/utils/common_utils.py @@ -5,7 +5,7 @@ from fastapi import Request from sqlalchemy.orm import Session -from src.storage.db.models import OperationLog, User +from src.storage.postgres.models_business import OperationLog, User def setup_logging(): diff --git a/server/utils/lifespan.py b/server/utils/lifespan.py index feb642f83..222feb11c 100644 --- a/server/utils/lifespan.py +++ b/server/utils/lifespan.py @@ -2,20 +2,37 @@ from fastapi import FastAPI -from server.services import tasker +from src.services.task_service import tasker from src.services.mcp_service import init_mcp_servers +from src.storage.postgres.manager import pg_manager +from src.knowledge import knowledge_base from src.utils import logger @asynccontextmanager async def lifespan(app: FastAPI): """FastAPI lifespan事件管理器""" + # 初始化数据库连接 + try: + pg_manager.initialize() + await pg_manager.create_business_tables() + await pg_manager.ensure_knowledge_schema() + except Exception as e: + logger.error(f"Failed to initialize database during startup: {e}") + # 初始化 MCP 服务器配置 try: await init_mcp_servers() except Exception as e: logger.error(f"Failed to initialize MCP servers during startup: {e}") + # 初始化知识库管理器 + try: + await knowledge_base.initialize() + except Exception as e: + logger.error(f"Failed to initialize knowledge base manager: {e}") + await tasker.start() yield await tasker.shutdown() + await pg_manager.close() diff --git a/server/utils/migrate.py b/server/utils/migrate.py index 0a9122456..a598cc136 100644 --- a/server/utils/migrate.py +++ b/server/utils/migrate.py @@ -310,6 +310,91 @@ def get_migrations(self) -> list[tuple[int, str, list[str]]]: migrations.append((4, "添加部门功能", v4_commands)) + # 迁移 v5: 补全知识库/评估相关表字段(为历史数据库增加新增列) + v5_commands: list[str] = [] + + # knowledge_bases + if self.check_table_exists("knowledge_bases"): + kb_columns = { + "embed_info": "JSON", + "llm_info": "JSON", + "query_params": "JSON", + "additional_params": "JSON", + "share_config": "JSON", + "mindmap": "JSON", + "sample_questions": "JSON", + "updated_at": "DATETIME", + } + for col, col_type in kb_columns.items(): + if not self.check_column_exists("knowledge_bases", col): + v5_commands.append(f"ALTER TABLE knowledge_bases ADD COLUMN {col} {col_type}") + + # knowledge_files + if self.check_table_exists("knowledge_files"): + kf_columns = { + "parent_id": "VARCHAR(64)", + "original_filename": "VARCHAR(512)", + "file_type": "VARCHAR(64)", + "path": "VARCHAR(1024)", + "minio_url": "VARCHAR(1024)", + "markdown_file": "VARCHAR(1024)", + "status": "VARCHAR(32) DEFAULT 'uploaded'", + "content_hash": "VARCHAR(128)", + "file_size": "BIGINT", + "content_type": "VARCHAR(64)", + "processing_params": "JSON", + "is_folder": "INTEGER NOT NULL DEFAULT 0", + "error_message": "TEXT", + "created_by": "VARCHAR(64)", + "updated_by": "VARCHAR(64)", + "updated_at": "DATETIME", + } + for col, col_type in kf_columns.items(): + if not self.check_column_exists("knowledge_files", col): + v5_commands.append(f"ALTER TABLE knowledge_files ADD COLUMN {col} {col_type}") + + # evaluation_benchmarks + if self.check_table_exists("evaluation_benchmarks"): + eb_columns = { + "data_file_path": "VARCHAR(1024)", + "created_by": "VARCHAR(64)", + "updated_at": "DATETIME", + } + for col, col_type in eb_columns.items(): + if not self.check_column_exists("evaluation_benchmarks", col): + v5_commands.append(f"ALTER TABLE evaluation_benchmarks ADD COLUMN {col} {col_type}") + + # evaluation_results + if self.check_table_exists("evaluation_results"): + er_columns = { + "retrieval_config": "JSON", + "metrics": "JSON", + "overall_score": "FLOAT", + "total_questions": "INTEGER NOT NULL DEFAULT 0", + "completed_questions": "INTEGER NOT NULL DEFAULT 0", + "started_at": "DATETIME", + "completed_at": "DATETIME", + "created_by": "VARCHAR(64)", + } + for col, col_type in er_columns.items(): + if not self.check_column_exists("evaluation_results", col): + v5_commands.append(f"ALTER TABLE evaluation_results ADD COLUMN {col} {col_type}") + + # evaluation_result_details + if self.check_table_exists("evaluation_result_details"): + erd_columns = { + "gold_chunk_ids": "JSON", + "gold_answer": "TEXT", + "generated_answer": "TEXT", + "retrieved_chunks": "JSON", + "metrics": "JSON", + } + for col, col_type in erd_columns.items(): + if not self.check_column_exists("evaluation_result_details", col): + v5_commands.append(f"ALTER TABLE evaluation_result_details ADD COLUMN {col} {col_type}") + + migrations.append((5, "补全知识库与评估相关表字段", v5_commands)) + # 未来的迁移可以在这里添加 # migrations.append(( # 2, @@ -368,6 +453,67 @@ def validate_database_schema(db_path: str) -> tuple[bool, list[str]]: "extra_metadata", "image_content", ], + "knowledge_bases": [ + "id", + "db_id", + "name", + "kb_type", + "query_params", + "additional_params", + "share_config", + "mindmap", + "sample_questions", + "created_at", + "updated_at", + ], + "knowledge_files": [ + "id", + "file_id", + "db_id", + "filename", + "file_type", + "status", + "is_folder", + "created_at", + "updated_at", + ], + "evaluation_benchmarks": [ + "id", + "benchmark_id", + "db_id", + "name", + "question_count", + "has_gold_chunks", + "has_gold_answers", + "data_file_path", + "created_at", + "updated_at", + ], + "evaluation_results": [ + "id", + "task_id", + "db_id", + "benchmark_id", + "status", + "retrieval_config", + "metrics", + "overall_score", + "total_questions", + "completed_questions", + "started_at", + "completed_at", + ], + "evaluation_result_details": [ + "id", + "task_id", + "query_index", + "query_text", + "gold_chunk_ids", + "gold_answer", + "generated_answer", + "retrieved_chunks", + "metrics", + ], } for table_name, fields in required_fields.items(): diff --git a/src/__init__.py b/src/__init__.py index 1ed71cb4a..a57528607 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -2,10 +2,13 @@ load_dotenv(".env", override=True) +import os # noqa: E402 from concurrent.futures import ThreadPoolExecutor # noqa: E402 from src.config import config as config # noqa: E402 -from src.knowledge import graph_base as graph_base # noqa: E402 -from src.knowledge import knowledge_base as knowledge_base # noqa: E402 + +if os.getenv("YUXI_SKIP_APP_INIT") != "1": + from src.knowledge import graph_base as graph_base # noqa: E402 + from src.knowledge import knowledge_base as knowledge_base # noqa: E402 executor = ThreadPoolExecutor() # noqa: E402 diff --git a/src/agents/chatbot/graph.py b/src/agents/chatbot/graph.py index 810bdb1a8..691b1b9c8 100644 --- a/src/agents/chatbot/graph.py +++ b/src/agents/chatbot/graph.py @@ -3,9 +3,10 @@ from src.agents.common import BaseAgent, load_chat_model from src.agents.common.middlewares import ( + RuntimeConfigMiddleware, inject_attachment_context, ) -from src.agents.common.tools import get_tools_from_context +from src.services.mcp_service import get_tools_from_all_servers class ChatbotAgent(BaseAgent): @@ -18,25 +19,24 @@ def __init__(self, **kwargs): async def get_graph(self, **kwargs): """构建图""" - if self.graph: - return self.graph - - # 获取上下文配置 - context = self.context_schema.from_file(module_name=self.module_name) + context = self.context_schema() + all_mcp_tools = ( + await get_tools_from_all_servers() + ) # 因为异步加载,无法放在 RuntimeConfigMiddleware 的 __init__ 中 # 使用 create_agent 创建智能体 + # 注意:tools 参数由 RuntimeConfigMiddleware 在 wrap_model_call 中动态设置 graph = create_agent( - model=load_chat_model(context.model), # 使用 context 中的模型配置 - tools=await get_tools_from_context(context), + model=load_chat_model(context.model), system_prompt=context.system_prompt, middleware=[ inject_attachment_context, # 附件上下文注入 + RuntimeConfigMiddleware(extra_tools=all_mcp_tools), # 运行时配置应用(模型/工具/知识库/MCP/提示词) ModelRetryMiddleware(), # 模型重试中间件 ], checkpointer=await self._get_checkpointer(), ) - self.graph = graph return graph diff --git a/src/agents/common/base.py b/src/agents/common/base.py index 25a61b9ad..90e864753 100644 --- a/src/agents/common/base.py +++ b/src/agents/common/base.py @@ -28,6 +28,7 @@ class BaseAgent: def __init__(self, **kwargs): self.graph = None # will be covered by get_graph self.checkpointer = None + self._async_conn = None self.workdir = Path(sys_config.save_dir) / "agents" / self.module_name self.workdir.mkdir(parents=True, exist_ok=True) self._metadata_cache = None # Cache for metadata to avoid repeated file reads @@ -62,19 +63,30 @@ async def get_config(self): async def stream_values(self, messages: list[str], input_context=None, **kwargs): graph = await self.get_graph() - context = self.context_schema.from_file(module_name=self.module_name, input_context=input_context) + context = self.context_schema() + agent_config = (input_context or {}).get("agent_config") + if isinstance(agent_config, dict): + context.update(agent_config) + context.update(input_context or {}) for event in graph.astream({"messages": messages}, stream_mode="values", context=context): yield event["messages"] async def stream_messages(self, messages: list[str], input_context=None, **kwargs): graph = await self.get_graph() - context = self.context_schema.from_file(module_name=self.module_name, input_context=input_context) + context = self.context_schema() + agent_config = (input_context or {}).get("agent_config") + if isinstance(agent_config, dict): + context.update(agent_config) + context.update(input_context or {}) logger.debug(f"stream_messages: {context}") # TODO Checkpointer 似乎还没有适配最新的 1.0 Context API # 从 input_context 中提取 attachments(如果有) attachments = (input_context or {}).get("attachments", []) - input_config = {"configurable": input_context, "recursion_limit": 300} + input_config = { + "configurable": {"thread_id": context.thread_id, "user_id": context.user_id}, + "recursion_limit": 300, + } async for msg, metadata in graph.astream( {"messages": messages, "attachments": attachments}, @@ -86,12 +98,19 @@ async def stream_messages(self, messages: list[str], input_context=None, **kwarg async def invoke_messages(self, messages: list[str], input_context=None, **kwargs): graph = await self.get_graph() - context = self.context_schema.from_file(module_name=self.module_name, input_context=input_context) + context = self.context_schema() + agent_config = (input_context or {}).get("agent_config") + if isinstance(agent_config, dict): + context.update(agent_config) + context.update(input_context or {}) logger.debug(f"invoke_messages: {context}") # 从 input_context 中提取 attachments(如果有) attachments = (input_context or {}).get("attachments", []) - input_config = {"configurable": input_context, "recursion_limit": 100} + input_config = { + "configurable": {"thread_id": context.thread_id, "user_id": context.user_id}, + "recursion_limit": 100, + } msg = await graph.ainvoke( {"messages": messages, "attachments": attachments}, context=context, config=input_config @@ -147,6 +166,9 @@ async def get_graph(self, **kwargs) -> CompiledStateGraph: pass async def _get_checkpointer(self): + if self.checkpointer is not None: + return self.checkpointer + # 创建数据库连接并确保设置 checkpointer checkpointer = None @@ -157,15 +179,20 @@ async def _get_checkpointer(self): logger.error(f"构建 Graph 设置 checkpointer 时出错: {e}, 尝试使用内存存储") checkpointer = InMemorySaver() - return checkpointer + self.checkpointer = checkpointer + return self.checkpointer async def get_async_conn(self) -> aiosqlite.Connection: """获取异步数据库连接""" + if self._async_conn is not None: + return self._async_conn + conn = await aiosqlite.connect(os.path.join(self.workdir, "aio_history.db")) # Patch: langgraph's AsyncSqliteSaver expects is_alive() method which aiosqlite may not have if not hasattr(conn, "is_alive"): conn.is_alive = lambda: True - return conn + self._async_conn = conn + return self._async_conn async def get_aio_memory(self) -> AsyncSqliteSaver: """获取异步存储实例""" diff --git a/src/agents/common/context.py b/src/agents/common/context.py index 190fb7b2e..60440f861 100644 --- a/src/agents/common/context.py +++ b/src/agents/common/context.py @@ -9,7 +9,6 @@ import yaml from src import config as sys_config -from src.knowledge import knowledge_base from src.services.mcp_service import get_mcp_server_names from src.utils import logger @@ -43,7 +42,12 @@ def update(self, data: dict): metadata={"name": "用户ID", "configurable": False, "description": "用来唯一标识一个用户"}, ) - system_prompt: str = field( + department_id: int | None = field( + default=None, + metadata={"name": "部门ID", "configurable": False, "description": "用来唯一标识一个部门"}, + ) + + system_prompt: Annotated[str, {"__template_metadata__": {"kind": "prompt"}}] = field( default="You are a helpful assistant.", metadata={"name": "系统提示词", "description": "用来描述智能体的角色和行为"}, ) @@ -66,17 +70,16 @@ def update(self, data: dict): }, ) - knowledges: list[str] = field( + knowledges: Annotated[list[str], {"__template_metadata__": {"kind": "knowledges"}}] = field( default_factory=list, metadata={ "name": "知识库", - "options": lambda: [k["name"] for k in knowledge_base.get_retrievers().values()], "description": "知识库列表,可以在左侧知识库页面中创建知识库。", "type": "list", # Explicitly mark as list type for frontend if needed }, ) - mcps: list[str] = field( + mcps: Annotated[list[str], {"__template_metadata__": {"kind": "mcps"}}] = field( default_factory=list, metadata={ "name": "MCP服务器", diff --git a/src/agents/common/middlewares/__init__.py b/src/agents/common/middlewares/__init__.py index 51c3cf2f3..74e9adcd8 100644 --- a/src/agents/common/middlewares/__init__.py +++ b/src/agents/common/middlewares/__init__.py @@ -1,9 +1,11 @@ from .attachment_middleware import inject_attachment_context from .context_middlewares import context_aware_prompt, context_based_model from .dynamic_tool_middleware import DynamicToolMiddleware +from .runtime_config_middleware import RuntimeConfigMiddleware __all__ = [ "DynamicToolMiddleware", + "RuntimeConfigMiddleware", "context_aware_prompt", "context_based_model", "inject_attachment_context", diff --git a/src/agents/common/middlewares/attachment_middleware.py b/src/agents/common/middlewares/attachment_middleware.py index 520a7a0b8..ad2f1c1eb 100644 --- a/src/agents/common/middlewares/attachment_middleware.py +++ b/src/agents/common/middlewares/attachment_middleware.py @@ -73,12 +73,21 @@ async def awrap_model_call( if attachment_prompt: logger.debug(f"Injecting {len(attachments)} attachments into model request") - # Inject attachment context at the beginning (as SystemMessage) - # 注意:这是 transient update,不会修改 state,只影响本次模型调用 - messages = [ - {"role": "system", "content": attachment_prompt}, - *request.messages, - ] + messages = list(request.messages) + insert_idx = 0 + for idx, msg in enumerate(messages): + if isinstance(msg, dict): + role = msg.get("role") or msg.get("type") + is_system = role == "system" + else: + msg_type = getattr(msg, "type", None) or getattr(msg, "role", None) + is_system = msg_type == "system" + + if not is_system: + break + insert_idx = idx + 1 + + messages.insert(insert_idx, {"role": "system", "content": attachment_prompt}) request = request.override(messages=messages) return await handler(request) diff --git a/src/agents/common/middlewares/runtime_config_middleware.py b/src/agents/common/middlewares/runtime_config_middleware.py new file mode 100644 index 000000000..051ecebcb --- /dev/null +++ b/src/agents/common/middlewares/runtime_config_middleware.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse + +from src.agents.common import load_chat_model +from src.agents.common.tools import get_kb_based_tools +from src.services.mcp_service import get_enabled_mcp_tools +from src.utils.logging_config import logger + + +def _is_system_message(msg: Any) -> bool: + if isinstance(msg, dict): + role = msg.get("role") or msg.get("type") + return role == "system" + msg_type = getattr(msg, "type", None) or getattr(msg, "role", None) + return msg_type == "system" + + +def _get_message_content(msg: Any) -> str | None: + if isinstance(msg, dict): + content = msg.get("content") + return str(content) if content is not None else None + content = getattr(msg, "content", None) + return str(content) if content is not None else None + + +class RuntimeConfigMiddleware(AgentMiddleware): + """运行时配置中间件 - 应用模型/工具/知识库/MCP/提示词配置 + + 注意:所有可能用到的知识库工具必须在初始化时预加载并注册到 self.tools + 运行时根据配置从 self.tools 中筛选工具,不能动态添加新工具 + """ + + def __init__(self, *, extra_tools: list[Any] | None = None): + """初始化中间件 + + Args: + extra_tools: 额外工具列表(从 create_agent 的 tools 参数传入) + """ + super().__init__() + # 这里的工具只是提供给 langchain 调用,并不是真正的绑定在模型上 + self.kb_tools = get_kb_based_tools() + self.tools = self.kb_tools + (extra_tools or []) + logger.debug(f"Initialized tools: {len(self.tools)}") + + async def awrap_model_call( + self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: + runtime_context = request.runtime.context + + model = load_chat_model(getattr(runtime_context, "model", None)) + enabled_tools = await self.get_tools_from_context(runtime_context) + system_prompt = getattr(runtime_context, "system_prompt", None) + + existing_systems: list[Any] = [] + remaining: list[Any] = [] + in_prefix = True + for msg in request.messages: + if in_prefix and _is_system_message(msg): + existing_systems.append(msg) + else: + in_prefix = False + remaining.append(msg) + + existing_contents = [_get_message_content(m) for m in existing_systems] + + new_systems: list[Any] = [] + if system_prompt: + try: + idx = existing_contents.index(system_prompt) + except ValueError: + new_systems.append({"role": "system", "content": system_prompt}) + else: + new_systems.append(existing_systems.pop(idx)) + existing_contents.pop(idx) + + messages = [*new_systems, *existing_systems, *remaining] + + request = request.override(model=model, tools=enabled_tools, messages=messages) + return await handler(request) + + async def get_tools_from_context(self, context) -> list: + """从上下文配置中获取工具列表""" + # 1. 基础工具 (从 context.tools 中筛选) + selected_tools = [] + + if context.tools: + # 创建工具映射表 + tools_map = {t.name: t for t in self.tools} + for tool_name in context.tools: + if tool_name in tools_map: + selected_tools.append(tools_map[tool_name]) + + # 2. 知识库工具 + if context.knowledges: + kb_tools = get_kb_based_tools(db_names=context.knowledges) + selected_tools.extend(kb_tools) + + # 3. MCP 工具(使用统一入口,自动过滤 disabled_tools) + if context.mcps: + for server_name in context.mcps: + mcp_tools = await get_enabled_mcp_tools(server_name) + selected_tools.extend(mcp_tools) + + return selected_tools diff --git a/src/agents/common/tools.py b/src/agents/common/tools.py index d4dc0791d..c547ba6ce 100644 --- a/src/agents/common/tools.py +++ b/src/agents/common/tools.py @@ -192,12 +192,15 @@ async def async_retriever_wrapper( try: logger.debug(f"Getting mindmap for database {db_id}") - # 从知识库元数据中获取思维导图 - if db_id not in knowledge_base.global_databases_meta: + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + + kb_repo = KnowledgeBaseRepository() + kb = await kb_repo.get_by_id(db_id) + + if kb is None: return f"知识库 {retriever_info['name']} 不存在" - db_meta = knowledge_base.global_databases_meta[db_id] - mindmap_data = db_meta.get("mindmap") + mindmap_data = kb.mindmap if not mindmap_data: return f"知识库 {retriever_info['name']} 还没有生成思维导图。" diff --git a/src/agents/deep_agent/context.py b/src/agents/deep_agent/context.py index 34e957cf6..b496dfcba 100644 --- a/src/agents/deep_agent/context.py +++ b/src/agents/deep_agent/context.py @@ -99,10 +99,11 @@ class DeepContext(BaseContext): """ # 深度分析专用的系统提示词 - system_prompt: str = field( + system_prompt: Annotated[str, {"__template_metadata__": {"kind": "prompt"}}] = field( default=DEEP_PROMPT, metadata={"name": "系统提示词", "description": "Deep智能体的角色和行为指导"}, ) + subagents_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field( default="siliconflow/deepseek-ai/DeepSeek-V3.2", metadata={ diff --git a/src/agents/deep_agent/graph.py b/src/agents/deep_agent/graph.py index 8a63cb012..460e7952f 100644 --- a/src/agents/deep_agent/graph.py +++ b/src/agents/deep_agent/graph.py @@ -4,11 +4,17 @@ from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware from deepagents.middleware.subagents import SubAgentMiddleware from langchain.agents import create_agent -from langchain.agents.middleware import ModelRequest, SummarizationMiddleware, TodoListMiddleware, dynamic_prompt +from langchain.agents.middleware import ( + ModelRequest, + SummarizationMiddleware, + TodoListMiddleware, + dynamic_prompt, +) from src.agents.common import BaseAgent, load_chat_model -from src.agents.common.middlewares import inject_attachment_context +from src.agents.common.middlewares import RuntimeConfigMiddleware, inject_attachment_context from src.agents.common.tools import get_tavily_search +from src.services.mcp_service import get_tools_from_all_servers from .context import DeepContext from .prompts import DEEP_PROMPT @@ -88,32 +94,31 @@ async def get_tools(self): async def get_graph(self, **kwargs): """构建 Deep Agent 的图""" - if self.graph: - return self.graph - # 获取上下文配置 context = self.context_schema.from_file(module_name=self.module_name) model = load_chat_model(context.model) sub_model = load_chat_model(context.subagents_model) - tools = await self.get_tools() + search_tools = await self.get_tools() + all_mcp_tools = await get_tools_from_all_servers() + # 合并搜索工具和 MCP 工具 # Build subagents with search tools - research_sub_agent = _get_research_sub_agent(tools) + research_sub_agent = _get_research_sub_agent(search_tools) # 使用 create_deep_agent 创建深度智能体 graph = create_agent( model=model, - tools=tools, system_prompt=context.system_prompt, middleware=[ context_aware_prompt, # 动态系统提示词 inject_attachment_context, # 附件上下文注入 + RuntimeConfigMiddleware(extra_tools=all_mcp_tools), TodoListMiddleware(), FilesystemMiddleware(), SubAgentMiddleware( default_model=sub_model, - default_tools=tools, + default_tools=search_tools, subagents=[critique_sub_agent, research_sub_agent], default_middleware=[ TodoListMiddleware(), # 子智能体也有 todo 列表 @@ -139,5 +144,4 @@ async def get_graph(self, **kwargs): checkpointer=await self._get_checkpointer(), ) - self.graph = graph return graph diff --git a/src/agents/mini_agent/graph.py b/src/agents/mini_agent/graph.py index 6271508b8..937daae49 100644 --- a/src/agents/mini_agent/graph.py +++ b/src/agents/mini_agent/graph.py @@ -1,7 +1,10 @@ from langchain.agents import create_agent from src.agents.common import BaseAgent, load_chat_model -from src.agents.common.tools import get_tools_from_context +from src.agents.common.middlewares import ( + RuntimeConfigMiddleware, +) +from src.services.mcp_service import get_tools_from_all_servers class MiniAgent(BaseAgent): @@ -12,18 +15,17 @@ def __init__(self, **kwargs): super().__init__(**kwargs) async def get_graph(self, **kwargs): - if self.graph: - return self.graph - + """构建图""" context = self.context_schema.from_file(module_name=self.module_name) + all_mcp_tools = await get_tools_from_all_servers() - # 创建 MiniAgent graph = create_agent( model=load_chat_model(context.model), system_prompt=context.system_prompt, - tools=await get_tools_from_context(context), + middleware=[ + RuntimeConfigMiddleware(extra_tools=all_mcp_tools), + ], checkpointer=await self._get_checkpointer(), ) - self.graph = graph return graph diff --git a/src/agents/reporter/graph.py b/src/agents/reporter/graph.py index 15cbcdd2f..2d25e8201 100644 --- a/src/agents/reporter/graph.py +++ b/src/agents/reporter/graph.py @@ -4,8 +4,12 @@ from langchain.agents import create_agent from src.agents.common import BaseAgent, BaseContext, load_chat_model +from src.agents.common.middlewares import ( + RuntimeConfigMiddleware, +) from src.agents.common.toolkits.mysql import get_mysql_tools -from src.agents.common.tools import gen_tool_info, get_buildin_tools, get_tools_from_context +from src.agents.common.tools import gen_tool_info, get_buildin_tools +from src.services.mcp_service import get_tools_from_all_servers from src.utils import logger @@ -35,19 +39,20 @@ def __init__(self, **kwargs): super().__init__(**kwargs) async def get_graph(self, **kwargs): - if self.graph: - return self.graph - + """构建图""" context = self.context_schema.from_file(module_name=self.module_name) + all_mcp_tools = await get_tools_from_all_servers() + # 合并 MySQL 工具和 MCP 工具 + extra_tools = get_mysql_tools() + all_mcp_tools - # 创建 SqlReporterAgent graph = create_agent( - model=load_chat_model(context.model), # 使用 context 中的模型配置 + model=load_chat_model(context.model), system_prompt=context.system_prompt, - tools=await get_tools_from_context(context, extra_tools=get_mysql_tools()), + middleware=[ + RuntimeConfigMiddleware(extra_tools=extra_tools), + ], checkpointer=await self._get_checkpointer(), ) - self.graph = graph logger.info("SqlReporterAgent 构建成功") return graph diff --git a/src/knowledge/adapters/factory.py b/src/knowledge/adapters/factory.py index 06c98eabb..349cae452 100644 --- a/src/knowledge/adapters/factory.py +++ b/src/knowledge/adapters/factory.py @@ -34,7 +34,7 @@ def get_supported_types(cls) -> dict[str, str]: } @classmethod - def detect_graph_type(cls, db_id: str, knowledge_base_manager=None) -> str: + async def detect_graph_type(cls, db_id: str, knowledge_base_manager=None) -> str: """ 自动检测图谱类型 @@ -47,7 +47,7 @@ def detect_graph_type(cls, db_id: str, knowledge_base_manager=None) -> str: """ # 1. 首先检查是否是 LightRAG 数据库 (通过知识库管理器) if knowledge_base_manager: - db_info = knowledge_base_manager.get_database_info(db_id) + db_info = await knowledge_base_manager.get_database_info(db_id) if db_info: # 有信息表示是 LightRAG 数据库 return "lightrag" @@ -59,7 +59,9 @@ def detect_graph_type(cls, db_id: str, knowledge_base_manager=None) -> str: return "upload" @classmethod - def create_adapter_by_db_id(cls, db_id: str, knowledge_base_manager=None, graph_db_instance=None) -> GraphAdapter: + async def create_adapter_by_db_id( + cls, db_id: str, knowledge_base_manager=None, graph_db_instance=None + ) -> GraphAdapter: """ 根据数据库ID自动创建对应的适配器 @@ -71,7 +73,7 @@ def create_adapter_by_db_id(cls, db_id: str, knowledge_base_manager=None, graph_ Returns: 对应的图谱适配器 """ - graph_type = cls.detect_graph_type(db_id, knowledge_base_manager) + graph_type = await cls.detect_graph_type(db_id, knowledge_base_manager) if graph_type == "lightrag": # LightRAG 类型,使用 kb_id 作为配置 @@ -81,8 +83,10 @@ def create_adapter_by_db_id(cls, db_id: str, knowledge_base_manager=None, graph_ return cls.create_adapter("upload", graph_db_instance=graph_db_instance, config={"kgdb_name": db_id}) @classmethod - def create_adapter_for_db_id(cls, db_id: str, knowledge_base_manager=None, graph_db_instance=None) -> GraphAdapter: + async def create_adapter_for_db_id( + cls, db_id: str, knowledge_base_manager=None, graph_db_instance=None + ) -> GraphAdapter: """ 兼容性方法,调用 create_adapter_by_db_id """ - return cls.create_adapter_by_db_id(db_id, knowledge_base_manager, graph_db_instance) + return await cls.create_adapter_by_db_id(db_id, knowledge_base_manager, graph_db_instance) diff --git a/src/knowledge/base.py b/src/knowledge/base.py index 1092c7bf7..f39d5915c 100644 --- a/src/knowledge/base.py +++ b/src/knowledge/base.py @@ -1,8 +1,5 @@ import asyncio -import json import os -import shutil -import tempfile from abc import ABC, abstractmethod from typing import Any @@ -61,6 +58,7 @@ def __init__(self, work_dir: str): self.databases_meta: dict[str, dict] = {} self.files_meta: dict[str, dict] = {} self.benchmarks_meta: dict[str, dict] = {} + self._metadata_loaded = False # 标记元数据是否已加载 # 初始化类级别的锁 if KnowledgeBase._processing_lock is None: @@ -68,9 +66,47 @@ def __init__(self, work_dir: str): os.makedirs(work_dir, exist_ok=True) - # 自动加载元数据 - self._load_metadata() + # 注意:不在 __init__ 中加载元数据,由 KnowledgeBaseManager 统一管理加载 + + def load_metadata( + self, global_databases_meta: dict[str, dict], files_meta: dict[str, dict], benchmarks_meta: dict[str, dict] + ): + """由 KnowledgeBaseManager 调用,同步加载元数据""" + # 过滤出当前 kb_type 的知识库 + self.databases_meta = {} + for db_id, meta in global_databases_meta.items(): + if meta.get("kb_type") == self.kb_type: + self.databases_meta[db_id] = { + "name": meta.get("name"), + "description": meta.get("description"), + "kb_type": meta.get("kb_type"), + "embed_info": meta.get("embed_info"), + "llm_info": meta.get("llm_info"), + "query_params": meta.get("query_params"), + "metadata": meta.get("additional_params", {}), + "created_at": meta.get("created_at"), + } + + # 过滤文件 + self.files_meta = {} + for file_id, meta in files_meta.items(): + if meta.get("database_id") in self.databases_meta: + self.files_meta[file_id] = meta + + # 过滤评估基准 + self.benchmarks_meta = {} + for kb_id, benchmarks in benchmarks_meta.items(): + if kb_id in self.databases_meta: + self.benchmarks_meta[kb_id] = benchmarks + self._normalize_metadata_state() + self._metadata_loaded = True + logger.info(f"{self.kb_type}: 加载了 {len(self.databases_meta)} 个数据库的元数据") + + def _ensure_metadata_loaded(self): + """确保元数据已加载(延迟加载)""" + if not self._metadata_loaded: + logger.warning(f"{self.kb_type}: 元数据尚未加载,请确保 KnowledgeBaseManager 已调用 load_metadata()") @staticmethod def _normalize_timestamp(value: Any) -> str | None: @@ -172,7 +208,7 @@ async def add_file_record( # Save to metadata self.files_meta[file_id] = metadata - self._save_metadata() + await self._save_metadata() return metadata @@ -220,7 +256,7 @@ async def parse_file(self, db_id: str, file_id: str, operator_id: str | None = N self.files_meta[file_id]["updated_at"] = utc_isoformat() if operator_id: self.files_meta[file_id]["updated_by"] = operator_id - self._save_metadata() + await self._save_metadata() # Add to processing queue self._add_to_processing_queue(file_id) @@ -244,7 +280,7 @@ async def parse_file(self, db_id: str, file_id: str, operator_id: str | None = N self.files_meta[file_id]["updated_at"] = utc_isoformat() if operator_id: self.files_meta[file_id]["updated_by"] = operator_id - self._save_metadata() + await self._save_metadata() return self.files_meta[file_id] @@ -257,7 +293,7 @@ async def parse_file(self, db_id: str, file_id: str, operator_id: str | None = N self.files_meta[file_id]["updated_at"] = utc_isoformat() if operator_id: self.files_meta[file_id]["updated_by"] = operator_id - self._save_metadata() + await self._save_metadata() raise @@ -289,7 +325,7 @@ async def update_file_params(self, db_id: str, file_id: str, params: dict, opera logger.debug(f"[update_file_params] file_id={file_id}, updated_params={current_params}") - self._save_metadata() + await self._save_metadata() async def _save_markdown_to_minio(self, db_id: str, file_id: str, content: str) -> str: """Save markdown content to MinIO and return HTTP URL""" @@ -342,7 +378,7 @@ async def index_file(self, db_id: str, file_id: str, operator_id: str | None = N """ pass - def create_database( + async def create_database( self, database_name: str, description: str, @@ -357,6 +393,7 @@ def create_database( database_name: 数据库名称 description: 数据库描述 embed_info: 嵌入模型信息 + llm_info: LLM配置信息 **kwargs: 其他配置参数 Returns: @@ -381,7 +418,7 @@ def create_database( "metadata": kwargs, "created_at": utc_isoformat(), } - self._save_metadata() + await self._save_metadata() # 创建工作目录 working_dir = os.path.join(self.work_dir, db_id) @@ -394,7 +431,7 @@ def create_database( return db_dict - def delete_database(self, db_id: str) -> dict: + async def delete_database(self, db_id: str) -> dict: """ 删除数据库 @@ -405,6 +442,8 @@ def delete_database(self, db_id: str) -> dict: 操作结果 """ if db_id in self.databases_meta: + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + # 删除相关文件记录 files_to_delete = [fid for fid, finfo in self.files_meta.items() if finfo.get("database_id") == db_id] for file_id in files_to_delete: @@ -412,7 +451,8 @@ def delete_database(self, db_id: str) -> dict: # 删除数据库记录 del self.databases_meta[db_id] - self._save_metadata() + await KnowledgeBaseRepository().delete(db_id) + await self._save_metadata() # 删除工作目录 working_dir = os.path.join(self.work_dir, db_id) @@ -426,7 +466,7 @@ def delete_database(self, db_id: str) -> dict: return {"message": "删除成功"} - def create_folder(self, db_id: str, folder_name: str, parent_id: str | None = None) -> dict: + async def create_folder(self, db_id: str, folder_name: str, parent_id: str | None = None) -> dict: """Create a folder in the database.""" import uuid @@ -443,7 +483,7 @@ def create_folder(self, db_id: str, folder_name: str, parent_id: str | None = No "path": folder_name, "file_type": "folder", } - self._save_metadata() + await self._save_metadata() return self.files_meta[folder_id] @abstractmethod @@ -524,12 +564,16 @@ def query(self, query_text: str, db_id: str, **kwargs) -> list[dict]: import asyncio logger.warning("query is deprecated, use aquery instead") - return asyncio.run(self.aquery(query_text, db_id, **kwargs)) + try: + loop = asyncio.get_running_loop() + return loop.run_until_complete(self.aquery(query_text, db_id, **kwargs)) + except RuntimeError: + return asyncio.run(self.aquery(query_text, db_id, **kwargs)) def _get_query_params(self, db_id: str) -> dict: """从实例元数据中加载查询参数""" if db_id in self.databases_meta: - query_params_meta = self.databases_meta[db_id].get("query_params", {}) + query_params_meta = self.databases_meta[db_id].get("query_params") or {} return query_params_meta.get("options", {}) return {} @@ -591,6 +635,9 @@ def get_databases(self) -> dict: Returns: 数据库列表 """ + # 确保元数据已加载(延迟加载机制) + self._ensure_metadata_loaded() + databases = [] for db_id, meta in self.databases_meta.items(): # 检查并修复异常的processing状态 @@ -710,7 +757,6 @@ def _check_and_fix_processing_status(self, db_id: str) -> None: # 如果有状态变更,保存元数据 if status_changed: - self._save_metadata() logger.info(f"Fixed interrupted processing status for database {db_id}") except Exception as e: @@ -779,7 +825,7 @@ async def move_file(self, db_id: str, file_id: str, new_parent_id: str | None) - current = parent_meta.get("parent_id") meta["parent_id"] = new_parent_id - self._save_metadata() + await self._save_metadata() return meta @abstractmethod @@ -877,7 +923,7 @@ def update_database(self, db_id: str, name: str, description: str, llm_info: dic if llm_info is not None: self.databases_meta[db_id]["llm_info"] = llm_info - self._save_metadata() + asyncio.create_task(self._save_metadata()) return self.get_database_info(db_id) @@ -905,90 +951,159 @@ async def retriever(query_text, **kwargs): } return retrievers - def _load_metadata(self): - """加载元数据""" - meta_file = os.path.join(self.work_dir, f"metadata_{self.kb_type}.json") + async def _load_metadata(self) -> None: + from src.repositories.evaluation_repository import EvaluationRepository + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + from src.repositories.knowledge_file_repository import KnowledgeFileRepository + + kb_repo = KnowledgeBaseRepository() + file_repo = KnowledgeFileRepository() + eval_repo = EvaluationRepository() + + databases = [kb for kb in await kb_repo.get_all() if kb.kb_type == self.kb_type] + self.databases_meta = { + kb.db_id: { + "name": kb.name, + "description": kb.description, + "kb_type": kb.kb_type, + "embed_info": kb.embed_info, + "llm_info": kb.llm_info, + "query_params": kb.query_params, + "metadata": kb.additional_params or {}, + "created_at": utc_isoformat(kb.created_at) if kb.created_at else utc_isoformat(), + } + for kb in databases + } + + self.files_meta = {} + for kb in databases: + for record in await file_repo.list_by_db_id(kb.db_id): + self.files_meta[record.file_id] = { + "file_id": record.file_id, + "database_id": record.db_id, + "parent_id": record.parent_id, + "filename": record.filename, + "file_type": record.file_type, + "path": record.path, + "markdown_file": record.markdown_file, + "status": record.status, + "content_hash": record.content_hash, + "size": record.file_size, + "content_type": record.content_type, + "processing_params": record.processing_params, + "is_folder": record.is_folder, + "error": record.error_message, + "created_by": record.created_by, + "updated_by": record.updated_by, + "created_at": utc_isoformat(record.created_at) if record.created_at else None, + "updated_at": utc_isoformat(record.updated_at) if record.updated_at else None, + "original_filename": record.original_filename, + "minio_url": record.minio_url, + } + + self.benchmarks_meta = {} + for kb in databases: + benchmarks = await eval_repo.list_benchmarks(kb.db_id) + if not benchmarks: + continue + self.benchmarks_meta[kb.db_id] = {} + for bench in benchmarks: + self.benchmarks_meta[kb.db_id][bench.benchmark_id] = { + "id": bench.benchmark_id, + "benchmark_id": bench.benchmark_id, + "name": bench.name, + "description": bench.description, + "db_id": bench.db_id, + "question_count": bench.question_count, + "has_gold_chunks": bench.has_gold_chunks, + "has_gold_answers": bench.has_gold_answers, + "benchmark_file": bench.data_file_path, + "created_by": bench.created_by, + "created_at": utc_isoformat(bench.created_at) if bench.created_at else None, + "updated_at": utc_isoformat(bench.updated_at) if bench.updated_at else None, + } + + logger.info(f"Loaded {self.kb_type} metadata from database for {len(self.databases_meta)} databases") + + async def _save_metadata(self) -> None: + from src.repositories.evaluation_repository import EvaluationRepository + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + from src.repositories.knowledge_file_repository import KnowledgeFileRepository + + kb_repo = KnowledgeBaseRepository() + file_repo = KnowledgeFileRepository() + eval_repo = EvaluationRepository() - if os.path.exists(meta_file): - try: - with open(meta_file, encoding="utf-8") as f: - data = json.load(f) - self.databases_meta = data.get("databases", {}) - self.files_meta = data.get("files", {}) - self.benchmarks_meta = data.get("benchmarks", {}) - logger.info(f"Loaded {self.kb_type} metadata for {len(self.databases_meta)} databases") - except Exception as e: - logger.error(f"Failed to load {self.kb_type} metadata: {e}") - # 尝试从备份恢复 - backup_file = f"{meta_file}.backup" - if os.path.exists(backup_file): - try: - with open(backup_file, encoding="utf-8") as f: - data = json.load(f) - self.databases_meta = data.get("databases", {}) - self.files_meta = data.get("files", {}) - self.benchmarks_meta = data.get("benchmarks", {}) - logger.info(f"Loaded {self.kb_type} metadata from backup") - # 恢复备份文件 - shutil.copy2(backup_file, meta_file) - return - except Exception as backup_e: - logger.error(f"Failed to load backup: {backup_e}") - - # 如果加载失败,初始化为空状态 - logger.warning(f"Initializing empty {self.kb_type} metadata") - self.databases_meta = {} - self.files_meta = {} - self.benchmarks_meta = {} - - def _serialize_metadata(self, obj): - """递归序列化元数据中的 Pydantic 模型""" - if hasattr(obj, "dict"): - return obj.dict() - elif isinstance(obj, dict): - return {k: self._serialize_metadata(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [self._serialize_metadata(item) for item in obj] - else: - return obj - - def _save_metadata(self): - """保存元数据""" self._normalize_metadata_state() - meta_file = os.path.join(self.work_dir, f"metadata_{self.kb_type}.json") - backup_file = f"{meta_file}.backup" - try: - # 创建简单备份 - if os.path.exists(meta_file): - shutil.copy2(meta_file, backup_file) - - # 准备数据并序列化 Pydantic 模型 - data = { - "databases": self._serialize_metadata(self.databases_meta), - "files": self._serialize_metadata(self.files_meta), - "benchmarks": self._serialize_metadata(self.benchmarks_meta), - "kb_type": self.kb_type, - "updated_at": utc_isoformat(), + for db_id, meta in self.databases_meta.items(): + existing = await kb_repo.get_by_id(db_id) + payload = { + "db_id": db_id, + "name": meta.get("name") or db_id, + "description": meta.get("description"), + "kb_type": meta.get("kb_type") or self.kb_type, + "embed_info": meta.get("embed_info"), + "llm_info": meta.get("llm_info"), + "query_params": meta.get("query_params"), + "additional_params": meta.get("metadata") or {}, } + if existing is None: + await kb_repo.create(payload) + else: + await kb_repo.update( + db_id, + { + "name": payload["name"], + "description": payload["description"], + "kb_type": payload["kb_type"], + "embed_info": payload["embed_info"], + "llm_info": payload["llm_info"], + "query_params": payload["query_params"], + "additional_params": payload["additional_params"], + }, + ) - # 原子性写入(使用临时文件) - with tempfile.NamedTemporaryFile( - mode="w", dir=os.path.dirname(meta_file), prefix=".tmp_", suffix=".json", delete=False - ) as tmp_file: - json.dump(data, tmp_file, ensure_ascii=False, indent=2) - temp_path = tmp_file.name - - os.replace(temp_path, meta_file) - logger.debug(f"Saved {self.kb_type} metadata") + for file_id, meta in self.files_meta.items(): + db_id = meta.get("database_id") + if not db_id: + continue + await file_repo.upsert( + file_id=file_id, + data={ + "db_id": db_id, + "parent_id": meta.get("parent_id"), + "filename": meta.get("filename") or "", + "original_filename": meta.get("original_filename"), + "file_type": meta.get("file_type"), + "path": meta.get("path"), + "minio_url": meta.get("minio_url"), + "markdown_file": meta.get("markdown_file"), + "status": meta.get("status"), + "content_hash": meta.get("content_hash"), + "file_size": meta.get("size"), + "content_type": meta.get("content_type"), + "processing_params": meta.get("processing_params"), + "is_folder": meta.get("is_folder", False), + "error_message": meta.get("error"), + "created_by": str(meta.get("created_by")) if meta.get("created_by") else None, + "updated_by": str(meta.get("updated_by")) if meta.get("updated_by") else None, + }, + ) - except Exception as e: - logger.error(f"Failed to save {self.kb_type} metadata: {e}") - # 尝试恢复备份 - if os.path.exists(backup_file): - try: - shutil.copy2(backup_file, meta_file) - logger.info("Restored metadata from backup") - except Exception as restore_e: - logger.error(f"Failed to restore backup: {restore_e}") - raise e + for db_id, benchmarks in self.benchmarks_meta.items(): + for benchmark_id, meta in benchmarks.items(): + existing = await eval_repo.get_benchmark(benchmark_id) + payload = { + "benchmark_id": benchmark_id, + "db_id": db_id, + "name": meta.get("name") or benchmark_id, + "description": meta.get("description"), + "question_count": int(meta.get("question_count") or 0), + "has_gold_chunks": bool(meta.get("has_gold_chunks")), + "has_gold_answers": bool(meta.get("has_gold_answers")), + "data_file_path": meta.get("benchmark_file"), + "created_by": str(meta.get("created_by")) if meta.get("created_by") else None, + } + if existing is None: + await eval_repo.create_benchmark(payload) diff --git a/src/knowledge/implementations/lightrag.py b/src/knowledge/implementations/lightrag.py index 057c6926d..0c73f307b 100644 --- a/src/knowledge/implementations/lightrag.py +++ b/src/knowledge/implementations/lightrag.py @@ -284,7 +284,7 @@ async def index_file(self, db_id: str, file_id: str, operator_id: str | None = N self.files_meta[file_id]["updated_at"] = utc_isoformat() if operator_id: self.files_meta[file_id]["updated_by"] = operator_id - self._save_metadata() + await self._save_metadata() # Add to processing queue self._add_to_processing_queue(file_id) @@ -307,7 +307,7 @@ async def index_file(self, db_id: str, file_id: str, operator_id: str | None = N self.files_meta[file_id]["updated_at"] = utc_isoformat() if operator_id: self.files_meta[file_id]["updated_by"] = operator_id - self._save_metadata() + await self._save_metadata() return self.files_meta[file_id] @@ -318,7 +318,7 @@ async def index_file(self, db_id: str, file_id: str, operator_id: str | None = N self.files_meta[file_id]["updated_at"] = utc_isoformat() if operator_id: self.files_meta[file_id]["updated_by"] = operator_id - self._save_metadata() + await self._save_metadata() raise finally: @@ -360,7 +360,7 @@ async def update_content(self, db_id: str, file_ids: list[str], params: dict | N # 更新状态为处理中 self.files_meta[file_id]["processing_params"] = params.copy() self.files_meta[file_id]["status"] = "processing" - self._save_metadata() + await self._save_metadata() # 重新解析文件为 markdown if content_type != "file": @@ -379,7 +379,7 @@ async def update_content(self, db_id: str, file_ids: list[str], params: dict | N # 更新元数据状态 self.files_meta[file_id]["status"] = "done" - self._save_metadata() + await self._save_metadata() # 从处理队列中移除 self._remove_from_processing_queue(file_id) @@ -395,7 +395,7 @@ async def update_content(self, db_id: str, file_ids: list[str], params: dict | N logger.error(f"更新{content_type} {file_path} 失败: {error_msg}, {traceback.format_exc()}") self.files_meta[file_id]["status"] = "failed" self.files_meta[file_id]["error"] = error_msg - self._save_metadata() + await self._save_metadata() # 从处理队列中移除 self._remove_from_processing_queue(file_id) @@ -505,7 +505,10 @@ async def delete_file(self, db_id: str, file_id: str) -> None: # 删除文件记录 if file_id in self.files_meta: del self.files_meta[file_id] - self._save_metadata() + from src.repositories.knowledge_file_repository import KnowledgeFileRepository + + await KnowledgeFileRepository().delete(file_id) + await self._save_metadata() async def get_file_basic_info(self, db_id: str, file_id: str) -> dict: """获取文件基本信息(仅元数据)""" diff --git a/src/knowledge/implementations/milvus.py b/src/knowledge/implementations/milvus.py index 4300f3a3f..0a05ebd33 100644 --- a/src/knowledge/implementations/milvus.py +++ b/src/knowledge/implementations/milvus.py @@ -280,7 +280,7 @@ async def index_file(self, db_id: str, file_id: str, operator_id: str | None = N self.files_meta[file_id]["updated_at"] = utc_isoformat() if operator_id: self.files_meta[file_id]["updated_by"] = operator_id - self._save_metadata() + await self._save_metadata() # Read processing params inside lock to ensure we get the latest values params = file_meta.get("processing_params", {}) or {} @@ -333,7 +333,7 @@ def _insert_records(): self.files_meta[file_id]["updated_at"] = utc_isoformat() if operator_id: self.files_meta[file_id]["updated_by"] = operator_id - self._save_metadata() + await self._save_metadata() return self.files_meta[file_id] except Exception as e: @@ -344,7 +344,7 @@ def _insert_records(): self.files_meta[file_id]["updated_at"] = utc_isoformat() if operator_id: self.files_meta[file_id]["updated_by"] = operator_id - self._save_metadata() + await self._save_metadata() raise finally: @@ -392,7 +392,7 @@ async def update_content(self, db_id: str, file_ids: list[str], params: dict | N async with self._metadata_lock: self.files_meta[file_id]["processing_params"] = params.copy() self.files_meta[file_id]["status"] = "processing" - self._save_metadata() + await self._save_metadata() # 重新解析文件为 markdown if content_type != "file": @@ -430,7 +430,7 @@ def _insert_records(): # 更新元数据状态 async with self._metadata_lock: self.files_meta[file_id]["status"] = "done" - self._save_metadata() + await self._save_metadata() # 从处理队列中移除 self._remove_from_processing_queue(file_id) @@ -445,7 +445,7 @@ def _insert_records(): logger.error(f"更新{content_type} {file_path} 失败: {e}, {traceback.format_exc()}") async with self._metadata_lock: self.files_meta[file_id]["status"] = "failed" - self._save_metadata() + await self._save_metadata() # 从处理队列中移除 self._remove_from_processing_queue(file_id) @@ -614,7 +614,10 @@ async def delete_file(self, db_id: str, file_id: str) -> None: async with self._metadata_lock: if file_id in self.files_meta: del self.files_meta[file_id] - self._save_metadata() + from src.repositories.knowledge_file_repository import KnowledgeFileRepository + + await KnowledgeFileRepository().delete(file_id) + await self._save_metadata() async def get_file_basic_info(self, db_id: str, file_id: str) -> dict: """获取文件基本信息(仅元数据)""" diff --git a/src/knowledge/manager.py b/src/knowledge/manager.py index abf097133..1f262bc9f 100644 --- a/src/knowledge/manager.py +++ b/src/knowledge/manager.py @@ -1,20 +1,17 @@ import asyncio -import json import os -import shutil -import tempfile from src.knowledge.base import KBNotFoundError, KnowledgeBase from src.knowledge.factory import KnowledgeBaseFactory from src.utils import logger -from src.utils.datetime_utils import coerce_any_to_utc_datetime, utc_isoformat +from src.utils.datetime_utils import utc_isoformat class KnowledgeBaseManager: """ 知识库管理器 - 统一管理多种类型的知识库实例,提供统一的外部接口 + 统一管理多种类型的知识库实例,直接通过 Repository 访问数据库,不维护冗余缓存。 """ def __init__(self, work_dir: str): @@ -30,134 +27,53 @@ def __init__(self, work_dir: str): # 知识库实例缓存 {kb_type: kb_instance} self.kb_instances: dict[str, KnowledgeBase] = {} - # 全局数据库元信息 {db_id: metadata_with_kb_type} - self.global_databases_meta: dict[str, dict] = {} - # 元数据锁 self._metadata_lock = asyncio.Lock() - # 加载全局元数据 - self._load_global_metadata() - self._normalize_global_metadata() - + async def initialize(self): + """异步初始化""" # 初始化已存在的知识库实例 self._initialize_existing_kbs() - - # 迁移 query_params 到 instance metadata - try: - self._migrate_all_query_params() - except Exception as e: - logger.warning(f"Failed to migrate query_params: {e}") - logger.info("KnowledgeBaseManager initialized") - # 在后台运行数据一致性检测(不阻塞初始化) - # try: - # # 尝试获取当前事件循环,如果没有则创建新的 - # try: - # loop = asyncio.get_event_loop() - # if loop.is_running(): - # # 如果已经在事件循环中,创建任务 - # asyncio.create_task(self.detect_data_inconsistencies()) - # else: - # # 如果事件循环未运行,直接运行 - # loop.run_until_complete(self.detect_data_inconsistencies()) - # except RuntimeError: - # # 没有事件循环,创建一个来运行检测 - # asyncio.run(self.detect_data_inconsistencies()) - # except Exception as e: - # logger.warning(f"初始化时运行数据一致性检测失败: {e}") - - def _load_global_metadata(self): - """加载全局元数据""" - meta_file = os.path.join(self.work_dir, "global_metadata.json") - - if os.path.exists(meta_file): - try: - with open(meta_file, encoding="utf-8") as f: - data = json.load(f) - self.global_databases_meta = data.get("databases", {}) - logger.info(f"Loaded global metadata for {len(self.global_databases_meta)} databases") - except Exception as e: - logger.error(f"Failed to load global metadata: {e}") - # 尝试从备份恢复 - backup_file = f"{meta_file}.backup" - if os.path.exists(backup_file): - try: - with open(backup_file, encoding="utf-8") as f: - data = json.load(f) - self.global_databases_meta = data.get("databases", {}) - logger.info("Loaded global metadata from backup") - # 恢复备份文件 - shutil.copy2(backup_file, meta_file) - return - except Exception as backup_e: - logger.error(f"Failed to load backup: {backup_e}") - - # 如果加载失败,初始化为空状态 - logger.warning("Initializing empty global metadata") - self.global_databases_meta = {} - - def _save_global_metadata(self): - """保存全局元数据""" - self._normalize_global_metadata() - meta_file = os.path.join(self.work_dir, "global_metadata.json") - backup_file = f"{meta_file}.backup" + async def _load_all_metadata(self): + """异步加载所有元数据 - 保留兼容性的空方法,现在由 KB 实例自行加载""" + pass - try: - # 创建简单备份 - if os.path.exists(meta_file): - shutil.copy2(meta_file, backup_file) + def _initialize_existing_kbs(self): + """初始化已存在的知识库实例""" + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository - # 准备数据 - data = {"databases": self.global_databases_meta, "updated_at": utc_isoformat(), "version": "2.0"} + async def _async_init(): + kb_repo = KnowledgeBaseRepository() + rows = await kb_repo.get_all() - # 原子性写入(使用临时文件) - with tempfile.NamedTemporaryFile( - mode="w", dir=os.path.dirname(meta_file), prefix=".tmp_", suffix=".json", delete=False - ) as tmp_file: - json.dump(data, tmp_file, ensure_ascii=False, indent=2) - temp_path = tmp_file.name + kb_types_in_use = set() + for row in rows: + kb_type = row.kb_type or "lightrag" + kb_types_in_use.add(kb_type) - os.replace(temp_path, meta_file) - logger.debug("Saved global metadata") + logger.info(f"[InitializeKB] 发现 {len(kb_types_in_use)} 种知识库类型: {kb_types_in_use}") - except Exception as e: - logger.error(f"Failed to save global metadata: {e}") - # 尝试恢复备份 - if os.path.exists(backup_file): - try: - shutil.copy2(backup_file, meta_file) - logger.info("Restored global metadata from backup") - except Exception as restore_e: - logger.error(f"Failed to restore backup: {restore_e}") - raise e - - def _normalize_global_metadata(self) -> None: - """Normalize stored timestamps within the global metadata cache.""" - for meta in self.global_databases_meta.values(): - if "created_at" in meta: + # 为每种使用中的知识库类型创建实例并加载元数据 + for kb_type in kb_types_in_use: try: - dt_value = coerce_any_to_utc_datetime(meta.get("created_at")) - if dt_value: - meta["created_at"] = utc_isoformat(dt_value) - continue - except Exception as exc: # noqa: BLE001 - logger.warning(f"Failed to normalize database metadata timestamp {meta.get('created_at')!r}: {exc}") + kb_instance = self._get_or_create_kb_instance(kb_type) + # 让 KB 实例自行加载元数据 + await kb_instance._load_metadata() + logger.info(f"[InitializeKB] {kb_type} 实例已初始化") + except Exception as e: + logger.error(f"Failed to initialize {kb_type} knowledge base: {e}") + import traceback - def _initialize_existing_kbs(self): - """初始化已存在的知识库实例""" - kb_types_in_use = set() - for db_meta in self.global_databases_meta.values(): - kb_type = db_meta.get("kb_type", "lightrag") # 默认为lightrag - kb_types_in_use.add(kb_type) + logger.error(traceback.format_exc()) - # 为每种使用中的知识库类型创建实例 - for kb_type in kb_types_in_use: - try: - self._get_or_create_kb_instance(kb_type) - except Exception as e: - logger.error(f"Failed to initialize {kb_type} knowledge base: {e}") + # 在事件循环中运行异步初始化 + try: + loop = asyncio.get_running_loop() + loop.create_task(_async_init()) + except RuntimeError: + asyncio.run(_async_init()) def _get_or_create_kb_instance(self, kb_type: str) -> KnowledgeBase: """ @@ -184,10 +100,10 @@ async def move_file(self, db_id: str, file_id: str, new_parent_id: str | None) - """ 移动文件/文件夹 """ - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.move_file(db_id, file_id, new_parent_id) - def _get_kb_for_database(self, db_id: str) -> KnowledgeBase: + async def _get_kb_for_database(self, db_id: str) -> KnowledgeBase: """ 根据数据库ID获取对应的知识库实例 @@ -200,60 +116,173 @@ def _get_kb_for_database(self, db_id: str) -> KnowledgeBase: Raises: KBNotFoundError: 数据库不存在或知识库类型不支持 """ - if db_id not in self.global_databases_meta: + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + + kb_repo = KnowledgeBaseRepository() + kb = await kb_repo.get_by_id(db_id) + + if kb is None: raise KBNotFoundError(f"Database {db_id} not found") - kb_type = self.global_databases_meta[db_id].get("kb_type", "lightrag") + kb_type = kb.kb_type or "lightrag" if not KnowledgeBaseFactory.is_type_supported(kb_type): raise KBNotFoundError(f"Unsupported knowledge base type: {kb_type}") return self._get_or_create_kb_instance(kb_type) + def _get_kb_for_database_sync(self, db_id: str) -> KnowledgeBase: + """同步版本的 _get_kb_for_database,用于兼容同步调用""" + try: + loop = asyncio.get_running_loop() + return loop.run_until_complete(self._get_kb_for_database(db_id)) + except RuntimeError: + return asyncio.run(self._get_kb_for_database(db_id)) + # ============================================================================= # 统一的外部接口 - 与原始 LightRagBasedKB 兼容 # ============================================================================= + async def aget_kb(self, db_id: str) -> KnowledgeBase: + """异步获取知识库实例 + + Args: + db_id: 数据库ID + + Returns: + 知识库实例 + """ + return await self._get_kb_for_database(db_id) + def get_kb(self, db_id: str) -> KnowledgeBase: - """Public accessor to fetch the underlying knowledge base instance by database id. + """同步获取知识库实例(兼容性方法,用于同步上下文) - This provides a simple compatibility layer for callers that expect a - `get_kb` method on the manager. + Args: + db_id: 数据库ID + + Returns: + 知识库实例 """ - return self._get_kb_for_database(db_id) + return self._get_kb_for_database_sync(db_id) - def get_databases(self) -> dict: + async def get_databases(self) -> dict: """获取所有数据库信息""" - all_databases = [] + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository - # 收集所有知识库的数据库信息 - for kb_type, kb_instance in self.kb_instances.items(): - kb_databases = kb_instance.get_databases()["databases"] - all_databases.extend(kb_databases) + kb_repo = KnowledgeBaseRepository() + rows = await kb_repo.get_all() + all_databases = [] + for row in rows: + kb_instance = self._get_or_create_kb_instance(row.kb_type or "lightrag") + db_info = kb_instance.get_database_info(row.db_id) + if db_info: + # 补充 share_config 和 additional_params + db_info["share_config"] = row.share_config or {"is_shared": True, "accessible_departments": []} + db_info["additional_params"] = row.additional_params or {} + all_databases.append(db_info) return {"databases": all_databases} - def database_name_exists(self, database_name: str) -> bool: - """检查知识库名称是否已存在 + async def check_accessible(self, user: dict, db_id: str) -> bool: + """检查用户是否有权限访问数据库 Args: - database_name: 要检查的知识库名称 + user: 用户信息字典 + db_id: 数据库ID Returns: - True 如果名称已存在,False 否则 + bool: 是否有权限 """ - for db_id, meta in self.global_databases_meta.items(): - if meta.get("name", "").lower() == database_name.lower(): + # 超级管理员有权访问所有 + if user.get("role") == "superadmin": + return True + + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + + kb_repo = KnowledgeBaseRepository() + kb = await kb_repo.get_by_id(db_id) + if kb is None: + return False + + share_config = kb.share_config or {} + is_shared = share_config.get("is_shared", True) + + # 如果是全员共享,则有权限 + if is_shared: + return True + + # 检查部门权限 + user_department_id = user.get("department_id") + accessible_departments = share_config.get("accessible_departments", []) + + if user_department_id is None: + return False + + # 转换为整数进行比较(前端可能传递字符串,后端存储为整数) + try: + user_department_id = int(user_department_id) + accessible_departments = [int(d) for d in accessible_departments] + except (ValueError, TypeError): + return False + + return user_department_id in accessible_departments + + async def get_databases_by_user(self, user: dict) -> dict: + """根据用户权限获取知识库列表 + + Args: + user: 用户信息字典,包含 role 和 department_id + + Returns: + 过滤后的知识库列表 + """ + all_databases = (await self.get_databases()).get("databases", []) + + # 超级管理员可以看到所有知识库 + if user.get("role") == "superadmin": + return {"databases": all_databases} + + filtered_databases = [] + + for db in all_databases: + db_id = db.get("db_id") + if not db_id: + continue + + if await self.check_accessible(user, db_id): + filtered_databases.append(db) + + return {"databases": filtered_databases} + + async def database_name_exists(self, database_name: str) -> bool: + """检查知识库名称是否已存在""" + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + from src.storage.postgres.manager import pg_manager + + # 确保 pg_manager 已初始化 + if not pg_manager._initialized: + pg_manager.initialize() + + kb_repo = KnowledgeBaseRepository() + rows = await kb_repo.get_all() + for row in rows: + if (row.name or "").lower() == database_name.lower(): return True return False async def create_folder(self, db_id: str, folder_name: str, parent_id: str = None) -> dict: """Create a folder in the database.""" - kb_instance = self._get_kb_for_database(db_id) - return kb_instance.create_folder(db_id, folder_name, parent_id) + kb_instance = await self._get_kb_for_database(db_id) + return await kb_instance.create_folder(db_id, folder_name, parent_id) async def create_database( - self, database_name: str, description: str, kb_type: str = "lightrag", embed_info: dict | None = None, **kwargs + self, + database_name: str, + description: str, + kb_type: str = "lightrag", + embed_info: dict | None = None, + share_config: dict | None = None, + **kwargs, ) -> dict: """ 创建数据库 @@ -263,6 +292,7 @@ async def create_database( description: 数据库描述 kb_type: 知识库类型,默认为lightrag embed_info: 嵌入模型信息 + share_config: 共享配置 **kwargs: 其他配置参数,包括chunk_size和chunk_overlap Returns: @@ -273,37 +303,50 @@ async def create_database( raise ValueError(f"Unsupported knowledge base type: {kb_type}. Available types: {available_types}") # 检查名称是否已存在 - if self.database_name_exists(database_name): + if await self.database_name_exists(database_name): raise ValueError(f"知识库名称 '{database_name}' 已存在,请使用其他名称") - kb_instance = self._get_or_create_kb_instance(kb_type) + # 默认共享配置 + if share_config is None: + share_config = {"is_shared": True, "accessible_departments": []} - db_info = kb_instance.create_database(database_name, description, embed_info, **kwargs) + kb_instance = self._get_or_create_kb_instance(kb_type) + db_info = await kb_instance.create_database(database_name, description, embed_info, **kwargs) db_id = db_info["db_id"] - async with self._metadata_lock: - self.global_databases_meta[db_id] = { - "name": database_name, - "description": description, - "kb_type": kb_type, - "created_at": utc_isoformat(), - "additional_params": kwargs.copy(), - } - self._save_global_metadata() + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + + kb_repo = KnowledgeBaseRepository() + updated = await kb_repo.update(db_id, {"share_config": share_config}) + if updated is None: + await kb_repo.create( + { + "db_id": db_id, + "name": database_name, + "description": description, + "kb_type": kb_type, + "embed_info": embed_info, + "llm_info": db_info.get("llm_info"), + "additional_params": kwargs.copy(), + "share_config": share_config, + } + ) logger.info(f"Created {kb_type} database: {database_name} ({db_id}) with {kwargs}") + db_info["share_config"] = share_config return db_info async def delete_database(self, db_id: str) -> dict: """删除数据库""" + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + try: - kb_instance = self._get_kb_for_database(db_id) - result = kb_instance.delete_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) + result = await kb_instance.delete_database(db_id) - async with self._metadata_lock: - if db_id in self.global_databases_meta: - del self.global_databases_meta[db_id] - self._save_global_metadata() + # 删除数据库记录 + kb_repo = KnowledgeBaseRepository() + await kb_repo.delete(db_id) return result except KBNotFoundError as e: @@ -314,125 +357,106 @@ async def add_file_record( self, db_id: str, item: str, params: dict | None = None, operator_id: str | None = None ) -> dict: """Add file record to metadata""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.add_file_record(db_id, item, params, operator_id) async def parse_file(self, db_id: str, file_id: str, operator_id: str | None = None) -> dict: """Parse file to Markdown""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.parse_file(db_id, file_id, operator_id) async def index_file(self, db_id: str, file_id: str, operator_id: str | None = None) -> dict: """Index parsed file""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.index_file(db_id, file_id, operator_id) async def update_file_params(self, db_id: str, file_id: str, params: dict, operator_id: str | None = None) -> None: """Update file processing params""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) await kb_instance.update_file_params(db_id, file_id, params, operator_id) async def aquery(self, query_text: str, db_id: str, **kwargs) -> str: """异步查询知识库""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.aquery(query_text, db_id, **kwargs) async def export_data(self, db_id: str, format: str = "zip", **kwargs) -> str: """导出知识库数据""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.export_data(db_id, format=format, **kwargs) def query(self, query_text: str, db_id: str, **kwargs) -> str: """同步查询知识库(兼容性方法)""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = self._get_kb_for_database_sync(db_id) return kb_instance.query(query_text, db_id, **kwargs) - def _migrate_all_query_params(self): - """将所有 query_params 从 global metadata 迁移到 instance metadata""" - migration_count = 0 - - for db_id, global_meta in list(self.global_databases_meta.items()): - if "query_params" not in global_meta: - continue - - kb_type = global_meta.get("kb_type", "lightrag") - kb_instance = self.kb_instances.get(kb_type) - - if not kb_instance or db_id not in kb_instance.databases_meta: - logger.warning(f"Cannot migrate query_params for {db_id}, skipping") - continue - - # 检查是否已迁移 - if "query_params" in kb_instance.databases_meta[db_id]: - # 已经迁移过,直接清理 global metadata 并跳过 - del global_meta["query_params"] - continue - - # 执行迁移 - kb_instance.databases_meta[db_id]["query_params"] = global_meta["query_params"] - del global_meta["query_params"] - migration_count += 1 + async def get_database_info(self, db_id: str) -> dict | None: + """获取数据库详细信息""" + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository - if migration_count > 0: - self._save_global_metadata() - logger.info(f"Successfully migrated query_params for {migration_count} databases") + kb_repo = KnowledgeBaseRepository() + kb = await kb_repo.get_by_id(db_id) + if kb is None: + return None - def get_database_info(self, db_id: str) -> dict | None: - """获取数据库详细信息""" try: - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) db_info = kb_instance.get_database_info(db_id) + except KBNotFoundError: + db_info = { + "db_id": db_id, + "name": kb.name, + "description": kb.description, + "kb_type": kb.kb_type, + "files": {}, + "row_count": 0, + "status": "已连接", + } - # 添加全局元数据中的additional_params信息 - if db_info and db_id in self.global_databases_meta: - global_meta = self.global_databases_meta[db_id] - additional_params = global_meta.get("additional_params", {}).copy() - - # 确保 auto_generate_questions 存在,默认为 False - if "auto_generate_questions" not in additional_params: - additional_params["auto_generate_questions"] = False - - db_info["additional_params"] = additional_params + # 添加数据库中的附加字段 + db_info["additional_params"] = kb.additional_params or {} + db_info["share_config"] = kb.share_config or {"is_shared": True, "accessible_departments": []} + db_info["mindmap"] = kb.mindmap + db_info["sample_questions"] = kb.sample_questions or [] + db_info["query_params"] = kb.query_params - return db_info - except KBNotFoundError: - return None + return db_info async def delete_folder(self, db_id: str, folder_id: str) -> None: """递归删除文件夹""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) await kb_instance.delete_folder(db_id, folder_id) async def delete_file(self, db_id: str, file_id: str) -> None: """删除文件""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) await kb_instance.delete_file(db_id, file_id) async def update_content(self, db_id: str, file_ids: list[str], params: dict | None = None) -> list[dict]: """更新内容(重新分块)""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.update_content(db_id, file_ids, params or {}) async def get_file_basic_info(self, db_id: str, file_id: str) -> dict: """获取文件基本信息(仅元数据)""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.get_file_basic_info(db_id, file_id) async def get_file_content(self, db_id: str, file_id: str) -> dict: """获取文件内容信息(chunks和lines)""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.get_file_content(db_id, file_id) async def get_file_info(self, db_id: str, file_id: str) -> dict: """获取文件完整信息(基本信息+内容信息)- 保持向后兼容""" - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) return await kb_instance.get_file_info(db_id, file_id) def get_db_upload_path(self, db_id: str | None = None) -> str: """获取数据库上传路径""" if db_id: try: - kb_instance = self._get_kb_for_database(db_id) + kb_instance = self._get_kb_for_database_sync(db_id) return kb_instance.get_db_upload_path(db_id) except KBNotFoundError: # 如果数据库不存在,创建通用上传路径 @@ -448,7 +472,7 @@ async def file_name_existed_in_db(self, db_id: str | None, file_name: str | None if not db_id or not file_name: return False try: - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) except KBNotFoundError: return False @@ -481,7 +505,7 @@ async def get_same_name_files(self, db_id: str, filename: str) -> list[dict]: if not db_id or not filename: return [] try: - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) except KBNotFoundError: return [] @@ -516,7 +540,7 @@ async def file_existed_in_db(self, db_id: str | None, content_hash: str | None) return False try: - kb_instance = self._get_kb_for_database(db_id) + kb_instance = await self._get_kb_for_database(db_id) except KBNotFoundError: return False @@ -531,29 +555,37 @@ async def file_existed_in_db(self, db_id: str | None, content_hash: str | None) return False async def update_database( - self, db_id: str, name: str, description: str, llm_info: dict = None, additional_params: dict | None = None + self, + db_id: str, + name: str, + description: str, + llm_info: dict = None, + additional_params: dict | None = None, + share_config: dict | None = None, ) -> dict: """更新数据库""" - kb_instance = self._get_kb_for_database(db_id) - result = kb_instance.update_database(db_id, name, description, llm_info) - - async with self._metadata_lock: - if db_id in self.global_databases_meta: - self.global_databases_meta[db_id]["name"] = name - self.global_databases_meta[db_id]["description"] = description + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository - # 合并现有的 additional_params 和新的 additional_params - existing_additional_params = self.global_databases_meta[db_id].get("additional_params", {}) - if additional_params: - existing_additional_params.update(additional_params) - self.global_databases_meta[db_id]["additional_params"] = existing_additional_params + kb_instance = await self._get_kb_for_database(db_id) + kb_instance.update_database(db_id, name, description, llm_info) - # 清理旧的 top-level key (如果存在) - self.global_databases_meta[db_id].pop("auto_generate_questions", None) + # 准备更新数据 + update_data: dict = { + "name": name, + "description": description, + } + if llm_info is not None: + update_data["llm_info"] = llm_info + if additional_params is not None: + update_data["additional_params"] = additional_params + if share_config is not None: + update_data["share_config"] = share_config - self._save_global_metadata() + # 保存到数据库 + kb_repo = KnowledgeBaseRepository() + await kb_repo.update(db_id, update_data) - return result + return await self.get_database_info(db_id) def get_retrievers(self) -> dict[str, dict]: """获取所有检索器""" @@ -585,20 +617,27 @@ def get_kb_instance_info(self) -> dict[str, dict]: } return info - def get_statistics(self) -> dict: + async def get_statistics(self) -> dict: """获取统计信息""" - stats = {"total_databases": len(self.global_databases_meta), "kb_types": {}, "total_files": 0} + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + from src.repositories.knowledge_file_repository import KnowledgeFileRepository + + kb_repo = KnowledgeBaseRepository() + rows = await kb_repo.get_all() + + stats = {"total_databases": len(rows), "kb_types": {}, "total_files": 0} # 按知识库类型统计 - for db_meta in self.global_databases_meta.values(): - kb_type = db_meta.get("kb_type", "lightrag") + for row in rows: + kb_type = row.kb_type or "lightrag" if kb_type not in stats["kb_types"]: stats["kb_types"][kb_type] = 0 stats["kb_types"][kb_type] += 1 # 统计文件总数 - for kb_instance in self.kb_instances.values(): - stats["total_files"] += len(kb_instance.files_meta) + file_repo = KnowledgeFileRepository() + files = await file_repo.get_all() + stats["total_files"] = len(files) return stats @@ -619,34 +658,29 @@ async def _get_lightrag_instance(self, db_id: str): Raises: ValueError: 如果数据库不存在或不是 lightrag 类型 """ - try: - # 检查数据库是否存在 - if db_id not in self.global_databases_meta: - logger.error(f"Database {db_id} not found in global metadata") - return None + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository - # 检查是否是 LightRAG 类型 - kb_type = self.global_databases_meta[db_id].get("kb_type", "lightrag") - if kb_type != "lightrag": - logger.error(f"Database {db_id} is not a LightRAG type (actual type: {kb_type})") - raise ValueError(f"Database {db_id} is not a LightRAG knowledge base") + kb_repo = KnowledgeBaseRepository() + kb = await kb_repo.get_by_id(db_id) - # 获取 LightRAG 知识库实例 - kb_instance = self._get_kb_for_database(db_id) + if kb is None: + logger.error(f"Database {db_id} not found in global metadata") + return None - # 如果不是 LightRagKB 实例,返回错误 - if not hasattr(kb_instance, "_get_lightrag_instance"): - logger.error(f"Knowledge base instance for {db_id} is not LightRagKB") - return None + kb_type = kb.kb_type or "lightrag" + if kb_type != "lightrag": + logger.error(f"Database {db_id} is not a LightRAG type (actual type: {kb_type})") + raise ValueError(f"Database {db_id} is not a LightRAG knowledge base") - # 调用 LightRagKB 的方法获取 LightRAG 实例 - return await kb_instance._get_lightrag_instance(db_id) + kb_instance = await self._get_kb_for_database(db_id) - except Exception as e: - logger.error(f"Failed to get LightRAG instance for {db_id}: {e}") + if not hasattr(kb_instance, "_get_lightrag_instance"): + logger.error(f"Knowledge base instance for {db_id} is not LightRagKB") return None - def is_lightrag_database(self, db_id: str) -> bool: + return await kb_instance._get_lightrag_instance(db_id) + + async def is_lightrag_database(self, db_id: str) -> bool: """ 检查数据库是否是 LightRAG 类型 @@ -656,27 +690,23 @@ def is_lightrag_database(self, db_id: str) -> bool: Returns: 是否是 LightRAG 类型的数据库 """ - if db_id not in self.global_databases_meta: - return False + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository - kb_type = self.global_databases_meta[db_id].get("kb_type", "lightrag") - return kb_type == "lightrag" + kb_repo = KnowledgeBaseRepository() + kb = await kb_repo.get_by_id(db_id) + if kb is None: + return False + return (kb.kb_type or "lightrag") == "lightrag" - def get_lightrag_databases(self) -> list[dict]: + async def get_lightrag_databases(self) -> list[dict]: """ 获取所有 LightRAG 类型的数据库 Returns: LightRAG 数据库列表 """ - lightrag_databases = [] - - all_databases = self.get_databases()["databases"] - for db in all_databases: - if db.get("kb_type", "lightrag") == "lightrag": - lightrag_databases.append(db) - - return lightrag_databases + all_databases = (await self.get_databases())["databases"] + return [db for db in all_databases if db.get("kb_type", "lightrag") == "lightrag"] # ============================================================================= # 数据一致性检测方法 @@ -714,6 +744,8 @@ async def detect_data_inconsistencies(self) -> dict: async def _detect_milvus_inconsistencies(self) -> dict: """检测 Milvus 中的数据不一致""" + from src.repositories.knowledge_base_repository import KnowledgeBaseRepository + inconsistencies = {"missing_collections": [], "missing_files": []} milvus_kb = self.kb_instances["milvus"] @@ -724,16 +756,40 @@ async def _detect_milvus_inconsistencies(self) -> dict: # 获取 Milvus 中所有实际的集合 actual_collection_names = set(utility.list_collections(using=milvus_kb.connection_alias)) - # 获取 metadata 中记录的数据库ID - metadata_collection_names = set(milvus_kb.databases_meta.keys()) + # 从数据库获取所有已知的数据库ID + kb_repo = KnowledgeBaseRepository() + rows = await kb_repo.get_all() + all_known_db_ids = {row.db_id for row in rows} + + lightrag_suffixes = ["_chunks", "_relationships", "_entities"] # 找出存在于 Milvus 但不在 metadata 中的集合 - missing_collections = actual_collection_names - metadata_collection_names - for collection_name in missing_collections: + # missing_collections = actual_collection_names - metadata_collection_names + for collection_name in actual_collection_names: # 跳过一些系统集合 if not collection_name.startswith("kb_"): continue + # 检查集合是否属于已知数据库 + is_known = False + + # 1. 精确匹配 (Milvus 类型的知识库) + if collection_name in all_known_db_ids: + is_known = True + # 2. 后缀匹配 (LightRAG 类型的知识库) + else: + for suffix in lightrag_suffixes: + if collection_name.endswith(suffix): + potential_db_id = collection_name[: -len(suffix)] + if potential_db_id in all_known_db_ids: + is_known = True + break + + # 如果是已知集合,跳过 + if is_known: + continue + + # 如果是未知集合,记录下来 collection_info = {"collection_name": collection_name, "detected_at": utc_isoformat()} # 尝试获取集合的基本信息 @@ -753,6 +809,9 @@ async def _detect_milvus_inconsistencies(self) -> dict: f"(实体数: {collection_info['count']})" ) + # 获取 metadata 中记录的数据库ID(仅 Milvus 类型,用于检查文件一致性) + metadata_collection_names = set(milvus_kb.databases_meta.keys()) + # 检查文件级别的不一致(针对已知的数据库) for db_id in metadata_collection_names: try: diff --git a/src/repositories/__init__.py b/src/repositories/__init__.py new file mode 100644 index 000000000..a9a2c5b3b --- /dev/null +++ b/src/repositories/__init__.py @@ -0,0 +1 @@ +__all__ = [] diff --git a/src/repositories/agent_config_repository.py b/src/repositories/agent_config_repository.py new file mode 100644 index 000000000..e2b6b589c --- /dev/null +++ b/src/repositories/agent_config_repository.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from src.storage.postgres.models_business import AgentConfig +from src.utils.datetime_utils import utc_now_naive + +# 默认配置名称 +DEFAULT_CONFIG_NAME = "初始配置" + + +class AgentConfigRepository: + def __init__(self, db_session: AsyncSession): + self.db = db_session + + async def list_by_department_agent(self, *, department_id: int, agent_id: str) -> list[AgentConfig]: + result = await self.db.execute( + select(AgentConfig) + .where(AgentConfig.department_id == department_id, AgentConfig.agent_id == agent_id) + .order_by(AgentConfig.is_default.desc(), AgentConfig.id.asc()) + ) + return list(result.scalars().all()) + + async def get_by_id(self, config_id: int) -> AgentConfig | None: + result = await self.db.execute(select(AgentConfig).where(AgentConfig.id == config_id)) + return result.scalar_one_or_none() + + async def set_default(self, *, config: AgentConfig, updated_by: str | None = None) -> AgentConfig: + if config.is_default: + return config + + now = utc_now_naive() + + # 先清空该部门+智能体的所有默认配置 + await self.db.execute( + update(AgentConfig) + .where( + AgentConfig.department_id == config.department_id, + AgentConfig.agent_id == config.agent_id, + ) + .values(is_default=False, updated_at=now, updated_by=updated_by) + ) + + # 再设置目标配置为默认 + config.is_default = True + config.updated_at = now + config.updated_by = updated_by + + await self.db.commit() + await self.db.refresh(config) + return config + + async def get_default(self, *, department_id: int, agent_id: str) -> AgentConfig | None: + result = await self.db.execute( + select(AgentConfig).where( + AgentConfig.department_id == department_id, + AgentConfig.agent_id == agent_id, + AgentConfig.is_default.is_(True), + ) + ) + return result.scalar_one_or_none() + + async def get_or_create_default( + self, + *, + department_id: int, + agent_id: str, + created_by: str | None = None, + ) -> AgentConfig: + existing = await self.get_default(department_id=department_id, agent_id=agent_id) + if existing: + return existing + + items = await self.list_by_department_agent(department_id=department_id, agent_id=agent_id) + if items: + return items[0] + + config = AgentConfig( + department_id=department_id, + agent_id=agent_id, + name=DEFAULT_CONFIG_NAME, + description=None, + icon=None, + pics=[], + examples=[], + config_json={}, + is_default=True, + created_by=created_by, + updated_by=created_by, + created_at=utc_now_naive(), + updated_at=utc_now_naive(), + ) + self.db.add(config) + await self.db.commit() + await self.db.refresh(config) + return config + + async def _name_exists(self, *, department_id: int, agent_id: str, name: str, exclude_id: int | None) -> bool: + stmt = select(AgentConfig.id).where( + AgentConfig.department_id == department_id, + AgentConfig.agent_id == agent_id, + AgentConfig.name == name, + ) + if exclude_id is not None: + stmt = stmt.where(AgentConfig.id != exclude_id) + result = await self.db.execute(stmt) + return result.scalar_one_or_none() is not None + + async def ensure_unique_name( + self, + *, + department_id: int, + agent_id: str, + desired_name: str, + exclude_id: int | None = None, + ) -> str: + candidate = desired_name.strip() or "未命名配置" + if not await self._name_exists( + department_id=department_id, agent_id=agent_id, name=candidate, exclude_id=exclude_id + ): + return candidate + + base = f"{candidate}-副本" + if not await self._name_exists( + department_id=department_id, agent_id=agent_id, name=base, exclude_id=exclude_id + ): + return base + + idx = 2 + while True: + candidate2 = f"{base}{idx}" + if not await self._name_exists( + department_id=department_id, agent_id=agent_id, name=candidate2, exclude_id=exclude_id + ): + return candidate2 + idx += 1 + + async def create( + self, + *, + department_id: int, + agent_id: str, + name: str, + description: str | None = None, + icon: str | None = None, + pics: list[str] | None = None, + examples: list[str] | None = None, + config_json: dict | None = None, + is_default: bool = False, + created_by: str | None = None, + ) -> AgentConfig: + unique_name = await self.ensure_unique_name( + department_id=department_id, + agent_id=agent_id, + desired_name=name, + exclude_id=None, + ) + config = AgentConfig( + department_id=department_id, + agent_id=agent_id, + name=unique_name, + description=description, + icon=icon, + pics=pics or [], + examples=examples or [], + config_json=config_json or {}, + is_default=bool(is_default), + created_by=created_by, + updated_by=created_by, + created_at=utc_now_naive(), + updated_at=utc_now_naive(), + ) + self.db.add(config) + await self.db.commit() + await self.db.refresh(config) + return config + + async def update( + self, + config: AgentConfig, + *, + name: str | None = None, + description: str | None = None, + icon: str | None = None, + pics: list[str] | None = None, + examples: list[str] | None = None, + config_json: dict | None = None, + updated_by: str | None = None, + ) -> AgentConfig: + if name is not None: + config.name = await self.ensure_unique_name( + department_id=config.department_id, + agent_id=config.agent_id, + desired_name=name, + exclude_id=config.id, + ) + if description is not None: + config.description = description + if icon is not None: + config.icon = icon + if pics is not None: + config.pics = pics + if examples is not None: + config.examples = examples + if config_json is not None: + config.config_json = config_json + + config.updated_by = updated_by + config.updated_at = utc_now_naive() + await self.db.commit() + await self.db.refresh(config) + return config + + async def delete(self, *, config: AgentConfig, updated_by: str | None = None) -> None: + department_id = config.department_id + agent_id = config.agent_id + was_default = bool(config.is_default) + + await self.db.delete(config) + await self.db.commit() + + remaining = await self.list_by_department_agent(department_id=department_id, agent_id=agent_id) + if not remaining: + await self.get_or_create_default(department_id=department_id, agent_id=agent_id, created_by=updated_by) + return + + if was_default: + await self.set_default(config=remaining[0], updated_by=updated_by) diff --git a/src/storage/conversation/manager.py b/src/repositories/conversation_repository.py similarity index 61% rename from src/storage/conversation/manager.py rename to src/repositories/conversation_repository.py index a3e0e30fb..4b0cd3b76 100644 --- a/src/storage/conversation/manager.py +++ b/src/repositories/conversation_repository.py @@ -1,24 +1,19 @@ """ -Conversation Storage Manager (Async) - -Manages conversation data storage including messages, tool calls, and statistics. -All database operations are now asynchronous for improved performance. +对话域持久化 Repository(Async) """ -import uuid +import uuid as uuid_lib from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from src.storage.db.models import Conversation, ConversationStats, Message, ToolCall +from src.storage.postgres.models_business import Conversation, ConversationStats, Message, ToolCall from src.utils import logger -from src.utils.datetime_utils import utc_now - +from src.utils.datetime_utils import utc_now_naive -class ConversationManager: - """Async Manager for conversation storage operations""" +class ConversationRepository: def __init__(self, db_session: AsyncSession): self.db = db_session @@ -30,21 +25,8 @@ async def create_conversation( thread_id: str | None = None, metadata: dict | None = None, ) -> Conversation: - """ - Create a new conversation - - Args: - user_id: User ID - agent_id: Agent ID - title: Conversation title - thread_id: Optional thread ID (will generate UUID if not provided) - metadata: Optional additional metadata - - Returns: - Created Conversation object - """ if not thread_id: - thread_id = str(uuid.uuid4()) + thread_id = str(uuid_lib.uuid4()) metadata = (metadata or {}).copy() metadata.setdefault("attachments", []) @@ -59,10 +41,8 @@ async def create_conversation( ) self.db.add(conversation) - # Flush to assign primary key without committing await self.db.flush() - # Create associated stats record and commit once stats = ConversationStats(conversation_id=conversation.id) self.db.add(stats) await self.db.commit() @@ -72,36 +52,21 @@ async def create_conversation( return conversation async def get_conversation_by_thread_id(self, thread_id: str) -> Conversation | None: - """ - Get conversation by thread ID - - Args: - thread_id: Thread ID - - Returns: - Conversation object or None if not found - """ - result = await self.db.execute(select(Conversation).filter(Conversation.thread_id == thread_id)) + result = await self.db.execute(select(Conversation).where(Conversation.thread_id == thread_id)) return result.scalar_one_or_none() async def _get_conversation_by_id(self, conversation_id: int) -> Conversation | None: - result = await self.db.execute(select(Conversation).filter(Conversation.id == conversation_id)) + result = await self.db.execute(select(Conversation).where(Conversation.id == conversation_id)) return result.scalar_one_or_none() def _ensure_metadata(self, conversation: Conversation) -> dict: - """ - Return a shallow copy of conversation metadata with a standalone attachments list. - - We copy here because SQLAlchemy's JSON type does not automatically detect in-place - mutations. By assigning a fresh dict/list back we ensure the ORM marks the row dirty. - """ metadata = dict(conversation.extra_metadata or {}) metadata["attachments"] = list(metadata.get("attachments", [])) return metadata async def _save_metadata(self, conversation: Conversation, metadata: dict) -> None: conversation.extra_metadata = metadata - conversation.updated_at = utc_now() + conversation.updated_at = utc_now_naive() await self.db.commit() await self.db.refresh(conversation) @@ -114,20 +79,6 @@ async def add_message( extra_metadata: dict | None = None, image_content: str | None = None, ) -> Message: - """ - Add a message to a conversation - - Args: - conversation_id: Conversation ID - role: Message role (user/assistant/system/tool) - content: Message content - message_type: Message type (text/tool_call/tool_result/multimodal_image) - extra_metadata: Additional metadata (complete message dump) - image_content: Base64 encoded image content for multimodal messages - - Returns: - Created Message object - """ message = Message( conversation_id=conversation_id, role=role, @@ -138,15 +89,13 @@ async def add_message( ) self.db.add(message) - # Mark the parent conversation as active for sorting/analytics conversation = await self._get_conversation_by_id(conversation_id) if conversation: - conversation.updated_at = utc_now() + conversation.updated_at = utc_now_naive() await self.db.commit() await self.db.refresh(message) - # Update conversation stats await self._update_message_count(conversation_id) logger.debug(f"Added {role} message to conversation {conversation_id}") @@ -161,20 +110,6 @@ async def add_message_by_thread_id( extra_metadata: dict | None = None, image_content: str | None = None, ) -> Message | None: - """ - Add a message to a conversation by thread ID - - Args: - thread_id: Thread ID - role: Message role (user/assistant/system/tool) - content: Message content - message_type: Message type (text/tool_call/tool_result/multimodal_image) - extra_metadata: Additional metadata (complete message dump) - image_content: Base64 encoded image content for multimodal messages - - Returns: - Created Message object or None if conversation not found - """ conversation = await self.get_conversation_by_thread_id(thread_id) if not conversation: logger.warning(f"Conversation not found for thread_id: {thread_id}") @@ -199,21 +134,6 @@ async def add_tool_call( error_message: str | None = None, langgraph_tool_call_id: str | None = None, ) -> ToolCall: - """ - Add a tool call record - - Args: - message_id: Message ID - tool_name: Tool name - tool_input: Tool input parameters - tool_output: Tool execution result - status: Status (pending/success/error) - error_message: Error message if failed - langgraph_tool_call_id: LangGraph tool_call_id for precise matching - - Returns: - Created ToolCall object - """ tool_call = ToolCall( message_id=message_id, tool_name=tool_name, @@ -232,24 +152,13 @@ async def add_tool_call( return tool_call async def get_messages(self, conversation_id: int, limit: int | None = None, offset: int = 0) -> list[Message]: - """ - Get messages for a conversation - - Args: - conversation_id: Conversation ID - limit: Maximum number of messages to return - offset: Number of messages to skip - - Returns: - List of Message objects with preloaded tool_calls and feedbacks - """ query = ( select(Message) .options( - selectinload(Message.tool_calls), # Preload tool calls - selectinload(Message.feedbacks), # Preload feedbacks for UI state + selectinload(Message.tool_calls), + selectinload(Message.feedbacks), ) - .filter(Message.conversation_id == conversation_id) + .where(Message.conversation_id == conversation_id) .order_by(Message.created_at.asc()) ) @@ -257,22 +166,11 @@ async def get_messages(self, conversation_id: int, limit: int | None = None, off query = query.limit(limit).offset(offset) result = await self.db.execute(query) - return result.scalars().unique().all() + return list(result.scalars().unique().all()) async def get_messages_by_thread_id( self, thread_id: str, limit: int | None = None, offset: int = 0 ) -> list[Message]: - """ - Get messages for a conversation by thread ID - - Args: - thread_id: Thread ID - limit: Maximum number of messages to return - offset: Number of messages to skip - - Returns: - List of Message objects - """ conversation = await self.get_conversation_by_thread_id(thread_id) if not conversation: logger.warning(f"Conversation not found for thread_id: {thread_id}") @@ -283,29 +181,16 @@ async def get_messages_by_thread_id( async def list_conversations( self, user_id: str | None = None, agent_id: str | None = None, status: str = "active" ) -> list[Conversation]: - """ - List conversations for a user or all users - - Args: - user_id: User ID (optional, if None or empty string, returns all users' conversations) - agent_id: Optional agent ID filter - status: Conversation status filter + query = select(Conversation).where(Conversation.status == status) - Returns: - List of Conversation objects - """ - query = select(Conversation).filter(Conversation.status == status) - - # Only filter by user_id if it's provided and not empty if user_id: - query = query.filter(Conversation.user_id == str(user_id)) - + query = query.where(Conversation.user_id == str(user_id)) if agent_id: - query = query.filter(Conversation.agent_id == agent_id) + query = query.where(Conversation.agent_id == agent_id) query = query.order_by(Conversation.updated_at.desc()) result = await self.db.execute(query) - return result.scalars().all() + return list(result.scalars().all()) async def update_conversation( self, @@ -314,18 +199,6 @@ async def update_conversation( status: str | None = None, metadata: dict | None = None, ) -> Conversation | None: - """ - Update conversation information - - Args: - thread_id: Thread ID - title: New title - status: New status - metadata: Additional metadata to merge - - Returns: - Updated Conversation object or None if not found - """ conversation = await self.get_conversation_by_thread_id(thread_id) if not conversation: return None @@ -335,13 +208,12 @@ async def update_conversation( if status is not None: conversation.status = status - # Handle metadata updates if metadata is not None: current_metadata = conversation.extra_metadata or {} current_metadata.update(metadata) conversation.extra_metadata = current_metadata - conversation.updated_at = utc_now() + conversation.updated_at = utc_now_naive() await self.db.commit() await self.db.refresh(conversation) @@ -349,16 +221,6 @@ async def update_conversation( return conversation async def delete_conversation(self, thread_id: str, soft_delete: bool = True) -> bool: - """ - Delete a conversation - - Args: - thread_id: Thread ID - soft_delete: If True, mark as deleted; if False, permanently delete - - Returns: - True if successful, False otherwise - """ conversation = await self.get_conversation_by_thread_id(thread_id) if not conversation: return False @@ -375,17 +237,8 @@ async def delete_conversation(self, thread_id: str, soft_delete: bool = True) -> return True async def get_stats(self, conversation_id: int) -> ConversationStats | None: - """ - Get conversation statistics - - Args: - conversation_id: Conversation ID - - Returns: - ConversationStats object or None if not found - """ result = await self.db.execute( - select(ConversationStats).filter(ConversationStats.conversation_id == conversation_id) + select(ConversationStats).where(ConversationStats.conversation_id == conversation_id) ) return result.scalar_one_or_none() @@ -396,18 +249,6 @@ async def update_stats( model_used: str | None = None, user_feedback: dict | None = None, ) -> ConversationStats | None: - """ - Update conversation statistics - - Args: - conversation_id: Conversation ID - tokens_used: Number of tokens to add - model_used: Model name - user_feedback: User feedback data - - Returns: - Updated ConversationStats object or None if not found - """ stats = await self.get_stats(conversation_id) if not stats: return None @@ -419,24 +260,15 @@ async def update_stats( if user_feedback is not None: stats.user_feedback = user_feedback - stats.updated_at = utc_now() + stats.updated_at = utc_now_naive() await self.db.commit() await self.db.refresh(stats) return stats async def get_tool_call_by_langgraph_id(self, langgraph_tool_call_id: str) -> ToolCall | None: - """ - Get tool call by LangGraph tool_call_id - - Args: - langgraph_tool_call_id: LangGraph tool_call_id - - Returns: - ToolCall object or None if not found - """ result = await self.db.execute( - select(ToolCall).filter(ToolCall.langgraph_tool_call_id == langgraph_tool_call_id) + select(ToolCall).where(ToolCall.langgraph_tool_call_id == langgraph_tool_call_id) ) return result.scalar_one_or_none() @@ -447,18 +279,6 @@ async def update_tool_call_output( status: str = "success", error_message: str | None = None, ) -> ToolCall | None: - """ - Update tool call output by LangGraph tool_call_id - - Args: - langgraph_tool_call_id: LangGraph tool_call_id - tool_output: Tool execution result - status: Status (success/error) - error_message: Error message if failed - - Returns: - Updated ToolCall object or None if not found - """ tool_call = await self.get_tool_call_by_langgraph_id(langgraph_tool_call_id) if not tool_call: logger.warning(f"Tool call not found for langgraph_tool_call_id: {langgraph_tool_call_id}") @@ -476,25 +296,15 @@ async def update_tool_call_output( return tool_call async def _update_message_count(self, conversation_id: int) -> None: - """ - Update message count in conversation stats - - Args: - conversation_id: Conversation ID - """ from sqlalchemy import func stats = await self.get_stats(conversation_id) if stats: - result = await self.db.execute(select(func.count()).filter(Message.conversation_id == conversation_id)) + result = await self.db.execute(select(func.count()).where(Message.conversation_id == conversation_id)) message_count = result.scalar() stats.message_count = message_count await self.db.commit() - # ------------------------------------------------------------------------- - # Attachment helpers - # ------------------------------------------------------------------------- - async def get_attachments(self, conversation_id: int) -> list[dict]: conversation = await self._get_conversation_by_id(conversation_id) if not conversation: diff --git a/src/repositories/department_repository.py b/src/repositories/department_repository.py new file mode 100644 index 000000000..df8ff80c0 --- /dev/null +++ b/src/repositories/department_repository.py @@ -0,0 +1,95 @@ +"""部门数据访问层 - Repository""" + +from typing import Any + +from sqlalchemy import func, select + +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import Department + + +class DepartmentRepository: + """部门数据访问层""" + + async def get_by_id(self, id: int) -> Department | None: + """根据 ID 获取部门""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(Department).where(Department.id == id)) + return result.scalar_one_or_none() + + async def get_by_name(self, name: str) -> Department | None: + """根据名称获取部门""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(Department).where(Department.name == name)) + return result.scalar_one_or_none() + + async def list_departments(self) -> list[Department]: + """获取所有部门列表""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(Department).order_by(Department.created_at.desc())) + return list(result.scalars().all()) + + async def list_with_user_count(self) -> list[dict[str, Any]]: + """获取所有部门列表,包含用户数量""" + async with pg_manager.get_async_session_context() as session: + from src.storage.postgres.models_business import User + + result = await session.execute(select(Department).order_by(Department.created_at.desc())) + departments = result.scalars().all() + + department_list = [] + for dep in departments: + user_count_result = await session.execute( + select(func.count(User.id)).where(User.department_id == dep.id, User.is_deleted == 0) + ) + user_count = user_count_result.scalar() + dep_dict = dep.to_dict() + dep_dict["user_count"] = user_count + department_list.append(dep_dict) + + return department_list + + async def create(self, data: dict[str, Any]) -> Department: + """创建部门""" + async with pg_manager.get_async_session_context() as session: + department = Department(**data) + session.add(department) + return department + + async def update(self, id: int, data: dict[str, Any]) -> Department | None: + """更新部门""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(Department).where(Department.id == id)) + department = result.scalar_one_or_none() + if department is None: + return None + for key, value in data.items(): + if key != "id": + setattr(department, key, value) + return department + + async def delete(self, id: int) -> bool: + """删除部门""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(Department).where(Department.id == id)) + department = result.scalar_one_or_none() + if department is None: + return False + await session.delete(department) + return True + + async def count_users(self, id: int) -> int: + """统计部门用户数量""" + async with pg_manager.get_async_session_context() as session: + from src.storage.postgres.models_business import User + + result = await session.execute( + select(func.count(User.id)).where(User.department_id == id, User.is_deleted == 0) + ) + return result.scalar() or 0 + + async def exists_by_name(self, name: str) -> bool: + """检查部门名称是否存在""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(Department.id).where(Department.name == name)) + return result.scalar_one_or_none() is not None diff --git a/src/repositories/evaluation_repository.py b/src/repositories/evaluation_repository.py new file mode 100644 index 000000000..c51d9cbe4 --- /dev/null +++ b/src/repositories/evaluation_repository.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from typing import Any + +from sqlalchemy import delete, select + +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_knowledge import EvaluationBenchmark, EvaluationResult, EvaluationResultDetail + + +class EvaluationRepository: + async def get_all_benchmarks(self) -> list[EvaluationBenchmark]: + """获取所有评估基准""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(EvaluationBenchmark)) + return list(result.scalars().all()) + + async def create_benchmark(self, data: dict[str, Any]) -> EvaluationBenchmark: + benchmark = EvaluationBenchmark(**data) + async with pg_manager.get_async_session_context() as session: + session.add(benchmark) + return benchmark + + async def get_benchmark(self, benchmark_id: str) -> EvaluationBenchmark | None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute( + select(EvaluationBenchmark).where(EvaluationBenchmark.benchmark_id == benchmark_id) + ) + return result.scalar_one_or_none() + + async def list_benchmarks(self, db_id: str) -> list[EvaluationBenchmark]: + async with pg_manager.get_async_session_context() as session: + result = await session.execute( + select(EvaluationBenchmark) + .where(EvaluationBenchmark.db_id == db_id) + .order_by(EvaluationBenchmark.created_at.desc()) + ) + return list(result.scalars().all()) + + async def delete_benchmark(self, benchmark_id: str) -> None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute( + select(EvaluationBenchmark).where(EvaluationBenchmark.benchmark_id == benchmark_id) + ) + record = result.scalar_one_or_none() + if record is not None: + await session.delete(record) + + async def create_result(self, data: dict[str, Any]) -> EvaluationResult: + result_row = EvaluationResult(**data) + async with pg_manager.get_async_session_context() as session: + session.add(result_row) + return result_row + + async def get_result(self, task_id: str) -> EvaluationResult | None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(EvaluationResult).where(EvaluationResult.task_id == task_id)) + return result.scalar_one_or_none() + + async def list_results(self, db_id: str) -> list[EvaluationResult]: + async with pg_manager.get_async_session_context() as session: + result = await session.execute( + select(EvaluationResult) + .where(EvaluationResult.db_id == db_id) + .order_by(EvaluationResult.started_at.desc()) + ) + return list(result.scalars().all()) + + async def update_result(self, task_id: str, data: dict[str, Any]) -> EvaluationResult | None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(EvaluationResult).where(EvaluationResult.task_id == task_id)) + record = result.scalar_one_or_none() + if record is None: + return None + for key, value in data.items(): + setattr(record, key, value) + return record + + async def delete_result(self, task_id: str) -> None: + async with pg_manager.get_async_session_context() as session: + await session.execute(delete(EvaluationResultDetail).where(EvaluationResultDetail.task_id == task_id)) + result = await session.execute(select(EvaluationResult).where(EvaluationResult.task_id == task_id)) + record = result.scalar_one_or_none() + if record is not None: + await session.delete(record) + + async def upsert_result_detail( + self, task_id: str, query_index: int, data: dict[str, Any] + ) -> EvaluationResultDetail: + async with pg_manager.get_async_session_context() as session: + result = await session.execute( + select(EvaluationResultDetail).where( + (EvaluationResultDetail.task_id == task_id) & (EvaluationResultDetail.query_index == query_index) + ) + ) + record = result.scalar_one_or_none() + if record is None: + record = EvaluationResultDetail(task_id=task_id, query_index=query_index, **data) + session.add(record) + return record + for key, value in data.items(): + setattr(record, key, value) + return record + + async def list_result_details(self, task_id: str) -> list[EvaluationResultDetail]: + async with pg_manager.get_async_session_context() as session: + result = await session.execute( + select(EvaluationResultDetail) + .where(EvaluationResultDetail.task_id == task_id) + .order_by(EvaluationResultDetail.query_index.asc()) + ) + return list(result.scalars().all()) + + async def delete_all(self) -> None: + async with pg_manager.get_async_session_context() as session: + await session.execute(delete(EvaluationResultDetail)) + await session.execute(delete(EvaluationResult)) + await session.execute(delete(EvaluationBenchmark)) diff --git a/src/repositories/knowledge_base_repository.py b/src/repositories/knowledge_base_repository.py new file mode 100644 index 000000000..ef5917504 --- /dev/null +++ b/src/repositories/knowledge_base_repository.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Any + +from sqlalchemy import select + +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_knowledge import KnowledgeBase + + +class KnowledgeBaseRepository: + async def get_all(self) -> list[KnowledgeBase]: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeBase)) + return list(result.scalars().all()) + + async def get_by_id(self, db_id: str) -> KnowledgeBase | None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeBase).where(KnowledgeBase.db_id == db_id)) + return result.scalar_one_or_none() + + async def create(self, data: dict[str, Any]) -> KnowledgeBase: + kb = KnowledgeBase(**data) + async with pg_manager.get_async_session_context() as session: + session.add(kb) + return kb + + async def update(self, db_id: str, data: dict[str, Any]) -> KnowledgeBase | None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeBase).where(KnowledgeBase.db_id == db_id)) + kb = result.scalar_one_or_none() + if kb is None: + return None + for key, value in data.items(): + setattr(kb, key, value) + return kb + + async def delete(self, db_id: str) -> None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeBase).where(KnowledgeBase.db_id == db_id)) + kb = result.scalar_one_or_none() + if kb is not None: + await session.delete(kb) diff --git a/src/repositories/knowledge_file_repository.py b/src/repositories/knowledge_file_repository.py new file mode 100644 index 000000000..520a54484 --- /dev/null +++ b/src/repositories/knowledge_file_repository.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Any + +from sqlalchemy import select + +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_knowledge import KnowledgeFile + + +class KnowledgeFileRepository: + async def get_all(self) -> list[KnowledgeFile]: + """获取所有文件记录""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeFile)) + return list(result.scalars().all()) + + async def get_by_file_id(self, file_id: str) -> KnowledgeFile | None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeFile).where(KnowledgeFile.file_id == file_id)) + return result.scalar_one_or_none() + + async def list_by_db_id(self, db_id: str) -> list[KnowledgeFile]: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeFile).where(KnowledgeFile.db_id == db_id)) + return list(result.scalars().all()) + + async def upsert(self, file_id: str, data: dict[str, Any]) -> KnowledgeFile: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeFile).where(KnowledgeFile.file_id == file_id)) + existing = result.scalar_one_or_none() + if existing is None: + record = KnowledgeFile(file_id=file_id, **data) + session.add(record) + return record + for key, value in data.items(): + setattr(existing, key, value) + return existing + + async def delete(self, file_id: str) -> None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeFile).where(KnowledgeFile.file_id == file_id)) + record = result.scalar_one_or_none() + if record is not None: + await session.delete(record) + + async def delete_by_db_id(self, db_id: str) -> None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(KnowledgeFile).where(KnowledgeFile.db_id == db_id)) + for record in result.scalars().all(): + await session.delete(record) diff --git a/src/repositories/mcp_server_repository.py b/src/repositories/mcp_server_repository.py new file mode 100644 index 000000000..04feae117 --- /dev/null +++ b/src/repositories/mcp_server_repository.py @@ -0,0 +1,81 @@ +"""MCP 服务器数据访问层 - Repository""" + +from typing import Any + +from sqlalchemy import select + +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import MCPServer + + +class MCPServerRepository: + """MCP 服务器数据访问层""" + + async def get_by_name(self, name: str) -> MCPServer | None: + """根据名称获取 MCP 服务器""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(MCPServer).where(MCPServer.name == name)) + return result.scalar_one_or_none() + + async def list(self) -> list[MCPServer]: + """获取所有 MCP 服务器""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(MCPServer)) + return list(result.scalars().all()) + + async def list_enabled(self) -> list[MCPServer]: + """获取所有启用的 MCP 服务器""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(MCPServer).where(MCPServer.enabled == 1)) + return list(result.scalars().all()) + + async def create(self, data: dict[str, Any]) -> MCPServer: + """创建 MCP 服务器""" + async with pg_manager.get_async_session_context() as session: + server = MCPServer(**data) + session.add(server) + return server + + async def update(self, name: str, data: dict[str, Any]) -> MCPServer | None: + """更新 MCP 服务器""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(MCPServer).where(MCPServer.name == name)) + server = result.scalar_one_or_none() + if server is None: + return None + for key, value in data.items(): + if key != "name": + setattr(server, key, value) + return server + + async def delete(self, name: str) -> bool: + """删除 MCP 服务器""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(MCPServer).where(MCPServer.name == name)) + server = result.scalar_one_or_none() + if server is None: + return False + await session.delete(server) + return True + + async def upsert(self, data: dict[str, Any]) -> MCPServer: + """插入或更新 MCP 服务器""" + name = data.get("name") + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(MCPServer).where(MCPServer.name == name)) + existing = result.scalar_one_or_none() + if existing is None: + server = MCPServer(**data) + session.add(server) + else: + for key, value in data.items(): + if key != "name": + setattr(existing, key, value) + server = existing + return server + + async def exists_by_name(self, name: str) -> bool: + """检查 MCP 服务器是否存在""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(MCPServer.id).where(MCPServer.name == name)) + return result.scalar_one_or_none() is not None diff --git a/src/repositories/message_feedback_repository.py b/src/repositories/message_feedback_repository.py new file mode 100644 index 000000000..45bf9bddd --- /dev/null +++ b/src/repositories/message_feedback_repository.py @@ -0,0 +1,41 @@ +"""消息反馈数据访问层 - Repository""" + +from typing import Any + +from sqlalchemy import select + +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import MessageFeedback + + +class MessageFeedbackRepository: + """消息反馈数据访问层""" + + async def get_by_id(self, id: int) -> MessageFeedback | None: + """根据 ID 获取消息反馈""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(MessageFeedback).where(MessageFeedback.id == id)) + return result.scalar_one_or_none() + + async def list_by_message(self, message_id: int) -> list[MessageFeedback]: + """获取消息的反馈列表""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(MessageFeedback).where(MessageFeedback.message_id == message_id)) + return list(result.scalars().all()) + + async def create(self, data: dict[str, Any]) -> MessageFeedback: + """创建消息反馈""" + async with pg_manager.get_async_session_context() as session: + feedback = MessageFeedback(**data) + session.add(feedback) + return feedback + + async def exists_by_message_and_user(self, message_id: int, user_id: str) -> bool: + """检查用户是否已对消息反馈""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute( + select(MessageFeedback.id).where( + MessageFeedback.message_id == message_id, MessageFeedback.user_id == user_id + ) + ) + return result.scalar_one_or_none() is not None diff --git a/src/repositories/operation_log_repository.py b/src/repositories/operation_log_repository.py new file mode 100644 index 000000000..6730402e2 --- /dev/null +++ b/src/repositories/operation_log_repository.py @@ -0,0 +1,45 @@ +"""操作日志数据访问层 - Repository""" + +from typing import Any + +from sqlalchemy import select + +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import OperationLog + + +class OperationLogRepository: + """操作日志数据访问层""" + + async def get_by_id(self, id: int) -> OperationLog | None: + """根据 ID 获取操作日志""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(OperationLog).where(OperationLog.id == id)) + return result.scalar_one_or_none() + + async def list_by_user(self, user_id: int, skip: int = 0, limit: int = 100) -> list[OperationLog]: + """获取用户的操作日志列表""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute( + select(OperationLog) + .where(OperationLog.user_id == user_id) + .order_by(OperationLog.timestamp.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def create(self, data: dict[str, Any]) -> OperationLog: + """创建操作日志""" + async with pg_manager.get_async_session_context() as session: + log = OperationLog(**data) + session.add(log) + return log + + async def count_by_user(self, user_id: int) -> int: + """统计用户操作日志数量""" + from sqlalchemy import func + + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(func.count(OperationLog.id)).where(OperationLog.user_id == user_id)) + return result.scalar() or 0 diff --git a/src/repositories/task_repository.py b/src/repositories/task_repository.py new file mode 100644 index 000000000..5f089f64a --- /dev/null +++ b/src/repositories/task_repository.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import Any + +from sqlalchemy import delete, select + +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import TaskRecord + + +class TaskRepository: + async def get_by_id(self, task_id: str) -> TaskRecord | None: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(TaskRecord).where(TaskRecord.id == task_id)) + return result.scalar_one_or_none() + + async def list(self, status: str | None = None, limit: int = 100) -> list[TaskRecord]: + async with pg_manager.get_async_session_context() as session: + stmt = select(TaskRecord) + if status: + stmt = stmt.where(TaskRecord.status == status) + stmt = stmt.order_by(TaskRecord.created_at.desc()).limit(max(limit, 0)) + result = await session.execute(stmt) + return list(result.scalars().all()) + + async def list_all(self) -> list[TaskRecord]: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(TaskRecord).order_by(TaskRecord.created_at.desc())) + return list(result.scalars().all()) + + async def upsert(self, task_id: str, data: dict[str, Any]) -> TaskRecord: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(TaskRecord).where(TaskRecord.id == task_id)) + record = result.scalar_one_or_none() + if record is None: + record = TaskRecord(id=task_id, **data) + session.add(record) + return record + for key, value in data.items(): + setattr(record, key, value) + return record + + async def delete_all(self) -> None: + async with pg_manager.get_async_session_context() as session: + await session.execute(delete(TaskRecord)) diff --git a/src/repositories/user_repository.py b/src/repositories/user_repository.py new file mode 100644 index 000000000..c3f37798b --- /dev/null +++ b/src/repositories/user_repository.py @@ -0,0 +1,147 @@ +"""用户数据访问层 - Repository""" + +from datetime import UTC +from datetime import datetime as dt +from typing import Annotated, Any + +from sqlalchemy import func, select + +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_business import User + +# 使用 naive datetime 以兼容 PostgreSQL TIMESTAMP WITHOUT TIME ZONE 列 +_utc_now = dt.now(UTC).replace(tzinfo=None) + + +class UserRepository: + """用户数据访问层""" + + async def get_by_id(self, id: int) -> User | None: + """根据 ID 获取用户""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(User).where(User.id == id)) + return result.scalar_one_or_none() + + async def get_by_user_id(self, user_id: str) -> User | None: + """根据 user_id 获取用户""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(User).where(User.user_id == user_id)) + return result.scalar_one_or_none() + + async def get_by_phone(self, phone: str) -> User | None: + """根据手机号获取用户""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(User).where(User.phone_number == phone)) + return result.scalar_one_or_none() + + async def list_users( + self, skip: int = 0, limit: int = 100, department_id: int | None = None, role: str | None = None + ) -> list[User]: + """获取用户列表""" + async with pg_manager.get_async_session_context() as session: + query = select(User).where(User.is_deleted == 0) + if department_id is not None: + query = query.where(User.department_id == department_id) + if role is not None: + query = query.where(User.role == role) + query = query.order_by(User.id.asc()).offset(skip).limit(limit) + result = await session.execute(query) + return list(result.scalars().all()) + + async def list_with_department( + self, skip: int = 0, limit: int = 100, department_id: int | None = None, role: str | None = None + ) -> Annotated[list[tuple[User, str | None]], "用户列表,包含部门名称"]: + """获取用户列表,包含部门名称""" + async with pg_manager.get_async_session_context() as session: + from src.storage.postgres.models_business import Department + + query = ( + select(User, Department.name.label("department_name")) + .outerjoin(Department, User.department_id == Department.id) + .where(User.is_deleted == 0) + ) + if department_id is not None: + query = query.where(User.department_id == department_id) + if role is not None: + query = query.where(User.role == role) + query = query.order_by(User.id.asc()).offset(skip).limit(limit) + result = await session.execute(query) + return list(result.all()) + + async def create(self, data: dict[str, Any]) -> User: + """创建用户""" + async with pg_manager.get_async_session_context() as session: + user = User(**data) + session.add(user) + await session.commit() + await session.refresh(user) + return user + + async def update(self, id: int, data: dict[str, Any]) -> User | None: + """更新用户""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(User).where(User.id == id, User.is_deleted == 0)) + user = result.scalar_one_or_none() + if user is None: + return None + for key, value in data.items(): + if key != "id": + setattr(user, key, value) + return user + + async def soft_delete(self, id: int, username: str | None = None, phone_number: str | None = None) -> bool: + """软删除用户""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(User).where(User.id == id, User.is_deleted == 0)) + user = result.scalar_one_or_none() + if user is None: + return False + user.is_deleted = 1 + + user.deleted_at = _utc_now() + if username: + import hashlib + + hash_suffix = hashlib.sha256(user.user_id.encode()).hexdigest()[:4] + user.username = f"已注销用户-{hash_suffix}" + if phone_number: + user.phone_number = None + return True + + async def exists_by_user_id(self, user_id: str) -> bool: + """检查 user_id 是否存在""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(User.id).where(User.user_id == user_id)) + return result.scalar_one_or_none() is not None + + async def exists_by_phone(self, phone: str) -> bool: + """检查手机号是否存在""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(User.id).where(User.phone_number == phone)) + return result.scalar_one_or_none() is not None + + async def count(self, department_id: int | None = None) -> int: + """统计用户数量""" + async with pg_manager.get_async_session_context() as session: + query = select(func.count(User.id)).where(User.is_deleted == 0) + if department_id is not None: + query = query.where(User.department_id == department_id) + result = await session.execute(query) + return result.scalar() or 0 + + async def get_all_user_ids(self) -> list[str]: + """获取所有用户 ID""" + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(User.user_id)) + return [uid for (uid,) in result.all()] + + async def get_admin_count_in_department(self, department_id: int, exclude_user_id: int | None = None) -> int: + """统计部门中管理员数量""" + async with pg_manager.get_async_session_context() as session: + query = select(func.count(User.id)).where( + User.department_id == department_id, User.role == "admin", User.is_deleted == 0 + ) + if exclude_user_id is not None: + query = query.where(User.id != exclude_user_id) + result = await session.execute(query) + return result.scalar() or 0 diff --git a/src/services/chat_stream_service.py b/src/services/chat_stream_service.py new file mode 100644 index 000000000..11e27c12f --- /dev/null +++ b/src/services/chat_stream_service.py @@ -0,0 +1,619 @@ +import asyncio +import json +import traceback +import uuid +from collections.abc import AsyncIterator + +from langchain.messages import AIMessage, AIMessageChunk, HumanMessage +from langgraph.types import Command + +from src import config as conf +from src import knowledge_base +from src.agents import agent_manager +from src.plugins.guard import content_guard +from src.repositories.agent_config_repository import AgentConfigRepository +from src.repositories.conversation_repository import ConversationRepository +from src.storage.postgres.manager import pg_manager +from src.utils.logging_config import logger + + +async def _get_langgraph_messages(agent_instance, config_dict): + graph = await agent_instance.get_graph() + state = await graph.aget_state(config_dict) + + if not state or not state.values: + logger.warning("No state found in LangGraph") + return None + + return state.values.get("messages", []) + + +def extract_agent_state(values: dict) -> dict: + if not isinstance(values, dict): + return {} + + def _norm_list(v): + if v is None: + return [] + if isinstance(v, (list, tuple)): + return list(v) + return [v] + + result = {} + result["todos"] = _norm_list(values.get("todos"))[:20] + result["files"] = _norm_list(values.get("files"))[:50] + + return result + + +async def _get_existing_message_ids(conv_repo: ConversationRepository, thread_id: str) -> set[str]: + existing_messages = await conv_repo.get_messages_by_thread_id(thread_id) + return { + msg.extra_metadata["id"] + for msg in existing_messages + if msg.extra_metadata and "id" in msg.extra_metadata and isinstance(msg.extra_metadata["id"], str) + } + + +async def _save_ai_message(conv_repo: ConversationRepository, thread_id: str, msg_dict: dict) -> None: + content = msg_dict.get("content", "") + tool_calls_data = msg_dict.get("tool_calls", []) + + ai_msg = await conv_repo.add_message_by_thread_id( + thread_id=thread_id, + role="assistant", + content=content, + message_type="text", + extra_metadata=msg_dict, + ) + + if ai_msg and tool_calls_data: + for tc in tool_calls_data: + await conv_repo.add_tool_call( + message_id=ai_msg.id, + tool_name=tc.get("name", "unknown"), + tool_input=tc.get("args", {}), + status="pending", + langgraph_tool_call_id=tc.get("id"), + ) + + +async def _save_tool_message(conv_repo: ConversationRepository, msg_dict: dict) -> None: + tool_call_id = msg_dict.get("tool_call_id") + content = msg_dict.get("content", "") + + if not tool_call_id: + return + + if isinstance(content, list): + tool_output = json.dumps(content) if content else "" + else: + tool_output = str(content) + + await conv_repo.update_tool_call_output( + langgraph_tool_call_id=tool_call_id, + tool_output=tool_output, + status="success", + ) + + +async def save_partial_message( + conv_repo: ConversationRepository, + thread_id: str, + full_msg=None, + error_message: str | None = None, + error_type: str = "interrupted", +): + try: + extra_metadata = { + "error_type": error_type, + "is_error": True, + "error_message": error_message or f"发生错误: {error_type}", + } + if full_msg: + msg_dict = full_msg.model_dump() if hasattr(full_msg, "model_dump") else {} + content = full_msg.content if hasattr(full_msg, "content") else str(full_msg) + extra_metadata = msg_dict | extra_metadata + else: + content = "" + + return await conv_repo.add_message_by_thread_id( + thread_id=thread_id, + role="assistant", + content=content, + message_type="text", + extra_metadata=extra_metadata, + ) + + except Exception as e: + logger.error(f"Error saving message: {e}") + logger.error(traceback.format_exc()) + return None + + +async def save_messages_from_langgraph_state( + agent_instance, + thread_id: str, + conv_repo: ConversationRepository, + config_dict: dict, +) -> None: + try: + messages = await _get_langgraph_messages(agent_instance, config_dict) + if messages is None: + return + + existing_ids = await _get_existing_message_ids(conv_repo, thread_id) + + for msg in messages: + msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else {} + msg_type = msg_dict.get("type", "unknown") + + if msg_type == "human" or getattr(msg, "id", None) in existing_ids: + continue + + if msg_type == "ai": + await _save_ai_message(conv_repo, thread_id, msg_dict) + elif msg_type == "tool": + await _save_tool_message(conv_repo, msg_dict) + + except Exception as e: + logger.error(f"Error saving messages from LangGraph state: {e}") + logger.error(traceback.format_exc()) + + +async def check_and_handle_interrupts( + agent, + langgraph_config: dict, + make_chunk, + meta: dict, + thread_id: str, +) -> AsyncIterator[bytes]: + try: + graph = await agent.get_graph() + state = await graph.aget_state(langgraph_config) + + if not state or not state.values: + return + + interrupt_info = None + + if hasattr(state, "tasks") and state.tasks: + for task in state.tasks: + if hasattr(task, "interrupts") and task.interrupts: + interrupt_info = task.interrupts[0] + break + + if not interrupt_info and state.values: + interrupt_data = state.values.get("__interrupt__") + if interrupt_data and isinstance(interrupt_data, list) and len(interrupt_data) > 0: + interrupt_info = interrupt_data[0] + + if interrupt_info: + question = "是否批准以下操作?" + operation = "需要人工审批的操作" + if isinstance(interrupt_info, dict): + question = interrupt_info.get("question", question) + operation = interrupt_info.get("operation", operation) + elif hasattr(interrupt_info, "question"): + question = getattr(interrupt_info, "question", question) + operation = getattr(interrupt_info, "operation", operation) + + meta["interrupt"] = { + "question": question, + "operation": operation, + "thread_id": thread_id, + } + yield make_chunk(status="interrupted", message=question, meta=meta) + + except Exception as e: + logger.error(f"Error checking interrupts: {e}") + logger.error(traceback.format_exc()) + + +async def stream_agent_chat( + *, + agent_id: str, + query: str, + config: dict, + meta: dict, + image_content: str | None, + current_user, + db, +) -> AsyncIterator[bytes]: + start_time = asyncio.get_event_loop().time() + + def make_chunk(content=None, **kwargs): + return ( + json.dumps( + {"request_id": meta.get("request_id"), "response": content, **kwargs}, ensure_ascii=False + ).encode("utf-8") + + b"\n" + ) + + if image_content: + human_message = HumanMessage( + content=[ + {"type": "text", "text": query}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_content}"}}, + ] + ) + message_type = "multimodal_image" + else: + human_message = HumanMessage(content=query) + message_type = "text" + + init_msg = {"role": "user", "content": query, "type": "human"} + if image_content: + init_msg["message_type"] = "multimodal_image" + init_msg["image_content"] = image_content + else: + init_msg["message_type"] = "text" + + yield make_chunk(status="init", meta=meta, msg=init_msg) + + if conf.enable_content_guard and await content_guard.check(query): + yield make_chunk( + status="error", error_type="content_guard_blocked", error_message="输入内容包含敏感词", meta=meta + ) + return + + try: + agent = agent_manager.get_agent(agent_id) + except Exception as e: + logger.error(f"Error getting agent {agent_id}: {e}, {traceback.format_exc()}") + yield make_chunk( + status="error", + error_type="agent_error", + error_message=f"智能体 {agent_id} 获取失败: {str(e)}", + meta=meta, + ) + return + + messages = [human_message] + + user_id = str(current_user.id) + department_id = current_user.department_id + if not department_id: + yield make_chunk(status="error", error_type="no_department", error_message="当前用户未绑定部门", meta=meta) + return + + agent_config_id = config.get("agent_config_id") + config_repo = AgentConfigRepository(db) + config_item = None + if agent_config_id is not None: + try: + config_item = await config_repo.get_by_id(int(agent_config_id)) + except Exception: + logger.warning(f"Failed to fetch agent config {agent_config_id}: {traceback.format_exc()}") + config_item = None + if config_item is not None and (config_item.department_id != department_id or config_item.agent_id != agent_id): + config_item = None + + if config_item is None: + config_item = await config_repo.get_or_create_default( + department_id=department_id, agent_id=agent_id, created_by=user_id + ) + agent_config_id = config_item.id + + if not (thread_id := config.get("thread_id")): + thread_id = str(uuid.uuid4()) + logger.warning(f"No thread_id provided, generated new thread_id: {thread_id}") + + agent_config = (config_item.config_json or {}).get("context", {}) + input_context = { + "user_id": user_id, + "thread_id": thread_id, + "department_id": department_id, + "agent_config_id": agent_config_id, + "agent_config": agent_config, + } + + try: + conv_repo = ConversationRepository(db) + + try: + await conv_repo.add_message_by_thread_id( + thread_id=thread_id, + role="user", + content=query, + message_type=message_type, + image_content=image_content, + extra_metadata={"raw_message": human_message.model_dump()}, + ) + except Exception as e: + logger.error(f"Error saving user message: {e}") + + try: + assert thread_id, "thread_id is required" + attachments = await conv_repo.get_attachments_by_thread_id(thread_id) + input_context["attachments"] = attachments + except Exception as e: + logger.error(f"Error loading attachments for thread_id={thread_id}: {e}") + input_context["attachments"] = [] + + # 根据用户权限过滤知识库 + requested_knowledge_names = input_context["agent_config"].get("knowledges") + logger.info(f"Requesting knowledges: {requested_knowledge_names}") + if requested_knowledge_names and isinstance(requested_knowledge_names, list) and requested_knowledge_names: + user_info = {"role": "user", "department_id": department_id} + accessible_databases = await knowledge_base.get_databases_by_user(user_info) + accessible_kb_names = { + db.get("name") + for db in accessible_databases.get("databases", []) + if isinstance(db, dict) and db.get("name") + } + logger.info(f"Accessible knowledges: {accessible_kb_names}") + + filtered_knowledge_names = [kb for kb in requested_knowledge_names if kb in accessible_kb_names] + blocked_knowledge_names = [kb for kb in requested_knowledge_names if kb not in accessible_kb_names] + if blocked_knowledge_names: + logger.warning(f"用户 {user_id} 无权访问知识库: {blocked_knowledge_names}, 已自动过滤") + input_context["agent_config"]["knowledges"] = filtered_knowledge_names + + full_msg = None + accumulated_content = [] + langgraph_config = {"configurable": {"thread_id": thread_id, "user_id": user_id}} + async for msg, metadata in agent.stream_messages(messages, input_context=input_context): + if isinstance(msg, AIMessageChunk): + accumulated_content.append(msg.content) + + content_for_check = "".join(accumulated_content[-10:]) + if conf.enable_content_guard and await content_guard.check_with_keywords(content_for_check): + full_msg = AIMessage(content="".join(accumulated_content)) + await save_partial_message(conv_repo, thread_id, full_msg, "content_guard_blocked") + meta["time_cost"] = asyncio.get_event_loop().time() - start_time + yield make_chunk(status="interrupted", message="检测到敏感内容,已中断输出", meta=meta) + return + + yield make_chunk(content=msg.content, msg=msg.model_dump(), metadata=metadata, status="loading") + else: + msg_dict = msg.model_dump() + yield make_chunk(msg=msg_dict, metadata=metadata, status="loading") + + try: + if msg_dict.get("type") == "tool": + graph = await agent.get_graph() + state = await graph.aget_state(langgraph_config) + agent_state = extract_agent_state(getattr(state, "values", {})) if state else {} + if agent_state: + yield make_chunk(status="agent_state", agent_state=agent_state, meta=meta) + except Exception as e: + logger.error(f"Error processing tool message: {e}") + + if not full_msg and accumulated_content: + full_msg = AIMessage(content="".join(accumulated_content)) + + if conf.enable_content_guard and hasattr(full_msg, "content") and await content_guard.check(full_msg.content): + await save_partial_message(conv_repo, thread_id, full_msg, "content_guard_blocked") + meta["time_cost"] = asyncio.get_event_loop().time() - start_time + yield make_chunk(status="interrupted", message="检测到敏感内容,已中断输出", meta=meta) + return + + async for chunk in check_and_handle_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): + yield chunk + + meta["time_cost"] = asyncio.get_event_loop().time() - start_time + try: + graph = await agent.get_graph() + state = await graph.aget_state(langgraph_config) + agent_state = extract_agent_state(getattr(state, "values", {})) if state else {} + except Exception: + agent_state = {} + + if agent_state: + yield make_chunk(status="agent_state", agent_state=agent_state, meta=meta) + + yield make_chunk(status="finished", meta=meta) + + await save_messages_from_langgraph_state( + agent_instance=agent, + thread_id=thread_id, + conv_repo=conv_repo, + config_dict=langgraph_config, + ) + + except (asyncio.CancelledError, ConnectionError) as e: + logger.warning(f"Client disconnected, cancelling stream: {e}") + + async def save_cleanup(): + nonlocal full_msg + if not full_msg and accumulated_content: + full_msg = AIMessage(content="".join(accumulated_content)) + + async with pg_manager.get_async_session_context() as new_db: + new_conv_repo = ConversationRepository(new_db) + await save_partial_message( + new_conv_repo, + thread_id, + full_msg=full_msg, + error_message="对话已中断" if not full_msg else None, + error_type="interrupted", + ) + + cleanup_task = asyncio.create_task(save_cleanup()) + try: + await asyncio.shield(cleanup_task) + except asyncio.CancelledError: + pass + except Exception as exc: + logger.error(f"Error during cleanup save: {exc}") + + yield make_chunk(status="interrupted", message="对话已中断", meta=meta) + + except Exception as e: + logger.error(f"Error streaming messages: {e}, {traceback.format_exc()}") + + error_msg = f"Error streaming messages: {e}" + error_type = "unexpected_error" + + if not full_msg and accumulated_content: + full_msg = AIMessage(content="".join(accumulated_content)) + + async with pg_manager.get_async_session_context() as new_db: + new_conv_repo = ConversationRepository(new_db) + await save_partial_message( + new_conv_repo, + thread_id, + full_msg=full_msg, + error_message=error_msg, + error_type=error_type, + ) + + yield make_chunk(status="error", error_type=error_type, error_message=error_msg, meta=meta) + + +async def stream_agent_resume( + *, + agent_id: str, + thread_id: str, + approved: bool, + meta: dict, + config: dict, + current_user, + db, +) -> AsyncIterator[bytes]: + start_time = asyncio.get_event_loop().time() + + def make_resume_chunk(content=None, **kwargs): + return ( + json.dumps( + {"request_id": meta.get("request_id"), "response": content, **kwargs}, ensure_ascii=False + ).encode("utf-8") + + b"\n" + ) + + try: + agent = agent_manager.get_agent(agent_id) + except Exception as e: + logger.error(f"Error getting agent {agent_id}: {e}, {traceback.format_exc()}") + yield ( + f'{{"request_id": "{meta.get("request_id")}", "message": ' + f'"Error getting agent {agent_id}: {e}", "status": "error"}}\n' + ) + return + + init_msg = {"type": "system", "content": f"Resume with approved: {approved}"} + yield make_resume_chunk(status="init", meta=meta, msg=init_msg) + + resume_command = Command(resume=approved) + graph = await agent.get_graph() + + user_id = str(current_user.id) + department_id = current_user.department_id + if not department_id: + yield make_resume_chunk( + status="error", error_type="no_department", error_message="当前用户未绑定部门", meta=meta + ) + return + + agent_config_id = (config or {}).get("agent_config_id") + config_repo = AgentConfigRepository(db) + config_item = None + if agent_config_id is not None: + try: + config_item = await config_repo.get_by_id(int(agent_config_id)) + except Exception: + logger.warning(f"Failed to fetch agent config {agent_config_id}: {traceback.format_exc()}") + config_item = None + if config_item is not None and (config_item.department_id != department_id or config_item.agent_id != agent_id): + config_item = None + + if config_item is None: + config_item = await config_repo.get_or_create_default( + department_id=department_id, agent_id=agent_id, created_by=user_id + ) + agent_config_id = config_item.id + + input_context = { + "user_id": user_id, + "thread_id": thread_id, + "department_id": department_id, + "agent_config_id": agent_config_id, + "agent_config": (config_item.config_json or {}).get("context", config_item.config_json or {}), + } + context = agent.context_schema() + agent_config = input_context.get("agent_config") + if isinstance(agent_config, dict): + context.update(agent_config) + context.update(input_context) + + stream_source = graph.astream( + resume_command, + context=context, + config={"configurable": {"thread_id": thread_id, "user_id": user_id}}, + stream_mode="messages", + ) + + try: + async for msg, metadata in stream_source: + msg_dict = msg.model_dump() + if "id" not in msg_dict: + msg_dict["id"] = str(uuid.uuid4()) + + yield make_resume_chunk( + content=getattr(msg, "content", ""), msg=msg_dict, metadata=metadata, status="loading" + ) + + langgraph_config = {"configurable": {"thread_id": thread_id, "user_id": str(current_user.id)}} + async for chunk in check_and_handle_interrupts(agent, langgraph_config, make_resume_chunk, meta, thread_id): + yield chunk + + meta["time_cost"] = asyncio.get_event_loop().time() - start_time + yield make_resume_chunk(status="finished", meta=meta) + + conv_repo = ConversationRepository(db) + await save_messages_from_langgraph_state( + agent_instance=agent, + thread_id=thread_id, + conv_repo=conv_repo, + config_dict=langgraph_config, + ) + + except (asyncio.CancelledError, ConnectionError) as e: + logger.warning(f"Client disconnected during resume: {e}") + + async with pg_manager.get_async_session_context() as new_db: + new_conv_repo = ConversationRepository(new_db) + await save_partial_message( + new_conv_repo, thread_id, error_message="对话恢复已中断", error_type="resume_interrupted" + ) + + yield make_resume_chunk(status="interrupted", message="对话恢复已中断", meta=meta) + + except Exception as e: + logger.error(f"Error during resume: {e}, {traceback.format_exc()}") + + async with pg_manager.get_async_session_context() as new_db: + new_conv_repo = ConversationRepository(new_db) + await save_partial_message( + new_conv_repo, thread_id, error_message=f"Error during resume: {e}", error_type="resume_error" + ) + + yield make_resume_chunk(message=f"Error during resume: {e}", status="error") + + +async def get_agent_state_view( + *, + agent_id: str, + thread_id: str, + current_user_id: str, + db, +) -> dict: + if not agent_manager.get_agent(agent_id): + from fastapi import HTTPException + + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + conv_repo = ConversationRepository(db) + conversation = await conv_repo.get_conversation_by_thread_id(thread_id) + if not conversation or conversation.user_id != str(current_user_id) or conversation.status == "deleted": + from fastapi import HTTPException + + raise HTTPException(status_code=404, detail="对话线程不存在") + + agent = agent_manager.get_agent(agent_id) + graph = await agent.get_graph() + langgraph_config = {"configurable": {"user_id": str(current_user_id), "thread_id": thread_id}} + state = await graph.aget_state(langgraph_config) + agent_state = extract_agent_state(getattr(state, "values", {})) if state else {} + return {"agent_state": agent_state} diff --git a/src/services/conversation_service.py b/src/services/conversation_service.py new file mode 100644 index 000000000..af23a364d --- /dev/null +++ b/src/services/conversation_service.py @@ -0,0 +1,191 @@ +import uuid + +from fastapi import HTTPException, UploadFile +from sqlalchemy.ext.asyncio import AsyncSession + +from src.repositories.conversation_repository import ConversationRepository +from src.services.doc_converter import ( + ATTACHMENT_ALLOWED_EXTENSIONS, + MAX_ATTACHMENT_SIZE_BYTES, + convert_upload_to_markdown, +) +from src.utils.datetime_utils import utc_isoformat +from src.utils.logging_config import logger + + +async def require_user_conversation(conv_repo: ConversationRepository, thread_id: str, user_id: str): + conversation = await conv_repo.get_conversation_by_thread_id(thread_id) + if not conversation or conversation.user_id != str(user_id) or conversation.status == "deleted": + raise HTTPException(status_code=404, detail="对话线程不存在") + return conversation + + +def serialize_attachment(record: dict) -> dict: + return { + "file_id": record.get("file_id"), + "file_name": record.get("file_name"), + "file_type": record.get("file_type"), + "file_size": record.get("file_size", 0), + "status": record.get("status", "parsed"), + "uploaded_at": record.get("uploaded_at"), + "truncated": record.get("truncated", False), + } + + +async def create_thread_view( + *, + agent_id: str, + title: str | None, + metadata: dict | None, + db: AsyncSession, + current_user_id: str, +) -> dict: + thread_id = str(uuid.uuid4()) + conv_repo = ConversationRepository(db) + conversation = await conv_repo.create_conversation( + user_id=str(current_user_id), + agent_id=agent_id, + title=title or "新的对话", + thread_id=thread_id, + metadata=metadata, + ) + + return { + "id": conversation.thread_id, + "user_id": conversation.user_id, + "agent_id": conversation.agent_id, + "title": conversation.title, + "created_at": conversation.created_at.isoformat(), + "updated_at": conversation.updated_at.isoformat(), + } + + +async def list_threads_view( + *, + agent_id: str, + db: AsyncSession, + current_user_id: str, +) -> list[dict]: + if not agent_id: + raise HTTPException(status_code=422, detail="agent_id 不能为空") + + conv_repo = ConversationRepository(db) + conversations = await conv_repo.list_conversations( + user_id=str(current_user_id), + agent_id=agent_id, + status="active", + ) + + return [ + { + "id": conv.thread_id, + "user_id": conv.user_id, + "agent_id": conv.agent_id, + "title": conv.title, + "created_at": conv.created_at.isoformat(), + "updated_at": conv.updated_at.isoformat(), + } + for conv in conversations + ] + + +async def delete_thread_view( + *, + thread_id: str, + db: AsyncSession, + current_user_id: str, +) -> dict: + conv_repo = ConversationRepository(db) + await require_user_conversation(conv_repo, thread_id, str(current_user_id)) + deleted = await conv_repo.delete_conversation(thread_id, soft_delete=True) + if not deleted: + raise HTTPException(status_code=404, detail="对话线程不存在") + return {"message": "删除成功"} + + +async def update_thread_view( + *, + thread_id: str, + title: str | None, + db: AsyncSession, + current_user_id: str, +) -> dict: + conv_repo = ConversationRepository(db) + await require_user_conversation(conv_repo, thread_id, str(current_user_id)) + updated_conv = await conv_repo.update_conversation(thread_id, title=title) + if not updated_conv: + raise HTTPException(status_code=500, detail="更新失败") + return { + "id": updated_conv.thread_id, + "user_id": updated_conv.user_id, + "agent_id": updated_conv.agent_id, + "title": updated_conv.title, + "created_at": updated_conv.created_at.isoformat(), + "updated_at": updated_conv.updated_at.isoformat(), + } + + +async def upload_thread_attachment_view( + *, + thread_id: str, + file: UploadFile, + db: AsyncSession, + current_user_id: str, +) -> dict: + conv_repo = ConversationRepository(db) + conversation = await require_user_conversation(conv_repo, thread_id, str(current_user_id)) + + try: + conversion = await convert_upload_to_markdown(file) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + logger.error(f"附件解析失败: {exc}") + raise HTTPException(status_code=500, detail="附件解析失败,请稍后重试") from exc + + attachment_record = { + "file_id": conversion.file_id, + "file_name": conversion.file_name, + "file_type": conversion.file_type, + "file_size": conversion.file_size, + "status": "parsed", + "markdown": conversion.markdown, + "uploaded_at": utc_isoformat(), + "truncated": conversion.truncated, + } + await conv_repo.add_attachment(conversation.id, attachment_record) + + return serialize_attachment(attachment_record) + + +async def list_thread_attachments_view( + *, + thread_id: str, + db: AsyncSession, + current_user_id: str, +) -> dict: + conv_repo = ConversationRepository(db) + conversation = await require_user_conversation(conv_repo, thread_id, str(current_user_id)) + attachments = await conv_repo.get_attachments(conversation.id) + return { + "attachments": [serialize_attachment(item) for item in attachments], + "limits": { + "allowed_extensions": sorted(ATTACHMENT_ALLOWED_EXTENSIONS), + "max_size_bytes": MAX_ATTACHMENT_SIZE_BYTES, + }, + } + + +async def delete_thread_attachment_view( + *, + thread_id: str, + file_id: str, + db: AsyncSession, + current_user_id: str, +) -> dict: + conv_repo = ConversationRepository(db) + conversation = await require_user_conversation(conv_repo, thread_id, str(current_user_id)) + removed = await conv_repo.remove_attachment(conversation.id, file_id) + if not removed: + raise HTTPException(status_code=404, detail="附件不存在或已被删除") + return {"message": "附件已删除"} diff --git a/src/services/evaluation_service.py b/src/services/evaluation_service.py index 604099a00..1280fac7f 100644 --- a/src/services/evaluation_service.py +++ b/src/services/evaluation_service.py @@ -1,5 +1,4 @@ import asyncio -import glob import json import os import re @@ -7,9 +6,11 @@ from datetime import datetime from typing import Any -from server.services.tasker import TaskContext, tasker from src.knowledge import knowledge_base from src.models import select_model +from src.repositories.evaluation_repository import EvaluationRepository +from src.repositories.knowledge_base_repository import KnowledgeBaseRepository +from src.services.task_service import TaskContext, tasker from src.utils import logger from src.utils.evaluation_metrics import EvaluationMetricsCalculator @@ -18,17 +19,20 @@ class EvaluationService: """RAG评估服务""" def __init__(self): - pass + self.eval_repo = EvaluationRepository() + self.kb_repo = KnowledgeBaseRepository() - def _get_benchmark_dir(self, db_id: str) -> str: - kb_instance = knowledge_base.get_kb(db_id) + async def _get_benchmark_dir(self, db_id: str) -> str: + """获取评估基准目录""" + kb_instance = await knowledge_base.aget_kb(db_id) base_dir = os.path.join(kb_instance.work_dir, db_id) path = os.path.join(base_dir, "benchmarks") os.makedirs(path, exist_ok=True) return path - def _get_result_dir(self, db_id: str) -> str: - kb_instance = knowledge_base.get_kb(db_id) + async def _get_result_dir(self, db_id: str) -> str: + """获取评估结果目录""" + kb_instance = await knowledge_base.aget_kb(db_id) base_dir = os.path.join(kb_instance.work_dir, db_id) path = os.path.join(base_dir, "results") os.makedirs(path, exist_ok=True) @@ -68,7 +72,7 @@ async def upload_benchmark( raise ValueError("文件中没有有效的问题数据") benchmark_id = f"benchmark_{uuid.uuid4().hex[:8]}" - benchmark_dir = self._get_benchmark_dir(db_id) + benchmark_dir = await self._get_benchmark_dir(db_id) # 保存数据文件 (.jsonl) data_file_path = os.path.join(benchmark_dir, f"{benchmark_id}.jsonl") @@ -89,11 +93,19 @@ async def upload_benchmark( "created_at": datetime.utcnow().isoformat(), "updated_at": datetime.utcnow().isoformat(), } - kb_instance = knowledge_base.get_kb(db_id) - if db_id not in kb_instance.benchmarks_meta: - kb_instance.benchmarks_meta[db_id] = {} - kb_instance.benchmarks_meta[db_id][benchmark_id] = meta - kb_instance._save_metadata() + await self.eval_repo.create_benchmark( + { + "benchmark_id": benchmark_id, + "db_id": db_id, + "name": name, + "description": description, + "question_count": len(questions), + "has_gold_chunks": has_gold_chunks, + "has_gold_answers": has_gold_answers, + "data_file_path": data_file_path, + "created_by": created_by, + } + ) return meta except Exception as e: @@ -103,11 +115,24 @@ async def upload_benchmark( async def get_benchmarks(self, db_id: str) -> list[dict[str, Any]]: """获取知识库的评估基准列表""" try: - kb_instance = knowledge_base.get_kb(db_id) - benchmarks_map = kb_instance.benchmarks_meta.get(db_id, {}) - benchmarks = list(benchmarks_map.values()) - benchmarks.sort(key=lambda x: x.get("created_at", ""), reverse=True) - return benchmarks + rows = await self.eval_repo.list_benchmarks(db_id) + return [ + { + "id": row.benchmark_id, + "benchmark_id": row.benchmark_id, + "name": row.name, + "description": row.description, + "db_id": row.db_id, + "question_count": row.question_count, + "has_gold_chunks": row.has_gold_chunks, + "has_gold_answers": row.has_gold_answers, + "benchmark_file": row.data_file_path, + "created_by": row.created_by, + "created_at": row.created_at.isoformat() if row.created_at else None, + "updated_at": row.updated_at.isoformat() if row.updated_at else None, + } + for row in rows + ] except Exception as e: logger.error(f"获取评估基准列表失败: {e}") @@ -116,21 +141,30 @@ async def get_benchmarks(self, db_id: str) -> list[dict[str, Any]]: async def get_benchmark_detail(self, benchmark_id: str) -> dict[str, Any]: """获取评估基准详情 (包含问题列表)""" try: - for kb_instance in knowledge_base.kb_instances.values(): - for db_id, m in kb_instance.benchmarks_meta.items(): - if benchmark_id in m: - meta = m[benchmark_id] - data_file_path = meta.get("benchmark_file") - questions = [] - if data_file_path and os.path.exists(data_file_path): - with open(data_file_path, encoding="utf-8") as f: - for line in f: - if line.strip(): - questions.append(json.loads(line)) - meta_with_q = meta.copy() - meta_with_q["questions"] = questions - return meta_with_q - raise ValueError("Benchmark not found") + row = await self.eval_repo.get_benchmark(benchmark_id) + if row is None: + raise ValueError("Benchmark not found") + questions = [] + if row.data_file_path and os.path.exists(row.data_file_path): + with open(row.data_file_path, encoding="utf-8") as f: + for line in f: + if line.strip(): + questions.append(json.loads(line)) + return { + "id": row.benchmark_id, + "benchmark_id": row.benchmark_id, + "name": row.name, + "description": row.description, + "db_id": row.db_id, + "question_count": row.question_count, + "has_gold_chunks": row.has_gold_chunks, + "has_gold_answers": row.has_gold_answers, + "benchmark_file": row.data_file_path, + "created_by": row.created_by, + "created_at": row.created_at.isoformat() if row.created_at else None, + "updated_at": row.updated_at.isoformat() if row.updated_at else None, + "questions": questions, + } except Exception as e: logger.error(f"获取评估基准详情失败: {e}") @@ -141,15 +175,11 @@ async def get_benchmark_detail_by_db( ) -> dict[str, Any]: """根据 db_id 获取评估基准详情(支持分页)""" try: - kb_instance = knowledge_base.get_kb(db_id) - benchmarks_map = kb_instance.benchmarks_meta.get(db_id, {}) - if benchmark_id not in benchmarks_map: + row = await self.eval_repo.get_benchmark(benchmark_id) + if row is None or row.db_id != db_id: raise ValueError("Benchmark not found") - meta = benchmarks_map[benchmark_id] - data_file_path = meta.get("benchmark_file") - - # 获取总问题数和分页数据 - total_questions = meta.get("question_count", 0) + data_file_path = row.data_file_path + total_questions = row.question_count or 0 questions = [] if data_file_path and os.path.exists(data_file_path): @@ -175,21 +205,29 @@ async def get_benchmark_detail_by_db( # 计算分页信息 total_pages = (total_questions + page_size - 1) // page_size - meta_with_q = meta.copy() - meta_with_q.update( - { - "questions": questions, - "pagination": { - "current_page": page, - "page_size": page_size, - "total_questions": total_questions, - "total_pages": total_pages, - "has_next": page < total_pages, - "has_prev": page > 1, - }, - } - ) - return meta_with_q + return { + "id": row.benchmark_id, + "benchmark_id": row.benchmark_id, + "name": row.name, + "description": row.description, + "db_id": row.db_id, + "question_count": row.question_count, + "has_gold_chunks": row.has_gold_chunks, + "has_gold_answers": row.has_gold_answers, + "benchmark_file": data_file_path, + "created_by": row.created_by, + "created_at": row.created_at.isoformat() if row.created_at else None, + "updated_at": row.updated_at.isoformat() if row.updated_at else None, + "questions": questions, + "pagination": { + "current_page": page, + "page_size": page_size, + "total_questions": total_questions, + "total_pages": total_pages, + "has_next": page < total_pages, + "has_prev": page > 1, + }, + } except Exception as e: logger.error(f"获取评估基准详情失败: {e}") raise @@ -197,27 +235,24 @@ async def get_benchmark_detail_by_db( async def delete_benchmark(self, benchmark_id: str) -> None: """删除评估基准""" try: - # 在所有KB中查找并删除 - for kb_instance in knowledge_base.kb_instances.values(): - for db_id, m in list(kb_instance.benchmarks_meta.items()): - if benchmark_id in m: - meta = m[benchmark_id] - data_file_path = meta.get("benchmark_file") - if data_file_path and os.path.exists(data_file_path): - os.remove(data_file_path) - del kb_instance.benchmarks_meta[db_id][benchmark_id] - kb_instance._save_metadata() - logger.info(f"成功删除评估基准: {benchmark_id}") - return - raise ValueError("Benchmark not found") + row = await self.eval_repo.get_benchmark(benchmark_id) + if row is None: + raise ValueError("Benchmark not found") + if row.data_file_path and os.path.exists(row.data_file_path): + os.remove(row.data_file_path) + await self.eval_repo.delete_benchmark(benchmark_id) + logger.info(f"成功删除评估基准: {benchmark_id}") + return except Exception as e: logger.error(f"删除评估基准失败: {e}") raise - async def delete_evaluation_result(self, task_id: str) -> None: + async def delete_evaluation_result(self, task_id: str, db_id: str) -> None: """删除评估结果""" - raise ValueError("Endpoint requires db_id; use delete_evaluation_result_by_db") + if not task_id: + raise ValueError("task_id is required") + await self.delete_evaluation_result_by_db(db_id, task_id) async def generate_benchmark(self, db_id: str, params: dict[str, Any], created_by: str) -> dict[str, Any]: task_id = f"gen_benchmark_{uuid.uuid4().hex[:8]}" @@ -251,7 +286,7 @@ async def _generate_benchmark_task(self, context: TaskContext): if neighbors_count > 10: neighbors_count = 10 - kb_instance = knowledge_base.get_kb(db_id) + kb_instance = await knowledge_base.aget_kb(db_id) if not kb_instance: await context.set_message("知识库不存在") raise ValueError("Knowledge Base not found") @@ -314,7 +349,7 @@ def cosine(a, b, na, nb): llm = select_model(model_spec=llm_model_spec) benchmark_id = f"benchmark_{uuid.uuid4().hex[:8]}" - bench_dir = self._get_benchmark_dir(db_id) + bench_dir = await self._get_benchmark_dir(db_id) data_file_path = os.path.join(bench_dir, f"{benchmark_id}.jsonl") generated = 0 @@ -381,25 +416,19 @@ def cosine(a, b, na, nb): logger.warning(f"Benchmark generation failed for one item: {e}") continue - meta = { - "id": benchmark_id, - "benchmark_id": benchmark_id, - "name": name, - "description": description, - "db_id": db_id, - "question_count": generated, - "has_gold_chunks": True, - "has_gold_answers": True, - "benchmark_file": data_file_path, - "created_by": payload.get("created_by"), - "created_at": datetime.utcnow().isoformat(), - "updated_at": datetime.utcnow().isoformat(), - } - kb_instance = knowledge_base.get_kb(db_id) - if db_id not in kb_instance.benchmarks_meta: - kb_instance.benchmarks_meta[db_id] = {} - kb_instance.benchmarks_meta[db_id][benchmark_id] = meta - kb_instance._save_metadata() + await self.eval_repo.create_benchmark( + { + "benchmark_id": benchmark_id, + "db_id": db_id, + "name": name, + "description": description, + "question_count": generated, + "has_gold_chunks": True, + "has_gold_answers": True, + "data_file_path": data_file_path, + "created_by": payload.get("created_by"), + } + ) await context.set_progress(100, "完成") @@ -410,18 +439,16 @@ async def run_evaluation( try: task_id = f"eval_{uuid.uuid4().hex[:8]}" - kb_instance = knowledge_base.get_kb(db_id) - bm = kb_instance.benchmarks_meta.get(db_id, {}).get(benchmark_id) - if not bm: + benchmark_row = await self.eval_repo.get_benchmark(benchmark_id) + if benchmark_row is None or benchmark_row.db_id != db_id: raise ValueError("Benchmark not found") - benchmark_meta = bm # 从知识库元数据中获取检索配置 retrieval_config = {} try: - kb_meta = knowledge_base.global_databases_meta.get(db_id, {}) - query_params = kb_meta.get("query_params", {}) - retrieval_config = query_params.get("options", {}) + kb_row = await self.kb_repo.get_by_id(db_id) + query_params = (kb_row.query_params if kb_row else None) or {} + retrieval_config = query_params.get("options", {}) if isinstance(query_params, dict) else {} logger.info(f"从知识库 {db_id} 加载检索配置: {list(retrieval_config.keys())}") except Exception as e: logger.error(f"获取知识库检索配置失败: {e}") @@ -431,30 +458,25 @@ async def run_evaluation( if model_config: retrieval_config.update(model_config) - # 初始化结果文件 (Status: running) - result_dir = self._get_result_dir(db_id) - result_file_path = os.path.join(result_dir, f"{task_id}.json") - - initial_result = { - "id": task_id, # for compatibility - "task_id": task_id, - "benchmark_id": benchmark_id, - "db_id": db_id, - "retrieval_config": retrieval_config, - "metrics": {}, - "status": "running", - "total_questions": benchmark_meta.get("question_count", 0), - "completed_questions": 0, - "started_at": datetime.utcnow().isoformat(), - "completed_at": None, - "interim_results": [], - } - - with open(result_file_path, "w", encoding="utf-8") as f: - json.dump(initial_result, f, ensure_ascii=False, indent=2) + await self.eval_repo.create_result( + { + "task_id": task_id, + "db_id": db_id, + "benchmark_id": benchmark_id, + "status": "running", + "retrieval_config": retrieval_config, + "metrics": {}, + "overall_score": None, + "total_questions": benchmark_row.question_count or 0, + "completed_questions": 0, + "started_at": datetime.utcnow(), + "completed_at": None, + "created_by": created_by, + } + ) await tasker.enqueue( - name=f"RAG评估({benchmark_meta.get('name')})", + name=f"RAG评估({benchmark_row.name})", task_type="rag_evaluation", payload={ "task_id": task_id, @@ -487,11 +509,10 @@ async def _run_evaluation_task(self, context: TaskContext): # 加载基准数据 await context.set_progress(5, "加载基准数据") - kb_instance = knowledge_base.get_kb(db_id) - benchmark_meta = kb_instance.benchmarks_meta.get(db_id, {}).get(benchmark_id) - if not benchmark_meta: + benchmark_row = await self.eval_repo.get_benchmark(benchmark_id) + if benchmark_row is None or benchmark_row.db_id != db_id: raise ValueError("Benchmark not found") - data_path = benchmark_meta.get("benchmark_file") + data_path = benchmark_row.data_file_path if not data_path or not os.path.exists(data_path): raise ValueError("Benchmark file not found") @@ -502,7 +523,7 @@ async def _run_evaluation_task(self, context: TaskContext): benchmark_data.append(json.loads(line)) # 开始评估 - kb_instance = knowledge_base.get_kb(db_id) + kb_instance = await knowledge_base.aget_kb(db_id) if not kb_instance: raise ValueError(f"Knowledge Base {db_id} not found") @@ -511,7 +532,7 @@ async def _run_evaluation_task(self, context: TaskContext): # 初始化 Judge LLM judge_llm = None - if benchmark_meta.get("has_gold_answers"): + if benchmark_row.has_gold_answers: # 优先使用配置中的 judge_llm,否则回退到 answer_llm,或者默认 judge_model_spec = retrieval_config.get("judge_llm") or retrieval_config.get("answer_llm") if judge_model_spec: @@ -522,36 +543,25 @@ async def _run_evaluation_task(self, context: TaskContext): logger.error(f"Failed to load judge LLM: {e}") total_questions = len(benchmark_data) - interim_results = [] all_retrieval_metrics = [] all_answer_metrics = [] - # 更新结果文件 helper - result_file_path = os.path.join(self._get_result_dir(db_id), f"{task_id}.json") - - def update_result_file(status="running", completed=0, metrics=None, interim=None, final_score=None): - try: - if os.path.exists(result_file_path): - with open(result_file_path, encoding="utf-8") as f: - data = json.load(f) - else: - data = {} # Should have been created in run_evaluation - - data["status"] = status - data["completed_questions"] = completed - if metrics: - data["metrics"] = metrics - if interim is not None: - data["interim_results"] = interim - if final_score is not None: - data["overall_score"] = final_score + async def update_result_db( + status: str | None = None, completed: int | None = None, metrics=None, final_score=None + ): + payload = {} + if status is not None: + payload["status"] = status if status in ["completed", "failed"]: - data["completed_at"] = datetime.utcnow().isoformat() - - with open(result_file_path, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) - except Exception as e: - logger.error(f"Failed to update result file: {e}") + payload["completed_at"] = datetime.utcnow() + if completed is not None: + payload["completed_questions"] = completed + if metrics is not None: + payload["metrics"] = metrics + if final_score is not None: + payload["overall_score"] = final_score + if payload: + await self.eval_repo.update_result(task_id, payload) for i, question_data in enumerate(benchmark_data): # 检查任务是否被取消 @@ -610,14 +620,14 @@ def update_result_file(status="running", completed=0, metrics=None, interim=None retrieval_scores = {} answer_scores = {} - if benchmark_meta.get("has_gold_chunks") and question_data.get("gold_chunk_ids"): + if benchmark_row.has_gold_chunks and question_data.get("gold_chunk_ids"): retrieval_scores = EvaluationMetricsCalculator.calculate_retrieval_metrics( retrieved_chunks, question_data["gold_chunk_ids"] ) current_metrics.update(retrieval_scores) all_retrieval_metrics.append(retrieval_scores) - if benchmark_meta.get("has_gold_answers") and question_data.get("gold_answer"): + if benchmark_row.has_gold_answers and question_data.get("gold_answer"): if judge_llm: # 评判过程包含 LLM 调用,使用 asyncio.to_thread 避免阻塞 answer_scores = await asyncio.to_thread( @@ -632,15 +642,17 @@ def update_result_file(status="running", completed=0, metrics=None, interim=None else: logger.warning("需要计算答案指标但未配置 Judge LLM") - interim_results.append( - { - "query": question_data["query"], + await self.eval_repo.upsert_result_detail( + task_id=task_id, + query_index=i, + data={ + "query_text": question_data["query"], "gold_chunk_ids": question_data.get("gold_chunk_ids"), "gold_answer": question_data.get("gold_answer"), "generated_answer": generated_answer, "retrieved_chunks": retrieved_chunks, "metrics": current_metrics, - } + }, ) # 计算当前累计指标 @@ -666,7 +678,7 @@ def update_result_file(status="running", completed=0, metrics=None, interim=None # 定期更新文件 (每5个或最后一个) if (i + 1) % 5 == 0 or (i + 1) == total_questions: - update_result_file(completed=i + 1, interim=interim_results) + await update_result_db(completed=i + 1) # 最终计算 await context.set_progress(95, "计算最终指标") @@ -690,75 +702,53 @@ def update_result_file(status="running", completed=0, metrics=None, interim=None ) overall_metrics["overall_score"] = overall_score - update_result_file( + await update_result_db( status="completed", completed=total_questions, metrics=overall_metrics, - interim=interim_results, final_score=overall_score, ) await context.set_progress(100, "完成") except Exception as e: logger.error(f"Task failed: {e}") - # Try to update status to failed try: - # Need to find the file path again or pass it around. - # Re-deriving from payload if available if "payload" in locals(): - path = os.path.join(self._get_result_dir(payload["db_id"]), f"{payload['task_id']}.json") - if os.path.exists(path): - with open(path, encoding="utf-8") as f: - d = json.load(f) - d["status"] = "failed" - d["error"] = str(e) - with open(path, "w", encoding="utf-8") as f: - json.dump(d, f, ensure_ascii=False, indent=2) - except Exception as e: - logger.error(f"Error updating result file: {e}") - pass + await self.eval_repo.update_result( + payload["task_id"], + {"status": "failed", "metrics": {"error": str(e)}, "completed_at": datetime.utcnow()}, + ) + except Exception as exc: + logger.error(f"Error updating result record: {exc}") await context.set_message(f"Error: {str(e)}") raise - async def get_evaluation_results(self, task_id: str) -> dict[str, Any]: + async def get_evaluation_results(self, task_id: str, db_id: str) -> dict[str, Any]: """获取评估结果""" - raise ValueError("Endpoint requires db_id; use get_evaluation_results_by_db") + if not task_id: + raise ValueError("task_id is required") + return await self.get_evaluation_results_by_db(db_id, task_id) async def get_evaluation_history(self, db_id: str) -> list[dict[str, Any]]: """获取知识库的评估历史记录""" try: - result_dir = self._get_result_dir(db_id) - history = [] - - # 查找所有 .json 文件 - result_files = glob.glob(os.path.join(result_dir, "*.json")) - for result_file in result_files: - try: - with open(result_file, encoding="utf-8") as f: - data = json.load(f) - # 只返回摘要信息,不返回详细的interim_results - summary = { - "task_id": data.get("task_id"), - "benchmark_id": data.get("benchmark_id"), - "status": data.get("status"), - "started_at": data.get("started_at"), - "completed_at": data.get("completed_at"), - "total_questions": data.get("total_questions"), - "completed_questions": data.get("completed_questions"), - "overall_score": data.get("overall_score"), - # 包含检索配置 - "retrieval_config": data.get("retrieval_config", {}), - # 也可以带上部分 metrics 摘要 - "metrics": data.get("metrics"), - } - history.append(summary) - except Exception as e: - logger.error(f"Failed to load result file {result_file}: {e}") - - # 按开始时间倒序 - history.sort(key=lambda x: x.get("started_at", ""), reverse=True) - return history + rows = await self.eval_repo.list_results(db_id) + return [ + { + "task_id": row.task_id, + "benchmark_id": row.benchmark_id, + "status": row.status, + "started_at": row.started_at.isoformat() if row.started_at else None, + "completed_at": row.completed_at.isoformat() if row.completed_at else None, + "total_questions": row.total_questions, + "completed_questions": row.completed_questions, + "overall_score": row.overall_score, + "retrieval_config": row.retrieval_config or {}, + "metrics": row.metrics or {}, + } + for row in rows + ] except Exception as e: logger.error(f"获取评估历史失败: {e}") @@ -768,81 +758,70 @@ async def get_evaluation_history(self, db_id: str) -> list[dict[str, Any]]: async def get_evaluation_results_by_db( self, db_id: str, task_id: str, page: int = 1, page_size: int = 20, error_only: bool = False ) -> dict[str, Any]: - # Validate task_id format to prevent path traversal if not re.match(r"^eval_[a-f0-9]{8}$", task_id): raise ValueError("Invalid task_id format") - result_file_path = os.path.join(self._get_result_dir(db_id), f"{task_id}.json") - if not os.path.exists(result_file_path): + row = await self.eval_repo.get_result(task_id) + if row is None or row.db_id != db_id: task = await tasker.get_task(task_id) if task: - return { - "task_id": task_id, - "status": task.status, - "progress": task.progress, - "message": task.message, - } + return {"task_id": task_id, "status": task.status, "progress": task.progress, "message": task.message} raise ValueError(f"Result not found for task {task_id}") - # 加载JSON文件 - with open(result_file_path, encoding="utf-8") as f: - data = json.load(f) - - # 如果是分页请求,处理详细结果 - if page and page_size: - all_results = data.get("interim_results", data.get("results", [])) - - # 如果只要错误结果,先过滤 - if error_only: - filtered_results = [] - for item in all_results: - # 检查答案评分是否为错误(score <= 0.5) - if item.get("metrics", {}).get("score", 1.0) <= 0.5: - filtered_results.append(item) - continue - - # 检查检索指标是否明显偏低 - metrics = item.get("metrics", {}) - has_low_recall = any(metrics.get(k, 1.0) < 0.3 for k in metrics if k.startswith("recall@")) - if has_low_recall: - filtered_results.append(item) - all_results = filtered_results - - # 计算分页 - total = len(all_results) - start_idx = (page - 1) * page_size - end_idx = start_idx + page_size - paged_results = all_results[start_idx:end_idx] - - # 返回分页数据 - return { - "task_id": data.get("task_id", task_id), - "status": data.get("status"), - "started_at": data.get("started_at"), - "completed_at": data.get("completed_at"), - "total_questions": data.get("total_questions", 0), - "completed_questions": data.get("completed_questions", 0), - "overall_score": data.get("overall_score"), - "retrieval_config": data.get("retrieval_config"), - "interim_results": paged_results, - "pagination": { - "current_page": page, - "page_size": page_size, - "total": total, - "total_pages": (total + page_size - 1) // page_size, - "error_only": error_only, - }, + details = await self.eval_repo.list_result_details(task_id) + all_results = [ + { + "query": d.query_text, + "gold_chunk_ids": d.gold_chunk_ids, + "gold_answer": d.gold_answer, + "generated_answer": d.generated_answer, + "retrieved_chunks": d.retrieved_chunks, + "metrics": d.metrics or {}, } - - # 非分页请求,返回完整数据(保持向后兼容) - return data + for d in details + ] + + if error_only: + filtered_results = [] + for item in all_results: + if item.get("metrics", {}).get("score", 1.0) <= 0.5: + filtered_results.append(item) + continue + metrics = item.get("metrics", {}) + has_low_recall = any(metrics.get(k, 1.0) < 0.3 for k in metrics if k.startswith("recall@")) + if has_low_recall: + filtered_results.append(item) + all_results = filtered_results + + total = len(all_results) + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + paged_results = all_results[start_idx:end_idx] + + return { + "task_id": row.task_id, + "status": row.status, + "started_at": row.started_at.isoformat() if row.started_at else None, + "completed_at": row.completed_at.isoformat() if row.completed_at else None, + "total_questions": row.total_questions or 0, + "completed_questions": row.completed_questions or 0, + "overall_score": row.overall_score, + "retrieval_config": row.retrieval_config or {}, + "interim_results": paged_results, + "pagination": { + "current_page": page, + "page_size": page_size, + "total": total, + "total_pages": (total + page_size - 1) // page_size, + "error_only": error_only, + }, + } async def delete_evaluation_result_by_db(self, db_id: str, task_id: str) -> None: - # Validate task_id format to prevent path traversal if not re.match(r"^eval_[a-f0-9]{8}$", task_id): raise ValueError("Invalid task_id format") - result_file_path = os.path.join(self._get_result_dir(db_id), f"{task_id}.json") - if os.path.exists(result_file_path): - os.remove(result_file_path) - logger.info(f"成功删除评估结果: {task_id}") - return - raise ValueError("Result not found") + row = await self.eval_repo.get_result(task_id) + if row is None or row.db_id != db_id: + raise ValueError("Result not found") + await self.eval_repo.delete_result(task_id) + logger.info(f"成功删除评估结果: {task_id}") + return diff --git a/src/services/feedback_service.py b/src/services/feedback_service.py new file mode 100644 index 000000000..3db6c2cd2 --- /dev/null +++ b/src/services/feedback_service.py @@ -0,0 +1,96 @@ +import traceback + +from fastapi import HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.storage.postgres.models_business import Conversation, Message, MessageFeedback +from src.utils.logging_config import logger + + +async def submit_message_feedback_view( + *, + message_id: int, + rating: str, + reason: str | None, + db: AsyncSession, + current_user_id: str, +) -> dict: + if rating not in ["like", "dislike"]: + raise HTTPException(status_code=422, detail="Rating must be 'like' or 'dislike'") + + try: + message_result = await db.execute(select(Message).filter_by(id=message_id)) + message = message_result.scalar_one_or_none() + if not message: + raise HTTPException(status_code=404, detail="Message not found") + + conversation_result = await db.execute(select(Conversation).filter_by(id=message.conversation_id)) + conversation = conversation_result.scalar_one_or_none() + if not conversation or conversation.user_id != str(current_user_id): + raise HTTPException(status_code=403, detail="Access denied") + + existing_feedback_result = await db.execute( + select(MessageFeedback).filter_by(message_id=message_id, user_id=str(current_user_id)) + ) + existing_feedback = existing_feedback_result.scalar_one_or_none() + if existing_feedback: + raise HTTPException(status_code=409, detail="Feedback already submitted for this message") + + new_feedback = MessageFeedback( + message_id=message_id, + user_id=str(current_user_id), + rating=rating, + reason=reason, + ) + + db.add(new_feedback) + await db.commit() + await db.refresh(new_feedback) + + logger.info(f"User {current_user_id} submitted {rating} feedback for message {message_id}") + + return { + "id": new_feedback.id, + "message_id": new_feedback.message_id, + "rating": new_feedback.rating, + "reason": new_feedback.reason, + "created_at": new_feedback.created_at.isoformat(), + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error submitting message feedback: {e}, {traceback.format_exc()}") + await db.rollback() + raise HTTPException(status_code=500, detail=f"Failed to submit feedback: {str(e)}") + + +async def get_message_feedback_view( + *, + message_id: int, + db: AsyncSession, + current_user_id: str, +) -> dict: + try: + feedback_result = await db.execute( + select(MessageFeedback).filter_by(message_id=message_id, user_id=str(current_user_id)) + ) + feedback = feedback_result.scalar_one_or_none() + + if not feedback: + return {"has_feedback": False, "feedback": None} + + return { + "has_feedback": True, + "feedback": { + "id": feedback.id, + "rating": feedback.rating, + "reason": feedback.reason, + "created_at": feedback.created_at.isoformat(), + }, + } + + except Exception as e: + logger.error(f"Error getting message feedback: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get feedback: {str(e)}") diff --git a/src/services/history_query_service.py b/src/services/history_query_service.py new file mode 100644 index 000000000..2c89f1c1b --- /dev/null +++ b/src/services/history_query_service.py @@ -0,0 +1,72 @@ +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from src.agents import agent_manager +from src.repositories.conversation_repository import ConversationRepository +from src.utils.logging_config import logger + + +async def get_agent_history_view( + *, + agent_id: str, + thread_id: str, + current_user_id: str, + db: AsyncSession, +) -> dict: + if not agent_manager.get_agent(agent_id): + raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + + conv_repo = ConversationRepository(db) + conversation = await conv_repo.get_conversation_by_thread_id(thread_id) + if not conversation or conversation.user_id != str(current_user_id) or conversation.status == "deleted": + raise HTTPException(status_code=404, detail="对话线程不存在") + + messages = await conv_repo.get_messages_by_thread_id(thread_id) + + history: list[dict] = [] + role_type_map = {"user": "human", "assistant": "ai", "tool": "tool", "system": "system"} + + for msg in messages: + user_feedback = None + if msg.feedbacks: + for feedback in msg.feedbacks: + if feedback.user_id == str(current_user_id): + user_feedback = { + "id": feedback.id, + "rating": feedback.rating, + "reason": feedback.reason, + "created_at": feedback.created_at.isoformat() if feedback.created_at else None, + } + break + + msg_dict = { + "id": msg.id, + "type": role_type_map.get(msg.role, msg.role), + "content": msg.content, + "created_at": msg.created_at.isoformat() if msg.created_at else None, + "error_type": msg.extra_metadata.get("error_type") if msg.extra_metadata else None, + "error_message": msg.extra_metadata.get("error_message") if msg.extra_metadata else None, + "extra_metadata": msg.extra_metadata, + "message_type": msg.message_type, + "image_content": msg.image_content, + "feedback": user_feedback, + } + + if msg.tool_calls: + msg_dict["tool_calls"] = [ + { + "id": str(tc.id), + "name": tc.tool_name, + "function": {"name": tc.tool_name}, + "args": tc.tool_input or {}, + "tool_call_result": {"content": (tc.tool_output or "")} if tc.status == "success" else None, + "status": tc.status, + "error_message": tc.error_message, + } + for tc in msg.tool_calls + ] + + history.append(msg_dict) + + logger.info(f"Loaded {len(history)} messages with feedback for thread {thread_id}") + return {"history": history} diff --git a/src/services/mcp_service.py b/src/services/mcp_service.py index 90651a537..d3d9da0ed 100644 --- a/src/services/mcp_service.py +++ b/src/services/mcp_service.py @@ -17,7 +17,7 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession -from src.storage.db.models import MCPServer +from src.storage.postgres.models_business import MCPServer from src.utils import logger # ============================================================================= @@ -65,10 +65,10 @@ async def load_mcp_servers_from_db() -> None: global MCP_SERVERS # Delayed import to avoid circular references - from src.storage.db.manager import db_manager + from src.storage.postgres.manager import pg_manager try: - async with db_manager.get_async_session_context() as session: + async with pg_manager.get_async_session_context() as session: result = await session.execute(select(MCPServer).filter(MCPServer.enabled == 1)) servers = result.scalars().all() @@ -111,10 +111,10 @@ async def init_mcp_servers() -> None: Also ensures all built-in MCP servers are present in the database. """ # Delayed import to avoid circular references - from src.storage.db.manager import db_manager + from src.storage.postgres.manager import pg_manager try: - async with db_manager.get_async_session_context() as session: + async with pg_manager.get_async_session_context() as session: # Check if database has MCP configurations result = await session.execute(select(func.count(MCPServer.name))) count = result.scalar() diff --git a/server/services/tasker.py b/src/services/task_service.py similarity index 78% rename from server/services/tasker.py rename to src/services/task_service.py index 1bdfdb35e..d4741b6c7 100644 --- a/server/services/tasker.py +++ b/src/services/task_service.py @@ -1,23 +1,23 @@ import asyncio -import json -import os import uuid +from collections import Counter +from collections.abc import Awaitable, Callable from dataclasses import asdict, dataclass, field -from pathlib import Path +from datetime import datetime from typing import Any -from collections.abc import Awaitable, Callable -from collections import Counter -from src.config import config +from src.repositories.task_repository import TaskRepository +from src.utils.datetime_utils import coerce_any_to_utc_datetime, utc_isoformat from src.utils.logging_config import logger -from src.utils.datetime_utils import utc_isoformat TaskCoroutine = Callable[["TaskContext"], Awaitable[Any]] TERMINAL_STATUSES = {"success", "failed", "cancelled"} -def _utc_timestamp() -> str: - return utc_isoformat() +def _iso_to_utc_naive(value: str | None) -> datetime | None: + if not value: + return None + return coerce_any_to_utc_datetime(value).replace(tzinfo=None) @dataclass @@ -28,8 +28,8 @@ class Task: status: str = "pending" progress: float = 0.0 message: str = "" - created_at: str = field(default_factory=_utc_timestamp) - updated_at: str = field(default_factory=_utc_timestamp) + created_at: str = field(default_factory=utc_isoformat) + updated_at: str = field(default_factory=utc_isoformat) started_at: str | None = None completed_at: str | None = None payload: dict[str, Any] = field(default_factory=dict) @@ -38,8 +38,7 @@ class Task: cancel_requested: bool = False def to_dict(self) -> dict[str, Any]: - data = asdict(self) - return data + return asdict(self) def to_summary_dict(self) -> dict[str, Any]: data = asdict(self) @@ -56,8 +55,8 @@ def from_dict(cls, data: dict[str, Any]) -> "Task": status=data.get("status", "pending"), progress=data.get("progress", 0.0), message=data.get("message", ""), - created_at=data.get("created_at", _utc_timestamp()), - updated_at=data.get("updated_at", _utc_timestamp()), + created_at=data.get("created_at", utc_isoformat()), + updated_at=data.get("updated_at", utc_isoformat()), started_at=data.get("started_at"), completed_at=data.get("completed_at"), payload=data.get("payload", {}), @@ -100,9 +99,8 @@ def __init__(self, worker_count: int = 2): self._tasks: dict[str, Task] = {} self._lock = asyncio.Lock() self._workers: list[asyncio.Task[Any]] = [] - self._storage_path = Path(config.save_dir) / "tasks" / "tasks.json" - os.makedirs(self._storage_path.parent, exist_ok=True) self._started = False + self._repo = TaskRepository() async def start(self) -> None: async with self._lock: @@ -123,7 +121,6 @@ async def shutdown(self) -> None: worker.cancel() await asyncio.gather(*self._workers, return_exceptions=True) self._workers.clear() - await self._persist_state() self._started = False logger.info("Tasker shutdown complete") @@ -139,7 +136,7 @@ async def enqueue( task = Task(id=task_id, name=name, type=task_type, payload=payload or {}) async with self._lock: self._tasks[task_id] = task - await self._persist_state() + await self._persist_task(task) await self._queue.put((task_id, coroutine)) logger.info("Enqueued task {} ({})", task_id, name) return task @@ -180,11 +177,11 @@ async def cancel_task(self, task_id: str) -> bool: task = self._tasks.get(task_id) if not task: return False - if task.status in {"success", "failed", "cancelled"}: + if task.status in TERMINAL_STATUSES: return False task.cancel_requested = True - task.updated_at = _utc_timestamp() - await self._persist_state() + task.updated_at = utc_isoformat() + await self._persist_task(task) logger.info("Cancellation requested for task {}", task_id) return True @@ -200,7 +197,7 @@ async def _worker_loop(self) -> None: await self._mark_cancelled(task_id, "Task was cancelled before execution") continue await self._update_task( - task_id, status="running", progress=0.0, message="任务开始执行", started_at=_utc_timestamp() + task_id, status="running", progress=0.0, message="任务开始执行", started_at=utc_isoformat() ) context = TaskContext(self, task_id) try: @@ -214,7 +211,7 @@ async def _worker_loop(self) -> None: progress=100.0, message="任务已完成", result=result, - completed_at=_utc_timestamp(), + completed_at=utc_isoformat(), ) except asyncio.CancelledError: await self._mark_cancelled(task_id, "任务被取消") @@ -226,7 +223,7 @@ async def _worker_loop(self) -> None: progress=100.0, message="任务执行失败", error=str(exc), - completed_at=_utc_timestamp(), + completed_at=utc_isoformat(), ) finally: self._queue.task_done() @@ -245,7 +242,7 @@ async def _mark_cancelled(self, task_id: str, message: str) -> None: status="cancelled", progress=100.0, message=message, - completed_at=_utc_timestamp(), + completed_at=utc_isoformat(), ) async def _update_task( @@ -278,49 +275,51 @@ async def _update_task( task.started_at = started_at if completed_at is not None: task.completed_at = completed_at - task.updated_at = _utc_timestamp() - await self._persist_state() + task.updated_at = utc_isoformat() + await self._persist_task(task) def _is_cancel_requested(self, task_id: str) -> bool: task = self._tasks.get(task_id) return bool(task and task.cancel_requested) async def _load_state(self) -> None: - if not self._storage_path.exists(): - return - try: - content = await asyncio.to_thread(self._storage_path.read_text, encoding="utf-8") - if not content.strip(): - return - data = json.loads(content) - tasks = data.get("tasks", []) - for item in tasks: - task = Task.from_dict(item) - if task.status == "running": - task.status = "failed" - task.message = "服务重启时任务中断" - task.updated_at = _utc_timestamp() - elif task.status not in TERMINAL_STATUSES: - task.status = "failed" - task.message = "服务重启时任务未继续执行" - task.updated_at = _utc_timestamp() - self._tasks[task.id] = task - logger.info("Loaded {} task records from storage", len(tasks)) - except Exception as exc: # noqa: BLE001 - logger.exception("Failed to load task state: {}", exc) - - async def _persist_state(self) -> None: - tasks = [task.to_dict() for task in self._tasks.values()] - payload = {"tasks": tasks, "updated_at": _utc_timestamp()} - - def _write() -> None: - self._storage_path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = self._storage_path.with_suffix(".tmp") - with open(tmp_path, "w", encoding="utf-8") as fh: - json.dump(payload, fh, ensure_ascii=False, indent=2) - os.replace(tmp_path, self._storage_path) - - await asyncio.to_thread(_write) + records = await self._repo.list_all() + updated: list[Task] = [] + for record in records: + task = Task.from_dict(record.to_dict()) + if task.status == "running": + task.status = "failed" + task.message = "服务重启时任务中断" + task.updated_at = utc_isoformat() + updated.append(task) + elif task.status not in TERMINAL_STATUSES: + task.status = "failed" + task.message = "服务重启时任务未继续执行" + task.updated_at = utc_isoformat() + updated.append(task) + self._tasks[task.id] = task + for task in updated: + await self._persist_task(task) + if records: + logger.info("Loaded {} task records from storage", len(records)) + + async def _persist_task(self, task: Task) -> None: + data: dict[str, Any] = { + "name": task.name, + "type": task.type, + "status": task.status, + "progress": task.progress, + "message": task.message, + "payload": task.payload, + "result": task.result, + "error": task.error, + "cancel_requested": 1 if task.cancel_requested else 0, + "created_at": _iso_to_utc_naive(task.created_at), + "updated_at": _iso_to_utc_naive(task.updated_at), + "started_at": _iso_to_utc_naive(task.started_at), + "completed_at": _iso_to_utc_naive(task.completed_at), + } + await self._repo.upsert(task.id, data) tasker = Tasker() diff --git a/src/storage/conversation/__init__.py b/src/storage/conversation/__init__.py deleted file mode 100644 index 547b201c4..000000000 --- a/src/storage/conversation/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .manager import ConversationManager - -__all__ = ["ConversationManager"] diff --git a/src/storage/db/__init__.py b/src/storage/db/__init__.py deleted file mode 100644 index 05c4dc588..000000000 --- a/src/storage/db/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from sqlalchemy.ext.declarative import declarative_base - -Base = declarative_base() - -# 导入所有模型,确保它们被注册到 Base.metadata -# 导入数据库管理器 -from src.storage.db.manager import db_manager # noqa: E402 -from src.storage.db.models import ( # noqa: E402, F401 - OperationLog, - User, -) # noqa: E402 - -__all__ = [ - "Base", - "User", - "OperationLog", - "db_manager", -] diff --git a/src/storage/db/manager.py b/src/storage/db/manager.py deleted file mode 100644 index 1330c6cd3..000000000 --- a/src/storage/db/manager.py +++ /dev/null @@ -1,152 +0,0 @@ -import asyncio -import json -import os -import pathlib -from contextlib import asynccontextmanager, contextmanager - -from sqlalchemy import create_engine, func, select -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import sessionmaker - -from server.utils.singleton import SingletonMeta -from src import config -from src.storage.db.models import Base, User -from src.utils import logger - -try: - from server.utils.migrate import DatabaseMigrator, validate_database_schema -except ImportError: - DatabaseMigrator = None - - # 如果迁移工具不存在,使用简单的占位函数 - def validate_database_schema(db_path): - return True, [] - - -class DBManager(metaclass=SingletonMeta): - """数据库管理器 - 提供异步数据库连接和会话管理""" - - def __init__(self): - self.db_path = os.path.join(config.save_dir, "database", "server.db") - self.ensure_db_dir() - - # 创建异步SQLAlchemy引擎,配置JSON序列化器以支持中文 - # 使用 ensure_ascii=False 确保中文字符不被转义为 Unicode 序列 - self.async_engine = create_async_engine( - f"sqlite+aiosqlite:///{self.db_path}", - json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), - json_deserializer=json.loads, - ) - - # 创建异步会话工厂 - self.AsyncSession = async_sessionmaker(bind=self.async_engine, class_=AsyncSession, expire_on_commit=False) - - # 保留同步引擎用于迁移等特殊操作 - self.engine = create_engine( - f"sqlite:///{self.db_path}", - json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), - json_deserializer=json.loads, - ) - self.Session = sessionmaker(bind=self.engine) - - # 首先创建基本表结构 - self.create_tables() - - # 然后检查并执行数据库迁移 - self.run_migrations() - - def ensure_db_dir(self): - """确保数据库目录存在""" - db_dir = os.path.dirname(self.db_path) - pathlib.Path(db_dir).mkdir(parents=True, exist_ok=True) - - def create_tables(self): - """创建数据库表""" - # 确保所有表都会被创建 - Base.metadata.create_all(self.engine) - logger.info("Database tables created/checked") - - def run_migrations(self): - """运行数据库迁移""" - if not os.path.exists(self.db_path): - return - - if DatabaseMigrator is not None: - migrator = DatabaseMigrator(self.db_path) - try: - migrator.run_migrations() - except Exception as exc: - logger.error(f"数据库迁移执行失败: {exc}") - else: - logger.warning("数据库迁移工具缺失,无法自动执行迁移") - - is_valid, issues = validate_database_schema(self.db_path) - - if not is_valid: - logger.warning("=" * 60) - logger.warning("检测到数据库结构与当前模型不一致!") - logger.warning("=" * 60) - for issue in issues: - logger.warning(f" ⚠️ {issue}") - logger.warning("") - logger.warning("请运行 scripts/migrate_user_soft_delete.py 手动修复数据库结构") - logger.warning("=" * 60) - - def get_session(self): - """获取同步数据库会话""" - return self.Session() - - @contextmanager - def get_session_context(self): - """获取同步数据库会话的上下文管理器""" - session = self.Session() - try: - yield session - session.commit() - except Exception as e: - session.rollback() - logger.error(f"Database operation failed: {e}") - raise - finally: - session.close() - - async def get_async_session(self): - """获取异步数据库会话""" - return self.AsyncSession() - - @asynccontextmanager - async def get_async_session_context(self): - """获取异步数据库会话的上下文管理器""" - session = self.AsyncSession() - try: - yield session - await session.commit() - except Exception as e: - await session.rollback() - logger.error(f"Async database operation failed: {e}") - raise - finally: - # Shield close operation to ensure connection is properly closed even if task is cancelled - # This prevents aiosqlite from raising errors during cancellation - await asyncio.shield(session.close()) - - def check_first_run(self): - """检查是否首次运行(同步版本)""" - session = self.get_session() - try: - # 检查是否有任何用户存在 - return session.query(User).count() == 0 - finally: - session.close() - - async def async_check_first_run(self): - """检查是否首次运行(异步版本)""" - async with self.get_async_session_context() as session: - # 检查是否有任何用户存在 - result = await session.execute(select(func.count(User.id))) - count = result.scalar() - return count == 0 - - -# 创建全局数据库管理器实例 -db_manager = DBManager() diff --git a/src/storage/db/models.py b/src/storage/db/models.py index a7e771716..c1edfc3e9 100644 --- a/src/storage/db/models.py +++ b/src/storage/db/models.py @@ -4,20 +4,11 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship -from src.utils.datetime_utils import coerce_datetime, utc_isoformat, utc_now +from src.utils.datetime_utils import coerce_datetime, format_utc_datetime, utc_now Base = declarative_base() -def _format_utc_datetime(dt_value): - """Helper to format datetime to UTC ISO string, assuming naive datetimes are UTC.""" - if dt_value is None: - return None - if dt_value.tzinfo is None: - dt_value = dt_value.replace(tzinfo=dt.UTC) - return utc_isoformat(dt_value) - - ## Removed legacy RDBMS knowledge models (KnowledgeDatabase/KnowledgeFile/KnowledgeNode) @@ -39,7 +30,7 @@ def to_dict(self): "id": self.id, "name": self.name, "description": self.description, - "created_at": _format_utc_datetime(self.created_at), + "created_at": format_utc_datetime(self.created_at), } @@ -72,8 +63,8 @@ def to_dict(self): "agent_id": self.agent_id, "title": self.title, "status": self.status, - "created_at": _format_utc_datetime(self.created_at), - "updated_at": _format_utc_datetime(self.updated_at), + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), "metadata": self.extra_metadata or {}, } @@ -106,7 +97,7 @@ def to_dict(self): "role": self.role, "content": self.content, "message_type": self.message_type, - "created_at": _format_utc_datetime(self.created_at), + "created_at": format_utc_datetime(self.created_at), "token_count": self.token_count, "metadata": self.extra_metadata or {}, "image_content": self.image_content, @@ -150,7 +141,7 @@ def to_dict(self): "tool_output": self.tool_output, "status": self.status, "error_message": self.error_message, - "created_at": _format_utc_datetime(self.created_at), + "created_at": format_utc_datetime(self.created_at), } @@ -181,8 +172,8 @@ def to_dict(self): "total_tokens": self.total_tokens, "model_used": self.model_used, "user_feedback": self.user_feedback or {}, - "created_at": _format_utc_datetime(self.created_at), - "updated_at": _format_utc_datetime(self.updated_at), + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), } @@ -227,13 +218,13 @@ def to_dict(self, include_password=False): "avatar": self.avatar, "role": self.role, "department_id": self.department_id, - "created_at": _format_utc_datetime(self.created_at), - "last_login": _format_utc_datetime(self.last_login), + "created_at": format_utc_datetime(self.created_at), + "last_login": format_utc_datetime(self.last_login), "login_failed_count": self.login_failed_count, - "last_failed_login": _format_utc_datetime(self.last_failed_login), - "login_locked_until": _format_utc_datetime(self.login_locked_until), + "last_failed_login": format_utc_datetime(self.last_failed_login), + "login_locked_until": format_utc_datetime(self.login_locked_until), "is_deleted": self.is_deleted, - "deleted_at": _format_utc_datetime(self.deleted_at), + "deleted_at": format_utc_datetime(self.deleted_at), } if include_password: result["password_hash"] = self.password_hash @@ -304,7 +295,7 @@ def to_dict(self): "operation": self.operation, "details": self.details, "ip_address": self.ip_address, - "timestamp": _format_utc_datetime(self.timestamp), + "timestamp": format_utc_datetime(self.timestamp), } @@ -332,7 +323,7 @@ def to_dict(self): "user_id": self.user_id, "rating": self.rating, "reason": self.reason, - "created_at": _format_utc_datetime(self.created_at), + "created_at": format_utc_datetime(self.created_at), } @@ -387,12 +378,14 @@ def to_dict(self): "disabled_tools": self.disabled_tools or [], "created_by": self.created_by, "updated_by": self.updated_by, - "created_at": _format_utc_datetime(self.created_at), - "updated_at": _format_utc_datetime(self.updated_at), + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), } def to_mcp_config(self) -> dict: """转换为 MCP 配置格式(用于加载到 MCP_SERVERS 缓存)""" + import json + config = { "transport": self.transport, } @@ -400,10 +393,24 @@ def to_mcp_config(self) -> dict: config["url"] = self.url if self.command: config["command"] = self.command - if self.args: - config["args"] = self.args - if self.headers: - config["headers"] = self.headers + # args 只用于 stdio 传输类型,必须是列表 + if self.transport == "stdio" and self.args: + if isinstance(self.args, list): + config["args"] = self.args + elif isinstance(self.args, str): + try: + config["args"] = json.loads(self.args) + except json.JSONDecodeError: + pass + # headers 只用于 sse/streamable_http 传输类型 + if self.transport in ("sse", "streamable_http") and self.headers: + if isinstance(self.headers, dict): + config["headers"] = self.headers + elif isinstance(self.headers, str): + try: + config["headers"] = json.loads(self.headers) + except json.JSONDecodeError: + pass if self.timeout is not None: config["timeout"] = self.timeout if self.sse_read_timeout is not None: diff --git a/src/storage/postgres/manager.py b/src/storage/postgres/manager.py new file mode 100644 index 000000000..021394c9e --- /dev/null +++ b/src/storage/postgres/manager.py @@ -0,0 +1,226 @@ +"""PostgreSQL 数据库管理器 - 支持知识库和业务数据""" + +import json +import os +from contextlib import asynccontextmanager + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import declarative_base + +from server.utils.singleton import SingletonMeta +from src.storage.postgres.models_business import Base as BusinessBase +from src.storage.postgres.models_knowledge import Base as KnowledgeBase +from src.utils import logger + +# 合并两个 Base +CombinedBase = declarative_base() + +# 继承所有表 +for module in [KnowledgeBase, BusinessBase]: + for table_name in dir(module): + table = getattr(module, table_name) + if isinstance(table, type) and hasattr(table, "__tablename__"): + setattr(CombinedBase, table_name, table) + + +class PostgresManager(metaclass=SingletonMeta): + """PostgreSQL 数据库管理器 - 支持知识库和业务数据""" + + # 知识库 PostgreSQL URL 环境变量名 + KB_DATABASE_URL_ENV = "POSTGRES_URL" + + def __init__(self): + self.async_engine = None + self.AsyncSession = None + self._initialized = False + + def initialize(self): + """初始化数据库连接""" + if self._initialized: + return + + db_url = os.getenv(self.KB_DATABASE_URL_ENV) + if not db_url: + logger.error( + f"环境变量 {self.KB_DATABASE_URL_ENV} 未设置," + "请在 docker-compose.yml 或 .env 中配置 PostgreSQL 连接字符串" + ) + return + + try: + # 创建异步 SQLAlchemy 引擎 + self.async_engine = create_async_engine( + db_url, + json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), + json_deserializer=json.loads, + pool_pre_ping=True, + pool_recycle=1800, + ) + + # 创建异步会话工厂 + self.AsyncSession = async_sessionmaker( + bind=self.async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + self._initialized = True + logger.info(f"PostgreSQL manager initialized for knowledge base: {db_url.split('@')[0]}://***") + except Exception as e: + logger.error(f"Failed to initialize PostgreSQL manager: {e}") + # 不抛出异常,允许应用启动,但在使用时会报错 + + def _check_initialized(self): + """检查是否已初始化""" + if not self._initialized: + raise RuntimeError("PostgreSQL manager not initialized. Please check configuration.") + + async def create_tables(self): + """创建所有表(知识库和业务表)""" + self._check_initialized() + async with self.async_engine.begin() as conn: + await conn.run_sync(KnowledgeBase.metadata.create_all) + await conn.run_sync(BusinessBase.metadata.create_all) + logger.info("PostgreSQL tables created/checked (knowledge + business)") + + async def create_business_tables(self): + """创建所有业务数据表""" + self._check_initialized() + async with self.async_engine.begin() as conn: + await conn.run_sync(BusinessBase.metadata.create_all) + logger.info("PostgreSQL business tables created/checked") + + async def drop_tables(self): + """删除所有表(慎用!)""" + self._check_initialized() + async with self.async_engine.begin() as conn: + await conn.run_sync(BusinessBase.metadata.drop_all) + await conn.run_sync(KnowledgeBase.metadata.drop_all) + logger.info("PostgreSQL tables dropped") + + async def ensure_knowledge_schema(self): + """确保知识库 schema 包含所有必要字段""" + self._check_initialized() + stmts = [ + "ALTER TABLE IF EXISTS knowledge_bases ADD COLUMN IF NOT EXISTS embed_info JSONB", + "ALTER TABLE IF EXISTS knowledge_bases ADD COLUMN IF NOT EXISTS llm_info JSONB", + "ALTER TABLE IF EXISTS knowledge_bases ADD COLUMN IF NOT EXISTS query_params JSONB", + "ALTER TABLE IF EXISTS knowledge_bases ADD COLUMN IF NOT EXISTS additional_params JSONB", + "ALTER TABLE IF EXISTS knowledge_bases ADD COLUMN IF NOT EXISTS share_config JSONB", + "ALTER TABLE IF EXISTS knowledge_bases ADD COLUMN IF NOT EXISTS mindmap JSONB", + "ALTER TABLE IF EXISTS knowledge_bases ADD COLUMN IF NOT EXISTS sample_questions JSONB", + "ALTER TABLE IF EXISTS knowledge_bases ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS parent_id VARCHAR(64)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS original_filename VARCHAR(512)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS file_type VARCHAR(64)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS path VARCHAR(1024)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS minio_url VARCHAR(1024)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS markdown_file VARCHAR(1024)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS status VARCHAR(32)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS content_hash VARCHAR(128)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS file_size BIGINT", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS content_type VARCHAR(64)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS processing_params JSONB", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS is_folder BOOLEAN", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS error_message TEXT", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS created_by VARCHAR(64)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS updated_by VARCHAR(64)", + "ALTER TABLE IF EXISTS knowledge_files ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ", + "ALTER TABLE IF EXISTS evaluation_benchmarks ADD COLUMN IF NOT EXISTS data_file_path VARCHAR(1024)", + "ALTER TABLE IF EXISTS evaluation_benchmarks ADD COLUMN IF NOT EXISTS created_by VARCHAR(64)", + "ALTER TABLE IF EXISTS evaluation_benchmarks ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ", + "ALTER TABLE IF EXISTS evaluation_results ADD COLUMN IF NOT EXISTS metrics JSONB", + "ALTER TABLE IF EXISTS evaluation_results ADD COLUMN IF NOT EXISTS overall_score DOUBLE PRECISION", + "ALTER TABLE IF EXISTS evaluation_results ADD COLUMN IF NOT EXISTS total_questions INTEGER", + "ALTER TABLE IF EXISTS evaluation_results ADD COLUMN IF NOT EXISTS completed_questions INTEGER", + "ALTER TABLE IF EXISTS evaluation_results ADD COLUMN IF NOT EXISTS started_at TIMESTAMPTZ", + "ALTER TABLE IF EXISTS evaluation_results ADD COLUMN IF NOT EXISTS completed_at TIMESTAMPTZ", + "ALTER TABLE IF EXISTS evaluation_results ADD COLUMN IF NOT EXISTS created_by VARCHAR(64)", + "ALTER TABLE IF EXISTS evaluation_result_details ADD COLUMN IF NOT EXISTS gold_chunk_ids JSONB", + "ALTER TABLE IF EXISTS evaluation_result_details ADD COLUMN IF NOT EXISTS gold_answer TEXT", + "ALTER TABLE IF EXISTS evaluation_result_details ADD COLUMN IF NOT EXISTS generated_answer TEXT", + "ALTER TABLE IF EXISTS evaluation_result_details ADD COLUMN IF NOT EXISTS retrieved_chunks JSONB", + "ALTER TABLE IF EXISTS evaluation_result_details ADD COLUMN IF NOT EXISTS metrics JSONB", + "CREATE INDEX IF NOT EXISTS idx_kb_type ON knowledge_bases(kb_type)", + "CREATE INDEX IF NOT EXISTS idx_kb_name ON knowledge_bases(name)", + "CREATE INDEX IF NOT EXISTS idx_kf_db_id ON knowledge_files(db_id)", + "CREATE INDEX IF NOT EXISTS idx_kf_parent ON knowledge_files(parent_id)", + "CREATE INDEX IF NOT EXISTS idx_kf_status ON knowledge_files(status)", + "CREATE INDEX IF NOT EXISTS idx_kf_hash ON knowledge_files(content_hash)", + "CREATE INDEX IF NOT EXISTS idx_eb_db_id ON evaluation_benchmarks(db_id)", + "CREATE INDEX IF NOT EXISTS idx_er_db_id ON evaluation_results(db_id)", + "CREATE INDEX IF NOT EXISTS idx_er_status ON evaluation_results(status)", + "CREATE INDEX IF NOT EXISTS idx_er_started ON evaluation_results(started_at DESC)", + "CREATE INDEX IF NOT EXISTS idx_erd_task ON evaluation_result_details(task_id)", + ] + + async with self.async_engine.begin() as conn: + for stmt in stmts: + await conn.execute(text(stmt)) + + @property + def is_postgresql(self) -> bool: + """检查是否是 PostgreSQL 数据库""" + if not self._initialized: + return False + return self.async_engine.dialect.name == "postgresql" + + async def get_async_session(self) -> AsyncSession: + """获取异步数据库会话""" + self._check_initialized() + return self.AsyncSession() + + @asynccontextmanager + async def get_async_session_context(self): + """获取异步数据库会话的上下文管理器""" + self._check_initialized() + session = self.AsyncSession() + try: + yield session + await session.commit() + except Exception as e: + await session.rollback() + logger.error(f"PostgreSQL async operation failed: {e}") + raise + finally: + await session.close() + + async def close(self): + """关闭引擎""" + if self.async_engine: + await self.async_engine.dispose() + + async def async_check_first_run(self): + """检查是否首次运行(异步版本)- 检查用户表是否有数据""" + from sqlalchemy import func, select + + self._check_initialized() + async with self.get_async_session_context() as session: + from src.storage.postgres.models_business import User + + result = await session.execute(select(func.count(User.id))) + count = result.scalar() + return count == 0 + + async def execute(self, statement): + """直接执行 SQL 语句(用于迁移脚本)""" + self._check_initialized() + async with self.get_async_session_context() as session: + return await session.execute(statement) + + async def add(self, instance): + """添加实例到会话(用于迁移脚本)""" + self._check_initialized() + async with self.get_async_session_context() as session: + session.add(instance) + + async def commit(self): + """提交当前会话""" + self._check_initialized() + async with self.get_async_session_context(): + pass # commit is automatic in context manager + + +# 创建全局 PostgreSQL 管理器实例 +pg_manager = PostgresManager() diff --git a/src/storage/postgres/models_business.py b/src/storage/postgres/models_business.py new file mode 100644 index 000000000..fd5c9ebe9 --- /dev/null +++ b/src/storage/postgres/models_business.py @@ -0,0 +1,501 @@ +"""PostgreSQL 业务数据模型 - 用户、部门、对话等相关表""" + +from typing import Any + +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + Float, + ForeignKey, + Index, + Integer, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +from src.utils.datetime_utils import format_utc_datetime, utc_now_naive + +Base = declarative_base() + + +class Department(Base): + """部门模型""" + + __tablename__ = "departments" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(50), nullable=False, unique=True, index=True) + description = Column(String(255), nullable=True) + created_at = Column(DateTime, default=utc_now_naive) + + # 关联关系 + users = relationship("User", back_populates="department", cascade="all, delete-orphan") + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "created_at": format_utc_datetime(self.created_at), + } + + +class User(Base): + """用户模型""" + + __tablename__ = "users" + + id = Column(Integer, primary_key=True, autoincrement=True) + username = Column(String, nullable=False, unique=True, index=True) # 显示名称 + user_id = Column(String, nullable=False, unique=True, index=True) # 登录ID + phone_number = Column(String, nullable=True, unique=True, index=True) # 手机号 + avatar = Column(String, nullable=True) # 头像URL + password_hash = Column(String, nullable=False) + role = Column(String, nullable=False, default="user") # 角色: superadmin, admin, user + department_id = Column(Integer, ForeignKey("departments.id"), nullable=True) # 部门ID + created_at = Column(DateTime, default=utc_now_naive) + last_login = Column(DateTime, nullable=True) + + # 登录失败限制相关字段 + login_failed_count = Column(Integer, nullable=False, default=0) # 登录失败次数 + last_failed_login = Column(DateTime, nullable=True) # 最后一次登录失败时间 + login_locked_until = Column(DateTime, nullable=True) # 锁定到什么时候 + + # 软删除相关字段 + is_deleted = Column(Integer, nullable=False, default=0, index=True) # 是否已删除:0=否,1=是 + deleted_at = Column(DateTime, nullable=True) # 删除时间 + + # 关联操作日志 + operation_logs = relationship("OperationLog", back_populates="user", cascade="all, delete-orphan") + + # 关联部门 + department = relationship("Department", back_populates="users") + + def to_dict(self, include_password: bool = False) -> dict[str, Any]: + result = { + "id": self.id, + "username": self.username, + "user_id": self.user_id, + "phone_number": self.phone_number, + "avatar": self.avatar, + "role": self.role, + "department_id": self.department_id, + "created_at": format_utc_datetime(self.created_at), + "last_login": format_utc_datetime(self.last_login), + "login_failed_count": self.login_failed_count, + "last_failed_login": format_utc_datetime(self.last_failed_login), + "login_locked_until": format_utc_datetime(self.login_locked_until), + "is_deleted": self.is_deleted, + "deleted_at": format_utc_datetime(self.deleted_at), + } + if include_password: + result["password_hash"] = self.password_hash + return result + + def is_login_locked(self) -> bool: + """检查用户是否处于登录锁定状态""" + if self.login_locked_until is None: + return False + return utc_now_naive() < self.login_locked_until + + def get_remaining_lock_time(self) -> int: + """获取剩余锁定时间(秒)""" + if self.login_locked_until is None: + return 0 + remaining = int((self.login_locked_until - utc_now_naive()).total_seconds()) + return max(0, remaining) + + def reset_failed_login(self): + """重置登录失败相关字段""" + self.login_failed_count = 0 + self.last_failed_login = None + self.login_locked_until = None + + +class AgentConfig(Base): + """智能体配置(按部门共享,多份可切换)""" + + __tablename__ = "agent_configs" + + id = Column(Integer, primary_key=True, autoincrement=True) + department_id = Column(Integer, ForeignKey("departments.id"), nullable=False, index=True) + agent_id = Column(String(64), nullable=False, index=True) + + name = Column(String(100), nullable=False) + description = Column(String(255), nullable=True) + icon = Column(String(255), nullable=True) + + pics = Column(JSON, nullable=False, default=list) + examples = Column(JSON, nullable=False, default=list) + config_json = Column(JSON, nullable=False, default=dict) + + is_default = Column(Boolean, nullable=False, default=False, index=True) + + created_by = Column(String(64), nullable=True) + updated_by = Column(String(64), nullable=True) + created_at = Column(DateTime, default=utc_now_naive) + updated_at = Column(DateTime, default=utc_now_naive, onupdate=utc_now_naive) + + __table_args__ = ( + UniqueConstraint("department_id", "agent_id", "name", name="uq_agent_configs_department_agent_name"), + Index( + "uq_agent_configs_department_agent_default", + "department_id", + "agent_id", + unique=True, + postgresql_where=is_default.is_(True), + ), + ) + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "department_id": self.department_id, + "agent_id": self.agent_id, + "name": self.name, + "description": self.description, + "icon": self.icon, + "pics": self.pics or [], + "examples": self.examples or [], + "config_json": self.config_json or {}, + "is_default": bool(self.is_default), + "created_by": self.created_by, + "updated_by": self.updated_by, + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), + } + + +class Conversation(Base): + """Conversation table - 对话表""" + + __tablename__ = "conversations" + + id = Column(Integer, primary_key=True, autoincrement=True, comment="Primary key") + thread_id = Column(String(64), unique=True, index=True, nullable=False, comment="Thread ID (UUID)") + user_id = Column(String(64), index=True, nullable=False, comment="User ID") + agent_id = Column(String(64), index=True, nullable=False, comment="Agent ID") + title = Column(String(255), nullable=True, comment="Conversation title") + status = Column(String(20), default="active", comment="Status: active/archived/deleted") + created_at = Column(DateTime, default=utc_now_naive, comment="Creation time") + updated_at = Column(DateTime, default=utc_now_naive, onupdate=utc_now_naive, comment="Update time") + extra_metadata = Column(JSON, nullable=True, comment="Additional metadata") + + # Relationships + messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") + stats = relationship( + "ConversationStats", back_populates="conversation", uselist=False, cascade="all, delete-orphan" + ) + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "thread_id": self.thread_id, + "user_id": self.user_id, + "agent_id": self.agent_id, + "title": self.title, + "status": self.status, + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), + "metadata": self.extra_metadata or {}, + } + + +class Message(Base): + """Message table - 消息表""" + + __tablename__ = "messages" + + id = Column(Integer, primary_key=True, autoincrement=True, comment="Primary key") + conversation_id = Column( + Integer, ForeignKey("conversations.id"), nullable=False, index=True, comment="Conversation ID" + ) + role = Column(String(20), nullable=False, comment="Message role: user/assistant/system/tool") + content = Column(Text, nullable=False, comment="Message content") + message_type = Column(String(30), default="text", comment="Message type: text/tool_call/tool_result") + created_at = Column(DateTime, default=utc_now_naive, comment="Creation time") + token_count = Column(Integer, nullable=True, comment="Token count (optional)") + extra_metadata = Column(JSON, nullable=True, comment="Additional metadata (complete message dump)") + image_content = Column(Text, nullable=True, comment="Base64 encoded image content for multimodal messages") + + # Relationships + conversation = relationship("Conversation", back_populates="messages") + tool_calls = relationship("ToolCall", back_populates="message", cascade="all, delete-orphan") + feedbacks = relationship("MessageFeedback", back_populates="message", cascade="all, delete-orphan") + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "conversation_id": self.conversation_id, + "role": self.role, + "content": self.content, + "message_type": self.message_type, + "created_at": format_utc_datetime(self.created_at), + "token_count": self.token_count, + "metadata": self.extra_metadata or {}, + "image_content": self.image_content, + "tool_calls": [tc.to_dict() for tc in self.tool_calls] if self.tool_calls else [], + } + + def to_simple_dict(self) -> dict[str, Any]: + return { + "role": self.role, + "content": self.content, + } + + +class ToolCall(Base): + """ToolCall table - 工具调用表""" + + __tablename__ = "tool_calls" + + id = Column(Integer, primary_key=True, autoincrement=True, comment="Primary key") + message_id = Column(Integer, ForeignKey("messages.id"), nullable=False, index=True, comment="Message ID") + langgraph_tool_call_id = Column(String(100), nullable=True, index=True, comment="LangGraph tool_call_id") + tool_name = Column(String(100), nullable=False, comment="Tool name") + tool_input = Column(JSON, nullable=True, comment="Tool input parameters") + tool_output = Column(Text, nullable=True, comment="Tool execution result") + status = Column(String(20), default="pending", comment="Status: pending/success/error") + error_message = Column(Text, nullable=True, comment="Error message if failed") + created_at = Column(DateTime, default=utc_now_naive, comment="Creation time") + + # Relationships + message = relationship("Message", back_populates="tool_calls") + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "message_id": self.message_id, + "langgraph_tool_call_id": self.langgraph_tool_call_id, + "tool_name": self.tool_name, + "tool_input": self.tool_input or {}, + "tool_output": self.tool_output, + "status": self.status, + "error_message": self.error_message, + "created_at": format_utc_datetime(self.created_at), + } + + +class ConversationStats(Base): + """ConversationStats table - 对话统计表""" + + __tablename__ = "conversation_stats" + + id = Column(Integer, primary_key=True, autoincrement=True, comment="Primary key") + conversation_id = Column( + Integer, ForeignKey("conversations.id"), unique=True, nullable=False, comment="Conversation ID" + ) + message_count = Column(Integer, default=0, comment="Total message count") + total_tokens = Column(Integer, default=0, comment="Total tokens used") + model_used = Column(String(100), nullable=True, comment="Model used") + user_feedback = Column(JSON, nullable=True, comment="User feedback") + created_at = Column(DateTime, default=utc_now_naive, comment="Creation time") + updated_at = Column(DateTime, default=utc_now_naive, onupdate=utc_now_naive, comment="Update time") + + # Relationships + conversation = relationship("Conversation", back_populates="stats") + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "conversation_id": self.conversation_id, + "message_count": self.message_count, + "total_tokens": self.total_tokens, + "model_used": self.model_used, + "user_feedback": self.user_feedback or {}, + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), + } + + +class OperationLog(Base): + """操作日志模型""" + + __tablename__ = "operation_logs" + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) + operation = Column(String, nullable=False) + details = Column(Text, nullable=True) + ip_address = Column(String, nullable=True) + timestamp = Column(DateTime, default=utc_now_naive) + + # 关联用户 + user = relationship("User", back_populates="operation_logs") + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "user_id": self.user_id, + "operation": self.operation, + "details": self.details, + "ip_address": self.ip_address, + "timestamp": format_utc_datetime(self.timestamp), + } + + +class MessageFeedback(Base): + """Message feedback table - 消息反馈表""" + + __tablename__ = "message_feedbacks" + + id = Column(Integer, primary_key=True, autoincrement=True, comment="Primary key") + message_id = Column( + Integer, ForeignKey("messages.id"), nullable=False, index=True, comment="Message ID being rated" + ) + user_id = Column(String(64), nullable=False, index=True, comment="User ID who provided feedback") + rating = Column(String(10), nullable=False, comment="Feedback rating: like or dislike") + reason = Column(Text, nullable=True, comment="Optional reason for dislike feedback") + created_at = Column(DateTime, default=utc_now_naive, comment="Feedback creation time") + + # Relationships + message = relationship("Message", back_populates="feedbacks") + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "message_id": self.message_id, + "user_id": self.user_id, + "rating": self.rating, + "reason": self.reason, + "created_at": format_utc_datetime(self.created_at), + } + + +class MCPServer(Base): + """MCP 服务器配置模型""" + + __tablename__ = "mcp_servers" + + # 核心字段 - name 作为主键 + name = Column(String(100), primary_key=True, comment="服务器名称(唯一标识)") + description = Column(String(500), nullable=True, comment="描述") + + # 连接配置 + transport = Column(String(20), nullable=False, comment="传输类型:sse/streamable_http/stdio") + url = Column(String(500), nullable=True, comment="服务器 URL(sse/streamable_http)") + command = Column(String(500), nullable=True, comment="命令(stdio)") + args = Column(JSON, nullable=True, comment="命令参数数组(stdio)") + headers = Column(JSON, nullable=True, comment="HTTP 请求头") + timeout = Column(Integer, nullable=True, comment="HTTP 超时时间(秒)") + sse_read_timeout = Column(Integer, nullable=True, comment="SSE 读取超时(秒)") + + # UI 增强字段 + tags = Column(JSON, nullable=True, comment="标签数组") + icon = Column(String(50), nullable=True, comment="图标(emoji)") + + # 状态字段 + enabled = Column(Integer, nullable=False, default=1, comment="是否启用:1=是,0=否") + disabled_tools = Column(JSON, nullable=True, comment="禁用的工具名称列表") + + # 用户追踪 + created_by = Column(String(100), nullable=False, comment="创建人用户名") + updated_by = Column(String(100), nullable=False, comment="修改人用户名") + + # 时间戳 + created_at = Column(DateTime, default=utc_now_naive, comment="创建时间") + updated_at = Column(DateTime, default=utc_now_naive, onupdate=utc_now_naive, comment="更新时间") + + def to_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "transport": self.transport, + "url": self.url, + "command": self.command, + "args": self.args or [], + "headers": self.headers or {}, + "timeout": self.timeout, + "sse_read_timeout": self.sse_read_timeout, + "tags": self.tags or [], + "icon": self.icon, + "enabled": bool(self.enabled), + "disabled_tools": self.disabled_tools or [], + "created_by": self.created_by, + "updated_by": self.updated_by, + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), + } + + def to_mcp_config(self) -> dict[str, Any]: + """转换为 MCP 配置格式(用于加载到 MCP_SERVERS 缓存)""" + import json + + config = {"transport": self.transport} + if self.url: + config["url"] = self.url + if self.command: + config["command"] = self.command + # args 只用于 stdio 传输类型,必须是列表 + if self.transport == "stdio" and self.args: + if isinstance(self.args, list): + config["args"] = self.args + elif isinstance(self.args, str): + try: + config["args"] = json.loads(self.args) + except json.JSONDecodeError: + pass + # headers 只用于 sse/streamable_http 传输类型 + if self.transport in ("sse", "streamable_http") and self.headers: + if isinstance(self.headers, dict): + config["headers"] = self.headers + elif isinstance(self.headers, str): + try: + config["headers"] = json.loads(self.headers) + except json.JSONDecodeError: + pass + if self.timeout is not None: + config["timeout"] = self.timeout + if self.sse_read_timeout is not None: + config["sse_read_timeout"] = self.sse_read_timeout + if self.disabled_tools: + config["disabled_tools"] = self.disabled_tools + return config + + +class TaskRecord(Base): + __tablename__ = "tasks" + + id = Column(String(32), primary_key=True) + name = Column(String(255), nullable=False) + type = Column(String(64), nullable=False, index=True) + status = Column(String(32), nullable=False, default="pending", index=True) + progress = Column(Float, nullable=False, default=0.0) + message = Column(Text, nullable=False, default="") + payload = Column(JSON, nullable=True) + result = Column(JSON, nullable=True) + error = Column(Text, nullable=True) + cancel_requested = Column(Integer, nullable=False, default=0) + created_at = Column(DateTime, default=utc_now_naive, index=True) + updated_at = Column(DateTime, default=utc_now_naive, onupdate=utc_now_naive) + started_at = Column(DateTime, nullable=True) + completed_at = Column(DateTime, nullable=True) + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "name": self.name, + "type": self.type, + "status": self.status, + "progress": self.progress, + "message": self.message, + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), + "started_at": format_utc_datetime(self.started_at), + "completed_at": format_utc_datetime(self.completed_at), + "payload": self.payload or {}, + "result": self.result, + "error": self.error, + "cancel_requested": bool(self.cancel_requested), + } + + def to_summary_dict(self) -> dict[str, Any]: + data = self.to_dict() + data.pop("payload", None) + data.pop("result", None) + return data diff --git a/src/storage/postgres/models_knowledge.py b/src/storage/postgres/models_knowledge.py new file mode 100644 index 000000000..1fed512a3 --- /dev/null +++ b/src/storage/postgres/models_knowledge.py @@ -0,0 +1,139 @@ +"""PostgreSQL 知识库模型 - KnowledgeBase、KnowledgeFile、评估相关表""" + +from sqlalchemy import ( + JSON, + BigInteger, + Boolean, + Column, + DateTime, + Float, + ForeignKey, + Integer, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.dialects.postgresql import JSONB + +from src.storage.postgres.models_business import Base +from src.utils.datetime_utils import utc_now_naive + +JSON_VALUE = JSON().with_variant(JSONB, "postgresql") + + +class KnowledgeBase(Base): + """知识库模型""" + + __tablename__ = "knowledge_bases" + __table_args__ = (UniqueConstraint("db_id", name="uq_knowledge_bases_db_id"),) + + id = Column(Integer, primary_key=True, autoincrement=True) + db_id = Column(String(64), unique=True, nullable=False, index=True) + name = Column(String(255), nullable=False, index=True) + description = Column(Text) + kb_type = Column(String(32), nullable=False, index=True) + embed_info = Column(JSON_VALUE) + llm_info = Column(JSON_VALUE) + query_params = Column(JSON_VALUE) + additional_params = Column(JSON_VALUE) + share_config = Column(JSON_VALUE) + mindmap = Column(JSON_VALUE) + sample_questions = Column(JSON_VALUE) + created_at = Column(DateTime(timezone=True), default=utc_now_naive) + updated_at = Column(DateTime(timezone=True), default=utc_now_naive, onupdate=utc_now_naive) + + +class KnowledgeFile(Base): + """知识文件模型""" + + __tablename__ = "knowledge_files" + __table_args__ = (UniqueConstraint("file_id", name="uq_knowledge_files_file_id"),) + + id = Column(Integer, primary_key=True, autoincrement=True) + file_id = Column(String(64), unique=True, nullable=False, index=True) + db_id = Column(String(64), ForeignKey("knowledge_bases.db_id", ondelete="CASCADE"), nullable=False, index=True) + parent_id = Column(String(64), ForeignKey("knowledge_files.file_id", ondelete="SET NULL"), index=True) + filename = Column(String(512), nullable=False) + original_filename = Column(String(512)) + file_type = Column(String(64)) + path = Column(String(1024)) + minio_url = Column(String(1024)) + markdown_file = Column(String(1024)) + status = Column(String(32), default="uploaded", index=True) + content_hash = Column(String(128), index=True) + file_size = Column(BigInteger) + content_type = Column(String(64)) + processing_params = Column(JSON_VALUE) + is_folder = Column(Boolean, default=False) + error_message = Column(Text) + created_by = Column(String(64)) + updated_by = Column(String(64)) + created_at = Column(DateTime(timezone=True), default=utc_now_naive) + updated_at = Column(DateTime(timezone=True), default=utc_now_naive, onupdate=utc_now_naive) + + +class EvaluationBenchmark(Base): + """评估基准模型""" + + __tablename__ = "evaluation_benchmarks" + __table_args__ = (UniqueConstraint("benchmark_id", name="uq_evaluation_benchmarks_benchmark_id"),) + + id = Column(Integer, primary_key=True, autoincrement=True) + benchmark_id = Column(String(64), unique=True, nullable=False, index=True) + db_id = Column(String(64), ForeignKey("knowledge_bases.db_id", ondelete="CASCADE"), nullable=False, index=True) + name = Column(String(255), nullable=False) + description = Column(Text) + question_count = Column(Integer, default=0) + has_gold_chunks = Column(Boolean, default=False) + has_gold_answers = Column(Boolean, default=False) + data_file_path = Column(String(1024)) + created_by = Column(String(64)) + created_at = Column(DateTime(timezone=True), default=utc_now_naive) + updated_at = Column(DateTime(timezone=True), default=utc_now_naive, onupdate=utc_now_naive) + + +class EvaluationResult(Base): + """评估结果模型""" + + __tablename__ = "evaluation_results" + __table_args__ = (UniqueConstraint("task_id", name="uq_evaluation_results_task_id"),) + + id = Column(Integer, primary_key=True, autoincrement=True) + task_id = Column(String(64), unique=True, nullable=False, index=True) + db_id = Column(String(64), ForeignKey("knowledge_bases.db_id", ondelete="CASCADE"), nullable=False, index=True) + benchmark_id = Column( + String(64), + ForeignKey("evaluation_benchmarks.benchmark_id", ondelete="SET NULL"), + index=True, + ) + status = Column(String(32), default="running", index=True) + retrieval_config = Column(JSON_VALUE) + metrics = Column(JSON_VALUE) + overall_score = Column(Float) + total_questions = Column(Integer, default=0) + completed_questions = Column(Integer, default=0) + started_at = Column(DateTime(timezone=True), default=utc_now_naive, index=True) + completed_at = Column(DateTime(timezone=True)) + created_by = Column(String(64)) + + +class EvaluationResultDetail(Base): + """评估结果详情模型""" + + __tablename__ = "evaluation_result_details" + __table_args__ = (UniqueConstraint("task_id", "query_index", name="uq_evaluation_result_details_task_query"),) + + id = Column(Integer, primary_key=True, autoincrement=True) + task_id = Column( + String(64), + ForeignKey("evaluation_results.task_id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + query_index = Column(Integer, nullable=False) + query_text = Column(Text, nullable=False) + gold_chunk_ids = Column(JSON_VALUE) + gold_answer = Column(Text) + generated_answer = Column(Text) + retrieved_chunks = Column(JSON_VALUE) + metrics = Column(JSON_VALUE) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 7c9016ae3..3c2b8ab58 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,6 +1,7 @@ import hashlib import os import time +import uuid from src.utils.logging_config import logger @@ -42,10 +43,11 @@ def hashstr(input_string, length=None, with_salt=False, salt=None): if with_salt: if not salt: - salt = str(time.time()) + # 使用时间戳+随机数的组合作为salt,确保唯一性 + salt = f"{time.time()}_{uuid.uuid4().hex[:8]}" encoded_string = (encoded_string.decode("utf-8") + salt).encode("utf-8") - hash = hashlib.md5(encoded_string).hexdigest() + hash = hashlib.sha256(encoded_string).hexdigest() if length: return hash[:length] return hash diff --git a/src/utils/datetime_utils.py b/src/utils/datetime_utils.py index 23361a4f8..084596104 100644 --- a/src/utils/datetime_utils.py +++ b/src/utils/datetime_utils.py @@ -22,6 +22,11 @@ def utc_now() -> dt.datetime: return dt.datetime.now(UTC) +def utc_now_naive() -> dt.datetime: + """Return the current UTC time as a naive datetime (for legacy DB fields).""" + return dt.datetime.now(UTC).replace(tzinfo=None) + + def shanghai_now() -> dt.datetime: """Return the current Asia/Shanghai time as an aware datetime.""" return utc_now().astimezone(SHANGHAI_TZ) @@ -110,10 +115,23 @@ def normalize_iterable_to_utc(values: Iterable[dt.datetime | None]) -> list[dt.d return [coerce_datetime(item) if isinstance(item, dt.datetime) else None for item in values] +def format_utc_datetime(value: dt.datetime | None) -> str | None: + """ + Format a datetime to UTC ISO 8601 string, handling naive datetimes. + + Returns None for None input. + Naive datetimes are assumed to be in UTC (legacy behavior). + """ + if value is None: + return None + return utc_isoformat(value) + + __all__ = [ "UTC", "SHANGHAI_TZ", "utc_now", + "utc_now_naive", "shanghai_now", "ensure_utc", "ensure_shanghai", @@ -122,4 +140,5 @@ def normalize_iterable_to_utc(values: Iterable[dt.datetime | None]) -> list[dt.d "coerce_datetime", "coerce_any_to_utc_datetime", "normalize_iterable_to_utc", + "format_utc_datetime", ] diff --git a/test/.env.test.example b/test/.env.test.example index f69467e29..66d0b8d6d 100644 --- a/test/.env.test.example +++ b/test/.env.test.example @@ -8,8 +8,6 @@ TEST_BASE_URL=http://localhost:5050 TEST_USERNAME= TEST_PASSWORD= -# 数据库连接(如果需要) -# DATABASE_URL=sqlite:///test.db # 其他测试相关配置 # ENABLE_DEBUG=true \ No newline at end of file diff --git a/test/api/test_dashboard_router.py b/test/api/test_dashboard_router.py index d05578629..83d7107b3 100644 --- a/test/api/test_dashboard_router.py +++ b/test/api/test_dashboard_router.py @@ -23,3 +23,35 @@ async def test_admin_can_fetch_conversations(test_client, admin_headers): response = await test_client.get("/api/dashboard/conversations", headers=admin_headers) assert response.status_code == 200, response.text assert isinstance(response.json(), list) + + +async def test_admin_can_fetch_stats(test_client, admin_headers): + """Test that all stats endpoints return 200 and don't crash on DB queries.""" + + # Test call timeseries stats for all types + types = ["models", "agents", "tokens", "tools"] + for stats_type in types: + response = await test_client.get( + f"/api/dashboard/stats/calls/timeseries?type={stats_type}&time_range=14days", headers=admin_headers + ) + assert response.status_code == 200, f"{stats_type} stats failed: {response.text}" + data = response.json() + assert "data" in data + assert "categories" in data + + # Test user activity stats + response = await test_client.get("/api/dashboard/stats/users", headers=admin_headers) + assert response.status_code == 200, f"user stats failed: {response.text}" + assert "total_users" in response.json() + + # Test tool call stats + response = await test_client.get("/api/dashboard/stats/tools", headers=admin_headers) + assert response.status_code == 200, f"tool stats failed: {response.text}" + assert "total_calls" in response.json() + + +async def test_admin_can_fetch_feedbacks(test_client, admin_headers): + """Test that feedback endpoint returns 200 and handles the User join correctly.""" + response = await test_client.get("/api/dashboard/feedbacks", headers=admin_headers) + assert response.status_code == 200, f"feedbacks failed: {response.text}" + assert isinstance(response.json(), list) diff --git a/test/api/test_graph_router_list.py b/test/api/test_graph_router_list.py new file mode 100644 index 000000000..5c70cadc8 --- /dev/null +++ b/test/api/test_graph_router_list.py @@ -0,0 +1,30 @@ +""" +Integration tests for graph router list endpoint. +""" + +from __future__ import annotations + +import pytest + +pytestmark = [pytest.mark.asyncio, pytest.mark.integration] + + +async def test_admin_can_list_graphs(test_client, admin_headers): + """Test that listing graphs returns 200 and a list of graphs.""" + response = await test_client.get("/api/graph/list", headers=admin_headers) + assert response.status_code == 200, f"Failed to list graphs: {response.text}" + data = response.json() + + # Check if response is wrapped + if isinstance(data, dict) and "data" in data: + graphs = data["data"] + else: + graphs = data + + assert isinstance(graphs, list) + # Check structure of returned items if list is not empty + if graphs: + item = graphs[0] + assert "id" in item + assert "name" in item + assert "type" in item diff --git a/test/api/test_knowledge_router.py b/test/api/test_knowledge_router.py index 482b1f901..29fa366fc 100644 --- a/test/api/test_knowledge_router.py +++ b/test/api/test_knowledge_router.py @@ -1,5 +1,5 @@ """ -Integration tests for knowledge router endpoints. +Integration tests for knowledge router and mindmap router endpoints. """ from __future__ import annotations @@ -11,6 +11,14 @@ pytestmark = [pytest.mark.asyncio, pytest.mark.integration] +def _assert_forbidden_response(response): + """验证 403 禁止访问响应的格式""" + assert response.status_code == 403 + payload = response.json() + assert "detail" in payload + assert isinstance(payload["detail"], str) + + async def test_admin_can_manage_knowledge_databases(test_client, admin_headers, knowledge_database): db_id = knowledge_database["db_id"] @@ -44,17 +52,20 @@ async def test_knowledge_routes_enforce_permissions(test_client, standard_user, }, headers=standard_user["headers"], ) - assert forbidden_create.status_code == 403 + _assert_forbidden_response(forbidden_create) forbidden_list = await test_client.get("/api/knowledge/databases", headers=standard_user["headers"]) - assert forbidden_list.status_code == 403 + _assert_forbidden_response(forbidden_list) forbidden_get = await test_client.get(f"/api/knowledge/databases/{db_id}", headers=standard_user["headers"]) - assert forbidden_get.status_code == 403 + _assert_forbidden_response(forbidden_get) async def test_admin_can_create_vector_db_with_reranker(test_client, admin_headers): - """测试创建向量库并配置 reranker 参数(通过 query_params.options)""" + """测试创建向量库并配置 reranker 参数(通过 query_params.options) + + 注意:数据库清理由 conftest.py 中的 session fixture 自动处理。 + """ db_name = f"pytest_rerank_{uuid.uuid4().hex[:6]}" payload = { "database_name": db_name, @@ -70,58 +81,260 @@ async def test_admin_can_create_vector_db_with_reranker(test_client, admin_heade db_payload = create_response.json() db_id = db_payload["db_id"] - try: - # 获取查询参数配置 - params_response = await test_client.get(f"/api/knowledge/databases/{db_id}/query-params", headers=admin_headers) - assert params_response.status_code == 200, params_response.text - - params_payload = params_response.json() - options = params_payload.get("params", {}).get("options", []) - option_keys = {option.get("key") for option in options} - - # 验证新的参数名称 - assert "final_top_k" in option_keys - assert "use_reranker" in option_keys - assert "recall_top_k" in option_keys - assert "reranker_model" in option_keys - - # 验证参数配置 - final_top_k_option = next((opt for opt in options if opt.get("key") == "final_top_k"), None) - assert final_top_k_option is not None - assert final_top_k_option.get("default") == 10 - - use_reranker_option = next((opt for opt in options if opt.get("key") == "use_reranker"), None) - assert use_reranker_option is not None - assert use_reranker_option.get("default") is False - - # 保存查询参数(模拟前端配置) - update_params = { - "final_top_k": 5, - "use_reranker": True, - "recall_top_k": 20, - } - update_response = await test_client.put( - f"/api/knowledge/databases/{db_id}/query-params", json=update_params, headers=admin_headers - ) - assert update_response.status_code == 200, update_response.text - - # 再次获取参数,验证保存成功 - params_response2 = await test_client.get( - f"/api/knowledge/databases/{db_id}/query-params", headers=admin_headers - ) - assert params_response2.status_code == 200, params_response2.text - - params_payload2 = params_response2.json() - options2 = params_payload2.get("params", {}).get("options", []) - - # 验证保存的值 - final_top_k_option2 = next((opt for opt in options2 if opt.get("key") == "final_top_k"), None) - assert final_top_k_option2 is not None - assert final_top_k_option2.get("default") == 5 # 保存的值 - - use_reranker_option2 = next((opt for opt in options2 if opt.get("key") == "use_reranker"), None) - assert use_reranker_option2 is not None - assert use_reranker_option2.get("default") is True # 保存的值 - - finally: - await test_client.delete(f"/api/knowledge/databases/{db_id}", headers=admin_headers) + # 获取查询参数配置 + params_response = await test_client.get(f"/api/knowledge/databases/{db_id}/query-params", headers=admin_headers) + assert params_response.status_code == 200, params_response.text + + params_payload = params_response.json() + options = params_payload.get("params", {}).get("options", []) + option_keys = {option.get("key") for option in options} + + # 验证新的参数名称 + assert "final_top_k" in option_keys + assert "use_reranker" in option_keys + assert "recall_top_k" in option_keys + assert "reranker_model" in option_keys + + # 验证参数配置 + final_top_k_option = next((opt for opt in options if opt.get("key") == "final_top_k"), None) + assert final_top_k_option is not None + assert final_top_k_option.get("default") == 10 + + use_reranker_option = next((opt for opt in options if opt.get("key") == "use_reranker"), None) + assert use_reranker_option is not None + assert use_reranker_option.get("default") is False + + # 保存查询参数(模拟前端配置) + update_params = { + "final_top_k": 5, + "use_reranker": True, + "recall_top_k": 20, + } + update_response = await test_client.put( + f"/api/knowledge/databases/{db_id}/query-params", json=update_params, headers=admin_headers + ) + assert update_response.status_code == 200, update_response.text + + # 再次获取参数,验证保存成功 + params_response2 = await test_client.get(f"/api/knowledge/databases/{db_id}/query-params", headers=admin_headers) + assert params_response2.status_code == 200, params_response2.text + + params_payload2 = params_response2.json() + options2 = params_payload2.get("params", {}).get("options", []) + + # 验证保存的值 + final_top_k_option2 = next((opt for opt in options2 if opt.get("key") == "final_top_k"), None) + assert final_top_k_option2 is not None + assert final_top_k_option2.get("default") == 5 # 保存的值 + + use_reranker_option2 = next((opt for opt in options2 if opt.get("key") == "use_reranker"), None) + assert use_reranker_option2 is not None + assert use_reranker_option2.get("default") is True # 保存的值 + + +# ============================================================================= +# === Mindmap Router Tests === +# ============================================================================= + + +async def test_get_databases_overview(test_client, admin_headers, knowledge_database): + """测试获取所有知识库概览""" + response = await test_client.get("/api/mindmap/databases", headers=admin_headers) + assert response.status_code == 200, response.text + payload = response.json() + assert payload["message"] == "success" + assert "databases" in payload + assert "total" in payload + + # 验证知识库在列表中 + db_ids = [db["db_id"] for db in payload["databases"]] + assert knowledge_database["db_id"] in db_ids + + +async def test_get_database_files(test_client, admin_headers, knowledge_database): + """测试获取知识库文件列表""" + db_id = knowledge_database["db_id"] + response = await test_client.get(f"/api/mindmap/databases/{db_id}/files", headers=admin_headers) + assert response.status_code == 200, response.text + payload = response.json() + assert payload["message"] == "success" + assert payload["db_id"] == db_id + assert "files" in payload + assert "total" in payload + assert payload["db_name"] == knowledge_database["name"] + + +async def test_get_database_files_not_found(test_client, admin_headers): + """测试获取不存在的知识库文件列表""" + response = await test_client.get("/api/mindmap/databases/nonexistent_db_id/files", headers=admin_headers) + assert response.status_code == 404 + + +async def test_generate_mindmap_empty_files(test_client, admin_headers, knowledge_database): + """测试空文件列表生成思维导图""" + db_id = knowledge_database["db_id"] + response = await test_client.post( + "/api/mindmap/generate", + json={"db_id": db_id, "file_ids": [], "user_prompt": ""}, + headers=admin_headers, + ) + # 空文件应该返回400错误 + assert response.status_code == 400 + assert "中没有文件" in response.json()["detail"] + + +async def test_get_database_mindmap_not_exists(test_client, admin_headers, knowledge_database): + """测试获取不存在的思维导图""" + db_id = knowledge_database["db_id"] + response = await test_client.get(f"/api/mindmap/database/{db_id}", headers=admin_headers) + assert response.status_code == 200, response.text + payload = response.json() + assert payload["db_id"] == db_id + assert payload["mindmap"] is None # 尚未生成思维导图 + + +async def test_generate_and_get_mindmap(test_client, admin_headers, knowledge_database): + """测试生成并获取思维导图 + + 注意:此测试需要知识库中有文件才能完整测试核心功能。 + 由于没有前置的文件上传 fixture,测试会先验证空文件场景(预期400), + 然后使用 xfail 标记等待后续完善。 + """ + db_id = knowledge_database["db_id"] + + # 空文件场景 - 预期返回400错误 + generate_response = await test_client.post( + "/api/mindmap/generate", + json={"db_id": db_id, "file_ids": [], "user_prompt": ""}, + headers=admin_headers, + ) + assert generate_response.status_code == 400 + assert "中没有文件" in generate_response.json()["detail"] + + # 标记此测试需要文件上传支持才能完整执行 + pytest.skip("需要先上传文件才能完整测试思维导图生成功能") + + +# ============================================================================= +# === Knowledge Router Additional Tests === +# ============================================================================= + + +async def test_get_accessible_databases(test_client, admin_headers, knowledge_database): + """测试获取可访问的知识库列表""" + response = await test_client.get("/api/knowledge/databases/accessible", headers=admin_headers) + assert response.status_code == 200, response.text + payload = response.json() + assert "databases" in payload + + # 验证知识库在列表中 + db_ids = [db["db_id"] for db in payload["databases"]] + assert knowledge_database["db_id"] in db_ids + + +async def test_get_knowledge_base_types(test_client, admin_headers): + """测试获取支持的知识库类型""" + response = await test_client.get("/api/knowledge/types", headers=admin_headers) + assert response.status_code == 200, response.text + payload = response.json() + assert payload["message"] == "success" + assert "kb_types" in payload + + +async def test_get_knowledge_base_statistics(test_client, admin_headers): + """测试获取知识库统计信息""" + response = await test_client.get("/api/knowledge/stats", headers=admin_headers) + assert response.status_code == 200, response.text + payload = response.json() + assert payload["message"] == "success" + assert "stats" in payload + + +async def test_get_supported_file_types(test_client, admin_headers): + """测试获取支持的文件类型""" + response = await test_client.get("/api/knowledge/files/supported-types", headers=admin_headers) + assert response.status_code == 200, response.text + payload = response.json() + assert payload["message"] == "success" + assert "file_types" in payload + assert isinstance(payload["file_types"], list) + + +async def test_duplicate_database_name(test_client, admin_headers, knowledge_database): + """测试重复创建同名知识库""" + db_name = knowledge_database["name"] + response = await test_client.post( + "/api/knowledge/databases", + json={ + "database_name": db_name, + "description": "Duplicate name test", + "embed_model_name": "siliconflow/BAAI/bge-m3", + "kb_type": "lightrag", + "additional_params": {}, + }, + headers=admin_headers, + ) + assert response.status_code == 409 + assert "已存在" in response.json()["detail"] + + +async def test_create_milvus_knowledge_base(test_client, admin_headers): + """测试创建 Milvus 知识库 + + 注意:数据库清理由 conftest.py 中的 session fixture 自动处理。 + """ + db_name = f"pytest_milvus_{uuid.uuid4().hex[:6]}" + payload = { + "database_name": db_name, + "description": "Pytest Milvus knowledge base", + "embed_model_name": "siliconflow/BAAI/bge-m3", + "kb_type": "milvus", + "additional_params": {}, + } + + create_response = await test_client.post("/api/knowledge/databases", json=payload, headers=admin_headers) + assert create_response.status_code == 200, create_response.text + + db_payload = create_response.json() + assert db_payload["kb_type"] == "milvus" + + +async def test_sample_questions_endpoints(test_client, admin_headers, knowledge_database): + """测试示例问题接口(空文件时预期返回400)""" + db_id = knowledge_database["db_id"] + + # 获取示例问题(空知识库应该返回空列表) + get_response = await test_client.get(f"/api/knowledge/databases/{db_id}/sample-questions", headers=admin_headers) + assert get_response.status_code == 200, get_response.text + get_payload = get_response.json() + assert get_payload["db_id"] == db_id + assert "questions" in get_payload + assert get_payload["count"] == 0 # 空知识库没有问题 + + # 生成示例问题(空知识库应该返回400) + generate_response = await test_client.post( + f"/api/knowledge/databases/{db_id}/sample-questions", + json={"count": 5}, + headers=admin_headers, + ) + assert generate_response.status_code == 400 + assert "中没有文件" in generate_response.json()["detail"] + + +async def test_mindmap_permissions(test_client, standard_user, knowledge_database): + """测试思维导图接口的权限控制""" + db_id = knowledge_database["db_id"] + + # 普通用户应该无法访问 + forbidden_list = await test_client.get("/api/mindmap/databases", headers=standard_user["headers"]) + _assert_forbidden_response(forbidden_list) + + forbidden_files = await test_client.get(f"/api/mindmap/databases/{db_id}/files", headers=standard_user["headers"]) + _assert_forbidden_response(forbidden_files) + + forbidden_generate = await test_client.post( + "/api/mindmap/generate", + json={"db_id": db_id, "file_ids": []}, + headers=standard_user["headers"], + ) + _assert_forbidden_response(forbidden_generate) diff --git a/test/compare_kb_metadata_with_db.py b/test/compare_kb_metadata_with_db.py new file mode 100644 index 000000000..31604fdd3 --- /dev/null +++ b/test/compare_kb_metadata_with_db.py @@ -0,0 +1,158 @@ +import asyncio +import glob +import json +import os +import sys +from dataclasses import dataclass +from typing import Any + +from sqlalchemy import func, select + +os.environ.setdefault("YUXI_SKIP_APP_INIT", "1") +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from src.config import config +from src.storage.postgres.manager import pg_manager +from src.storage.postgres.models_knowledge import ( + EvaluationBenchmark, + EvaluationResult, + EvaluationResultDetail, + KnowledgeBase, + KnowledgeFile, +) + + +def _load_json(path: str) -> dict[str, Any]: + if not os.path.exists(path): + return {} + with open(path, encoding="utf-8") as f: + return json.load(f) + + +@dataclass(frozen=True) +class JsonState: + kb_ids: set[str] + file_ids: set[str] + benchmark_ids: set[str] + result_task_ids: set[str] + result_detail_count: int + + +@dataclass(frozen=True) +class DbState: + kb_ids: set[str] + file_ids: set[str] + benchmark_ids: set[str] + result_task_ids: set[str] + result_detail_count: int + + +def load_json_state() -> JsonState: + base_dir = os.path.join(config.save_dir, "knowledge_base_data") + global_meta = _load_json(os.path.join(base_dir, "global_metadata.json")).get("databases", {}) or {} + + kb_ids: set[str] = set(global_meta.keys()) + file_ids: set[str] = set() + benchmark_ids: set[str] = set() + result_task_ids: set[str] = set() + result_detail_count = 0 + + kb_type_dirs = [ + p for p in glob.glob(os.path.join(base_dir, "*_data")) if os.path.isdir(p) and os.path.basename(p) != "uploads" + ] + + for kb_dir in kb_type_dirs: + kb_type = os.path.basename(kb_dir)[: -len("_data")] + meta_file = os.path.join(kb_dir, f"metadata_{kb_type}.json") + meta = _load_json(meta_file) + + databases_meta: dict[str, Any] = meta.get("databases", {}) or {} + files_meta: dict[str, Any] = meta.get("files", {}) or {} + benchmarks_meta: dict[str, Any] = meta.get("benchmarks", {}) or {} + + kb_ids.update(databases_meta.keys()) + file_ids.update(files_meta.keys()) + + for _db_id, bmap in benchmarks_meta.items(): + if not isinstance(bmap, dict): + continue + benchmark_ids.update(bmap.keys()) + + for db_id in databases_meta.keys(): + result_dir = os.path.join(kb_dir, db_id, "results") + if not os.path.isdir(result_dir): + continue + for result_path in glob.glob(os.path.join(result_dir, "*.json")): + try: + data = _load_json(result_path) + except Exception: + continue + task_id = data.get("task_id") or os.path.splitext(os.path.basename(result_path))[0] + result_task_ids.add(task_id) + interim = data.get("interim_results") or data.get("results") or [] + result_detail_count += len(interim) + + return JsonState( + kb_ids=kb_ids, + file_ids=file_ids, + benchmark_ids=benchmark_ids, + result_task_ids=result_task_ids, + result_detail_count=result_detail_count, + ) + + +async def load_db_state() -> DbState: + async with pg_manager.get_async_session_context() as session: + kb_ids = set((await session.execute(select(KnowledgeBase.db_id))).scalars().all()) + file_ids = set((await session.execute(select(KnowledgeFile.file_id))).scalars().all()) + benchmark_ids = set((await session.execute(select(EvaluationBenchmark.benchmark_id))).scalars().all()) + result_task_ids = set((await session.execute(select(EvaluationResult.task_id))).scalars().all()) + detail_count = (await session.execute(select(func.count(EvaluationResultDetail.id)))).scalar_one() + + return DbState( + kb_ids=kb_ids, + file_ids=file_ids, + benchmark_ids=benchmark_ids, + result_task_ids=result_task_ids, + result_detail_count=int(detail_count or 0), + ) + + +def _diff(name: str, json_set: set[str], db_set: set[str], limit: int = 30) -> list[str]: + missing = sorted(json_set - db_set) + extra = sorted(db_set - json_set) + lines: list[str] = [] + lines.append(f"{name}: json={len(json_set)} db={len(db_set)}") + if missing: + preview = ", ".join(missing[:limit]) + lines.append(f" missing_in_db({len(missing)}): {preview}") + if extra: + preview = ", ".join(extra[:limit]) + lines.append(f" extra_in_db({len(extra)}): {preview}") + return lines + + +async def main() -> None: + engine_url = pg_manager.async_engine.url.render_as_string(hide_password=True) + print(f"db_url={engine_url}") + + json_state = load_json_state() + db_state = await load_db_state() + + for line in _diff("knowledge_bases.db_id", json_state.kb_ids, db_state.kb_ids): + print(line) + for line in _diff("knowledge_files.file_id", json_state.file_ids, db_state.file_ids): + print(line) + for line in _diff("evaluation_benchmarks.benchmark_id", json_state.benchmark_ids, db_state.benchmark_ids): + print(line) + for line in _diff("evaluation_results.task_id", json_state.result_task_ids, db_state.result_task_ids): + print(line) + + print( + "evaluation_result_details: " + f"json_count={json_state.result_detail_count} db_count={db_state.result_detail_count}" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/test/conftest.py b/test/conftest.py index a7b41f89d..47c6fa216 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,6 +8,7 @@ import uuid from collections.abc import AsyncGenerator +import anyio import httpx import pytest import pytest_asyncio @@ -78,6 +79,69 @@ def admin_headers(admin_token: str) -> dict[str, str]: return {"Authorization": f"Bearer {admin_token}"} +@pytest.fixture(scope="session", autouse=True) +def cleanup_test_knowledge_databases(): + """ + Best-effort cleanup for leftover test knowledge databases (e.g. pytest_* / py_test*). + """ + + async def run_cleanup() -> None: + global _ADMIN_TOKEN_CACHE + + if not ADMIN_LOGIN or not ADMIN_PASSWORD: + return + + if not _ADMIN_TOKEN_CACHE: + async with httpx.AsyncClient( + base_url=API_BASE_URL, timeout=HTTP_TIMEOUT, follow_redirects=True + ) as bootstrap_client: + response = await bootstrap_client.post( + "/api/auth/token", data={"username": ADMIN_LOGIN, "password": ADMIN_PASSWORD} + ) + if response.status_code != 200: + return + token = response.json().get("access_token") + if not token: + return + _ADMIN_TOKEN_CACHE = token + + headers = {"Authorization": f"Bearer {_ADMIN_TOKEN_CACHE}"} + + async with httpx.AsyncClient(base_url=API_BASE_URL, timeout=HTTP_TIMEOUT, follow_redirects=True) as client: + try: + list_response = await client.get("/api/knowledge/databases", headers=headers) + except Exception as e: + print(f"Warning: Failed to list knowledge databases for cleanup: {e}") + return + + if list_response.status_code != 200: + return + + databases = list_response.json().get("databases", []) + prefixes = ("pytest_", "py_test") + for entry in databases: + name = entry.get("name") or "" + db_id = entry.get("db_id") + if not db_id or not isinstance(name, str) or not name.startswith(prefixes): + continue + try: + delete_response = await client.delete(f"/api/knowledge/databases/{db_id}", headers=headers) + if delete_response.status_code not in (200, 404): + print(f"Warning: Failed to cleanup knowledge database {db_id}: {delete_response.text}") + except Exception as e: + print(f"Warning: Exception during cleanup of {db_id}: {e}") + + try: + anyio.run(run_cleanup) + except Exception as e: + print(f"Warning: Exception during session cleanup startup: {e}") + yield + try: + anyio.run(run_cleanup) + except Exception as e: + print(f"Warning: Exception during session cleanup teardown: {e}") + + @pytest_asyncio.fixture(scope="function") async def standard_user(test_client: httpx.AsyncClient, admin_headers: dict[str, str]) -> dict: """ @@ -126,30 +190,49 @@ async def knowledge_database(test_client: httpx.AsyncClient, admin_headers: dict """ Create a temporary knowledge database for tests that need LightRAG metadata. """ - db_name = f"pytest_kb_{uuid.uuid4().hex[:6]}" - create_response = await test_client.post( - "/api/knowledge/databases", - json={ - "database_name": db_name, - "description": "Pytest managed knowledge base", - "embed_model_name": "siliconflow/BAAI/bge-m3", - "kb_type": "lightrag", - "additional_params": {}, - }, - headers=admin_headers, - ) - if create_response.status_code != 200: - pytest.fail( - f"Failed to create knowledge database (status={create_response.status_code}): {create_response.text}" - ) + import time - db_payload = create_response.json() - db_id = db_payload["db_id"] + # 使用UUID作为数据库名称的一部分,确保唯一性 + unique_id = uuid.uuid4().hex + timestamp = int(time.time() * 1000000) # 微秒级时间戳 + db_name = f"pytest_kb_{timestamp}_{unique_id}" + db_id = None try: - yield db_payload + create_response = await test_client.post( + "/api/knowledge/databases", + json={ + "database_name": db_name, + "description": "Pytest managed knowledge base", + "embed_model_name": "siliconflow/BAAI/bge-m3", + "kb_type": "lightrag", + "additional_params": {}, + }, + headers=admin_headers, + ) + + if create_response.status_code == 200: + db_payload = create_response.json() + db_id = db_payload["db_id"] + elif create_response.status_code == 409: + error_detail = create_response.json().get("detail", "") + pytest.fail(f"Knowledge database name conflict: {error_detail}. Please clean up old test databases first.") + else: + pytest.fail( + f"Failed to create knowledge database (status={create_response.status_code}): {create_response.text}" + ) + + yield db_payload if db_id else {"db_id": db_id, "name": db_name} + finally: - await test_client.delete(f"/api/knowledge/databases/{db_id}", headers=admin_headers) + # 确保清理,即使测试失败 + if db_id: + try: + delete_response = await test_client.delete(f"/api/knowledge/databases/{db_id}", headers=admin_headers) + if delete_response.status_code != 200: + print(f"Warning: Failed to cleanup knowledge database {db_id}: {delete_response.text}") + except Exception as e: + print(f"Warning: Exception during cleanup of {db_id}: {e}") def pytest_configure(config: pytest.Config) -> None: diff --git a/web/src/apis/agent_api.js b/web/src/apis/agent_api.js index 98b24d8ee..46908c133 100644 --- a/web/src/apis/agent_api.js +++ b/web/src/apis/agent_api.js @@ -1,4 +1,13 @@ -import { apiGet, apiPost, apiDelete, apiPut, apiAdminGet, apiAdminPost, apiRequest } from './base' +import { + apiGet, + apiPost, + apiDelete, + apiPut, + apiAdminGet, + apiAdminPost, + apiAdminDelete, + apiRequest +} from './base' import { useUserStore } from '@/stores/user' /** @@ -136,6 +145,23 @@ export const agentApi = { return apiAdminPost(url, config) }, + getAgentConfigs: (agentId) => apiGet(`/api/chat/agent/${agentId}/configs`), + + getAgentConfigProfile: (agentId, configId) => + apiGet(`/api/chat/agent/${agentId}/configs/${configId}`), + + createAgentConfigProfile: (agentId, payload) => + apiAdminPost(`/api/chat/agent/${agentId}/configs`, payload), + + updateAgentConfigProfile: (agentId, configId, payload) => + apiPut(`/api/chat/agent/${agentId}/configs/${configId}`, payload), + + setAgentConfigDefault: (agentId, configId) => + apiAdminPost(`/api/chat/agent/${agentId}/configs/${configId}/set_default`, {}), + + deleteAgentConfigProfile: (agentId, configId) => + apiAdminDelete(`/api/chat/agent/${agentId}/configs/${configId}`), + /** * 设置默认智能体 * @param {string} agentId - 智能体ID diff --git a/web/src/apis/department_api.js b/web/src/apis/department_api.js index 5013e82ef..5dc98f344 100644 --- a/web/src/apis/department_api.js +++ b/web/src/apis/department_api.js @@ -2,16 +2,22 @@ * 部门管理 API */ -import { apiSuperAdminGet, apiSuperAdminPost, apiSuperAdminPut, apiSuperAdminDelete } from './base' +import { + apiAdminGet, + apiSuperAdminGet, + apiSuperAdminPost, + apiSuperAdminPut, + apiSuperAdminDelete +} from './base' const BASE_URL = '/api/departments' /** - * 获取部门列表 + * 获取部门列表(普通管理员可访问) * @returns {Promise} 部门列表 */ export const getDepartments = () => { - return apiSuperAdminGet(BASE_URL) + return apiAdminGet(BASE_URL) } /** diff --git a/web/src/apis/knowledge_api.js b/web/src/apis/knowledge_api.js index 031db8839..122306939 100644 --- a/web/src/apis/knowledge_api.js +++ b/web/src/apis/knowledge_api.js @@ -68,6 +68,14 @@ export const databaseApi = { current_description: currentDescription, file_list: fileList }) + }, + + /** + * 获取当前用户有权访问的知识库列表(用于智能体配置) + * @returns {Promise} - 可访问的知识库列表 + */ + getAccessibleDatabases: async () => { + return apiAdminGet('/api/knowledge/databases/accessible') } } diff --git a/web/src/components/AgentChatComponent.vue b/web/src/components/AgentChatComponent.vue index be09856e0..fa9a0d779 100644 --- a/web/src/components/AgentChatComponent.vue +++ b/web/src/components/AgentChatComponent.vue @@ -205,7 +205,7 @@ const emit = defineEmits(['open-config', 'open-agent-modal']) // ==================== STORE MANAGEMENT ==================== const agentStore = useAgentStore() const chatUIStore = useChatUIStore() -const { agents, selectedAgentId, defaultAgentId } = storeToRefs(agentStore) +const { agents, selectedAgentId, defaultAgentId, selectedAgentConfigId } = storeToRefs(agentStore) // ==================== LOCAL CHAT & UI STATE ==================== const userInput = ref('') @@ -621,7 +621,8 @@ const sendMessage = async ({ const requestData = { query: text, config: { - thread_id: threadId + thread_id: threadId, + ...(selectedAgentConfigId.value ? { agent_config_id: selectedAgentConfigId.value } : {}) } } @@ -871,7 +872,11 @@ const handleApprovalWithStream = async (approved) => { try { // 使用审批 composable 处理审批 - const response = await handleApproval(approved, currentAgentId.value) + const response = await handleApproval( + approved, + currentAgentId.value, + selectedAgentConfigId.value + ) if (!response) return // 如果 handleApproval 抛出错误,这里不会执行 @@ -1363,12 +1368,12 @@ watch( diff --git a/web/src/components/BasicSettingsSection.vue b/web/src/components/BasicSettingsSection.vue index edf646bb9..777e96c9c 100644 --- a/web/src/components/BasicSettingsSection.vue +++ b/web/src/components/BasicSettingsSection.vue @@ -146,20 +146,6 @@ 访问 - - diff --git a/web/src/components/DebugComponent.vue b/web/src/components/DebugComponent.vue index 907b945d7..2ac7ad0ee 100644 --- a/web/src/components/DebugComponent.vue +++ b/web/src/components/DebugComponent.vue @@ -1,125 +1,169 @@ diff --git a/web/src/components/UserInfoComponent.vue b/web/src/components/UserInfoComponent.vue index d2cf8d231..af0c4efaf 100644 --- a/web/src/components/UserInfoComponent.vue +++ b/web/src/components/UserInfoComponent.vue @@ -26,28 +26,36 @@ - + 文档中心 {{ themeStore.isDark ? '切换到浅色模式' : '切换到深色模式 (Beta)' }} + + 调试面板(非生产环境) + 系统设置 - + 退出登录 @@ -165,6 +173,9 @@ + + + @@ -172,10 +183,7 @@ import { computed, ref, inject, h } from 'vue' import { useRouter } from 'vue-router' import { useUserStore } from '@/stores/user' -// -// -// -// +import DebugComponent from '@/components/DebugComponent.vue' import { message } from 'ant-design-vue' import { CircleUser, @@ -186,7 +194,8 @@ import { User, LogOut, Upload, - Settings + Settings, + Terminal } from 'lucide-vue-next' import { useThemeStore } from '@/stores/theme' @@ -194,6 +203,17 @@ const router = useRouter() const userStore = useUserStore() const themeStore = useThemeStore() +// 预定义图标组件,避免 Vue 警告 +const BookOpenIcon = h(BookOpen, { size: '16' }) +const SunIcon = h(Sun, { size: '16' }) +const MoonIcon = h(Moon, { size: '16' }) +const TerminalIcon = h(Terminal, { size: '16' }) +const SettingsIcon = h(Settings, { size: '16' }) +const LogOutIcon = h(LogOut, { size: '16' }) + +// 调试面板状态 +const showDebug = ref(false) + // Inject settings modal methods const { openSettingsModal } = inject('settingsModal', {}) diff --git a/web/src/components/UserManagementComponent.vue b/web/src/components/UserManagementComponent.vue index 8ea829e71..80aab46fa 100644 --- a/web/src/components/UserManagementComponent.vue +++ b/web/src/components/UserManagementComponent.vue @@ -41,20 +41,22 @@