From 154ed4e916ee8e85088ed8080bd8cc329c4a6c35 Mon Sep 17 00:00:00 2001 From: Wenjie Zhang Date: Tue, 20 Jan 2026 02:14:51 +0800 Subject: [PATCH 01/22] =?UTF-8?q?=E5=8A=9F=E8=83=BD=EF=BC=9A=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA=E7=9F=A5=E8=AF=86=E5=BA=93=E8=AE=BF=E9=97=AE=E6=8E=A7?= =?UTF-8?q?=E5=88=B6=E5=92=8C=E5=85=B1=E4=BA=AB=E9=85=8D=E7=BD=AE=20#480?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在save_agent_config中实现了基于用户角色的知识库访问控制。 - 添加了新端点以检索当前用户可访问的数据库。 - 在数据库创建和更新过程中引入了share_config以管理共享设置。 - 创建了ShareConfigForm组件用于在UI中管理共享设置。 - 更新了部门和知识API以支持新的访问控制功能。 - 增强了用户管理组件以显示用户角色和部门名称。 - 重构了数据库视图以包含共享配置设置。 --- server/routers/auth_router.py | 80 +++++-- server/routers/chat_router.py | 27 +++ server/routers/department_router.py | 23 +- server/routers/knowledge_router.py | 50 ++++- server/routers/mindmap_router.py | 5 +- src/knowledge/base.py | 1 + src/knowledge/manager.py | 164 +++++++++++++- web/src/apis/department_api.js | 6 +- web/src/apis/knowledge_api.js | 8 + web/src/components/KnowledgeBaseCard.vue | 127 +++++++++-- web/src/components/ShareConfigForm.vue | 208 ++++++++++++++++++ .../components/UserManagementComponent.vue | 61 +++-- web/src/stores/user.js | 2 +- web/src/views/DataBaseView.vue | 25 +++ 14 files changed, 680 insertions(+), 107 deletions(-) create mode 100644 web/src/components/ShareConfigForm.vue diff --git a/server/routers/auth_router.py b/server/routers/auth_router.py index b14592f96..63f4a1463 100644 --- a/server/routers/auth_router.py +++ b/server/routers/auth_router.py @@ -30,6 +30,8 @@ class Token(BaseModel): phone_number: str | None = None avatar: str | None = None role: str + department_id: int | None = None + department_name: str | None = None class UserCreate(BaseModel): @@ -176,6 +178,12 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( # 记录登录操作 await log_operation(db, user.id, "登录") + # 获取部门名称 + department_name = None + if user.department_id: + result = await db.execute(select(Department.name).filter(Department.id == user.department_id)) + department_name = result.scalar_one_or_none() + return { "access_token": access_token, "token_type": "bearer", @@ -186,6 +194,7 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( "avatar": user.avatar, "role": user.role, "department_id": user.department_id, + "department_name": department_name, } @@ -275,9 +284,15 @@ async def initialize_admin(admin_data: InitializeAdmin, db: AsyncSession = Depen @auth.get("/me", response_model=UserResponse) -async def read_users_me(current_user: User = Depends(get_current_user)): +async def read_users_me(current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db)): """获取当前登录用户的个人信息""" - return current_user.to_dict() + user_dict = current_user.to_dict() + + if current_user.department_id: + result = await db.execute(select(Department.name).filter(Department.id == current_user.department_id)) + user_dict["department_name"] = result.scalar_one_or_none() + + return user_dict # 路由:更新个人资料 @@ -393,11 +408,11 @@ async def create_user( hashed_password = AuthUtils.hash_password(user_data.password) # 检查角色权限 - # 超级管理员可以创建任何类型的用户 - if user_data.role == "superadmin" and current_user.role != "superadmin": + # 禁止创建超级管理员账户(系统只能有一个超级管理员) + if user_data.role == "superadmin": raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="只有超级管理员才能创建超级管理员账户", + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能创建超级管理员账户", ) # 管理员只能创建普通用户 @@ -490,6 +505,18 @@ async def read_user(user_id: int, current_user: User = Depends(get_admin_user), return user.to_dict() +async def check_department_admin_count(db: AsyncSession, department_id: int, exclude_user_id: int) -> int: + """检查部门中管理员数量(排除指定用户)""" + result = await db.execute( + select(func.count(User.id)).filter( + User.department_id == department_id, + User.role == "admin", + User.id != exclude_user_id, + ) + ) + return result.scalar() + + # 路由:更新用户信息(管理员权限) @auth.put("/users/{user_id}", response_model=UserResponse) async def update_user( @@ -541,6 +568,14 @@ async def update_user( update_details.append("密码已更新") if user_data.role is not None: + # 检查是否将管理员降级为普通用户 + if user.role == "admin" and user_data.role == "user" and user.department_id is not None: + admin_count = await check_department_admin_count(db, user.department_id, user_id) + if admin_count <= 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能将管理员降级为普通用户,因为该用户是当前部门的唯一管理员", + ) user.role = user_data.role update_details.append(f"角色: {user_data.role}") @@ -553,12 +588,22 @@ async def update_user( update_details.append(f"头像: {user_data.avatar or '已清空'}") # 部门修改权限控制(只有超级管理员可以修改用户部门) - if user_data.department_id is not None: + if user_data.department_id is not None and user_data.department_id != user.department_id: if current_user.role != "superadmin": raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="只有超级管理员才能修改用户部门", ) + + # 检查该用户是否是当前部门的唯一管理员 + if user.role == "admin" and user.department_id is not None: + admin_count = await check_department_admin_count(db, user.department_id, user_id) + if admin_count <= 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能修改该用户的部门,因为该用户是当前部门的唯一管理员", + ) + user.department_id = user_data.department_id update_details.append(f"部门ID: {user_data.department_id}") @@ -583,25 +628,12 @@ async def delete_user( detail="用户不存在", ) - # 检查权限 + # 不能删除超级管理员账户 if user.role == "superadmin": - # 只有超级管理员可以删除超级管理员 - if current_user.role != "superadmin": - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="只有超级管理员才能删除超级管理员账户", - ) - - # 检查是否是最后一个超级管理员 - result = await db.execute( - select(func.count(User.id)).filter(User.role == "superadmin", User.is_deleted == 0) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="不能删除超级管理员账户", ) - superadmin_count = result.scalar() - if superadmin_count <= 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="不能删除最后一个超级管理员账户", - ) # 检查是否是部门的唯一管理员 if user.role == "admin": diff --git a/server/routers/chat_router.py b/server/routers/chat_router.py index 030ede903..6081886e0 100644 --- a/server/routers/chat_router.py +++ b/server/routers/chat_router.py @@ -859,6 +859,31 @@ async def save_agent_config( if not (agent := agent_manager.get_agent(agent_id)): raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在") + # === 校验知识库权限 === + from src import knowledge_base + + if "knowledges" in config and config["knowledges"]: + # 获取用户有权访问的知识库 ID + try: + user_info = {"role": current_user.role, "department_id": current_user.department_id} + accessible_databases = knowledge_base.get_databases_by_user(user_info) + accessible_db_ids = {db.get("db_id") for db in accessible_databases.get("databases", []) if db.get("db_id")} + except Exception as db_error: + logger.warning(f"获取知识库列表失败: {db_error}") + # 如果获取失败,superadmin 可以访问所有,非 superadmin 无法访问任何 + if current_user.role != "superadmin": + raise HTTPException(status_code=500, detail="无法获取知识库列表") + accessible_db_ids = set(knowledge_base.global_databases_meta.keys()) + + # 检查配置中的知识库是否都可用 + invalid_kbs = [kb for kb in config["knowledges"] if kb not in accessible_db_ids] + if invalid_kbs: + raise HTTPException( + status_code=403, + detail=f"无权访问以下知识库: {', '.join(invalid_kbs)}" + ) + # === 校验结束 === + # 使用配置类的save_to_file方法保存配置 result = agent.context_schema.save_to_file(config, agent.module_name) @@ -869,6 +894,8 @@ async def save_agent_config( else: raise HTTPException(status_code=500, detail="保存智能体配置失败") + except HTTPException: + raise except Exception as e: logger.error(f"保存智能体配置出错: {e}, {traceback.format_exc()}") raise HTTPException(status_code=500, detail=f"保存智能体配置出错: {str(e)}") diff --git a/server/routers/department_router.py b/server/routers/department_router.py index 057e051d3..c4eca3b71 100644 --- a/server/routers/department_router.py +++ b/server/routers/department_router.py @@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from src.storage.db.models import Department, User -from server.utils.auth_middleware import get_superadmin_user, get_db +from server.utils.auth_middleware import get_superadmin_user, get_admin_user, get_db from server.utils.auth_utils import AuthUtils from server.utils.common_utils import log_operation from server.utils.user_utils import is_valid_phone_number @@ -60,27 +60,16 @@ class DepartmentResponse(BaseModel): user_count: int = 0 -class DepartmentSimpleResponse(BaseModel): - """部门简单响应(不含用户数量)""" - - id: int - name: str - description: str | None = None - created_at: str - - # ============================================================================= # === 部门管理路由 === # ============================================================================= -@department.get("", response_model=list[DepartmentResponse]) -async def get_departments(current_user: User = Depends(get_superadmin_user), db: AsyncSession = Depends(get_db)): - """获取所有部门列表""" +async def _get_departments_with_user_count(db: AsyncSession) -> list[dict]: + """获取所有部门列表,包含用户数量(内部辅助函数)""" result = await db.execute(select(Department).order_by(Department.created_at.desc())) departments = result.scalars().all() - # 获取每个部门的用户数量 department_list = [] for dep in departments: user_count_result = await db.execute( @@ -92,6 +81,12 @@ async def get_departments(current_user: User = Depends(get_superadmin_user), db: return department_list +@department.get("", response_model=list[DepartmentResponse]) +async def get_departments(current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db)): + """获取所有部门列表(管理员可访问)""" + return await _get_departments_with_user_count(db) + + @department.get("/{department_id}", response_model=DepartmentResponse) async def get_department( department_id: int, current_user: User = Depends(get_superadmin_user), db: AsyncSession = Depends(get_db) diff --git a/server/routers/knowledge_router.py b/server/routers/knowledge_router.py index 754814e44..289674265 100644 --- a/server/routers/knowledge_router.py +++ b/server/routers/knowledge_router.py @@ -1,4 +1,5 @@ import asyncio +import json import os import textwrap import traceback @@ -7,10 +8,11 @@ import aiofiles from fastapi import APIRouter, Body, Depends, File, HTTPException, Query, Request, UploadFile from fastapi.responses import FileResponse +from pydantic import BaseModel from starlette.responses import StreamingResponse from server.services.tasker import TaskContext, tasker -from server.utils.auth_middleware import get_admin_user +from server.utils.auth_middleware import get_admin_user, get_required_user from src import config, knowledge_base from src.knowledge.indexing import SUPPORTED_FILE_EXTENSIONS, is_supported_file_extension, process_file_to_markdown from src.knowledge.utils import calculate_content_hash @@ -18,6 +20,7 @@ from src.storage.db.models import User from src.storage.minio.client import StorageError, aupload_file_to_minio, get_minio_client from src.utils import logger +from src.utils.datetime_utils import utc_isoformat knowledge = APIRouter(prefix="/knowledge", tags=["knowledge"]) @@ -63,10 +66,10 @@ @knowledge.get("/databases") async def get_databases(current_user: User = Depends(get_admin_user)): - """获取所有知识库""" + """获取所有知识库(根据用户权限过滤)""" try: - database = knowledge_base.get_databases() - return database + user_info = {"role": current_user.role, "department_id": current_user.department_id} + return knowledge_base.get_databases_by_user(user_info) except Exception as e: logger.error(f"获取数据库列表失败 {e}, {traceback.format_exc()}") return {"message": f"获取数据库列表失败 {e}", "databases": []} @@ -80,13 +83,14 @@ async def create_database( kb_type: str = Body("lightrag"), additional_params: dict = Body({}), llm_info: dict = Body(None), + share_config: dict = Body(None), current_user: User = Depends(get_admin_user), ): """创建知识库""" logger.debug( f"Create database {database_name} with kb_type {kb_type}, " f"additional_params {additional_params}, llm_info {llm_info}, " - f"embed_model_name {embed_model_name}" + f"embed_model_name {embed_model_name}, share_config {share_config}" ) try: # 先检查名称是否已存在 @@ -117,7 +121,7 @@ def remove_reranker_config(kb: str, params: dict) -> None: embed_info = config.embed_model_names[embed_model_name] database_info = await knowledge_base.create_database( - database_name, description, kb_type=kb_type, embed_info=embed_info, llm_info=llm_info, **additional_params + database_name, description, kb_type=kb_type, embed_info=embed_info, llm_info=llm_info, share_config=share_config, **additional_params ) # 需要重新加载所有智能体,因为工具刷新了 @@ -133,6 +137,28 @@ def remove_reranker_config(kb: str, params: dict) -> None: raise HTTPException(status_code=400, detail=f"创建数据库失败: {e}") +@knowledge.get("/databases/accessible") +async def get_accessible_databases(current_user: User = Depends(get_required_user)): + """获取当前用户有权访问的知识库列表(用于智能体配置)""" + try: + user_info = {"role": current_user.role, "department_id": current_user.department_id} + databases = knowledge_base.get_databases_by_user(user_info) + + accessible = [ + { + "name": db.get("name", ""), + "db_id": db.get("db_id"), + "description": db.get("description", ""), + } + for db in databases.get("databases", []) + ] + + return {"databases": accessible} + except Exception as e: + logger.error(f"获取可访问知识库列表失败: {e}, {traceback.format_exc()}") + return {"message": f"获取可访问知识库列表失败: {str(e)}", "databases": []} + + @knowledge.get("/databases/{db_id}") async def get_database_info(db_id: str, current_user: User = Depends(get_admin_user)): """获取知识库详细信息""" @@ -148,13 +174,14 @@ async def update_database_info( name: str = Body(...), description: str = Body(...), llm_info: dict = Body(None), - additional_params: dict = Body({}), # Now accepts a dict + additional_params: dict = Body({}), + share_config: dict = Body(None), current_user: User = Depends(get_admin_user), ): """更新知识库信息""" logger.debug( - f"Update database {db_id} info: {name}, {description}, llm_info: {llm_info}, " - f"additional_params: {additional_params}" + f"[update_database_info] 接收到的参数: name={name}, llm_info={llm_info}, " + f"additional_params={additional_params}, share_config={share_config}" ) try: database = await knowledge_base.update_database( @@ -162,7 +189,8 @@ async def update_database_info( name, description, llm_info, - additional_params=additional_params, # Pass the dict to the manager + additional_params=additional_params, + share_config=share_config, ) return {"message": "更新成功", "database": database} except Exception as e: @@ -872,8 +900,6 @@ async def generate_sample_questions( 生成的问题列表 """ try: - import json - from src.models import select_model # 从请求体中提取参数 diff --git a/server/routers/mindmap_router.py b/server/routers/mindmap_router.py index 81fa27a16..b32cc7c02 100644 --- a/server/routers/mindmap_router.py +++ b/server/routers/mindmap_router.py @@ -282,13 +282,14 @@ async def generate_mindmap( @mindmap.get("/databases") async def get_databases_overview(current_user: User = Depends(get_admin_user)): """ - 获取所有知识库的概览信息,用于思维导图界面选择 + 获取所有知识库的概览信息,用于思维导图界面选择(根据用户权限过滤) Returns: 知识库列表 """ try: - databases = knowledge_base.get_databases() + user_info = {"role": current_user.role, "department_id": current_user.department_id} + databases = knowledge_base.get_databases_by_user(user_info) # databases["databases"] 是一个列表,每个元素已经包含了基本信息 db_list_raw = databases.get("databases", []) diff --git a/src/knowledge/base.py b/src/knowledge/base.py index 1092c7bf7..ba1fb8c95 100644 --- a/src/knowledge/base.py +++ b/src/knowledge/base.py @@ -357,6 +357,7 @@ def create_database( database_name: 数据库名称 description: 数据库描述 embed_info: 嵌入模型信息 + llm_info: LLM配置信息 **kwargs: 其他配置参数 Returns: diff --git a/src/knowledge/manager.py b/src/knowledge/manager.py index abf097133..750449034 100644 --- a/src/knowledge/manager.py +++ b/src/knowledge/manager.py @@ -49,6 +49,12 @@ def __init__(self, work_dir: str): except Exception as e: logger.warning(f"Failed to migrate query_params: {e}") + # 迁移 share_config 到 global metadata + try: + self._migrate_share_config() + except Exception as e: + logger.warning(f"Failed to migrate share_config: {e}") + logger.info("KnowledgeBaseManager initialized") # 在后台运行数据一致性检测(不阻塞初始化) @@ -77,7 +83,9 @@ def _load_global_metadata(self): with open(meta_file, encoding="utf-8") as f: data = json.load(f) self.global_databases_meta = data.get("databases", {}) - logger.info(f"Loaded global metadata for {len(self.global_databases_meta)} databases") + logger.info(f"[_load_global_metadata] 加载了 {len(self.global_databases_meta)} 个数据库的全局元数据") + for db_id, meta in self.global_databases_meta.items(): + logger.debug(f" [{db_id}] share_config: {meta.get('share_config')}") except Exception as e: logger.error(f"Failed to load global metadata: {e}") # 尝试从备份恢复 @@ -111,6 +119,7 @@ def _save_global_metadata(self): # 准备数据 data = {"databases": self.global_databases_meta, "updated_at": utc_isoformat(), "version": "2.0"} + logger.debug(f"[_save_global_metadata] 保存数据: databases count={len(self.global_databases_meta)}") # 原子性写入(使用临时文件) with tempfile.NamedTemporaryFile( @@ -122,6 +131,12 @@ def _save_global_metadata(self): os.replace(temp_path, meta_file) logger.debug("Saved global metadata") + # 验证写入 + if os.path.exists(meta_file): + with open(meta_file, encoding="utf-8") as f: + saved_data = json.load(f) + logger.debug(f"[_save_global_metadata] 验证: 保存了 {len(saved_data.get('databases', {}))} 个数据库") + except Exception as e: logger.error(f"Failed to save global metadata: {e}") # 尝试恢复备份 @@ -229,10 +244,95 @@ def get_databases(self) -> dict: # 收集所有知识库的数据库信息 for kb_type, kb_instance in self.kb_instances.items(): kb_databases = kb_instance.get_databases()["databases"] + + # 合并全局元数据 + for db in kb_databases: + db_id = db.get("db_id") + if db_id and db_id in self.global_databases_meta: + global_meta = self.global_databases_meta[db_id] + + # 合并 share_config + db["share_config"] = global_meta.get("share_config", {"is_shared": True, "accessible_departments": []}) + + # 合并 additional_params + # 注意:kb_instance 返回的 metadata 字段可能已经包含了部分参数, + # 但 global_databases_meta 中的 additional_params 是我们在 create/update 时保存的 + db["additional_params"] = global_meta.get("additional_params", {}) + all_databases.extend(kb_databases) return {"databases": all_databases} + def check_accessible(self, user: dict, db_id: str) -> bool: + """检查用户是否有权限访问数据库 + + Args: + user: 用户信息字典 + db_id: 数据库ID + + Returns: + bool: 是否有权限 + """ + # 超级管理员有权访问所有 + if user.get("role") == "superadmin": + return True + + if db_id not in self.global_databases_meta: + return False + + share_config = self.global_databases_meta[db_id].get("share_config", {}) + is_shared = share_config.get("is_shared", True) + + # 如果是全员共享,则有权限 + if is_shared: + return True + + # 检查部门权限 + user_department_id = user.get("department_id") + accessible_departments = share_config.get("accessible_departments", []) + + if user_department_id is None: + return False + + # 转换为整数进行比较(前端可能传递字符串,后端存储为整数) + try: + user_department_id = int(user_department_id) + accessible_departments = [int(d) for d in accessible_departments] + except (ValueError, TypeError): + return False + + if user_department_id in accessible_departments: + return True + + return False + + def get_databases_by_user(self, user: dict) -> dict: + """根据用户权限获取知识库列表 + + Args: + user: 用户信息字典,包含 role 和 department_id + + Returns: + 过滤后的知识库列表 + """ + all_databases = self.get_databases().get("databases", []) + + # 超级管理员可以看到所有知识库 + if user.get("role") == "superadmin": + return {"databases": all_databases} + + filtered_databases = [] + + for db in all_databases: + db_id = db.get("db_id") + if not db_id: + continue + + if self.check_accessible(user, db_id): + filtered_databases.append(db) + + return {"databases": filtered_databases} + def database_name_exists(self, database_name: str) -> bool: """检查知识库名称是否已存在 @@ -253,7 +353,7 @@ async def create_folder(self, db_id: str, folder_name: str, parent_id: str = Non return kb_instance.create_folder(db_id, folder_name, parent_id) async def create_database( - self, database_name: str, description: str, kb_type: str = "lightrag", embed_info: dict | None = None, **kwargs + self, database_name: str, description: str, kb_type: str = "lightrag", embed_info: dict | None = None, share_config: dict | None = None, **kwargs ) -> dict: """ 创建数据库 @@ -263,6 +363,7 @@ async def create_database( description: 数据库描述 kb_type: 知识库类型,默认为lightrag embed_info: 嵌入模型信息 + share_config: 共享配置 **kwargs: 其他配置参数,包括chunk_size和chunk_overlap Returns: @@ -275,9 +376,16 @@ async def create_database( # 检查名称是否已存在 if self.database_name_exists(database_name): raise ValueError(f"知识库名称 '{database_name}' 已存在,请使用其他名称") + + # 默认共享配置 + if share_config is None: + share_config = {"is_shared": True, "accessible_departments": []} kb_instance = self._get_or_create_kb_instance(kb_type) + # 注意:不再传递 share_config 给 kb_instance,因为它由 Manager 管理 + # 但为了兼容性,如果 kb_instance.create_database 签名还没改,可能会有问题? + # Base KB create_database 接受 **kwargs,所以没问题。我们这里不传 share_config 给它。 db_info = kb_instance.create_database(database_name, description, embed_info, **kwargs) db_id = db_info["db_id"] @@ -288,10 +396,15 @@ async def create_database( "kb_type": kb_type, "created_at": utc_isoformat(), "additional_params": kwargs.copy(), + "share_config": share_config, } + logger.debug(f"[create_database] 保存 global_databases_meta[{db_id}]: {self.global_databases_meta[db_id]}") self._save_global_metadata() + logger.debug(f"[create_database] _save_global_metadata 完成") logger.info(f"Created {kb_type} database: {database_name} ({db_id}) with {kwargs}") + # 返回信息中包含 share_config + db_info["share_config"] = share_config return db_info async def delete_database(self, db_id: str) -> dict: @@ -377,6 +490,31 @@ def _migrate_all_query_params(self): self._save_global_metadata() logger.info(f"Successfully migrated query_params for {migration_count} databases") + def _migrate_share_config(self): + """将 share_config 从 instance metadata 迁移到 global metadata""" + migration_count = 0 + + for kb_type, kb_instance in self.kb_instances.items(): + for db_id, instance_meta in kb_instance.databases_meta.items(): + if db_id not in self.global_databases_meta: + continue + + global_meta = self.global_databases_meta[db_id] + + # 如果 global metadata 中没有 share_config,但 instance metadata 中有 + if "share_config" not in global_meta and "share_config" in instance_meta: + global_meta["share_config"] = instance_meta["share_config"] + # 可选:从 instance metadata 中移除?暂时保留以防万一,或者清理 + # del instance_meta["share_config"] + migration_count += 1 + + if migration_count > 0: + self._save_global_metadata() + # 保存所有修改过的实例元数据 + for kb_instance in self.kb_instances.values(): + kb_instance._save_metadata() + logger.info(f"Successfully migrated share_config for {migration_count} databases") + def get_database_info(self, db_id: str) -> dict | None: """获取数据库详细信息""" try: @@ -384,7 +522,7 @@ def get_database_info(self, db_id: str) -> dict | None: db_info = kb_instance.get_database_info(db_id) # 添加全局元数据中的additional_params信息 - if db_info and db_id in self.global_databases_meta: + if db_id in self.global_databases_meta: global_meta = self.global_databases_meta[db_id] additional_params = global_meta.get("additional_params", {}).copy() @@ -393,6 +531,9 @@ def get_database_info(self, db_id: str) -> dict | None: additional_params["auto_generate_questions"] = False db_info["additional_params"] = additional_params + + # 添加 share_config + db_info["share_config"] = global_meta.get("share_config", {"is_shared": True, "accessible_departments": []}) return db_info except KBNotFoundError: @@ -531,11 +672,12 @@ async def file_existed_in_db(self, db_id: str | None, content_hash: str | None) return False async def update_database( - self, db_id: str, name: str, description: str, llm_info: dict = None, additional_params: dict | None = None + self, db_id: str, name: str, description: str, llm_info: dict = None, additional_params: dict | None = None, share_config: dict | None = None ) -> dict: """更新数据库""" kb_instance = self._get_kb_for_database(db_id) - result = kb_instance.update_database(db_id, name, description, llm_info) + # 注意:这里 kb_instance.update_database 返回的是实例的元数据,不包含 global metadata 中的 share_config + kb_instance.update_database(db_id, name, description, llm_info) async with self._metadata_lock: if db_id in self.global_databases_meta: @@ -551,9 +693,19 @@ async def update_database( # 清理旧的 top-level key (如果存在) self.global_databases_meta[db_id].pop("auto_generate_questions", None) + # 更新共享配置 + if share_config is not None: + existing_share_config = self.global_databases_meta[db_id].get("share_config", {}) + logger.debug(f"[update_database] 原始 share_config: {existing_share_config}") + logger.debug(f"[update_database] 新 share_config: {share_config}") + existing_share_config.update(share_config) + self.global_databases_meta[db_id]["share_config"] = existing_share_config + logger.debug(f"[update_database] 合并后 share_config: {existing_share_config}") + self._save_global_metadata() - return result + # 返回包含 global metadata 的完整信息 + return self.get_database_info(db_id) def get_retrievers(self) -> dict[str, dict]: """获取所有检索器""" diff --git a/web/src/apis/department_api.js b/web/src/apis/department_api.js index 5013e82ef..1c4641e00 100644 --- a/web/src/apis/department_api.js +++ b/web/src/apis/department_api.js @@ -2,16 +2,16 @@ * 部门管理 API */ -import { apiSuperAdminGet, apiSuperAdminPost, apiSuperAdminPut, apiSuperAdminDelete } from './base' +import { apiAdminGet, apiSuperAdminGet, apiSuperAdminPost, apiSuperAdminPut, apiSuperAdminDelete } from './base' const BASE_URL = '/api/departments' /** - * 获取部门列表 + * 获取部门列表(普通管理员可访问) * @returns {Promise} 部门列表 */ export const getDepartments = () => { - return apiSuperAdminGet(BASE_URL) + return apiAdminGet(BASE_URL) } /** diff --git a/web/src/apis/knowledge_api.js b/web/src/apis/knowledge_api.js index 031db8839..122306939 100644 --- a/web/src/apis/knowledge_api.js +++ b/web/src/apis/knowledge_api.js @@ -68,6 +68,14 @@ export const databaseApi = { current_description: currentDescription, file_list: fileList }) + }, + + /** + * 获取当前用户有权访问的知识库列表(用于智能体配置) + * @returns {Promise} - 可访问的知识库列表 + */ + getAccessibleDatabases: async () => { + return apiAdminGet('/api/knowledge/databases/accessible') } } diff --git a/web/src/components/KnowledgeBaseCard.vue b/web/src/components/KnowledgeBaseCard.vue index 484e5980f..4a2048c44 100644 --- a/web/src/components/KnowledgeBaseCard.vue +++ b/web/src/components/KnowledgeBaseCard.vue @@ -90,26 +90,94 @@ style="width: 100%" /> + + + + + + + + + + + + + diff --git a/web/src/components/UserManagementComponent.vue b/web/src/components/UserManagementComponent.vue index 8ea829e71..bfe2d44ac 100644 --- a/web/src/components/UserManagementComponent.vue +++ b/web/src/components/UserManagementComponent.vue @@ -41,20 +41,18 @@